mirror of
https://github.com/XTLS/Xray-core.git
synced 2025-04-30 09:18:34 +00:00
feat : upgrade wireguard go sdk (#2716)
Co-authored-by: kunson <kunson@kunsondeMacBook-Pro-3.local> Co-authored-by: 世界 <i@sekai.icu>
This commit is contained in:
parent
ea67c98eaf
commit
5ae3791a8e
7 changed files with 514 additions and 362 deletions
|
@ -9,7 +9,8 @@ import (
|
|||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/wireguard-go/conn"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
xnet "github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/features/dns"
|
||||
"github.com/xtls/xray-core/transport/internet"
|
||||
|
@ -36,7 +37,7 @@ type netBindClient struct {
|
|||
readQueue chan *netReadInfo
|
||||
}
|
||||
|
||||
func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||
func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||
ipStr, port, _, err := splitAddrPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -44,7 +45,7 @@ func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
|
|||
|
||||
var addr net.IP
|
||||
if IsDomainName(ipStr) {
|
||||
ips, err := n.dns.LookupIP(ipStr, n.dnsOption)
|
||||
ips, err := bind.dns.LookupIP(ipStr, bind.dnsOption)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(ips) == 0 {
|
||||
|
@ -79,22 +80,22 @@ func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
|
|||
func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
bind.readQueue = make(chan *netReadInfo)
|
||||
|
||||
fun := func(buff []byte) (cap int, ep conn.Endpoint, err error) {
|
||||
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
cap = 0
|
||||
ep = nil
|
||||
n = 0
|
||||
err = errors.New("channel closed")
|
||||
}
|
||||
}()
|
||||
|
||||
r := &netReadInfo{
|
||||
buff: buff,
|
||||
buff: bufs[0],
|
||||
}
|
||||
r.waiter.Add(1)
|
||||
bind.readQueue <- r
|
||||
r.waiter.Wait() // wait read goroutine done, or we will miss the result
|
||||
return r.bytes, r.endpoint, r.err
|
||||
sizes[0], eps[0] = r.bytes, r.endpoint
|
||||
return 1, r.err
|
||||
}
|
||||
workers := bind.workers
|
||||
if workers <= 0 {
|
||||
|
@ -150,7 +151,7 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (bind *netBindClient) Send(buff []byte, endpoint conn.Endpoint) error {
|
||||
func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
|
||||
var err error
|
||||
|
||||
nend, ok := endpoint.(*netEndpoint)
|
||||
|
@ -165,19 +166,25 @@ func (bind *netBindClient) Send(buff []byte, endpoint conn.Endpoint) error {
|
|||
}
|
||||
}
|
||||
|
||||
if len(buff) > 3 && len(bind.reserved) == 3 {
|
||||
copy(buff[1:], bind.reserved)
|
||||
for _, buff := range buff {
|
||||
if len(buff) > 3 && len(bind.reserved) == 3 {
|
||||
copy(buff[1:], bind.reserved)
|
||||
}
|
||||
if _, err = nend.conn.Write(buff); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = nend.conn.Write(buff)
|
||||
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *netBindClient) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *netBindClient) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
type netEndpoint struct {
|
||||
dst xnet.Destination
|
||||
conn net.Conn
|
||||
|
@ -264,3 +271,44 @@ func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) {
|
|||
|
||||
return ip, port, v6, nil
|
||||
}
|
||||
|
||||
func IsDomainName(s string) bool {
|
||||
l := len(s)
|
||||
if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
|
||||
return false
|
||||
}
|
||||
last := byte('.')
|
||||
nonNumeric := false
|
||||
partlen := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
switch {
|
||||
default:
|
||||
return false
|
||||
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
|
||||
nonNumeric = true
|
||||
partlen++
|
||||
case '0' <= c && c <= '9':
|
||||
partlen++
|
||||
case c == '-':
|
||||
if last == '.' {
|
||||
return false
|
||||
}
|
||||
partlen++
|
||||
nonNumeric = true
|
||||
case c == '.':
|
||||
if last == '.' || last == '-' {
|
||||
return false
|
||||
}
|
||||
if partlen > 63 || partlen == 0 {
|
||||
return false
|
||||
}
|
||||
partlen = 0
|
||||
}
|
||||
last = c
|
||||
}
|
||||
if last == '-' || partlen > 63 {
|
||||
return false
|
||||
}
|
||||
return nonNumeric
|
||||
}
|
||||
|
|
|
@ -1,303 +1,105 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/wireguard-go/tun"
|
||||
"github.com/xtls/xray-core/features/dns"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
type netTun struct {
|
||||
ep *channel.Endpoint
|
||||
stack *stack.Stack
|
||||
events chan tun.Event
|
||||
incomingPacket chan *buffer.View
|
||||
mtu int
|
||||
dnsClient dns.Client
|
||||
hasV4, hasV6 bool
|
||||
type Tunnel interface {
|
||||
BuildDevice(ipc string, bind conn.Bind) error
|
||||
DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
|
||||
DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Net netTun
|
||||
|
||||
func CreateNetTUN(localAddresses []netip.Addr, dnsClient dns.Client, mtu int) (tun.Device, *Net, error) {
|
||||
opts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
|
||||
HandleLocal: true,
|
||||
}
|
||||
dev := &netTun{
|
||||
ep: channel.New(1024, uint32(mtu), ""),
|
||||
stack: stack.New(opts),
|
||||
events: make(chan tun.Event, 10),
|
||||
incomingPacket: make(chan *buffer.View),
|
||||
dnsClient: dnsClient,
|
||||
mtu: mtu,
|
||||
}
|
||||
dev.ep.AddNotify(dev)
|
||||
tcpipErr := dev.stack.CreateNIC(1, dev.ep)
|
||||
if tcpipErr != nil {
|
||||
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
|
||||
}
|
||||
for _, ip := range localAddresses {
|
||||
var protoNumber tcpip.NetworkProtocolNumber
|
||||
if ip.Is4() {
|
||||
protoNumber = ipv4.ProtocolNumber
|
||||
} else if ip.Is6() {
|
||||
protoNumber = ipv6.ProtocolNumber
|
||||
}
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: protoNumber,
|
||||
AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
|
||||
}
|
||||
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
|
||||
if tcpipErr != nil {
|
||||
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
|
||||
}
|
||||
if ip.Is4() {
|
||||
dev.hasV4 = true
|
||||
} else if ip.Is6() {
|
||||
dev.hasV6 = true
|
||||
}
|
||||
}
|
||||
if dev.hasV4 {
|
||||
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
|
||||
}
|
||||
if dev.hasV6 {
|
||||
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
|
||||
}
|
||||
|
||||
dev.events <- tun.EventUp
|
||||
return dev, (*Net)(dev), nil
|
||||
type tunnel struct {
|
||||
tun tun.Device
|
||||
device *device.Device
|
||||
rw sync.Mutex
|
||||
}
|
||||
|
||||
func (tun *netTun) Name() (string, error) {
|
||||
return "go", nil
|
||||
}
|
||||
func (t *tunnel) BuildDevice(ipc string, bind conn.Bind) (err error) {
|
||||
t.rw.Lock()
|
||||
defer t.rw.Unlock()
|
||||
|
||||
func (tun *netTun) File() *os.File {
|
||||
if t.device != nil {
|
||||
return errors.New("device is already initialized")
|
||||
}
|
||||
|
||||
logger := &device.Logger{
|
||||
Verbosef: func(format string, args ...any) {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Debug,
|
||||
Content: fmt.Sprintf(format, args...),
|
||||
})
|
||||
},
|
||||
Errorf: func(format string, args ...any) {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Error,
|
||||
Content: fmt.Sprintf(format, args...),
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
t.device = device.NewDevice(t.tun, bind, logger)
|
||||
if err = t.device.IpcSet(ipc); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = t.device.Up(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *netTun) Events() chan tun.Event {
|
||||
return tun.events
|
||||
}
|
||||
func (t *tunnel) Close() (err error) {
|
||||
t.rw.Lock()
|
||||
defer t.rw.Unlock()
|
||||
|
||||
func (tun *netTun) Read(buf []byte, offset int) (int, error) {
|
||||
view, ok := <-tun.incomingPacket
|
||||
if !ok {
|
||||
return 0, os.ErrClosed
|
||||
if t.device == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return view.Read(buf[offset:])
|
||||
t.device.Close()
|
||||
t.device = nil
|
||||
err = t.tun.Close()
|
||||
t.tun = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *netTun) Write(buf []byte, offset int) (int, error) {
|
||||
packet := buf[offset:]
|
||||
if len(packet) == 0 {
|
||||
return 0, nil
|
||||
func CalculateInterfaceName(name string) (tunName string) {
|
||||
if runtime.GOOS == "darwin" {
|
||||
tunName = "utun"
|
||||
} else if name != "" {
|
||||
tunName = name
|
||||
} else {
|
||||
tunName = "tun"
|
||||
}
|
||||
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
|
||||
switch packet[0] >> 4 {
|
||||
case 4:
|
||||
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
||||
case 6:
|
||||
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
|
||||
}
|
||||
|
||||
return len(buf), nil
|
||||
}
|
||||
|
||||
func (tun *netTun) WriteNotify() {
|
||||
pkt := tun.ep.Read()
|
||||
if pkt == nil {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
view := pkt.ToView()
|
||||
pkt.DecRef()
|
||||
|
||||
tun.incomingPacket <- view
|
||||
}
|
||||
|
||||
func (tun *netTun) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *netTun) Close() error {
|
||||
tun.stack.RemoveNIC(1)
|
||||
|
||||
if tun.events != nil {
|
||||
close(tun.events)
|
||||
}
|
||||
|
||||
tun.ep.Close()
|
||||
|
||||
if tun.incomingPacket != nil {
|
||||
close(tun.incomingPacket)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *netTun) MTU() (int, error) {
|
||||
return tun.mtu, nil
|
||||
}
|
||||
|
||||
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
|
||||
var protoNumber tcpip.NetworkProtocolNumber
|
||||
if endpoint.Addr().Is4() {
|
||||
protoNumber = ipv4.ProtocolNumber
|
||||
} else {
|
||||
protoNumber = ipv6.ProtocolNumber
|
||||
}
|
||||
return tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
|
||||
Port: endpoint.Port(),
|
||||
}, protoNumber
|
||||
}
|
||||
|
||||
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
|
||||
fa, pn := convertToFullAddr(addr)
|
||||
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
|
||||
}
|
||||
|
||||
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
|
||||
if addr == nil {
|
||||
return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
|
||||
}
|
||||
ip, _ := netip.AddrFromSlice(addr.IP)
|
||||
return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
|
||||
}
|
||||
|
||||
func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
|
||||
fa, pn := convertToFullAddr(addr)
|
||||
return gonet.DialTCP(net.stack, fa, pn)
|
||||
}
|
||||
|
||||
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
|
||||
if addr == nil {
|
||||
return net.DialTCPAddrPort(netip.AddrPort{})
|
||||
}
|
||||
ip, _ := netip.AddrFromSlice(addr.IP)
|
||||
return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
|
||||
}
|
||||
|
||||
func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
|
||||
fa, pn := convertToFullAddr(addr)
|
||||
return gonet.ListenTCP(net.stack, fa, pn)
|
||||
}
|
||||
|
||||
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
|
||||
if addr == nil {
|
||||
return net.ListenTCPAddrPort(netip.AddrPort{})
|
||||
}
|
||||
ip, _ := netip.AddrFromSlice(addr.IP)
|
||||
return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
|
||||
}
|
||||
|
||||
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
|
||||
var lfa, rfa *tcpip.FullAddress
|
||||
var pn tcpip.NetworkProtocolNumber
|
||||
if laddr.IsValid() || laddr.Port() > 0 {
|
||||
var addr tcpip.FullAddress
|
||||
addr, pn = convertToFullAddr(laddr)
|
||||
lfa = &addr
|
||||
}
|
||||
if raddr.IsValid() || raddr.Port() > 0 {
|
||||
var addr tcpip.FullAddress
|
||||
addr, pn = convertToFullAddr(raddr)
|
||||
rfa = &addr
|
||||
}
|
||||
return gonet.DialUDP(net.stack, lfa, rfa, pn)
|
||||
}
|
||||
|
||||
func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
|
||||
return net.DialUDPAddrPort(laddr, netip.AddrPort{})
|
||||
}
|
||||
|
||||
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
||||
var la, ra netip.AddrPort
|
||||
if laddr != nil {
|
||||
ip, _ := netip.AddrFromSlice(laddr.IP)
|
||||
la = netip.AddrPortFrom(ip, uint16(laddr.Port))
|
||||
}
|
||||
if raddr != nil {
|
||||
ip, _ := netip.AddrFromSlice(raddr.IP)
|
||||
ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
|
||||
}
|
||||
return net.DialUDPAddrPort(la, ra)
|
||||
}
|
||||
|
||||
func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
||||
return net.DialUDP(laddr, nil)
|
||||
}
|
||||
|
||||
func (n *Net) HasV4() bool {
|
||||
return n.hasV4
|
||||
}
|
||||
|
||||
func (n *Net) HasV6() bool {
|
||||
return n.hasV6
|
||||
}
|
||||
|
||||
func IsDomainName(s string) bool {
|
||||
l := len(s)
|
||||
if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
|
||||
return false
|
||||
}
|
||||
last := byte('.')
|
||||
nonNumeric := false
|
||||
partlen := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
switch {
|
||||
default:
|
||||
return false
|
||||
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
|
||||
nonNumeric = true
|
||||
partlen++
|
||||
case '0' <= c && c <= '9':
|
||||
partlen++
|
||||
case c == '-':
|
||||
if last == '.' {
|
||||
return false
|
||||
var tunIndex int
|
||||
for _, netInterface := range interfaces {
|
||||
if strings.HasPrefix(netInterface.Name, tunName) {
|
||||
index, parseErr := strconv.ParseInt(netInterface.Name[len(tunName):], 10, 16)
|
||||
if parseErr == nil {
|
||||
tunIndex = int(index) + 1
|
||||
}
|
||||
partlen++
|
||||
nonNumeric = true
|
||||
case c == '.':
|
||||
if last == '.' || last == '-' {
|
||||
return false
|
||||
}
|
||||
if partlen > 63 || partlen == 0 {
|
||||
return false
|
||||
}
|
||||
partlen = 0
|
||||
}
|
||||
last = c
|
||||
}
|
||||
if last == '-' || partlen > 63 {
|
||||
return false
|
||||
}
|
||||
return nonNumeric
|
||||
tunName = fmt.Sprintf("%s%d", tunName, tunIndex)
|
||||
return
|
||||
}
|
||||
|
|
42
proxy/wireguard/tun_default.go
Normal file
42
proxy/wireguard/tun_default.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
//go:build !linux
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
)
|
||||
|
||||
var _ Tunnel = (*gvisorNet)(nil)
|
||||
|
||||
type gvisorNet struct {
|
||||
tunnel
|
||||
net *netstack.Net
|
||||
}
|
||||
|
||||
func (g *gvisorNet) Close() error {
|
||||
return g.tunnel.Close()
|
||||
}
|
||||
|
||||
func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
|
||||
net.Conn, error,
|
||||
) {
|
||||
return g.net.DialContextTCPAddrPort(ctx, addr)
|
||||
}
|
||||
|
||||
func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
|
||||
return g.net.DialUDPAddrPort(laddr, raddr)
|
||||
}
|
||||
|
||||
func CreateTun(localAddresses []netip.Addr, mtu int) (Tunnel, error) {
|
||||
out := &gvisorNet{}
|
||||
tun, n, err := netstack.CreateNetTUN(localAddresses, nil, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out.tun, out.net = tun, n
|
||||
return out, nil
|
||||
}
|
223
proxy/wireguard/tun_linux.go
Normal file
223
proxy/wireguard/tun_linux.go
Normal file
|
@ -0,0 +1,223 @@
|
|||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/sagernet/sing/common/control"
|
||||
"github.com/vishvananda/netlink"
|
||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
type deviceNet struct {
|
||||
tunnel
|
||||
dialer net.Dialer
|
||||
|
||||
handle *netlink.Handle
|
||||
linkAddrs []netlink.Addr
|
||||
routes []*netlink.Route
|
||||
rules []*netlink.Rule
|
||||
}
|
||||
|
||||
func newDeviceNet(interfaceName string) *deviceNet {
|
||||
var dialer net.Dialer
|
||||
bindControl := control.BindToInterface(control.DefaultInterfaceFinder(), interfaceName, -1)
|
||||
dialer.Control = control.Append(dialer.Control, bindControl)
|
||||
return &deviceNet{dialer: dialer}
|
||||
}
|
||||
|
||||
func (d *deviceNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
|
||||
net.Conn, error,
|
||||
) {
|
||||
return d.dialer.DialContext(ctx, "tcp", addr.String())
|
||||
}
|
||||
|
||||
func (d *deviceNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
|
||||
dialer := d.dialer
|
||||
dialer.LocalAddr = &net.UDPAddr{IP: laddr.Addr().AsSlice(), Port: int(laddr.Port())}
|
||||
return dialer.DialContext(context.Background(), "udp", raddr.String())
|
||||
}
|
||||
|
||||
func (d *deviceNet) Close() (err error) {
|
||||
var errs []error
|
||||
for _, rule := range d.rules {
|
||||
if err = d.handle.RuleDel(rule); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to delete rule: %w", err))
|
||||
}
|
||||
}
|
||||
for _, route := range d.routes {
|
||||
if err = d.handle.RouteDel(route); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to delete route: %w", err))
|
||||
}
|
||||
}
|
||||
if err = d.tunnel.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to close tunnel: %w", err))
|
||||
}
|
||||
if d.handle != nil {
|
||||
d.handle.Close()
|
||||
d.handle = nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func CreateTun(localAddresses []netip.Addr, mtu int) (t Tunnel, err error) {
|
||||
var v4, v6 *netip.Addr
|
||||
for _, prefixes := range localAddresses {
|
||||
if v4 == nil && prefixes.Is4() {
|
||||
x := prefixes
|
||||
v4 = &x
|
||||
}
|
||||
if v6 == nil && prefixes.Is6() {
|
||||
x := prefixes
|
||||
v6 = &x
|
||||
}
|
||||
}
|
||||
|
||||
writeSysctlZero := func(path string) error {
|
||||
_, err := os.Stat(path)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, []byte("0"), 0o644)
|
||||
}
|
||||
|
||||
// system configs.
|
||||
if v4 != nil {
|
||||
if err = writeSysctlZero("/proc/sys/net/ipv4/conf/all/rp_filter"); err != nil {
|
||||
return nil, fmt.Errorf("failed to disable ipv4 rp_filter for all: %w", err)
|
||||
}
|
||||
}
|
||||
if v6 != nil {
|
||||
if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/disable_ipv6"); err != nil {
|
||||
return nil, fmt.Errorf("failed to enable ipv6: %w", err)
|
||||
}
|
||||
if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/rp_filter"); err != nil {
|
||||
return nil, fmt.Errorf("failed to disable ipv6 rp_filter for all: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
n := CalculateInterfaceName("wg")
|
||||
wgt, err := wgtun.CreateTUN(n, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = wgt.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// disable linux rp_filter for tunnel device to avoid packet drop.
|
||||
// the operation require root privilege on container require '--privileged' flag.
|
||||
if v4 != nil {
|
||||
if err = writeSysctlZero("/proc/sys/net/ipv4/conf/" + n + "/rp_filter"); err != nil {
|
||||
return nil, fmt.Errorf("failed to disable ipv4 rp_filter for tunnel: %w", err)
|
||||
}
|
||||
}
|
||||
if v6 != nil {
|
||||
if err = writeSysctlZero("/proc/sys/net/ipv6/conf/" + n + "/rp_filter"); err != nil {
|
||||
return nil, fmt.Errorf("failed to disable ipv6 rp_filter for tunnel: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
ipv6TableIndex := 1023
|
||||
if v6 != nil {
|
||||
r := &netlink.Route{Table: ipv6TableIndex}
|
||||
for {
|
||||
routeList, fErr := netlink.RouteListFiltered(netlink.FAMILY_V6, r, netlink.RT_FILTER_TABLE)
|
||||
if len(routeList) == 0 || fErr != nil {
|
||||
break
|
||||
}
|
||||
ipv6TableIndex--
|
||||
if ipv6TableIndex < 0 {
|
||||
return nil, fmt.Errorf("failed to find available ipv6 table index")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out := newDeviceNet(n)
|
||||
out.handle, err = netlink.NewHandle()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = out.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
l, err := netlink.LinkByName(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if v4 != nil {
|
||||
addr := netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: v4.AsSlice(),
|
||||
Mask: net.CIDRMask(v4.BitLen(), v4.BitLen()),
|
||||
},
|
||||
}
|
||||
out.linkAddrs = append(out.linkAddrs, addr)
|
||||
}
|
||||
if v6 != nil {
|
||||
addr := netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: v6.AsSlice(),
|
||||
Mask: net.CIDRMask(v6.BitLen(), v6.BitLen()),
|
||||
},
|
||||
}
|
||||
out.linkAddrs = append(out.linkAddrs, addr)
|
||||
|
||||
rt := &netlink.Route{
|
||||
LinkIndex: l.Attrs().Index,
|
||||
Dst: &net.IPNet{
|
||||
IP: net.IPv6zero,
|
||||
Mask: net.CIDRMask(0, 128),
|
||||
},
|
||||
Table: ipv6TableIndex,
|
||||
}
|
||||
out.routes = append(out.routes, rt)
|
||||
|
||||
r := netlink.NewRule()
|
||||
r.Table, r.Family, r.Src = ipv6TableIndex, unix.AF_INET6, addr.IPNet
|
||||
out.rules = append(out.rules, r)
|
||||
}
|
||||
|
||||
for _, addr := range out.linkAddrs {
|
||||
if err = out.handle.AddrAdd(l, &addr); err != nil {
|
||||
return nil, fmt.Errorf("failed to add address %s to %s: %w", addr, n, err)
|
||||
}
|
||||
}
|
||||
if err = out.handle.LinkSetMTU(l, mtu); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = out.handle.LinkSetUp(l); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, route := range out.routes {
|
||||
if err = out.handle.RouteAdd(route); err != nil {
|
||||
return nil, fmt.Errorf("failed to add route %s: %w", route, err)
|
||||
}
|
||||
}
|
||||
for _, rule := range out.rules {
|
||||
if err = out.handle.RuleAdd(rule); err != nil {
|
||||
return nil, fmt.Errorf("failed to add rule %s: %w", rule, err)
|
||||
}
|
||||
}
|
||||
out.tun = wgt
|
||||
return out, nil
|
||||
}
|
|
@ -24,10 +24,11 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
stdnet "net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
|
@ -46,13 +47,15 @@ import (
|
|||
// Handler is an outbound connection that silently swallow the entire payload.
|
||||
type Handler struct {
|
||||
conf *DeviceConfig
|
||||
net *Net
|
||||
net Tunnel
|
||||
bind *netBindClient
|
||||
policyManager policy.Manager
|
||||
dns dns.Client
|
||||
// cached configuration
|
||||
ipc string
|
||||
endpoints []netip.Addr
|
||||
ipc string
|
||||
endpoints []netip.Addr
|
||||
hasIPv4, hasIPv6 bool
|
||||
wgLock sync.Mutex
|
||||
}
|
||||
|
||||
// New creates a new wireguard handler.
|
||||
|
@ -64,15 +67,71 @@ func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
hasIPv4, hasIPv6 := false, false
|
||||
for _, e := range endpoints {
|
||||
if e.Is4() {
|
||||
hasIPv4 = true
|
||||
}
|
||||
if e.Is6() {
|
||||
hasIPv6 = true
|
||||
}
|
||||
}
|
||||
|
||||
d := v.GetFeature(dns.ClientType()).(dns.Client)
|
||||
return &Handler{
|
||||
conf: conf,
|
||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
||||
dns: v.GetFeature(dns.ClientType()).(dns.Client),
|
||||
ipc: createIPCRequest(conf),
|
||||
dns: d,
|
||||
ipc: createIPCRequest(conf, d, hasIPv6),
|
||||
endpoints: endpoints,
|
||||
hasIPv4: hasIPv4,
|
||||
hasIPv6: hasIPv6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
|
||||
h.wgLock.Lock()
|
||||
defer h.wgLock.Unlock()
|
||||
|
||||
if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Info,
|
||||
Content: "switching dialer",
|
||||
})
|
||||
|
||||
if h.net != nil {
|
||||
_ = h.net.Close()
|
||||
h.net = nil
|
||||
}
|
||||
if h.bind != nil {
|
||||
_ = h.bind.Close()
|
||||
h.bind = nil
|
||||
}
|
||||
|
||||
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
|
||||
bind := &netBindClient{
|
||||
dialer: dialer,
|
||||
workers: int(h.conf.NumWorkers),
|
||||
dns: h.dns,
|
||||
reserved: h.conf.Reserved,
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = bind.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
h.net, err = h.makeVirtualTun(bind)
|
||||
if err != nil {
|
||||
return newError("failed to create virtual tun interface").Base(err)
|
||||
}
|
||||
h.bind = bind
|
||||
return nil
|
||||
}
|
||||
|
||||
// Process implements OutboundHandler.Dispatch().
|
||||
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
|
@ -85,30 +144,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
|||
inbound.SetCanSpliceCopy(3)
|
||||
}
|
||||
|
||||
if h.bind == nil || h.bind.dialer != dialer || h.net == nil {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Info,
|
||||
Content: "switching dialer",
|
||||
})
|
||||
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
|
||||
bind := &netBindClient{
|
||||
dialer: dialer,
|
||||
workers: int(h.conf.NumWorkers),
|
||||
dns: h.dns,
|
||||
reserved: h.conf.Reserved,
|
||||
}
|
||||
|
||||
net, err := h.makeVirtualTun(bind)
|
||||
if err != nil {
|
||||
bind.Close()
|
||||
return newError("failed to create virtual tun interface").Base(err)
|
||||
}
|
||||
|
||||
h.net = net
|
||||
if h.bind != nil {
|
||||
h.bind.Close()
|
||||
}
|
||||
h.bind = bind
|
||||
if err := h.processWireGuard(dialer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Destination of the inner request.
|
||||
|
@ -122,8 +159,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
|||
addr := destination.Address
|
||||
if addr.Family().IsDomain() {
|
||||
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
||||
IPv4Enable: h.net.HasV4(),
|
||||
IPv6Enable: h.net.HasV6(),
|
||||
IPv4Enable: h.hasIPv4,
|
||||
IPv6Enable: h.hasIPv6,
|
||||
})
|
||||
if err != nil {
|
||||
return newError("failed to lookup DNS").Base(err)
|
||||
|
@ -200,14 +237,26 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
|||
}
|
||||
|
||||
// serialize the config into an IPC request
|
||||
func createIPCRequest(conf *DeviceConfig) string {
|
||||
func createIPCRequest(conf *DeviceConfig, d dns.Client, resolveEndPointToV4 bool) string {
|
||||
var request bytes.Buffer
|
||||
|
||||
request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
|
||||
|
||||
for _, peer := range conf.Peers {
|
||||
endpoint := peer.Endpoint
|
||||
host, port, err := net.SplitHostPort(endpoint)
|
||||
if resolveEndPointToV4 && err == nil {
|
||||
_, err = netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
ipList, err := d.LookupIP(host, dns.IPOption{IPv4Enable: true, IPv6Enable: false})
|
||||
if err == nil && len(ipList) > 0 {
|
||||
endpoint = stdnet.JoinHostPort(ipList[0].String(), port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n",
|
||||
peer.PublicKey, peer.Endpoint, peer.KeepAlive, peer.PreSharedKey))
|
||||
peer.PublicKey, endpoint, peer.KeepAlive, peer.PreSharedKey))
|
||||
|
||||
for _, ip := range peer.AllowedIps {
|
||||
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
|
||||
|
@ -245,41 +294,20 @@ func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) {
|
|||
}
|
||||
|
||||
// creates a tun interface on netstack given a configuration
|
||||
func (h *Handler) makeVirtualTun(bind *netBindClient) (*Net, error) {
|
||||
tun, tnet, err := CreateNetTUN(h.endpoints, h.dns, int(h.conf.Mtu))
|
||||
func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
|
||||
t, err := CreateTun(h.endpoints, int(h.conf.Mtu))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bind.dnsOption.IPv4Enable = tnet.HasV4()
|
||||
bind.dnsOption.IPv6Enable = tnet.HasV6()
|
||||
bind.dnsOption.IPv4Enable = h.hasIPv4
|
||||
bind.dnsOption.IPv6Enable = h.hasIPv6
|
||||
|
||||
// dev := device.NewDevice(tun, conn.NewDefaultBind(), nil /* device.NewLogger(device.LogLevelVerbose, "") */)
|
||||
dev := device.NewDevice(tun, bind, &device.Logger{
|
||||
Verbosef: func(format string, args ...any) {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Debug,
|
||||
Content: fmt.Sprintf(format, args...),
|
||||
})
|
||||
},
|
||||
Errorf: func(format string, args ...any) {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Error,
|
||||
Content: fmt.Sprintf(format, args...),
|
||||
})
|
||||
},
|
||||
}, int(h.conf.NumWorkers))
|
||||
err = dev.IpcSet(h.ipc)
|
||||
if err != nil {
|
||||
if err = t.BuildDevice(h.ipc, bind); err != nil {
|
||||
_ = t.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = dev.Up()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tnet, nil
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue