WireGuard Inbound (User-space WireGuard server) (#2477)

* feat: wireguard inbound

* feat(command): generate wireguard compatible keypair

* feat(wireguard): connection idle timeout

* fix(wireguard): close endpoint after connection closed

* fix(wireguard): resolve conflicts

* feat(wireguard): set cubic as default cc algorithm in gVisor TUN

* chore(wireguard): resolve conflict

* chore(wireguard): remove redurant code

* chore(wireguard): remove redurant code

* feat: rework server for gvisor tun

* feat: keep user-space tun as an option

* fix: exclude android from native tun build

* feat: auto kernel tun

* fix: build

* fix: regulate function name & fix test
This commit is contained in:
hax0r31337 2023-11-18 11:27:17 +08:00 committed by GitHub
parent f1c81557dc
commit 0ac7da2fc8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 1049 additions and 500 deletions

View file

@ -27,48 +27,45 @@ type netReadInfo struct {
err error
}
type netBindClient struct {
workers int
dialer internet.Dialer
// reduce duplicated code
type netBind struct {
dns dns.Client
dnsOption dns.IPOption
reserved []byte
workers int
readQueue chan *netReadInfo
}
func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
ipStr, port, _, err := splitAddrPort(s)
// SetMark implements conn.Bind
func (bind *netBind) SetMark(mark uint32) error {
return nil
}
// ParseEndpoint implements conn.Bind
func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
ipStr, port, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
portNum, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
var addr net.IP
if IsDomainName(ipStr) {
ips, err := bind.dns.LookupIP(ipStr, bind.dnsOption)
addr := xnet.ParseAddress(ipStr)
if addr.Family() == xnet.AddressFamilyDomain {
ips, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
if err != nil {
return nil, err
} else if len(ips) == 0 {
return nil, dns.ErrEmptyResponse
}
addr = ips[0]
} else {
addr = net.ParseIP(ipStr)
}
if addr == nil {
return nil, errors.New("failed to parse ip: " + ipStr)
}
var ip xnet.Address
if p4 := addr.To4(); len(p4) == net.IPv4len {
ip = xnet.IPAddress(p4[:])
} else {
ip = xnet.IPAddress(addr[:])
addr = xnet.IPAddress(ips[0])
}
dst := xnet.Destination{
Address: ip,
Port: xnet.Port(port),
Address: addr,
Port: xnet.Port(portNum),
Network: xnet.Network_UDP,
}
@ -77,7 +74,13 @@ func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
}, nil
}
func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
// BatchSize implements conn.Bind
func (bind *netBind) BatchSize() int {
return 1
}
// Open implements conn.Bind
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
bind.readQueue = make(chan *netReadInfo)
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
@ -109,13 +112,21 @@ func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error
return arr, uint16(uport), nil
}
func (bind *netBindClient) Close() error {
// Close implements conn.Bind
func (bind *netBind) Close() error {
if bind.readQueue != nil {
close(bind.readQueue)
}
return nil
}
type netBindClient struct {
netBind
dialer internet.Dialer
reserved []byte
}
func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
if err != nil {
@ -177,12 +188,29 @@ func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
return nil
}
func (bind *netBindClient) SetMark(mark uint32) error {
return nil
type netBindServer struct {
netBind
}
func (bind *netBindClient) BatchSize() int {
return 1
func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
var err error
nend, ok := endpoint.(*netEndpoint)
if !ok {
return conn.ErrWrongEndpointType
}
if nend.conn == nil {
return newError("connection not open yet")
}
for _, buff := range buff {
if _, err = nend.conn.Write(buff); err != nil {
return err
}
}
return err
}
type netEndpoint struct {
@ -193,7 +221,7 @@ type netEndpoint struct {
func (netEndpoint) ClearSrc() {}
func (e netEndpoint) DstIP() netip.Addr {
return toNetIpAddr(e.dst.Address)
return netip.Addr{}
}
func (e netEndpoint) SrcIP() netip.Addr {
@ -232,83 +260,3 @@ func toNetIpAddr(addr xnet.Address) netip.Addr {
return netip.AddrFrom16(arr)
}
}
func stringsLastIndexByte(s string, b byte) int {
for i := len(s) - 1; i >= 0; i-- {
if s[i] == b {
return i
}
}
return -1
}
func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) {
i := stringsLastIndexByte(s, ':')
if i == -1 {
return "", 0, false, errors.New("not an ip:port")
}
ip = s[:i]
portStr := s[i+1:]
if len(ip) == 0 {
return "", 0, false, errors.New("no IP")
}
if len(portStr) == 0 {
return "", 0, false, errors.New("no port")
}
port64, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s))
}
port = uint16(port64)
if ip[0] == '[' {
if len(ip) < 2 || ip[len(ip)-1] != ']' {
return "", 0, false, errors.New("missing ]")
}
ip = ip[1 : len(ip)-1]
v6 = true
}
return ip, port, v6, nil
}
func IsDomainName(s string) bool {
l := len(s)
if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
return false
}
last := byte('.')
nonNumeric := false
partlen := 0
for i := 0; i < len(s); i++ {
c := s[i]
switch {
default:
return false
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
nonNumeric = true
partlen++
case '0' <= c && c <= '9':
partlen++
case c == '-':
if last == '.' {
return false
}
partlen++
nonNumeric = true
case c == '.':
if last == '.' || last == '-' {
return false
}
if partlen > 63 || partlen == 0 {
return false
}
partlen = 0
}
last = c
}
if last == '-' || partlen > 63 {
return false
}
return nonNumeric
}