Transport: Add HTTP3 to HTTP (#3819)

This commit is contained in:
yuhan6665 2024-09-25 21:29:41 -04:00 committed by GitHub
parent 7086d286be
commit 3632e83faa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 316 additions and 128 deletions

View File

@ -650,7 +650,7 @@ func (p TransportProtocol) Build() (string, error) {
return "mkcp", nil return "mkcp", nil
case "ws", "websocket": case "ws", "websocket":
return "websocket", nil return "websocket", nil
case "h2", "http": case "h2", "h3", "http":
return "http", nil return "http", nil
case "grpc", "gun": case "grpc", "gun":
return "grpc", nil return "grpc", nil

View File

@ -9,6 +9,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
c "github.com/xtls/xray-core/common/ctx" c "github.com/xtls/xray-core/common/ctx"
@ -24,6 +26,13 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
// defines the maximum time an idle TCP session can survive in the tunnel, so
// it should be consistent across HTTP versions and with other transports.
const connIdleTimeout = 300 * time.Second
// consistent with quic-go
const h3KeepalivePeriod = 10 * time.Second
type dialerConf struct { type dialerConf struct {
net.Destination net.Destination
*internet.MemoryStreamConfig *internet.MemoryStreamConfig
@ -48,72 +57,129 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
if tlsConfigs == nil && realityConfigs == nil { if tlsConfigs == nil && realityConfigs == nil {
return nil, errors.New("TLS or REALITY must be enabled for http transport.").AtWarning() return nil, errors.New("TLS or REALITY must be enabled for http transport.").AtWarning()
} }
isH3 := tlsConfigs != nil && (len(tlsConfigs.NextProtocol) == 1 && tlsConfigs.NextProtocol[0] == "h3")
if isH3 {
dest.Network = net.Network_UDP
}
sockopt := streamSettings.SocketSettings sockopt := streamSettings.SocketSettings
if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found { if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
return client, nil return client, nil
} }
transport := &http2.Transport{ var transport http.RoundTripper
DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) { if isH3 {
rawHost, rawPort, err := net.SplitHostPort(addr) quicConfig := &quic.Config{
if err != nil { MaxIdleTimeout: connIdleTimeout,
return nil, err
}
if len(rawPort) == 0 {
rawPort = "443"
}
port, err := net.PortFromString(rawPort)
if err != nil {
return nil, err
}
address := net.ParseAddress(rawHost)
hctx = c.ContextWithID(hctx, c.IDFromContext(ctx)) // these two are defaults of quic-go/http3. the default of quic-go (no
hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx)) // http3) is different, so it is hardcoded here for clarity.
hctx = session.ContextWithTimeoutOnly(hctx, true) // https://github.com/quic-go/quic-go/blob/b8ea5c798155950fb5bbfdd06cad1939c9355878/http3/client.go#L36-L39
MaxIncomingStreams: -1,
KeepAlivePeriod: h3KeepalivePeriod,
}
roundTripper := &http3.RoundTripper{
QUICConfig: quicConfig,
TLSClientConfig: tlsConfigs.GetTLSConfig(tls.WithDestination(dest)),
Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
if err != nil {
return nil, err
}
pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt) var udpConn net.PacketConn
if err != nil { var udpAddr *net.UDPAddr
errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
return nil, err
}
if realityConfigs != nil { switch c := conn.(type) {
return reality.UClient(pconn, realityConfigs, hctx, dest) case *internet.PacketConnWrapper:
} var ok bool
udpConn, ok = c.Conn.(*net.UDPConn)
if !ok {
return nil, errors.New("PacketConnWrapper does not contain a UDP connection")
}
udpAddr, err = net.ResolveUDPAddr("udp", c.Dest.String())
if err != nil {
return nil, err
}
case *net.UDPConn:
udpConn = c
udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
if err != nil {
return nil, err
}
default:
udpConn = &internet.FakePacketConn{c}
udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
if err != nil {
return nil, err
}
}
var cn tls.Interface return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
if fingerprint := tls.GetFingerprint(tlsConfigs.Fingerprint); fingerprint != nil { },
cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn) }
} else { transport = roundTripper
cn = tls.Client(pconn, tlsConfig).(*tls.Conn) } else {
} transportH2 := &http2.Transport{
if err := cn.HandshakeContext(ctx); err != nil { DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr) rawHost, rawPort, err := net.SplitHostPort(addr)
return nil, err if err != nil {
} return nil, err
if !tlsConfig.InsecureSkipVerify { }
if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil { if len(rawPort) == 0 {
rawPort = "443"
}
port, err := net.PortFromString(rawPort)
if err != nil {
return nil, err
}
address := net.ParseAddress(rawHost)
hctx = c.ContextWithID(hctx, c.IDFromContext(ctx))
hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx))
hctx = session.ContextWithTimeoutOnly(hctx, true)
pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr) errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
return nil, err return nil, err
} }
}
negotiatedProtocol := cn.NegotiatedProtocol()
if negotiatedProtocol != http2.NextProtoTLS {
return nil, errors.New("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError()
}
return cn, nil
},
}
if tlsConfigs != nil { if realityConfigs != nil {
transport.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest)) return reality.UClient(pconn, realityConfigs, hctx, dest)
} }
if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 { var cn tls.Interface
transport.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout) if fingerprint := tls.GetFingerprint(tlsConfigs.Fingerprint); fingerprint != nil {
transport.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout) cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
} else {
cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
}
if err := cn.HandshakeContext(ctx); err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
return nil, err
}
if !tlsConfig.InsecureSkipVerify {
if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
return nil, err
}
}
negotiatedProtocol := cn.NegotiatedProtocol()
if negotiatedProtocol != http2.NextProtoTLS {
return nil, errors.New("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError()
}
return cn, nil
},
}
if tlsConfigs != nil {
transportH2.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest))
}
if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 {
transportH2.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout)
transportH2.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout)
}
transport = transportH2
} }
client := &http.Client{ client := &http.Client{
@ -158,9 +224,6 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
Host: dest.NetAddr(), Host: dest.NetAddr(),
Path: httpSettings.getNormalizedPath(), Path: httpSettings.getNormalizedPath(),
}, },
Proto: "HTTP/2",
ProtoMajor: 2,
ProtoMinor: 0,
Header: httpHeaders, Header: httpHeaders,
} }
// Disable any compression method from server. // Disable any compression method from server.

View File

@ -12,6 +12,7 @@ import (
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol/tls/cert" "github.com/xtls/xray-core/common/protocol/tls/cert"
"github.com/xtls/xray-core/testing/servers/tcp" "github.com/xtls/xray-core/testing/servers/tcp"
"github.com/xtls/xray-core/testing/servers/udp"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
. "github.com/xtls/xray-core/transport/internet/http" . "github.com/xtls/xray-core/transport/internet/http"
"github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/stat"
@ -92,3 +93,80 @@ func TestHTTPConnection(t *testing.T) {
t.Error(r) t.Error(r)
} }
} }
func TestH3Connection(t *testing.T) {
port := udp.PickPort()
listener, err := Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{
ProtocolName: "http",
ProtocolSettings: &Config{},
SecurityType: "tls",
SecuritySettings: &tls.Config{
NextProtocol: []string{"h3"},
Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("www.example.com")))},
},
}, func(conn stat.Connection) {
go func() {
defer conn.Close()
b := buf.New()
defer b.Release()
for {
if _, err := b.ReadFrom(conn); err != nil {
return
}
_, err := conn.Write(b.Bytes())
common.Must(err)
}
}()
})
common.Must(err)
defer listener.Close()
time.Sleep(time.Second)
dctx := context.Background()
conn, err := Dial(dctx, net.TCPDestination(net.LocalHostIP, port), &internet.MemoryStreamConfig{
ProtocolName: "http",
ProtocolSettings: &Config{},
SecurityType: "tls",
SecuritySettings: &tls.Config{
NextProtocol: []string{"h3"},
ServerName: "www.example.com",
AllowInsecure: true,
},
})
common.Must(err)
defer conn.Close()
const N = 1024
b1 := make([]byte, N)
common.Must2(rand.Read(b1))
b2 := buf.New()
nBytes, err := conn.Write(b1)
common.Must(err)
if nBytes != N {
t.Error("write: ", nBytes)
}
b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N))
if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}
nBytes, err = conn.Write(b1)
common.Must(err)
if nBytes != N {
t.Error("write: ", nBytes)
}
b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N))
if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}
}

View File

@ -2,11 +2,14 @@ package http
import ( import (
"context" "context"
gotls "crypto/tls"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
goreality "github.com/xtls/reality" goreality "github.com/xtls/reality"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
@ -23,10 +26,12 @@ import (
) )
type Listener struct { type Listener struct {
server *http.Server server *http.Server
handler internet.ConnHandler h3server *http3.Server
local net.Addr handler internet.ConnHandler
config *Config local net.Addr
config *Config
isH3 bool
} }
func (l *Listener) Addr() net.Addr { func (l *Listener) Addr() net.Addr {
@ -34,7 +39,14 @@ func (l *Listener) Addr() net.Addr {
} }
func (l *Listener) Close() error { func (l *Listener) Close() error {
return l.server.Close() if l.h3server != nil {
if err := l.h3server.Close(); err != nil {
return err
}
} else if l.server != nil {
return l.server.Close()
}
return errors.New("listener does not have an HTTP/3 server or h2 server")
} }
type flushWriter struct { type flushWriter struct {
@ -119,43 +131,33 @@ func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request)
func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) { func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
httpSettings := streamSettings.ProtocolSettings.(*Config) httpSettings := streamSettings.ProtocolSettings.(*Config)
var listener *Listener
if port == net.Port(0) { // unix
listener = &Listener{
handler: handler,
local: &net.UnixAddr{
Name: address.Domain(),
Net: "unix",
},
config: httpSettings,
}
} else { // tcp
listener = &Listener{
handler: handler,
local: &net.TCPAddr{
IP: address.IP(),
Port: int(port),
},
config: httpSettings,
}
}
var server *http.Server
config := tls.ConfigFromStreamSettings(streamSettings) config := tls.ConfigFromStreamSettings(streamSettings)
var tlsConfig *gotls.Config
if config == nil { if config == nil {
h2s := &http2.Server{} tlsConfig = &gotls.Config{}
} else {
server = &http.Server{ tlsConfig = config.GetTLSConfig()
Addr: serial.Concat(address, ":", port), }
Handler: h2c.NewHandler(listener, h2s), isH3 := len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"
ReadHeaderTimeout: time.Second * 4, listener := &Listener{
handler: handler,
config: httpSettings,
isH3: isH3,
}
if port == net.Port(0) { // unix
listener.local = &net.UnixAddr{
Name: address.Domain(),
Net: "unix",
}
} else if isH3 { // udp
listener.local = &net.UDPAddr{
IP: address.IP(),
Port: int(port),
} }
} else { } else {
server = &http.Server{ listener.local = &net.TCPAddr{
Addr: serial.Concat(address, ":", port), IP: address.IP(),
TLSConfig: config.GetTLSConfig(tls.WithNextProto("h2")), Port: int(port),
Handler: listener,
ReadHeaderTimeout: time.Second * 4,
} }
} }
@ -163,45 +165,84 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
errors.LogWarning(ctx, "accepting PROXY protocol") errors.LogWarning(ctx, "accepting PROXY protocol")
} }
listener.server = server if isH3 {
go func() { Conn, err := internet.ListenSystemPacket(context.Background(), listener.local, streamSettings.SocketSettings)
var streamListener net.Listener if err != nil {
var err error return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err)
if port == net.Port(0) { // unix }
streamListener, err = internet.ListenSystem(ctx, &net.UnixAddr{ h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil)
Name: address.Domain(), if err != nil {
Net: "unix", return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err)
}, streamSettings.SocketSettings) }
if err != nil { errors.LogInfo(ctx, "listening QUIC(for SH3) on ", address, ":", port)
errors.LogErrorInner(ctx, err, "failed to listen on ", address)
return listener.h3server = &http3.Server{
Handler: listener,
}
go func() {
if err := listener.h3server.ServeListener(h3listener); err != nil {
errors.LogWarningInner(ctx, err, "failed to serve http3 for splithttp")
} }
} else { // tcp }()
streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{ } else {
IP: address.IP(), var server *http.Server
Port: int(port), if config == nil {
}, streamSettings.SocketSettings) h2s := &http2.Server{}
if err != nil {
errors.LogErrorInner(ctx, err, "failed to listen on ", address, ":", port) server = &http.Server{
return Addr: serial.Concat(address, ":", port),
Handler: h2c.NewHandler(listener, h2s),
ReadHeaderTimeout: time.Second * 4,
}
} else {
server = &http.Server{
Addr: serial.Concat(address, ":", port),
TLSConfig: config.GetTLSConfig(tls.WithNextProto("h2")),
Handler: listener,
ReadHeaderTimeout: time.Second * 4,
} }
} }
if config == nil { listener.server = server
if config := reality.ConfigFromStreamSettings(streamSettings); config != nil { go func() {
streamListener = goreality.NewListener(streamListener, config.GetREALITYConfig()) var streamListener net.Listener
var err error
if port == net.Port(0) { // unix
streamListener, err = internet.ListenSystem(ctx, &net.UnixAddr{
Name: address.Domain(),
Net: "unix",
}, streamSettings.SocketSettings)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to listen on ", address)
return
}
} else { // tcp
streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(),
Port: int(port),
}, streamSettings.SocketSettings)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to listen on ", address, ":", port)
return
}
} }
err = server.Serve(streamListener)
if err != nil { if config == nil {
errors.LogInfoInner(ctx, err, "stopping serving H2C or REALITY H2") if config := reality.ConfigFromStreamSettings(streamSettings); config != nil {
streamListener = goreality.NewListener(streamListener, config.GetREALITYConfig())
}
err = server.Serve(streamListener)
if err != nil {
errors.LogInfoInner(ctx, err, "stopping serving H2C or REALITY H2")
}
} else {
err = server.ServeTLS(streamListener, "", "")
if err != nil {
errors.LogInfoInner(ctx, err, "stopping serving TLS H2")
}
} }
} else { }()
err = server.ServeTLS(streamListener, "", "") }
if err != nil {
errors.LogInfoInner(ctx, err, "stopping serving TLS H2")
}
}
}()
return listener, nil return listener, nil
} }

View File

@ -365,7 +365,13 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
// Addr implements net.Listener.Addr(). // Addr implements net.Listener.Addr().
func (ln *Listener) Addr() net.Addr { func (ln *Listener) Addr() net.Addr {
return ln.listener.Addr() if ln.h3listener != nil {
return ln.h3listener.Addr()
}
if ln.listener != nil {
return ln.listener.Addr()
}
return nil
} }
// Close implements net.Listener.Close(). // Close implements net.Listener.Close().