mirror of
https://github.com/XTLS/Xray-core.git
synced 2025-01-27 12:04:13 +00:00
Adapt to current single connection per client
This commit is contained in:
parent
c2ac58d93e
commit
9006e3b35c
@ -270,7 +270,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
|
||||
raceTransport := &raceTransport{
|
||||
h3: makeH3Transport(),
|
||||
h2: makeH2Transport(),
|
||||
dest: dest,
|
||||
dest: dest.NetAddr(),
|
||||
}
|
||||
transport = raceTransport.setup()
|
||||
default:
|
||||
|
@ -4,28 +4,37 @@ import (
|
||||
"context"
|
||||
gotls "crypto/tls"
|
||||
goerrors "errors"
|
||||
gonet "net"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/xtls/quic-go"
|
||||
"github.com/xtls/quic-go/http3"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
const (
|
||||
// net/quic/quic_session_pool.cc
|
||||
// QuicSessionPool::GetTimeDelayForWaitingJob > kDefaultRTT
|
||||
chromeH2TryDelay = 300 * time.Millisecond
|
||||
chromeH2DefaultTryDelay = 300 * time.Millisecond
|
||||
// QuicSessionPool::GetTimeDelayForWaitingJob > srtt
|
||||
chromeH2TryDelayScale = 1.5
|
||||
|
||||
// net/http/broken_alternative_services.cc
|
||||
// kDefaultBrokenAlternativeProtocolDelay
|
||||
chromeH3BrokenDelay = 5 * time.Minute
|
||||
chromeH3BrokenInitialDelay = 5 * time.Minute
|
||||
// kMaxBrokenAlternativeProtocolDelay
|
||||
chromeH3BrokenMaxDelay = 48 * time.Hour
|
||||
// kBrokenDelayMaxShift
|
||||
chromeH3BrokenMaxShift = 18
|
||||
|
||||
// net/third_party/quiche/src/quiche/quic/core/congestion_control/rtt_stats.cc
|
||||
// kAlpha
|
||||
chromeH3SmoothRTTAlpha = 0.125
|
||||
)
|
||||
|
||||
type raceKeyType struct{}
|
||||
@ -41,12 +50,14 @@ var (
|
||||
brokenSpanError = goerrors.New("protocol temporarily broken")
|
||||
)
|
||||
|
||||
type raceState int
|
||||
func isRaceInternalError(err error) bool {
|
||||
return goerrors.Is(err, loseRaceError) || goerrors.Is(err, brokenSpanError)
|
||||
}
|
||||
|
||||
const (
|
||||
raceInitialized raceState = 0
|
||||
raceEstablished raceState = 1
|
||||
raceErrored raceState = -1
|
||||
raceInitialized = 0
|
||||
raceEstablished = 1
|
||||
raceErrored = -1
|
||||
)
|
||||
|
||||
type raceResult int
|
||||
@ -59,11 +70,118 @@ const (
|
||||
raceInactive raceResult = -2
|
||||
)
|
||||
|
||||
type endpointInfo struct {
|
||||
lastFail time.Time
|
||||
failCount int
|
||||
rtt atomic.Int64
|
||||
}
|
||||
|
||||
var h3EndpointCatalog map[string]*endpointInfo
|
||||
var h3EndpointCatalogLock sync.RWMutex
|
||||
|
||||
func isH3Broken(endpoint string) bool {
|
||||
h3EndpointCatalogLock.RLock()
|
||||
defer h3EndpointCatalogLock.RUnlock()
|
||||
info, ok := h3EndpointCatalog[endpoint]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
brokenDuration := min(chromeH3BrokenInitialDelay<<min(info.failCount, chromeH3BrokenMaxShift), chromeH3BrokenMaxDelay)
|
||||
return time.Since(info.lastFail) < brokenDuration
|
||||
}
|
||||
|
||||
func getH2Delay(endpoint string) time.Duration {
|
||||
h3EndpointCatalogLock.RLock()
|
||||
defer h3EndpointCatalogLock.RUnlock()
|
||||
info, ok := h3EndpointCatalog[endpoint]
|
||||
if !ok {
|
||||
return chromeH2DefaultTryDelay
|
||||
}
|
||||
if info.failCount > 0 {
|
||||
return 0
|
||||
}
|
||||
rtt := info.rtt.Load()
|
||||
if rtt == 0 {
|
||||
return chromeH2DefaultTryDelay
|
||||
}
|
||||
return time.Duration(chromeH2TryDelayScale * float64(rtt))
|
||||
}
|
||||
|
||||
func updateH3Broken(endpoint string, brokenAt time.Time) int {
|
||||
h3EndpointCatalogLock.Lock()
|
||||
defer h3EndpointCatalogLock.Unlock()
|
||||
if h3EndpointCatalog == nil {
|
||||
h3EndpointCatalog = make(map[string]*endpointInfo)
|
||||
}
|
||||
|
||||
info, ok := h3EndpointCatalog[endpoint]
|
||||
if !ok {
|
||||
h3EndpointCatalog[endpoint] = &endpointInfo{
|
||||
lastFail: brokenAt,
|
||||
failCount: 1,
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
info.failCount++
|
||||
if brokenAt.After(info.lastFail) {
|
||||
info.lastFail = brokenAt
|
||||
}
|
||||
|
||||
return info.failCount
|
||||
}
|
||||
|
||||
func smoothedRtt(oldRtt, newRtt int64) int64 {
|
||||
if oldRtt == 0 {
|
||||
return newRtt
|
||||
}
|
||||
|
||||
return int64((1-chromeH3SmoothRTTAlpha)*float64(oldRtt) + chromeH3SmoothRTTAlpha*float64(newRtt))
|
||||
}
|
||||
|
||||
func updateH3RTT(endpoint string, rtt time.Duration) {
|
||||
h3EndpointCatalogLock.RLock()
|
||||
info, ok := h3EndpointCatalog[endpoint]
|
||||
if !ok || info.failCount > 0 {
|
||||
h3EndpointCatalogLock.RUnlock()
|
||||
updateH3RTTSlow(endpoint, rtt)
|
||||
return
|
||||
}
|
||||
|
||||
defer h3EndpointCatalogLock.RUnlock()
|
||||
for {
|
||||
oldRtt := info.rtt.Load()
|
||||
newRtt := smoothedRtt(oldRtt, int64(rtt))
|
||||
if info.rtt.CompareAndSwap(oldRtt, newRtt) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func updateH3RTTSlow(endpoint string, rtt time.Duration) {
|
||||
h3EndpointCatalogLock.Lock()
|
||||
defer h3EndpointCatalogLock.Unlock()
|
||||
|
||||
info, ok := h3EndpointCatalog[endpoint]
|
||||
switch {
|
||||
case !ok:
|
||||
info = &endpointInfo{}
|
||||
info.rtt.Store(int64(rtt))
|
||||
case info.failCount > 0:
|
||||
info.failCount = 0
|
||||
info.lastFail = time.Time{}
|
||||
info.rtt.Store(int64(rtt))
|
||||
default:
|
||||
info.rtt.Store(smoothedRtt(info.rtt.Load(), int64(rtt)))
|
||||
}
|
||||
}
|
||||
|
||||
type raceNotify struct {
|
||||
c chan struct{}
|
||||
result raceResult
|
||||
|
||||
// left is the remove counter. This entry should be removed from race map when it reached 0
|
||||
// left is the remove counter. It should be released when it reached 0
|
||||
left atomic.Int32
|
||||
}
|
||||
|
||||
@ -72,50 +190,16 @@ func (r *raceNotify) wait() raceResult {
|
||||
return r.result
|
||||
}
|
||||
|
||||
type raceInfo struct {
|
||||
key uintptr
|
||||
host string
|
||||
notify *raceNotify
|
||||
}
|
||||
|
||||
type raceTransport struct {
|
||||
h3 *http3.Transport
|
||||
h2 *http2.Transport
|
||||
dest net.Destination
|
||||
dest string
|
||||
|
||||
muFlag sync.Mutex
|
||||
// flag must be read or write with muFlag holding.
|
||||
flag map[uintptr]raceState
|
||||
|
||||
muRace sync.RWMutex
|
||||
// race must be read or write with muRace holding.
|
||||
// Entries are:
|
||||
// added by RoundTrip before racing start
|
||||
// removed by the last finished racer
|
||||
race map[string]*raceNotify
|
||||
|
||||
h3broken atomic.Int64
|
||||
}
|
||||
|
||||
func (t *raceTransport) isH3Broken() bool {
|
||||
return time.Now().UnixNano()-t.h3broken.Load() < int64(chromeH3BrokenDelay)
|
||||
}
|
||||
|
||||
func (t *raceTransport) updateH3Broken(broken time.Time) {
|
||||
value := broken.UnixNano()
|
||||
old := t.h3broken.Load()
|
||||
for old < value {
|
||||
if t.h3broken.CompareAndSwap(old, value) {
|
||||
break
|
||||
}
|
||||
old = t.h3broken.Load()
|
||||
}
|
||||
flag atomic.Int64
|
||||
notify atomic.Pointer[raceNotify]
|
||||
}
|
||||
|
||||
func (t *raceTransport) setup() *raceTransport {
|
||||
t.flag = make(map[uintptr]raceState)
|
||||
t.race = make(map[string]*raceNotify)
|
||||
|
||||
h3dial := t.h3.Dial
|
||||
h2dial := t.h2.DialTLSContext
|
||||
|
||||
@ -124,106 +208,108 @@ func (t *raceTransport) setup() *raceTransport {
|
||||
return nil, http3.ErrNoCachedConn
|
||||
}
|
||||
|
||||
info := ctx.Value(raceKey).(raceInfo)
|
||||
key := info.key
|
||||
var dialStart time.Time
|
||||
|
||||
defer func() {
|
||||
notify := t.notify.Load()
|
||||
if err == nil {
|
||||
updateH3RTT(t.dest, time.Since(dialStart))
|
||||
notify.result = raceH3
|
||||
close(notify.c)
|
||||
} else if !isRaceInternalError(err) {
|
||||
failed := updateH3Broken(t.dest, time.Now())
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 connection to ", t.dest, " failed ", failed, "times")
|
||||
}
|
||||
|
||||
// We can safely remove the raceNotify here, since both h2 and h3 Transport
|
||||
// hold mutex while dialing.
|
||||
// So another request can't slip in after we removed raceNotify but before
|
||||
// Transport put the returned conn into pool - they will always reuse the conn we returned.
|
||||
if err == nil {
|
||||
info.notify.result = raceH3
|
||||
close(info.notify.c)
|
||||
}
|
||||
if info.notify.left.Add(-1) == 0 {
|
||||
if notify.left.Add(-1) == 0 {
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 cleaning race wait")
|
||||
t.muRace.Lock()
|
||||
if t.race[info.host] == info.notify {
|
||||
delete(t.race, info.host)
|
||||
}
|
||||
t.muRace.Unlock()
|
||||
t.notify.Store(nil)
|
||||
}
|
||||
}()
|
||||
|
||||
if t.isH3Broken() {
|
||||
if isH3Broken(t.dest) {
|
||||
return nil, brokenSpanError
|
||||
}
|
||||
|
||||
t.muFlag.Lock()
|
||||
established := t.flag[key]
|
||||
t.muFlag.Unlock()
|
||||
established := t.flag.Load()
|
||||
if established == raceEstablished {
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established before try)")
|
||||
return nil, loseRaceError
|
||||
}
|
||||
|
||||
dialStart = time.Now()
|
||||
conn, err = h3dial(ctx, addr, tlsCfg, cfg)
|
||||
|
||||
t.muFlag.Lock()
|
||||
established = t.flag[key]
|
||||
if err != nil {
|
||||
// We fail.
|
||||
// Record if we are the first, cleanup if we are the last.
|
||||
if established == raceInitialized {
|
||||
t.flag[key] = raceErrored
|
||||
// Record if we are the first.
|
||||
if t.flag.CompareAndSwap(raceInitialized, raceErrored) {
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error)")
|
||||
} else {
|
||||
delete(t.flag, key)
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 draw (both error)")
|
||||
}
|
||||
t.muFlag.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch established {
|
||||
flag := t.flag.Load()
|
||||
switch flag {
|
||||
case raceEstablished:
|
||||
// h2 wins.
|
||||
delete(t.flag, key)
|
||||
t.muFlag.Unlock()
|
||||
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established before handshake complete)")
|
||||
return nil, loseRaceError
|
||||
case raceErrored:
|
||||
// h2 errored first. We will always be used.
|
||||
delete(t.flag, key)
|
||||
t.muFlag.Unlock()
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)")
|
||||
return conn, nil
|
||||
case raceInitialized:
|
||||
// continue
|
||||
default:
|
||||
panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag))
|
||||
}
|
||||
|
||||
// Don't consider we win until handshake completed.
|
||||
t.muFlag.Unlock()
|
||||
<-conn.HandshakeComplete()
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 handshake complete")
|
||||
t.muFlag.Lock()
|
||||
|
||||
established = t.flag[key]
|
||||
switch established {
|
||||
if err = conn.Context().Err(); err != nil {
|
||||
if t.flag.CompareAndSwap(raceInitialized, raceErrored) {
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error first)")
|
||||
return nil, err
|
||||
}
|
||||
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
|
||||
conn = nil
|
||||
} else {
|
||||
if t.flag.CompareAndSwap(raceInitialized, raceEstablished) {
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 win (h3 first)")
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
flag = t.flag.Load()
|
||||
switch flag {
|
||||
case raceEstablished:
|
||||
// h2 wins.
|
||||
delete(t.flag, key)
|
||||
t.muFlag.Unlock()
|
||||
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
|
||||
return nil, loseRaceError
|
||||
case raceErrored:
|
||||
// h2 errored first. We clean up.
|
||||
delete(t.flag, key)
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)")
|
||||
case raceInitialized:
|
||||
// Check if handshake failed.
|
||||
if err = conn.Context().Err(); err != nil {
|
||||
conn = nil
|
||||
t.flag[key] = raceErrored
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error first)")
|
||||
// h2 errored first.
|
||||
if err == nil {
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)")
|
||||
} else {
|
||||
t.flag[key] = raceEstablished
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 win (h3 first)")
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 draw (both error)")
|
||||
}
|
||||
return conn, err
|
||||
case raceInitialized:
|
||||
panic("unreachable: race flag should not revert to raceInitialized")
|
||||
default:
|
||||
panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag))
|
||||
}
|
||||
|
||||
t.muFlag.Unlock()
|
||||
return conn, err
|
||||
}
|
||||
|
||||
t.h2.DialTLSContext = func(ctx context.Context, network, addr string, cfg *gotls.Config) (conn net.Conn, err error) {
|
||||
@ -231,116 +317,84 @@ func (t *raceTransport) setup() *raceTransport {
|
||||
return nil, http2.ErrNoCachedConn
|
||||
}
|
||||
|
||||
info := ctx.Value(raceKey).(raceInfo)
|
||||
key := info.key
|
||||
defer func() {
|
||||
notify := t.notify.Load()
|
||||
if err == nil {
|
||||
info.notify.result = raceH2
|
||||
close(info.notify.c)
|
||||
notify.result = raceH2
|
||||
close(notify.c)
|
||||
}
|
||||
if info.notify.left.Add(-1) == 0 {
|
||||
if notify.left.Add(-1) == 0 {
|
||||
errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait")
|
||||
t.muRace.Lock()
|
||||
if t.race[info.host] == info.notify {
|
||||
delete(t.race, info.host)
|
||||
}
|
||||
t.muRace.Unlock()
|
||||
t.notify.Store(nil)
|
||||
}
|
||||
}()
|
||||
|
||||
if !t.isH3Broken() {
|
||||
time.Sleep(chromeH2TryDelay)
|
||||
}
|
||||
delay := getH2Delay(t.dest)
|
||||
errors.LogDebug(ctx, "Race Dialer: h2 dial delay: ", delay)
|
||||
time.Sleep(delay)
|
||||
|
||||
t.muFlag.Lock()
|
||||
established := t.flag[key]
|
||||
t.muFlag.Unlock()
|
||||
established := t.flag.Load()
|
||||
if established == raceEstablished {
|
||||
errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)")
|
||||
errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established before try)")
|
||||
return nil, loseRaceError
|
||||
}
|
||||
|
||||
conn, err = h2dial(ctx, network, addr, cfg)
|
||||
|
||||
t.muFlag.Lock()
|
||||
established = t.flag[key]
|
||||
if err != nil {
|
||||
// We fail.
|
||||
// Record if we are the first, cleanup if we are the last.
|
||||
if established == raceInitialized {
|
||||
t.flag[key] = raceErrored
|
||||
// Record if we are the first.
|
||||
if t.flag.CompareAndSwap(raceInitialized, raceErrored) {
|
||||
errors.LogDebug(ctx, "Race Dialer: h2 lose (h2 error first)")
|
||||
} else {
|
||||
delete(t.flag, key)
|
||||
errors.LogDebug(ctx, "Race Dialer: h2 draw (both error)")
|
||||
return nil, err
|
||||
}
|
||||
_ = conn.Close()
|
||||
conn = nil
|
||||
} else {
|
||||
if t.flag.CompareAndSwap(raceInitialized, raceEstablished) {
|
||||
errors.LogDebug(ctx, "Race Dialer: h2 win (h2 first)")
|
||||
return conn, nil
|
||||
}
|
||||
t.muFlag.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch established {
|
||||
flag := t.flag.Load()
|
||||
switch flag {
|
||||
case raceEstablished:
|
||||
// h3 wins.
|
||||
delete(t.flag, key)
|
||||
t.muFlag.Unlock()
|
||||
_ = conn.Close()
|
||||
errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)")
|
||||
return nil, loseRaceError
|
||||
case raceErrored:
|
||||
// h3 errored first. We clean up.
|
||||
delete(t.flag, key)
|
||||
errors.LogDebug(ctx, "Race Dialer: h2 win (h3 error)")
|
||||
// h3 errored first.
|
||||
if err == nil {
|
||||
errors.LogDebug(ctx, "Race Dialer: h2 win (h3 error)")
|
||||
} else {
|
||||
errors.LogDebug(ctx, "Race Dialer: h2 draw (both error)")
|
||||
}
|
||||
return conn, err
|
||||
case raceInitialized:
|
||||
// We win, record.
|
||||
t.flag[key] = raceEstablished
|
||||
errors.LogDebug(ctx, "Race Dialer: h2 win (h2 first)")
|
||||
panic("unreachable: race flag should not revert to raceInitialized")
|
||||
default:
|
||||
panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag))
|
||||
}
|
||||
|
||||
t.muFlag.Unlock()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func authorityAddr(scheme string, authority string) (addr string) {
|
||||
host, port, err := net.SplitHostPort(authority)
|
||||
if err != nil { // authority didn't have a port
|
||||
host = authority
|
||||
port = ""
|
||||
}
|
||||
if port == "" { // authority's port was empty
|
||||
port = "443"
|
||||
if scheme == "http" {
|
||||
port = "80"
|
||||
}
|
||||
}
|
||||
if a, err := idna.ToASCII(host); err == nil {
|
||||
host = a
|
||||
}
|
||||
// IPv6 address literal, without a port:
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
return host + ":" + port
|
||||
}
|
||||
return gonet.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
func (t *raceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
ctx := req.Context()
|
||||
|
||||
// If there is inflight racing for current host, let it finish first,
|
||||
// If there is inflight racing, let it finish first,
|
||||
// so we can know and reuse winner's conn.
|
||||
host := authorityAddr(req.URL.Scheme, req.URL.Host)
|
||||
t.muRace.RLock()
|
||||
notify, ok := t.race[host]
|
||||
t.muRace.RUnlock()
|
||||
notify := t.notify.Load()
|
||||
raceResult := raceInactive
|
||||
|
||||
WaitRace:
|
||||
raceResult := raceInactive
|
||||
if ok {
|
||||
errors.LogDebug(ctx, "Race Dialer: found inflight race to ", t.dest.NetAddr(), ", waiting race winner")
|
||||
if notify != nil {
|
||||
errors.LogDebug(ctx, "Race Dialer: found inflight race to ", t.dest, ", waiting race winner")
|
||||
raceResult = notify.wait()
|
||||
errors.LogDebug(ctx, "Race Dialer: winner for ", t.dest.NetAddr(), " resolved, continue handling request")
|
||||
errors.LogDebug(ctx, "Race Dialer: winner for ", t.dest, " resolved, continue handling request")
|
||||
}
|
||||
|
||||
reqNoDial := req.WithContext(context.WithValue(ctx, noDialKey, struct{}{}))
|
||||
@ -353,7 +407,7 @@ WaitRace:
|
||||
// - raceInflight: should not see this state.
|
||||
if raceResult == raceH3 || raceResult == raceInactive {
|
||||
if resp, err := t.h3.RoundTripOpt(reqNoDial, http3.RoundTripOpt{OnlyCachedConn: true}); err == nil {
|
||||
errors.LogInfo(ctx, "Race Dialer: use h3 connection for ", t.dest.NetAddr(), " (reusing conn)")
|
||||
errors.LogInfo(ctx, "Race Dialer: use h3 connection for ", t.dest, " (reusing conn)")
|
||||
return resp, nil
|
||||
} else if !goerrors.Is(err, http3.ErrNoCachedConn) {
|
||||
return nil, err
|
||||
@ -365,7 +419,7 @@ WaitRace:
|
||||
if raceResult == raceH2 || raceResult == raceInactive {
|
||||
// http2.RoundTripOpt.OnlyCachedConn is not effective. However, our noDialKey will block dialing anyway.
|
||||
if resp, err := t.h2.RoundTripOpt(reqNoDial, http2.RoundTripOpt{OnlyCachedConn: true}); err == nil {
|
||||
errors.LogInfo(ctx, "Race Dialer: use h2 connection for ", t.dest.NetAddr(), " (reusing conn)")
|
||||
errors.LogInfo(ctx, "Race Dialer: use h2 connection for ", t.dest, " (reusing conn)")
|
||||
return resp, nil
|
||||
} else if !goerrors.Is(err, http2.ErrNoCachedConn) {
|
||||
return nil, err
|
||||
@ -375,35 +429,18 @@ WaitRace:
|
||||
}
|
||||
|
||||
// Both don't have cached conn. Now race between h2 and h3.
|
||||
t.muRace.Lock()
|
||||
// Recheck
|
||||
notify, ok = t.race[host]
|
||||
if ok {
|
||||
// Some other request started racing before us, we wait for them to finish.
|
||||
t.muRace.Unlock()
|
||||
goto WaitRace
|
||||
}
|
||||
// We are the goroutine to initialize racing.
|
||||
// Recheck first.
|
||||
notify = &raceNotify{c: make(chan struct{})}
|
||||
notify.left.Store(2)
|
||||
t.race[host] = notify
|
||||
t.muRace.Unlock()
|
||||
if !t.notify.CompareAndSwap(nil, notify) {
|
||||
// Some other request started racing before us, we wait for them to finish.
|
||||
goto WaitRace
|
||||
}
|
||||
|
||||
errors.LogDebug(ctx, "Race Dialer: start race to ", t.dest.NetAddr())
|
||||
// We are the goroutine to initialize racing.
|
||||
errors.LogDebug(ctx, "Race Dialer: start race to ", t.dest)
|
||||
|
||||
// Both RoundTripper can share req.Body, because only one can dial successfully,
|
||||
// and proceed to read request body.
|
||||
key := uintptr(unsafe.Pointer(req))
|
||||
raceCtx := context.WithValue(ctx, raceKey, raceInfo{
|
||||
key: key,
|
||||
host: host,
|
||||
notify: notify,
|
||||
})
|
||||
req = req.WithContext(raceCtx)
|
||||
|
||||
t.muFlag.Lock()
|
||||
t.flag[key] = raceInitialized
|
||||
t.muFlag.Unlock()
|
||||
t.flag.Store(raceInitialized)
|
||||
|
||||
h2resp := make(chan any)
|
||||
h3resp := make(chan any)
|
||||
@ -417,36 +454,10 @@ WaitRace:
|
||||
close(raceDone)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
resp, err := t.h3.RoundTrip(req)
|
||||
|
||||
var result any
|
||||
if err == nil {
|
||||
result = resp
|
||||
old := t.h3broken.Swap(0)
|
||||
if old != 0 {
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 connection succeed, clear broken state")
|
||||
}
|
||||
} else {
|
||||
result = err
|
||||
t.updateH3Broken(time.Now())
|
||||
errors.LogDebug(ctx, "Race Dialer: h3 connection failed, set broken state")
|
||||
}
|
||||
|
||||
select {
|
||||
case h3resp <- result:
|
||||
case <-raceDone:
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
resp, err := t.h2.RoundTrip(req)
|
||||
if notify.left.Add(-1) == 0 {
|
||||
errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait")
|
||||
t.muRace.Lock()
|
||||
delete(t.race, host)
|
||||
t.muRace.Unlock()
|
||||
}
|
||||
// Both RoundTripper can share req.Body, because only one can dial successfully,
|
||||
// and proceed to read request body.
|
||||
roundTrip := func(r http.RoundTripper, respChan chan any) {
|
||||
resp, err := r.RoundTrip(req)
|
||||
|
||||
var result any
|
||||
if err == nil {
|
||||
@ -456,17 +467,20 @@ WaitRace:
|
||||
}
|
||||
|
||||
select {
|
||||
case h2resp <- result:
|
||||
case respChan <- result:
|
||||
case <-raceDone:
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
go roundTrip(t.h3, h3resp)
|
||||
go roundTrip(t.h2, h2resp)
|
||||
|
||||
reportState := func(isH3 bool) {
|
||||
winner := "h2"
|
||||
if isH3 {
|
||||
winner = "h3"
|
||||
}
|
||||
errors.LogInfo(ctx, "Race Dialer: use ", winner, " connection for ", t.dest.NetAddr(), " (race winner)")
|
||||
errors.LogInfo(ctx, "Race Dialer: use ", winner, " connection for ", t.dest, " (race winner)")
|
||||
}
|
||||
|
||||
handleResult := func(respErr any, other chan any, isH3 bool) (*http.Response, error) {
|
||||
@ -484,9 +498,9 @@ WaitRace:
|
||||
case error:
|
||||
switch {
|
||||
// hide internal error
|
||||
case goerrors.Is(value, loseRaceError) || goerrors.Is(value, brokenSpanError):
|
||||
case isRaceInternalError(value):
|
||||
return nil, otherValue
|
||||
case goerrors.Is(otherValue, loseRaceError) || goerrors.Is(otherValue, brokenSpanError):
|
||||
case isRaceInternalError(otherValue):
|
||||
return nil, value
|
||||
// prefer h3 error
|
||||
case isH3:
|
||||
@ -495,10 +509,10 @@ WaitRace:
|
||||
return nil, otherValue
|
||||
}
|
||||
default:
|
||||
panic("unreachable: unexpected response type")
|
||||
panic(fmt.Sprintf("unreachable: unexpected response type %T", otherValue))
|
||||
}
|
||||
default:
|
||||
panic("unreachable: unexpected response type")
|
||||
panic(fmt.Sprintf("unreachable: unexpected response type %T", value))
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user