diff --git a/common/log/access.go b/common/log/access.go index 87bc1afe..22b84cf6 100644 --- a/common/log/access.go +++ b/common/log/access.go @@ -36,19 +36,23 @@ func (m *AccessMessage) String() string { builder.WriteString(string(m.Status)) builder.WriteByte(' ') builder.WriteString(serial.ToString(m.To)) - builder.WriteByte(' ') + if len(m.Detour) > 0 { - builder.WriteByte('[') + builder.WriteString(" [") builder.WriteString(m.Detour) - builder.WriteString("] ") + builder.WriteByte(']') + } + + if reason := serial.ToString(m.Reason); len(reason) > 0 { + builder.WriteString(" ") + builder.WriteString(reason) } - builder.WriteString(serial.ToString(m.Reason)) if len(m.Email) > 0 { - builder.WriteString("email:") + builder.WriteString(" email: ") builder.WriteString(m.Email) - builder.WriteByte(' ') } + return builder.String() } diff --git a/common/serial/string.go b/common/serial/string.go index 70de87a7..e220e543 100644 --- a/common/serial/string.go +++ b/common/serial/string.go @@ -8,7 +8,7 @@ import ( // ToString serialize an arbitrary value into string. func ToString(v interface{}) string { if v == nil { - return " " + return "" } switch value := v.(type) { diff --git a/infra/conf/dns.go b/infra/conf/dns.go index 7e7bc0df..ea86359e 100644 --- a/infra/conf/dns.go +++ b/infra/conf/dns.go @@ -84,7 +84,7 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) { geoipList, err := toCidrList(c.ExpectIPs) if err != nil { - return nil, newError("invalid ip rule: ", c.ExpectIPs).Base(err) + return nil, newError("invalid IP rule: ", c.ExpectIPs).Base(err) } return &dns.NameServer{ @@ -142,7 +142,7 @@ func (c *DNSConfig) Build() (*dns.Config, error) { for _, server := range c.Servers { ns, err := server.Build() if err != nil { - return nil, newError("failed to build name server").Base(err) + return nil, newError("failed to build nameserver").Base(err) } config.NameServer = append(config.NameServer, ns) } @@ -159,15 +159,23 @@ func (c *DNSConfig) Build() (*dns.Config, error) { var mappings []*dns.Config_HostMapping switch { case strings.HasPrefix(domain, "domain:"): + domainName := domain[7:] + if len(domainName) == 0 { + return nil, newError("empty domain type of rule: ", domain) + } mapping := getHostMapping(addr) mapping.Type = dns.DomainMatchingType_Subdomain - mapping.Domain = domain[7:] + mapping.Domain = domainName mappings = append(mappings, mapping) case strings.HasPrefix(domain, "geosite:"): - domains, err := loadGeositeWithAttr("geosite.dat", strings.ToUpper(domain[8:])) + listName := domain[8:] + if len(listName) == 0 { + return nil, newError("empty geosite rule: ", domain) + } + domains, err := loadGeositeWithAttr("geosite.dat", listName) if err != nil { - return nil, newError("invalid geosite settings: ", domain).Base(err) + return nil, newError("failed to load geosite: ", listName).Base(err) } for _, d := range domains { mapping := getHostMapping(addr) @@ -177,21 +185,33 @@ func (c *DNSConfig) Build() (*dns.Config, error) { } case strings.HasPrefix(domain, "regexp:"): + regexpVal := domain[7:] + if len(regexpVal) == 0 { + return nil, newError("empty regexp type of rule: ", domain) + } mapping := getHostMapping(addr) mapping.Type = dns.DomainMatchingType_Regex - mapping.Domain = domain[7:] + mapping.Domain = regexpVal mappings = append(mappings, mapping) case strings.HasPrefix(domain, "keyword:"): + keywordVal := domain[8:] + if len(keywordVal) == 0 { + return nil, newError("empty keyword type of rule: ", domain) + } mapping := getHostMapping(addr) mapping.Type = dns.DomainMatchingType_Keyword - mapping.Domain = domain[8:] + mapping.Domain = keywordVal mappings = append(mappings, mapping) case strings.HasPrefix(domain, "full:"): + fullVal := domain[5:] + if len(fullVal) == 0 { + return nil, newError("empty full domain type of rule: ", domain) + } mapping := getHostMapping(addr) mapping.Type = dns.DomainMatchingType_Full - mapping.Domain = domain[5:] + mapping.Domain = fullVal mappings = append(mappings, mapping) case strings.HasPrefix(domain, "dotless:"): @@ -213,10 +233,10 @@ func (c *DNSConfig) Build() (*dns.Config, error) { return nil, newError("invalid external resource: ", domain) } filename := kv[0] - country := kv[1] - domains, err := loadGeositeWithAttr(filename, country) + list := kv[1] + domains, err := loadGeositeWithAttr(filename, list) if err != nil { - return nil, newError("failed to load domains: ", country, " from ", filename).Base(err) + return nil, newError("failed to load domain list: ", list, " from ", filename).Base(err) } for _, d := range domains { mapping := getHostMapping(addr) diff --git a/infra/conf/trojan.go b/infra/conf/trojan.go index 1e5e8e89..2e0dd378 100644 --- a/infra/conf/trojan.go +++ b/infra/conf/trojan.go @@ -167,7 +167,7 @@ func (c *TrojanServerConfig) Build() (proto.Message, error) { switch fb.Dest[0] { case '@', '/': fb.Type = "unix" - if fb.Dest[0] == '@' && len(fb.Dest) > 1 && fb.Dest[1] == '@' && runtime.GOOS == "linux" { + if fb.Dest[0] == '@' && len(fb.Dest) > 1 && fb.Dest[1] == '@' && (runtime.GOOS == "linux" || runtime.GOOS == "android") { fullAddr := make([]byte, len(syscall.RawSockaddrUnix{}.Path)) // may need padding to work with haproxy copy(fullAddr, fb.Dest[1:]) fb.Dest = string(fullAddr) diff --git a/infra/conf/vless.go b/infra/conf/vless.go index 605759fc..88ff1c1b 100644 --- a/infra/conf/vless.go +++ b/infra/conf/vless.go @@ -101,7 +101,7 @@ func (c *VLessInboundConfig) Build() (proto.Message, error) { switch fb.Dest[0] { case '@', '/': fb.Type = "unix" - if fb.Dest[0] == '@' && len(fb.Dest) > 1 && fb.Dest[1] == '@' && runtime.GOOS == "linux" { + if fb.Dest[0] == '@' && len(fb.Dest) > 1 && fb.Dest[1] == '@' && (runtime.GOOS == "linux" || runtime.GOOS == "android") { fullAddr := make([]byte, len(syscall.RawSockaddrUnix{}.Path)) // may need padding to work with haproxy copy(fullAddr, fb.Dest[1:]) fb.Dest = string(fullAddr) diff --git a/proxy/http/client.go b/proxy/http/client.go index 11d99690..2d7535a5 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -168,6 +168,7 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u rawConn.Close() return nil, err } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { rawConn.Close() diff --git a/proxy/http/server.go b/proxy/http/server.go index 4b894c3b..6d33cdac 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -293,6 +293,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri response.Close = true result = nil } + defer response.Body.Close() } else { newError("failed to read response from ", request.Host).Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) response = &http.Response{ diff --git a/proxy/socks/client.go b/proxy/socks/client.go index f9248de8..e409cc75 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -51,14 +51,19 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter if outbound == nil || !outbound.Target.IsValid() { return newError("target not specified.") } + // Destination of the inner request. destination := outbound.Target + // Outbound server. var server *protocol.ServerSpec + // Outbound server's destination. + var dest net.Destination + // Connection to the outbound server. var conn internet.Connection if err := retry.ExponentialBackoff(5, 100).On(func() error { server = c.serverPicker.PickServer() - dest := server.Destination() + dest = server.Destination() rawConn, err := dialer.Dial(ctx, dest) if err != nil { return err @@ -101,6 +106,11 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter if err != nil { return newError("failed to establish connection to server").AtWarning().Base(err) } + if udpRequest != nil { + if udpRequest.Address == net.AnyIP || udpRequest.Address == net.AnyIPv6 { + udpRequest.Address = dest.Address + } + } if err := conn.SetDeadline(time.Time{}); err != nil { newError("failed to clear deadline after handshake").Base(err).WriteToLog(session.ExportIDToError(ctx)) diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index ca66e336..b32166b5 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -16,7 +16,7 @@ const ( cmdTCPConnect = 0x01 cmdTCPBind = 0x02 - cmdUDPPort = 0x03 + cmdUDPAssociate = 0x03 cmdTorResolve = 0xF0 cmdTorResolvePTR = 0xF1 @@ -39,8 +39,10 @@ var addrParser = protocol.NewAddressParser( ) type ServerSession struct { - config *ServerConfig - port net.Port + config *ServerConfig + address net.Address + port net.Port + clientAddress net.Address } func (s *ServerSession) handshake4(cmd byte, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) { @@ -162,7 +164,7 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri case cmdTCPConnect, cmdTorResolve, cmdTorResolvePTR: // We don't have a solution for Tor case now. Simply treat it as connect command. request.Command = protocol.RequestCommandTCP - case cmdUDPPort: + case cmdUDPAssociate: if !s.config.UdpEnabled { writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) return nil, newError("UDP is not enabled.") @@ -185,15 +187,20 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri request.Address = addr request.Port = port - responseAddress := net.AnyIP - responsePort := net.Port(1717) + responseAddress := s.address + responsePort := s.port + //nolint:gocritic // Use if else chain for clarity if request.Command == protocol.RequestCommandUDP { - addr := s.config.Address.AsAddress() - if addr == nil { - addr = net.LocalHostIP + if s.config.Address != nil { + // Use configured IP as remote address in the response to UdpAssociate + responseAddress = s.config.Address.AsAddress() + } else if s.clientAddress == net.LocalHostIP || s.clientAddress == net.LocalHostIPv6 { + // For localhost clients use loopback IP + responseAddress = s.clientAddress + } else { + // For non-localhost clients use inbound listening address + responseAddress = s.address } - responseAddress = addr - responsePort = s.port } if err := writeSocks5Response(writer, statusSuccess, responseAddress, responsePort); err != nil { return nil, err @@ -446,7 +453,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i command := byte(cmdTCPConnect) if request.Command == protocol.RequestCommandUDP { - command = byte(cmdUDPPort) + command = byte(cmdUDPAssociate) } common.Must2(b.Write([]byte{socks5Version, command, 0x00 /* reserved */})) if err := addrParser.WriteAddressPort(b, request.Address, request.Port); err != nil { diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 620d5dae..829277e4 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -89,8 +89,10 @@ func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispa } svrSession := &ServerSession{ - config: s.config, - port: inbound.Gateway.Port, + config: s.config, + address: inbound.Gateway.Address, + port: inbound.Gateway.Port, + clientAddress: inbound.Source.Address, } reader := &buf.BufferedReader{Reader: buf.NewReader(conn)} diff --git a/testing/scenarios/transport_test.go b/testing/scenarios/transport_test.go index 5ffd4569..ad447d9d 100644 --- a/testing/scenarios/transport_test.go +++ b/testing/scenarios/transport_test.go @@ -136,8 +136,8 @@ func TestHTTPConnectionHeader(t *testing.T) { } func TestDomainSocket(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("Not supported on windows") + if runtime.GOOS == "windows" || runtime.GOOS == "android" { + t.Skip("Not supported on windows and android") return } tcpServer := tcp.Server{ diff --git a/transport/internet/domainsocket/listener_test.go b/transport/internet/domainsocket/listener_test.go index ee8f4a0b..8c14604f 100644 --- a/transport/internet/domainsocket/listener_test.go +++ b/transport/internet/domainsocket/listener_test.go @@ -1,4 +1,5 @@ // +build !windows +// +build !android package domainsocket_test diff --git a/transport/internet/system_listener.go b/transport/internet/system_listener.go index 2dc8a15a..eb753d60 100644 --- a/transport/internet/system_listener.go +++ b/transport/internet/system_listener.go @@ -54,7 +54,7 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S lc.Control = nil network = addr.Network() address = addr.Name - if runtime.GOOS == "linux" && address[0] == '@' { + if (runtime.GOOS == "linux" || runtime.GOOS == "android") && address[0] == '@' { // linux abstract unix domain socket is lockfree if len(address) > 1 && address[1] == '@' { // but may need padding to work with haproxy diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index bb8cef46..142579bc 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -48,7 +48,7 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, streamSe Net: "unix", }, streamSettings.SocketSettings) if err != nil { - return nil, newError("failed to listen Unix Doman Socket on ", address).Base(err) + return nil, newError("failed to listen Unix Domain Socket on ", address).Base(err) } newError("listening Unix Domain Socket on ", address).WriteToLog(session.ExportIDToError(ctx)) locker := ctx.Value(address.Domain())