From ce6c0dc690c32056bc1faabac97112d306b9f0c3 Mon Sep 17 00:00:00 2001 From: RPRX <63339210+RPRX@users.noreply.github.com> Date: Mon, 6 Jan 2025 14:06:11 +0000 Subject: [PATCH] XHTTP XMUX: Abandon `client` if `client.Do(req)` failed (#4253) https://github.com/XTLS/Xray-core/commit/51769fdde1ca663dcb08d942618e480bee13109f --- transport/internet/splithttp/client.go | 25 +++++++++++++++---------- transport/internet/splithttp/config.go | 5 ++++- transport/internet/splithttp/dialer.go | 6 +++--- transport/internet/splithttp/hub.go | 14 ++++++++++++++ 4 files changed, 36 insertions(+), 14 deletions(-) diff --git a/transport/internet/splithttp/client.go b/transport/internet/splithttp/client.go index c6cfee55..23166536 100644 --- a/transport/internet/splithttp/client.go +++ b/transport/internet/splithttp/client.go @@ -55,11 +55,11 @@ func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, body i }, }) - method := "GET" + method := "GET" // stream-down if body != nil { - method = "POST" + method = "POST" // stream-up/one } - req, _ := http.NewRequestWithContext(ctx, method, url, body) + req, _ := http.NewRequestWithContext(context.WithoutCancel(ctx), method, url, body) req.Header = c.transportConfig.GetRequestHeader() if method == "POST" && !c.transportConfig.NoGRPCHeader { req.Header.Set("Content-Type", "application/grpc") @@ -69,17 +69,20 @@ func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, body i go func() { resp, err := c.client.Do(req) if err != nil { + if !uploadOnly { + c.closed = true + } errors.LogInfoInner(ctx, err, "failed to "+method+" "+url) gotConn.Close() wrc.Close() return } if resp.StatusCode != 200 && !uploadOnly { - // c.closed = true errors.LogInfo(ctx, "unexpected status ", resp.StatusCode) } - if resp.StatusCode != 200 || uploadOnly { - resp.Body.Close() + if resp.StatusCode != 200 || uploadOnly { // stream-up + io.Copy(io.Discard, resp.Body) + resp.Body.Close() // if it is called immediately, the upload will be interrupted also wrc.Close() return } @@ -91,7 +94,7 @@ func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, body i } func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, body io.Reader, contentLength int64) error { - req, err := http.NewRequestWithContext(ctx, "POST", url, body) + req, err := http.NewRequestWithContext(context.WithoutCancel(ctx), "POST", url, body) if err != nil { return err } @@ -101,13 +104,14 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, body i if c.httpVersion != "1.1" { resp, err := c.client.Do(req) if err != nil { + c.closed = true return err } + io.Copy(io.Discard, resp.Body) defer resp.Body.Close() if resp.StatusCode != 200 { - // c.closed = true return errors.New("bad status code:", resp.Status) } } else { @@ -139,11 +143,12 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, body i if h1UploadConn.UnreadedResponsesCount > 0 { resp, err := http.ReadResponse(h1UploadConn.RespBufReader, req) if err != nil { + c.closed = true return fmt.Errorf("error while reading response: %s", err.Error()) } + io.Copy(io.Discard, resp.Body) + defer resp.Body.Close() if resp.StatusCode != 200 { - // c.closed = true - // resp.Body.Close() // I'm not sure return fmt.Errorf("got non-200 error response code: %d", resp.StatusCode) } } diff --git a/transport/internet/splithttp/config.go b/transport/internet/splithttp/config.go index 190cb633..a76bf0e4 100644 --- a/transport/internet/splithttp/config.go +++ b/transport/internet/splithttp/config.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/transport/internet" ) @@ -36,10 +37,11 @@ func (c *Config) GetNormalizedQuery() string { if query != "" { query += "&" } + query += "x_version=" + core.Version() paddingLen := c.GetNormalizedXPaddingBytes().rand() if paddingLen > 0 { - query += "x_padding=" + strings.Repeat("0", int(paddingLen)) + query += "&x_padding=" + strings.Repeat("0", int(paddingLen)) } return query @@ -58,6 +60,7 @@ func (c *Config) WriteResponseHeader(writer http.ResponseWriter) { // CORS headers for the browser dialer writer.Header().Set("Access-Control-Allow-Origin", "*") writer.Header().Set("Access-Control-Allow-Methods", "GET, POST") + writer.Header().Set("X-Version", core.Version()) paddingLen := c.GetNormalizedXPaddingBytes().rand() if paddingLen > 0 { writer.Header().Set("X-Padding", strings.Repeat("0", int(paddingLen))) diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index f33bb9e4..22f854cd 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -372,14 +372,14 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me if xmuxClient != nil { xmuxClient.LeftRequests.Add(-1) } - conn.reader, conn.remoteAddr, conn.localAddr, _ = httpClient.OpenStream(context.WithoutCancel(ctx), requestURL.String(), reader, false) + conn.reader, conn.remoteAddr, conn.localAddr, _ = httpClient.OpenStream(ctx, requestURL.String(), reader, false) return stat.Connection(&conn), nil } else { // stream-down var err error if xmuxClient2 != nil { xmuxClient2.LeftRequests.Add(-1) } - conn.reader, conn.remoteAddr, conn.localAddr, err = httpClient2.OpenStream(context.WithoutCancel(ctx), requestURL2.String(), nil, false) + conn.reader, conn.remoteAddr, conn.localAddr, err = httpClient2.OpenStream(ctx, requestURL2.String(), nil, false) if err != nil { // browser dialer only return nil, err } @@ -454,7 +454,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me go func() { err := httpClient.PostPacket( - context.WithoutCancel(ctx), + ctx, url.String(), &buf.MultiBufferContainer{MultiBuffer: chunk}, int64(chunk.Len()), diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index b7c5098b..0d8c20da 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -1,6 +1,7 @@ package splithttp import ( + "bytes" "context" "crypto/tls" "io" @@ -102,6 +103,12 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req h.config.WriteResponseHeader(writer) + clientVer := []int{0, 0, 0} + x_version := strings.Split(request.URL.Query().Get("x_version"), ".") + for j := 0; j < 3 && len(x_version) > j; j++ { + clientVer[j], _ = strconv.Atoi(x_version[j]) + } + validRange := h.config.GetNormalizedXPaddingBytes() x_padding := int32(len(request.URL.Query().Get("x_padding"))) if validRange.To > 0 && (x_padding < validRange.From || x_padding > validRange.To) { @@ -160,6 +167,13 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req writer.WriteHeader(http.StatusConflict) } else { writer.WriteHeader(http.StatusOK) + if request.ProtoMajor != 1 && len(clientVer) > 0 && clientVer[0] >= 25 { + paddingLen := h.config.GetNormalizedXPaddingBytes().rand() + if paddingLen > 0 { + writer.Write(bytes.Repeat([]byte{'0'}, int(paddingLen))) + } + writer.(http.Flusher).Flush() + } <-request.Context().Done() } return