Implement WireGuard protocol as outbound (client) (#1344)

* implement WireGuard protocol for Outbound

* upload license

* fix build for openbsd & dragonfly os

* updated wireguard-go

* fix up

* switch to another wireguard fork

* fix

* switch to upstream

* open connection through internet.Dialer (#1)

* use internet.Dialer

* maybe better code

* fix

* real fix

Co-authored-by: nanoda0523 <nanoda0523@users.noreply.github.com>

* fix bugs & add ability to recover during connection reset on UDP over TCP parent protocols

* improve performance

improve performance

* dns lookup endpoint && remove unused code

* interface address fallback

* better code && add config test case

Co-authored-by: nanoda0523 <nanoda0523@users.noreply.github.com>
This commit is contained in:
nanoda0523 2022-11-22 09:05:54 +08:00 committed by GitHub
parent 691b2b1c73
commit e18b52a5df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 1326 additions and 1 deletions

254
proxy/wireguard/bind.go Normal file
View file

@ -0,0 +1,254 @@
package wireguard
import (
"context"
"errors"
"io"
"net"
"net/netip"
"strconv"
"sync"
"github.com/sagernet/wireguard-go/conn"
xnet "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/transport/internet"
)
type netReadInfo struct {
// status
waiter sync.WaitGroup
// param
buff []byte
// result
bytes int
endpoint conn.Endpoint
err error
}
type netBindClient struct {
workers int
dialer internet.Dialer
dns dns.Client
dnsOption dns.IPOption
readQueue chan *netReadInfo
}
func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
ipStr, port, _, err := splitAddrPort(s)
if err != nil {
return nil, err
}
var addr net.IP
if IsDomainName(ipStr) {
ips, err := n.dns.LookupIP(ipStr, n.dnsOption)
if err != nil {
return nil, err
} else if len(ips) == 0 {
return nil, dns.ErrEmptyResponse
}
addr = ips[0]
} else {
addr = net.ParseIP(ipStr)
}
if addr == nil {
return nil, errors.New("failed to parse ip: " + ipStr)
}
var ip xnet.Address
if p4 := addr.To4(); len(p4) == net.IPv4len {
ip = xnet.IPAddress(p4[:])
} else {
ip = xnet.IPAddress(addr[:])
}
dst := xnet.Destination{
Address: ip,
Port: xnet.Port(port),
Network: xnet.Network_UDP,
}
return &netEndpoint{
dst: dst,
}, nil
}
func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
bind.readQueue = make(chan *netReadInfo)
fun := func(buff []byte) (cap int, ep conn.Endpoint, err error) {
defer func() {
if r := recover(); r != nil {
cap = 0
ep = nil
err = errors.New("channel closed")
}
}()
r := &netReadInfo{
buff: buff,
}
r.waiter.Add(1)
bind.readQueue <- r
r.waiter.Wait() // wait read goroutine done, or we will miss the result
return r.bytes, r.endpoint, r.err
}
workers := bind.workers
if workers <= 0 {
workers = 1
}
arr := make([]conn.ReceiveFunc, workers)
for i := 0; i < workers; i++ {
arr[i] = fun
}
return arr, uint16(uport), nil
}
func (bind *netBindClient) Close() error {
if bind.readQueue != nil {
close(bind.readQueue)
}
return nil
}
func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
if err != nil {
return err
}
endpoint.conn = c
go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
for {
v, ok := <-readQueue
if !ok {
return
}
i, err := c.Read(v.buff)
v.bytes = i
v.endpoint = endpoint
v.err = err
v.waiter.Done()
if err != nil && errors.Is(err, io.EOF) {
endpoint.conn = nil
return
}
}
}(bind.readQueue, endpoint)
return nil
}
func (bind *netBindClient) Send(buff []byte, endpoint conn.Endpoint) error {
var err error
nend, ok := endpoint.(*netEndpoint)
if !ok {
return conn.ErrWrongEndpointType
}
if nend.conn == nil {
err = bind.connectTo(nend)
if err != nil {
return err
}
}
_, err = nend.conn.Write(buff)
return err
}
func (bind *netBindClient) SetMark(mark uint32) error {
return nil
}
type netEndpoint struct {
dst xnet.Destination
conn net.Conn
}
func (netEndpoint) ClearSrc() {}
func (e netEndpoint) DstIP() netip.Addr {
return toNetIpAddr(e.dst.Address)
}
func (e netEndpoint) SrcIP() netip.Addr {
return netip.Addr{}
}
func (e netEndpoint) DstToBytes() []byte {
var dat []byte
if e.dst.Address.Family().IsIPv4() {
dat = e.dst.Address.IP().To4()[:]
} else {
dat = e.dst.Address.IP().To16()[:]
}
dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
return dat
}
func (e netEndpoint) DstToString() string {
return e.dst.NetAddr()
}
func (e netEndpoint) SrcToString() string {
return ""
}
func toNetIpAddr(addr xnet.Address) netip.Addr {
if addr.Family().IsIPv4() {
ip := addr.IP()
return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
} else {
ip := addr.IP()
arr := [16]byte{}
for i := 0; i < 16; i++ {
arr[i] = ip[i]
}
return netip.AddrFrom16(arr)
}
}
func stringsLastIndexByte(s string, b byte) int {
for i := len(s) - 1; i >= 0; i-- {
if s[i] == b {
return i
}
}
return -1
}
func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) {
i := stringsLastIndexByte(s, ':')
if i == -1 {
return "", 0, false, errors.New("not an ip:port")
}
ip = s[:i]
portStr := s[i+1:]
if len(ip) == 0 {
return "", 0, false, errors.New("no IP")
}
if len(portStr) == 0 {
return "", 0, false, errors.New("no port")
}
port64, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s))
}
port = uint16(port64)
if ip[0] == '[' {
if len(ip) < 2 || ip[len(ip)-1] != ']' {
return "", 0, false, errors.New("missing ]")
}
ip = ip[1 : len(ip)-1]
v6 = true
}
return ip, port, v6, nil
}

View file

@ -0,0 +1,294 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.9
// source: proxy/wireguard/config.proto
package wireguard
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type PeerConfig struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
PublicKey string `protobuf:"bytes,1,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
PreSharedKey string `protobuf:"bytes,2,opt,name=pre_shared_key,json=preSharedKey,proto3" json:"pre_shared_key,omitempty"`
Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
KeepAlive int32 `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"`
AllowedIps []string `protobuf:"bytes,5,rep,name=allowed_ips,json=allowedIps,proto3" json:"allowed_ips,omitempty"`
}
func (x *PeerConfig) Reset() {
*x = PeerConfig{}
if protoimpl.UnsafeEnabled {
mi := &file_proxy_wireguard_config_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *PeerConfig) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*PeerConfig) ProtoMessage() {}
func (x *PeerConfig) ProtoReflect() protoreflect.Message {
mi := &file_proxy_wireguard_config_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use PeerConfig.ProtoReflect.Descriptor instead.
func (*PeerConfig) Descriptor() ([]byte, []int) {
return file_proxy_wireguard_config_proto_rawDescGZIP(), []int{0}
}
func (x *PeerConfig) GetPublicKey() string {
if x != nil {
return x.PublicKey
}
return ""
}
func (x *PeerConfig) GetPreSharedKey() string {
if x != nil {
return x.PreSharedKey
}
return ""
}
func (x *PeerConfig) GetEndpoint() string {
if x != nil {
return x.Endpoint
}
return ""
}
func (x *PeerConfig) GetKeepAlive() int32 {
if x != nil {
return x.KeepAlive
}
return 0
}
func (x *PeerConfig) GetAllowedIps() []string {
if x != nil {
return x.AllowedIps
}
return nil
}
type DeviceConfig struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
SecretKey string `protobuf:"bytes,1,opt,name=secret_key,json=secretKey,proto3" json:"secret_key,omitempty"`
Endpoint []string `protobuf:"bytes,2,rep,name=endpoint,proto3" json:"endpoint,omitempty"`
Peers []*PeerConfig `protobuf:"bytes,3,rep,name=peers,proto3" json:"peers,omitempty"`
Mtu int32 `protobuf:"varint,4,opt,name=mtu,proto3" json:"mtu,omitempty"`
NumWorkers int32 `protobuf:"varint,5,opt,name=num_workers,json=numWorkers,proto3" json:"num_workers,omitempty"`
}
func (x *DeviceConfig) Reset() {
*x = DeviceConfig{}
if protoimpl.UnsafeEnabled {
mi := &file_proxy_wireguard_config_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DeviceConfig) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DeviceConfig) ProtoMessage() {}
func (x *DeviceConfig) ProtoReflect() protoreflect.Message {
mi := &file_proxy_wireguard_config_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DeviceConfig.ProtoReflect.Descriptor instead.
func (*DeviceConfig) Descriptor() ([]byte, []int) {
return file_proxy_wireguard_config_proto_rawDescGZIP(), []int{1}
}
func (x *DeviceConfig) GetSecretKey() string {
if x != nil {
return x.SecretKey
}
return ""
}
func (x *DeviceConfig) GetEndpoint() []string {
if x != nil {
return x.Endpoint
}
return nil
}
func (x *DeviceConfig) GetPeers() []*PeerConfig {
if x != nil {
return x.Peers
}
return nil
}
func (x *DeviceConfig) GetMtu() int32 {
if x != nil {
return x.Mtu
}
return 0
}
func (x *DeviceConfig) GetNumWorkers() int32 {
if x != nil {
return x.NumWorkers
}
return 0
}
var File_proxy_wireguard_config_proto protoreflect.FileDescriptor
var file_proxy_wireguard_config_proto_rawDesc = []byte{
0x0a, 0x1c, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72,
0x64, 0x2f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x14,
0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x77, 0x69, 0x72, 0x65, 0x67,
0x75, 0x61, 0x72, 0x64, 0x22, 0xad, 0x01, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e,
0x66, 0x69, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65,
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b,
0x65, 0x79, 0x12, 0x24, 0x0a, 0x0e, 0x70, 0x72, 0x65, 0x5f, 0x73, 0x68, 0x61, 0x72, 0x65, 0x64,
0x5f, 0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53,
0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70,
0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x64, 0x70,
0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x6b, 0x65, 0x65, 0x70, 0x5f, 0x61, 0x6c, 0x69,
0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c,
0x69, 0x76, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x69,
0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65,
0x64, 0x49, 0x70, 0x73, 0x22, 0xb4, 0x01, 0x0a, 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x5f,
0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x63, 0x72, 0x65,
0x74, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
0x12, 0x36, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32,
0x20, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x77, 0x69, 0x72,
0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69,
0x67, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x74, 0x75, 0x18,
0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x12, 0x1f, 0x0a, 0x0b, 0x6e, 0x75,
0x6d, 0x5f, 0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x05, 0x52,
0x0a, 0x6e, 0x75, 0x6d, 0x57, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x73, 0x42, 0x5e, 0x0a, 0x18, 0x63,
0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x77, 0x69,
0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x01, 0x5a, 0x29, 0x67, 0x69, 0x74, 0x68, 0x75,
0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d,
0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x77, 0x69, 0x72, 0x65, 0x67,
0x75, 0x61, 0x72, 0x64, 0xaa, 0x02, 0x14, 0x58, 0x72, 0x61, 0x79, 0x2e, 0x50, 0x72, 0x6f, 0x78,
0x79, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x47, 0x75, 0x61, 0x72, 0x64, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
}
var (
file_proxy_wireguard_config_proto_rawDescOnce sync.Once
file_proxy_wireguard_config_proto_rawDescData = file_proxy_wireguard_config_proto_rawDesc
)
func file_proxy_wireguard_config_proto_rawDescGZIP() []byte {
file_proxy_wireguard_config_proto_rawDescOnce.Do(func() {
file_proxy_wireguard_config_proto_rawDescData = protoimpl.X.CompressGZIP(file_proxy_wireguard_config_proto_rawDescData)
})
return file_proxy_wireguard_config_proto_rawDescData
}
var file_proxy_wireguard_config_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_proxy_wireguard_config_proto_goTypes = []interface{}{
(*PeerConfig)(nil), // 0: xray.proxy.wireguard.PeerConfig
(*DeviceConfig)(nil), // 1: xray.proxy.wireguard.DeviceConfig
}
var file_proxy_wireguard_config_proto_depIdxs = []int32{
0, // 0: xray.proxy.wireguard.DeviceConfig.peers:type_name -> xray.proxy.wireguard.PeerConfig
1, // [1:1] is the sub-list for method output_type
1, // [1:1] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_proxy_wireguard_config_proto_init() }
func file_proxy_wireguard_config_proto_init() {
if File_proxy_wireguard_config_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_proxy_wireguard_config_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PeerConfig); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_proxy_wireguard_config_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DeviceConfig); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_proxy_wireguard_config_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_proxy_wireguard_config_proto_goTypes,
DependencyIndexes: file_proxy_wireguard_config_proto_depIdxs,
MessageInfos: file_proxy_wireguard_config_proto_msgTypes,
}.Build()
File_proxy_wireguard_config_proto = out.File
file_proxy_wireguard_config_proto_rawDesc = nil
file_proxy_wireguard_config_proto_goTypes = nil
file_proxy_wireguard_config_proto_depIdxs = nil
}

View file

@ -0,0 +1,23 @@
syntax = "proto3";
package xray.proxy.wireguard;
option csharp_namespace = "Xray.Proxy.WireGuard";
option go_package = "github.com/xtls/xray-core/proxy/wireguard";
option java_package = "com.xray.proxy.wireguard";
option java_multiple_files = true;
message PeerConfig {
string public_key = 1;
string pre_shared_key = 2;
string endpoint = 3;
int32 keep_alive = 4;
repeated string allowed_ips = 5;
}
message DeviceConfig {
string secret_key = 1;
repeated string endpoint = 2;
repeated PeerConfig peers = 3;
int32 mtu = 4;
int32 num_workers = 5;
}

View file

@ -0,0 +1,9 @@
package wireguard
import "github.com/xtls/xray-core/common/errors"
type errPathObjHolder struct{}
func newError(values ...interface{}) *errors.Error {
return errors.New(values...).WithPathObj(errPathObjHolder{})
}

303
proxy/wireguard/tun.go Normal file
View file

@ -0,0 +1,303 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
*/
package wireguard
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"github.com/sagernet/wireguard-go/tun"
"github.com/xtls/xray-core/features/dns"
"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
type netTun struct {
ep *channel.Endpoint
stack *stack.Stack
events chan tun.Event
incomingPacket chan *bufferv2.View
mtu int
dnsClient dns.Client
hasV4, hasV6 bool
}
type Net netTun
func CreateNetTUN(localAddresses []netip.Addr, dnsClient dns.Client, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
HandleLocal: true,
}
dev := &netTun{
ep: channel.New(1024, uint32(mtu), ""),
stack: stack.New(opts),
events: make(chan tun.Event, 10),
incomingPacket: make(chan *bufferv2.View),
dnsClient: dnsClient,
mtu: mtu,
}
dev.ep.AddNotify(dev)
tcpipErr := dev.stack.CreateNIC(1, dev.ep)
if tcpipErr != nil {
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
for _, ip := range localAddresses {
var protoNumber tcpip.NetworkProtocolNumber
if ip.Is4() {
protoNumber = ipv4.ProtocolNumber
} else if ip.Is6() {
protoNumber = ipv6.ProtocolNumber
}
protoAddr := tcpip.ProtocolAddress{
Protocol: protoNumber,
AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
}
if ip.Is4() {
dev.hasV4 = true
} else if ip.Is6() {
dev.hasV6 = true
}
}
if dev.hasV4 {
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
}
if dev.hasV6 {
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
}
dev.events <- tun.EventUp
return dev, (*Net)(dev), nil
}
func (tun *netTun) Name() (string, error) {
return "go", nil
}
func (tun *netTun) File() *os.File {
return nil
}
func (tun *netTun) Events() chan tun.Event {
return tun.events
}
func (tun *netTun) Read(buf []byte, offset int) (int, error) {
view, ok := <-tun.incomingPacket
if !ok {
return 0, os.ErrClosed
}
return view.Read(buf[offset:])
}
func (tun *netTun) Write(buf []byte, offset int) (int, error) {
packet := buf[offset:]
if len(packet) == 0 {
return 0, nil
}
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
switch packet[0] >> 4 {
case 4:
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
case 6:
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
}
return len(buf), nil
}
func (tun *netTun) WriteNotify() {
pkt := tun.ep.Read()
if pkt == nil {
return
}
view := pkt.ToView()
pkt.DecRef()
tun.incomingPacket <- view
}
func (tun *netTun) Flush() error {
return nil
}
func (tun *netTun) Close() error {
tun.stack.RemoveNIC(1)
if tun.events != nil {
close(tun.events)
}
tun.ep.Close()
if tun.incomingPacket != nil {
close(tun.incomingPacket)
}
return nil
}
func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
var protoNumber tcpip.NetworkProtocolNumber
if endpoint.Addr().Is4() {
protoNumber = ipv4.ProtocolNumber
} else {
protoNumber = ipv6.ProtocolNumber
}
return tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(endpoint.Addr().AsSlice()),
Port: endpoint.Port(),
}, protoNumber
}
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
fa, pn := convertToFullAddr(addr)
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
}
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
}
ip, _ := netip.AddrFromSlice(addr.IP)
return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
}
func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
fa, pn := convertToFullAddr(addr)
return gonet.DialTCP(net.stack, fa, pn)
}
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
return net.DialTCPAddrPort(netip.AddrPort{})
}
ip, _ := netip.AddrFromSlice(addr.IP)
return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
}
func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
fa, pn := convertToFullAddr(addr)
return gonet.ListenTCP(net.stack, fa, pn)
}
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
if addr == nil {
return net.ListenTCPAddrPort(netip.AddrPort{})
}
ip, _ := netip.AddrFromSlice(addr.IP)
return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
}
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
var lfa, rfa *tcpip.FullAddress
var pn tcpip.NetworkProtocolNumber
if laddr.IsValid() || laddr.Port() > 0 {
var addr tcpip.FullAddress
addr, pn = convertToFullAddr(laddr)
lfa = &addr
}
if raddr.IsValid() || raddr.Port() > 0 {
var addr tcpip.FullAddress
addr, pn = convertToFullAddr(raddr)
rfa = &addr
}
return gonet.DialUDP(net.stack, lfa, rfa, pn)
}
func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
return net.DialUDPAddrPort(laddr, netip.AddrPort{})
}
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
var la, ra netip.AddrPort
if laddr != nil {
ip, _ := netip.AddrFromSlice(laddr.IP)
la = netip.AddrPortFrom(ip, uint16(laddr.Port))
}
if raddr != nil {
ip, _ := netip.AddrFromSlice(raddr.IP)
ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
}
return net.DialUDPAddrPort(la, ra)
}
func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
return net.DialUDP(laddr, nil)
}
func (n *Net) HasV4() bool {
return n.hasV4
}
func (n *Net) HasV6() bool {
return n.hasV6
}
func IsDomainName(s string) bool {
l := len(s)
if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
return false
}
last := byte('.')
nonNumeric := false
partlen := 0
for i := 0; i < len(s); i++ {
c := s[i]
switch {
default:
return false
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
nonNumeric = true
partlen++
case '0' <= c && c <= '9':
partlen++
case c == '-':
if last == '.' {
return false
}
partlen++
nonNumeric = true
case c == '.':
if last == '.' || last == '-' {
return false
}
if partlen > 63 || partlen == 0 {
return false
}
partlen = 0
}
last = c
}
if last == '-' || partlen > 63 {
return false
}
return nonNumeric
}

View file

@ -0,0 +1,263 @@
/*
Some of codes are copied from https://github.com/octeep/wireproxy, license below.
Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package wireguard
import (
"bytes"
"context"
"fmt"
"net/netip"
"strings"
"github.com/sagernet/wireguard-go/device"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet"
)
// Handler is an outbound connection that silently swallow the entire payload.
type Handler struct {
conf *DeviceConfig
net *Net
bind *netBindClient
policyManager policy.Manager
dns dns.Client
// cached configuration
ipc string
endpoints []netip.Addr
}
// New creates a new wireguard handler.
func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
v := core.MustFromContext(ctx)
endpoints, err := parseEndpoints(conf)
if err != nil {
return nil, err
}
return &Handler{
conf: conf,
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
dns: v.GetFeature(dns.ClientType()).(dns.Client),
ipc: createIPCRequest(conf),
endpoints: endpoints,
}, nil
}
// Process implements OutboundHandler.Dispatch().
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
if h.bind == nil || h.bind.dialer != dialer || h.net == nil {
log.Record(&log.GeneralMessage{
Severity: log.Severity_Info,
Content: "switching dialer",
})
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
bind := &netBindClient{
dialer: dialer,
workers: int(h.conf.NumWorkers),
dns: h.dns,
}
net, err := h.makeVirtualTun(bind)
if err != nil {
bind.Close()
return newError("failed to create virtual tun interface").Base(err)
}
h.net = net
if h.bind != nil {
h.bind.Close()
}
h.bind = bind
}
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified")
}
// Destination of the inner request.
destination := outbound.Target
command := protocol.RequestCommandTCP
if destination.Network == net.Network_UDP {
command = protocol.RequestCommandUDP
}
// resolve dns
addr := destination.Address
if addr.Family().IsDomain() {
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
IPv4Enable: h.net.HasV4(),
IPv6Enable: h.net.HasV6(),
})
if err != nil {
return newError("failed to lookup DNS").Base(err)
} else if len(ips) == 0 {
return dns.ErrEmptyResponse
}
addr = net.IPAddress(ips[0])
}
p := h.policyManager.ForLevel(0)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, p.Timeouts.ConnectionIdle)
addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
var requestFunc func() error
var responseFunc func() error
if command == protocol.RequestCommandTCP {
conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
if err != nil {
return newError("failed to create TCP connection").Base(err)
}
requestFunc = func() error {
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
}
responseFunc = func() error {
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
}
} else if command == protocol.RequestCommandUDP {
conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
if err != nil {
return newError("failed to create UDP connection").Base(err)
}
requestFunc = func() error {
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
}
responseFunc = func() error {
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
}
}
responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
return newError("connection ends").Base(err)
}
return nil
}
// serialize the config into an IPC request
func createIPCRequest(conf *DeviceConfig) string {
var request bytes.Buffer
request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
for _, peer := range conf.Peers {
request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n",
peer.PublicKey, peer.Endpoint, peer.KeepAlive, peer.PreSharedKey))
for _, ip := range peer.AllowedIps {
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
}
}
return request.String()[:request.Len()]
}
// convert endpoint string to netip.Addr
func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) {
endpoints := make([]netip.Addr, len(conf.Endpoint))
for i, str := range conf.Endpoint {
var addr netip.Addr
if strings.Contains(str, "/") {
prefix, err := netip.ParsePrefix(str)
if err != nil {
return nil, err
}
addr = prefix.Addr()
if prefix.Bits() != addr.BitLen() {
return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6")
}
} else {
var err error
addr, err = netip.ParseAddr(str)
if err != nil {
return nil, err
}
}
endpoints[i] = addr
}
return endpoints, nil
}
// creates a tun interface on netstack given a configuration
func (h *Handler) makeVirtualTun(bind *netBindClient) (*Net, error) {
tun, tnet, err := CreateNetTUN(h.endpoints, h.dns, int(h.conf.Mtu))
if err != nil {
return nil, err
}
bind.dnsOption.IPv4Enable = tnet.HasV4()
bind.dnsOption.IPv6Enable = tnet.HasV6()
// dev := device.NewDevice(tun, conn.NewDefaultBind(), nil /* device.NewLogger(device.LogLevelVerbose, "") */)
dev := device.NewDevice(tun, bind, &device.Logger{
Verbosef: func(format string, args ...any) {
log.Record(&log.GeneralMessage{
Severity: log.Severity_Debug,
Content: fmt.Sprintf(format, args...),
})
},
Errorf: func(format string, args ...any) {
log.Record(&log.GeneralMessage{
Severity: log.Severity_Error,
Content: fmt.Sprintf(format, args...),
})
},
}, int(h.conf.NumWorkers))
err = dev.IpcSet(h.ipc)
if err != nil {
return nil, err
}
err = dev.Up()
if err != nil {
return nil, err
}
return tnet, nil
}
func init() {
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
return New(ctx, config.(*DeviceConfig))
}))
}