From c2ac58d93ee08d2df66839ecf72285291ea3e781 Mon Sep 17 00:00:00 2001 From: rPDmYQ <195319688+rPDmYQ@users.noreply.github.com> Date: Thu, 23 Jan 2025 15:41:02 +0000 Subject: [PATCH] XHTTP Client: Race Dialer --- transport/internet/splithttp/dialer.go | 60 ++- transport/internet/splithttp/race_dialer.go | 511 ++++++++++++++++++++ 2 files changed, 561 insertions(+), 10 deletions(-) create mode 100644 transport/internet/splithttp/race_dialer.go diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index 6a276e8f..d31d44a8 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptrace" "net/url" + "slices" "strconv" "sync" "sync/atomic" @@ -93,6 +94,9 @@ func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) str return "1.1" } if len(tlsConfig.NextProtocol) != 1 { + if slices.Contains(tlsConfig.NextProtocol, "h3") && slices.Contains(tlsConfig.NextProtocol, "h2") { + return "3+2" + } return "2" } if tlsConfig.NextProtocol[0] == "http/1.1" { @@ -101,6 +105,7 @@ func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) str if tlsConfig.NextProtocol[0] == "h3" { return "3" } + return "2" } @@ -109,14 +114,27 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea realityConfig := reality.ConfigFromStreamSettings(streamSettings) httpVersion := decideHTTPVersion(tlsConfig, realityConfig) - if httpVersion == "3" { - dest.Network = net.Network_UDP // better to keep this line - } var gotlsConfig *gotls.Config + var h3gotlsConfig *gotls.Config if tlsConfig != nil { gotlsConfig = tlsConfig.GetTLSConfig(tls.WithDestination(dest)) + h3gotlsConfig = gotlsConfig + + if httpVersion == "3+2" { + h3gotlsConfig = &gotls.Config{} + *h3gotlsConfig = *gotlsConfig + + // Make QUIC ALPN only contains h3, and remove h3 from TCP TLS ALPN + h3gotlsConfig.NextProtos = []string{"h3"} + h3idx := slices.Index(h3gotlsConfig.NextProtos, "h3") + // Don't modify original tlsConfig.NextProtocol + nextProtos := gotlsConfig.NextProtos + gotlsConfig.NextProtos = make([]string, 0, len(nextProtos)-1) + gotlsConfig.NextProtos = append(gotlsConfig.NextProtos, nextProtos[:h3idx]...) + gotlsConfig.NextProtos = append(gotlsConfig.NextProtos, nextProtos[h3idx+1:]...) + } } transportConfig := streamSettings.ProtocolSettings.(*Config) @@ -152,7 +170,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea var transport http.RoundTripper - if httpVersion == "3" { + makeH3Transport := func() *http3.Transport { if keepAlivePeriod == 0 { keepAlivePeriod = quicgoH3KeepAlivePeriod } @@ -168,9 +186,11 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea MaxIncomingStreams: -1, KeepAlivePeriod: keepAlivePeriod, } - transport = &http3.RoundTripper{ + dest := dest + dest.Network = net.Network_UDP + return &http3.Transport{ QUICConfig: quicConfig, - TLSClientConfig: gotlsConfig, + TLSClientConfig: h3gotlsConfig, Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) if err != nil { @@ -208,26 +228,30 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg) }, } - } else if httpVersion == "2" { + } + + makeH2Transport := func() *http2.Transport { if keepAlivePeriod == 0 { keepAlivePeriod = chromeH2KeepAlivePeriod } if keepAlivePeriod < 0 { keepAlivePeriod = 0 } - transport = &http2.Transport{ + return &http2.Transport{ DialTLSContext: func(ctxInner context.Context, network string, addr string, cfg *gotls.Config) (net.Conn, error) { return dialContext(ctxInner) }, IdleConnTimeout: connIdleTimeout, ReadIdleTimeout: keepAlivePeriod, } - } else { + } + + makeTransport := func() *http.Transport { httpDialContext := func(ctxInner context.Context, network string, addr string) (net.Conn, error) { return dialContext(ctxInner) } - transport = &http.Transport{ + return &http.Transport{ DialTLSContext: httpDialContext, DialContext: httpDialContext, IdleConnTimeout: connIdleTimeout, @@ -237,6 +261,22 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea } } + switch httpVersion { + case "3": + transport = makeH3Transport() + case "2": + transport = makeH2Transport() + case "3+2": + raceTransport := &raceTransport{ + h3: makeH3Transport(), + h2: makeH2Transport(), + dest: dest, + } + transport = raceTransport.setup() + default: + transport = makeTransport() + } + client := &DefaultDialerClient{ transportConfig: transportConfig, client: &http.Client{ diff --git a/transport/internet/splithttp/race_dialer.go b/transport/internet/splithttp/race_dialer.go new file mode 100644 index 00000000..8acd66dc --- /dev/null +++ b/transport/internet/splithttp/race_dialer.go @@ -0,0 +1,511 @@ +package splithttp + +import ( + "context" + gotls "crypto/tls" + goerrors "errors" + gonet "net" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + "unsafe" + + "github.com/xtls/quic-go" + "github.com/xtls/quic-go/http3" + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/net" + "golang.org/x/net/http2" + "golang.org/x/net/idna" +) + +const ( + // QuicSessionPool::GetTimeDelayForWaitingJob > kDefaultRTT + chromeH2TryDelay = 300 * time.Millisecond + + // kDefaultBrokenAlternativeProtocolDelay + chromeH3BrokenDelay = 5 * time.Minute +) + +type raceKeyType struct{} + +var raceKey raceKeyType + +type noDialKeyType struct{} + +var noDialKey noDialKeyType + +var ( + loseRaceError = goerrors.New("lose race") + brokenSpanError = goerrors.New("protocol temporarily broken") +) + +type raceState int + +const ( + raceInitialized raceState = 0 + raceEstablished raceState = 1 + raceErrored raceState = -1 +) + +type raceResult int + +const ( + raceInflight raceResult = 0 + raceH3 raceResult = 1 + raceH2 raceResult = 2 + raceFailed raceResult = -1 + raceInactive raceResult = -2 +) + +type raceNotify struct { + c chan struct{} + result raceResult + + // left is the remove counter. This entry should be removed from race map when it reached 0 + left atomic.Int32 +} + +func (r *raceNotify) wait() raceResult { + <-r.c + return r.result +} + +type raceInfo struct { + key uintptr + host string + notify *raceNotify +} + +type raceTransport struct { + h3 *http3.Transport + h2 *http2.Transport + dest net.Destination + + muFlag sync.Mutex + // flag must be read or write with muFlag holding. + flag map[uintptr]raceState + + muRace sync.RWMutex + // race must be read or write with muRace holding. + // Entries are: + // added by RoundTrip before racing start + // removed by the last finished racer + race map[string]*raceNotify + + h3broken atomic.Int64 +} + +func (t *raceTransport) isH3Broken() bool { + return time.Now().UnixNano()-t.h3broken.Load() < int64(chromeH3BrokenDelay) +} + +func (t *raceTransport) updateH3Broken(broken time.Time) { + value := broken.UnixNano() + old := t.h3broken.Load() + for old < value { + if t.h3broken.CompareAndSwap(old, value) { + break + } + old = t.h3broken.Load() + } +} + +func (t *raceTransport) setup() *raceTransport { + t.flag = make(map[uintptr]raceState) + t.race = make(map[string]*raceNotify) + + h3dial := t.h3.Dial + h2dial := t.h2.DialTLSContext + + t.h3.Dial = func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (conn quic.EarlyConnection, err error) { + if ctx.Value(noDialKey) != nil { + return nil, http3.ErrNoCachedConn + } + + info := ctx.Value(raceKey).(raceInfo) + key := info.key + defer func() { + // We can safely remove the raceNotify here, since both h2 and h3 Transport + // hold mutex while dialing. + // So another request can't slip in after we removed raceNotify but before + // Transport put the returned conn into pool - they will always reuse the conn we returned. + if err == nil { + info.notify.result = raceH3 + close(info.notify.c) + } + if info.notify.left.Add(-1) == 0 { + errors.LogDebug(ctx, "Race Dialer: h3 cleaning race wait") + t.muRace.Lock() + if t.race[info.host] == info.notify { + delete(t.race, info.host) + } + t.muRace.Unlock() + } + }() + + if t.isH3Broken() { + return nil, brokenSpanError + } + + t.muFlag.Lock() + established := t.flag[key] + t.muFlag.Unlock() + if established == raceEstablished { + errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)") + return nil, loseRaceError + } + + conn, err = h3dial(ctx, addr, tlsCfg, cfg) + + t.muFlag.Lock() + established = t.flag[key] + if err != nil { + // We fail. + // Record if we are the first, cleanup if we are the last. + if established == raceInitialized { + t.flag[key] = raceErrored + errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error)") + } else { + delete(t.flag, key) + errors.LogDebug(ctx, "Race Dialer: h3 draw (both error)") + } + t.muFlag.Unlock() + return nil, err + } + + switch established { + case raceEstablished: + // h2 wins. + delete(t.flag, key) + t.muFlag.Unlock() + _ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race") + errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)") + return nil, loseRaceError + case raceErrored: + // h2 errored first. We will always be used. + delete(t.flag, key) + t.muFlag.Unlock() + errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)") + return conn, nil + } + + // Don't consider we win until handshake completed. + t.muFlag.Unlock() + <-conn.HandshakeComplete() + errors.LogDebug(ctx, "Race Dialer: h3 handshake complete") + t.muFlag.Lock() + + established = t.flag[key] + switch established { + case raceEstablished: + // h2 wins. + delete(t.flag, key) + t.muFlag.Unlock() + _ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race") + errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)") + return nil, loseRaceError + case raceErrored: + // h2 errored first. We clean up. + delete(t.flag, key) + errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)") + case raceInitialized: + // Check if handshake failed. + if err = conn.Context().Err(); err != nil { + conn = nil + t.flag[key] = raceErrored + errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error first)") + } else { + t.flag[key] = raceEstablished + errors.LogDebug(ctx, "Race Dialer: h3 win (h3 first)") + } + } + + t.muFlag.Unlock() + return conn, err + } + + t.h2.DialTLSContext = func(ctx context.Context, network, addr string, cfg *gotls.Config) (conn net.Conn, err error) { + if ctx.Value(noDialKey) != nil { + return nil, http2.ErrNoCachedConn + } + + info := ctx.Value(raceKey).(raceInfo) + key := info.key + defer func() { + if err == nil { + info.notify.result = raceH2 + close(info.notify.c) + } + if info.notify.left.Add(-1) == 0 { + errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait") + t.muRace.Lock() + if t.race[info.host] == info.notify { + delete(t.race, info.host) + } + t.muRace.Unlock() + } + }() + + if !t.isH3Broken() { + time.Sleep(chromeH2TryDelay) + } + + t.muFlag.Lock() + established := t.flag[key] + t.muFlag.Unlock() + if established == raceEstablished { + errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)") + return nil, loseRaceError + } + + conn, err = h2dial(ctx, network, addr, cfg) + + t.muFlag.Lock() + established = t.flag[key] + if err != nil { + // We fail. + // Record if we are the first, cleanup if we are the last. + if established == raceInitialized { + t.flag[key] = raceErrored + errors.LogDebug(ctx, "Race Dialer: h2 lose (h2 error first)") + } else { + delete(t.flag, key) + errors.LogDebug(ctx, "Race Dialer: h2 draw (both error)") + } + t.muFlag.Unlock() + return nil, err + } + + switch established { + case raceEstablished: + // h3 wins. + delete(t.flag, key) + t.muFlag.Unlock() + _ = conn.Close() + errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)") + return nil, loseRaceError + case raceErrored: + // h3 errored first. We clean up. + delete(t.flag, key) + errors.LogDebug(ctx, "Race Dialer: h2 win (h3 error)") + case raceInitialized: + // We win, record. + t.flag[key] = raceEstablished + errors.LogDebug(ctx, "Race Dialer: h2 win (h2 first)") + } + + t.muFlag.Unlock() + return conn, nil + } + + return t +} + +func authorityAddr(scheme string, authority string) (addr string) { + host, port, err := net.SplitHostPort(authority) + if err != nil { // authority didn't have a port + host = authority + port = "" + } + if port == "" { // authority's port was empty + port = "443" + if scheme == "http" { + port = "80" + } + } + if a, err := idna.ToASCII(host); err == nil { + host = a + } + // IPv6 address literal, without a port: + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + return host + ":" + port + } + return gonet.JoinHostPort(host, port) +} + +func (t *raceTransport) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := req.Context() + + // If there is inflight racing for current host, let it finish first, + // so we can know and reuse winner's conn. + host := authorityAddr(req.URL.Scheme, req.URL.Host) + t.muRace.RLock() + notify, ok := t.race[host] + t.muRace.RUnlock() + +WaitRace: + raceResult := raceInactive + if ok { + errors.LogDebug(ctx, "Race Dialer: found inflight race to ", t.dest.NetAddr(), ", waiting race winner") + raceResult = notify.wait() + errors.LogDebug(ctx, "Race Dialer: winner for ", t.dest.NetAddr(), " resolved, continue handling request") + } + + reqNoDial := req.WithContext(context.WithValue(ctx, noDialKey, struct{}{})) + + // First see if there's cached connection, for both h3 and h2. + // - raceInactive: no inflight race. Try both. + // - raceH3/raceH2: another request just decided race winner. + // Losing Transport may not yet fail, so avoid trying it. + // - raceFailed: both failed. There won't be cached conn, no need to try. + // - raceInflight: should not see this state. + if raceResult == raceH3 || raceResult == raceInactive { + if resp, err := t.h3.RoundTripOpt(reqNoDial, http3.RoundTripOpt{OnlyCachedConn: true}); err == nil { + errors.LogInfo(ctx, "Race Dialer: use h3 connection for ", t.dest.NetAddr(), " (reusing conn)") + return resp, nil + } else if !goerrors.Is(err, http3.ErrNoCachedConn) { + return nil, err + } else if raceResult == raceH3 { + return nil, errors.New("Race Dialer: h3 just succeeded, but no cached conn") + } + } + + if raceResult == raceH2 || raceResult == raceInactive { + // http2.RoundTripOpt.OnlyCachedConn is not effective. However, our noDialKey will block dialing anyway. + if resp, err := t.h2.RoundTripOpt(reqNoDial, http2.RoundTripOpt{OnlyCachedConn: true}); err == nil { + errors.LogInfo(ctx, "Race Dialer: use h2 connection for ", t.dest.NetAddr(), " (reusing conn)") + return resp, nil + } else if !goerrors.Is(err, http2.ErrNoCachedConn) { + return nil, err + } else if raceResult == raceH2 { + return nil, errors.New("Race Dialer: h2 just succeeded, but no cached conn") + } + } + + // Both don't have cached conn. Now race between h2 and h3. + t.muRace.Lock() + // Recheck + notify, ok = t.race[host] + if ok { + // Some other request started racing before us, we wait for them to finish. + t.muRace.Unlock() + goto WaitRace + } + // We are the goroutine to initialize racing. + notify = &raceNotify{c: make(chan struct{})} + notify.left.Store(2) + t.race[host] = notify + t.muRace.Unlock() + + errors.LogDebug(ctx, "Race Dialer: start race to ", t.dest.NetAddr()) + + // Both RoundTripper can share req.Body, because only one can dial successfully, + // and proceed to read request body. + key := uintptr(unsafe.Pointer(req)) + raceCtx := context.WithValue(ctx, raceKey, raceInfo{ + key: key, + host: host, + notify: notify, + }) + req = req.WithContext(raceCtx) + + t.muFlag.Lock() + t.flag[key] = raceInitialized + t.muFlag.Unlock() + + h2resp := make(chan any) + h3resp := make(chan any) + raceDone := make(chan struct{}) + + defer func() { + if notify.result == raceInflight { + notify.result = raceFailed + close(notify.c) + } + close(raceDone) + }() + + go func() { + resp, err := t.h3.RoundTrip(req) + + var result any + if err == nil { + result = resp + old := t.h3broken.Swap(0) + if old != 0 { + errors.LogDebug(ctx, "Race Dialer: h3 connection succeed, clear broken state") + } + } else { + result = err + t.updateH3Broken(time.Now()) + errors.LogDebug(ctx, "Race Dialer: h3 connection failed, set broken state") + } + + select { + case h3resp <- result: + case <-raceDone: + } + }() + + go func() { + resp, err := t.h2.RoundTrip(req) + if notify.left.Add(-1) == 0 { + errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait") + t.muRace.Lock() + delete(t.race, host) + t.muRace.Unlock() + } + + var result any + if err == nil { + result = resp + } else { + result = err + } + + select { + case h2resp <- result: + case <-raceDone: + } + }() + + reportState := func(isH3 bool) { + winner := "h2" + if isH3 { + winner = "h3" + } + errors.LogInfo(ctx, "Race Dialer: use ", winner, " connection for ", t.dest.NetAddr(), " (race winner)") + } + + handleResult := func(respErr any, other chan any, isH3 bool) (*http.Response, error) { + switch value := respErr.(type) { + case *http.Response: + // we win + reportState(isH3) + return value, nil + case error: + switch otherValue := (<-other).(type) { + case *http.Response: + // other win + reportState(!isH3) + return otherValue, nil + case error: + switch { + // hide internal error + case goerrors.Is(value, loseRaceError) || goerrors.Is(value, brokenSpanError): + return nil, otherValue + case goerrors.Is(otherValue, loseRaceError) || goerrors.Is(otherValue, brokenSpanError): + return nil, value + // prefer h3 error + case isH3: + return nil, value + default: + return nil, otherValue + } + default: + panic("unreachable: unexpected response type") + } + default: + panic("unreachable: unexpected response type") + } + } + + select { + case respErr := <-h3resp: + return handleResult(respErr, h2resp, true) + case respErr := <-h2resp: + return handleResult(respErr, h3resp, false) + } +}