mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-12-23 22:19:49 +00:00
WireGuard Inbound (User-space WireGuard server) (#2477)
* feat: wireguard inbound * feat(command): generate wireguard compatible keypair * feat(wireguard): connection idle timeout * fix(wireguard): close endpoint after connection closed * fix(wireguard): resolve conflicts * feat(wireguard): set cubic as default cc algorithm in gVisor TUN * chore(wireguard): resolve conflict * chore(wireguard): remove redurant code * chore(wireguard): remove redurant code * feat: rework server for gvisor tun * feat: keep user-space tun as an option * fix: exclude android from native tun build * feat: auto kernel tun * fix: build * fix: regulate function name & fix test
This commit is contained in:
parent
f1c81557dc
commit
0ac7da2fc8
4
go.mod
4
go.mod
@ -27,6 +27,7 @@ require (
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb
|
||||
google.golang.org/grpc v1.59.0
|
||||
google.golang.org/protobuf v1.31.0
|
||||
gvisor.dev/gvisor v0.0.0-20231104011432-48a6d7d5bd0b
|
||||
h12.io/socks v1.0.3
|
||||
lukechampine.com/blake3 v1.2.1
|
||||
)
|
||||
@ -48,7 +49,7 @@ require (
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
|
||||
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect
|
||||
go.uber.org/mock v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect
|
||||
golang.org/x/mod v0.14.0 // indirect
|
||||
@ -59,5 +60,4 @@ require (
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20231106174013-bbf56f31fb17 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20231104011432-48a6d7d5bd0b // indirect
|
||||
)
|
||||
|
3
go.sum
3
go.sum
@ -168,9 +168,8 @@ github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49u
|
||||
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3 h1:tkMT5pTye+1NlKIXETU78NXw0fyjnaNHmJyyLyzw8+U=
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3/go.mod h1:cAAsePK2e15YDAMJNyOpGYEWNe4sIghTY7gpz4cX/Ik=
|
||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns=
|
||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19 h1:capMfFYRgH9BCLd6A3Er/cH3A9Nz3CU2KwxwOQZIePI=
|
||||
github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19/go.mod h1:dm4y/1QwzjGaK17ofi0Vs6NpKAHegZky8qk6J2JJZAE=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
|
@ -13,7 +13,7 @@ type WireGuardPeerConfig struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
PreSharedKey string `json:"preSharedKey"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
KeepAlive int `json:"keepAlive"`
|
||||
KeepAlive uint32 `json:"keepAlive"`
|
||||
AllowedIPs []string `json:"allowedIPs,omitempty"`
|
||||
}
|
||||
|
||||
@ -21,9 +21,11 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) {
|
||||
var err error
|
||||
config := new(wireguard.PeerConfig)
|
||||
|
||||
config.PublicKey, err = parseWireGuardKey(c.PublicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if c.PublicKey != "" {
|
||||
config.PublicKey, err = parseWireGuardKey(c.PublicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if c.PreSharedKey != "" {
|
||||
@ -31,13 +33,11 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
config.PreSharedKey = "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
}
|
||||
|
||||
config.Endpoint = c.Endpoint
|
||||
// default 0
|
||||
config.KeepAlive = int32(c.KeepAlive)
|
||||
config.KeepAlive = c.KeepAlive
|
||||
if c.AllowedIPs == nil {
|
||||
config.AllowedIps = []string{"0.0.0.0/0", "::0/0"}
|
||||
} else {
|
||||
@ -48,11 +48,14 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) {
|
||||
}
|
||||
|
||||
type WireGuardConfig struct {
|
||||
IsClient bool `json:""`
|
||||
|
||||
KernelMode *bool `json:"kernelMode"`
|
||||
SecretKey string `json:"secretKey"`
|
||||
Address []string `json:"address"`
|
||||
Peers []*WireGuardPeerConfig `json:"peers"`
|
||||
MTU int `json:"mtu"`
|
||||
NumWorkers int `json:"workers"`
|
||||
MTU int32 `json:"mtu"`
|
||||
NumWorkers int32 `json:"workers"`
|
||||
Reserved []byte `json:"reserved"`
|
||||
DomainStrategy string `json:"domainStrategy"`
|
||||
}
|
||||
@ -87,11 +90,11 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
|
||||
if c.MTU == 0 {
|
||||
config.Mtu = 1420
|
||||
} else {
|
||||
config.Mtu = int32(c.MTU)
|
||||
config.Mtu = c.MTU
|
||||
}
|
||||
// these a fallback code exists in github.com/nanoda0523/wireguard-go code,
|
||||
// these a fallback code exists in wireguard-go code,
|
||||
// we don't need to process fallback manually
|
||||
config.NumWorkers = int32(c.NumWorkers)
|
||||
config.NumWorkers = c.NumWorkers
|
||||
|
||||
if len(c.Reserved) != 0 && len(c.Reserved) != 3 {
|
||||
return nil, newError(`"reserved" should be empty or 3 bytes`)
|
||||
@ -113,22 +116,42 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
|
||||
return nil, newError("unsupported domain strategy: ", c.DomainStrategy)
|
||||
}
|
||||
|
||||
config.IsClient = c.IsClient
|
||||
if c.KernelMode != nil {
|
||||
config.KernelMode = *c.KernelMode
|
||||
if config.KernelMode && !wireguard.KernelTunSupported() {
|
||||
newError("kernel mode is not supported on your OS or permission is insufficient").AtWarning().WriteToLog()
|
||||
}
|
||||
} else {
|
||||
config.KernelMode = wireguard.KernelTunSupported()
|
||||
if config.KernelMode {
|
||||
newError("kernel mode is enabled as it's supported and permission is sufficient").AtDebug().WriteToLog()
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func parseWireGuardKey(str string) (string, error) {
|
||||
if len(str) != 64 {
|
||||
// may in base64 form
|
||||
dat, err := base64.StdEncoding.DecodeString(str)
|
||||
if err != nil {
|
||||
return "", err
|
||||
var err error
|
||||
|
||||
if len(str)%2 == 0 {
|
||||
_, err = hex.DecodeString(str)
|
||||
if err == nil {
|
||||
return str, nil
|
||||
}
|
||||
if len(dat) != 32 {
|
||||
return "", newError("key should be 32 bytes: " + str)
|
||||
}
|
||||
return hex.EncodeToString(dat), err
|
||||
} else {
|
||||
// already hex form
|
||||
return str, nil
|
||||
}
|
||||
|
||||
var dat []byte
|
||||
str = strings.TrimSuffix(str, "=")
|
||||
if strings.ContainsRune(str, '+') || strings.ContainsRune(str, '/') {
|
||||
dat, err = base64.RawStdEncoding.DecodeString(str)
|
||||
} else {
|
||||
dat, err = base64.RawURLEncoding.DecodeString(str)
|
||||
}
|
||||
if err == nil {
|
||||
return hex.EncodeToString(dat), nil
|
||||
}
|
||||
|
||||
return "", newError("failed to deserialize key").Base(err)
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
"github.com/xtls/xray-core/proxy/wireguard"
|
||||
)
|
||||
|
||||
func TestWireGuardOutbound(t *testing.T) {
|
||||
func TestWireGuardConfig(t *testing.T) {
|
||||
creator := func() Buildable {
|
||||
return new(WireGuardConfig)
|
||||
}
|
||||
@ -25,7 +25,8 @@ func TestWireGuardOutbound(t *testing.T) {
|
||||
],
|
||||
"mtu": 1300,
|
||||
"workers": 2,
|
||||
"domainStrategy": "ForceIPv6v4"
|
||||
"domainStrategy": "ForceIPv6v4",
|
||||
"kernelMode": false
|
||||
}`,
|
||||
Parser: loadJSON(creator),
|
||||
Output: &wireguard.DeviceConfig{
|
||||
@ -35,16 +36,16 @@ func TestWireGuardOutbound(t *testing.T) {
|
||||
Peers: []*wireguard.PeerConfig{
|
||||
{
|
||||
// also can read from hex form directly
|
||||
PublicKey: "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a",
|
||||
PreSharedKey: "0000000000000000000000000000000000000000000000000000000000000000",
|
||||
Endpoint: "127.0.0.1:1234",
|
||||
KeepAlive: 0,
|
||||
AllowedIps: []string{"0.0.0.0/0", "::0/0"},
|
||||
PublicKey: "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a",
|
||||
Endpoint: "127.0.0.1:1234",
|
||||
KeepAlive: 0,
|
||||
AllowedIps: []string{"0.0.0.0/0", "::0/0"},
|
||||
},
|
||||
},
|
||||
Mtu: 1300,
|
||||
NumWorkers: 2,
|
||||
DomainStrategy: wireguard.DeviceConfig_FORCE_IP64,
|
||||
KernelMode: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
@ -24,6 +24,7 @@ var (
|
||||
"vless": func() interface{} { return new(VLessInboundConfig) },
|
||||
"vmess": func() interface{} { return new(VMessInboundConfig) },
|
||||
"trojan": func() interface{} { return new(TrojanServerConfig) },
|
||||
"wireguard": func() interface{} { return &WireGuardConfig{IsClient: false} },
|
||||
}, "protocol", "settings")
|
||||
|
||||
outboundConfigLoader = NewJSONConfigLoader(ConfigCreatorCache{
|
||||
@ -37,7 +38,7 @@ var (
|
||||
"vmess": func() interface{} { return new(VMessOutboundConfig) },
|
||||
"trojan": func() interface{} { return new(TrojanClientConfig) },
|
||||
"dns": func() interface{} { return new(DNSOutboundConfig) },
|
||||
"wireguard": func() interface{} { return new(WireGuardConfig) },
|
||||
"wireguard": func() interface{} { return &WireGuardConfig{IsClient: true} },
|
||||
}, "protocol", "settings")
|
||||
|
||||
ctllog = log.New(os.Stderr, "xctl> ", 0)
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
var cmdX25519 = &base.Command{
|
||||
UsageLine: `{{.Exec}} x25519 [-i "private key (base64.RawURLEncoding)"]`,
|
||||
UsageLine: `{{.Exec}} x25519 [-i "private key (base64.RawURLEncoding)"] [--std-encoding]`,
|
||||
Short: `Generate key pair for x25519 key exchange`,
|
||||
Long: `
|
||||
Generate key pair for x25519 key exchange.
|
||||
@ -18,6 +18,7 @@ Generate key pair for x25519 key exchange.
|
||||
Random: {{.Exec}} x25519
|
||||
|
||||
From private key: {{.Exec}} x25519 -i "private key (base64.RawURLEncoding)"
|
||||
For Std Encoding: {{.Exec}} x25519 --std-encoding
|
||||
`,
|
||||
}
|
||||
|
||||
@ -26,12 +27,14 @@ func init() {
|
||||
}
|
||||
|
||||
var input_base64 = cmdX25519.Flag.String("i", "", "")
|
||||
var input_stdEncoding = cmdX25519.Flag.Bool("std-encoding", false, "")
|
||||
|
||||
func executeX25519(cmd *base.Command, args []string) {
|
||||
var output string
|
||||
var err error
|
||||
var privateKey []byte
|
||||
var publicKey []byte
|
||||
var encoding *base64.Encoding
|
||||
if len(*input_base64) > 0 {
|
||||
privateKey, err = base64.RawURLEncoding.DecodeString(*input_base64)
|
||||
if err != nil {
|
||||
@ -63,9 +66,15 @@ func executeX25519(cmd *base.Command, args []string) {
|
||||
goto out
|
||||
}
|
||||
|
||||
if *input_stdEncoding {
|
||||
encoding = base64.StdEncoding
|
||||
} else {
|
||||
encoding = base64.RawURLEncoding
|
||||
}
|
||||
|
||||
output = fmt.Sprintf("Private key: %v\nPublic key: %v",
|
||||
base64.RawURLEncoding.EncodeToString(privateKey),
|
||||
base64.RawURLEncoding.EncodeToString(publicKey))
|
||||
encoding.EncodeToString(privateKey),
|
||||
encoding.EncodeToString(publicKey))
|
||||
out:
|
||||
fmt.Println(output)
|
||||
}
|
||||
|
@ -27,48 +27,45 @@ type netReadInfo struct {
|
||||
err error
|
||||
}
|
||||
|
||||
type netBindClient struct {
|
||||
workers int
|
||||
dialer internet.Dialer
|
||||
// reduce duplicated code
|
||||
type netBind struct {
|
||||
dns dns.Client
|
||||
dnsOption dns.IPOption
|
||||
reserved []byte
|
||||
|
||||
workers int
|
||||
readQueue chan *netReadInfo
|
||||
}
|
||||
|
||||
func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||
ipStr, port, _, err := splitAddrPort(s)
|
||||
// SetMark implements conn.Bind
|
||||
func (bind *netBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseEndpoint implements conn.Bind
|
||||
func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||
ipStr, port, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var addr net.IP
|
||||
if IsDomainName(ipStr) {
|
||||
ips, err := bind.dns.LookupIP(ipStr, bind.dnsOption)
|
||||
addr := xnet.ParseAddress(ipStr)
|
||||
if addr.Family() == xnet.AddressFamilyDomain {
|
||||
ips, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(ips) == 0 {
|
||||
return nil, dns.ErrEmptyResponse
|
||||
}
|
||||
addr = ips[0]
|
||||
} else {
|
||||
addr = net.ParseIP(ipStr)
|
||||
}
|
||||
if addr == nil {
|
||||
return nil, errors.New("failed to parse ip: " + ipStr)
|
||||
}
|
||||
|
||||
var ip xnet.Address
|
||||
if p4 := addr.To4(); len(p4) == net.IPv4len {
|
||||
ip = xnet.IPAddress(p4[:])
|
||||
} else {
|
||||
ip = xnet.IPAddress(addr[:])
|
||||
addr = xnet.IPAddress(ips[0])
|
||||
}
|
||||
|
||||
dst := xnet.Destination{
|
||||
Address: ip,
|
||||
Port: xnet.Port(port),
|
||||
Address: addr,
|
||||
Port: xnet.Port(portNum),
|
||||
Network: xnet.Network_UDP,
|
||||
}
|
||||
|
||||
@ -77,7 +74,13 @@ func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
// BatchSize implements conn.Bind
|
||||
func (bind *netBind) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
// Open implements conn.Bind
|
||||
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
bind.readQueue = make(chan *netReadInfo)
|
||||
|
||||
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
||||
@ -109,13 +112,21 @@ func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error
|
||||
return arr, uint16(uport), nil
|
||||
}
|
||||
|
||||
func (bind *netBindClient) Close() error {
|
||||
// Close implements conn.Bind
|
||||
func (bind *netBind) Close() error {
|
||||
if bind.readQueue != nil {
|
||||
close(bind.readQueue)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type netBindClient struct {
|
||||
netBind
|
||||
|
||||
dialer internet.Dialer
|
||||
reserved []byte
|
||||
}
|
||||
|
||||
func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
||||
c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
|
||||
if err != nil {
|
||||
@ -177,12 +188,29 @@ func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *netBindClient) SetMark(mark uint32) error {
|
||||
return nil
|
||||
type netBindServer struct {
|
||||
netBind
|
||||
}
|
||||
|
||||
func (bind *netBindClient) BatchSize() int {
|
||||
return 1
|
||||
func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
|
||||
var err error
|
||||
|
||||
nend, ok := endpoint.(*netEndpoint)
|
||||
if !ok {
|
||||
return conn.ErrWrongEndpointType
|
||||
}
|
||||
|
||||
if nend.conn == nil {
|
||||
return newError("connection not open yet")
|
||||
}
|
||||
|
||||
for _, buff := range buff {
|
||||
if _, err = nend.conn.Write(buff); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
type netEndpoint struct {
|
||||
@ -193,7 +221,7 @@ type netEndpoint struct {
|
||||
func (netEndpoint) ClearSrc() {}
|
||||
|
||||
func (e netEndpoint) DstIP() netip.Addr {
|
||||
return toNetIpAddr(e.dst.Address)
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (e netEndpoint) SrcIP() netip.Addr {
|
||||
@ -232,83 +260,3 @@ func toNetIpAddr(addr xnet.Address) netip.Addr {
|
||||
return netip.AddrFrom16(arr)
|
||||
}
|
||||
}
|
||||
|
||||
func stringsLastIndexByte(s string, b byte) int {
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
if s[i] == b {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) {
|
||||
i := stringsLastIndexByte(s, ':')
|
||||
if i == -1 {
|
||||
return "", 0, false, errors.New("not an ip:port")
|
||||
}
|
||||
|
||||
ip = s[:i]
|
||||
portStr := s[i+1:]
|
||||
if len(ip) == 0 {
|
||||
return "", 0, false, errors.New("no IP")
|
||||
}
|
||||
if len(portStr) == 0 {
|
||||
return "", 0, false, errors.New("no port")
|
||||
}
|
||||
port64, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s))
|
||||
}
|
||||
port = uint16(port64)
|
||||
if ip[0] == '[' {
|
||||
if len(ip) < 2 || ip[len(ip)-1] != ']' {
|
||||
return "", 0, false, errors.New("missing ]")
|
||||
}
|
||||
ip = ip[1 : len(ip)-1]
|
||||
v6 = true
|
||||
}
|
||||
|
||||
return ip, port, v6, nil
|
||||
}
|
||||
|
||||
func IsDomainName(s string) bool {
|
||||
l := len(s)
|
||||
if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
|
||||
return false
|
||||
}
|
||||
last := byte('.')
|
||||
nonNumeric := false
|
||||
partlen := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
switch {
|
||||
default:
|
||||
return false
|
||||
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
|
||||
nonNumeric = true
|
||||
partlen++
|
||||
case '0' <= c && c <= '9':
|
||||
partlen++
|
||||
case c == '-':
|
||||
if last == '.' {
|
||||
return false
|
||||
}
|
||||
partlen++
|
||||
nonNumeric = true
|
||||
case c == '.':
|
||||
if last == '.' || last == '-' {
|
||||
return false
|
||||
}
|
||||
if partlen > 63 || partlen == 0 {
|
||||
return false
|
||||
}
|
||||
partlen = 0
|
||||
}
|
||||
last = c
|
||||
}
|
||||
if last == '-' || partlen > 63 {
|
||||
return false
|
||||
}
|
||||
return nonNumeric
|
||||
}
|
||||
|
255
proxy/wireguard/client.go
Normal file
255
proxy/wireguard/client.go
Normal file
@ -0,0 +1,255 @@
|
||||
/*
|
||||
|
||||
Some of codes are copied from https://github.com/octeep/wireproxy, license below.
|
||||
|
||||
Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
|
||||
|
||||
Permission to use, copy, modify, and distribute this software for any
|
||||
purpose with or without fee is hereby granted, provided that the above
|
||||
copyright notice and this permission notice appear in all copies.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
*/
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
"github.com/xtls/xray-core/common/dice"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/protocol"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
"github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/features/dns"
|
||||
"github.com/xtls/xray-core/features/policy"
|
||||
"github.com/xtls/xray-core/transport"
|
||||
"github.com/xtls/xray-core/transport/internet"
|
||||
)
|
||||
|
||||
// Handler is an outbound connection that silently swallow the entire payload.
|
||||
type Handler struct {
|
||||
conf *DeviceConfig
|
||||
net Tunnel
|
||||
bind *netBindClient
|
||||
policyManager policy.Manager
|
||||
dns dns.Client
|
||||
// cached configuration
|
||||
ipc string
|
||||
endpoints []netip.Addr
|
||||
hasIPv4, hasIPv6 bool
|
||||
wgLock sync.Mutex
|
||||
}
|
||||
|
||||
// New creates a new wireguard handler.
|
||||
func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
|
||||
v := core.MustFromContext(ctx)
|
||||
|
||||
endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := v.GetFeature(dns.ClientType()).(dns.Client)
|
||||
return &Handler{
|
||||
conf: conf,
|
||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
||||
dns: d,
|
||||
ipc: createIPCRequest(conf),
|
||||
endpoints: endpoints,
|
||||
hasIPv4: hasIPv4,
|
||||
hasIPv6: hasIPv6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
|
||||
h.wgLock.Lock()
|
||||
defer h.wgLock.Unlock()
|
||||
|
||||
if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Info,
|
||||
Content: "switching dialer",
|
||||
})
|
||||
|
||||
if h.net != nil {
|
||||
_ = h.net.Close()
|
||||
h.net = nil
|
||||
}
|
||||
if h.bind != nil {
|
||||
_ = h.bind.Close()
|
||||
h.bind = nil
|
||||
}
|
||||
|
||||
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
|
||||
bind := &netBindClient{
|
||||
netBind: netBind{
|
||||
dns: h.dns,
|
||||
dnsOption: dns.IPOption{
|
||||
IPv4Enable: h.hasIPv4,
|
||||
IPv6Enable: h.hasIPv6,
|
||||
},
|
||||
workers: int(h.conf.NumWorkers),
|
||||
},
|
||||
dialer: dialer,
|
||||
reserved: h.conf.Reserved,
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = bind.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
h.net, err = h.makeVirtualTun(bind)
|
||||
if err != nil {
|
||||
return newError("failed to create virtual tun interface").Base(err)
|
||||
}
|
||||
h.bind = bind
|
||||
return nil
|
||||
}
|
||||
|
||||
// Process implements OutboundHandler.Dispatch().
|
||||
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
if outbound == nil || !outbound.Target.IsValid() {
|
||||
return newError("target not specified")
|
||||
}
|
||||
outbound.Name = "wireguard"
|
||||
inbound := session.InboundFromContext(ctx)
|
||||
if inbound != nil {
|
||||
inbound.SetCanSpliceCopy(3)
|
||||
}
|
||||
|
||||
if err := h.processWireGuard(dialer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Destination of the inner request.
|
||||
destination := outbound.Target
|
||||
command := protocol.RequestCommandTCP
|
||||
if destination.Network == net.Network_UDP {
|
||||
command = protocol.RequestCommandUDP
|
||||
}
|
||||
|
||||
// resolve dns
|
||||
addr := destination.Address
|
||||
if addr.Family().IsDomain() {
|
||||
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
||||
IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
|
||||
IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
|
||||
})
|
||||
{ // Resolve fallback
|
||||
if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
|
||||
ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
||||
IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
|
||||
IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
|
||||
})
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return newError("failed to lookup DNS").Base(err)
|
||||
} else if len(ips) == 0 {
|
||||
return dns.ErrEmptyResponse
|
||||
}
|
||||
addr = net.IPAddress(ips[dice.Roll(len(ips))])
|
||||
}
|
||||
|
||||
var newCtx context.Context
|
||||
var newCancel context.CancelFunc
|
||||
if session.TimeoutOnlyFromContext(ctx) {
|
||||
newCtx, newCancel = context.WithCancel(context.Background())
|
||||
}
|
||||
|
||||
p := h.policyManager.ForLevel(0)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
timer := signal.CancelAfterInactivity(ctx, func() {
|
||||
cancel()
|
||||
if newCancel != nil {
|
||||
newCancel()
|
||||
}
|
||||
}, p.Timeouts.ConnectionIdle)
|
||||
addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
|
||||
|
||||
var requestFunc func() error
|
||||
var responseFunc func() error
|
||||
|
||||
if command == protocol.RequestCommandTCP {
|
||||
conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
|
||||
if err != nil {
|
||||
return newError("failed to create TCP connection").Base(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
requestFunc = func() error {
|
||||
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
|
||||
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
|
||||
}
|
||||
responseFunc = func() error {
|
||||
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
|
||||
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
|
||||
}
|
||||
} else if command == protocol.RequestCommandUDP {
|
||||
conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
|
||||
if err != nil {
|
||||
return newError("failed to create UDP connection").Base(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
requestFunc = func() error {
|
||||
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
|
||||
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
|
||||
}
|
||||
responseFunc = func() error {
|
||||
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
|
||||
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
|
||||
}
|
||||
}
|
||||
|
||||
if newCtx != nil {
|
||||
ctx = newCtx
|
||||
}
|
||||
|
||||
responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
|
||||
common.Interrupt(link.Reader)
|
||||
common.Interrupt(link.Writer)
|
||||
return newError("connection ends").Base(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// creates a tun interface on netstack given a configuration
|
||||
func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
|
||||
t, err := h.conf.createTun()(h.endpoints, int(h.conf.Mtu), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bind.dnsOption.IPv4Enable = h.hasIPv4
|
||||
bind.dnsOption.IPv6Enable = h.hasIPv6
|
||||
|
||||
if err = t.BuildDevice(h.ipc, bind); err != nil {
|
||||
_ = t.Close()
|
||||
return nil, err
|
||||
}
|
||||
return t, nil
|
||||
}
|
@ -23,3 +23,10 @@ func (c *DeviceConfig) fallbackIP4() bool {
|
||||
func (c *DeviceConfig) fallbackIP6() bool {
|
||||
return c.DomainStrategy == DeviceConfig_FORCE_IP46
|
||||
}
|
||||
|
||||
func (c *DeviceConfig) createTun() tunCreator {
|
||||
if c.KernelMode {
|
||||
return createKernelTun
|
||||
}
|
||||
return createGVisorTun
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.31.0
|
||||
// protoc v4.23.1
|
||||
// protoc-gen-go v1.28.1
|
||||
// protoc v4.25.0
|
||||
// source: proxy/wireguard/config.proto
|
||||
|
||||
package wireguard
|
||||
@ -83,7 +83,7 @@ type PeerConfig struct {
|
||||
PublicKey string `protobuf:"bytes,1,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
|
||||
PreSharedKey string `protobuf:"bytes,2,opt,name=pre_shared_key,json=preSharedKey,proto3" json:"pre_shared_key,omitempty"`
|
||||
Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
|
||||
KeepAlive int32 `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"`
|
||||
KeepAlive uint32 `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"`
|
||||
AllowedIps []string `protobuf:"bytes,5,rep,name=allowed_ips,json=allowedIps,proto3" json:"allowed_ips,omitempty"`
|
||||
}
|
||||
|
||||
@ -140,7 +140,7 @@ func (x *PeerConfig) GetEndpoint() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *PeerConfig) GetKeepAlive() int32 {
|
||||
func (x *PeerConfig) GetKeepAlive() uint32 {
|
||||
if x != nil {
|
||||
return x.KeepAlive
|
||||
}
|
||||
@ -166,6 +166,8 @@ type DeviceConfig struct {
|
||||
NumWorkers int32 `protobuf:"varint,5,opt,name=num_workers,json=numWorkers,proto3" json:"num_workers,omitempty"`
|
||||
Reserved []byte `protobuf:"bytes,6,opt,name=reserved,proto3" json:"reserved,omitempty"`
|
||||
DomainStrategy DeviceConfig_DomainStrategy `protobuf:"varint,7,opt,name=domain_strategy,json=domainStrategy,proto3,enum=xray.proxy.wireguard.DeviceConfig_DomainStrategy" json:"domain_strategy,omitempty"`
|
||||
IsClient bool `protobuf:"varint,8,opt,name=is_client,json=isClient,proto3" json:"is_client,omitempty"`
|
||||
KernelMode bool `protobuf:"varint,9,opt,name=kernel_mode,json=kernelMode,proto3" json:"kernel_mode,omitempty"`
|
||||
}
|
||||
|
||||
func (x *DeviceConfig) Reset() {
|
||||
@ -249,6 +251,20 @@ func (x *DeviceConfig) GetDomainStrategy() DeviceConfig_DomainStrategy {
|
||||
return DeviceConfig_FORCE_IP
|
||||
}
|
||||
|
||||
func (x *DeviceConfig) GetIsClient() bool {
|
||||
if x != nil {
|
||||
return x.IsClient
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *DeviceConfig) GetKernelMode() bool {
|
||||
if x != nil {
|
||||
return x.KernelMode
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var File_proxy_wireguard_config_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_proxy_wireguard_config_proto_rawDesc = []byte{
|
||||
@ -263,10 +279,10 @@ var file_proxy_wireguard_config_proto_rawDesc = []byte{
|
||||
0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70,
|
||||
0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x64, 0x70,
|
||||
0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x6b, 0x65, 0x65, 0x70, 0x5f, 0x61, 0x6c, 0x69,
|
||||
0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c,
|
||||
0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c,
|
||||
0x69, 0x76, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x69,
|
||||
0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65,
|
||||
0x64, 0x49, 0x70, 0x73, 0x22, 0x8a, 0x03, 0x0a, 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43,
|
||||
0x64, 0x49, 0x70, 0x73, 0x22, 0xc8, 0x03, 0x0a, 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43,
|
||||
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x5f,
|
||||
0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x63, 0x72, 0x65,
|
||||
0x74, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
|
||||
@ -285,19 +301,23 @@ var file_proxy_wireguard_config_proto_rawDesc = []byte{
|
||||
0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, 0x6f,
|
||||
0x6e, 0x66, 0x69, 0x67, 0x2e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
|
||||
0x65, 0x67, 0x79, 0x52, 0x0e, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
|
||||
0x65, 0x67, 0x79, 0x22, 0x5c, 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72,
|
||||
0x61, 0x74, 0x65, 0x67, 0x79, 0x12, 0x0c, 0x0a, 0x08, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49,
|
||||
0x50, 0x10, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34,
|
||||
0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x10,
|
||||
0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x36, 0x10,
|
||||
0x03, 0x12, 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x34, 0x10,
|
||||
0x04, 0x42, 0x5e, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72,
|
||||
0x6f, 0x78, 0x79, 0x2e, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x01, 0x5a,
|
||||
0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73,
|
||||
0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79,
|
||||
0x2f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0xaa, 0x02, 0x14, 0x58, 0x72, 0x61,
|
||||
0x79, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x47, 0x75, 0x61, 0x72,
|
||||
0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x65, 0x67, 0x79, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x73, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74,
|
||||
0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x69, 0x73, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74,
|
||||
0x12, 0x1f, 0x0a, 0x0b, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x18,
|
||||
0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x4d, 0x6f, 0x64,
|
||||
0x65, 0x22, 0x5c, 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
|
||||
0x65, 0x67, 0x79, 0x12, 0x0c, 0x0a, 0x08, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x10,
|
||||
0x00, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x10, 0x01,
|
||||
0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x10, 0x02, 0x12,
|
||||
0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x36, 0x10, 0x03, 0x12,
|
||||
0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x34, 0x10, 0x04, 0x42,
|
||||
0x5e, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x78,
|
||||
0x79, 0x2e, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x01, 0x5a, 0x29, 0x67,
|
||||
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78,
|
||||
0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x77,
|
||||
0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0xaa, 0x02, 0x14, 0x58, 0x72, 0x61, 0x79, 0x2e,
|
||||
0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x47, 0x75, 0x61, 0x72, 0x64, 0x62,
|
||||
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -7,26 +7,28 @@ option java_package = "com.xray.proxy.wireguard";
|
||||
option java_multiple_files = true;
|
||||
|
||||
message PeerConfig {
|
||||
string public_key = 1;
|
||||
string pre_shared_key = 2;
|
||||
string endpoint = 3;
|
||||
int32 keep_alive = 4;
|
||||
repeated string allowed_ips = 5;
|
||||
string public_key = 1;
|
||||
string pre_shared_key = 2;
|
||||
string endpoint = 3;
|
||||
uint32 keep_alive = 4;
|
||||
repeated string allowed_ips = 5;
|
||||
}
|
||||
|
||||
message DeviceConfig {
|
||||
enum DomainStrategy {
|
||||
FORCE_IP = 0;
|
||||
FORCE_IP4 = 1;
|
||||
FORCE_IP6 = 2;
|
||||
FORCE_IP46 = 3;
|
||||
FORCE_IP64 = 4;
|
||||
}
|
||||
string secret_key = 1;
|
||||
repeated string endpoint = 2;
|
||||
repeated PeerConfig peers = 3;
|
||||
int32 mtu = 4;
|
||||
int32 num_workers = 5;
|
||||
bytes reserved = 6;
|
||||
DomainStrategy domain_strategy = 7;
|
||||
enum DomainStrategy {
|
||||
FORCE_IP = 0;
|
||||
FORCE_IP4 = 1;
|
||||
FORCE_IP6 = 2;
|
||||
FORCE_IP46 = 3;
|
||||
FORCE_IP64 = 4;
|
||||
}
|
||||
string secret_key = 1;
|
||||
repeated string endpoint = 2;
|
||||
repeated PeerConfig peers = 3;
|
||||
int32 mtu = 4;
|
||||
int32 num_workers = 5;
|
||||
bytes reserved = 6;
|
||||
DomainStrategy domain_strategy = 7;
|
||||
bool is_client = 8;
|
||||
bool kernel_mode = 9;
|
||||
}
|
230
proxy/wireguard/gvisortun/tun.go
Normal file
230
proxy/wireguard/gvisortun/tun.go
Normal file
@ -0,0 +1,230 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package gvisortun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
type netTun struct {
|
||||
ep *channel.Endpoint
|
||||
stack *stack.Stack
|
||||
events chan tun.Event
|
||||
incomingPacket chan *buffer.View
|
||||
mtu int
|
||||
hasV4, hasV6 bool
|
||||
}
|
||||
|
||||
type Net netTun
|
||||
|
||||
func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (tun.Device, *Net, *stack.Stack, error) {
|
||||
opts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
|
||||
HandleLocal: !promiscuousMode,
|
||||
}
|
||||
dev := &netTun{
|
||||
ep: channel.New(1024, uint32(mtu), ""),
|
||||
stack: stack.New(opts),
|
||||
events: make(chan tun.Event, 1),
|
||||
incomingPacket: make(chan *buffer.View),
|
||||
mtu: mtu,
|
||||
}
|
||||
dev.ep.AddNotify(dev)
|
||||
tcpipErr := dev.stack.CreateNIC(1, dev.ep)
|
||||
if tcpipErr != nil {
|
||||
return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr)
|
||||
}
|
||||
for _, ip := range localAddresses {
|
||||
var protoNumber tcpip.NetworkProtocolNumber
|
||||
if ip.Is4() {
|
||||
protoNumber = ipv4.ProtocolNumber
|
||||
} else if ip.Is6() {
|
||||
protoNumber = ipv6.ProtocolNumber
|
||||
}
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: protoNumber,
|
||||
AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
|
||||
}
|
||||
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
|
||||
if tcpipErr != nil {
|
||||
return nil, nil, dev.stack, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
|
||||
}
|
||||
if ip.Is4() {
|
||||
dev.hasV4 = true
|
||||
} else if ip.Is6() {
|
||||
dev.hasV6 = true
|
||||
}
|
||||
}
|
||||
if dev.hasV4 {
|
||||
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
|
||||
}
|
||||
if dev.hasV6 {
|
||||
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
|
||||
}
|
||||
if promiscuousMode {
|
||||
// enable promiscuous mode to handle all packets processed by netstack
|
||||
dev.stack.SetPromiscuousMode(1, true)
|
||||
dev.stack.SetSpoofing(1, true)
|
||||
}
|
||||
|
||||
opt := tcpip.CongestionControlOption("cubic")
|
||||
if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
|
||||
return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err)
|
||||
}
|
||||
|
||||
dev.events <- tun.EventUp
|
||||
return dev, (*Net)(dev), dev.stack, nil
|
||||
}
|
||||
|
||||
// BatchSize implements tun.Device
|
||||
func (tun *netTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
// Name implements tun.Device
|
||||
func (tun *netTun) Name() (string, error) {
|
||||
return "go", nil
|
||||
}
|
||||
|
||||
// File implements tun.Device
|
||||
func (tun *netTun) File() *os.File {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Events implements tun.Device
|
||||
func (tun *netTun) Events() <-chan tun.Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
// Read implements tun.Device
|
||||
|
||||
func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
|
||||
view, ok := <-tun.incomingPacket
|
||||
if !ok {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
n, err := view.Read(buf[0][offset:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
// Write implements tun.Device
|
||||
func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
|
||||
for _, buf := range buf {
|
||||
packet := buf[offset:]
|
||||
if len(packet) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
|
||||
switch packet[0] >> 4 {
|
||||
case 4:
|
||||
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
||||
case 6:
|
||||
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
|
||||
default:
|
||||
return 0, syscall.EAFNOSUPPORT
|
||||
}
|
||||
}
|
||||
return len(buf), nil
|
||||
}
|
||||
|
||||
// WriteNotify implements channel.Notification
|
||||
func (tun *netTun) WriteNotify() {
|
||||
pkt := tun.ep.Read()
|
||||
if pkt.IsNil() {
|
||||
return
|
||||
}
|
||||
|
||||
view := pkt.ToView()
|
||||
pkt.DecRef()
|
||||
|
||||
tun.incomingPacket <- view
|
||||
}
|
||||
|
||||
// Flush implements tun.Device
|
||||
func (tun *netTun) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements tun.Device
|
||||
func (tun *netTun) Close() error {
|
||||
tun.stack.RemoveNIC(1)
|
||||
|
||||
if tun.events != nil {
|
||||
close(tun.events)
|
||||
}
|
||||
|
||||
tun.ep.Close()
|
||||
|
||||
if tun.incomingPacket != nil {
|
||||
close(tun.incomingPacket)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MTU implements tun.Device
|
||||
func (tun *netTun) MTU() (int, error) {
|
||||
return tun.mtu, nil
|
||||
}
|
||||
|
||||
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
|
||||
var protoNumber tcpip.NetworkProtocolNumber
|
||||
if endpoint.Addr().Is4() {
|
||||
protoNumber = ipv4.ProtocolNumber
|
||||
} else {
|
||||
protoNumber = ipv6.ProtocolNumber
|
||||
}
|
||||
return tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
|
||||
Port: endpoint.Port(),
|
||||
}, protoNumber
|
||||
}
|
||||
|
||||
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
|
||||
fa, pn := convertToFullAddr(addr)
|
||||
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
|
||||
}
|
||||
|
||||
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
|
||||
var lfa, rfa *tcpip.FullAddress
|
||||
var pn tcpip.NetworkProtocolNumber
|
||||
if laddr.IsValid() || laddr.Port() > 0 {
|
||||
var addr tcpip.FullAddress
|
||||
addr, pn = convertToFullAddr(laddr)
|
||||
lfa = &addr
|
||||
}
|
||||
if raddr.IsValid() || raddr.Port() > 0 {
|
||||
var addr tcpip.FullAddress
|
||||
addr, pn = convertToFullAddr(raddr)
|
||||
rfa = &addr
|
||||
}
|
||||
return gonet.DialUDP(net.stack, lfa, rfa, pn)
|
||||
}
|
181
proxy/wireguard/server.go
Normal file
181
proxy/wireguard/server.go
Normal file
@ -0,0 +1,181 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
"github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/features/dns"
|
||||
"github.com/xtls/xray-core/features/policy"
|
||||
"github.com/xtls/xray-core/features/routing"
|
||||
"github.com/xtls/xray-core/transport/internet/stat"
|
||||
)
|
||||
|
||||
var nullDestination = net.TCPDestination(net.AnyIP, 0)
|
||||
|
||||
type Server struct {
|
||||
bindServer *netBindServer
|
||||
|
||||
info routingInfo
|
||||
policyManager policy.Manager
|
||||
}
|
||||
|
||||
type routingInfo struct {
|
||||
ctx context.Context
|
||||
dispatcher routing.Dispatcher
|
||||
inboundTag *session.Inbound
|
||||
outboundTag *session.Outbound
|
||||
contentTag *session.Content
|
||||
}
|
||||
|
||||
func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
|
||||
v := core.MustFromContext(ctx)
|
||||
|
||||
endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
bindServer: &netBindServer{
|
||||
netBind: netBind{
|
||||
dns: v.GetFeature(dns.ClientType()).(dns.Client),
|
||||
dnsOption: dns.IPOption{
|
||||
IPv4Enable: hasIPv4,
|
||||
IPv6Enable: hasIPv6,
|
||||
},
|
||||
},
|
||||
},
|
||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
||||
}
|
||||
|
||||
tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil {
|
||||
_ = tun.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// Network implements proxy.Inbound.
|
||||
func (*Server) Network() []net.Network {
|
||||
return []net.Network{net.Network_UDP}
|
||||
}
|
||||
|
||||
// Process implements proxy.Inbound.
|
||||
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
|
||||
s.info = routingInfo{
|
||||
ctx: core.ToBackgroundDetachedContext(ctx),
|
||||
dispatcher: dispatcher,
|
||||
inboundTag: session.InboundFromContext(ctx),
|
||||
outboundTag: session.OutboundFromContext(ctx),
|
||||
contentTag: session.ContentFromContext(ctx),
|
||||
}
|
||||
|
||||
ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nep := ep.(*netEndpoint)
|
||||
nep.conn = conn
|
||||
|
||||
reader := buf.NewPacketReader(conn)
|
||||
for {
|
||||
mpayload, err := reader.ReadMultiBuffer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, payload := range mpayload {
|
||||
v, ok := <-s.bindServer.readQueue
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
i, err := payload.Read(v.buff)
|
||||
|
||||
v.bytes = i
|
||||
v.endpoint = nep
|
||||
v.err = err
|
||||
v.waiter.Done()
|
||||
if err != nil && errors.Is(err, io.EOF) {
|
||||
nep.conn = nil
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
|
||||
if s.info.dispatcher == nil {
|
||||
newError("unexpected: dispatcher == nil").AtError().WriteToLog()
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
|
||||
plcy := s.policyManager.ForLevel(0)
|
||||
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
|
||||
|
||||
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
||||
From: nullDestination,
|
||||
To: dest,
|
||||
Status: log.AccessAccepted,
|
||||
Reason: "",
|
||||
})
|
||||
|
||||
if s.info.inboundTag != nil {
|
||||
ctx = session.ContextWithInbound(ctx, s.info.inboundTag)
|
||||
}
|
||||
if s.info.outboundTag != nil {
|
||||
ctx = session.ContextWithOutbound(ctx, s.info.outboundTag)
|
||||
}
|
||||
if s.info.contentTag != nil {
|
||||
ctx = session.ContextWithContent(ctx, s.info.contentTag)
|
||||
}
|
||||
|
||||
link, err := s.info.dispatcher.Dispatch(ctx, dest)
|
||||
if err != nil {
|
||||
newError("dispatch connection").Base(err).AtError().WriteToLog()
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
requestDone := func() error {
|
||||
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
|
||||
if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
|
||||
return newError("failed to transport all TCP request").Base(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
responseDone := func() error {
|
||||
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
|
||||
if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
|
||||
return newError("failed to transport all TCP response").Base(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
|
||||
common.Interrupt(link.Reader)
|
||||
common.Interrupt(link.Writer)
|
||||
newError("connection ends").Base(err).AtDebug().WriteToLog()
|
||||
return
|
||||
}
|
||||
}
|
@ -10,14 +10,26 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
xnet "github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/proxy/wireguard/gvisortun"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
type tunCreator func(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error)
|
||||
|
||||
type promiscuousModeHandler func(dest xnet.Destination, conn net.Conn)
|
||||
|
||||
type Tunnel interface {
|
||||
BuildDevice(ipc string, bind conn.Bind) error
|
||||
DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
|
||||
@ -103,3 +115,91 @@ func CalculateInterfaceName(name string) (tunName string) {
|
||||
tunName = fmt.Sprintf("%s%d", tunName, tunIndex)
|
||||
return
|
||||
}
|
||||
|
||||
var _ Tunnel = (*gvisorNet)(nil)
|
||||
|
||||
type gvisorNet struct {
|
||||
tunnel
|
||||
net *gvisortun.Net
|
||||
}
|
||||
|
||||
func (g *gvisorNet) Close() error {
|
||||
return g.tunnel.Close()
|
||||
}
|
||||
|
||||
func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
|
||||
net.Conn, error,
|
||||
) {
|
||||
return g.net.DialContextTCPAddrPort(ctx, addr)
|
||||
}
|
||||
|
||||
func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
|
||||
return g.net.DialUDPAddrPort(laddr, raddr)
|
||||
}
|
||||
|
||||
func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
|
||||
out := &gvisorNet{}
|
||||
tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if handler != nil {
|
||||
// handler is only used for promiscuous mode
|
||||
// capture all packets and send to handler
|
||||
|
||||
tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
|
||||
go func(r *tcp.ForwarderRequest) {
|
||||
var (
|
||||
wq waiter.Queue
|
||||
id = r.ID()
|
||||
)
|
||||
|
||||
// Perform a TCP three-way handshake.
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
newError(err.String()).AtError().WriteToLog()
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
r.Complete(false)
|
||||
defer ep.Close()
|
||||
|
||||
// enable tcp keep-alive to prevent hanging connections
|
||||
ep.SocketOptions().SetKeepAlive(true)
|
||||
|
||||
// local address is actually destination
|
||||
handler(xnet.TCPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
|
||||
}(r)
|
||||
})
|
||||
stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
|
||||
udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) {
|
||||
go func(r *udp.ForwarderRequest) {
|
||||
var (
|
||||
wq waiter.Queue
|
||||
id = r.ID()
|
||||
)
|
||||
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
newError(err.String()).AtError().WriteToLog()
|
||||
return
|
||||
}
|
||||
defer ep.Close()
|
||||
|
||||
// prevents hanging connections and ensure timely release
|
||||
ep.SocketOptions().SetLinger(tcpip.LingerOption{
|
||||
Enabled: true,
|
||||
Timeout: 15 * time.Second,
|
||||
})
|
||||
|
||||
handler(xnet.UDPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewUDPConn(stack, &wq, ep))
|
||||
}(r)
|
||||
})
|
||||
stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
}
|
||||
|
||||
out.tun, out.net = tun, n
|
||||
return out, nil
|
||||
}
|
||||
|
@ -1,42 +1,16 @@
|
||||
//go:build !linux
|
||||
//go:build !linux || android
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"errors"
|
||||
"net/netip"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
)
|
||||
|
||||
var _ Tunnel = (*gvisorNet)(nil)
|
||||
|
||||
type gvisorNet struct {
|
||||
tunnel
|
||||
net *netstack.Net
|
||||
func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (g *gvisorNet) Close() error {
|
||||
return g.tunnel.Close()
|
||||
}
|
||||
|
||||
func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
|
||||
net.Conn, error,
|
||||
) {
|
||||
return g.net.DialContextTCPAddrPort(ctx, addr)
|
||||
}
|
||||
|
||||
func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
|
||||
return g.net.DialUDPAddrPort(laddr, raddr)
|
||||
}
|
||||
|
||||
func CreateTun(localAddresses []netip.Addr, mtu int) (Tunnel, error) {
|
||||
out := &gvisorNet{}
|
||||
tun, n, err := netstack.CreateNetTUN(localAddresses, nil, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out.tun, out.net = tun, n
|
||||
return out, nil
|
||||
func KernelTunSupported() bool {
|
||||
return false
|
||||
}
|
||||
|
@ -1,3 +1,5 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
@ -69,7 +71,11 @@ func (d *deviceNet) Close() (err error) {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func CreateTun(localAddresses []netip.Addr, mtu int) (t Tunnel, err error) {
|
||||
func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) {
|
||||
if handler != nil {
|
||||
return nil, newError("TODO: support promiscuous mode")
|
||||
}
|
||||
|
||||
var v4, v6 *netip.Addr
|
||||
for _, prefixes := range localAddresses {
|
||||
if v4 == nil && prefixes.Is4() {
|
||||
@ -221,3 +227,11 @@ func CreateTun(localAddresses []netip.Addr, mtu int) (t Tunnel, err error) {
|
||||
out.tun = wgt
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func KernelTunSupported() bool {
|
||||
// run a superuser permission check to check
|
||||
// if the current user has the sufficient permission
|
||||
// to create a tun device.
|
||||
|
||||
return unix.Geteuid() == 0 // 0 means root
|
||||
}
|
||||
|
@ -1,326 +1,111 @@
|
||||
/*
|
||||
|
||||
Some of codes are copied from https://github.com/octeep/wireproxy, license below.
|
||||
|
||||
Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
|
||||
|
||||
Permission to use, copy, modify, and distribute this software for any
|
||||
purpose with or without fee is hereby granted, provided that the above
|
||||
copyright notice and this permission notice appear in all copies.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
*/
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
stdnet "net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
"github.com/xtls/xray-core/common/dice"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/protocol"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
"github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/features/dns"
|
||||
"github.com/xtls/xray-core/features/policy"
|
||||
"github.com/xtls/xray-core/transport"
|
||||
"github.com/xtls/xray-core/transport/internet"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
// Handler is an outbound connection that silently swallow the entire payload.
|
||||
type Handler struct {
|
||||
conf *DeviceConfig
|
||||
net Tunnel
|
||||
bind *netBindClient
|
||||
policyManager policy.Manager
|
||||
dns dns.Client
|
||||
// cached configuration
|
||||
ipc string
|
||||
endpoints []netip.Addr
|
||||
hasIPv4, hasIPv6 bool
|
||||
wgLock sync.Mutex
|
||||
}
|
||||
//go:generate go run github.com/xtls/xray-core/common/errors/errorgen
|
||||
|
||||
// New creates a new wireguard handler.
|
||||
func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
|
||||
v := core.MustFromContext(ctx)
|
||||
|
||||
endpoints, err := parseEndpoints(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hasIPv4, hasIPv6 := false, false
|
||||
for _, e := range endpoints {
|
||||
if e.Is4() {
|
||||
hasIPv4 = true
|
||||
}
|
||||
if e.Is6() {
|
||||
hasIPv6 = true
|
||||
}
|
||||
}
|
||||
|
||||
d := v.GetFeature(dns.ClientType()).(dns.Client)
|
||||
return &Handler{
|
||||
conf: conf,
|
||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
||||
dns: d,
|
||||
ipc: createIPCRequest(conf, d, hasIPv6),
|
||||
endpoints: endpoints,
|
||||
hasIPv4: hasIPv4,
|
||||
hasIPv6: hasIPv6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
|
||||
h.wgLock.Lock()
|
||||
defer h.wgLock.Unlock()
|
||||
|
||||
if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Info,
|
||||
Content: "switching dialer",
|
||||
})
|
||||
|
||||
if h.net != nil {
|
||||
_ = h.net.Close()
|
||||
h.net = nil
|
||||
}
|
||||
if h.bind != nil {
|
||||
_ = h.bind.Close()
|
||||
h.bind = nil
|
||||
}
|
||||
|
||||
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
|
||||
bind := &netBindClient{
|
||||
dialer: dialer,
|
||||
workers: int(h.conf.NumWorkers),
|
||||
dns: h.dns,
|
||||
reserved: h.conf.Reserved,
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = bind.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
h.net, err = h.makeVirtualTun(bind)
|
||||
if err != nil {
|
||||
return newError("failed to create virtual tun interface").Base(err)
|
||||
}
|
||||
h.bind = bind
|
||||
return nil
|
||||
}
|
||||
|
||||
// Process implements OutboundHandler.Dispatch().
|
||||
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
if outbound == nil || !outbound.Target.IsValid() {
|
||||
return newError("target not specified")
|
||||
}
|
||||
outbound.Name = "wireguard"
|
||||
inbound := session.InboundFromContext(ctx)
|
||||
if inbound != nil {
|
||||
inbound.SetCanSpliceCopy(3)
|
||||
}
|
||||
|
||||
if err := h.processWireGuard(dialer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Destination of the inner request.
|
||||
destination := outbound.Target
|
||||
command := protocol.RequestCommandTCP
|
||||
if destination.Network == net.Network_UDP {
|
||||
command = protocol.RequestCommandUDP
|
||||
}
|
||||
|
||||
// resolve dns
|
||||
addr := destination.Address
|
||||
if addr.Family().IsDomain() {
|
||||
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
||||
IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
|
||||
IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
|
||||
var wgLogger = &device.Logger{
|
||||
Verbosef: func(format string, args ...any) {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Debug,
|
||||
Content: fmt.Sprintf(format, args...),
|
||||
})
|
||||
{ // Resolve fallback
|
||||
if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
|
||||
ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
||||
IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
|
||||
IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
|
||||
})
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return newError("failed to lookup DNS").Base(err)
|
||||
} else if len(ips) == 0 {
|
||||
return dns.ErrEmptyResponse
|
||||
}
|
||||
addr = net.IPAddress(ips[dice.Roll(len(ips))])
|
||||
}
|
||||
|
||||
var newCtx context.Context
|
||||
var newCancel context.CancelFunc
|
||||
if session.TimeoutOnlyFromContext(ctx) {
|
||||
newCtx, newCancel = context.WithCancel(context.Background())
|
||||
}
|
||||
|
||||
p := h.policyManager.ForLevel(0)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
timer := signal.CancelAfterInactivity(ctx, func() {
|
||||
cancel()
|
||||
if newCancel != nil {
|
||||
newCancel()
|
||||
}
|
||||
}, p.Timeouts.ConnectionIdle)
|
||||
addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
|
||||
|
||||
var requestFunc func() error
|
||||
var responseFunc func() error
|
||||
|
||||
if command == protocol.RequestCommandTCP {
|
||||
conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
|
||||
if err != nil {
|
||||
return newError("failed to create TCP connection").Base(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
requestFunc = func() error {
|
||||
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
|
||||
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
|
||||
}
|
||||
responseFunc = func() error {
|
||||
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
|
||||
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
|
||||
}
|
||||
} else if command == protocol.RequestCommandUDP {
|
||||
conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
|
||||
if err != nil {
|
||||
return newError("failed to create UDP connection").Base(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
requestFunc = func() error {
|
||||
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
|
||||
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
|
||||
}
|
||||
responseFunc = func() error {
|
||||
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
|
||||
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
|
||||
}
|
||||
}
|
||||
|
||||
if newCtx != nil {
|
||||
ctx = newCtx
|
||||
}
|
||||
|
||||
responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
|
||||
common.Interrupt(link.Reader)
|
||||
common.Interrupt(link.Writer)
|
||||
return newError("connection ends").Base(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Errorf: func(format string, args ...any) {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Error,
|
||||
Content: fmt.Sprintf(format, args...),
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// serialize the config into an IPC request
|
||||
func createIPCRequest(conf *DeviceConfig, d dns.Client, resolveEndPointToV4 bool) string {
|
||||
var request bytes.Buffer
|
||||
|
||||
request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
|
||||
|
||||
for _, peer := range conf.Peers {
|
||||
endpoint := peer.Endpoint
|
||||
host, port, err := net.SplitHostPort(endpoint)
|
||||
if resolveEndPointToV4 && err == nil {
|
||||
_, err = netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
ipList, err := d.LookupIP(host, dns.IPOption{IPv4Enable: true, IPv6Enable: false})
|
||||
if err == nil && len(ipList) > 0 {
|
||||
endpoint = stdnet.JoinHostPort(ipList[0].String(), port)
|
||||
}
|
||||
}
|
||||
func init() {
|
||||
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
||||
deviceConfig := config.(*DeviceConfig)
|
||||
if deviceConfig.IsClient {
|
||||
return New(ctx, deviceConfig)
|
||||
} else {
|
||||
return NewServer(ctx, deviceConfig)
|
||||
}
|
||||
|
||||
request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n",
|
||||
peer.PublicKey, endpoint, peer.KeepAlive, peer.PreSharedKey))
|
||||
|
||||
for _, ip := range peer.AllowedIps {
|
||||
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
|
||||
}
|
||||
}
|
||||
|
||||
return request.String()[:request.Len()]
|
||||
}))
|
||||
}
|
||||
|
||||
// convert endpoint string to netip.Addr
|
||||
func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) {
|
||||
func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, bool, bool, error) {
|
||||
var hasIPv4, hasIPv6 bool
|
||||
|
||||
endpoints := make([]netip.Addr, len(conf.Endpoint))
|
||||
for i, str := range conf.Endpoint {
|
||||
var addr netip.Addr
|
||||
if strings.Contains(str, "/") {
|
||||
prefix, err := netip.ParsePrefix(str)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, false, err
|
||||
}
|
||||
addr = prefix.Addr()
|
||||
if prefix.Bits() != addr.BitLen() {
|
||||
return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6")
|
||||
return nil, false, false, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6")
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
addr, err = netip.ParseAddr(str)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, false, err
|
||||
}
|
||||
}
|
||||
endpoints[i] = addr
|
||||
|
||||
if addr.Is4() {
|
||||
hasIPv4 = true
|
||||
} else if addr.Is6() {
|
||||
hasIPv6 = true
|
||||
}
|
||||
}
|
||||
|
||||
return endpoints, nil
|
||||
return endpoints, hasIPv4, hasIPv6, nil
|
||||
}
|
||||
|
||||
// creates a tun interface on netstack given a configuration
|
||||
func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
|
||||
t, err := CreateTun(h.endpoints, int(h.conf.Mtu))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// serialize the config into an IPC request
|
||||
func createIPCRequest(conf *DeviceConfig) string {
|
||||
var request strings.Builder
|
||||
|
||||
request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
|
||||
|
||||
if !conf.IsClient {
|
||||
// placeholder, we'll handle actual port listening on Xray
|
||||
request.WriteString("listen_port=1337\n")
|
||||
}
|
||||
|
||||
bind.dnsOption.IPv4Enable = h.hasIPv4
|
||||
bind.dnsOption.IPv6Enable = h.hasIPv6
|
||||
for _, peer := range conf.Peers {
|
||||
if peer.PublicKey != "" {
|
||||
request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey))
|
||||
}
|
||||
|
||||
if err = t.BuildDevice(h.ipc, bind); err != nil {
|
||||
_ = t.Close()
|
||||
return nil, err
|
||||
if peer.PreSharedKey != "" {
|
||||
request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey))
|
||||
}
|
||||
|
||||
if peer.Endpoint != "" {
|
||||
request.WriteString(fmt.Sprintf("endpoint=%s\n", peer.Endpoint))
|
||||
}
|
||||
|
||||
for _, ip := range peer.AllowedIps {
|
||||
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
|
||||
}
|
||||
|
||||
if peer.KeepAlive != 0 {
|
||||
request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive))
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
||||
return New(ctx, config.(*DeviceConfig))
|
||||
}))
|
||||
return request.String()[:request.Len()]
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user