From 7463561856f75ed42564e8a6fad722c6bafc74a6 Mon Sep 17 00:00:00 2001 From: RPRX <63339210+RPRX@users.noreply.github.com> Date: Thu, 12 Dec 2024 12:19:18 +0000 Subject: [PATCH] XHTTP client: Add decideHTTPVersion() and more logs https://github.com/XTLS/Xray-core/pull/4150#issuecomment-2537981368 --- transport/internet/splithttp/client.go | 5 +- transport/internet/splithttp/dialer.go | 99 +++++++++++--------- transport/internet/splithttp/hub.go | 18 ++-- transport/internet/splithttp/upload_queue.go | 2 +- 4 files changed, 69 insertions(+), 55 deletions(-) diff --git a/transport/internet/splithttp/client.go b/transport/internet/splithttp/client.go index 35aad1ec..925d3b0d 100644 --- a/transport/internet/splithttp/client.go +++ b/transport/internet/splithttp/client.go @@ -39,8 +39,7 @@ type DialerClient interface { type DefaultDialerClient struct { transportConfig *Config client *http.Client - isH2 bool - isH3 bool + httpVersion string // pool of net.Conn, created using dialUploadConn uploadRawPool *sync.Pool dialUploadConn func(ctxInner context.Context) (net.Conn, error) @@ -172,7 +171,7 @@ func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, req.ContentLength = contentLength req.Header = c.transportConfig.GetRequestHeader() - if c.isH2 || c.isH3 { + if c.httpVersion != "1.1" { resp, err := c.client.Do(req) if err != nil { return err diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index 4f9ee01e..cd0abc85 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -3,6 +3,7 @@ package splithttp import ( "context" gotls "crypto/tls" + "fmt" "io" "net/http" "net/http/httptrace" @@ -83,23 +84,32 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in return res.Resource.(DialerClient), res } +func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) string { + if realityConfig != nil { + return "2" + } + if tlsConfig == nil { + return "1.1" + } + if len(tlsConfig.NextProtocol) != 1 { + return "2" + } + if tlsConfig.NextProtocol[0] == "http/1.1" { + return "1.1" + } + if tlsConfig.NextProtocol[0] == "h3" { + return "3" + } + return "2" +} + func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) DialerClient { tlsConfig := tls.ConfigFromStreamSettings(streamSettings) realityConfig := reality.ConfigFromStreamSettings(streamSettings) - isH2 := false - isH3 := false - - if tlsConfig != nil { - isH2 = !(len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "http/1.1") - isH3 = len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "h3" - } else if realityConfig != nil { - isH2 = true - isH3 = false - } - - if isH3 { - dest.Network = net.Network_UDP + httpVersion := decideHTTPVersion(tlsConfig, realityConfig) + if httpVersion == "3" { + dest.Network = net.Network_UDP // better to keep this line } var gotlsConfig *gotls.Config @@ -138,7 +148,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea var transport http.RoundTripper - if isH3 { + if httpVersion == "3" { if keepAlivePeriod == 0 { keepAlivePeriod = quicgoH3KeepAlivePeriod } @@ -194,7 +204,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg) }, } - } else if isH2 { + } else if httpVersion == "2" { if keepAlivePeriod == 0 { keepAlivePeriod = chromeH2KeepAlivePeriod } @@ -228,8 +238,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea client: &http.Client{ Transport: transport, }, - isH2: isH2, - isH3: isH3, + httpVersion: httpVersion, uploadRawPool: &sync.Pool{}, dialUploadConn: dialContext, } @@ -242,16 +251,16 @@ func init() { } func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { - errors.LogInfo(ctx, "dialing splithttp to ", dest) - - var requestURL url.URL - - transportConfiguration := streamSettings.ProtocolSettings.(*Config) tlsConfig := tls.ConfigFromStreamSettings(streamSettings) realityConfig := reality.ConfigFromStreamSettings(streamSettings) - scMaxEachPostBytes := transportConfiguration.GetNormalizedScMaxEachPostBytes() - scMinPostsIntervalMs := transportConfiguration.GetNormalizedScMinPostsIntervalMs() + httpVersion := decideHTTPVersion(tlsConfig, realityConfig) + if httpVersion == "3" { + dest.Network = net.Network_UDP + } + + transportConfiguration := streamSettings.ProtocolSettings.(*Config) + var requestURL url.URL if tlsConfig != nil || realityConfig != nil { requestURL.Scheme = "https" @@ -275,8 +284,21 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me httpClient, muxRes := getHTTPClient(ctx, dest, streamSettings) - httpClient2 := httpClient + mode := transportConfiguration.Mode + if mode == "" || mode == "auto" { + mode = "packet-up" + if httpVersion == "2" { + mode = "stream-up" + } + if realityConfig != nil && transportConfiguration.DownloadSettings == nil { + mode = "stream-one" + } + } + + errors.LogInfo(ctx, fmt.Sprintf("XHTTP is dialing to %s, mode %s, HTTP version %s, host %s", dest, mode, httpVersion, requestURL.Host)) + requestURL2 := requestURL + httpClient2 := httpClient var muxRes2 *muxResource if transportConfiguration.DownloadSettings != nil { globalDialerAccess.Lock() @@ -286,9 +308,12 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me globalDialerAccess.Unlock() memory2 := streamSettings.DownloadSettings dest2 := *memory2.Destination // just panic - httpClient2, muxRes2 = getHTTPClient(ctx, dest2, memory2) tlsConfig2 := tls.ConfigFromStreamSettings(memory2) realityConfig2 := reality.ConfigFromStreamSettings(memory2) + httpVersion2 := decideHTTPVersion(tlsConfig2, realityConfig2) + if httpVersion2 == "3" { + dest2.Network = net.Network_UDP + } if tlsConfig2 != nil || realityConfig2 != nil { requestURL2.Scheme = "https" } else { @@ -307,20 +332,10 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me } requestURL2.Path = config2.GetNormalizedPath() + sessionIdUuid.String() requestURL2.RawQuery = config2.GetNormalizedQuery() + httpClient2, muxRes2 = getHTTPClient(ctx, dest2, memory2) + errors.LogInfo(ctx, fmt.Sprintf("XHTTP is downloading from %s, mode %s, HTTP version %s, host %s", dest2, "stream-down", httpVersion2, requestURL2.Host)) } - mode := transportConfiguration.Mode - if mode == "" || mode == "auto" { - mode = "packet-up" - if (tlsConfig != nil && (len(tlsConfig.NextProtocol) != 1 || tlsConfig.NextProtocol[0] == "h2")) || realityConfig != nil { - mode = "stream-up" - } - if realityConfig != nil && transportConfiguration.DownloadSettings == nil { - mode = "stream-one" - } - } - errors.LogInfo(ctx, "XHTTP is using mode: "+mode) - var writer io.WriteCloser var reader io.ReadCloser var remoteAddr, localAddr net.Addr @@ -373,6 +388,9 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me return stat.Connection(&conn), nil } + scMaxEachPostBytes := transportConfiguration.GetNormalizedScMaxEachPostBytes() + scMinPostsIntervalMs := transportConfiguration.GetNormalizedScMinPostsIntervalMs() + maxUploadSize := scMaxEachPostBytes.roll() // WithSizeLimit(0) will still allow single bytes to pass, and a lot of // code relies on this behavior. Subtract 1 so that together with @@ -408,10 +426,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me seq += 1 if scMinPostsIntervalMs.From > 0 { - sleep := time.Duration(scMinPostsIntervalMs.roll())*time.Millisecond - time.Since(lastWrite) - if sleep > 0 { - time.Sleep(sleep) - } + time.Sleep(time.Duration(scMinPostsIntervalMs.roll())*time.Millisecond - time.Since(lastWrite)) } // by offloading the uploads into a buffered pipe, multiple conn.Write diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index 5fef37d9..1df029dd 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -333,30 +333,30 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet Net: "unix", }, streamSettings.SocketSettings) if err != nil { - return nil, errors.New("failed to listen unix domain socket(for SH) on ", address).Base(err) + return nil, errors.New("failed to listen UNIX domain socket for XHTTP on ", address).Base(err) } - errors.LogInfo(ctx, "listening unix domain socket(for SH) on ", address) + errors.LogInfo(ctx, "listening UNIX domain socket for XHTTP on ", address) } else if l.isH3 { // quic Conn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{ IP: address.IP(), Port: int(port), }, streamSettings.SocketSettings) if err != nil { - return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err) + return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err) } h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil) if err != nil { - return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err) + return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err) } l.h3listener = h3listener - errors.LogInfo(ctx, "listening QUIC(for SH3) on ", address, ":", port) + errors.LogInfo(ctx, "listening QUIC for XHTTP/3 on ", address, ":", port) l.h3server = &http3.Server{ Handler: handler, } go func() { if err := l.h3server.ServeListener(l.h3listener); err != nil { - errors.LogWarningInner(ctx, err, "failed to serve http3 for splithttp") + errors.LogWarningInner(ctx, err, "failed to serve HTTP/3 for XHTTP/3") } }() } else { // tcp @@ -369,9 +369,9 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet Port: int(port), }, streamSettings.SocketSettings) if err != nil { - return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err) + return nil, errors.New("failed to listen TCP for XHTTP on ", address, ":", port).Base(err) } - errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port) + errors.LogInfo(ctx, "listening TCP for XHTTP on ", address, ":", port) } // tcp/unix (h1/h2) @@ -397,7 +397,7 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet go func() { if err := l.server.Serve(l.listener); err != nil { - errors.LogWarningInner(ctx, err, "failed to serve http for splithttp") + errors.LogWarningInner(ctx, err, "failed to serve HTTP for XHTTP") } }() } diff --git a/transport/internet/splithttp/upload_queue.go b/transport/internet/splithttp/upload_queue.go index f8dd87b0..382e381e 100644 --- a/transport/internet/splithttp/upload_queue.go +++ b/transport/internet/splithttp/upload_queue.go @@ -52,7 +52,7 @@ func (h *uploadQueue) Push(p Packet) error { if p.Reader != nil { p.Reader.Close() } - return errors.New("splithttp packet queue closed") + return errors.New("packet queue closed") } h.pushedPackets <- p