From b786a50aee65359c8bbab6fc8ab39f33f4bf0d9a Mon Sep 17 00:00:00 2001 From: RPRX <63339210+RPRX@users.noreply.github.com> Date: Thu, 20 Feb 2025 16:28:06 +0000 Subject: [PATCH] XHTTP server: Fix stream-up "single POST problem", Use united httpServerConn instead of recover() https://github.com/XTLS/Xray-core/issues/4373#issuecomment-2671795675 https://github.com/XTLS/Xray-core/issues/4406#issuecomment-2668041926 --- transport/internet/splithttp/hub.go | 114 +++++++------------ transport/internet/splithttp/upload_queue.go | 30 +++-- 2 files changed, 63 insertions(+), 81 deletions(-) diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index 86a750f7..179d696a 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -47,21 +47,6 @@ type httpSession struct { isFullyConnected *done.Instance } -func (h *requestHandler) maybeReapSession(isFullyConnected *done.Instance, sessionId string) { - shouldReap := done.New() - go func() { - time.Sleep(30 * time.Second) - shouldReap.Close() - }() - - select { - case <-isFullyConnected.Wait(): - return - case <-shouldReap.Wait(): - h.sessions.Delete(sessionId) - } -} - func (h *requestHandler) upsertSession(sessionId string) *httpSession { // fast path currentSessionAny, ok := h.sessions.Load(sessionId) @@ -84,7 +69,21 @@ func (h *requestHandler) upsertSession(sessionId string) *httpSession { } h.sessions.Store(sessionId, s) - go h.maybeReapSession(s.isFullyConnected, sessionId) + + shouldReap := done.New() + go func() { + time.Sleep(30 * time.Second) + shouldReap.Close() + }() + go func() { + select { + case <-shouldReap.Wait(): + h.sessions.Delete(sessionId) + s.uploadQueue.Close() + case <-s.isFullyConnected.Wait(): + } + }() + return s } @@ -183,12 +182,13 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req writer.WriteHeader(http.StatusBadRequest) return } - uploadDone := done.New() + httpSC := &httpServerConn{ + Instance: done.New(), + Reader: request.Body, + ResponseWriter: writer, + } err = currentSession.uploadQueue.Push(Packet{ - Reader: &httpRequestBodyReader{ - requestReader: request.Body, - uploadDone: uploadDone, - }, + Reader: httpSC, }) if err != nil { errors.LogInfoInner(context.Background(), err, "failed to upload (PushReader)") @@ -200,25 +200,21 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req scStreamUpServerSecs := h.config.GetNormalizedScStreamUpServerSecs() if referrer != "" && scStreamUpServerSecs.To > 0 { go func() { - defer func() { - recover() - }() for { - _, err := writer.Write(bytes.Repeat([]byte{'X'}, int(h.config.GetNormalizedXPaddingBytes().rand()))) + _, err := httpSC.Write(bytes.Repeat([]byte{'X'}, int(h.config.GetNormalizedXPaddingBytes().rand()))) if err != nil { break } - writer.(http.Flusher).Flush() time.Sleep(time.Duration(scStreamUpServerSecs.rand()) * time.Second) } }() } select { case <-request.Context().Done(): - case <-uploadDone.Wait(): + case <-httpSC.Wait(): } } - uploadDone.Close() + httpSC.Close() return } @@ -262,11 +258,6 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req writer.WriteHeader(http.StatusOK) } else if request.Method == "GET" || sessionId == "" { // stream-down, stream-one - responseFlusher, ok := writer.(http.Flusher) - if !ok { - panic("expected http.ResponseWriter to be an http.Flusher") - } - if sessionId != "" { // after GET is done, the connection is finished. disable automatic // session reaping, and handle it in defer @@ -287,20 +278,18 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } writer.WriteHeader(http.StatusOK) + writer.(http.Flusher).Flush() - responseFlusher.Flush() - - downloadDone := done.New() - + httpSC := &httpServerConn{ + Instance: done.New(), + Reader: request.Body, + ResponseWriter: writer, + } conn := splitConn{ - writer: &httpResponseBodyWriter{ - responseWriter: writer, - downloadDone: downloadDone, - responseFlusher: responseFlusher, - }, - reader: request.Body, - localAddr: h.localAddr, + writer: httpSC, + reader: httpSC, remoteAddr: remoteAddr, + localAddr: h.localAddr, } if sessionId != "" { // if not stream-one conn.reader = currentSession.uploadQueue @@ -311,7 +300,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req // "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned." select { case <-request.Context().Done(): - case <-downloadDone.Wait(): + case <-httpSC.Wait(): } conn.Close() @@ -321,45 +310,30 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } } -type httpRequestBodyReader struct { - requestReader io.ReadCloser - uploadDone *done.Instance -} - -func (c *httpRequestBodyReader) Read(b []byte) (int, error) { - return c.requestReader.Read(b) -} - -func (c *httpRequestBodyReader) Close() error { - defer c.uploadDone.Close() - return c.requestReader.Close() -} - -type httpResponseBodyWriter struct { +type httpServerConn struct { sync.Mutex - responseWriter http.ResponseWriter - responseFlusher http.Flusher - downloadDone *done.Instance + *done.Instance + io.Reader // no need to Close request.Body + http.ResponseWriter } -func (c *httpResponseBodyWriter) Write(b []byte) (int, error) { +func (c *httpServerConn) Write(b []byte) (int, error) { c.Lock() defer c.Unlock() - if c.downloadDone.Done() { + if c.Done() { return 0, io.ErrClosedPipe } - n, err := c.responseWriter.Write(b) + n, err := c.ResponseWriter.Write(b) if err == nil { - c.responseFlusher.Flush() + c.ResponseWriter.(http.Flusher).Flush() } return n, err } -func (c *httpResponseBodyWriter) Close() error { +func (c *httpServerConn) Close() error { c.Lock() defer c.Unlock() - c.downloadDone.Close() - return nil + return c.Instance.Close() } type Listener struct { diff --git a/transport/internet/splithttp/upload_queue.go b/transport/internet/splithttp/upload_queue.go index 382e381e..69b9a972 100644 --- a/transport/internet/splithttp/upload_queue.go +++ b/transport/internet/splithttp/upload_queue.go @@ -20,6 +20,7 @@ type Packet struct { type uploadQueue struct { reader io.ReadCloser + nomore bool pushedPackets chan Packet writeCloseMutex sync.Mutex heap uploadHeap @@ -42,19 +43,15 @@ func (h *uploadQueue) Push(p Packet) error { h.writeCloseMutex.Lock() defer h.writeCloseMutex.Unlock() - runtime.Gosched() - if h.reader != nil && p.Reader != nil { - p.Reader.Close() - return errors.New("h.reader already exists") - } - if h.closed { - if p.Reader != nil { - p.Reader.Close() - } return errors.New("packet queue closed") } - + if h.nomore { + return errors.New("h.reader already exists") + } + if p.Reader != nil { + h.nomore = true + } h.pushedPackets <- p return nil } @@ -65,9 +62,20 @@ func (h *uploadQueue) Close() error { if !h.closed { h.closed = true + runtime.Gosched() // hope Read() gets the packet + f: + for { + select { + case p := <-h.pushedPackets: + if p.Reader != nil { + h.reader = p.Reader + } + default: + break f + } + } close(h.pushedPackets) } - runtime.Gosched() if h.reader != nil { return h.reader.Close() }