mirror of
https://github.com/XTLS/Xray-core.git
synced 2025-04-30 01:08:33 +00:00
Implement WireGuard protocol as outbound (client) (#1344)
* implement WireGuard protocol for Outbound * upload license * fix build for openbsd & dragonfly os * updated wireguard-go * fix up * switch to another wireguard fork * fix * switch to upstream * open connection through internet.Dialer (#1) * use internet.Dialer * maybe better code * fix * real fix Co-authored-by: nanoda0523 <nanoda0523@users.noreply.github.com> * fix bugs & add ability to recover during connection reset on UDP over TCP parent protocols * improve performance improve performance * dns lookup endpoint && remove unused code * interface address fallback * better code && add config test case Co-authored-by: nanoda0523 <nanoda0523@users.noreply.github.com>
This commit is contained in:
parent
691b2b1c73
commit
e18b52a5df
12 changed files with 1326 additions and 1 deletions
254
proxy/wireguard/bind.go
Normal file
254
proxy/wireguard/bind.go
Normal file
|
@ -0,0 +1,254 @@
|
|||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/wireguard-go/conn"
|
||||
xnet "github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/features/dns"
|
||||
"github.com/xtls/xray-core/transport/internet"
|
||||
)
|
||||
|
||||
type netReadInfo struct {
|
||||
// status
|
||||
waiter sync.WaitGroup
|
||||
// param
|
||||
buff []byte
|
||||
// result
|
||||
bytes int
|
||||
endpoint conn.Endpoint
|
||||
err error
|
||||
}
|
||||
|
||||
type netBindClient struct {
|
||||
workers int
|
||||
dialer internet.Dialer
|
||||
dns dns.Client
|
||||
dnsOption dns.IPOption
|
||||
|
||||
readQueue chan *netReadInfo
|
||||
}
|
||||
|
||||
func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||
ipStr, port, _, err := splitAddrPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var addr net.IP
|
||||
if IsDomainName(ipStr) {
|
||||
ips, err := n.dns.LookupIP(ipStr, n.dnsOption)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(ips) == 0 {
|
||||
return nil, dns.ErrEmptyResponse
|
||||
}
|
||||
addr = ips[0]
|
||||
} else {
|
||||
addr = net.ParseIP(ipStr)
|
||||
}
|
||||
if addr == nil {
|
||||
return nil, errors.New("failed to parse ip: " + ipStr)
|
||||
}
|
||||
|
||||
var ip xnet.Address
|
||||
if p4 := addr.To4(); len(p4) == net.IPv4len {
|
||||
ip = xnet.IPAddress(p4[:])
|
||||
} else {
|
||||
ip = xnet.IPAddress(addr[:])
|
||||
}
|
||||
|
||||
dst := xnet.Destination{
|
||||
Address: ip,
|
||||
Port: xnet.Port(port),
|
||||
Network: xnet.Network_UDP,
|
||||
}
|
||||
|
||||
return &netEndpoint{
|
||||
dst: dst,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
bind.readQueue = make(chan *netReadInfo)
|
||||
|
||||
fun := func(buff []byte) (cap int, ep conn.Endpoint, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
cap = 0
|
||||
ep = nil
|
||||
err = errors.New("channel closed")
|
||||
}
|
||||
}()
|
||||
|
||||
r := &netReadInfo{
|
||||
buff: buff,
|
||||
}
|
||||
r.waiter.Add(1)
|
||||
bind.readQueue <- r
|
||||
r.waiter.Wait() // wait read goroutine done, or we will miss the result
|
||||
return r.bytes, r.endpoint, r.err
|
||||
}
|
||||
workers := bind.workers
|
||||
if workers <= 0 {
|
||||
workers = 1
|
||||
}
|
||||
arr := make([]conn.ReceiveFunc, workers)
|
||||
for i := 0; i < workers; i++ {
|
||||
arr[i] = fun
|
||||
}
|
||||
|
||||
return arr, uint16(uport), nil
|
||||
}
|
||||
|
||||
func (bind *netBindClient) Close() error {
|
||||
if bind.readQueue != nil {
|
||||
close(bind.readQueue)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
||||
c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
endpoint.conn = c
|
||||
|
||||
go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
|
||||
for {
|
||||
v, ok := <-readQueue
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
i, err := c.Read(v.buff)
|
||||
v.bytes = i
|
||||
v.endpoint = endpoint
|
||||
v.err = err
|
||||
v.waiter.Done()
|
||||
if err != nil && errors.Is(err, io.EOF) {
|
||||
endpoint.conn = nil
|
||||
return
|
||||
}
|
||||
}
|
||||
}(bind.readQueue, endpoint)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *netBindClient) Send(buff []byte, endpoint conn.Endpoint) error {
|
||||
var err error
|
||||
|
||||
nend, ok := endpoint.(*netEndpoint)
|
||||
if !ok {
|
||||
return conn.ErrWrongEndpointType
|
||||
}
|
||||
|
||||
if nend.conn == nil {
|
||||
err = bind.connectTo(nend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = nend.conn.Write(buff)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (bind *netBindClient) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type netEndpoint struct {
|
||||
dst xnet.Destination
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (netEndpoint) ClearSrc() {}
|
||||
|
||||
func (e netEndpoint) DstIP() netip.Addr {
|
||||
return toNetIpAddr(e.dst.Address)
|
||||
}
|
||||
|
||||
func (e netEndpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (e netEndpoint) DstToBytes() []byte {
|
||||
var dat []byte
|
||||
if e.dst.Address.Family().IsIPv4() {
|
||||
dat = e.dst.Address.IP().To4()[:]
|
||||
} else {
|
||||
dat = e.dst.Address.IP().To16()[:]
|
||||
}
|
||||
dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
|
||||
return dat
|
||||
}
|
||||
|
||||
func (e netEndpoint) DstToString() string {
|
||||
return e.dst.NetAddr()
|
||||
}
|
||||
|
||||
func (e netEndpoint) SrcToString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func toNetIpAddr(addr xnet.Address) netip.Addr {
|
||||
if addr.Family().IsIPv4() {
|
||||
ip := addr.IP()
|
||||
return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
|
||||
} else {
|
||||
ip := addr.IP()
|
||||
arr := [16]byte{}
|
||||
for i := 0; i < 16; i++ {
|
||||
arr[i] = ip[i]
|
||||
}
|
||||
return netip.AddrFrom16(arr)
|
||||
}
|
||||
}
|
||||
|
||||
func stringsLastIndexByte(s string, b byte) int {
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
if s[i] == b {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) {
|
||||
i := stringsLastIndexByte(s, ':')
|
||||
if i == -1 {
|
||||
return "", 0, false, errors.New("not an ip:port")
|
||||
}
|
||||
|
||||
ip = s[:i]
|
||||
portStr := s[i+1:]
|
||||
if len(ip) == 0 {
|
||||
return "", 0, false, errors.New("no IP")
|
||||
}
|
||||
if len(portStr) == 0 {
|
||||
return "", 0, false, errors.New("no port")
|
||||
}
|
||||
port64, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s))
|
||||
}
|
||||
port = uint16(port64)
|
||||
if ip[0] == '[' {
|
||||
if len(ip) < 2 || ip[len(ip)-1] != ']' {
|
||||
return "", 0, false, errors.New("missing ]")
|
||||
}
|
||||
ip = ip[1 : len(ip)-1]
|
||||
v6 = true
|
||||
}
|
||||
|
||||
return ip, port, v6, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue