From 5ae3791a8e62cdb6ef7efdf6489a74679d8528b0 Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Sun, 12 Nov 2023 15:10:01 -0500 Subject: [PATCH] feat : upgrade wireguard go sdk (#2716) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: kunson Co-authored-by: 世界 --- go.mod | 10 +- go.sum | 23 ++- proxy/wireguard/bind.go | 78 ++++++-- proxy/wireguard/tun.go | 348 +++++++-------------------------- proxy/wireguard/tun_default.go | 42 ++++ proxy/wireguard/tun_linux.go | 223 +++++++++++++++++++++ proxy/wireguard/wireguard.go | 152 ++++++++------ 7 files changed, 514 insertions(+), 362 deletions(-) create mode 100644 proxy/wireguard/tun_default.go create mode 100644 proxy/wireguard/tun_linux.go diff --git a/go.mod b/go.mod index 81d90e10..f3abbad1 100644 --- a/go.mod +++ b/go.mod @@ -14,19 +14,19 @@ require ( github.com/refraction-networking/utls v1.5.4 github.com/sagernet/sing v0.2.17 github.com/sagernet/sing-shadowsocks v0.2.5 - github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c github.com/seiflotfy/cuckoofilter v0.0.0-20220411075957-e3b120b3f5fb github.com/stretchr/testify v1.8.4 github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e + github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3 github.com/xtls/reality v0.0.0-20230828171259-e426190d57f6 go4.org/netipx v0.0.0-20230824141953-6213f710f925 golang.org/x/crypto v0.15.0 golang.org/x/net v0.18.0 golang.org/x/sync v0.5.0 golang.org/x/sys v0.14.0 + golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb google.golang.org/grpc v1.59.0 google.golang.org/protobuf v1.31.0 - gvisor.dev/gvisor v0.0.0-20230822212503-5bf4e5f98744 h12.io/socks v1.0.3 lukechampine.com/blake3 v1.2.1 ) @@ -48,14 +48,16 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qtls-go1-20 v0.4.1 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect - go.uber.org/atomic v1.11.0 // indirect + github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect go.uber.org/mock v0.3.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/mod v0.12.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.3.0 // indirect golang.org/x/tools v0.13.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13 // indirect + golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect ) diff --git a/go.sum b/go.sum index 4876802d..3195b2b6 100644 --- a/go.sum +++ b/go.sum @@ -129,8 +129,6 @@ github.com/sagernet/sing v0.2.17 h1:vMPKb3MV0Aa5ws4dCJkRI8XEjrsUcDn810czd0FwmzI= github.com/sagernet/sing v0.2.17/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= github.com/sagernet/sing-shadowsocks v0.2.5 h1:qxIttos4xu6ii7MTVJYA8EFQR7Q3KG6xMqmLJIFtBaY= github.com/sagernet/sing-shadowsocks v0.2.5/go.mod h1:MGWGkcU2xW2G2mfArT9/QqpVLOGU+dBaahZCtPHdt7A= -github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c h1:vK2wyt9aWYHHvNLWniwijBu/n4pySypiKRhN32u/JGo= -github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c/go.mod h1:euOmN6O5kk9dQmgSS8Df4psAl3TCjxOz0NW60EWkSaI= github.com/seiflotfy/cuckoofilter v0.0.0-20220411075957-e3b120b3f5fb h1:XfLJSPIOUX+osiMraVgIrMR27uMXnRJWGm1+GL8/63U= github.com/seiflotfy/cuckoofilter v0.0.0-20220411075957-e3b120b3f5fb/go.mod h1:bR6DqgcAl1zTcOX8/pE2Qkj9XO00eCNqmKb7lXP8EAg= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= @@ -168,12 +166,15 @@ github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e h1:5QefA066A1tF github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e/go.mod h1:5t19P9LBIrNamL6AcMQOncg/r10y3Pc01AbHeMhwlpU= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3 h1:tkMT5pTye+1NlKIXETU78NXw0fyjnaNHmJyyLyzw8+U= +github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3/go.mod h1:cAAsePK2e15YDAMJNyOpGYEWNe4sIghTY7gpz4cX/Ik= +github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg= +github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/xtls/reality v0.0.0-20230828171259-e426190d57f6 h1:T+YCYGfFdzyaKTDCdZn/hEiKvsw6yUfd+e4hze0rCUw= github.com/xtls/reality v0.0.0-20230828171259-e426190d57f6/go.mod h1:rkuAY1S9F8eI8gDiPDYvACE8e2uwkyg8qoOTuwWov7Y= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= -go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= -go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= @@ -227,9 +228,11 @@ golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220804214406-8e32c043e418/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= @@ -256,6 +259,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= +golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb h1:c5tyN8sSp8jSDxdCCDXVOpJwYXXhmTkNMt+g0zTSOic= +golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= @@ -268,8 +275,8 @@ google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoA google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13 h1:N3bU/SQDCDyD6R528GJ/PwW9KjYcJA3dgyH+MovAkIM= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13/go.mod h1:KSqppvjFjtoCI+KGd4PELB0qLNxdJHRGqRI09mB6pQA= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d h1:uvYuEyMHKNt+lT4K3bN6fGswmK8qSvcreM3BwjDh+y4= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d/go.mod h1:+Bk1OCOj40wS2hwAMA+aCW9ypzm63QTBBHp6lQ3p+9M= google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= @@ -293,8 +300,8 @@ gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= -gvisor.dev/gvisor v0.0.0-20230822212503-5bf4e5f98744 h1:tE44CyJgxEGzoPtHs9GI7ddKdgEGCREQBP54AmaVM+I= -gvisor.dev/gvisor v0.0.0-20230822212503-5bf4e5f98744/go.mod h1:lYEMhXbxgudVhALYsMQrBaUAjM3NMinh8mKL1CJv7rc= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= h12.io/socks v1.0.3 h1:Ka3qaQewws4j4/eDQnOdpr4wXsC//dXtWvftlIcCQUo= h12.io/socks v1.0.3/go.mod h1:AIhxy1jOId/XCz9BO+EIgNL2rQiPTBNnOfnVnQ+3Eck= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/proxy/wireguard/bind.go b/proxy/wireguard/bind.go index 527f0e74..c224dc56 100644 --- a/proxy/wireguard/bind.go +++ b/proxy/wireguard/bind.go @@ -9,7 +9,8 @@ import ( "strconv" "sync" - "github.com/sagernet/wireguard-go/conn" + "golang.zx2c4.com/wireguard/conn" + xnet "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/transport/internet" @@ -36,7 +37,7 @@ type netBindClient struct { readQueue chan *netReadInfo } -func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) { +func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) { ipStr, port, _, err := splitAddrPort(s) if err != nil { return nil, err @@ -44,7 +45,7 @@ func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) { var addr net.IP if IsDomainName(ipStr) { - ips, err := n.dns.LookupIP(ipStr, n.dnsOption) + ips, err := bind.dns.LookupIP(ipStr, bind.dnsOption) if err != nil { return nil, err } else if len(ips) == 0 { @@ -79,22 +80,22 @@ func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) { func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { bind.readQueue = make(chan *netReadInfo) - fun := func(buff []byte) (cap int, ep conn.Endpoint, err error) { + fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { defer func() { if r := recover(); r != nil { - cap = 0 - ep = nil + n = 0 err = errors.New("channel closed") } }() r := &netReadInfo{ - buff: buff, + buff: bufs[0], } r.waiter.Add(1) bind.readQueue <- r r.waiter.Wait() // wait read goroutine done, or we will miss the result - return r.bytes, r.endpoint, r.err + sizes[0], eps[0] = r.bytes, r.endpoint + return 1, r.err } workers := bind.workers if workers <= 0 { @@ -150,7 +151,7 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error { return nil } -func (bind *netBindClient) Send(buff []byte, endpoint conn.Endpoint) error { +func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error { var err error nend, ok := endpoint.(*netEndpoint) @@ -165,19 +166,25 @@ func (bind *netBindClient) Send(buff []byte, endpoint conn.Endpoint) error { } } - if len(buff) > 3 && len(bind.reserved) == 3 { - copy(buff[1:], bind.reserved) + for _, buff := range buff { + if len(buff) > 3 && len(bind.reserved) == 3 { + copy(buff[1:], bind.reserved) + } + if _, err = nend.conn.Write(buff); err != nil { + return err + } } - - _, err = nend.conn.Write(buff) - - return err + return nil } func (bind *netBindClient) SetMark(mark uint32) error { return nil } +func (bind *netBindClient) BatchSize() int { + return 1 +} + type netEndpoint struct { dst xnet.Destination conn net.Conn @@ -264,3 +271,44 @@ func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) { return ip, port, v6, nil } + +func IsDomainName(s string) bool { + l := len(s) + if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { + return false + } + last := byte('.') + nonNumeric := false + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': + nonNumeric = true + partlen++ + case '0' <= c && c <= '9': + partlen++ + case c == '-': + if last == '.' { + return false + } + partlen++ + nonNumeric = true + case c == '.': + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + if last == '-' || partlen > 63 { + return false + } + return nonNumeric +} diff --git a/proxy/wireguard/tun.go b/proxy/wireguard/tun.go index ed6e434f..c320d0d0 100644 --- a/proxy/wireguard/tun.go +++ b/proxy/wireguard/tun.go @@ -1,303 +1,105 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved. - */ - package wireguard import ( "context" + "errors" "fmt" "net" "net/netip" - "os" + "runtime" + "strconv" + "strings" + "sync" - "github.com/sagernet/wireguard-go/tun" - "github.com/xtls/xray-core/features/dns" - "gvisor.dev/gvisor/pkg/buffer" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "github.com/xtls/xray-core/common/log" + + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" ) -type netTun struct { - ep *channel.Endpoint - stack *stack.Stack - events chan tun.Event - incomingPacket chan *buffer.View - mtu int - dnsClient dns.Client - hasV4, hasV6 bool +type Tunnel interface { + BuildDevice(ipc string, bind conn.Bind) error + DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error) + DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) + Close() error } -type Net netTun - -func CreateNetTUN(localAddresses []netip.Addr, dnsClient dns.Client, mtu int) (tun.Device, *Net, error) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, - HandleLocal: true, - } - dev := &netTun{ - ep: channel.New(1024, uint32(mtu), ""), - stack: stack.New(opts), - events: make(chan tun.Event, 10), - incomingPacket: make(chan *buffer.View), - dnsClient: dnsClient, - mtu: mtu, - } - dev.ep.AddNotify(dev) - tcpipErr := dev.stack.CreateNIC(1, dev.ep) - if tcpipErr != nil { - return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) - } - for _, ip := range localAddresses { - var protoNumber tcpip.NetworkProtocolNumber - if ip.Is4() { - protoNumber = ipv4.ProtocolNumber - } else if ip.Is6() { - protoNumber = ipv6.ProtocolNumber - } - protoAddr := tcpip.ProtocolAddress{ - Protocol: protoNumber, - AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), - } - tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) - if tcpipErr != nil { - return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) - } - if ip.Is4() { - dev.hasV4 = true - } else if ip.Is6() { - dev.hasV6 = true - } - } - if dev.hasV4 { - dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) - } - if dev.hasV6 { - dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) - } - - dev.events <- tun.EventUp - return dev, (*Net)(dev), nil +type tunnel struct { + tun tun.Device + device *device.Device + rw sync.Mutex } -func (tun *netTun) Name() (string, error) { - return "go", nil -} +func (t *tunnel) BuildDevice(ipc string, bind conn.Bind) (err error) { + t.rw.Lock() + defer t.rw.Unlock() -func (tun *netTun) File() *os.File { + if t.device != nil { + return errors.New("device is already initialized") + } + + logger := &device.Logger{ + Verbosef: func(format string, args ...any) { + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Debug, + Content: fmt.Sprintf(format, args...), + }) + }, + Errorf: func(format string, args ...any) { + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Error, + Content: fmt.Sprintf(format, args...), + }) + }, + } + + t.device = device.NewDevice(t.tun, bind, logger) + if err = t.device.IpcSet(ipc); err != nil { + return err + } + if err = t.device.Up(); err != nil { + return err + } return nil } -func (tun *netTun) Events() chan tun.Event { - return tun.events -} +func (t *tunnel) Close() (err error) { + t.rw.Lock() + defer t.rw.Unlock() -func (tun *netTun) Read(buf []byte, offset int) (int, error) { - view, ok := <-tun.incomingPacket - if !ok { - return 0, os.ErrClosed + if t.device == nil { + return nil } - return view.Read(buf[offset:]) + t.device.Close() + t.device = nil + err = t.tun.Close() + t.tun = nil + return nil } -func (tun *netTun) Write(buf []byte, offset int) (int, error) { - packet := buf[offset:] - if len(packet) == 0 { - return 0, nil +func CalculateInterfaceName(name string) (tunName string) { + if runtime.GOOS == "darwin" { + tunName = "utun" + } else if name != "" { + tunName = name + } else { + tunName = "tun" } - - pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) - switch packet[0] >> 4 { - case 4: - tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) - case 6: - tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) - } - - return len(buf), nil -} - -func (tun *netTun) WriteNotify() { - pkt := tun.ep.Read() - if pkt == nil { + interfaces, err := net.Interfaces() + if err != nil { return } - - view := pkt.ToView() - pkt.DecRef() - - tun.incomingPacket <- view -} - -func (tun *netTun) Flush() error { - return nil -} - -func (tun *netTun) Close() error { - tun.stack.RemoveNIC(1) - - if tun.events != nil { - close(tun.events) - } - - tun.ep.Close() - - if tun.incomingPacket != nil { - close(tun.incomingPacket) - } - - return nil -} - -func (tun *netTun) MTU() (int, error) { - return tun.mtu, nil -} - -func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { - var protoNumber tcpip.NetworkProtocolNumber - if endpoint.Addr().Is4() { - protoNumber = ipv4.ProtocolNumber - } else { - protoNumber = ipv6.ProtocolNumber - } - return tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), - Port: endpoint.Port(), - }, protoNumber -} - -func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { - fa, pn := convertToFullAddr(addr) - return gonet.DialContextTCP(ctx, net.stack, fa, pn) -} - -func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { - if addr == nil { - return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) - } - ip, _ := netip.AddrFromSlice(addr.IP) - return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) -} - -func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { - fa, pn := convertToFullAddr(addr) - return gonet.DialTCP(net.stack, fa, pn) -} - -func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { - if addr == nil { - return net.DialTCPAddrPort(netip.AddrPort{}) - } - ip, _ := netip.AddrFromSlice(addr.IP) - return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) -} - -func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { - fa, pn := convertToFullAddr(addr) - return gonet.ListenTCP(net.stack, fa, pn) -} - -func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { - if addr == nil { - return net.ListenTCPAddrPort(netip.AddrPort{}) - } - ip, _ := netip.AddrFromSlice(addr.IP) - return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) -} - -func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { - var lfa, rfa *tcpip.FullAddress - var pn tcpip.NetworkProtocolNumber - if laddr.IsValid() || laddr.Port() > 0 { - var addr tcpip.FullAddress - addr, pn = convertToFullAddr(laddr) - lfa = &addr - } - if raddr.IsValid() || raddr.Port() > 0 { - var addr tcpip.FullAddress - addr, pn = convertToFullAddr(raddr) - rfa = &addr - } - return gonet.DialUDP(net.stack, lfa, rfa, pn) -} - -func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { - return net.DialUDPAddrPort(laddr, netip.AddrPort{}) -} - -func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { - var la, ra netip.AddrPort - if laddr != nil { - ip, _ := netip.AddrFromSlice(laddr.IP) - la = netip.AddrPortFrom(ip, uint16(laddr.Port)) - } - if raddr != nil { - ip, _ := netip.AddrFromSlice(raddr.IP) - ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) - } - return net.DialUDPAddrPort(la, ra) -} - -func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { - return net.DialUDP(laddr, nil) -} - -func (n *Net) HasV4() bool { - return n.hasV4 -} - -func (n *Net) HasV6() bool { - return n.hasV6 -} - -func IsDomainName(s string) bool { - l := len(s) - if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { - return false - } - last := byte('.') - nonNumeric := false - partlen := 0 - for i := 0; i < len(s); i++ { - c := s[i] - switch { - default: - return false - case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': - nonNumeric = true - partlen++ - case '0' <= c && c <= '9': - partlen++ - case c == '-': - if last == '.' { - return false + var tunIndex int + for _, netInterface := range interfaces { + if strings.HasPrefix(netInterface.Name, tunName) { + index, parseErr := strconv.ParseInt(netInterface.Name[len(tunName):], 10, 16) + if parseErr == nil { + tunIndex = int(index) + 1 } - partlen++ - nonNumeric = true - case c == '.': - if last == '.' || last == '-' { - return false - } - if partlen > 63 || partlen == 0 { - return false - } - partlen = 0 } - last = c } - if last == '-' || partlen > 63 { - return false - } - return nonNumeric + tunName = fmt.Sprintf("%s%d", tunName, tunIndex) + return } diff --git a/proxy/wireguard/tun_default.go b/proxy/wireguard/tun_default.go new file mode 100644 index 00000000..07f21272 --- /dev/null +++ b/proxy/wireguard/tun_default.go @@ -0,0 +1,42 @@ +//go:build !linux + +package wireguard + +import ( + "context" + "net" + "net/netip" + + "golang.zx2c4.com/wireguard/tun/netstack" +) + +var _ Tunnel = (*gvisorNet)(nil) + +type gvisorNet struct { + tunnel + net *netstack.Net +} + +func (g *gvisorNet) Close() error { + return g.tunnel.Close() +} + +func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) ( + net.Conn, error, +) { + return g.net.DialContextTCPAddrPort(ctx, addr) +} + +func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) { + return g.net.DialUDPAddrPort(laddr, raddr) +} + +func CreateTun(localAddresses []netip.Addr, mtu int) (Tunnel, error) { + out := &gvisorNet{} + tun, n, err := netstack.CreateNetTUN(localAddresses, nil, mtu) + if err != nil { + return nil, err + } + out.tun, out.net = tun, n + return out, nil +} diff --git a/proxy/wireguard/tun_linux.go b/proxy/wireguard/tun_linux.go new file mode 100644 index 00000000..ec940c56 --- /dev/null +++ b/proxy/wireguard/tun_linux.go @@ -0,0 +1,223 @@ +package wireguard + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "os" + + "golang.org/x/sys/unix" + + "github.com/sagernet/sing/common/control" + "github.com/vishvananda/netlink" + wgtun "golang.zx2c4.com/wireguard/tun" +) + +type deviceNet struct { + tunnel + dialer net.Dialer + + handle *netlink.Handle + linkAddrs []netlink.Addr + routes []*netlink.Route + rules []*netlink.Rule +} + +func newDeviceNet(interfaceName string) *deviceNet { + var dialer net.Dialer + bindControl := control.BindToInterface(control.DefaultInterfaceFinder(), interfaceName, -1) + dialer.Control = control.Append(dialer.Control, bindControl) + return &deviceNet{dialer: dialer} +} + +func (d *deviceNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) ( + net.Conn, error, +) { + return d.dialer.DialContext(ctx, "tcp", addr.String()) +} + +func (d *deviceNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) { + dialer := d.dialer + dialer.LocalAddr = &net.UDPAddr{IP: laddr.Addr().AsSlice(), Port: int(laddr.Port())} + return dialer.DialContext(context.Background(), "udp", raddr.String()) +} + +func (d *deviceNet) Close() (err error) { + var errs []error + for _, rule := range d.rules { + if err = d.handle.RuleDel(rule); err != nil { + errs = append(errs, fmt.Errorf("failed to delete rule: %w", err)) + } + } + for _, route := range d.routes { + if err = d.handle.RouteDel(route); err != nil { + errs = append(errs, fmt.Errorf("failed to delete route: %w", err)) + } + } + if err = d.tunnel.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close tunnel: %w", err)) + } + if d.handle != nil { + d.handle.Close() + d.handle = nil + } + if len(errs) == 0 { + return nil + } + return errors.Join(errs...) +} + +func CreateTun(localAddresses []netip.Addr, mtu int) (t Tunnel, err error) { + var v4, v6 *netip.Addr + for _, prefixes := range localAddresses { + if v4 == nil && prefixes.Is4() { + x := prefixes + v4 = &x + } + if v6 == nil && prefixes.Is6() { + x := prefixes + v6 = &x + } + } + + writeSysctlZero := func(path string) error { + _, err := os.Stat(path) + if os.IsNotExist(err) { + return nil + } + if err != nil { + return err + } + return os.WriteFile(path, []byte("0"), 0o644) + } + + // system configs. + if v4 != nil { + if err = writeSysctlZero("/proc/sys/net/ipv4/conf/all/rp_filter"); err != nil { + return nil, fmt.Errorf("failed to disable ipv4 rp_filter for all: %w", err) + } + } + if v6 != nil { + if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/disable_ipv6"); err != nil { + return nil, fmt.Errorf("failed to enable ipv6: %w", err) + } + if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/rp_filter"); err != nil { + return nil, fmt.Errorf("failed to disable ipv6 rp_filter for all: %w", err) + } + } + + n := CalculateInterfaceName("wg") + wgt, err := wgtun.CreateTUN(n, mtu) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + _ = wgt.Close() + } + }() + + // disable linux rp_filter for tunnel device to avoid packet drop. + // the operation require root privilege on container require '--privileged' flag. + if v4 != nil { + if err = writeSysctlZero("/proc/sys/net/ipv4/conf/" + n + "/rp_filter"); err != nil { + return nil, fmt.Errorf("failed to disable ipv4 rp_filter for tunnel: %w", err) + } + } + if v6 != nil { + if err = writeSysctlZero("/proc/sys/net/ipv6/conf/" + n + "/rp_filter"); err != nil { + return nil, fmt.Errorf("failed to disable ipv6 rp_filter for tunnel: %w", err) + } + } + + ipv6TableIndex := 1023 + if v6 != nil { + r := &netlink.Route{Table: ipv6TableIndex} + for { + routeList, fErr := netlink.RouteListFiltered(netlink.FAMILY_V6, r, netlink.RT_FILTER_TABLE) + if len(routeList) == 0 || fErr != nil { + break + } + ipv6TableIndex-- + if ipv6TableIndex < 0 { + return nil, fmt.Errorf("failed to find available ipv6 table index") + } + } + } + + out := newDeviceNet(n) + out.handle, err = netlink.NewHandle() + if err != nil { + return nil, err + } + defer func() { + if err != nil { + _ = out.Close() + } + }() + + l, err := netlink.LinkByName(n) + if err != nil { + return nil, err + } + + if v4 != nil { + addr := netlink.Addr{ + IPNet: &net.IPNet{ + IP: v4.AsSlice(), + Mask: net.CIDRMask(v4.BitLen(), v4.BitLen()), + }, + } + out.linkAddrs = append(out.linkAddrs, addr) + } + if v6 != nil { + addr := netlink.Addr{ + IPNet: &net.IPNet{ + IP: v6.AsSlice(), + Mask: net.CIDRMask(v6.BitLen(), v6.BitLen()), + }, + } + out.linkAddrs = append(out.linkAddrs, addr) + + rt := &netlink.Route{ + LinkIndex: l.Attrs().Index, + Dst: &net.IPNet{ + IP: net.IPv6zero, + Mask: net.CIDRMask(0, 128), + }, + Table: ipv6TableIndex, + } + out.routes = append(out.routes, rt) + + r := netlink.NewRule() + r.Table, r.Family, r.Src = ipv6TableIndex, unix.AF_INET6, addr.IPNet + out.rules = append(out.rules, r) + } + + for _, addr := range out.linkAddrs { + if err = out.handle.AddrAdd(l, &addr); err != nil { + return nil, fmt.Errorf("failed to add address %s to %s: %w", addr, n, err) + } + } + if err = out.handle.LinkSetMTU(l, mtu); err != nil { + return nil, err + } + if err = out.handle.LinkSetUp(l); err != nil { + return nil, err + } + + for _, route := range out.routes { + if err = out.handle.RouteAdd(route); err != nil { + return nil, fmt.Errorf("failed to add route %s: %w", route, err) + } + } + for _, rule := range out.rules { + if err = out.handle.RuleAdd(rule); err != nil { + return nil, fmt.Errorf("failed to add rule %s: %w", rule, err) + } + } + out.tun = wgt + return out, nil +} diff --git a/proxy/wireguard/wireguard.go b/proxy/wireguard/wireguard.go index 899dcac5..231776e7 100644 --- a/proxy/wireguard/wireguard.go +++ b/proxy/wireguard/wireguard.go @@ -24,10 +24,11 @@ import ( "bytes" "context" "fmt" + stdnet "net" "net/netip" "strings" + "sync" - "github.com/sagernet/wireguard-go/device" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/log" @@ -46,13 +47,15 @@ import ( // Handler is an outbound connection that silently swallow the entire payload. type Handler struct { conf *DeviceConfig - net *Net + net Tunnel bind *netBindClient policyManager policy.Manager dns dns.Client // cached configuration - ipc string - endpoints []netip.Addr + ipc string + endpoints []netip.Addr + hasIPv4, hasIPv6 bool + wgLock sync.Mutex } // New creates a new wireguard handler. @@ -64,15 +67,71 @@ func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) { return nil, err } + hasIPv4, hasIPv6 := false, false + for _, e := range endpoints { + if e.Is4() { + hasIPv4 = true + } + if e.Is6() { + hasIPv6 = true + } + } + + d := v.GetFeature(dns.ClientType()).(dns.Client) return &Handler{ conf: conf, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), - dns: v.GetFeature(dns.ClientType()).(dns.Client), - ipc: createIPCRequest(conf), + dns: d, + ipc: createIPCRequest(conf, d, hasIPv6), endpoints: endpoints, + hasIPv4: hasIPv4, + hasIPv6: hasIPv6, }, nil } +func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) { + h.wgLock.Lock() + defer h.wgLock.Unlock() + + if h.bind != nil && h.bind.dialer == dialer && h.net != nil { + return nil + } + + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Info, + Content: "switching dialer", + }) + + if h.net != nil { + _ = h.net.Close() + h.net = nil + } + if h.bind != nil { + _ = h.bind.Close() + h.bind = nil + } + + // bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer + bind := &netBindClient{ + dialer: dialer, + workers: int(h.conf.NumWorkers), + dns: h.dns, + reserved: h.conf.Reserved, + } + defer func() { + if err != nil { + _ = bind.Close() + } + }() + + h.net, err = h.makeVirtualTun(bind) + if err != nil { + return newError("failed to create virtual tun interface").Base(err) + } + h.bind = bind + return nil +} + // Process implements OutboundHandler.Dispatch(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { outbound := session.OutboundFromContext(ctx) @@ -85,30 +144,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte inbound.SetCanSpliceCopy(3) } - if h.bind == nil || h.bind.dialer != dialer || h.net == nil { - log.Record(&log.GeneralMessage{ - Severity: log.Severity_Info, - Content: "switching dialer", - }) - // bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer - bind := &netBindClient{ - dialer: dialer, - workers: int(h.conf.NumWorkers), - dns: h.dns, - reserved: h.conf.Reserved, - } - - net, err := h.makeVirtualTun(bind) - if err != nil { - bind.Close() - return newError("failed to create virtual tun interface").Base(err) - } - - h.net = net - if h.bind != nil { - h.bind.Close() - } - h.bind = bind + if err := h.processWireGuard(dialer); err != nil { + return err } // Destination of the inner request. @@ -122,8 +159,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte addr := destination.Address if addr.Family().IsDomain() { ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ - IPv4Enable: h.net.HasV4(), - IPv6Enable: h.net.HasV6(), + IPv4Enable: h.hasIPv4, + IPv6Enable: h.hasIPv6, }) if err != nil { return newError("failed to lookup DNS").Base(err) @@ -200,14 +237,26 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } // serialize the config into an IPC request -func createIPCRequest(conf *DeviceConfig) string { +func createIPCRequest(conf *DeviceConfig, d dns.Client, resolveEndPointToV4 bool) string { var request bytes.Buffer request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey)) for _, peer := range conf.Peers { + endpoint := peer.Endpoint + host, port, err := net.SplitHostPort(endpoint) + if resolveEndPointToV4 && err == nil { + _, err = netip.ParseAddr(host) + if err != nil { + ipList, err := d.LookupIP(host, dns.IPOption{IPv4Enable: true, IPv6Enable: false}) + if err == nil && len(ipList) > 0 { + endpoint = stdnet.JoinHostPort(ipList[0].String(), port) + } + } + } + request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n", - peer.PublicKey, peer.Endpoint, peer.KeepAlive, peer.PreSharedKey)) + peer.PublicKey, endpoint, peer.KeepAlive, peer.PreSharedKey)) for _, ip := range peer.AllowedIps { request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) @@ -245,41 +294,20 @@ func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) { } // creates a tun interface on netstack given a configuration -func (h *Handler) makeVirtualTun(bind *netBindClient) (*Net, error) { - tun, tnet, err := CreateNetTUN(h.endpoints, h.dns, int(h.conf.Mtu)) +func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) { + t, err := CreateTun(h.endpoints, int(h.conf.Mtu)) if err != nil { return nil, err } - bind.dnsOption.IPv4Enable = tnet.HasV4() - bind.dnsOption.IPv6Enable = tnet.HasV6() + bind.dnsOption.IPv4Enable = h.hasIPv4 + bind.dnsOption.IPv6Enable = h.hasIPv6 - // dev := device.NewDevice(tun, conn.NewDefaultBind(), nil /* device.NewLogger(device.LogLevelVerbose, "") */) - dev := device.NewDevice(tun, bind, &device.Logger{ - Verbosef: func(format string, args ...any) { - log.Record(&log.GeneralMessage{ - Severity: log.Severity_Debug, - Content: fmt.Sprintf(format, args...), - }) - }, - Errorf: func(format string, args ...any) { - log.Record(&log.GeneralMessage{ - Severity: log.Severity_Error, - Content: fmt.Sprintf(format, args...), - }) - }, - }, int(h.conf.NumWorkers)) - err = dev.IpcSet(h.ipc) - if err != nil { + if err = t.BuildDevice(h.ipc, bind); err != nil { + _ = t.Close() return nil, err } - - err = dev.Up() - if err != nil { - return nil, err - } - - return tnet, nil + return t, nil } func init() {