Adapt to current single connection per client

This commit is contained in:
rPDmYQ 2025-01-24 09:56:47 +00:00 committed by GitHub
parent c2ac58d93e
commit 9006e3b35c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 247 additions and 233 deletions

View File

@ -270,7 +270,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
raceTransport := &raceTransport{
h3: makeH3Transport(),
h2: makeH2Transport(),
dest: dest,
dest: dest.NetAddr(),
}
transport = raceTransport.setup()
default:

View File

@ -4,28 +4,37 @@ import (
"context"
gotls "crypto/tls"
goerrors "errors"
gonet "net"
"fmt"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"unsafe"
"github.com/xtls/quic-go"
"github.com/xtls/quic-go/http3"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"golang.org/x/net/http2"
"golang.org/x/net/idna"
)
const (
// net/quic/quic_session_pool.cc
// QuicSessionPool::GetTimeDelayForWaitingJob > kDefaultRTT
chromeH2TryDelay = 300 * time.Millisecond
chromeH2DefaultTryDelay = 300 * time.Millisecond
// QuicSessionPool::GetTimeDelayForWaitingJob > srtt
chromeH2TryDelayScale = 1.5
// net/http/broken_alternative_services.cc
// kDefaultBrokenAlternativeProtocolDelay
chromeH3BrokenDelay = 5 * time.Minute
chromeH3BrokenInitialDelay = 5 * time.Minute
// kMaxBrokenAlternativeProtocolDelay
chromeH3BrokenMaxDelay = 48 * time.Hour
// kBrokenDelayMaxShift
chromeH3BrokenMaxShift = 18
// net/third_party/quiche/src/quiche/quic/core/congestion_control/rtt_stats.cc
// kAlpha
chromeH3SmoothRTTAlpha = 0.125
)
type raceKeyType struct{}
@ -41,12 +50,14 @@ var (
brokenSpanError = goerrors.New("protocol temporarily broken")
)
type raceState int
func isRaceInternalError(err error) bool {
return goerrors.Is(err, loseRaceError) || goerrors.Is(err, brokenSpanError)
}
const (
raceInitialized raceState = 0
raceEstablished raceState = 1
raceErrored raceState = -1
raceInitialized = 0
raceEstablished = 1
raceErrored = -1
)
type raceResult int
@ -59,11 +70,118 @@ const (
raceInactive raceResult = -2
)
type endpointInfo struct {
lastFail time.Time
failCount int
rtt atomic.Int64
}
var h3EndpointCatalog map[string]*endpointInfo
var h3EndpointCatalogLock sync.RWMutex
func isH3Broken(endpoint string) bool {
h3EndpointCatalogLock.RLock()
defer h3EndpointCatalogLock.RUnlock()
info, ok := h3EndpointCatalog[endpoint]
if !ok {
return false
}
brokenDuration := min(chromeH3BrokenInitialDelay<<min(info.failCount, chromeH3BrokenMaxShift), chromeH3BrokenMaxDelay)
return time.Since(info.lastFail) < brokenDuration
}
func getH2Delay(endpoint string) time.Duration {
h3EndpointCatalogLock.RLock()
defer h3EndpointCatalogLock.RUnlock()
info, ok := h3EndpointCatalog[endpoint]
if !ok {
return chromeH2DefaultTryDelay
}
if info.failCount > 0 {
return 0
}
rtt := info.rtt.Load()
if rtt == 0 {
return chromeH2DefaultTryDelay
}
return time.Duration(chromeH2TryDelayScale * float64(rtt))
}
func updateH3Broken(endpoint string, brokenAt time.Time) int {
h3EndpointCatalogLock.Lock()
defer h3EndpointCatalogLock.Unlock()
if h3EndpointCatalog == nil {
h3EndpointCatalog = make(map[string]*endpointInfo)
}
info, ok := h3EndpointCatalog[endpoint]
if !ok {
h3EndpointCatalog[endpoint] = &endpointInfo{
lastFail: brokenAt,
failCount: 1,
}
return 1
}
info.failCount++
if brokenAt.After(info.lastFail) {
info.lastFail = brokenAt
}
return info.failCount
}
func smoothedRtt(oldRtt, newRtt int64) int64 {
if oldRtt == 0 {
return newRtt
}
return int64((1-chromeH3SmoothRTTAlpha)*float64(oldRtt) + chromeH3SmoothRTTAlpha*float64(newRtt))
}
func updateH3RTT(endpoint string, rtt time.Duration) {
h3EndpointCatalogLock.RLock()
info, ok := h3EndpointCatalog[endpoint]
if !ok || info.failCount > 0 {
h3EndpointCatalogLock.RUnlock()
updateH3RTTSlow(endpoint, rtt)
return
}
defer h3EndpointCatalogLock.RUnlock()
for {
oldRtt := info.rtt.Load()
newRtt := smoothedRtt(oldRtt, int64(rtt))
if info.rtt.CompareAndSwap(oldRtt, newRtt) {
return
}
}
}
func updateH3RTTSlow(endpoint string, rtt time.Duration) {
h3EndpointCatalogLock.Lock()
defer h3EndpointCatalogLock.Unlock()
info, ok := h3EndpointCatalog[endpoint]
switch {
case !ok:
info = &endpointInfo{}
info.rtt.Store(int64(rtt))
case info.failCount > 0:
info.failCount = 0
info.lastFail = time.Time{}
info.rtt.Store(int64(rtt))
default:
info.rtt.Store(smoothedRtt(info.rtt.Load(), int64(rtt)))
}
}
type raceNotify struct {
c chan struct{}
result raceResult
// left is the remove counter. This entry should be removed from race map when it reached 0
// left is the remove counter. It should be released when it reached 0
left atomic.Int32
}
@ -72,50 +190,16 @@ func (r *raceNotify) wait() raceResult {
return r.result
}
type raceInfo struct {
key uintptr
host string
notify *raceNotify
}
type raceTransport struct {
h3 *http3.Transport
h2 *http2.Transport
dest net.Destination
dest string
muFlag sync.Mutex
// flag must be read or write with muFlag holding.
flag map[uintptr]raceState
muRace sync.RWMutex
// race must be read or write with muRace holding.
// Entries are:
// added by RoundTrip before racing start
// removed by the last finished racer
race map[string]*raceNotify
h3broken atomic.Int64
}
func (t *raceTransport) isH3Broken() bool {
return time.Now().UnixNano()-t.h3broken.Load() < int64(chromeH3BrokenDelay)
}
func (t *raceTransport) updateH3Broken(broken time.Time) {
value := broken.UnixNano()
old := t.h3broken.Load()
for old < value {
if t.h3broken.CompareAndSwap(old, value) {
break
}
old = t.h3broken.Load()
}
flag atomic.Int64
notify atomic.Pointer[raceNotify]
}
func (t *raceTransport) setup() *raceTransport {
t.flag = make(map[uintptr]raceState)
t.race = make(map[string]*raceNotify)
h3dial := t.h3.Dial
h2dial := t.h2.DialTLSContext
@ -124,106 +208,108 @@ func (t *raceTransport) setup() *raceTransport {
return nil, http3.ErrNoCachedConn
}
info := ctx.Value(raceKey).(raceInfo)
key := info.key
var dialStart time.Time
defer func() {
notify := t.notify.Load()
if err == nil {
updateH3RTT(t.dest, time.Since(dialStart))
notify.result = raceH3
close(notify.c)
} else if !isRaceInternalError(err) {
failed := updateH3Broken(t.dest, time.Now())
errors.LogDebug(ctx, "Race Dialer: h3 connection to ", t.dest, " failed ", failed, "times")
}
// We can safely remove the raceNotify here, since both h2 and h3 Transport
// hold mutex while dialing.
// So another request can't slip in after we removed raceNotify but before
// Transport put the returned conn into pool - they will always reuse the conn we returned.
if err == nil {
info.notify.result = raceH3
close(info.notify.c)
}
if info.notify.left.Add(-1) == 0 {
if notify.left.Add(-1) == 0 {
errors.LogDebug(ctx, "Race Dialer: h3 cleaning race wait")
t.muRace.Lock()
if t.race[info.host] == info.notify {
delete(t.race, info.host)
}
t.muRace.Unlock()
t.notify.Store(nil)
}
}()
if t.isH3Broken() {
if isH3Broken(t.dest) {
return nil, brokenSpanError
}
t.muFlag.Lock()
established := t.flag[key]
t.muFlag.Unlock()
established := t.flag.Load()
if established == raceEstablished {
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established before try)")
return nil, loseRaceError
}
dialStart = time.Now()
conn, err = h3dial(ctx, addr, tlsCfg, cfg)
t.muFlag.Lock()
established = t.flag[key]
if err != nil {
// We fail.
// Record if we are the first, cleanup if we are the last.
if established == raceInitialized {
t.flag[key] = raceErrored
// Record if we are the first.
if t.flag.CompareAndSwap(raceInitialized, raceErrored) {
errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error)")
} else {
delete(t.flag, key)
errors.LogDebug(ctx, "Race Dialer: h3 draw (both error)")
}
t.muFlag.Unlock()
return nil, err
}
switch established {
flag := t.flag.Load()
switch flag {
case raceEstablished:
// h2 wins.
delete(t.flag, key)
t.muFlag.Unlock()
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established before handshake complete)")
return nil, loseRaceError
case raceErrored:
// h2 errored first. We will always be used.
delete(t.flag, key)
t.muFlag.Unlock()
errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)")
return conn, nil
case raceInitialized:
// continue
default:
panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag))
}
// Don't consider we win until handshake completed.
t.muFlag.Unlock()
<-conn.HandshakeComplete()
errors.LogDebug(ctx, "Race Dialer: h3 handshake complete")
t.muFlag.Lock()
established = t.flag[key]
switch established {
if err = conn.Context().Err(); err != nil {
if t.flag.CompareAndSwap(raceInitialized, raceErrored) {
errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error first)")
return nil, err
}
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
conn = nil
} else {
if t.flag.CompareAndSwap(raceInitialized, raceEstablished) {
errors.LogDebug(ctx, "Race Dialer: h3 win (h3 first)")
return conn, nil
}
}
flag = t.flag.Load()
switch flag {
case raceEstablished:
// h2 wins.
delete(t.flag, key)
t.muFlag.Unlock()
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
return nil, loseRaceError
case raceErrored:
// h2 errored first. We clean up.
delete(t.flag, key)
errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)")
case raceInitialized:
// Check if handshake failed.
if err = conn.Context().Err(); err != nil {
conn = nil
t.flag[key] = raceErrored
errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error first)")
// h2 errored first.
if err == nil {
errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)")
} else {
t.flag[key] = raceEstablished
errors.LogDebug(ctx, "Race Dialer: h3 win (h3 first)")
errors.LogDebug(ctx, "Race Dialer: h3 draw (both error)")
}
return conn, err
case raceInitialized:
panic("unreachable: race flag should not revert to raceInitialized")
default:
panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag))
}
t.muFlag.Unlock()
return conn, err
}
t.h2.DialTLSContext = func(ctx context.Context, network, addr string, cfg *gotls.Config) (conn net.Conn, err error) {
@ -231,116 +317,84 @@ func (t *raceTransport) setup() *raceTransport {
return nil, http2.ErrNoCachedConn
}
info := ctx.Value(raceKey).(raceInfo)
key := info.key
defer func() {
notify := t.notify.Load()
if err == nil {
info.notify.result = raceH2
close(info.notify.c)
notify.result = raceH2
close(notify.c)
}
if info.notify.left.Add(-1) == 0 {
if notify.left.Add(-1) == 0 {
errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait")
t.muRace.Lock()
if t.race[info.host] == info.notify {
delete(t.race, info.host)
}
t.muRace.Unlock()
t.notify.Store(nil)
}
}()
if !t.isH3Broken() {
time.Sleep(chromeH2TryDelay)
}
delay := getH2Delay(t.dest)
errors.LogDebug(ctx, "Race Dialer: h2 dial delay: ", delay)
time.Sleep(delay)
t.muFlag.Lock()
established := t.flag[key]
t.muFlag.Unlock()
established := t.flag.Load()
if established == raceEstablished {
errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)")
errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established before try)")
return nil, loseRaceError
}
conn, err = h2dial(ctx, network, addr, cfg)
t.muFlag.Lock()
established = t.flag[key]
if err != nil {
// We fail.
// Record if we are the first, cleanup if we are the last.
if established == raceInitialized {
t.flag[key] = raceErrored
// Record if we are the first.
if t.flag.CompareAndSwap(raceInitialized, raceErrored) {
errors.LogDebug(ctx, "Race Dialer: h2 lose (h2 error first)")
} else {
delete(t.flag, key)
errors.LogDebug(ctx, "Race Dialer: h2 draw (both error)")
return nil, err
}
_ = conn.Close()
conn = nil
} else {
if t.flag.CompareAndSwap(raceInitialized, raceEstablished) {
errors.LogDebug(ctx, "Race Dialer: h2 win (h2 first)")
return conn, nil
}
t.muFlag.Unlock()
return nil, err
}
switch established {
flag := t.flag.Load()
switch flag {
case raceEstablished:
// h3 wins.
delete(t.flag, key)
t.muFlag.Unlock()
_ = conn.Close()
errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)")
return nil, loseRaceError
case raceErrored:
// h3 errored first. We clean up.
delete(t.flag, key)
errors.LogDebug(ctx, "Race Dialer: h2 win (h3 error)")
// h3 errored first.
if err == nil {
errors.LogDebug(ctx, "Race Dialer: h2 win (h3 error)")
} else {
errors.LogDebug(ctx, "Race Dialer: h2 draw (both error)")
}
return conn, err
case raceInitialized:
// We win, record.
t.flag[key] = raceEstablished
errors.LogDebug(ctx, "Race Dialer: h2 win (h2 first)")
panic("unreachable: race flag should not revert to raceInitialized")
default:
panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag))
}
t.muFlag.Unlock()
return conn, nil
}
return t
}
func authorityAddr(scheme string, authority string) (addr string) {
host, port, err := net.SplitHostPort(authority)
if err != nil { // authority didn't have a port
host = authority
port = ""
}
if port == "" { // authority's port was empty
port = "443"
if scheme == "http" {
port = "80"
}
}
if a, err := idna.ToASCII(host); err == nil {
host = a
}
// IPv6 address literal, without a port:
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host + ":" + port
}
return gonet.JoinHostPort(host, port)
}
func (t *raceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()
// If there is inflight racing for current host, let it finish first,
// If there is inflight racing, let it finish first,
// so we can know and reuse winner's conn.
host := authorityAddr(req.URL.Scheme, req.URL.Host)
t.muRace.RLock()
notify, ok := t.race[host]
t.muRace.RUnlock()
notify := t.notify.Load()
raceResult := raceInactive
WaitRace:
raceResult := raceInactive
if ok {
errors.LogDebug(ctx, "Race Dialer: found inflight race to ", t.dest.NetAddr(), ", waiting race winner")
if notify != nil {
errors.LogDebug(ctx, "Race Dialer: found inflight race to ", t.dest, ", waiting race winner")
raceResult = notify.wait()
errors.LogDebug(ctx, "Race Dialer: winner for ", t.dest.NetAddr(), " resolved, continue handling request")
errors.LogDebug(ctx, "Race Dialer: winner for ", t.dest, " resolved, continue handling request")
}
reqNoDial := req.WithContext(context.WithValue(ctx, noDialKey, struct{}{}))
@ -353,7 +407,7 @@ WaitRace:
// - raceInflight: should not see this state.
if raceResult == raceH3 || raceResult == raceInactive {
if resp, err := t.h3.RoundTripOpt(reqNoDial, http3.RoundTripOpt{OnlyCachedConn: true}); err == nil {
errors.LogInfo(ctx, "Race Dialer: use h3 connection for ", t.dest.NetAddr(), " (reusing conn)")
errors.LogInfo(ctx, "Race Dialer: use h3 connection for ", t.dest, " (reusing conn)")
return resp, nil
} else if !goerrors.Is(err, http3.ErrNoCachedConn) {
return nil, err
@ -365,7 +419,7 @@ WaitRace:
if raceResult == raceH2 || raceResult == raceInactive {
// http2.RoundTripOpt.OnlyCachedConn is not effective. However, our noDialKey will block dialing anyway.
if resp, err := t.h2.RoundTripOpt(reqNoDial, http2.RoundTripOpt{OnlyCachedConn: true}); err == nil {
errors.LogInfo(ctx, "Race Dialer: use h2 connection for ", t.dest.NetAddr(), " (reusing conn)")
errors.LogInfo(ctx, "Race Dialer: use h2 connection for ", t.dest, " (reusing conn)")
return resp, nil
} else if !goerrors.Is(err, http2.ErrNoCachedConn) {
return nil, err
@ -375,35 +429,18 @@ WaitRace:
}
// Both don't have cached conn. Now race between h2 and h3.
t.muRace.Lock()
// Recheck
notify, ok = t.race[host]
if ok {
// Some other request started racing before us, we wait for them to finish.
t.muRace.Unlock()
goto WaitRace
}
// We are the goroutine to initialize racing.
// Recheck first.
notify = &raceNotify{c: make(chan struct{})}
notify.left.Store(2)
t.race[host] = notify
t.muRace.Unlock()
if !t.notify.CompareAndSwap(nil, notify) {
// Some other request started racing before us, we wait for them to finish.
goto WaitRace
}
errors.LogDebug(ctx, "Race Dialer: start race to ", t.dest.NetAddr())
// We are the goroutine to initialize racing.
errors.LogDebug(ctx, "Race Dialer: start race to ", t.dest)
// Both RoundTripper can share req.Body, because only one can dial successfully,
// and proceed to read request body.
key := uintptr(unsafe.Pointer(req))
raceCtx := context.WithValue(ctx, raceKey, raceInfo{
key: key,
host: host,
notify: notify,
})
req = req.WithContext(raceCtx)
t.muFlag.Lock()
t.flag[key] = raceInitialized
t.muFlag.Unlock()
t.flag.Store(raceInitialized)
h2resp := make(chan any)
h3resp := make(chan any)
@ -417,36 +454,10 @@ WaitRace:
close(raceDone)
}()
go func() {
resp, err := t.h3.RoundTrip(req)
var result any
if err == nil {
result = resp
old := t.h3broken.Swap(0)
if old != 0 {
errors.LogDebug(ctx, "Race Dialer: h3 connection succeed, clear broken state")
}
} else {
result = err
t.updateH3Broken(time.Now())
errors.LogDebug(ctx, "Race Dialer: h3 connection failed, set broken state")
}
select {
case h3resp <- result:
case <-raceDone:
}
}()
go func() {
resp, err := t.h2.RoundTrip(req)
if notify.left.Add(-1) == 0 {
errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait")
t.muRace.Lock()
delete(t.race, host)
t.muRace.Unlock()
}
// Both RoundTripper can share req.Body, because only one can dial successfully,
// and proceed to read request body.
roundTrip := func(r http.RoundTripper, respChan chan any) {
resp, err := r.RoundTrip(req)
var result any
if err == nil {
@ -456,17 +467,20 @@ WaitRace:
}
select {
case h2resp <- result:
case respChan <- result:
case <-raceDone:
}
}()
}
go roundTrip(t.h3, h3resp)
go roundTrip(t.h2, h2resp)
reportState := func(isH3 bool) {
winner := "h2"
if isH3 {
winner = "h3"
}
errors.LogInfo(ctx, "Race Dialer: use ", winner, " connection for ", t.dest.NetAddr(), " (race winner)")
errors.LogInfo(ctx, "Race Dialer: use ", winner, " connection for ", t.dest, " (race winner)")
}
handleResult := func(respErr any, other chan any, isH3 bool) (*http.Response, error) {
@ -484,9 +498,9 @@ WaitRace:
case error:
switch {
// hide internal error
case goerrors.Is(value, loseRaceError) || goerrors.Is(value, brokenSpanError):
case isRaceInternalError(value):
return nil, otherValue
case goerrors.Is(otherValue, loseRaceError) || goerrors.Is(otherValue, brokenSpanError):
case isRaceInternalError(otherValue):
return nil, value
// prefer h3 error
case isH3:
@ -495,10 +509,10 @@ WaitRace:
return nil, otherValue
}
default:
panic("unreachable: unexpected response type")
panic(fmt.Sprintf("unreachable: unexpected response type %T", otherValue))
}
default:
panic("unreachable: unexpected response type")
panic(fmt.Sprintf("unreachable: unexpected response type %T", value))
}
}