mirror of
https://github.com/XTLS/Xray-core.git
synced 2025-01-28 20:44:13 +00:00
Adapt to current single connection per client
This commit is contained in:
parent
c2ac58d93e
commit
9006e3b35c
@ -270,7 +270,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
|
|||||||
raceTransport := &raceTransport{
|
raceTransport := &raceTransport{
|
||||||
h3: makeH3Transport(),
|
h3: makeH3Transport(),
|
||||||
h2: makeH2Transport(),
|
h2: makeH2Transport(),
|
||||||
dest: dest,
|
dest: dest.NetAddr(),
|
||||||
}
|
}
|
||||||
transport = raceTransport.setup()
|
transport = raceTransport.setup()
|
||||||
default:
|
default:
|
||||||
|
@ -4,28 +4,37 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
gotls "crypto/tls"
|
gotls "crypto/tls"
|
||||||
goerrors "errors"
|
goerrors "errors"
|
||||||
gonet "net"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/xtls/quic-go"
|
"github.com/xtls/quic-go"
|
||||||
"github.com/xtls/quic-go/http3"
|
"github.com/xtls/quic-go/http3"
|
||||||
"github.com/xtls/xray-core/common/errors"
|
"github.com/xtls/xray-core/common/errors"
|
||||||
"github.com/xtls/xray-core/common/net"
|
"github.com/xtls/xray-core/common/net"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/idna"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// net/quic/quic_session_pool.cc
|
||||||
// QuicSessionPool::GetTimeDelayForWaitingJob > kDefaultRTT
|
// QuicSessionPool::GetTimeDelayForWaitingJob > kDefaultRTT
|
||||||
chromeH2TryDelay = 300 * time.Millisecond
|
chromeH2DefaultTryDelay = 300 * time.Millisecond
|
||||||
|
// QuicSessionPool::GetTimeDelayForWaitingJob > srtt
|
||||||
|
chromeH2TryDelayScale = 1.5
|
||||||
|
|
||||||
|
// net/http/broken_alternative_services.cc
|
||||||
// kDefaultBrokenAlternativeProtocolDelay
|
// kDefaultBrokenAlternativeProtocolDelay
|
||||||
chromeH3BrokenDelay = 5 * time.Minute
|
chromeH3BrokenInitialDelay = 5 * time.Minute
|
||||||
|
// kMaxBrokenAlternativeProtocolDelay
|
||||||
|
chromeH3BrokenMaxDelay = 48 * time.Hour
|
||||||
|
// kBrokenDelayMaxShift
|
||||||
|
chromeH3BrokenMaxShift = 18
|
||||||
|
|
||||||
|
// net/third_party/quiche/src/quiche/quic/core/congestion_control/rtt_stats.cc
|
||||||
|
// kAlpha
|
||||||
|
chromeH3SmoothRTTAlpha = 0.125
|
||||||
)
|
)
|
||||||
|
|
||||||
type raceKeyType struct{}
|
type raceKeyType struct{}
|
||||||
@ -41,12 +50,14 @@ var (
|
|||||||
brokenSpanError = goerrors.New("protocol temporarily broken")
|
brokenSpanError = goerrors.New("protocol temporarily broken")
|
||||||
)
|
)
|
||||||
|
|
||||||
type raceState int
|
func isRaceInternalError(err error) bool {
|
||||||
|
return goerrors.Is(err, loseRaceError) || goerrors.Is(err, brokenSpanError)
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
raceInitialized raceState = 0
|
raceInitialized = 0
|
||||||
raceEstablished raceState = 1
|
raceEstablished = 1
|
||||||
raceErrored raceState = -1
|
raceErrored = -1
|
||||||
)
|
)
|
||||||
|
|
||||||
type raceResult int
|
type raceResult int
|
||||||
@ -59,11 +70,118 @@ const (
|
|||||||
raceInactive raceResult = -2
|
raceInactive raceResult = -2
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type endpointInfo struct {
|
||||||
|
lastFail time.Time
|
||||||
|
failCount int
|
||||||
|
rtt atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
var h3EndpointCatalog map[string]*endpointInfo
|
||||||
|
var h3EndpointCatalogLock sync.RWMutex
|
||||||
|
|
||||||
|
func isH3Broken(endpoint string) bool {
|
||||||
|
h3EndpointCatalogLock.RLock()
|
||||||
|
defer h3EndpointCatalogLock.RUnlock()
|
||||||
|
info, ok := h3EndpointCatalog[endpoint]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
brokenDuration := min(chromeH3BrokenInitialDelay<<min(info.failCount, chromeH3BrokenMaxShift), chromeH3BrokenMaxDelay)
|
||||||
|
return time.Since(info.lastFail) < brokenDuration
|
||||||
|
}
|
||||||
|
|
||||||
|
func getH2Delay(endpoint string) time.Duration {
|
||||||
|
h3EndpointCatalogLock.RLock()
|
||||||
|
defer h3EndpointCatalogLock.RUnlock()
|
||||||
|
info, ok := h3EndpointCatalog[endpoint]
|
||||||
|
if !ok {
|
||||||
|
return chromeH2DefaultTryDelay
|
||||||
|
}
|
||||||
|
if info.failCount > 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
rtt := info.rtt.Load()
|
||||||
|
if rtt == 0 {
|
||||||
|
return chromeH2DefaultTryDelay
|
||||||
|
}
|
||||||
|
return time.Duration(chromeH2TryDelayScale * float64(rtt))
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateH3Broken(endpoint string, brokenAt time.Time) int {
|
||||||
|
h3EndpointCatalogLock.Lock()
|
||||||
|
defer h3EndpointCatalogLock.Unlock()
|
||||||
|
if h3EndpointCatalog == nil {
|
||||||
|
h3EndpointCatalog = make(map[string]*endpointInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
info, ok := h3EndpointCatalog[endpoint]
|
||||||
|
if !ok {
|
||||||
|
h3EndpointCatalog[endpoint] = &endpointInfo{
|
||||||
|
lastFail: brokenAt,
|
||||||
|
failCount: 1,
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
info.failCount++
|
||||||
|
if brokenAt.After(info.lastFail) {
|
||||||
|
info.lastFail = brokenAt
|
||||||
|
}
|
||||||
|
|
||||||
|
return info.failCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func smoothedRtt(oldRtt, newRtt int64) int64 {
|
||||||
|
if oldRtt == 0 {
|
||||||
|
return newRtt
|
||||||
|
}
|
||||||
|
|
||||||
|
return int64((1-chromeH3SmoothRTTAlpha)*float64(oldRtt) + chromeH3SmoothRTTAlpha*float64(newRtt))
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateH3RTT(endpoint string, rtt time.Duration) {
|
||||||
|
h3EndpointCatalogLock.RLock()
|
||||||
|
info, ok := h3EndpointCatalog[endpoint]
|
||||||
|
if !ok || info.failCount > 0 {
|
||||||
|
h3EndpointCatalogLock.RUnlock()
|
||||||
|
updateH3RTTSlow(endpoint, rtt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer h3EndpointCatalogLock.RUnlock()
|
||||||
|
for {
|
||||||
|
oldRtt := info.rtt.Load()
|
||||||
|
newRtt := smoothedRtt(oldRtt, int64(rtt))
|
||||||
|
if info.rtt.CompareAndSwap(oldRtt, newRtt) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateH3RTTSlow(endpoint string, rtt time.Duration) {
|
||||||
|
h3EndpointCatalogLock.Lock()
|
||||||
|
defer h3EndpointCatalogLock.Unlock()
|
||||||
|
|
||||||
|
info, ok := h3EndpointCatalog[endpoint]
|
||||||
|
switch {
|
||||||
|
case !ok:
|
||||||
|
info = &endpointInfo{}
|
||||||
|
info.rtt.Store(int64(rtt))
|
||||||
|
case info.failCount > 0:
|
||||||
|
info.failCount = 0
|
||||||
|
info.lastFail = time.Time{}
|
||||||
|
info.rtt.Store(int64(rtt))
|
||||||
|
default:
|
||||||
|
info.rtt.Store(smoothedRtt(info.rtt.Load(), int64(rtt)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type raceNotify struct {
|
type raceNotify struct {
|
||||||
c chan struct{}
|
c chan struct{}
|
||||||
result raceResult
|
result raceResult
|
||||||
|
|
||||||
// left is the remove counter. This entry should be removed from race map when it reached 0
|
// left is the remove counter. It should be released when it reached 0
|
||||||
left atomic.Int32
|
left atomic.Int32
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -72,50 +190,16 @@ func (r *raceNotify) wait() raceResult {
|
|||||||
return r.result
|
return r.result
|
||||||
}
|
}
|
||||||
|
|
||||||
type raceInfo struct {
|
|
||||||
key uintptr
|
|
||||||
host string
|
|
||||||
notify *raceNotify
|
|
||||||
}
|
|
||||||
|
|
||||||
type raceTransport struct {
|
type raceTransport struct {
|
||||||
h3 *http3.Transport
|
h3 *http3.Transport
|
||||||
h2 *http2.Transport
|
h2 *http2.Transport
|
||||||
dest net.Destination
|
dest string
|
||||||
|
|
||||||
muFlag sync.Mutex
|
flag atomic.Int64
|
||||||
// flag must be read or write with muFlag holding.
|
notify atomic.Pointer[raceNotify]
|
||||||
flag map[uintptr]raceState
|
|
||||||
|
|
||||||
muRace sync.RWMutex
|
|
||||||
// race must be read or write with muRace holding.
|
|
||||||
// Entries are:
|
|
||||||
// added by RoundTrip before racing start
|
|
||||||
// removed by the last finished racer
|
|
||||||
race map[string]*raceNotify
|
|
||||||
|
|
||||||
h3broken atomic.Int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *raceTransport) isH3Broken() bool {
|
|
||||||
return time.Now().UnixNano()-t.h3broken.Load() < int64(chromeH3BrokenDelay)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *raceTransport) updateH3Broken(broken time.Time) {
|
|
||||||
value := broken.UnixNano()
|
|
||||||
old := t.h3broken.Load()
|
|
||||||
for old < value {
|
|
||||||
if t.h3broken.CompareAndSwap(old, value) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
old = t.h3broken.Load()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *raceTransport) setup() *raceTransport {
|
func (t *raceTransport) setup() *raceTransport {
|
||||||
t.flag = make(map[uintptr]raceState)
|
|
||||||
t.race = make(map[string]*raceNotify)
|
|
||||||
|
|
||||||
h3dial := t.h3.Dial
|
h3dial := t.h3.Dial
|
||||||
h2dial := t.h2.DialTLSContext
|
h2dial := t.h2.DialTLSContext
|
||||||
|
|
||||||
@ -124,106 +208,108 @@ func (t *raceTransport) setup() *raceTransport {
|
|||||||
return nil, http3.ErrNoCachedConn
|
return nil, http3.ErrNoCachedConn
|
||||||
}
|
}
|
||||||
|
|
||||||
info := ctx.Value(raceKey).(raceInfo)
|
var dialStart time.Time
|
||||||
key := info.key
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
notify := t.notify.Load()
|
||||||
|
if err == nil {
|
||||||
|
updateH3RTT(t.dest, time.Since(dialStart))
|
||||||
|
notify.result = raceH3
|
||||||
|
close(notify.c)
|
||||||
|
} else if !isRaceInternalError(err) {
|
||||||
|
failed := updateH3Broken(t.dest, time.Now())
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 connection to ", t.dest, " failed ", failed, "times")
|
||||||
|
}
|
||||||
|
|
||||||
// We can safely remove the raceNotify here, since both h2 and h3 Transport
|
// We can safely remove the raceNotify here, since both h2 and h3 Transport
|
||||||
// hold mutex while dialing.
|
// hold mutex while dialing.
|
||||||
// So another request can't slip in after we removed raceNotify but before
|
// So another request can't slip in after we removed raceNotify but before
|
||||||
// Transport put the returned conn into pool - they will always reuse the conn we returned.
|
// Transport put the returned conn into pool - they will always reuse the conn we returned.
|
||||||
if err == nil {
|
if notify.left.Add(-1) == 0 {
|
||||||
info.notify.result = raceH3
|
|
||||||
close(info.notify.c)
|
|
||||||
}
|
|
||||||
if info.notify.left.Add(-1) == 0 {
|
|
||||||
errors.LogDebug(ctx, "Race Dialer: h3 cleaning race wait")
|
errors.LogDebug(ctx, "Race Dialer: h3 cleaning race wait")
|
||||||
t.muRace.Lock()
|
t.notify.Store(nil)
|
||||||
if t.race[info.host] == info.notify {
|
|
||||||
delete(t.race, info.host)
|
|
||||||
}
|
|
||||||
t.muRace.Unlock()
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if t.isH3Broken() {
|
if isH3Broken(t.dest) {
|
||||||
return nil, brokenSpanError
|
return nil, brokenSpanError
|
||||||
}
|
}
|
||||||
|
|
||||||
t.muFlag.Lock()
|
established := t.flag.Load()
|
||||||
established := t.flag[key]
|
|
||||||
t.muFlag.Unlock()
|
|
||||||
if established == raceEstablished {
|
if established == raceEstablished {
|
||||||
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
|
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established before try)")
|
||||||
return nil, loseRaceError
|
return nil, loseRaceError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dialStart = time.Now()
|
||||||
conn, err = h3dial(ctx, addr, tlsCfg, cfg)
|
conn, err = h3dial(ctx, addr, tlsCfg, cfg)
|
||||||
|
|
||||||
t.muFlag.Lock()
|
|
||||||
established = t.flag[key]
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// We fail.
|
// We fail.
|
||||||
// Record if we are the first, cleanup if we are the last.
|
// Record if we are the first.
|
||||||
if established == raceInitialized {
|
if t.flag.CompareAndSwap(raceInitialized, raceErrored) {
|
||||||
t.flag[key] = raceErrored
|
|
||||||
errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error)")
|
errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error)")
|
||||||
} else {
|
} else {
|
||||||
delete(t.flag, key)
|
|
||||||
errors.LogDebug(ctx, "Race Dialer: h3 draw (both error)")
|
errors.LogDebug(ctx, "Race Dialer: h3 draw (both error)")
|
||||||
}
|
}
|
||||||
t.muFlag.Unlock()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch established {
|
flag := t.flag.Load()
|
||||||
|
switch flag {
|
||||||
case raceEstablished:
|
case raceEstablished:
|
||||||
// h2 wins.
|
// h2 wins.
|
||||||
delete(t.flag, key)
|
|
||||||
t.muFlag.Unlock()
|
|
||||||
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
|
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
|
||||||
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
|
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established before handshake complete)")
|
||||||
return nil, loseRaceError
|
return nil, loseRaceError
|
||||||
case raceErrored:
|
case raceErrored:
|
||||||
// h2 errored first. We will always be used.
|
// h2 errored first. We will always be used.
|
||||||
delete(t.flag, key)
|
|
||||||
t.muFlag.Unlock()
|
|
||||||
errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)")
|
errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)")
|
||||||
return conn, nil
|
return conn, nil
|
||||||
|
case raceInitialized:
|
||||||
|
// continue
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't consider we win until handshake completed.
|
// Don't consider we win until handshake completed.
|
||||||
t.muFlag.Unlock()
|
|
||||||
<-conn.HandshakeComplete()
|
<-conn.HandshakeComplete()
|
||||||
errors.LogDebug(ctx, "Race Dialer: h3 handshake complete")
|
errors.LogDebug(ctx, "Race Dialer: h3 handshake complete")
|
||||||
t.muFlag.Lock()
|
|
||||||
|
|
||||||
established = t.flag[key]
|
if err = conn.Context().Err(); err != nil {
|
||||||
switch established {
|
if t.flag.CompareAndSwap(raceInitialized, raceErrored) {
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error first)")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
|
||||||
|
conn = nil
|
||||||
|
} else {
|
||||||
|
if t.flag.CompareAndSwap(raceInitialized, raceEstablished) {
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 win (h3 first)")
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flag = t.flag.Load()
|
||||||
|
switch flag {
|
||||||
case raceEstablished:
|
case raceEstablished:
|
||||||
// h2 wins.
|
// h2 wins.
|
||||||
delete(t.flag, key)
|
|
||||||
t.muFlag.Unlock()
|
|
||||||
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
|
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
|
||||||
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
|
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
|
||||||
return nil, loseRaceError
|
return nil, loseRaceError
|
||||||
case raceErrored:
|
case raceErrored:
|
||||||
// h2 errored first. We clean up.
|
// h2 errored first.
|
||||||
delete(t.flag, key)
|
if err == nil {
|
||||||
errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)")
|
errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)")
|
||||||
case raceInitialized:
|
|
||||||
// Check if handshake failed.
|
|
||||||
if err = conn.Context().Err(); err != nil {
|
|
||||||
conn = nil
|
|
||||||
t.flag[key] = raceErrored
|
|
||||||
errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error first)")
|
|
||||||
} else {
|
} else {
|
||||||
t.flag[key] = raceEstablished
|
errors.LogDebug(ctx, "Race Dialer: h3 draw (both error)")
|
||||||
errors.LogDebug(ctx, "Race Dialer: h3 win (h3 first)")
|
|
||||||
}
|
}
|
||||||
|
return conn, err
|
||||||
|
case raceInitialized:
|
||||||
|
panic("unreachable: race flag should not revert to raceInitialized")
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag))
|
||||||
}
|
}
|
||||||
|
|
||||||
t.muFlag.Unlock()
|
|
||||||
return conn, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
t.h2.DialTLSContext = func(ctx context.Context, network, addr string, cfg *gotls.Config) (conn net.Conn, err error) {
|
t.h2.DialTLSContext = func(ctx context.Context, network, addr string, cfg *gotls.Config) (conn net.Conn, err error) {
|
||||||
@ -231,116 +317,84 @@ func (t *raceTransport) setup() *raceTransport {
|
|||||||
return nil, http2.ErrNoCachedConn
|
return nil, http2.ErrNoCachedConn
|
||||||
}
|
}
|
||||||
|
|
||||||
info := ctx.Value(raceKey).(raceInfo)
|
|
||||||
key := info.key
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
notify := t.notify.Load()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
info.notify.result = raceH2
|
notify.result = raceH2
|
||||||
close(info.notify.c)
|
close(notify.c)
|
||||||
}
|
}
|
||||||
if info.notify.left.Add(-1) == 0 {
|
if notify.left.Add(-1) == 0 {
|
||||||
errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait")
|
errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait")
|
||||||
t.muRace.Lock()
|
t.notify.Store(nil)
|
||||||
if t.race[info.host] == info.notify {
|
|
||||||
delete(t.race, info.host)
|
|
||||||
}
|
|
||||||
t.muRace.Unlock()
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if !t.isH3Broken() {
|
delay := getH2Delay(t.dest)
|
||||||
time.Sleep(chromeH2TryDelay)
|
errors.LogDebug(ctx, "Race Dialer: h2 dial delay: ", delay)
|
||||||
}
|
time.Sleep(delay)
|
||||||
|
|
||||||
t.muFlag.Lock()
|
established := t.flag.Load()
|
||||||
established := t.flag[key]
|
|
||||||
t.muFlag.Unlock()
|
|
||||||
if established == raceEstablished {
|
if established == raceEstablished {
|
||||||
errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)")
|
errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established before try)")
|
||||||
return nil, loseRaceError
|
return nil, loseRaceError
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err = h2dial(ctx, network, addr, cfg)
|
conn, err = h2dial(ctx, network, addr, cfg)
|
||||||
|
|
||||||
t.muFlag.Lock()
|
|
||||||
established = t.flag[key]
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// We fail.
|
// We fail.
|
||||||
// Record if we are the first, cleanup if we are the last.
|
// Record if we are the first.
|
||||||
if established == raceInitialized {
|
if t.flag.CompareAndSwap(raceInitialized, raceErrored) {
|
||||||
t.flag[key] = raceErrored
|
|
||||||
errors.LogDebug(ctx, "Race Dialer: h2 lose (h2 error first)")
|
errors.LogDebug(ctx, "Race Dialer: h2 lose (h2 error first)")
|
||||||
} else {
|
return nil, err
|
||||||
delete(t.flag, key)
|
}
|
||||||
errors.LogDebug(ctx, "Race Dialer: h2 draw (both error)")
|
_ = conn.Close()
|
||||||
|
conn = nil
|
||||||
|
} else {
|
||||||
|
if t.flag.CompareAndSwap(raceInitialized, raceEstablished) {
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h2 win (h2 first)")
|
||||||
|
return conn, nil
|
||||||
}
|
}
|
||||||
t.muFlag.Unlock()
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch established {
|
flag := t.flag.Load()
|
||||||
|
switch flag {
|
||||||
case raceEstablished:
|
case raceEstablished:
|
||||||
// h3 wins.
|
// h3 wins.
|
||||||
delete(t.flag, key)
|
|
||||||
t.muFlag.Unlock()
|
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)")
|
errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)")
|
||||||
return nil, loseRaceError
|
return nil, loseRaceError
|
||||||
case raceErrored:
|
case raceErrored:
|
||||||
// h3 errored first. We clean up.
|
// h3 errored first.
|
||||||
delete(t.flag, key)
|
if err == nil {
|
||||||
errors.LogDebug(ctx, "Race Dialer: h2 win (h3 error)")
|
errors.LogDebug(ctx, "Race Dialer: h2 win (h3 error)")
|
||||||
|
} else {
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h2 draw (both error)")
|
||||||
|
}
|
||||||
|
return conn, err
|
||||||
case raceInitialized:
|
case raceInitialized:
|
||||||
// We win, record.
|
panic("unreachable: race flag should not revert to raceInitialized")
|
||||||
t.flag[key] = raceEstablished
|
default:
|
||||||
errors.LogDebug(ctx, "Race Dialer: h2 win (h2 first)")
|
panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag))
|
||||||
}
|
}
|
||||||
|
|
||||||
t.muFlag.Unlock()
|
|
||||||
return conn, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
func authorityAddr(scheme string, authority string) (addr string) {
|
|
||||||
host, port, err := net.SplitHostPort(authority)
|
|
||||||
if err != nil { // authority didn't have a port
|
|
||||||
host = authority
|
|
||||||
port = ""
|
|
||||||
}
|
|
||||||
if port == "" { // authority's port was empty
|
|
||||||
port = "443"
|
|
||||||
if scheme == "http" {
|
|
||||||
port = "80"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if a, err := idna.ToASCII(host); err == nil {
|
|
||||||
host = a
|
|
||||||
}
|
|
||||||
// IPv6 address literal, without a port:
|
|
||||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
|
||||||
return host + ":" + port
|
|
||||||
}
|
|
||||||
return gonet.JoinHostPort(host, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *raceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (t *raceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
ctx := req.Context()
|
ctx := req.Context()
|
||||||
|
|
||||||
// If there is inflight racing for current host, let it finish first,
|
// If there is inflight racing, let it finish first,
|
||||||
// so we can know and reuse winner's conn.
|
// so we can know and reuse winner's conn.
|
||||||
host := authorityAddr(req.URL.Scheme, req.URL.Host)
|
notify := t.notify.Load()
|
||||||
t.muRace.RLock()
|
raceResult := raceInactive
|
||||||
notify, ok := t.race[host]
|
|
||||||
t.muRace.RUnlock()
|
|
||||||
|
|
||||||
WaitRace:
|
WaitRace:
|
||||||
raceResult := raceInactive
|
if notify != nil {
|
||||||
if ok {
|
errors.LogDebug(ctx, "Race Dialer: found inflight race to ", t.dest, ", waiting race winner")
|
||||||
errors.LogDebug(ctx, "Race Dialer: found inflight race to ", t.dest.NetAddr(), ", waiting race winner")
|
|
||||||
raceResult = notify.wait()
|
raceResult = notify.wait()
|
||||||
errors.LogDebug(ctx, "Race Dialer: winner for ", t.dest.NetAddr(), " resolved, continue handling request")
|
errors.LogDebug(ctx, "Race Dialer: winner for ", t.dest, " resolved, continue handling request")
|
||||||
}
|
}
|
||||||
|
|
||||||
reqNoDial := req.WithContext(context.WithValue(ctx, noDialKey, struct{}{}))
|
reqNoDial := req.WithContext(context.WithValue(ctx, noDialKey, struct{}{}))
|
||||||
@ -353,7 +407,7 @@ WaitRace:
|
|||||||
// - raceInflight: should not see this state.
|
// - raceInflight: should not see this state.
|
||||||
if raceResult == raceH3 || raceResult == raceInactive {
|
if raceResult == raceH3 || raceResult == raceInactive {
|
||||||
if resp, err := t.h3.RoundTripOpt(reqNoDial, http3.RoundTripOpt{OnlyCachedConn: true}); err == nil {
|
if resp, err := t.h3.RoundTripOpt(reqNoDial, http3.RoundTripOpt{OnlyCachedConn: true}); err == nil {
|
||||||
errors.LogInfo(ctx, "Race Dialer: use h3 connection for ", t.dest.NetAddr(), " (reusing conn)")
|
errors.LogInfo(ctx, "Race Dialer: use h3 connection for ", t.dest, " (reusing conn)")
|
||||||
return resp, nil
|
return resp, nil
|
||||||
} else if !goerrors.Is(err, http3.ErrNoCachedConn) {
|
} else if !goerrors.Is(err, http3.ErrNoCachedConn) {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -365,7 +419,7 @@ WaitRace:
|
|||||||
if raceResult == raceH2 || raceResult == raceInactive {
|
if raceResult == raceH2 || raceResult == raceInactive {
|
||||||
// http2.RoundTripOpt.OnlyCachedConn is not effective. However, our noDialKey will block dialing anyway.
|
// http2.RoundTripOpt.OnlyCachedConn is not effective. However, our noDialKey will block dialing anyway.
|
||||||
if resp, err := t.h2.RoundTripOpt(reqNoDial, http2.RoundTripOpt{OnlyCachedConn: true}); err == nil {
|
if resp, err := t.h2.RoundTripOpt(reqNoDial, http2.RoundTripOpt{OnlyCachedConn: true}); err == nil {
|
||||||
errors.LogInfo(ctx, "Race Dialer: use h2 connection for ", t.dest.NetAddr(), " (reusing conn)")
|
errors.LogInfo(ctx, "Race Dialer: use h2 connection for ", t.dest, " (reusing conn)")
|
||||||
return resp, nil
|
return resp, nil
|
||||||
} else if !goerrors.Is(err, http2.ErrNoCachedConn) {
|
} else if !goerrors.Is(err, http2.ErrNoCachedConn) {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -375,35 +429,18 @@ WaitRace:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Both don't have cached conn. Now race between h2 and h3.
|
// Both don't have cached conn. Now race between h2 and h3.
|
||||||
t.muRace.Lock()
|
// Recheck first.
|
||||||
// Recheck
|
|
||||||
notify, ok = t.race[host]
|
|
||||||
if ok {
|
|
||||||
// Some other request started racing before us, we wait for them to finish.
|
|
||||||
t.muRace.Unlock()
|
|
||||||
goto WaitRace
|
|
||||||
}
|
|
||||||
// We are the goroutine to initialize racing.
|
|
||||||
notify = &raceNotify{c: make(chan struct{})}
|
notify = &raceNotify{c: make(chan struct{})}
|
||||||
notify.left.Store(2)
|
notify.left.Store(2)
|
||||||
t.race[host] = notify
|
if !t.notify.CompareAndSwap(nil, notify) {
|
||||||
t.muRace.Unlock()
|
// Some other request started racing before us, we wait for them to finish.
|
||||||
|
goto WaitRace
|
||||||
|
}
|
||||||
|
|
||||||
errors.LogDebug(ctx, "Race Dialer: start race to ", t.dest.NetAddr())
|
// We are the goroutine to initialize racing.
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: start race to ", t.dest)
|
||||||
|
|
||||||
// Both RoundTripper can share req.Body, because only one can dial successfully,
|
t.flag.Store(raceInitialized)
|
||||||
// and proceed to read request body.
|
|
||||||
key := uintptr(unsafe.Pointer(req))
|
|
||||||
raceCtx := context.WithValue(ctx, raceKey, raceInfo{
|
|
||||||
key: key,
|
|
||||||
host: host,
|
|
||||||
notify: notify,
|
|
||||||
})
|
|
||||||
req = req.WithContext(raceCtx)
|
|
||||||
|
|
||||||
t.muFlag.Lock()
|
|
||||||
t.flag[key] = raceInitialized
|
|
||||||
t.muFlag.Unlock()
|
|
||||||
|
|
||||||
h2resp := make(chan any)
|
h2resp := make(chan any)
|
||||||
h3resp := make(chan any)
|
h3resp := make(chan any)
|
||||||
@ -417,36 +454,10 @@ WaitRace:
|
|||||||
close(raceDone)
|
close(raceDone)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
// Both RoundTripper can share req.Body, because only one can dial successfully,
|
||||||
resp, err := t.h3.RoundTrip(req)
|
// and proceed to read request body.
|
||||||
|
roundTrip := func(r http.RoundTripper, respChan chan any) {
|
||||||
var result any
|
resp, err := r.RoundTrip(req)
|
||||||
if err == nil {
|
|
||||||
result = resp
|
|
||||||
old := t.h3broken.Swap(0)
|
|
||||||
if old != 0 {
|
|
||||||
errors.LogDebug(ctx, "Race Dialer: h3 connection succeed, clear broken state")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
result = err
|
|
||||||
t.updateH3Broken(time.Now())
|
|
||||||
errors.LogDebug(ctx, "Race Dialer: h3 connection failed, set broken state")
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case h3resp <- result:
|
|
||||||
case <-raceDone:
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
resp, err := t.h2.RoundTrip(req)
|
|
||||||
if notify.left.Add(-1) == 0 {
|
|
||||||
errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait")
|
|
||||||
t.muRace.Lock()
|
|
||||||
delete(t.race, host)
|
|
||||||
t.muRace.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
var result any
|
var result any
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -456,17 +467,20 @@ WaitRace:
|
|||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case h2resp <- result:
|
case respChan <- result:
|
||||||
case <-raceDone:
|
case <-raceDone:
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
|
|
||||||
|
go roundTrip(t.h3, h3resp)
|
||||||
|
go roundTrip(t.h2, h2resp)
|
||||||
|
|
||||||
reportState := func(isH3 bool) {
|
reportState := func(isH3 bool) {
|
||||||
winner := "h2"
|
winner := "h2"
|
||||||
if isH3 {
|
if isH3 {
|
||||||
winner = "h3"
|
winner = "h3"
|
||||||
}
|
}
|
||||||
errors.LogInfo(ctx, "Race Dialer: use ", winner, " connection for ", t.dest.NetAddr(), " (race winner)")
|
errors.LogInfo(ctx, "Race Dialer: use ", winner, " connection for ", t.dest, " (race winner)")
|
||||||
}
|
}
|
||||||
|
|
||||||
handleResult := func(respErr any, other chan any, isH3 bool) (*http.Response, error) {
|
handleResult := func(respErr any, other chan any, isH3 bool) (*http.Response, error) {
|
||||||
@ -484,9 +498,9 @@ WaitRace:
|
|||||||
case error:
|
case error:
|
||||||
switch {
|
switch {
|
||||||
// hide internal error
|
// hide internal error
|
||||||
case goerrors.Is(value, loseRaceError) || goerrors.Is(value, brokenSpanError):
|
case isRaceInternalError(value):
|
||||||
return nil, otherValue
|
return nil, otherValue
|
||||||
case goerrors.Is(otherValue, loseRaceError) || goerrors.Is(otherValue, brokenSpanError):
|
case isRaceInternalError(otherValue):
|
||||||
return nil, value
|
return nil, value
|
||||||
// prefer h3 error
|
// prefer h3 error
|
||||||
case isH3:
|
case isH3:
|
||||||
@ -495,10 +509,10 @@ WaitRace:
|
|||||||
return nil, otherValue
|
return nil, otherValue
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
panic("unreachable: unexpected response type")
|
panic(fmt.Sprintf("unreachable: unexpected response type %T", otherValue))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
panic("unreachable: unexpected response type")
|
panic(fmt.Sprintf("unreachable: unexpected response type %T", value))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user