XHTTP Client: Race Dialer

This commit is contained in:
rPDmYQ 2025-01-23 15:41:02 +00:00 committed by GitHub
parent ca9a902213
commit c2ac58d93e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 561 additions and 10 deletions

View File

@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptrace"
"net/url"
"slices"
"strconv"
"sync"
"sync/atomic"
@ -93,6 +94,9 @@ func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) str
return "1.1"
}
if len(tlsConfig.NextProtocol) != 1 {
if slices.Contains(tlsConfig.NextProtocol, "h3") && slices.Contains(tlsConfig.NextProtocol, "h2") {
return "3+2"
}
return "2"
}
if tlsConfig.NextProtocol[0] == "http/1.1" {
@ -101,6 +105,7 @@ func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) str
if tlsConfig.NextProtocol[0] == "h3" {
return "3"
}
return "2"
}
@ -109,14 +114,27 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
realityConfig := reality.ConfigFromStreamSettings(streamSettings)
httpVersion := decideHTTPVersion(tlsConfig, realityConfig)
if httpVersion == "3" {
dest.Network = net.Network_UDP // better to keep this line
}
var gotlsConfig *gotls.Config
var h3gotlsConfig *gotls.Config
if tlsConfig != nil {
gotlsConfig = tlsConfig.GetTLSConfig(tls.WithDestination(dest))
h3gotlsConfig = gotlsConfig
if httpVersion == "3+2" {
h3gotlsConfig = &gotls.Config{}
*h3gotlsConfig = *gotlsConfig
// Make QUIC ALPN only contains h3, and remove h3 from TCP TLS ALPN
h3gotlsConfig.NextProtos = []string{"h3"}
h3idx := slices.Index(h3gotlsConfig.NextProtos, "h3")
// Don't modify original tlsConfig.NextProtocol
nextProtos := gotlsConfig.NextProtos
gotlsConfig.NextProtos = make([]string, 0, len(nextProtos)-1)
gotlsConfig.NextProtos = append(gotlsConfig.NextProtos, nextProtos[:h3idx]...)
gotlsConfig.NextProtos = append(gotlsConfig.NextProtos, nextProtos[h3idx+1:]...)
}
}
transportConfig := streamSettings.ProtocolSettings.(*Config)
@ -152,7 +170,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
var transport http.RoundTripper
if httpVersion == "3" {
makeH3Transport := func() *http3.Transport {
if keepAlivePeriod == 0 {
keepAlivePeriod = quicgoH3KeepAlivePeriod
}
@ -168,9 +186,11 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
MaxIncomingStreams: -1,
KeepAlivePeriod: keepAlivePeriod,
}
transport = &http3.RoundTripper{
dest := dest
dest.Network = net.Network_UDP
return &http3.Transport{
QUICConfig: quicConfig,
TLSClientConfig: gotlsConfig,
TLSClientConfig: h3gotlsConfig,
Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
if err != nil {
@ -208,26 +228,30 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
},
}
} else if httpVersion == "2" {
}
makeH2Transport := func() *http2.Transport {
if keepAlivePeriod == 0 {
keepAlivePeriod = chromeH2KeepAlivePeriod
}
if keepAlivePeriod < 0 {
keepAlivePeriod = 0
}
transport = &http2.Transport{
return &http2.Transport{
DialTLSContext: func(ctxInner context.Context, network string, addr string, cfg *gotls.Config) (net.Conn, error) {
return dialContext(ctxInner)
},
IdleConnTimeout: connIdleTimeout,
ReadIdleTimeout: keepAlivePeriod,
}
} else {
}
makeTransport := func() *http.Transport {
httpDialContext := func(ctxInner context.Context, network string, addr string) (net.Conn, error) {
return dialContext(ctxInner)
}
transport = &http.Transport{
return &http.Transport{
DialTLSContext: httpDialContext,
DialContext: httpDialContext,
IdleConnTimeout: connIdleTimeout,
@ -237,6 +261,22 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
}
}
switch httpVersion {
case "3":
transport = makeH3Transport()
case "2":
transport = makeH2Transport()
case "3+2":
raceTransport := &raceTransport{
h3: makeH3Transport(),
h2: makeH2Transport(),
dest: dest,
}
transport = raceTransport.setup()
default:
transport = makeTransport()
}
client := &DefaultDialerClient{
transportConfig: transportConfig,
client: &http.Client{

View File

@ -0,0 +1,511 @@
package splithttp
import (
"context"
gotls "crypto/tls"
goerrors "errors"
gonet "net"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"unsafe"
"github.com/xtls/quic-go"
"github.com/xtls/quic-go/http3"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"golang.org/x/net/http2"
"golang.org/x/net/idna"
)
const (
// QuicSessionPool::GetTimeDelayForWaitingJob > kDefaultRTT
chromeH2TryDelay = 300 * time.Millisecond
// kDefaultBrokenAlternativeProtocolDelay
chromeH3BrokenDelay = 5 * time.Minute
)
type raceKeyType struct{}
var raceKey raceKeyType
type noDialKeyType struct{}
var noDialKey noDialKeyType
var (
loseRaceError = goerrors.New("lose race")
brokenSpanError = goerrors.New("protocol temporarily broken")
)
type raceState int
const (
raceInitialized raceState = 0
raceEstablished raceState = 1
raceErrored raceState = -1
)
type raceResult int
const (
raceInflight raceResult = 0
raceH3 raceResult = 1
raceH2 raceResult = 2
raceFailed raceResult = -1
raceInactive raceResult = -2
)
type raceNotify struct {
c chan struct{}
result raceResult
// left is the remove counter. This entry should be removed from race map when it reached 0
left atomic.Int32
}
func (r *raceNotify) wait() raceResult {
<-r.c
return r.result
}
type raceInfo struct {
key uintptr
host string
notify *raceNotify
}
type raceTransport struct {
h3 *http3.Transport
h2 *http2.Transport
dest net.Destination
muFlag sync.Mutex
// flag must be read or write with muFlag holding.
flag map[uintptr]raceState
muRace sync.RWMutex
// race must be read or write with muRace holding.
// Entries are:
// added by RoundTrip before racing start
// removed by the last finished racer
race map[string]*raceNotify
h3broken atomic.Int64
}
func (t *raceTransport) isH3Broken() bool {
return time.Now().UnixNano()-t.h3broken.Load() < int64(chromeH3BrokenDelay)
}
func (t *raceTransport) updateH3Broken(broken time.Time) {
value := broken.UnixNano()
old := t.h3broken.Load()
for old < value {
if t.h3broken.CompareAndSwap(old, value) {
break
}
old = t.h3broken.Load()
}
}
func (t *raceTransport) setup() *raceTransport {
t.flag = make(map[uintptr]raceState)
t.race = make(map[string]*raceNotify)
h3dial := t.h3.Dial
h2dial := t.h2.DialTLSContext
t.h3.Dial = func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (conn quic.EarlyConnection, err error) {
if ctx.Value(noDialKey) != nil {
return nil, http3.ErrNoCachedConn
}
info := ctx.Value(raceKey).(raceInfo)
key := info.key
defer func() {
// We can safely remove the raceNotify here, since both h2 and h3 Transport
// hold mutex while dialing.
// So another request can't slip in after we removed raceNotify but before
// Transport put the returned conn into pool - they will always reuse the conn we returned.
if err == nil {
info.notify.result = raceH3
close(info.notify.c)
}
if info.notify.left.Add(-1) == 0 {
errors.LogDebug(ctx, "Race Dialer: h3 cleaning race wait")
t.muRace.Lock()
if t.race[info.host] == info.notify {
delete(t.race, info.host)
}
t.muRace.Unlock()
}
}()
if t.isH3Broken() {
return nil, brokenSpanError
}
t.muFlag.Lock()
established := t.flag[key]
t.muFlag.Unlock()
if established == raceEstablished {
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
return nil, loseRaceError
}
conn, err = h3dial(ctx, addr, tlsCfg, cfg)
t.muFlag.Lock()
established = t.flag[key]
if err != nil {
// We fail.
// Record if we are the first, cleanup if we are the last.
if established == raceInitialized {
t.flag[key] = raceErrored
errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error)")
} else {
delete(t.flag, key)
errors.LogDebug(ctx, "Race Dialer: h3 draw (both error)")
}
t.muFlag.Unlock()
return nil, err
}
switch established {
case raceEstablished:
// h2 wins.
delete(t.flag, key)
t.muFlag.Unlock()
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
return nil, loseRaceError
case raceErrored:
// h2 errored first. We will always be used.
delete(t.flag, key)
t.muFlag.Unlock()
errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)")
return conn, nil
}
// Don't consider we win until handshake completed.
t.muFlag.Unlock()
<-conn.HandshakeComplete()
errors.LogDebug(ctx, "Race Dialer: h3 handshake complete")
t.muFlag.Lock()
established = t.flag[key]
switch established {
case raceEstablished:
// h2 wins.
delete(t.flag, key)
t.muFlag.Unlock()
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
return nil, loseRaceError
case raceErrored:
// h2 errored first. We clean up.
delete(t.flag, key)
errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)")
case raceInitialized:
// Check if handshake failed.
if err = conn.Context().Err(); err != nil {
conn = nil
t.flag[key] = raceErrored
errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error first)")
} else {
t.flag[key] = raceEstablished
errors.LogDebug(ctx, "Race Dialer: h3 win (h3 first)")
}
}
t.muFlag.Unlock()
return conn, err
}
t.h2.DialTLSContext = func(ctx context.Context, network, addr string, cfg *gotls.Config) (conn net.Conn, err error) {
if ctx.Value(noDialKey) != nil {
return nil, http2.ErrNoCachedConn
}
info := ctx.Value(raceKey).(raceInfo)
key := info.key
defer func() {
if err == nil {
info.notify.result = raceH2
close(info.notify.c)
}
if info.notify.left.Add(-1) == 0 {
errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait")
t.muRace.Lock()
if t.race[info.host] == info.notify {
delete(t.race, info.host)
}
t.muRace.Unlock()
}
}()
if !t.isH3Broken() {
time.Sleep(chromeH2TryDelay)
}
t.muFlag.Lock()
established := t.flag[key]
t.muFlag.Unlock()
if established == raceEstablished {
errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)")
return nil, loseRaceError
}
conn, err = h2dial(ctx, network, addr, cfg)
t.muFlag.Lock()
established = t.flag[key]
if err != nil {
// We fail.
// Record if we are the first, cleanup if we are the last.
if established == raceInitialized {
t.flag[key] = raceErrored
errors.LogDebug(ctx, "Race Dialer: h2 lose (h2 error first)")
} else {
delete(t.flag, key)
errors.LogDebug(ctx, "Race Dialer: h2 draw (both error)")
}
t.muFlag.Unlock()
return nil, err
}
switch established {
case raceEstablished:
// h3 wins.
delete(t.flag, key)
t.muFlag.Unlock()
_ = conn.Close()
errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)")
return nil, loseRaceError
case raceErrored:
// h3 errored first. We clean up.
delete(t.flag, key)
errors.LogDebug(ctx, "Race Dialer: h2 win (h3 error)")
case raceInitialized:
// We win, record.
t.flag[key] = raceEstablished
errors.LogDebug(ctx, "Race Dialer: h2 win (h2 first)")
}
t.muFlag.Unlock()
return conn, nil
}
return t
}
func authorityAddr(scheme string, authority string) (addr string) {
host, port, err := net.SplitHostPort(authority)
if err != nil { // authority didn't have a port
host = authority
port = ""
}
if port == "" { // authority's port was empty
port = "443"
if scheme == "http" {
port = "80"
}
}
if a, err := idna.ToASCII(host); err == nil {
host = a
}
// IPv6 address literal, without a port:
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host + ":" + port
}
return gonet.JoinHostPort(host, port)
}
func (t *raceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()
// If there is inflight racing for current host, let it finish first,
// so we can know and reuse winner's conn.
host := authorityAddr(req.URL.Scheme, req.URL.Host)
t.muRace.RLock()
notify, ok := t.race[host]
t.muRace.RUnlock()
WaitRace:
raceResult := raceInactive
if ok {
errors.LogDebug(ctx, "Race Dialer: found inflight race to ", t.dest.NetAddr(), ", waiting race winner")
raceResult = notify.wait()
errors.LogDebug(ctx, "Race Dialer: winner for ", t.dest.NetAddr(), " resolved, continue handling request")
}
reqNoDial := req.WithContext(context.WithValue(ctx, noDialKey, struct{}{}))
// First see if there's cached connection, for both h3 and h2.
// - raceInactive: no inflight race. Try both.
// - raceH3/raceH2: another request just decided race winner.
// Losing Transport may not yet fail, so avoid trying it.
// - raceFailed: both failed. There won't be cached conn, no need to try.
// - raceInflight: should not see this state.
if raceResult == raceH3 || raceResult == raceInactive {
if resp, err := t.h3.RoundTripOpt(reqNoDial, http3.RoundTripOpt{OnlyCachedConn: true}); err == nil {
errors.LogInfo(ctx, "Race Dialer: use h3 connection for ", t.dest.NetAddr(), " (reusing conn)")
return resp, nil
} else if !goerrors.Is(err, http3.ErrNoCachedConn) {
return nil, err
} else if raceResult == raceH3 {
return nil, errors.New("Race Dialer: h3 just succeeded, but no cached conn")
}
}
if raceResult == raceH2 || raceResult == raceInactive {
// http2.RoundTripOpt.OnlyCachedConn is not effective. However, our noDialKey will block dialing anyway.
if resp, err := t.h2.RoundTripOpt(reqNoDial, http2.RoundTripOpt{OnlyCachedConn: true}); err == nil {
errors.LogInfo(ctx, "Race Dialer: use h2 connection for ", t.dest.NetAddr(), " (reusing conn)")
return resp, nil
} else if !goerrors.Is(err, http2.ErrNoCachedConn) {
return nil, err
} else if raceResult == raceH2 {
return nil, errors.New("Race Dialer: h2 just succeeded, but no cached conn")
}
}
// Both don't have cached conn. Now race between h2 and h3.
t.muRace.Lock()
// Recheck
notify, ok = t.race[host]
if ok {
// Some other request started racing before us, we wait for them to finish.
t.muRace.Unlock()
goto WaitRace
}
// We are the goroutine to initialize racing.
notify = &raceNotify{c: make(chan struct{})}
notify.left.Store(2)
t.race[host] = notify
t.muRace.Unlock()
errors.LogDebug(ctx, "Race Dialer: start race to ", t.dest.NetAddr())
// Both RoundTripper can share req.Body, because only one can dial successfully,
// and proceed to read request body.
key := uintptr(unsafe.Pointer(req))
raceCtx := context.WithValue(ctx, raceKey, raceInfo{
key: key,
host: host,
notify: notify,
})
req = req.WithContext(raceCtx)
t.muFlag.Lock()
t.flag[key] = raceInitialized
t.muFlag.Unlock()
h2resp := make(chan any)
h3resp := make(chan any)
raceDone := make(chan struct{})
defer func() {
if notify.result == raceInflight {
notify.result = raceFailed
close(notify.c)
}
close(raceDone)
}()
go func() {
resp, err := t.h3.RoundTrip(req)
var result any
if err == nil {
result = resp
old := t.h3broken.Swap(0)
if old != 0 {
errors.LogDebug(ctx, "Race Dialer: h3 connection succeed, clear broken state")
}
} else {
result = err
t.updateH3Broken(time.Now())
errors.LogDebug(ctx, "Race Dialer: h3 connection failed, set broken state")
}
select {
case h3resp <- result:
case <-raceDone:
}
}()
go func() {
resp, err := t.h2.RoundTrip(req)
if notify.left.Add(-1) == 0 {
errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait")
t.muRace.Lock()
delete(t.race, host)
t.muRace.Unlock()
}
var result any
if err == nil {
result = resp
} else {
result = err
}
select {
case h2resp <- result:
case <-raceDone:
}
}()
reportState := func(isH3 bool) {
winner := "h2"
if isH3 {
winner = "h3"
}
errors.LogInfo(ctx, "Race Dialer: use ", winner, " connection for ", t.dest.NetAddr(), " (race winner)")
}
handleResult := func(respErr any, other chan any, isH3 bool) (*http.Response, error) {
switch value := respErr.(type) {
case *http.Response:
// we win
reportState(isH3)
return value, nil
case error:
switch otherValue := (<-other).(type) {
case *http.Response:
// other win
reportState(!isH3)
return otherValue, nil
case error:
switch {
// hide internal error
case goerrors.Is(value, loseRaceError) || goerrors.Is(value, brokenSpanError):
return nil, otherValue
case goerrors.Is(otherValue, loseRaceError) || goerrors.Is(otherValue, brokenSpanError):
return nil, value
// prefer h3 error
case isH3:
return nil, value
default:
return nil, otherValue
}
default:
panic("unreachable: unexpected response type")
}
default:
panic("unreachable: unexpected response type")
}
}
select {
case respErr := <-h3resp:
return handleResult(respErr, h2resp, true)
case respErr := <-h2resp:
return handleResult(respErr, h3resp, false)
}
}