From 8cb63db6c0c83f73333d033f8a30bf8730955e65 Mon Sep 17 00:00:00 2001 From: RPRX <63339210+RPRX@users.noreply.github.com> Date: Tue, 18 Feb 2025 10:50:50 +0000 Subject: [PATCH] XHTTP server: Set remoteAddr & localAddr correctly Completes https://github.com/XTLS/Xray-core/commit/22c50a70c61f18b54f9e9de82962a053261a398c --- common/net/system.go | 3 +- proxy/proxy.go | 35 +++++---- transport/internet/splithttp/hub.go | 78 ++++++++++--------- .../internet/splithttp/splithttp_test.go | 26 +++---- transport/internet/system_listener.go | 6 +- 5 files changed, 77 insertions(+), 71 deletions(-) diff --git a/common/net/system.go b/common/net/system.go index e5bded04..7e1c4b01 100644 --- a/common/net/system.go +++ b/common/net/system.go @@ -76,8 +76,9 @@ type ( ) var ( - ResolveUnixAddr = net.ResolveUnixAddr + ResolveTCPAddr = net.ResolveTCPAddr ResolveUDPAddr = net.ResolveUDPAddr + ResolveUnixAddr = net.ResolveUnixAddr ) type Resolver = net.Resolver diff --git a/proxy/proxy.go b/proxy/proxy.go index a90aa340..57649795 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -113,12 +113,12 @@ type TrafficState struct { type InboundState struct { // reader link state - WithinPaddingBuffers bool - UplinkReaderDirectCopy bool - RemainingCommand int32 - RemainingContent int32 - RemainingPadding int32 - CurrentCommand int + WithinPaddingBuffers bool + UplinkReaderDirectCopy bool + RemainingCommand int32 + RemainingContent int32 + RemainingPadding int32 + CurrentCommand int // write link state IsPadding bool DownlinkWriterDirectCopy bool @@ -133,19 +133,19 @@ type OutboundState struct { RemainingPadding int32 CurrentCommand int // write link state - IsPadding bool - UplinkWriterDirectCopy bool + IsPadding bool + UplinkWriterDirectCopy bool } func NewTrafficState(userUUID []byte) *TrafficState { return &TrafficState{ - UserUUID: userUUID, - NumberOfPacketToFilter: 8, - EnableXtls: false, - IsTLS12orAbove: false, - IsTLS: false, - Cipher: 0, - RemainingServerHello: -1, + UserUUID: userUUID, + NumberOfPacketToFilter: 8, + EnableXtls: false, + IsTLS12orAbove: false, + IsTLS: false, + Cipher: 0, + RemainingServerHello: -1, Inbound: InboundState{ WithinPaddingBuffers: true, UplinkReaderDirectCopy: false, @@ -524,7 +524,7 @@ func XtlsFilterTls(buffer buf.MultiBuffer, trafficState *TrafficState, ctx conte } } -// UnwrapRawConn support unwrap stats, tls, utls, reality and proxyproto conn and get raw tcp conn from it +// UnwrapRawConn support unwrap stats, tls, utls, reality, proxyproto, uds-wrapper conn and get raw tcp/uds conn from it func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) { var readCounter, writerCounter stats.Counter if conn != nil { @@ -547,6 +547,9 @@ func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) { conn = pc.Raw() // 8192 > 4096, there is no need to process pc's bufReader } + if uc, ok := conn.(*internet.UDSWrapperConn); ok { + conn = uc.Conn + } } return conn, readCounter, writerCounter } diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index 1c5ace05..e205ce8d 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -3,9 +3,8 @@ package splithttp import ( "bytes" "context" - "crypto/tls" + gotls "crypto/tls" "io" - gonet "net" "net/http" "net/url" "strconv" @@ -24,7 +23,7 @@ import ( "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/reality" "github.com/xtls/xray-core/transport/internet/stat" - v2tls "github.com/xtls/xray-core/transport/internet/tls" + "github.com/xtls/xray-core/transport/internet/tls" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) @@ -36,7 +35,7 @@ type requestHandler struct { ln *Listener sessionMu *sync.Mutex sessions sync.Map - localAddr gonet.TCPAddr + localAddr net.Addr } type httpSession struct { @@ -144,14 +143,25 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } forwardedAddrs := http_proto.ParseXForwardedFor(request.Header) - remoteAddr, err := gonet.ResolveTCPAddr("tcp", request.RemoteAddr) + var remoteAddr net.Addr + var err error + remoteAddr, err = net.ResolveTCPAddr("tcp", request.RemoteAddr) if err != nil { - remoteAddr = &gonet.TCPAddr{} + remoteAddr = &net.TCPAddr{ + IP: []byte{0, 0, 0, 0}, + Port: 0, + } + } + if request.ProtoMajor == 3 { + remoteAddr = &net.UDPAddr{ + IP: remoteAddr.(*net.TCPAddr).IP, + Port: remoteAddr.(*net.TCPAddr).Port, + } } if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() { remoteAddr = &net.TCPAddr{ IP: forwardedAddrs[0].IP(), - Port: int(0), + Port: 0, } } @@ -289,6 +299,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req responseFlusher: responseFlusher, }, reader: request.Body, + localAddr: h.localAddr, remoteAddr: remoteAddr, } if sessionId != "" { // if not stream-one @@ -362,34 +373,30 @@ type Listener struct { isH3 bool } -func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { +func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { l := &Listener{ addConn: addConn, } - shSettings := streamSettings.ProtocolSettings.(*Config) - l.config = shSettings + l.config = streamSettings.ProtocolSettings.(*Config) if l.config != nil { if streamSettings.SocketSettings == nil { streamSettings.SocketSettings = &internet.SocketConfig{} } } - var listener net.Listener - var err error - var localAddr = gonet.TCPAddr{} handler := &requestHandler{ - config: shSettings, - host: shSettings.Host, - path: shSettings.GetNormalizedPath(), + config: l.config, + host: l.config.Host, + path: l.config.GetNormalizedPath(), ln: l, sessionMu: &sync.Mutex{}, sessions: sync.Map{}, - localAddr: localAddr, } tlsConfig := getTLSConfig(streamSettings) l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3" + var err error if port == net.Port(0) { // unix - listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ + l.listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ Name: address.Domain(), Net: "unix", }, streamSettings.SocketSettings) @@ -405,13 +412,14 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet if err != nil { return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err) } - h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil) + l.h3listener, err = quic.ListenEarly(Conn, tlsConfig, nil) if err != nil { return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err) } - l.h3listener = h3listener errors.LogInfo(ctx, "listening QUIC for XHTTP/3 on ", address, ":", port) + handler.localAddr = l.h3listener.Addr() + l.h3server = &http3.Server{ Handler: handler, } @@ -421,11 +429,7 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet } }() } else { // tcp - localAddr = gonet.TCPAddr{ - IP: address.IP(), - Port: int(port), - } - listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ + l.listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ IP: address.IP(), Port: int(port), }, streamSettings.SocketSettings) @@ -436,26 +440,24 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet } // tcp/unix (h1/h2) - if listener != nil { - if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { + if l.listener != nil { + if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { - listener = tls.NewListener(listener, tlsConfig) + l.listener = gotls.NewListener(l.listener, tlsConfig) } } - if config := reality.ConfigFromStreamSettings(streamSettings); config != nil { - listener = goreality.NewListener(listener, config.GetREALITYConfig()) + l.listener = goreality.NewListener(l.listener, config.GetREALITYConfig()) } + handler.localAddr = l.listener.Addr() + // h2cHandler can handle both plaintext HTTP/1.1 and h2c - h2cHandler := h2c.NewHandler(handler, &http2.Server{}) - l.listener = listener l.server = http.Server{ - Handler: h2cHandler, + Handler: h2c.NewHandler(handler, &http2.Server{}), ReadHeaderTimeout: time.Second * 4, MaxHeaderBytes: 8192, } - go func() { if err := l.server.Serve(l.listener); err != nil { errors.LogWarningInner(ctx, err, "failed to serve HTTP for XHTTP") @@ -488,13 +490,13 @@ func (ln *Listener) Close() error { } return errors.New("listener does not have an HTTP/3 server or a net.listener") } -func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *tls.Config { - config := v2tls.ConfigFromStreamSettings(streamSettings) +func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *gotls.Config { + config := tls.ConfigFromStreamSettings(streamSettings) if config == nil { - return &tls.Config{} + return &gotls.Config{} } return config.GetTLSConfig() } func init() { - common.Must(internet.RegisterTransportListener(protocolName, ListenSH)) + common.Must(internet.RegisterTransportListener(protocolName, ListenXH)) } diff --git a/transport/internet/splithttp/splithttp_test.go b/transport/internet/splithttp/splithttp_test.go index 20043b4a..f566ecc4 100644 --- a/transport/internet/splithttp/splithttp_test.go +++ b/transport/internet/splithttp/splithttp_test.go @@ -26,9 +26,9 @@ import ( "golang.org/x/net/http2" ) -func Test_listenSHAndDial(t *testing.T) { +func Test_ListenXHAndDial(t *testing.T) { listenPort := tcp.PickPort() - listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ + listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ ProtocolName: "splithttp", ProtocolSettings: &Config{ Path: "/sh", @@ -85,7 +85,7 @@ func Test_listenSHAndDial(t *testing.T) { func TestDialWithRemoteAddr(t *testing.T) { listenPort := tcp.PickPort() - listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ + listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ ProtocolName: "splithttp", ProtocolSettings: &Config{ Path: "sh", @@ -125,7 +125,7 @@ func TestDialWithRemoteAddr(t *testing.T) { common.Must(listen.Close()) } -func Test_listenSHAndDial_TLS(t *testing.T) { +func Test_ListenXHAndDial_TLS(t *testing.T) { if runtime.GOARCH == "arm64" { return } @@ -145,7 +145,7 @@ func Test_listenSHAndDial_TLS(t *testing.T) { Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))}, }, } - listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { + listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { go func() { defer conn.Close() @@ -180,7 +180,7 @@ func Test_listenSHAndDial_TLS(t *testing.T) { } } -func Test_listenSHAndDial_H2C(t *testing.T) { +func Test_ListenXHAndDial_H2C(t *testing.T) { if runtime.GOARCH == "arm64" { return } @@ -193,7 +193,7 @@ func Test_listenSHAndDial_H2C(t *testing.T) { Path: "shs", }, } - listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { + listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { go func() { _ = conn.Close() }() @@ -227,7 +227,7 @@ func Test_listenSHAndDial_H2C(t *testing.T) { } } -func Test_listenSHAndDial_QUIC(t *testing.T) { +func Test_ListenXHAndDial_QUIC(t *testing.T) { if runtime.GOARCH == "arm64" { return } @@ -250,7 +250,7 @@ func Test_listenSHAndDial_QUIC(t *testing.T) { } serverClosed := false - listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { + listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { go func() { defer conn.Close() @@ -309,11 +309,11 @@ func Test_listenSHAndDial_QUIC(t *testing.T) { } } -func Test_listenSHAndDial_Unix(t *testing.T) { +func Test_ListenXHAndDial_Unix(t *testing.T) { tempDir := t.TempDir() tempSocket := tempDir + "/server.sock" - listen, err := ListenSH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{ + listen, err := ListenXH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{ ProtocolName: "splithttp", ProtocolSettings: &Config{ Path: "/sh", @@ -373,7 +373,7 @@ func Test_listenSHAndDial_Unix(t *testing.T) { func Test_queryString(t *testing.T) { listenPort := tcp.PickPort() - listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ + listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ ProtocolName: "splithttp", ProtocolSettings: &Config{ // this querystring does not have any effect, but sometimes people blindly copy it from websocket config. make sure the outbound doesn't break @@ -431,7 +431,7 @@ func Test_maxUpload(t *testing.T) { } var uploadSize int - listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { + listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { go func(c stat.Connection) { defer c.Close() var b [10240]byte diff --git a/transport/internet/system_listener.go b/transport/internet/system_listener.go index cdabd3bf..1086aef5 100644 --- a/transport/internet/system_listener.go +++ b/transport/internet/system_listener.go @@ -54,7 +54,7 @@ func (l *listenUDSWrapper) Accept() (net.Conn, error) { if err != nil { return nil, err } - return &listenUDSWrapperConn{Conn: conn}, nil + return &UDSWrapperConn{Conn: conn}, nil } func (l *listenUDSWrapper) Close() error { @@ -65,11 +65,11 @@ func (l *listenUDSWrapper) Close() error { return l.Listener.Close() } -type listenUDSWrapperConn struct { +type UDSWrapperConn struct { net.Conn } -func (conn *listenUDSWrapperConn) RemoteAddr() net.Addr { +func (conn *UDSWrapperConn) RemoteAddr() net.Addr { return &net.TCPAddr{ IP: []byte{0, 0, 0, 0}, }