mirror of
https://github.com/XTLS/Xray-core.git
synced 2025-01-30 13:34:12 +00:00
XHTTP Client: Race Dialer
This commit is contained in:
parent
ca9a902213
commit
c2ac58d93e
@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptrace"
|
"net/http/httptrace"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -93,6 +94,9 @@ func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) str
|
|||||||
return "1.1"
|
return "1.1"
|
||||||
}
|
}
|
||||||
if len(tlsConfig.NextProtocol) != 1 {
|
if len(tlsConfig.NextProtocol) != 1 {
|
||||||
|
if slices.Contains(tlsConfig.NextProtocol, "h3") && slices.Contains(tlsConfig.NextProtocol, "h2") {
|
||||||
|
return "3+2"
|
||||||
|
}
|
||||||
return "2"
|
return "2"
|
||||||
}
|
}
|
||||||
if tlsConfig.NextProtocol[0] == "http/1.1" {
|
if tlsConfig.NextProtocol[0] == "http/1.1" {
|
||||||
@ -101,6 +105,7 @@ func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) str
|
|||||||
if tlsConfig.NextProtocol[0] == "h3" {
|
if tlsConfig.NextProtocol[0] == "h3" {
|
||||||
return "3"
|
return "3"
|
||||||
}
|
}
|
||||||
|
|
||||||
return "2"
|
return "2"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,14 +114,27 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
|
|||||||
realityConfig := reality.ConfigFromStreamSettings(streamSettings)
|
realityConfig := reality.ConfigFromStreamSettings(streamSettings)
|
||||||
|
|
||||||
httpVersion := decideHTTPVersion(tlsConfig, realityConfig)
|
httpVersion := decideHTTPVersion(tlsConfig, realityConfig)
|
||||||
if httpVersion == "3" {
|
|
||||||
dest.Network = net.Network_UDP // better to keep this line
|
|
||||||
}
|
|
||||||
|
|
||||||
var gotlsConfig *gotls.Config
|
var gotlsConfig *gotls.Config
|
||||||
|
var h3gotlsConfig *gotls.Config
|
||||||
|
|
||||||
if tlsConfig != nil {
|
if tlsConfig != nil {
|
||||||
gotlsConfig = tlsConfig.GetTLSConfig(tls.WithDestination(dest))
|
gotlsConfig = tlsConfig.GetTLSConfig(tls.WithDestination(dest))
|
||||||
|
h3gotlsConfig = gotlsConfig
|
||||||
|
|
||||||
|
if httpVersion == "3+2" {
|
||||||
|
h3gotlsConfig = &gotls.Config{}
|
||||||
|
*h3gotlsConfig = *gotlsConfig
|
||||||
|
|
||||||
|
// Make QUIC ALPN only contains h3, and remove h3 from TCP TLS ALPN
|
||||||
|
h3gotlsConfig.NextProtos = []string{"h3"}
|
||||||
|
h3idx := slices.Index(h3gotlsConfig.NextProtos, "h3")
|
||||||
|
// Don't modify original tlsConfig.NextProtocol
|
||||||
|
nextProtos := gotlsConfig.NextProtos
|
||||||
|
gotlsConfig.NextProtos = make([]string, 0, len(nextProtos)-1)
|
||||||
|
gotlsConfig.NextProtos = append(gotlsConfig.NextProtos, nextProtos[:h3idx]...)
|
||||||
|
gotlsConfig.NextProtos = append(gotlsConfig.NextProtos, nextProtos[h3idx+1:]...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
transportConfig := streamSettings.ProtocolSettings.(*Config)
|
transportConfig := streamSettings.ProtocolSettings.(*Config)
|
||||||
@ -152,7 +170,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
|
|||||||
|
|
||||||
var transport http.RoundTripper
|
var transport http.RoundTripper
|
||||||
|
|
||||||
if httpVersion == "3" {
|
makeH3Transport := func() *http3.Transport {
|
||||||
if keepAlivePeriod == 0 {
|
if keepAlivePeriod == 0 {
|
||||||
keepAlivePeriod = quicgoH3KeepAlivePeriod
|
keepAlivePeriod = quicgoH3KeepAlivePeriod
|
||||||
}
|
}
|
||||||
@ -168,9 +186,11 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
|
|||||||
MaxIncomingStreams: -1,
|
MaxIncomingStreams: -1,
|
||||||
KeepAlivePeriod: keepAlivePeriod,
|
KeepAlivePeriod: keepAlivePeriod,
|
||||||
}
|
}
|
||||||
transport = &http3.RoundTripper{
|
dest := dest
|
||||||
|
dest.Network = net.Network_UDP
|
||||||
|
return &http3.Transport{
|
||||||
QUICConfig: quicConfig,
|
QUICConfig: quicConfig,
|
||||||
TLSClientConfig: gotlsConfig,
|
TLSClientConfig: h3gotlsConfig,
|
||||||
Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||||
conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
|
conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -208,26 +228,30 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
|
|||||||
return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
|
return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
} else if httpVersion == "2" {
|
}
|
||||||
|
|
||||||
|
makeH2Transport := func() *http2.Transport {
|
||||||
if keepAlivePeriod == 0 {
|
if keepAlivePeriod == 0 {
|
||||||
keepAlivePeriod = chromeH2KeepAlivePeriod
|
keepAlivePeriod = chromeH2KeepAlivePeriod
|
||||||
}
|
}
|
||||||
if keepAlivePeriod < 0 {
|
if keepAlivePeriod < 0 {
|
||||||
keepAlivePeriod = 0
|
keepAlivePeriod = 0
|
||||||
}
|
}
|
||||||
transport = &http2.Transport{
|
return &http2.Transport{
|
||||||
DialTLSContext: func(ctxInner context.Context, network string, addr string, cfg *gotls.Config) (net.Conn, error) {
|
DialTLSContext: func(ctxInner context.Context, network string, addr string, cfg *gotls.Config) (net.Conn, error) {
|
||||||
return dialContext(ctxInner)
|
return dialContext(ctxInner)
|
||||||
},
|
},
|
||||||
IdleConnTimeout: connIdleTimeout,
|
IdleConnTimeout: connIdleTimeout,
|
||||||
ReadIdleTimeout: keepAlivePeriod,
|
ReadIdleTimeout: keepAlivePeriod,
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
|
|
||||||
|
makeTransport := func() *http.Transport {
|
||||||
httpDialContext := func(ctxInner context.Context, network string, addr string) (net.Conn, error) {
|
httpDialContext := func(ctxInner context.Context, network string, addr string) (net.Conn, error) {
|
||||||
return dialContext(ctxInner)
|
return dialContext(ctxInner)
|
||||||
}
|
}
|
||||||
|
|
||||||
transport = &http.Transport{
|
return &http.Transport{
|
||||||
DialTLSContext: httpDialContext,
|
DialTLSContext: httpDialContext,
|
||||||
DialContext: httpDialContext,
|
DialContext: httpDialContext,
|
||||||
IdleConnTimeout: connIdleTimeout,
|
IdleConnTimeout: connIdleTimeout,
|
||||||
@ -237,6 +261,22 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch httpVersion {
|
||||||
|
case "3":
|
||||||
|
transport = makeH3Transport()
|
||||||
|
case "2":
|
||||||
|
transport = makeH2Transport()
|
||||||
|
case "3+2":
|
||||||
|
raceTransport := &raceTransport{
|
||||||
|
h3: makeH3Transport(),
|
||||||
|
h2: makeH2Transport(),
|
||||||
|
dest: dest,
|
||||||
|
}
|
||||||
|
transport = raceTransport.setup()
|
||||||
|
default:
|
||||||
|
transport = makeTransport()
|
||||||
|
}
|
||||||
|
|
||||||
client := &DefaultDialerClient{
|
client := &DefaultDialerClient{
|
||||||
transportConfig: transportConfig,
|
transportConfig: transportConfig,
|
||||||
client: &http.Client{
|
client: &http.Client{
|
||||||
|
511
transport/internet/splithttp/race_dialer.go
Normal file
511
transport/internet/splithttp/race_dialer.go
Normal file
@ -0,0 +1,511 @@
|
|||||||
|
package splithttp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
gotls "crypto/tls"
|
||||||
|
goerrors "errors"
|
||||||
|
gonet "net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/xtls/quic-go"
|
||||||
|
"github.com/xtls/quic-go/http3"
|
||||||
|
"github.com/xtls/xray-core/common/errors"
|
||||||
|
"github.com/xtls/xray-core/common/net"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/idna"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// QuicSessionPool::GetTimeDelayForWaitingJob > kDefaultRTT
|
||||||
|
chromeH2TryDelay = 300 * time.Millisecond
|
||||||
|
|
||||||
|
// kDefaultBrokenAlternativeProtocolDelay
|
||||||
|
chromeH3BrokenDelay = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
type raceKeyType struct{}
|
||||||
|
|
||||||
|
var raceKey raceKeyType
|
||||||
|
|
||||||
|
type noDialKeyType struct{}
|
||||||
|
|
||||||
|
var noDialKey noDialKeyType
|
||||||
|
|
||||||
|
var (
|
||||||
|
loseRaceError = goerrors.New("lose race")
|
||||||
|
brokenSpanError = goerrors.New("protocol temporarily broken")
|
||||||
|
)
|
||||||
|
|
||||||
|
type raceState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
raceInitialized raceState = 0
|
||||||
|
raceEstablished raceState = 1
|
||||||
|
raceErrored raceState = -1
|
||||||
|
)
|
||||||
|
|
||||||
|
type raceResult int
|
||||||
|
|
||||||
|
const (
|
||||||
|
raceInflight raceResult = 0
|
||||||
|
raceH3 raceResult = 1
|
||||||
|
raceH2 raceResult = 2
|
||||||
|
raceFailed raceResult = -1
|
||||||
|
raceInactive raceResult = -2
|
||||||
|
)
|
||||||
|
|
||||||
|
type raceNotify struct {
|
||||||
|
c chan struct{}
|
||||||
|
result raceResult
|
||||||
|
|
||||||
|
// left is the remove counter. This entry should be removed from race map when it reached 0
|
||||||
|
left atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *raceNotify) wait() raceResult {
|
||||||
|
<-r.c
|
||||||
|
return r.result
|
||||||
|
}
|
||||||
|
|
||||||
|
type raceInfo struct {
|
||||||
|
key uintptr
|
||||||
|
host string
|
||||||
|
notify *raceNotify
|
||||||
|
}
|
||||||
|
|
||||||
|
type raceTransport struct {
|
||||||
|
h3 *http3.Transport
|
||||||
|
h2 *http2.Transport
|
||||||
|
dest net.Destination
|
||||||
|
|
||||||
|
muFlag sync.Mutex
|
||||||
|
// flag must be read or write with muFlag holding.
|
||||||
|
flag map[uintptr]raceState
|
||||||
|
|
||||||
|
muRace sync.RWMutex
|
||||||
|
// race must be read or write with muRace holding.
|
||||||
|
// Entries are:
|
||||||
|
// added by RoundTrip before racing start
|
||||||
|
// removed by the last finished racer
|
||||||
|
race map[string]*raceNotify
|
||||||
|
|
||||||
|
h3broken atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *raceTransport) isH3Broken() bool {
|
||||||
|
return time.Now().UnixNano()-t.h3broken.Load() < int64(chromeH3BrokenDelay)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *raceTransport) updateH3Broken(broken time.Time) {
|
||||||
|
value := broken.UnixNano()
|
||||||
|
old := t.h3broken.Load()
|
||||||
|
for old < value {
|
||||||
|
if t.h3broken.CompareAndSwap(old, value) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
old = t.h3broken.Load()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *raceTransport) setup() *raceTransport {
|
||||||
|
t.flag = make(map[uintptr]raceState)
|
||||||
|
t.race = make(map[string]*raceNotify)
|
||||||
|
|
||||||
|
h3dial := t.h3.Dial
|
||||||
|
h2dial := t.h2.DialTLSContext
|
||||||
|
|
||||||
|
t.h3.Dial = func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (conn quic.EarlyConnection, err error) {
|
||||||
|
if ctx.Value(noDialKey) != nil {
|
||||||
|
return nil, http3.ErrNoCachedConn
|
||||||
|
}
|
||||||
|
|
||||||
|
info := ctx.Value(raceKey).(raceInfo)
|
||||||
|
key := info.key
|
||||||
|
defer func() {
|
||||||
|
// We can safely remove the raceNotify here, since both h2 and h3 Transport
|
||||||
|
// hold mutex while dialing.
|
||||||
|
// So another request can't slip in after we removed raceNotify but before
|
||||||
|
// Transport put the returned conn into pool - they will always reuse the conn we returned.
|
||||||
|
if err == nil {
|
||||||
|
info.notify.result = raceH3
|
||||||
|
close(info.notify.c)
|
||||||
|
}
|
||||||
|
if info.notify.left.Add(-1) == 0 {
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 cleaning race wait")
|
||||||
|
t.muRace.Lock()
|
||||||
|
if t.race[info.host] == info.notify {
|
||||||
|
delete(t.race, info.host)
|
||||||
|
}
|
||||||
|
t.muRace.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if t.isH3Broken() {
|
||||||
|
return nil, brokenSpanError
|
||||||
|
}
|
||||||
|
|
||||||
|
t.muFlag.Lock()
|
||||||
|
established := t.flag[key]
|
||||||
|
t.muFlag.Unlock()
|
||||||
|
if established == raceEstablished {
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
|
||||||
|
return nil, loseRaceError
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err = h3dial(ctx, addr, tlsCfg, cfg)
|
||||||
|
|
||||||
|
t.muFlag.Lock()
|
||||||
|
established = t.flag[key]
|
||||||
|
if err != nil {
|
||||||
|
// We fail.
|
||||||
|
// Record if we are the first, cleanup if we are the last.
|
||||||
|
if established == raceInitialized {
|
||||||
|
t.flag[key] = raceErrored
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error)")
|
||||||
|
} else {
|
||||||
|
delete(t.flag, key)
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 draw (both error)")
|
||||||
|
}
|
||||||
|
t.muFlag.Unlock()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch established {
|
||||||
|
case raceEstablished:
|
||||||
|
// h2 wins.
|
||||||
|
delete(t.flag, key)
|
||||||
|
t.muFlag.Unlock()
|
||||||
|
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
|
||||||
|
return nil, loseRaceError
|
||||||
|
case raceErrored:
|
||||||
|
// h2 errored first. We will always be used.
|
||||||
|
delete(t.flag, key)
|
||||||
|
t.muFlag.Unlock()
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)")
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't consider we win until handshake completed.
|
||||||
|
t.muFlag.Unlock()
|
||||||
|
<-conn.HandshakeComplete()
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 handshake complete")
|
||||||
|
t.muFlag.Lock()
|
||||||
|
|
||||||
|
established = t.flag[key]
|
||||||
|
switch established {
|
||||||
|
case raceEstablished:
|
||||||
|
// h2 wins.
|
||||||
|
delete(t.flag, key)
|
||||||
|
t.muFlag.Unlock()
|
||||||
|
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race")
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)")
|
||||||
|
return nil, loseRaceError
|
||||||
|
case raceErrored:
|
||||||
|
// h2 errored first. We clean up.
|
||||||
|
delete(t.flag, key)
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)")
|
||||||
|
case raceInitialized:
|
||||||
|
// Check if handshake failed.
|
||||||
|
if err = conn.Context().Err(); err != nil {
|
||||||
|
conn = nil
|
||||||
|
t.flag[key] = raceErrored
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error first)")
|
||||||
|
} else {
|
||||||
|
t.flag[key] = raceEstablished
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 win (h3 first)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.muFlag.Unlock()
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.h2.DialTLSContext = func(ctx context.Context, network, addr string, cfg *gotls.Config) (conn net.Conn, err error) {
|
||||||
|
if ctx.Value(noDialKey) != nil {
|
||||||
|
return nil, http2.ErrNoCachedConn
|
||||||
|
}
|
||||||
|
|
||||||
|
info := ctx.Value(raceKey).(raceInfo)
|
||||||
|
key := info.key
|
||||||
|
defer func() {
|
||||||
|
if err == nil {
|
||||||
|
info.notify.result = raceH2
|
||||||
|
close(info.notify.c)
|
||||||
|
}
|
||||||
|
if info.notify.left.Add(-1) == 0 {
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait")
|
||||||
|
t.muRace.Lock()
|
||||||
|
if t.race[info.host] == info.notify {
|
||||||
|
delete(t.race, info.host)
|
||||||
|
}
|
||||||
|
t.muRace.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if !t.isH3Broken() {
|
||||||
|
time.Sleep(chromeH2TryDelay)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.muFlag.Lock()
|
||||||
|
established := t.flag[key]
|
||||||
|
t.muFlag.Unlock()
|
||||||
|
if established == raceEstablished {
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)")
|
||||||
|
return nil, loseRaceError
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err = h2dial(ctx, network, addr, cfg)
|
||||||
|
|
||||||
|
t.muFlag.Lock()
|
||||||
|
established = t.flag[key]
|
||||||
|
if err != nil {
|
||||||
|
// We fail.
|
||||||
|
// Record if we are the first, cleanup if we are the last.
|
||||||
|
if established == raceInitialized {
|
||||||
|
t.flag[key] = raceErrored
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h2 lose (h2 error first)")
|
||||||
|
} else {
|
||||||
|
delete(t.flag, key)
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h2 draw (both error)")
|
||||||
|
}
|
||||||
|
t.muFlag.Unlock()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch established {
|
||||||
|
case raceEstablished:
|
||||||
|
// h3 wins.
|
||||||
|
delete(t.flag, key)
|
||||||
|
t.muFlag.Unlock()
|
||||||
|
_ = conn.Close()
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)")
|
||||||
|
return nil, loseRaceError
|
||||||
|
case raceErrored:
|
||||||
|
// h3 errored first. We clean up.
|
||||||
|
delete(t.flag, key)
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h2 win (h3 error)")
|
||||||
|
case raceInitialized:
|
||||||
|
// We win, record.
|
||||||
|
t.flag[key] = raceEstablished
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h2 win (h2 first)")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.muFlag.Unlock()
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func authorityAddr(scheme string, authority string) (addr string) {
|
||||||
|
host, port, err := net.SplitHostPort(authority)
|
||||||
|
if err != nil { // authority didn't have a port
|
||||||
|
host = authority
|
||||||
|
port = ""
|
||||||
|
}
|
||||||
|
if port == "" { // authority's port was empty
|
||||||
|
port = "443"
|
||||||
|
if scheme == "http" {
|
||||||
|
port = "80"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if a, err := idna.ToASCII(host); err == nil {
|
||||||
|
host = a
|
||||||
|
}
|
||||||
|
// IPv6 address literal, without a port:
|
||||||
|
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||||
|
return host + ":" + port
|
||||||
|
}
|
||||||
|
return gonet.JoinHostPort(host, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *raceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
ctx := req.Context()
|
||||||
|
|
||||||
|
// If there is inflight racing for current host, let it finish first,
|
||||||
|
// so we can know and reuse winner's conn.
|
||||||
|
host := authorityAddr(req.URL.Scheme, req.URL.Host)
|
||||||
|
t.muRace.RLock()
|
||||||
|
notify, ok := t.race[host]
|
||||||
|
t.muRace.RUnlock()
|
||||||
|
|
||||||
|
WaitRace:
|
||||||
|
raceResult := raceInactive
|
||||||
|
if ok {
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: found inflight race to ", t.dest.NetAddr(), ", waiting race winner")
|
||||||
|
raceResult = notify.wait()
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: winner for ", t.dest.NetAddr(), " resolved, continue handling request")
|
||||||
|
}
|
||||||
|
|
||||||
|
reqNoDial := req.WithContext(context.WithValue(ctx, noDialKey, struct{}{}))
|
||||||
|
|
||||||
|
// First see if there's cached connection, for both h3 and h2.
|
||||||
|
// - raceInactive: no inflight race. Try both.
|
||||||
|
// - raceH3/raceH2: another request just decided race winner.
|
||||||
|
// Losing Transport may not yet fail, so avoid trying it.
|
||||||
|
// - raceFailed: both failed. There won't be cached conn, no need to try.
|
||||||
|
// - raceInflight: should not see this state.
|
||||||
|
if raceResult == raceH3 || raceResult == raceInactive {
|
||||||
|
if resp, err := t.h3.RoundTripOpt(reqNoDial, http3.RoundTripOpt{OnlyCachedConn: true}); err == nil {
|
||||||
|
errors.LogInfo(ctx, "Race Dialer: use h3 connection for ", t.dest.NetAddr(), " (reusing conn)")
|
||||||
|
return resp, nil
|
||||||
|
} else if !goerrors.Is(err, http3.ErrNoCachedConn) {
|
||||||
|
return nil, err
|
||||||
|
} else if raceResult == raceH3 {
|
||||||
|
return nil, errors.New("Race Dialer: h3 just succeeded, but no cached conn")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if raceResult == raceH2 || raceResult == raceInactive {
|
||||||
|
// http2.RoundTripOpt.OnlyCachedConn is not effective. However, our noDialKey will block dialing anyway.
|
||||||
|
if resp, err := t.h2.RoundTripOpt(reqNoDial, http2.RoundTripOpt{OnlyCachedConn: true}); err == nil {
|
||||||
|
errors.LogInfo(ctx, "Race Dialer: use h2 connection for ", t.dest.NetAddr(), " (reusing conn)")
|
||||||
|
return resp, nil
|
||||||
|
} else if !goerrors.Is(err, http2.ErrNoCachedConn) {
|
||||||
|
return nil, err
|
||||||
|
} else if raceResult == raceH2 {
|
||||||
|
return nil, errors.New("Race Dialer: h2 just succeeded, but no cached conn")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both don't have cached conn. Now race between h2 and h3.
|
||||||
|
t.muRace.Lock()
|
||||||
|
// Recheck
|
||||||
|
notify, ok = t.race[host]
|
||||||
|
if ok {
|
||||||
|
// Some other request started racing before us, we wait for them to finish.
|
||||||
|
t.muRace.Unlock()
|
||||||
|
goto WaitRace
|
||||||
|
}
|
||||||
|
// We are the goroutine to initialize racing.
|
||||||
|
notify = &raceNotify{c: make(chan struct{})}
|
||||||
|
notify.left.Store(2)
|
||||||
|
t.race[host] = notify
|
||||||
|
t.muRace.Unlock()
|
||||||
|
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: start race to ", t.dest.NetAddr())
|
||||||
|
|
||||||
|
// Both RoundTripper can share req.Body, because only one can dial successfully,
|
||||||
|
// and proceed to read request body.
|
||||||
|
key := uintptr(unsafe.Pointer(req))
|
||||||
|
raceCtx := context.WithValue(ctx, raceKey, raceInfo{
|
||||||
|
key: key,
|
||||||
|
host: host,
|
||||||
|
notify: notify,
|
||||||
|
})
|
||||||
|
req = req.WithContext(raceCtx)
|
||||||
|
|
||||||
|
t.muFlag.Lock()
|
||||||
|
t.flag[key] = raceInitialized
|
||||||
|
t.muFlag.Unlock()
|
||||||
|
|
||||||
|
h2resp := make(chan any)
|
||||||
|
h3resp := make(chan any)
|
||||||
|
raceDone := make(chan struct{})
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if notify.result == raceInflight {
|
||||||
|
notify.result = raceFailed
|
||||||
|
close(notify.c)
|
||||||
|
}
|
||||||
|
close(raceDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
resp, err := t.h3.RoundTrip(req)
|
||||||
|
|
||||||
|
var result any
|
||||||
|
if err == nil {
|
||||||
|
result = resp
|
||||||
|
old := t.h3broken.Swap(0)
|
||||||
|
if old != 0 {
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 connection succeed, clear broken state")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result = err
|
||||||
|
t.updateH3Broken(time.Now())
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h3 connection failed, set broken state")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case h3resp <- result:
|
||||||
|
case <-raceDone:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
resp, err := t.h2.RoundTrip(req)
|
||||||
|
if notify.left.Add(-1) == 0 {
|
||||||
|
errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait")
|
||||||
|
t.muRace.Lock()
|
||||||
|
delete(t.race, host)
|
||||||
|
t.muRace.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
var result any
|
||||||
|
if err == nil {
|
||||||
|
result = resp
|
||||||
|
} else {
|
||||||
|
result = err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case h2resp <- result:
|
||||||
|
case <-raceDone:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
reportState := func(isH3 bool) {
|
||||||
|
winner := "h2"
|
||||||
|
if isH3 {
|
||||||
|
winner = "h3"
|
||||||
|
}
|
||||||
|
errors.LogInfo(ctx, "Race Dialer: use ", winner, " connection for ", t.dest.NetAddr(), " (race winner)")
|
||||||
|
}
|
||||||
|
|
||||||
|
handleResult := func(respErr any, other chan any, isH3 bool) (*http.Response, error) {
|
||||||
|
switch value := respErr.(type) {
|
||||||
|
case *http.Response:
|
||||||
|
// we win
|
||||||
|
reportState(isH3)
|
||||||
|
return value, nil
|
||||||
|
case error:
|
||||||
|
switch otherValue := (<-other).(type) {
|
||||||
|
case *http.Response:
|
||||||
|
// other win
|
||||||
|
reportState(!isH3)
|
||||||
|
return otherValue, nil
|
||||||
|
case error:
|
||||||
|
switch {
|
||||||
|
// hide internal error
|
||||||
|
case goerrors.Is(value, loseRaceError) || goerrors.Is(value, brokenSpanError):
|
||||||
|
return nil, otherValue
|
||||||
|
case goerrors.Is(otherValue, loseRaceError) || goerrors.Is(otherValue, brokenSpanError):
|
||||||
|
return nil, value
|
||||||
|
// prefer h3 error
|
||||||
|
case isH3:
|
||||||
|
return nil, value
|
||||||
|
default:
|
||||||
|
return nil, otherValue
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic("unreachable: unexpected response type")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic("unreachable: unexpected response type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case respErr := <-h3resp:
|
||||||
|
return handleResult(respErr, h2resp, true)
|
||||||
|
case respErr := <-h2resp:
|
||||||
|
return handleResult(respErr, h3resp, false)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user