diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index d31d44a8..9f18e9ab 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -270,7 +270,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea raceTransport := &raceTransport{ h3: makeH3Transport(), h2: makeH2Transport(), - dest: dest, + dest: dest.NetAddr(), } transport = raceTransport.setup() default: diff --git a/transport/internet/splithttp/race_dialer.go b/transport/internet/splithttp/race_dialer.go index 8acd66dc..0efbc477 100644 --- a/transport/internet/splithttp/race_dialer.go +++ b/transport/internet/splithttp/race_dialer.go @@ -4,28 +4,37 @@ import ( "context" gotls "crypto/tls" goerrors "errors" - gonet "net" + "fmt" "net/http" - "strings" "sync" "sync/atomic" "time" - "unsafe" "github.com/xtls/quic-go" "github.com/xtls/quic-go/http3" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" "golang.org/x/net/http2" - "golang.org/x/net/idna" ) const ( + // net/quic/quic_session_pool.cc // QuicSessionPool::GetTimeDelayForWaitingJob > kDefaultRTT - chromeH2TryDelay = 300 * time.Millisecond + chromeH2DefaultTryDelay = 300 * time.Millisecond + // QuicSessionPool::GetTimeDelayForWaitingJob > srtt + chromeH2TryDelayScale = 1.5 + // net/http/broken_alternative_services.cc // kDefaultBrokenAlternativeProtocolDelay - chromeH3BrokenDelay = 5 * time.Minute + chromeH3BrokenInitialDelay = 5 * time.Minute + // kMaxBrokenAlternativeProtocolDelay + chromeH3BrokenMaxDelay = 48 * time.Hour + // kBrokenDelayMaxShift + chromeH3BrokenMaxShift = 18 + + // net/third_party/quiche/src/quiche/quic/core/congestion_control/rtt_stats.cc + // kAlpha + chromeH3SmoothRTTAlpha = 0.125 ) type raceKeyType struct{} @@ -41,12 +50,14 @@ var ( brokenSpanError = goerrors.New("protocol temporarily broken") ) -type raceState int +func isRaceInternalError(err error) bool { + return goerrors.Is(err, loseRaceError) || goerrors.Is(err, brokenSpanError) +} const ( - raceInitialized raceState = 0 - raceEstablished raceState = 1 - raceErrored raceState = -1 + raceInitialized = 0 + raceEstablished = 1 + raceErrored = -1 ) type raceResult int @@ -59,11 +70,118 @@ const ( raceInactive raceResult = -2 ) +type endpointInfo struct { + lastFail time.Time + failCount int + rtt atomic.Int64 +} + +var h3EndpointCatalog map[string]*endpointInfo +var h3EndpointCatalogLock sync.RWMutex + +func isH3Broken(endpoint string) bool { + h3EndpointCatalogLock.RLock() + defer h3EndpointCatalogLock.RUnlock() + info, ok := h3EndpointCatalog[endpoint] + if !ok { + return false + } + + brokenDuration := min(chromeH3BrokenInitialDelay< 0 { + return 0 + } + rtt := info.rtt.Load() + if rtt == 0 { + return chromeH2DefaultTryDelay + } + return time.Duration(chromeH2TryDelayScale * float64(rtt)) +} + +func updateH3Broken(endpoint string, brokenAt time.Time) int { + h3EndpointCatalogLock.Lock() + defer h3EndpointCatalogLock.Unlock() + if h3EndpointCatalog == nil { + h3EndpointCatalog = make(map[string]*endpointInfo) + } + + info, ok := h3EndpointCatalog[endpoint] + if !ok { + h3EndpointCatalog[endpoint] = &endpointInfo{ + lastFail: brokenAt, + failCount: 1, + } + return 1 + } + + info.failCount++ + if brokenAt.After(info.lastFail) { + info.lastFail = brokenAt + } + + return info.failCount +} + +func smoothedRtt(oldRtt, newRtt int64) int64 { + if oldRtt == 0 { + return newRtt + } + + return int64((1-chromeH3SmoothRTTAlpha)*float64(oldRtt) + chromeH3SmoothRTTAlpha*float64(newRtt)) +} + +func updateH3RTT(endpoint string, rtt time.Duration) { + h3EndpointCatalogLock.RLock() + info, ok := h3EndpointCatalog[endpoint] + if !ok || info.failCount > 0 { + h3EndpointCatalogLock.RUnlock() + updateH3RTTSlow(endpoint, rtt) + return + } + + defer h3EndpointCatalogLock.RUnlock() + for { + oldRtt := info.rtt.Load() + newRtt := smoothedRtt(oldRtt, int64(rtt)) + if info.rtt.CompareAndSwap(oldRtt, newRtt) { + return + } + } +} + +func updateH3RTTSlow(endpoint string, rtt time.Duration) { + h3EndpointCatalogLock.Lock() + defer h3EndpointCatalogLock.Unlock() + + info, ok := h3EndpointCatalog[endpoint] + switch { + case !ok: + info = &endpointInfo{} + info.rtt.Store(int64(rtt)) + case info.failCount > 0: + info.failCount = 0 + info.lastFail = time.Time{} + info.rtt.Store(int64(rtt)) + default: + info.rtt.Store(smoothedRtt(info.rtt.Load(), int64(rtt))) + } +} + type raceNotify struct { c chan struct{} result raceResult - // left is the remove counter. This entry should be removed from race map when it reached 0 + // left is the remove counter. It should be released when it reached 0 left atomic.Int32 } @@ -72,50 +190,16 @@ func (r *raceNotify) wait() raceResult { return r.result } -type raceInfo struct { - key uintptr - host string - notify *raceNotify -} - type raceTransport struct { h3 *http3.Transport h2 *http2.Transport - dest net.Destination + dest string - muFlag sync.Mutex - // flag must be read or write with muFlag holding. - flag map[uintptr]raceState - - muRace sync.RWMutex - // race must be read or write with muRace holding. - // Entries are: - // added by RoundTrip before racing start - // removed by the last finished racer - race map[string]*raceNotify - - h3broken atomic.Int64 -} - -func (t *raceTransport) isH3Broken() bool { - return time.Now().UnixNano()-t.h3broken.Load() < int64(chromeH3BrokenDelay) -} - -func (t *raceTransport) updateH3Broken(broken time.Time) { - value := broken.UnixNano() - old := t.h3broken.Load() - for old < value { - if t.h3broken.CompareAndSwap(old, value) { - break - } - old = t.h3broken.Load() - } + flag atomic.Int64 + notify atomic.Pointer[raceNotify] } func (t *raceTransport) setup() *raceTransport { - t.flag = make(map[uintptr]raceState) - t.race = make(map[string]*raceNotify) - h3dial := t.h3.Dial h2dial := t.h2.DialTLSContext @@ -124,106 +208,108 @@ func (t *raceTransport) setup() *raceTransport { return nil, http3.ErrNoCachedConn } - info := ctx.Value(raceKey).(raceInfo) - key := info.key + var dialStart time.Time + defer func() { + notify := t.notify.Load() + if err == nil { + updateH3RTT(t.dest, time.Since(dialStart)) + notify.result = raceH3 + close(notify.c) + } else if !isRaceInternalError(err) { + failed := updateH3Broken(t.dest, time.Now()) + errors.LogDebug(ctx, "Race Dialer: h3 connection to ", t.dest, " failed ", failed, "times") + } + // We can safely remove the raceNotify here, since both h2 and h3 Transport // hold mutex while dialing. // So another request can't slip in after we removed raceNotify but before // Transport put the returned conn into pool - they will always reuse the conn we returned. - if err == nil { - info.notify.result = raceH3 - close(info.notify.c) - } - if info.notify.left.Add(-1) == 0 { + if notify.left.Add(-1) == 0 { errors.LogDebug(ctx, "Race Dialer: h3 cleaning race wait") - t.muRace.Lock() - if t.race[info.host] == info.notify { - delete(t.race, info.host) - } - t.muRace.Unlock() + t.notify.Store(nil) } }() - if t.isH3Broken() { + if isH3Broken(t.dest) { return nil, brokenSpanError } - t.muFlag.Lock() - established := t.flag[key] - t.muFlag.Unlock() + established := t.flag.Load() if established == raceEstablished { - errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)") + errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established before try)") return nil, loseRaceError } + dialStart = time.Now() conn, err = h3dial(ctx, addr, tlsCfg, cfg) - t.muFlag.Lock() - established = t.flag[key] if err != nil { // We fail. - // Record if we are the first, cleanup if we are the last. - if established == raceInitialized { - t.flag[key] = raceErrored + // Record if we are the first. + if t.flag.CompareAndSwap(raceInitialized, raceErrored) { errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error)") } else { - delete(t.flag, key) errors.LogDebug(ctx, "Race Dialer: h3 draw (both error)") } - t.muFlag.Unlock() return nil, err } - switch established { + flag := t.flag.Load() + switch flag { case raceEstablished: // h2 wins. - delete(t.flag, key) - t.muFlag.Unlock() _ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race") - errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)") + errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established before handshake complete)") return nil, loseRaceError case raceErrored: // h2 errored first. We will always be used. - delete(t.flag, key) - t.muFlag.Unlock() errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)") return conn, nil + case raceInitialized: + // continue + default: + panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag)) } // Don't consider we win until handshake completed. - t.muFlag.Unlock() <-conn.HandshakeComplete() errors.LogDebug(ctx, "Race Dialer: h3 handshake complete") - t.muFlag.Lock() - established = t.flag[key] - switch established { + if err = conn.Context().Err(); err != nil { + if t.flag.CompareAndSwap(raceInitialized, raceErrored) { + errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error first)") + return nil, err + } + _ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race") + conn = nil + } else { + if t.flag.CompareAndSwap(raceInitialized, raceEstablished) { + errors.LogDebug(ctx, "Race Dialer: h3 win (h3 first)") + return conn, nil + } + } + + flag = t.flag.Load() + switch flag { case raceEstablished: // h2 wins. - delete(t.flag, key) - t.muFlag.Unlock() _ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race") errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)") return nil, loseRaceError case raceErrored: - // h2 errored first. We clean up. - delete(t.flag, key) - errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)") - case raceInitialized: - // Check if handshake failed. - if err = conn.Context().Err(); err != nil { - conn = nil - t.flag[key] = raceErrored - errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error first)") + // h2 errored first. + if err == nil { + errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)") } else { - t.flag[key] = raceEstablished - errors.LogDebug(ctx, "Race Dialer: h3 win (h3 first)") + errors.LogDebug(ctx, "Race Dialer: h3 draw (both error)") } + return conn, err + case raceInitialized: + panic("unreachable: race flag should not revert to raceInitialized") + default: + panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag)) } - - t.muFlag.Unlock() - return conn, err } t.h2.DialTLSContext = func(ctx context.Context, network, addr string, cfg *gotls.Config) (conn net.Conn, err error) { @@ -231,116 +317,84 @@ func (t *raceTransport) setup() *raceTransport { return nil, http2.ErrNoCachedConn } - info := ctx.Value(raceKey).(raceInfo) - key := info.key defer func() { + notify := t.notify.Load() if err == nil { - info.notify.result = raceH2 - close(info.notify.c) + notify.result = raceH2 + close(notify.c) } - if info.notify.left.Add(-1) == 0 { + if notify.left.Add(-1) == 0 { errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait") - t.muRace.Lock() - if t.race[info.host] == info.notify { - delete(t.race, info.host) - } - t.muRace.Unlock() + t.notify.Store(nil) } }() - if !t.isH3Broken() { - time.Sleep(chromeH2TryDelay) - } + delay := getH2Delay(t.dest) + errors.LogDebug(ctx, "Race Dialer: h2 dial delay: ", delay) + time.Sleep(delay) - t.muFlag.Lock() - established := t.flag[key] - t.muFlag.Unlock() + established := t.flag.Load() if established == raceEstablished { - errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)") + errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established before try)") return nil, loseRaceError } conn, err = h2dial(ctx, network, addr, cfg) - t.muFlag.Lock() - established = t.flag[key] if err != nil { // We fail. - // Record if we are the first, cleanup if we are the last. - if established == raceInitialized { - t.flag[key] = raceErrored + // Record if we are the first. + if t.flag.CompareAndSwap(raceInitialized, raceErrored) { errors.LogDebug(ctx, "Race Dialer: h2 lose (h2 error first)") - } else { - delete(t.flag, key) - errors.LogDebug(ctx, "Race Dialer: h2 draw (both error)") + return nil, err + } + _ = conn.Close() + conn = nil + } else { + if t.flag.CompareAndSwap(raceInitialized, raceEstablished) { + errors.LogDebug(ctx, "Race Dialer: h2 win (h2 first)") + return conn, nil } - t.muFlag.Unlock() - return nil, err } - switch established { + flag := t.flag.Load() + switch flag { case raceEstablished: // h3 wins. - delete(t.flag, key) - t.muFlag.Unlock() _ = conn.Close() errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)") return nil, loseRaceError case raceErrored: - // h3 errored first. We clean up. - delete(t.flag, key) - errors.LogDebug(ctx, "Race Dialer: h2 win (h3 error)") + // h3 errored first. + if err == nil { + errors.LogDebug(ctx, "Race Dialer: h2 win (h3 error)") + } else { + errors.LogDebug(ctx, "Race Dialer: h2 draw (both error)") + } + return conn, err case raceInitialized: - // We win, record. - t.flag[key] = raceEstablished - errors.LogDebug(ctx, "Race Dialer: h2 win (h2 first)") + panic("unreachable: race flag should not revert to raceInitialized") + default: + panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag)) } - - t.muFlag.Unlock() - return conn, nil } return t } -func authorityAddr(scheme string, authority string) (addr string) { - host, port, err := net.SplitHostPort(authority) - if err != nil { // authority didn't have a port - host = authority - port = "" - } - if port == "" { // authority's port was empty - port = "443" - if scheme == "http" { - port = "80" - } - } - if a, err := idna.ToASCII(host); err == nil { - host = a - } - // IPv6 address literal, without a port: - if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { - return host + ":" + port - } - return gonet.JoinHostPort(host, port) -} - func (t *raceTransport) RoundTrip(req *http.Request) (*http.Response, error) { ctx := req.Context() - // If there is inflight racing for current host, let it finish first, + // If there is inflight racing, let it finish first, // so we can know and reuse winner's conn. - host := authorityAddr(req.URL.Scheme, req.URL.Host) - t.muRace.RLock() - notify, ok := t.race[host] - t.muRace.RUnlock() + notify := t.notify.Load() + raceResult := raceInactive WaitRace: - raceResult := raceInactive - if ok { - errors.LogDebug(ctx, "Race Dialer: found inflight race to ", t.dest.NetAddr(), ", waiting race winner") + if notify != nil { + errors.LogDebug(ctx, "Race Dialer: found inflight race to ", t.dest, ", waiting race winner") raceResult = notify.wait() - errors.LogDebug(ctx, "Race Dialer: winner for ", t.dest.NetAddr(), " resolved, continue handling request") + errors.LogDebug(ctx, "Race Dialer: winner for ", t.dest, " resolved, continue handling request") } reqNoDial := req.WithContext(context.WithValue(ctx, noDialKey, struct{}{})) @@ -353,7 +407,7 @@ WaitRace: // - raceInflight: should not see this state. if raceResult == raceH3 || raceResult == raceInactive { if resp, err := t.h3.RoundTripOpt(reqNoDial, http3.RoundTripOpt{OnlyCachedConn: true}); err == nil { - errors.LogInfo(ctx, "Race Dialer: use h3 connection for ", t.dest.NetAddr(), " (reusing conn)") + errors.LogInfo(ctx, "Race Dialer: use h3 connection for ", t.dest, " (reusing conn)") return resp, nil } else if !goerrors.Is(err, http3.ErrNoCachedConn) { return nil, err @@ -365,7 +419,7 @@ WaitRace: if raceResult == raceH2 || raceResult == raceInactive { // http2.RoundTripOpt.OnlyCachedConn is not effective. However, our noDialKey will block dialing anyway. if resp, err := t.h2.RoundTripOpt(reqNoDial, http2.RoundTripOpt{OnlyCachedConn: true}); err == nil { - errors.LogInfo(ctx, "Race Dialer: use h2 connection for ", t.dest.NetAddr(), " (reusing conn)") + errors.LogInfo(ctx, "Race Dialer: use h2 connection for ", t.dest, " (reusing conn)") return resp, nil } else if !goerrors.Is(err, http2.ErrNoCachedConn) { return nil, err @@ -375,35 +429,18 @@ WaitRace: } // Both don't have cached conn. Now race between h2 and h3. - t.muRace.Lock() - // Recheck - notify, ok = t.race[host] - if ok { - // Some other request started racing before us, we wait for them to finish. - t.muRace.Unlock() - goto WaitRace - } - // We are the goroutine to initialize racing. + // Recheck first. notify = &raceNotify{c: make(chan struct{})} notify.left.Store(2) - t.race[host] = notify - t.muRace.Unlock() + if !t.notify.CompareAndSwap(nil, notify) { + // Some other request started racing before us, we wait for them to finish. + goto WaitRace + } - errors.LogDebug(ctx, "Race Dialer: start race to ", t.dest.NetAddr()) + // We are the goroutine to initialize racing. + errors.LogDebug(ctx, "Race Dialer: start race to ", t.dest) - // Both RoundTripper can share req.Body, because only one can dial successfully, - // and proceed to read request body. - key := uintptr(unsafe.Pointer(req)) - raceCtx := context.WithValue(ctx, raceKey, raceInfo{ - key: key, - host: host, - notify: notify, - }) - req = req.WithContext(raceCtx) - - t.muFlag.Lock() - t.flag[key] = raceInitialized - t.muFlag.Unlock() + t.flag.Store(raceInitialized) h2resp := make(chan any) h3resp := make(chan any) @@ -417,36 +454,10 @@ WaitRace: close(raceDone) }() - go func() { - resp, err := t.h3.RoundTrip(req) - - var result any - if err == nil { - result = resp - old := t.h3broken.Swap(0) - if old != 0 { - errors.LogDebug(ctx, "Race Dialer: h3 connection succeed, clear broken state") - } - } else { - result = err - t.updateH3Broken(time.Now()) - errors.LogDebug(ctx, "Race Dialer: h3 connection failed, set broken state") - } - - select { - case h3resp <- result: - case <-raceDone: - } - }() - - go func() { - resp, err := t.h2.RoundTrip(req) - if notify.left.Add(-1) == 0 { - errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait") - t.muRace.Lock() - delete(t.race, host) - t.muRace.Unlock() - } + // Both RoundTripper can share req.Body, because only one can dial successfully, + // and proceed to read request body. + roundTrip := func(r http.RoundTripper, respChan chan any) { + resp, err := r.RoundTrip(req) var result any if err == nil { @@ -456,17 +467,20 @@ WaitRace: } select { - case h2resp <- result: + case respChan <- result: case <-raceDone: } - }() + } + + go roundTrip(t.h3, h3resp) + go roundTrip(t.h2, h2resp) reportState := func(isH3 bool) { winner := "h2" if isH3 { winner = "h3" } - errors.LogInfo(ctx, "Race Dialer: use ", winner, " connection for ", t.dest.NetAddr(), " (race winner)") + errors.LogInfo(ctx, "Race Dialer: use ", winner, " connection for ", t.dest, " (race winner)") } handleResult := func(respErr any, other chan any, isH3 bool) (*http.Response, error) { @@ -484,9 +498,9 @@ WaitRace: case error: switch { // hide internal error - case goerrors.Is(value, loseRaceError) || goerrors.Is(value, brokenSpanError): + case isRaceInternalError(value): return nil, otherValue - case goerrors.Is(otherValue, loseRaceError) || goerrors.Is(otherValue, brokenSpanError): + case isRaceInternalError(otherValue): return nil, value // prefer h3 error case isH3: @@ -495,10 +509,10 @@ WaitRace: return nil, otherValue } default: - panic("unreachable: unexpected response type") + panic(fmt.Sprintf("unreachable: unexpected response type %T", otherValue)) } default: - panic("unreachable: unexpected response type") + panic(fmt.Sprintf("unreachable: unexpected response type %T", value)) } }