From 02cd3b8c74335681ea4bb3664833cdef5298d994 Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Wed, 17 Jul 2024 07:41:17 -0400 Subject: [PATCH] Fix SplitHTTP race condition when creating new sessions (#3533) Co-authored-by: nobody Co-authored-by: mmmray <142015632+mmmray@users.noreply.github.com> --- transport/internet/splithttp/hub.go | 12 ++++++++++++ transport/internet/splithttp/splithttp_test.go | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index 27fab68d..b31ec707 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -27,6 +27,7 @@ type requestHandler struct { host string path string ln *Listener + sessionMu *sync.Mutex sessions sync.Map localAddr gonet.TCPAddr } @@ -56,11 +57,21 @@ func (h *requestHandler) maybeReapSession(isFullyConnected *done.Instance, sessi } func (h *requestHandler) upsertSession(sessionId string) *httpSession { + // fast path currentSessionAny, ok := h.sessions.Load(sessionId) if ok { return currentSessionAny.(*httpSession) } + // slow path + h.sessionMu.Lock() + defer h.sessionMu.Unlock() + + currentSessionAny, ok = h.sessions.Load(sessionId) + if ok { + return currentSessionAny.(*httpSession) + } + s := &httpSession{ uploadQueue: NewUploadQueue(int(2 * h.ln.config.GetNormalizedMaxConcurrentUploads())), isFullyConnected: done.New(), @@ -277,6 +288,7 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet host: shSettings.Host, path: shSettings.GetNormalizedPath(), ln: l, + sessionMu: &sync.Mutex{}, sessions: sync.Map{}, localAddr: localAddr, } diff --git a/transport/internet/splithttp/splithttp_test.go b/transport/internet/splithttp/splithttp_test.go index d125cedd..c1c527ef 100644 --- a/transport/internet/splithttp/splithttp_test.go +++ b/transport/internet/splithttp/splithttp_test.go @@ -63,8 +63,8 @@ func Test_listenSHAndDial(t *testing.T) { } common.Must(conn.Close()) - <-time.After(time.Second * 5) conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) + common.Must(err) _, err = conn.Write([]byte("Test connection 2")) common.Must(err)