mirror of
https://github.com/XTLS/Xray-core.git
synced 2025-04-30 01:08:33 +00:00
feat : upgrade wireguard go sdk (#2716)
Co-authored-by: kunson <kunson@kunsondeMacBook-Pro-3.local> Co-authored-by: 世界 <i@sekai.icu>
This commit is contained in:
parent
ea67c98eaf
commit
5ae3791a8e
7 changed files with 514 additions and 362 deletions
|
@ -24,10 +24,11 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
stdnet "net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
|
@ -46,13 +47,15 @@ import (
|
|||
// Handler is an outbound connection that silently swallow the entire payload.
|
||||
type Handler struct {
|
||||
conf *DeviceConfig
|
||||
net *Net
|
||||
net Tunnel
|
||||
bind *netBindClient
|
||||
policyManager policy.Manager
|
||||
dns dns.Client
|
||||
// cached configuration
|
||||
ipc string
|
||||
endpoints []netip.Addr
|
||||
ipc string
|
||||
endpoints []netip.Addr
|
||||
hasIPv4, hasIPv6 bool
|
||||
wgLock sync.Mutex
|
||||
}
|
||||
|
||||
// New creates a new wireguard handler.
|
||||
|
@ -64,15 +67,71 @@ func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
hasIPv4, hasIPv6 := false, false
|
||||
for _, e := range endpoints {
|
||||
if e.Is4() {
|
||||
hasIPv4 = true
|
||||
}
|
||||
if e.Is6() {
|
||||
hasIPv6 = true
|
||||
}
|
||||
}
|
||||
|
||||
d := v.GetFeature(dns.ClientType()).(dns.Client)
|
||||
return &Handler{
|
||||
conf: conf,
|
||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
||||
dns: v.GetFeature(dns.ClientType()).(dns.Client),
|
||||
ipc: createIPCRequest(conf),
|
||||
dns: d,
|
||||
ipc: createIPCRequest(conf, d, hasIPv6),
|
||||
endpoints: endpoints,
|
||||
hasIPv4: hasIPv4,
|
||||
hasIPv6: hasIPv6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
|
||||
h.wgLock.Lock()
|
||||
defer h.wgLock.Unlock()
|
||||
|
||||
if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Info,
|
||||
Content: "switching dialer",
|
||||
})
|
||||
|
||||
if h.net != nil {
|
||||
_ = h.net.Close()
|
||||
h.net = nil
|
||||
}
|
||||
if h.bind != nil {
|
||||
_ = h.bind.Close()
|
||||
h.bind = nil
|
||||
}
|
||||
|
||||
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
|
||||
bind := &netBindClient{
|
||||
dialer: dialer,
|
||||
workers: int(h.conf.NumWorkers),
|
||||
dns: h.dns,
|
||||
reserved: h.conf.Reserved,
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = bind.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
h.net, err = h.makeVirtualTun(bind)
|
||||
if err != nil {
|
||||
return newError("failed to create virtual tun interface").Base(err)
|
||||
}
|
||||
h.bind = bind
|
||||
return nil
|
||||
}
|
||||
|
||||
// Process implements OutboundHandler.Dispatch().
|
||||
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
|
@ -85,30 +144,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
|||
inbound.SetCanSpliceCopy(3)
|
||||
}
|
||||
|
||||
if h.bind == nil || h.bind.dialer != dialer || h.net == nil {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Info,
|
||||
Content: "switching dialer",
|
||||
})
|
||||
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
|
||||
bind := &netBindClient{
|
||||
dialer: dialer,
|
||||
workers: int(h.conf.NumWorkers),
|
||||
dns: h.dns,
|
||||
reserved: h.conf.Reserved,
|
||||
}
|
||||
|
||||
net, err := h.makeVirtualTun(bind)
|
||||
if err != nil {
|
||||
bind.Close()
|
||||
return newError("failed to create virtual tun interface").Base(err)
|
||||
}
|
||||
|
||||
h.net = net
|
||||
if h.bind != nil {
|
||||
h.bind.Close()
|
||||
}
|
||||
h.bind = bind
|
||||
if err := h.processWireGuard(dialer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Destination of the inner request.
|
||||
|
@ -122,8 +159,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
|||
addr := destination.Address
|
||||
if addr.Family().IsDomain() {
|
||||
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
||||
IPv4Enable: h.net.HasV4(),
|
||||
IPv6Enable: h.net.HasV6(),
|
||||
IPv4Enable: h.hasIPv4,
|
||||
IPv6Enable: h.hasIPv6,
|
||||
})
|
||||
if err != nil {
|
||||
return newError("failed to lookup DNS").Base(err)
|
||||
|
@ -200,14 +237,26 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
|||
}
|
||||
|
||||
// serialize the config into an IPC request
|
||||
func createIPCRequest(conf *DeviceConfig) string {
|
||||
func createIPCRequest(conf *DeviceConfig, d dns.Client, resolveEndPointToV4 bool) string {
|
||||
var request bytes.Buffer
|
||||
|
||||
request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
|
||||
|
||||
for _, peer := range conf.Peers {
|
||||
endpoint := peer.Endpoint
|
||||
host, port, err := net.SplitHostPort(endpoint)
|
||||
if resolveEndPointToV4 && err == nil {
|
||||
_, err = netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
ipList, err := d.LookupIP(host, dns.IPOption{IPv4Enable: true, IPv6Enable: false})
|
||||
if err == nil && len(ipList) > 0 {
|
||||
endpoint = stdnet.JoinHostPort(ipList[0].String(), port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n",
|
||||
peer.PublicKey, peer.Endpoint, peer.KeepAlive, peer.PreSharedKey))
|
||||
peer.PublicKey, endpoint, peer.KeepAlive, peer.PreSharedKey))
|
||||
|
||||
for _, ip := range peer.AllowedIps {
|
||||
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
|
||||
|
@ -245,41 +294,20 @@ func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) {
|
|||
}
|
||||
|
||||
// creates a tun interface on netstack given a configuration
|
||||
func (h *Handler) makeVirtualTun(bind *netBindClient) (*Net, error) {
|
||||
tun, tnet, err := CreateNetTUN(h.endpoints, h.dns, int(h.conf.Mtu))
|
||||
func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
|
||||
t, err := CreateTun(h.endpoints, int(h.conf.Mtu))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bind.dnsOption.IPv4Enable = tnet.HasV4()
|
||||
bind.dnsOption.IPv6Enable = tnet.HasV6()
|
||||
bind.dnsOption.IPv4Enable = h.hasIPv4
|
||||
bind.dnsOption.IPv6Enable = h.hasIPv6
|
||||
|
||||
// dev := device.NewDevice(tun, conn.NewDefaultBind(), nil /* device.NewLogger(device.LogLevelVerbose, "") */)
|
||||
dev := device.NewDevice(tun, bind, &device.Logger{
|
||||
Verbosef: func(format string, args ...any) {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Debug,
|
||||
Content: fmt.Sprintf(format, args...),
|
||||
})
|
||||
},
|
||||
Errorf: func(format string, args ...any) {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Error,
|
||||
Content: fmt.Sprintf(format, args...),
|
||||
})
|
||||
},
|
||||
}, int(h.conf.NumWorkers))
|
||||
err = dev.IpcSet(h.ipc)
|
||||
if err != nil {
|
||||
if err = t.BuildDevice(h.ipc, bind); err != nil {
|
||||
_ = t.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = dev.Up()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tnet, nil
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue