package wireguard import ( "context" "errors" "io" "net" "net/netip" "strconv" "sync" "github.com/sagernet/wireguard-go/conn" xnet "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/transport/internet" ) type netReadInfo struct { // status waiter sync.WaitGroup // param buff []byte // result bytes int endpoint conn.Endpoint err error } type netBindClient struct { workers int dialer internet.Dialer dns dns.Client dnsOption dns.IPOption reserved []byte readQueue chan *netReadInfo } func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) { ipStr, port, _, err := splitAddrPort(s) if err != nil { return nil, err } var addr net.IP if IsDomainName(ipStr) { ips, err := n.dns.LookupIP(ipStr, n.dnsOption) if err != nil { return nil, err } else if len(ips) == 0 { return nil, dns.ErrEmptyResponse } addr = ips[0] } else { addr = net.ParseIP(ipStr) } if addr == nil { return nil, errors.New("failed to parse ip: " + ipStr) } var ip xnet.Address if p4 := addr.To4(); len(p4) == net.IPv4len { ip = xnet.IPAddress(p4[:]) } else { ip = xnet.IPAddress(addr[:]) } dst := xnet.Destination{ Address: ip, Port: xnet.Port(port), Network: xnet.Network_UDP, } return &netEndpoint{ dst: dst, }, nil } func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { bind.readQueue = make(chan *netReadInfo) fun := func(buff []byte) (cap int, ep conn.Endpoint, err error) { defer func() { if r := recover(); r != nil { cap = 0 ep = nil err = errors.New("channel closed") } }() r := &netReadInfo{ buff: buff, } r.waiter.Add(1) bind.readQueue <- r r.waiter.Wait() // wait read goroutine done, or we will miss the result return r.bytes, r.endpoint, r.err } workers := bind.workers if workers <= 0 { workers = 1 } arr := make([]conn.ReceiveFunc, workers) for i := 0; i < workers; i++ { arr[i] = fun } return arr, uint16(uport), nil } func (bind *netBindClient) Close() error { if bind.readQueue != nil { close(bind.readQueue) } return nil } func (bind *netBindClient) connectTo(endpoint *netEndpoint) error { c, err := bind.dialer.Dial(context.Background(), endpoint.dst) if err != nil { return err } endpoint.conn = c go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) { for { v, ok := <-readQueue if !ok { return } i, err := c.Read(v.buff) v.bytes = i v.endpoint = endpoint v.err = err v.waiter.Done() if err != nil && errors.Is(err, io.EOF) { endpoint.conn = nil return } } }(bind.readQueue, endpoint) return nil } func (bind *netBindClient) Send(buff []byte, endpoint conn.Endpoint) error { var err error nend, ok := endpoint.(*netEndpoint) if !ok { return conn.ErrWrongEndpointType } if nend.conn == nil { err = bind.connectTo(nend) if err != nil { return err } } if len(buff) > 3 && len(bind.reserved) == 3 { copy(buff[1:], bind.reserved) } _, err = nend.conn.Write(buff) return err } func (bind *netBindClient) SetMark(mark uint32) error { return nil } type netEndpoint struct { dst xnet.Destination conn net.Conn } func (netEndpoint) ClearSrc() {} func (e netEndpoint) DstIP() netip.Addr { return toNetIpAddr(e.dst.Address) } func (e netEndpoint) SrcIP() netip.Addr { return netip.Addr{} } func (e netEndpoint) DstToBytes() []byte { var dat []byte if e.dst.Address.Family().IsIPv4() { dat = e.dst.Address.IP().To4()[:] } else { dat = e.dst.Address.IP().To16()[:] } dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8)) return dat } func (e netEndpoint) DstToString() string { return e.dst.NetAddr() } func (e netEndpoint) SrcToString() string { return "" } func toNetIpAddr(addr xnet.Address) netip.Addr { if addr.Family().IsIPv4() { ip := addr.IP() return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]}) } else { ip := addr.IP() arr := [16]byte{} for i := 0; i < 16; i++ { arr[i] = ip[i] } return netip.AddrFrom16(arr) } } func stringsLastIndexByte(s string, b byte) int { for i := len(s) - 1; i >= 0; i-- { if s[i] == b { return i } } return -1 } func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) { i := stringsLastIndexByte(s, ':') if i == -1 { return "", 0, false, errors.New("not an ip:port") } ip = s[:i] portStr := s[i+1:] if len(ip) == 0 { return "", 0, false, errors.New("no IP") } if len(portStr) == 0 { return "", 0, false, errors.New("no port") } port64, err := strconv.ParseUint(portStr, 10, 16) if err != nil { return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s)) } port = uint16(port64) if ip[0] == '[' { if len(ip) < 2 || ip[len(ip)-1] != ']' { return "", 0, false, errors.New("missing ]") } ip = ip[1 : len(ip)-1] v6 = true } return ip, port, v6, nil }