WireGuard Inbound (User-space WireGuard server) (#2477)

* feat: wireguard inbound

* feat(command): generate wireguard compatible keypair

* feat(wireguard): connection idle timeout

* fix(wireguard): close endpoint after connection closed

* fix(wireguard): resolve conflicts

* feat(wireguard): set cubic as default cc algorithm in gVisor TUN

* chore(wireguard): resolve conflict

* chore(wireguard): remove redurant code

* chore(wireguard): remove redurant code

* feat: rework server for gvisor tun

* feat: keep user-space tun as an option

* fix: exclude android from native tun build

* feat: auto kernel tun

* fix: build

* fix: regulate function name & fix test
This commit is contained in:
hax0r31337 2023-11-18 11:27:17 +08:00 committed by GitHub
parent f1c81557dc
commit 0ac7da2fc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1049 additions and 500 deletions

4
go.mod
View File

@ -27,6 +27,7 @@ require (
golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb
google.golang.org/grpc v1.59.0 google.golang.org/grpc v1.59.0
google.golang.org/protobuf v1.31.0 google.golang.org/protobuf v1.31.0
gvisor.dev/gvisor v0.0.0-20231104011432-48a6d7d5bd0b
h12.io/socks v1.0.3 h12.io/socks v1.0.3
lukechampine.com/blake3 v1.2.1 lukechampine.com/blake3 v1.2.1
) )
@ -48,7 +49,7 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/quic-go/qtls-go1-20 v0.4.1 // indirect github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
github.com/vishvananda/netns v0.0.4 // indirect github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect
go.uber.org/mock v0.3.0 // indirect go.uber.org/mock v0.3.0 // indirect
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect
golang.org/x/mod v0.14.0 // indirect golang.org/x/mod v0.14.0 // indirect
@ -59,5 +60,4 @@ require (
google.golang.org/genproto/googleapis/rpc v0.0.0-20231106174013-bbf56f31fb17 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20231106174013-bbf56f31fb17 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
gvisor.dev/gvisor v0.0.0-20231104011432-48a6d7d5bd0b // indirect
) )

3
go.sum
View File

@ -168,9 +168,8 @@ github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49u
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3 h1:tkMT5pTye+1NlKIXETU78NXw0fyjnaNHmJyyLyzw8+U= github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3 h1:tkMT5pTye+1NlKIXETU78NXw0fyjnaNHmJyyLyzw8+U=
github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3/go.mod h1:cAAsePK2e15YDAMJNyOpGYEWNe4sIghTY7gpz4cX/Ik= github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3/go.mod h1:cAAsePK2e15YDAMJNyOpGYEWNe4sIghTY7gpz4cX/Ik=
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns=
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19 h1:capMfFYRgH9BCLd6A3Er/cH3A9Nz3CU2KwxwOQZIePI= github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19 h1:capMfFYRgH9BCLd6A3Er/cH3A9Nz3CU2KwxwOQZIePI=
github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19/go.mod h1:dm4y/1QwzjGaK17ofi0Vs6NpKAHegZky8qk6J2JJZAE= github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19/go.mod h1:dm4y/1QwzjGaK17ofi0Vs6NpKAHegZky8qk6J2JJZAE=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=

View File

@ -13,7 +13,7 @@ type WireGuardPeerConfig struct {
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
PreSharedKey string `json:"preSharedKey"` PreSharedKey string `json:"preSharedKey"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
KeepAlive int `json:"keepAlive"` KeepAlive uint32 `json:"keepAlive"`
AllowedIPs []string `json:"allowedIPs,omitempty"` AllowedIPs []string `json:"allowedIPs,omitempty"`
} }
@ -21,23 +21,23 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) {
var err error var err error
config := new(wireguard.PeerConfig) config := new(wireguard.PeerConfig)
if c.PublicKey != "" {
config.PublicKey, err = parseWireGuardKey(c.PublicKey) config.PublicKey, err = parseWireGuardKey(c.PublicKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
if c.PreSharedKey != "" { if c.PreSharedKey != "" {
config.PreSharedKey, err = parseWireGuardKey(c.PreSharedKey) config.PreSharedKey, err = parseWireGuardKey(c.PreSharedKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else {
config.PreSharedKey = "0000000000000000000000000000000000000000000000000000000000000000"
} }
config.Endpoint = c.Endpoint config.Endpoint = c.Endpoint
// default 0 // default 0
config.KeepAlive = int32(c.KeepAlive) config.KeepAlive = c.KeepAlive
if c.AllowedIPs == nil { if c.AllowedIPs == nil {
config.AllowedIps = []string{"0.0.0.0/0", "::0/0"} config.AllowedIps = []string{"0.0.0.0/0", "::0/0"}
} else { } else {
@ -48,11 +48,14 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) {
} }
type WireGuardConfig struct { type WireGuardConfig struct {
IsClient bool `json:""`
KernelMode *bool `json:"kernelMode"`
SecretKey string `json:"secretKey"` SecretKey string `json:"secretKey"`
Address []string `json:"address"` Address []string `json:"address"`
Peers []*WireGuardPeerConfig `json:"peers"` Peers []*WireGuardPeerConfig `json:"peers"`
MTU int `json:"mtu"` MTU int32 `json:"mtu"`
NumWorkers int `json:"workers"` NumWorkers int32 `json:"workers"`
Reserved []byte `json:"reserved"` Reserved []byte `json:"reserved"`
DomainStrategy string `json:"domainStrategy"` DomainStrategy string `json:"domainStrategy"`
} }
@ -87,11 +90,11 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
if c.MTU == 0 { if c.MTU == 0 {
config.Mtu = 1420 config.Mtu = 1420
} else { } else {
config.Mtu = int32(c.MTU) config.Mtu = c.MTU
} }
// these a fallback code exists in github.com/nanoda0523/wireguard-go code, // these a fallback code exists in wireguard-go code,
// we don't need to process fallback manually // we don't need to process fallback manually
config.NumWorkers = int32(c.NumWorkers) config.NumWorkers = c.NumWorkers
if len(c.Reserved) != 0 && len(c.Reserved) != 3 { if len(c.Reserved) != 0 && len(c.Reserved) != 3 {
return nil, newError(`"reserved" should be empty or 3 bytes`) return nil, newError(`"reserved" should be empty or 3 bytes`)
@ -113,22 +116,42 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
return nil, newError("unsupported domain strategy: ", c.DomainStrategy) return nil, newError("unsupported domain strategy: ", c.DomainStrategy)
} }
config.IsClient = c.IsClient
if c.KernelMode != nil {
config.KernelMode = *c.KernelMode
if config.KernelMode && !wireguard.KernelTunSupported() {
newError("kernel mode is not supported on your OS or permission is insufficient").AtWarning().WriteToLog()
}
} else {
config.KernelMode = wireguard.KernelTunSupported()
if config.KernelMode {
newError("kernel mode is enabled as it's supported and permission is sufficient").AtDebug().WriteToLog()
}
}
return config, nil return config, nil
} }
func parseWireGuardKey(str string) (string, error) { func parseWireGuardKey(str string) (string, error) {
if len(str) != 64 { var err error
// may in base64 form
dat, err := base64.StdEncoding.DecodeString(str) if len(str)%2 == 0 {
if err != nil { _, err = hex.DecodeString(str)
return "", err if err == nil {
}
if len(dat) != 32 {
return "", newError("key should be 32 bytes: " + str)
}
return hex.EncodeToString(dat), err
} else {
// already hex form
return str, nil return str, nil
} }
}
var dat []byte
str = strings.TrimSuffix(str, "=")
if strings.ContainsRune(str, '+') || strings.ContainsRune(str, '/') {
dat, err = base64.RawStdEncoding.DecodeString(str)
} else {
dat, err = base64.RawURLEncoding.DecodeString(str)
}
if err == nil {
return hex.EncodeToString(dat), nil
}
return "", newError("failed to deserialize key").Base(err)
} }

View File

@ -7,7 +7,7 @@ import (
"github.com/xtls/xray-core/proxy/wireguard" "github.com/xtls/xray-core/proxy/wireguard"
) )
func TestWireGuardOutbound(t *testing.T) { func TestWireGuardConfig(t *testing.T) {
creator := func() Buildable { creator := func() Buildable {
return new(WireGuardConfig) return new(WireGuardConfig)
} }
@ -25,7 +25,8 @@ func TestWireGuardOutbound(t *testing.T) {
], ],
"mtu": 1300, "mtu": 1300,
"workers": 2, "workers": 2,
"domainStrategy": "ForceIPv6v4" "domainStrategy": "ForceIPv6v4",
"kernelMode": false
}`, }`,
Parser: loadJSON(creator), Parser: loadJSON(creator),
Output: &wireguard.DeviceConfig{ Output: &wireguard.DeviceConfig{
@ -36,7 +37,6 @@ func TestWireGuardOutbound(t *testing.T) {
{ {
// also can read from hex form directly // also can read from hex form directly
PublicKey: "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a", PublicKey: "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a",
PreSharedKey: "0000000000000000000000000000000000000000000000000000000000000000",
Endpoint: "127.0.0.1:1234", Endpoint: "127.0.0.1:1234",
KeepAlive: 0, KeepAlive: 0,
AllowedIps: []string{"0.0.0.0/0", "::0/0"}, AllowedIps: []string{"0.0.0.0/0", "::0/0"},
@ -45,6 +45,7 @@ func TestWireGuardOutbound(t *testing.T) {
Mtu: 1300, Mtu: 1300,
NumWorkers: 2, NumWorkers: 2,
DomainStrategy: wireguard.DeviceConfig_FORCE_IP64, DomainStrategy: wireguard.DeviceConfig_FORCE_IP64,
KernelMode: false,
}, },
}, },
}) })

View File

@ -24,6 +24,7 @@ var (
"vless": func() interface{} { return new(VLessInboundConfig) }, "vless": func() interface{} { return new(VLessInboundConfig) },
"vmess": func() interface{} { return new(VMessInboundConfig) }, "vmess": func() interface{} { return new(VMessInboundConfig) },
"trojan": func() interface{} { return new(TrojanServerConfig) }, "trojan": func() interface{} { return new(TrojanServerConfig) },
"wireguard": func() interface{} { return &WireGuardConfig{IsClient: false} },
}, "protocol", "settings") }, "protocol", "settings")
outboundConfigLoader = NewJSONConfigLoader(ConfigCreatorCache{ outboundConfigLoader = NewJSONConfigLoader(ConfigCreatorCache{
@ -37,7 +38,7 @@ var (
"vmess": func() interface{} { return new(VMessOutboundConfig) }, "vmess": func() interface{} { return new(VMessOutboundConfig) },
"trojan": func() interface{} { return new(TrojanClientConfig) }, "trojan": func() interface{} { return new(TrojanClientConfig) },
"dns": func() interface{} { return new(DNSOutboundConfig) }, "dns": func() interface{} { return new(DNSOutboundConfig) },
"wireguard": func() interface{} { return new(WireGuardConfig) }, "wireguard": func() interface{} { return &WireGuardConfig{IsClient: true} },
}, "protocol", "settings") }, "protocol", "settings")
ctllog = log.New(os.Stderr, "xctl> ", 0) ctllog = log.New(os.Stderr, "xctl> ", 0)

View File

@ -10,7 +10,7 @@ import (
) )
var cmdX25519 = &base.Command{ var cmdX25519 = &base.Command{
UsageLine: `{{.Exec}} x25519 [-i "private key (base64.RawURLEncoding)"]`, UsageLine: `{{.Exec}} x25519 [-i "private key (base64.RawURLEncoding)"] [--std-encoding]`,
Short: `Generate key pair for x25519 key exchange`, Short: `Generate key pair for x25519 key exchange`,
Long: ` Long: `
Generate key pair for x25519 key exchange. Generate key pair for x25519 key exchange.
@ -18,6 +18,7 @@ Generate key pair for x25519 key exchange.
Random: {{.Exec}} x25519 Random: {{.Exec}} x25519
From private key: {{.Exec}} x25519 -i "private key (base64.RawURLEncoding)" From private key: {{.Exec}} x25519 -i "private key (base64.RawURLEncoding)"
For Std Encoding: {{.Exec}} x25519 --std-encoding
`, `,
} }
@ -26,12 +27,14 @@ func init() {
} }
var input_base64 = cmdX25519.Flag.String("i", "", "") var input_base64 = cmdX25519.Flag.String("i", "", "")
var input_stdEncoding = cmdX25519.Flag.Bool("std-encoding", false, "")
func executeX25519(cmd *base.Command, args []string) { func executeX25519(cmd *base.Command, args []string) {
var output string var output string
var err error var err error
var privateKey []byte var privateKey []byte
var publicKey []byte var publicKey []byte
var encoding *base64.Encoding
if len(*input_base64) > 0 { if len(*input_base64) > 0 {
privateKey, err = base64.RawURLEncoding.DecodeString(*input_base64) privateKey, err = base64.RawURLEncoding.DecodeString(*input_base64)
if err != nil { if err != nil {
@ -63,9 +66,15 @@ func executeX25519(cmd *base.Command, args []string) {
goto out goto out
} }
if *input_stdEncoding {
encoding = base64.StdEncoding
} else {
encoding = base64.RawURLEncoding
}
output = fmt.Sprintf("Private key: %v\nPublic key: %v", output = fmt.Sprintf("Private key: %v\nPublic key: %v",
base64.RawURLEncoding.EncodeToString(privateKey), encoding.EncodeToString(privateKey),
base64.RawURLEncoding.EncodeToString(publicKey)) encoding.EncodeToString(publicKey))
out: out:
fmt.Println(output) fmt.Println(output)
} }

View File

@ -27,48 +27,45 @@ type netReadInfo struct {
err error err error
} }
type netBindClient struct { // reduce duplicated code
workers int type netBind struct {
dialer internet.Dialer
dns dns.Client dns dns.Client
dnsOption dns.IPOption dnsOption dns.IPOption
reserved []byte
workers int
readQueue chan *netReadInfo readQueue chan *netReadInfo
} }
func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) { // SetMark implements conn.Bind
ipStr, port, _, err := splitAddrPort(s) func (bind *netBind) SetMark(mark uint32) error {
return nil
}
// ParseEndpoint implements conn.Bind
func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
ipStr, port, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
portNum, err := strconv.Atoi(port)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var addr net.IP addr := xnet.ParseAddress(ipStr)
if IsDomainName(ipStr) { if addr.Family() == xnet.AddressFamilyDomain {
ips, err := bind.dns.LookupIP(ipStr, bind.dnsOption) ips, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
if err != nil { if err != nil {
return nil, err return nil, err
} else if len(ips) == 0 { } else if len(ips) == 0 {
return nil, dns.ErrEmptyResponse return nil, dns.ErrEmptyResponse
} }
addr = ips[0] addr = xnet.IPAddress(ips[0])
} else {
addr = net.ParseIP(ipStr)
}
if addr == nil {
return nil, errors.New("failed to parse ip: " + ipStr)
}
var ip xnet.Address
if p4 := addr.To4(); len(p4) == net.IPv4len {
ip = xnet.IPAddress(p4[:])
} else {
ip = xnet.IPAddress(addr[:])
} }
dst := xnet.Destination{ dst := xnet.Destination{
Address: ip, Address: addr,
Port: xnet.Port(port), Port: xnet.Port(portNum),
Network: xnet.Network_UDP, Network: xnet.Network_UDP,
} }
@ -77,7 +74,13 @@ func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
}, nil }, nil
} }
func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { // BatchSize implements conn.Bind
func (bind *netBind) BatchSize() int {
return 1
}
// Open implements conn.Bind
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
bind.readQueue = make(chan *netReadInfo) bind.readQueue = make(chan *netReadInfo)
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
@ -109,13 +112,21 @@ func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error
return arr, uint16(uport), nil return arr, uint16(uport), nil
} }
func (bind *netBindClient) Close() error { // Close implements conn.Bind
func (bind *netBind) Close() error {
if bind.readQueue != nil { if bind.readQueue != nil {
close(bind.readQueue) close(bind.readQueue)
} }
return nil return nil
} }
type netBindClient struct {
netBind
dialer internet.Dialer
reserved []byte
}
func (bind *netBindClient) connectTo(endpoint *netEndpoint) error { func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
c, err := bind.dialer.Dial(context.Background(), endpoint.dst) c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
if err != nil { if err != nil {
@ -177,12 +188,29 @@ func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
return nil return nil
} }
func (bind *netBindClient) SetMark(mark uint32) error { type netBindServer struct {
return nil netBind
} }
func (bind *netBindClient) BatchSize() int { func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
return 1 var err error
nend, ok := endpoint.(*netEndpoint)
if !ok {
return conn.ErrWrongEndpointType
}
if nend.conn == nil {
return newError("connection not open yet")
}
for _, buff := range buff {
if _, err = nend.conn.Write(buff); err != nil {
return err
}
}
return err
} }
type netEndpoint struct { type netEndpoint struct {
@ -193,7 +221,7 @@ type netEndpoint struct {
func (netEndpoint) ClearSrc() {} func (netEndpoint) ClearSrc() {}
func (e netEndpoint) DstIP() netip.Addr { func (e netEndpoint) DstIP() netip.Addr {
return toNetIpAddr(e.dst.Address) return netip.Addr{}
} }
func (e netEndpoint) SrcIP() netip.Addr { func (e netEndpoint) SrcIP() netip.Addr {
@ -232,83 +260,3 @@ func toNetIpAddr(addr xnet.Address) netip.Addr {
return netip.AddrFrom16(arr) return netip.AddrFrom16(arr)
} }
} }
func stringsLastIndexByte(s string, b byte) int {
for i := len(s) - 1; i >= 0; i-- {
if s[i] == b {
return i
}
}
return -1
}
func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) {
i := stringsLastIndexByte(s, ':')
if i == -1 {
return "", 0, false, errors.New("not an ip:port")
}
ip = s[:i]
portStr := s[i+1:]
if len(ip) == 0 {
return "", 0, false, errors.New("no IP")
}
if len(portStr) == 0 {
return "", 0, false, errors.New("no port")
}
port64, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s))
}
port = uint16(port64)
if ip[0] == '[' {
if len(ip) < 2 || ip[len(ip)-1] != ']' {
return "", 0, false, errors.New("missing ]")
}
ip = ip[1 : len(ip)-1]
v6 = true
}
return ip, port, v6, nil
}
func IsDomainName(s string) bool {
l := len(s)
if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
return false
}
last := byte('.')
nonNumeric := false
partlen := 0
for i := 0; i < len(s); i++ {
c := s[i]
switch {
default:
return false
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
nonNumeric = true
partlen++
case '0' <= c && c <= '9':
partlen++
case c == '-':
if last == '.' {
return false
}
partlen++
nonNumeric = true
case c == '.':
if last == '.' || last == '-' {
return false
}
if partlen > 63 || partlen == 0 {
return false
}
partlen = 0
}
last = c
}
if last == '-' || partlen > 63 {
return false
}
return nonNumeric
}

255
proxy/wireguard/client.go Normal file
View File

@ -0,0 +1,255 @@
/*
Some of codes are copied from https://github.com/octeep/wireproxy, license below.
Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package wireguard
import (
"context"
"net/netip"
"sync"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/dice"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet"
)
// Handler is an outbound connection that silently swallow the entire payload.
type Handler struct {
conf *DeviceConfig
net Tunnel
bind *netBindClient
policyManager policy.Manager
dns dns.Client
// cached configuration
ipc string
endpoints []netip.Addr
hasIPv4, hasIPv6 bool
wgLock sync.Mutex
}
// New creates a new wireguard handler.
func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
v := core.MustFromContext(ctx)
endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
if err != nil {
return nil, err
}
d := v.GetFeature(dns.ClientType()).(dns.Client)
return &Handler{
conf: conf,
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
dns: d,
ipc: createIPCRequest(conf),
endpoints: endpoints,
hasIPv4: hasIPv4,
hasIPv6: hasIPv6,
}, nil
}
func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
h.wgLock.Lock()
defer h.wgLock.Unlock()
if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
return nil
}
log.Record(&log.GeneralMessage{
Severity: log.Severity_Info,
Content: "switching dialer",
})
if h.net != nil {
_ = h.net.Close()
h.net = nil
}
if h.bind != nil {
_ = h.bind.Close()
h.bind = nil
}
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
bind := &netBindClient{
netBind: netBind{
dns: h.dns,
dnsOption: dns.IPOption{
IPv4Enable: h.hasIPv4,
IPv6Enable: h.hasIPv6,
},
workers: int(h.conf.NumWorkers),
},
dialer: dialer,
reserved: h.conf.Reserved,
}
defer func() {
if err != nil {
_ = bind.Close()
}
}()
h.net, err = h.makeVirtualTun(bind)
if err != nil {
return newError("failed to create virtual tun interface").Base(err)
}
h.bind = bind
return nil
}
// Process implements OutboundHandler.Dispatch().
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified")
}
outbound.Name = "wireguard"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(3)
}
if err := h.processWireGuard(dialer); err != nil {
return err
}
// Destination of the inner request.
destination := outbound.Target
command := protocol.RequestCommandTCP
if destination.Network == net.Network_UDP {
command = protocol.RequestCommandUDP
}
// resolve dns
addr := destination.Address
if addr.Family().IsDomain() {
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
})
{ // Resolve fallback
if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
})
}
}
if err != nil {
return newError("failed to lookup DNS").Base(err)
} else if len(ips) == 0 {
return dns.ErrEmptyResponse
}
addr = net.IPAddress(ips[dice.Roll(len(ips))])
}
var newCtx context.Context
var newCancel context.CancelFunc
if session.TimeoutOnlyFromContext(ctx) {
newCtx, newCancel = context.WithCancel(context.Background())
}
p := h.policyManager.ForLevel(0)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, func() {
cancel()
if newCancel != nil {
newCancel()
}
}, p.Timeouts.ConnectionIdle)
addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
var requestFunc func() error
var responseFunc func() error
if command == protocol.RequestCommandTCP {
conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
if err != nil {
return newError("failed to create TCP connection").Base(err)
}
defer conn.Close()
requestFunc = func() error {
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
}
responseFunc = func() error {
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
}
} else if command == protocol.RequestCommandUDP {
conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
if err != nil {
return newError("failed to create UDP connection").Base(err)
}
defer conn.Close()
requestFunc = func() error {
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
}
responseFunc = func() error {
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
}
}
if newCtx != nil {
ctx = newCtx
}
responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return newError("connection ends").Base(err)
}
return nil
}
// creates a tun interface on netstack given a configuration
func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
t, err := h.conf.createTun()(h.endpoints, int(h.conf.Mtu), nil)
if err != nil {
return nil, err
}
bind.dnsOption.IPv4Enable = h.hasIPv4
bind.dnsOption.IPv6Enable = h.hasIPv6
if err = t.BuildDevice(h.ipc, bind); err != nil {
_ = t.Close()
return nil, err
}
return t, nil
}

View File

@ -23,3 +23,10 @@ func (c *DeviceConfig) fallbackIP4() bool {
func (c *DeviceConfig) fallbackIP6() bool { func (c *DeviceConfig) fallbackIP6() bool {
return c.DomainStrategy == DeviceConfig_FORCE_IP46 return c.DomainStrategy == DeviceConfig_FORCE_IP46
} }
func (c *DeviceConfig) createTun() tunCreator {
if c.KernelMode {
return createKernelTun
}
return createGVisorTun
}

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.31.0 // protoc-gen-go v1.28.1
// protoc v4.23.1 // protoc v4.25.0
// source: proxy/wireguard/config.proto // source: proxy/wireguard/config.proto
package wireguard package wireguard
@ -83,7 +83,7 @@ type PeerConfig struct {
PublicKey string `protobuf:"bytes,1,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"` PublicKey string `protobuf:"bytes,1,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
PreSharedKey string `protobuf:"bytes,2,opt,name=pre_shared_key,json=preSharedKey,proto3" json:"pre_shared_key,omitempty"` PreSharedKey string `protobuf:"bytes,2,opt,name=pre_shared_key,json=preSharedKey,proto3" json:"pre_shared_key,omitempty"`
Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"` Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
KeepAlive int32 `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"` KeepAlive uint32 `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"`
AllowedIps []string `protobuf:"bytes,5,rep,name=allowed_ips,json=allowedIps,proto3" json:"allowed_ips,omitempty"` AllowedIps []string `protobuf:"bytes,5,rep,name=allowed_ips,json=allowedIps,proto3" json:"allowed_ips,omitempty"`
} }
@ -140,7 +140,7 @@ func (x *PeerConfig) GetEndpoint() string {
return "" return ""
} }
func (x *PeerConfig) GetKeepAlive() int32 { func (x *PeerConfig) GetKeepAlive() uint32 {
if x != nil { if x != nil {
return x.KeepAlive return x.KeepAlive
} }
@ -166,6 +166,8 @@ type DeviceConfig struct {
NumWorkers int32 `protobuf:"varint,5,opt,name=num_workers,json=numWorkers,proto3" json:"num_workers,omitempty"` NumWorkers int32 `protobuf:"varint,5,opt,name=num_workers,json=numWorkers,proto3" json:"num_workers,omitempty"`
Reserved []byte `protobuf:"bytes,6,opt,name=reserved,proto3" json:"reserved,omitempty"` Reserved []byte `protobuf:"bytes,6,opt,name=reserved,proto3" json:"reserved,omitempty"`
DomainStrategy DeviceConfig_DomainStrategy `protobuf:"varint,7,opt,name=domain_strategy,json=domainStrategy,proto3,enum=xray.proxy.wireguard.DeviceConfig_DomainStrategy" json:"domain_strategy,omitempty"` DomainStrategy DeviceConfig_DomainStrategy `protobuf:"varint,7,opt,name=domain_strategy,json=domainStrategy,proto3,enum=xray.proxy.wireguard.DeviceConfig_DomainStrategy" json:"domain_strategy,omitempty"`
IsClient bool `protobuf:"varint,8,opt,name=is_client,json=isClient,proto3" json:"is_client,omitempty"`
KernelMode bool `protobuf:"varint,9,opt,name=kernel_mode,json=kernelMode,proto3" json:"kernel_mode,omitempty"`
} }
func (x *DeviceConfig) Reset() { func (x *DeviceConfig) Reset() {
@ -249,6 +251,20 @@ func (x *DeviceConfig) GetDomainStrategy() DeviceConfig_DomainStrategy {
return DeviceConfig_FORCE_IP return DeviceConfig_FORCE_IP
} }
func (x *DeviceConfig) GetIsClient() bool {
if x != nil {
return x.IsClient
}
return false
}
func (x *DeviceConfig) GetKernelMode() bool {
if x != nil {
return x.KernelMode
}
return false
}
var File_proxy_wireguard_config_proto protoreflect.FileDescriptor var File_proxy_wireguard_config_proto protoreflect.FileDescriptor
var file_proxy_wireguard_config_proto_rawDesc = []byte{ var file_proxy_wireguard_config_proto_rawDesc = []byte{
@ -263,10 +279,10 @@ var file_proxy_wireguard_config_proto_rawDesc = []byte{
0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70,
0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x64, 0x70,
0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x6b, 0x65, 0x65, 0x70, 0x5f, 0x61, 0x6c, 0x69, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x6b, 0x65, 0x65, 0x70, 0x5f, 0x61, 0x6c, 0x69,
0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c, 0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c,
0x69, 0x76, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x69, 0x69, 0x76, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x69,
0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65,
0x64, 0x49, 0x70, 0x73, 0x22, 0x8a, 0x03, 0x0a, 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, 0x64, 0x49, 0x70, 0x73, 0x22, 0xc8, 0x03, 0x0a, 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x5f, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x5f,
0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x63, 0x72, 0x65, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x63, 0x72, 0x65,
0x74, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x74, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
@ -285,19 +301,23 @@ var file_proxy_wireguard_config_proto_rawDesc = []byte{
0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, 0x6f, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, 0x6f,
0x6e, 0x66, 0x69, 0x67, 0x2e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
0x65, 0x67, 0x79, 0x52, 0x0e, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x52, 0x0e, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
0x65, 0x67, 0x79, 0x22, 0x5c, 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x65, 0x67, 0x79, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x73, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74,
0x61, 0x74, 0x65, 0x67, 0x79, 0x12, 0x0c, 0x0a, 0x08, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x69, 0x73, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74,
0x50, 0x10, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x12, 0x1f, 0x0a, 0x0b, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x18,
0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x10, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x4d, 0x6f, 0x64,
0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x36, 0x10, 0x65, 0x22, 0x5c, 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
0x03, 0x12, 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x34, 0x10, 0x65, 0x67, 0x79, 0x12, 0x0c, 0x0a, 0x08, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x10,
0x04, 0x42, 0x5e, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x10, 0x01,
0x6f, 0x78, 0x79, 0x2e, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x01, 0x5a, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x10, 0x02, 0x12,
0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x36, 0x10, 0x03, 0x12,
0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x34, 0x10, 0x04, 0x42,
0x2f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0xaa, 0x02, 0x14, 0x58, 0x72, 0x61, 0x5e, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x78,
0x79, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x47, 0x75, 0x61, 0x72, 0x79, 0x2e, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x01, 0x5a, 0x29, 0x67,
0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78,
0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x77,
0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0xaa, 0x02, 0x14, 0x58, 0x72, 0x61, 0x79, 0x2e,
0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x47, 0x75, 0x61, 0x72, 0x64, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (

View File

@ -10,7 +10,7 @@ message PeerConfig {
string public_key = 1; string public_key = 1;
string pre_shared_key = 2; string pre_shared_key = 2;
string endpoint = 3; string endpoint = 3;
int32 keep_alive = 4; uint32 keep_alive = 4;
repeated string allowed_ips = 5; repeated string allowed_ips = 5;
} }
@ -29,4 +29,6 @@ message DeviceConfig {
int32 num_workers = 5; int32 num_workers = 5;
bytes reserved = 6; bytes reserved = 6;
DomainStrategy domain_strategy = 7; DomainStrategy domain_strategy = 7;
bool is_client = 8;
bool kernel_mode = 9;
} }

View File

@ -0,0 +1,230 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
*/
package gvisortun
import (
"context"
"fmt"
"net/netip"
"os"
"syscall"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
type netTun struct {
ep *channel.Endpoint
stack *stack.Stack
events chan tun.Event
incomingPacket chan *buffer.View
mtu int
hasV4, hasV6 bool
}
type Net netTun
func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (tun.Device, *Net, *stack.Stack, error) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
HandleLocal: !promiscuousMode,
}
dev := &netTun{
ep: channel.New(1024, uint32(mtu), ""),
stack: stack.New(opts),
events: make(chan tun.Event, 1),
incomingPacket: make(chan *buffer.View),
mtu: mtu,
}
dev.ep.AddNotify(dev)
tcpipErr := dev.stack.CreateNIC(1, dev.ep)
if tcpipErr != nil {
return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
for _, ip := range localAddresses {
var protoNumber tcpip.NetworkProtocolNumber
if ip.Is4() {
protoNumber = ipv4.ProtocolNumber
} else if ip.Is6() {
protoNumber = ipv6.ProtocolNumber
}
protoAddr := tcpip.ProtocolAddress{
Protocol: protoNumber,
AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
return nil, nil, dev.stack, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
}
if ip.Is4() {
dev.hasV4 = true
} else if ip.Is6() {
dev.hasV6 = true
}
}
if dev.hasV4 {
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
}
if dev.hasV6 {
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
}
if promiscuousMode {
// enable promiscuous mode to handle all packets processed by netstack
dev.stack.SetPromiscuousMode(1, true)
dev.stack.SetSpoofing(1, true)
}
opt := tcpip.CongestionControlOption("cubic")
if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err)
}
dev.events <- tun.EventUp
return dev, (*Net)(dev), dev.stack, nil
}
// BatchSize implements tun.Device
func (tun *netTun) BatchSize() int {
return 1
}
// Name implements tun.Device
func (tun *netTun) Name() (string, error) {
return "go", nil
}
// File implements tun.Device
func (tun *netTun) File() *os.File {
return nil
}
// Events implements tun.Device
func (tun *netTun) Events() <-chan tun.Event {
return tun.events
}
// Read implements tun.Device
func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
view, ok := <-tun.incomingPacket
if !ok {
return 0, os.ErrClosed
}
n, err := view.Read(buf[0][offset:])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// Write implements tun.Device
func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
for _, buf := range buf {
packet := buf[offset:]
if len(packet) == 0 {
continue
}
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
switch packet[0] >> 4 {
case 4:
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
case 6:
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
default:
return 0, syscall.EAFNOSUPPORT
}
}
return len(buf), nil
}
// WriteNotify implements channel.Notification
func (tun *netTun) WriteNotify() {
pkt := tun.ep.Read()
if pkt.IsNil() {
return
}
view := pkt.ToView()
pkt.DecRef()
tun.incomingPacket <- view
}
// Flush implements tun.Device
func (tun *netTun) Flush() error {
return nil
}
// Close implements tun.Device
func (tun *netTun) Close() error {
tun.stack.RemoveNIC(1)
if tun.events != nil {
close(tun.events)
}
tun.ep.Close()
if tun.incomingPacket != nil {
close(tun.incomingPacket)
}
return nil
}
// MTU implements tun.Device
func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
var protoNumber tcpip.NetworkProtocolNumber
if endpoint.Addr().Is4() {
protoNumber = ipv4.ProtocolNumber
} else {
protoNumber = ipv6.ProtocolNumber
}
return tcpip.FullAddress{
NIC: 1,
Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
Port: endpoint.Port(),
}, protoNumber
}
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
fa, pn := convertToFullAddr(addr)
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
}
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
var lfa, rfa *tcpip.FullAddress
var pn tcpip.NetworkProtocolNumber
if laddr.IsValid() || laddr.Port() > 0 {
var addr tcpip.FullAddress
addr, pn = convertToFullAddr(laddr)
lfa = &addr
}
if raddr.IsValid() || raddr.Port() > 0 {
var addr tcpip.FullAddress
addr, pn = convertToFullAddr(raddr)
rfa = &addr
}
return gonet.DialUDP(net.stack, lfa, rfa, pn)
}

181
proxy/wireguard/server.go Normal file
View File

@ -0,0 +1,181 @@
package wireguard
import (
"context"
"errors"
"io"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport/internet/stat"
)
var nullDestination = net.TCPDestination(net.AnyIP, 0)
type Server struct {
bindServer *netBindServer
info routingInfo
policyManager policy.Manager
}
type routingInfo struct {
ctx context.Context
dispatcher routing.Dispatcher
inboundTag *session.Inbound
outboundTag *session.Outbound
contentTag *session.Content
}
func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
v := core.MustFromContext(ctx)
endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
if err != nil {
return nil, err
}
server := &Server{
bindServer: &netBindServer{
netBind: netBind{
dns: v.GetFeature(dns.ClientType()).(dns.Client),
dnsOption: dns.IPOption{
IPv4Enable: hasIPv4,
IPv6Enable: hasIPv6,
},
},
},
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
}
tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
if err != nil {
return nil, err
}
if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil {
_ = tun.Close()
return nil, err
}
return server, nil
}
// Network implements proxy.Inbound.
func (*Server) Network() []net.Network {
return []net.Network{net.Network_UDP}
}
// Process implements proxy.Inbound.
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
s.info = routingInfo{
ctx: core.ToBackgroundDetachedContext(ctx),
dispatcher: dispatcher,
inboundTag: session.InboundFromContext(ctx),
outboundTag: session.OutboundFromContext(ctx),
contentTag: session.ContentFromContext(ctx),
}
ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
if err != nil {
return err
}
nep := ep.(*netEndpoint)
nep.conn = conn
reader := buf.NewPacketReader(conn)
for {
mpayload, err := reader.ReadMultiBuffer()
if err != nil {
return err
}
for _, payload := range mpayload {
v, ok := <-s.bindServer.readQueue
if !ok {
return nil
}
i, err := payload.Read(v.buff)
v.bytes = i
v.endpoint = nep
v.err = err
v.waiter.Done()
if err != nil && errors.Is(err, io.EOF) {
nep.conn = nil
return nil
}
}
}
}
func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
if s.info.dispatcher == nil {
newError("unexpected: dispatcher == nil").AtError().WriteToLog()
return
}
defer conn.Close()
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
plcy := s.policyManager.ForLevel(0)
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: nullDestination,
To: dest,
Status: log.AccessAccepted,
Reason: "",
})
if s.info.inboundTag != nil {
ctx = session.ContextWithInbound(ctx, s.info.inboundTag)
}
if s.info.outboundTag != nil {
ctx = session.ContextWithOutbound(ctx, s.info.outboundTag)
}
if s.info.contentTag != nil {
ctx = session.ContextWithContent(ctx, s.info.contentTag)
}
link, err := s.info.dispatcher.Dispatch(ctx, dest)
if err != nil {
newError("dispatch connection").Base(err).AtError().WriteToLog()
}
defer cancel()
requestDone := func() error {
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
return newError("failed to transport all TCP request").Base(err)
}
return nil
}
responseDone := func() error {
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
return newError("failed to transport all TCP response").Base(err)
}
return nil
}
requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
newError("connection ends").Base(err).AtDebug().WriteToLog()
return
}
}

View File

@ -10,14 +10,26 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
"github.com/xtls/xray-core/common/log" "github.com/xtls/xray-core/common/log"
xnet "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/proxy/wireguard/gvisortun"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
type tunCreator func(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error)
type promiscuousModeHandler func(dest xnet.Destination, conn net.Conn)
type Tunnel interface { type Tunnel interface {
BuildDevice(ipc string, bind conn.Bind) error BuildDevice(ipc string, bind conn.Bind) error
DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
@ -103,3 +115,91 @@ func CalculateInterfaceName(name string) (tunName string) {
tunName = fmt.Sprintf("%s%d", tunName, tunIndex) tunName = fmt.Sprintf("%s%d", tunName, tunIndex)
return return
} }
var _ Tunnel = (*gvisorNet)(nil)
type gvisorNet struct {
tunnel
net *gvisortun.Net
}
func (g *gvisorNet) Close() error {
return g.tunnel.Close()
}
func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
net.Conn, error,
) {
return g.net.DialContextTCPAddrPort(ctx, addr)
}
func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
return g.net.DialUDPAddrPort(laddr, raddr)
}
func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
out := &gvisorNet{}
tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
if err != nil {
return nil, err
}
if handler != nil {
// handler is only used for promiscuous mode
// capture all packets and send to handler
tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
go func(r *tcp.ForwarderRequest) {
var (
wq waiter.Queue
id = r.ID()
)
// Perform a TCP three-way handshake.
ep, err := r.CreateEndpoint(&wq)
if err != nil {
newError(err.String()).AtError().WriteToLog()
r.Complete(true)
return
}
r.Complete(false)
defer ep.Close()
// enable tcp keep-alive to prevent hanging connections
ep.SocketOptions().SetKeepAlive(true)
// local address is actually destination
handler(xnet.TCPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
}(r)
})
stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) {
go func(r *udp.ForwarderRequest) {
var (
wq waiter.Queue
id = r.ID()
)
ep, err := r.CreateEndpoint(&wq)
if err != nil {
newError(err.String()).AtError().WriteToLog()
return
}
defer ep.Close()
// prevents hanging connections and ensure timely release
ep.SocketOptions().SetLinger(tcpip.LingerOption{
Enabled: true,
Timeout: 15 * time.Second,
})
handler(xnet.UDPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewUDPConn(stack, &wq, ep))
}(r)
})
stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
}
out.tun, out.net = tun, n
return out, nil
}

View File

@ -1,42 +1,16 @@
//go:build !linux //go:build !linux || android
package wireguard package wireguard
import ( import (
"context" "errors"
"net"
"net/netip" "net/netip"
"golang.zx2c4.com/wireguard/tun/netstack"
) )
var _ Tunnel = (*gvisorNet)(nil) func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) {
return nil, errors.New("not implemented")
type gvisorNet struct {
tunnel
net *netstack.Net
} }
func (g *gvisorNet) Close() error { func KernelTunSupported() bool {
return g.tunnel.Close() return false
}
func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
net.Conn, error,
) {
return g.net.DialContextTCPAddrPort(ctx, addr)
}
func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
return g.net.DialUDPAddrPort(laddr, raddr)
}
func CreateTun(localAddresses []netip.Addr, mtu int) (Tunnel, error) {
out := &gvisorNet{}
tun, n, err := netstack.CreateNetTUN(localAddresses, nil, mtu)
if err != nil {
return nil, err
}
out.tun, out.net = tun, n
return out, nil
} }

View File

@ -1,3 +1,5 @@
//go:build linux && !android
package wireguard package wireguard
import ( import (
@ -69,7 +71,11 @@ func (d *deviceNet) Close() (err error) {
return errors.Join(errs...) return errors.Join(errs...)
} }
func CreateTun(localAddresses []netip.Addr, mtu int) (t Tunnel, err error) { func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) {
if handler != nil {
return nil, newError("TODO: support promiscuous mode")
}
var v4, v6 *netip.Addr var v4, v6 *netip.Addr
for _, prefixes := range localAddresses { for _, prefixes := range localAddresses {
if v4 == nil && prefixes.Is4() { if v4 == nil && prefixes.Is4() {
@ -221,3 +227,11 @@ func CreateTun(localAddresses []netip.Addr, mtu int) (t Tunnel, err error) {
out.tun = wgt out.tun = wgt
return out, nil return out, nil
} }
func KernelTunSupported() bool {
// run a superuser permission check to check
// if the current user has the sufficient permission
// to create a tun device.
return unix.Geteuid() == 0 // 0 means root
}

View File

@ -1,326 +1,111 @@
/*
Some of codes are copied from https://github.com/octeep/wireproxy, license below.
Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package wireguard package wireguard
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
stdnet "net"
"net/netip" "net/netip"
"strings" "strings"
"sync"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/dice"
"github.com/xtls/xray-core/common/log" "github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net" "golang.zx2c4.com/wireguard/device"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet"
) )
// Handler is an outbound connection that silently swallow the entire payload. //go:generate go run github.com/xtls/xray-core/common/errors/errorgen
type Handler struct {
conf *DeviceConfig
net Tunnel
bind *netBindClient
policyManager policy.Manager
dns dns.Client
// cached configuration
ipc string
endpoints []netip.Addr
hasIPv4, hasIPv6 bool
wgLock sync.Mutex
}
// New creates a new wireguard handler.
func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
v := core.MustFromContext(ctx)
endpoints, err := parseEndpoints(conf)
if err != nil {
return nil, err
}
hasIPv4, hasIPv6 := false, false
for _, e := range endpoints {
if e.Is4() {
hasIPv4 = true
}
if e.Is6() {
hasIPv6 = true
}
}
d := v.GetFeature(dns.ClientType()).(dns.Client)
return &Handler{
conf: conf,
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
dns: d,
ipc: createIPCRequest(conf, d, hasIPv6),
endpoints: endpoints,
hasIPv4: hasIPv4,
hasIPv6: hasIPv6,
}, nil
}
func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
h.wgLock.Lock()
defer h.wgLock.Unlock()
if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
return nil
}
var wgLogger = &device.Logger{
Verbosef: func(format string, args ...any) {
log.Record(&log.GeneralMessage{ log.Record(&log.GeneralMessage{
Severity: log.Severity_Info, Severity: log.Severity_Debug,
Content: "switching dialer", Content: fmt.Sprintf(format, args...),
}) })
},
if h.net != nil { Errorf: func(format string, args ...any) {
_ = h.net.Close() log.Record(&log.GeneralMessage{
h.net = nil Severity: log.Severity_Error,
} Content: fmt.Sprintf(format, args...),
if h.bind != nil { })
_ = h.bind.Close() },
h.bind = nil
}
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
bind := &netBindClient{
dialer: dialer,
workers: int(h.conf.NumWorkers),
dns: h.dns,
reserved: h.conf.Reserved,
}
defer func() {
if err != nil {
_ = bind.Close()
}
}()
h.net, err = h.makeVirtualTun(bind)
if err != nil {
return newError("failed to create virtual tun interface").Base(err)
}
h.bind = bind
return nil
} }
// Process implements OutboundHandler.Dispatch(). func init() {
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
outbound := session.OutboundFromContext(ctx) deviceConfig := config.(*DeviceConfig)
if outbound == nil || !outbound.Target.IsValid() { if deviceConfig.IsClient {
return newError("target not specified") return New(ctx, deviceConfig)
} else {
return NewServer(ctx, deviceConfig)
} }
outbound.Name = "wireguard" }))
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(3)
}
if err := h.processWireGuard(dialer); err != nil {
return err
}
// Destination of the inner request.
destination := outbound.Target
command := protocol.RequestCommandTCP
if destination.Network == net.Network_UDP {
command = protocol.RequestCommandUDP
}
// resolve dns
addr := destination.Address
if addr.Family().IsDomain() {
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
})
{ // Resolve fallback
if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
})
}
}
if err != nil {
return newError("failed to lookup DNS").Base(err)
} else if len(ips) == 0 {
return dns.ErrEmptyResponse
}
addr = net.IPAddress(ips[dice.Roll(len(ips))])
}
var newCtx context.Context
var newCancel context.CancelFunc
if session.TimeoutOnlyFromContext(ctx) {
newCtx, newCancel = context.WithCancel(context.Background())
}
p := h.policyManager.ForLevel(0)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, func() {
cancel()
if newCancel != nil {
newCancel()
}
}, p.Timeouts.ConnectionIdle)
addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
var requestFunc func() error
var responseFunc func() error
if command == protocol.RequestCommandTCP {
conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
if err != nil {
return newError("failed to create TCP connection").Base(err)
}
defer conn.Close()
requestFunc = func() error {
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
}
responseFunc = func() error {
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
}
} else if command == protocol.RequestCommandUDP {
conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
if err != nil {
return newError("failed to create UDP connection").Base(err)
}
defer conn.Close()
requestFunc = func() error {
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
}
responseFunc = func() error {
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
}
}
if newCtx != nil {
ctx = newCtx
}
responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return newError("connection ends").Base(err)
}
return nil
}
// serialize the config into an IPC request
func createIPCRequest(conf *DeviceConfig, d dns.Client, resolveEndPointToV4 bool) string {
var request bytes.Buffer
request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
for _, peer := range conf.Peers {
endpoint := peer.Endpoint
host, port, err := net.SplitHostPort(endpoint)
if resolveEndPointToV4 && err == nil {
_, err = netip.ParseAddr(host)
if err != nil {
ipList, err := d.LookupIP(host, dns.IPOption{IPv4Enable: true, IPv6Enable: false})
if err == nil && len(ipList) > 0 {
endpoint = stdnet.JoinHostPort(ipList[0].String(), port)
}
}
}
request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n",
peer.PublicKey, endpoint, peer.KeepAlive, peer.PreSharedKey))
for _, ip := range peer.AllowedIps {
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
}
}
return request.String()[:request.Len()]
} }
// convert endpoint string to netip.Addr // convert endpoint string to netip.Addr
func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) { func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, bool, bool, error) {
var hasIPv4, hasIPv6 bool
endpoints := make([]netip.Addr, len(conf.Endpoint)) endpoints := make([]netip.Addr, len(conf.Endpoint))
for i, str := range conf.Endpoint { for i, str := range conf.Endpoint {
var addr netip.Addr var addr netip.Addr
if strings.Contains(str, "/") { if strings.Contains(str, "/") {
prefix, err := netip.ParsePrefix(str) prefix, err := netip.ParsePrefix(str)
if err != nil { if err != nil {
return nil, err return nil, false, false, err
} }
addr = prefix.Addr() addr = prefix.Addr()
if prefix.Bits() != addr.BitLen() { if prefix.Bits() != addr.BitLen() {
return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6") return nil, false, false, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6")
} }
} else { } else {
var err error var err error
addr, err = netip.ParseAddr(str) addr, err = netip.ParseAddr(str)
if err != nil { if err != nil {
return nil, err return nil, false, false, err
} }
} }
endpoints[i] = addr endpoints[i] = addr
if addr.Is4() {
hasIPv4 = true
} else if addr.Is6() {
hasIPv6 = true
}
} }
return endpoints, nil return endpoints, hasIPv4, hasIPv6, nil
} }
// creates a tun interface on netstack given a configuration // serialize the config into an IPC request
func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) { func createIPCRequest(conf *DeviceConfig) string {
t, err := CreateTun(h.endpoints, int(h.conf.Mtu)) var request strings.Builder
if err != nil {
return nil, err request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
if !conf.IsClient {
// placeholder, we'll handle actual port listening on Xray
request.WriteString("listen_port=1337\n")
} }
bind.dnsOption.IPv4Enable = h.hasIPv4 for _, peer := range conf.Peers {
bind.dnsOption.IPv6Enable = h.hasIPv6 if peer.PublicKey != "" {
request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey))
if err = t.BuildDevice(h.ipc, bind); err != nil {
_ = t.Close()
return nil, err
} }
return t, nil
}
func init() { if peer.PreSharedKey != "" {
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey))
return New(ctx, config.(*DeviceConfig)) }
}))
if peer.Endpoint != "" {
request.WriteString(fmt.Sprintf("endpoint=%s\n", peer.Endpoint))
}
for _, ip := range peer.AllowedIps {
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
}
if peer.KeepAlive != 0 {
request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive))
}
}
return request.String()[:request.Len()]
} }