diff --git a/transport/internet/splithttp/client.go b/transport/internet/splithttp/client.go index f9b7cba1..e491ef3e 100644 --- a/transport/internet/splithttp/client.go +++ b/transport/internet/splithttp/client.go @@ -49,6 +49,8 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string) var downResponse io.ReadCloser gotDownResponse := done.New() + ctx, ctxCancel := context.WithCancel(ctx) + go func() { trace := &httptrace.ClientTrace{ GotConn: func(connInfo httptrace.GotConnInfo) { @@ -61,8 +63,10 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string) // in case we hit an error, we want to unblock this part defer gotConn.Close() + ctx = httptrace.WithClientTrace(ctx, trace) + req, err := http.NewRequestWithContext( - httptrace.WithClientTrace(ctx, trace), + ctx, "GET", baseURL, nil, @@ -94,16 +98,17 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string) gotDownResponse.Close() }() - if c.isH3 { - gotConn.Close() + if !c.isH3 { + // in quic-go, sometimes gotConn is never closed for the lifetime of + // the entire connection, and the download locks up + // https://github.com/quic-go/quic-go/issues/3342 + // for other HTTP versions, we want to block Dial until we know the + // remote address of the server, for logging purposes + <-gotConn.Wait() } - // we want to block Dial until we know the remote address of the server, - // for logging purposes - <-gotConn.Wait() - lazyDownload := &LazyReader{ - CreateReader: func() (io.ReadCloser, error) { + CreateReader: func() (io.Reader, error) { <-gotDownResponse.Wait() if downResponse == nil { return nil, errors.New("downResponse failed") @@ -112,7 +117,15 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string) }, } - return lazyDownload, remoteAddr, localAddr, nil + // workaround for https://github.com/quic-go/quic-go/issues/2143 -- + // always cancel request context so that Close cancels any Read. + // Should then match the behavior of http2 and http1. + reader := downloadBody{ + lazyDownload, + ctxCancel, + } + + return reader, remoteAddr, localAddr, nil } func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, payload io.ReadWriteCloser, contentLength int64) error { @@ -172,3 +185,13 @@ func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, return nil } + +type downloadBody struct { + io.Reader + cancel context.CancelFunc +} + +func (c downloadBody) Close() error { + c.cancel() + return nil +} diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index a95ab34b..35e8c7b4 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -1,10 +1,8 @@ package splithttp import ( - "bytes" "context" gotls "crypto/tls" - "io" "net/http" "net/url" "strconv" @@ -292,35 +290,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me return nil, err } - lazyDownload := &LazyReader{ - CreateReader: func() (io.ReadCloser, error) { - // skip "ok" response - trashHeader := []byte{0, 0} - _, err := io.ReadFull(lazyRawDownload, trashHeader) - if err != nil { - return nil, errors.New("failed to read initial response").Base(err) - } - - if bytes.Equal(trashHeader, []byte("ok")) { - return lazyRawDownload, nil - } - - // we read some garbage byte that may not have been "ok" at - // all. return a reader that replays what we have read so far - reader := io.MultiReader( - bytes.NewReader(trashHeader), - lazyRawDownload, - ) - readCloser := struct { - io.Reader - io.Closer - }{ - Reader: reader, - Closer: lazyRawDownload, - } - return readCloser, nil - }, - } + reader := &stripOkReader{ReadCloser: lazyRawDownload} writer := uploadWriter{ uploadPipeWriter, @@ -329,7 +299,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me conn := splitConn{ writer: writer, - reader: lazyDownload, + reader: reader, remoteAddr: remoteAddr, localAddr: localAddr, } diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index 9fdd15b4..4a05e1c3 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -222,8 +222,12 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req h.ln.addConn(stat.Connection(&conn)) // "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned." - <-downloadDone.Wait() + select { + case <-request.Context().Done(): + case <-downloadDone.Wait(): + } + conn.Close() } else { writer.WriteHeader(http.StatusMethodNotAllowed) } diff --git a/transport/internet/splithttp/lazy_reader.go b/transport/internet/splithttp/lazy_reader.go index 8641cbac..35d1f436 100644 --- a/transport/internet/splithttp/lazy_reader.go +++ b/transport/internet/splithttp/lazy_reader.go @@ -3,18 +3,20 @@ package splithttp import ( "io" "sync" - - "github.com/xtls/xray-core/common/errors" ) +// Close is intentionally not supported by LazyReader because it's not clear +// how CreateReader should be aborted in case of Close. It's best to wrap +// LazyReader in another struct that handles Close correctly, or better, stop +// using LazyReader entirely. type LazyReader struct { readerSync sync.Mutex - CreateReader func() (io.ReadCloser, error) - reader io.ReadCloser + CreateReader func() (io.Reader, error) + reader io.Reader readerError error } -func (r *LazyReader) getReader() (io.ReadCloser, error) { +func (r *LazyReader) getReader() (io.Reader, error) { r.readerSync.Lock() defer r.readerSync.Unlock() if r.reader != nil { @@ -43,17 +45,3 @@ func (r *LazyReader) Read(b []byte) (int, error) { n, err := reader.Read(b) return n, err } - -func (r *LazyReader) Close() error { - r.readerSync.Lock() - defer r.readerSync.Unlock() - - var err error - if r.reader != nil { - err = r.reader.Close() - r.reader = nil - r.readerError = errors.New("closed reader") - } - - return err -} diff --git a/transport/internet/splithttp/splithttp_test.go b/transport/internet/splithttp/splithttp_test.go index 30f92c7f..b01727e0 100644 --- a/transport/internet/splithttp/splithttp_test.go +++ b/transport/internet/splithttp/splithttp_test.go @@ -248,6 +248,8 @@ func Test_listenSHAndDial_QUIC(t *testing.T) { NextProtocol: []string{"h3"}, }, } + + serverClosed := false listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { go func() { defer conn.Close() @@ -258,10 +260,12 @@ func Test_listenSHAndDial_QUIC(t *testing.T) { for { b.Clear() if _, err := b.ReadFrom(conn); err != nil { - return + break } common.Must2(conn.Write(b.Bytes())) } + + serverClosed = true }() }) common.Must(err) @@ -271,7 +275,6 @@ func Test_listenSHAndDial_QUIC(t *testing.T) { conn, err := Dial(context.Background(), net.UDPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) common.Must(err) - defer conn.Close() const N = 1024 b1 := make([]byte, N) @@ -294,6 +297,12 @@ func Test_listenSHAndDial_QUIC(t *testing.T) { t.Error(r) } + conn.Close() + time.Sleep(100 * time.Millisecond) + if !serverClosed { + t.Error("server did not get closed") + } + end := time.Now() if !end.Before(start.Add(time.Second * 5)) { t.Error("end: ", end, " start: ", start) diff --git a/transport/internet/splithttp/strip_ok_reader.go b/transport/internet/splithttp/strip_ok_reader.go new file mode 100644 index 00000000..5dbbe22b --- /dev/null +++ b/transport/internet/splithttp/strip_ok_reader.go @@ -0,0 +1,48 @@ +package splithttp + +import ( + "bytes" + "io" + + "github.com/xtls/xray-core/common/errors" +) + +// in older versions of splithttp, the server would respond with `ok` to flush +// out HTTP response headers early. Response headers and a 200 OK were required +// to initiate the connection. Later versions of splithttp dropped this +// requirement, and in xray 1.8.24 the server stopped sending "ok" if it sees +// x_padding. For compatibility, we need to remove "ok" from the underlying +// reader if it exists, and otherwise forward the stream as-is. +type stripOkReader struct { + io.ReadCloser + firstDone bool + prefixRead []byte +} + +func (r *stripOkReader) Read(b []byte) (int, error) { + if !r.firstDone { + r.firstDone = true + + // skip "ok" response + prefixRead := []byte{0, 0} + _, err := io.ReadFull(r.ReadCloser, prefixRead) + if err != nil { + return 0, errors.New("failed to read initial response").Base(err) + } + + if !bytes.Equal(prefixRead, []byte("ok")) { + // we read some garbage byte that may not have been "ok" at + // all. return a reader that replays what we have read so far + r.prefixRead = prefixRead + } + } + + if len(r.prefixRead) > 0 { + n := copy(b, r.prefixRead) + r.prefixRead = r.prefixRead[n:] + return n, nil + } + + n, err := r.ReadCloser.Read(b) + return n, err +} diff --git a/transport/internet/splithttp/upload_queue.go b/transport/internet/splithttp/upload_queue.go index 23124c4f..9ac38f8b 100644 --- a/transport/internet/splithttp/upload_queue.go +++ b/transport/internet/splithttp/upload_queue.go @@ -51,8 +51,10 @@ func (h *uploadQueue) Close() error { h.writeCloseMutex.Lock() defer h.writeCloseMutex.Unlock() - h.closed = true - close(h.pushedPackets) + if !h.closed { + h.closed = true + close(h.pushedPackets) + } return nil }