Fix SplitHTTP race condition when creating new sessions (#3533)

Co-authored-by: nobody <nobody@nowhere.mars>
Co-authored-by: mmmray <142015632+mmmray@users.noreply.github.com>
This commit is contained in:
yuhan6665 2024-07-17 07:41:17 -04:00 committed by GitHub
parent a7e198e1e2
commit 02cd3b8c74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 1 deletions

View File

@ -27,6 +27,7 @@ type requestHandler struct {
host string host string
path string path string
ln *Listener ln *Listener
sessionMu *sync.Mutex
sessions sync.Map sessions sync.Map
localAddr gonet.TCPAddr localAddr gonet.TCPAddr
} }
@ -56,11 +57,21 @@ func (h *requestHandler) maybeReapSession(isFullyConnected *done.Instance, sessi
} }
func (h *requestHandler) upsertSession(sessionId string) *httpSession { func (h *requestHandler) upsertSession(sessionId string) *httpSession {
// fast path
currentSessionAny, ok := h.sessions.Load(sessionId) currentSessionAny, ok := h.sessions.Load(sessionId)
if ok { if ok {
return currentSessionAny.(*httpSession) return currentSessionAny.(*httpSession)
} }
// slow path
h.sessionMu.Lock()
defer h.sessionMu.Unlock()
currentSessionAny, ok = h.sessions.Load(sessionId)
if ok {
return currentSessionAny.(*httpSession)
}
s := &httpSession{ s := &httpSession{
uploadQueue: NewUploadQueue(int(2 * h.ln.config.GetNormalizedMaxConcurrentUploads())), uploadQueue: NewUploadQueue(int(2 * h.ln.config.GetNormalizedMaxConcurrentUploads())),
isFullyConnected: done.New(), isFullyConnected: done.New(),
@ -277,6 +288,7 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
host: shSettings.Host, host: shSettings.Host,
path: shSettings.GetNormalizedPath(), path: shSettings.GetNormalizedPath(),
ln: l, ln: l,
sessionMu: &sync.Mutex{},
sessions: sync.Map{}, sessions: sync.Map{},
localAddr: localAddr, localAddr: localAddr,
} }

View File

@ -63,8 +63,8 @@ func Test_listenSHAndDial(t *testing.T) {
} }
common.Must(conn.Close()) common.Must(conn.Close())
<-time.After(time.Second * 5)
conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
common.Must(err) common.Must(err)
_, err = conn.Write([]byte("Test connection 2")) _, err = conn.Write([]byte("Test connection 2"))
common.Must(err) common.Must(err)