XHTTP XMUX: Add hMaxRequestTimes and hKeepAlivePeriod (#4163)

Fixes https://github.com/XTLS/Xray-core/discussions/4113#discussioncomment-11492833
This commit is contained in:
RPRX 2024-12-15 05:43:10 +00:00 committed by GitHub
parent 7463561856
commit 73e0d4a666
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 515 additions and 430 deletions

View file

@ -10,6 +10,7 @@ import (
"net/url"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go"
@ -45,11 +46,11 @@ type dialerConf struct {
}
var (
globalDialerMap map[dialerConf]*muxManager
globalDialerMap map[dialerConf]*XmuxManager
globalDialerAccess sync.Mutex
)
func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (DialerClient, *muxResource) {
func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (DialerClient, *XmuxClient) {
realityConfig := reality.ConfigFromStreamSettings(streamSettings)
if browser_dialer.HasBrowserDialer() && realityConfig != nil {
@ -60,28 +61,28 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
defer globalDialerAccess.Unlock()
if globalDialerMap == nil {
globalDialerMap = make(map[dialerConf]*muxManager)
globalDialerMap = make(map[dialerConf]*XmuxManager)
}
key := dialerConf{dest, streamSettings}
muxManager, found := globalDialerMap[key]
xmuxManager, found := globalDialerMap[key]
if !found {
transportConfig := streamSettings.ProtocolSettings.(*Config)
var mux Multiplexing
var xmuxConfig XmuxConfig
if transportConfig.Xmux != nil {
mux = *transportConfig.Xmux
xmuxConfig = *transportConfig.Xmux
}
muxManager = NewMuxManager(mux, func() interface{} {
xmuxManager = NewXmuxManager(xmuxConfig, func() XmuxConn {
return createHTTPClient(dest, streamSettings)
})
globalDialerMap[key] = muxManager
globalDialerMap[key] = xmuxManager
}
res := muxManager.GetResource(ctx)
return res.Resource.(DialerClient), res
xmuxClient := xmuxManager.GetXmuxClient(ctx)
return xmuxClient.XmuxConn.(DialerClient), xmuxClient
}
func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) string {
@ -144,7 +145,10 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
return conn, nil
}
keepAlivePeriod := time.Duration(streamSettings.ProtocolSettings.(*Config).KeepAlivePeriod) * time.Second
var keepAlivePeriod time.Duration
if streamSettings.ProtocolSettings.(*Config).Xmux != nil {
keepAlivePeriod = time.Duration(streamSettings.ProtocolSettings.(*Config).Xmux.HKeepAlivePeriod) * time.Second
}
var transport http.RoundTripper
@ -282,7 +286,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
requestURL.Path = transportConfiguration.GetNormalizedPath() + sessionIdUuid.String()
requestURL.RawQuery = transportConfiguration.GetNormalizedQuery()
httpClient, muxRes := getHTTPClient(ctx, dest, streamSettings)
httpClient, xmuxClient := getHTTPClient(ctx, dest, streamSettings)
mode := transportConfiguration.Mode
if mode == "" || mode == "auto" {
@ -299,7 +303,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
requestURL2 := requestURL
httpClient2 := httpClient
var muxRes2 *muxResource
xmuxClient2 := xmuxClient
if transportConfiguration.DownloadSettings != nil {
globalDialerAccess.Lock()
if streamSettings.DownloadSettings == nil {
@ -332,7 +336,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
}
requestURL2.Path = config2.GetNormalizedPath() + sessionIdUuid.String()
requestURL2.RawQuery = config2.GetNormalizedQuery()
httpClient2, muxRes2 = getHTTPClient(ctx, dest2, memory2)
httpClient2, xmuxClient2 = getHTTPClient(ctx, dest2, memory2)
errors.LogInfo(ctx, fmt.Sprintf("XHTTP is downloading from %s, mode %s, HTTP version %s, host %s", dest2, "stream-down", httpVersion2, requestURL2.Host))
}
@ -343,23 +347,29 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
if mode == "stream-one" {
requestURL.Path = transportConfiguration.GetNormalizedPath()
if xmuxClient != nil {
xmuxClient.LeftRequests.Add(-1)
}
writer, reader = httpClient.Open(context.WithoutCancel(ctx), requestURL.String())
remoteAddr = &net.TCPAddr{}
localAddr = &net.TCPAddr{}
} else {
if xmuxClient2 != nil {
xmuxClient2.LeftRequests.Add(-1)
}
reader, remoteAddr, localAddr, err = httpClient2.OpenDownload(context.WithoutCancel(ctx), requestURL2.String())
if err != nil {
return nil, err
}
}
if muxRes != nil {
muxRes.OpenRequests.Add(1)
if xmuxClient != nil {
xmuxClient.OpenUsage.Add(1)
}
if muxRes2 != nil {
muxRes2.OpenRequests.Add(1)
if xmuxClient2 != nil && xmuxClient2 != xmuxClient {
xmuxClient2.OpenUsage.Add(1)
}
closed := false
var once atomic.Int32
conn := splitConn{
writer: writer,
@ -367,23 +377,28 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
remoteAddr: remoteAddr,
localAddr: localAddr,
onClose: func() {
if closed {
if once.Add(-1) < 0 {
return
}
closed = true
if muxRes != nil {
muxRes.OpenRequests.Add(-1)
if xmuxClient != nil {
xmuxClient.OpenUsage.Add(-1)
}
if muxRes2 != nil {
muxRes2.OpenRequests.Add(-1)
if xmuxClient2 != nil && xmuxClient2 != xmuxClient {
xmuxClient2.OpenUsage.Add(-1)
}
},
}
if mode == "stream-one" {
if xmuxClient != nil {
xmuxClient.LeftRequests.Add(-1)
}
return stat.Connection(&conn), nil
}
if mode == "stream-up" {
if xmuxClient != nil {
xmuxClient.LeftRequests.Add(-1)
}
conn.writer = httpClient.OpenUpload(ctx, requestURL.String())
return stat.Connection(&conn), nil
}
@ -391,7 +406,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
scMaxEachPostBytes := transportConfiguration.GetNormalizedScMaxEachPostBytes()
scMinPostsIntervalMs := transportConfiguration.GetNormalizedScMinPostsIntervalMs()
maxUploadSize := scMaxEachPostBytes.roll()
maxUploadSize := scMaxEachPostBytes.rand()
// WithSizeLimit(0) will still allow single bytes to pass, and a lot of
// code relies on this behavior. Subtract 1 so that together with
// uploadWriter wrapper, exact size limits can be enforced
@ -426,7 +441,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
seq += 1
if scMinPostsIntervalMs.From > 0 {
time.Sleep(time.Duration(scMinPostsIntervalMs.roll())*time.Millisecond - time.Since(lastWrite))
time.Sleep(time.Duration(scMinPostsIntervalMs.rand())*time.Millisecond - time.Since(lastWrite))
}
// by offloading the uploads into a buffered pipe, multiple conn.Write
@ -439,6 +454,10 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
lastWrite = time.Now()
if xmuxClient != nil && xmuxClient.LeftRequests.Add(-1) <= 0 {
httpClient, xmuxClient = getHTTPClient(ctx, dest, streamSettings)
}
go func() {
err := httpClient.SendUploadRequest(
context.WithoutCancel(ctx),