feat : upgrade wireguard go sdk (#2716)

Co-authored-by: kunson <kunson@kunsondeMacBook-Pro-3.local>
Co-authored-by: 世界 <i@sekai.icu>
This commit is contained in:
yuhan6665 2023-11-12 15:10:01 -05:00 committed by GitHub
parent ea67c98eaf
commit 5ae3791a8e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 514 additions and 362 deletions

View file

@ -24,10 +24,11 @@ import (
"bytes"
"context"
"fmt"
stdnet "net"
"net/netip"
"strings"
"sync"
"github.com/sagernet/wireguard-go/device"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/log"
@ -46,13 +47,15 @@ import (
// Handler is an outbound connection that silently swallow the entire payload.
type Handler struct {
conf *DeviceConfig
net *Net
net Tunnel
bind *netBindClient
policyManager policy.Manager
dns dns.Client
// cached configuration
ipc string
endpoints []netip.Addr
ipc string
endpoints []netip.Addr
hasIPv4, hasIPv6 bool
wgLock sync.Mutex
}
// New creates a new wireguard handler.
@ -64,15 +67,71 @@ func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
return nil, err
}
hasIPv4, hasIPv6 := false, false
for _, e := range endpoints {
if e.Is4() {
hasIPv4 = true
}
if e.Is6() {
hasIPv6 = true
}
}
d := v.GetFeature(dns.ClientType()).(dns.Client)
return &Handler{
conf: conf,
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
dns: v.GetFeature(dns.ClientType()).(dns.Client),
ipc: createIPCRequest(conf),
dns: d,
ipc: createIPCRequest(conf, d, hasIPv6),
endpoints: endpoints,
hasIPv4: hasIPv4,
hasIPv6: hasIPv6,
}, nil
}
func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
h.wgLock.Lock()
defer h.wgLock.Unlock()
if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
return nil
}
log.Record(&log.GeneralMessage{
Severity: log.Severity_Info,
Content: "switching dialer",
})
if h.net != nil {
_ = h.net.Close()
h.net = nil
}
if h.bind != nil {
_ = h.bind.Close()
h.bind = nil
}
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
bind := &netBindClient{
dialer: dialer,
workers: int(h.conf.NumWorkers),
dns: h.dns,
reserved: h.conf.Reserved,
}
defer func() {
if err != nil {
_ = bind.Close()
}
}()
h.net, err = h.makeVirtualTun(bind)
if err != nil {
return newError("failed to create virtual tun interface").Base(err)
}
h.bind = bind
return nil
}
// Process implements OutboundHandler.Dispatch().
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
@ -85,30 +144,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
inbound.SetCanSpliceCopy(3)
}
if h.bind == nil || h.bind.dialer != dialer || h.net == nil {
log.Record(&log.GeneralMessage{
Severity: log.Severity_Info,
Content: "switching dialer",
})
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
bind := &netBindClient{
dialer: dialer,
workers: int(h.conf.NumWorkers),
dns: h.dns,
reserved: h.conf.Reserved,
}
net, err := h.makeVirtualTun(bind)
if err != nil {
bind.Close()
return newError("failed to create virtual tun interface").Base(err)
}
h.net = net
if h.bind != nil {
h.bind.Close()
}
h.bind = bind
if err := h.processWireGuard(dialer); err != nil {
return err
}
// Destination of the inner request.
@ -122,8 +159,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
addr := destination.Address
if addr.Family().IsDomain() {
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
IPv4Enable: h.net.HasV4(),
IPv6Enable: h.net.HasV6(),
IPv4Enable: h.hasIPv4,
IPv6Enable: h.hasIPv6,
})
if err != nil {
return newError("failed to lookup DNS").Base(err)
@ -200,14 +237,26 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
}
// serialize the config into an IPC request
func createIPCRequest(conf *DeviceConfig) string {
func createIPCRequest(conf *DeviceConfig, d dns.Client, resolveEndPointToV4 bool) string {
var request bytes.Buffer
request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
for _, peer := range conf.Peers {
endpoint := peer.Endpoint
host, port, err := net.SplitHostPort(endpoint)
if resolveEndPointToV4 && err == nil {
_, err = netip.ParseAddr(host)
if err != nil {
ipList, err := d.LookupIP(host, dns.IPOption{IPv4Enable: true, IPv6Enable: false})
if err == nil && len(ipList) > 0 {
endpoint = stdnet.JoinHostPort(ipList[0].String(), port)
}
}
}
request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n",
peer.PublicKey, peer.Endpoint, peer.KeepAlive, peer.PreSharedKey))
peer.PublicKey, endpoint, peer.KeepAlive, peer.PreSharedKey))
for _, ip := range peer.AllowedIps {
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
@ -245,41 +294,20 @@ func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) {
}
// creates a tun interface on netstack given a configuration
func (h *Handler) makeVirtualTun(bind *netBindClient) (*Net, error) {
tun, tnet, err := CreateNetTUN(h.endpoints, h.dns, int(h.conf.Mtu))
func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
t, err := CreateTun(h.endpoints, int(h.conf.Mtu))
if err != nil {
return nil, err
}
bind.dnsOption.IPv4Enable = tnet.HasV4()
bind.dnsOption.IPv6Enable = tnet.HasV6()
bind.dnsOption.IPv4Enable = h.hasIPv4
bind.dnsOption.IPv6Enable = h.hasIPv6
// dev := device.NewDevice(tun, conn.NewDefaultBind(), nil /* device.NewLogger(device.LogLevelVerbose, "") */)
dev := device.NewDevice(tun, bind, &device.Logger{
Verbosef: func(format string, args ...any) {
log.Record(&log.GeneralMessage{
Severity: log.Severity_Debug,
Content: fmt.Sprintf(format, args...),
})
},
Errorf: func(format string, args ...any) {
log.Record(&log.GeneralMessage{
Severity: log.Severity_Error,
Content: fmt.Sprintf(format, args...),
})
},
}, int(h.conf.NumWorkers))
err = dev.IpcSet(h.ipc)
if err != nil {
if err = t.BuildDevice(h.ipc, bind); err != nil {
_ = t.Close()
return nil, err
}
err = dev.Up()
if err != nil {
return nil, err
}
return tnet, nil
return t, nil
}
func init() {