XHTTP server: Set remoteAddr & localAddr correctly

Completes 22c50a70c6
This commit is contained in:
RPRX 2025-02-18 10:50:50 +00:00 committed by GitHub
parent eef74b2c7d
commit 8cb63db6c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 77 additions and 71 deletions

View File

@ -76,8 +76,9 @@ type (
) )
var ( var (
ResolveUnixAddr = net.ResolveUnixAddr ResolveTCPAddr = net.ResolveTCPAddr
ResolveUDPAddr = net.ResolveUDPAddr ResolveUDPAddr = net.ResolveUDPAddr
ResolveUnixAddr = net.ResolveUnixAddr
) )
type Resolver = net.Resolver type Resolver = net.Resolver

View File

@ -113,12 +113,12 @@ type TrafficState struct {
type InboundState struct { type InboundState struct {
// reader link state // reader link state
WithinPaddingBuffers bool WithinPaddingBuffers bool
UplinkReaderDirectCopy bool UplinkReaderDirectCopy bool
RemainingCommand int32 RemainingCommand int32
RemainingContent int32 RemainingContent int32
RemainingPadding int32 RemainingPadding int32
CurrentCommand int CurrentCommand int
// write link state // write link state
IsPadding bool IsPadding bool
DownlinkWriterDirectCopy bool DownlinkWriterDirectCopy bool
@ -133,19 +133,19 @@ type OutboundState struct {
RemainingPadding int32 RemainingPadding int32
CurrentCommand int CurrentCommand int
// write link state // write link state
IsPadding bool IsPadding bool
UplinkWriterDirectCopy bool UplinkWriterDirectCopy bool
} }
func NewTrafficState(userUUID []byte) *TrafficState { func NewTrafficState(userUUID []byte) *TrafficState {
return &TrafficState{ return &TrafficState{
UserUUID: userUUID, UserUUID: userUUID,
NumberOfPacketToFilter: 8, NumberOfPacketToFilter: 8,
EnableXtls: false, EnableXtls: false,
IsTLS12orAbove: false, IsTLS12orAbove: false,
IsTLS: false, IsTLS: false,
Cipher: 0, Cipher: 0,
RemainingServerHello: -1, RemainingServerHello: -1,
Inbound: InboundState{ Inbound: InboundState{
WithinPaddingBuffers: true, WithinPaddingBuffers: true,
UplinkReaderDirectCopy: false, UplinkReaderDirectCopy: false,
@ -524,7 +524,7 @@ func XtlsFilterTls(buffer buf.MultiBuffer, trafficState *TrafficState, ctx conte
} }
} }
// UnwrapRawConn support unwrap stats, tls, utls, reality and proxyproto conn and get raw tcp conn from it // UnwrapRawConn support unwrap stats, tls, utls, reality, proxyproto, uds-wrapper conn and get raw tcp/uds conn from it
func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) { func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
var readCounter, writerCounter stats.Counter var readCounter, writerCounter stats.Counter
if conn != nil { if conn != nil {
@ -547,6 +547,9 @@ func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
conn = pc.Raw() conn = pc.Raw()
// 8192 > 4096, there is no need to process pc's bufReader // 8192 > 4096, there is no need to process pc's bufReader
} }
if uc, ok := conn.(*internet.UDSWrapperConn); ok {
conn = uc.Conn
}
} }
return conn, readCounter, writerCounter return conn, readCounter, writerCounter
} }

View File

@ -3,9 +3,8 @@ package splithttp
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls" gotls "crypto/tls"
"io" "io"
gonet "net"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
@ -24,7 +23,7 @@ import (
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/reality" "github.com/xtls/xray-core/transport/internet/reality"
"github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/stat"
v2tls "github.com/xtls/xray-core/transport/internet/tls" "github.com/xtls/xray-core/transport/internet/tls"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
) )
@ -36,7 +35,7 @@ type requestHandler struct {
ln *Listener ln *Listener
sessionMu *sync.Mutex sessionMu *sync.Mutex
sessions sync.Map sessions sync.Map
localAddr gonet.TCPAddr localAddr net.Addr
} }
type httpSession struct { type httpSession struct {
@ -144,14 +143,25 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
} }
forwardedAddrs := http_proto.ParseXForwardedFor(request.Header) forwardedAddrs := http_proto.ParseXForwardedFor(request.Header)
remoteAddr, err := gonet.ResolveTCPAddr("tcp", request.RemoteAddr) var remoteAddr net.Addr
var err error
remoteAddr, err = net.ResolveTCPAddr("tcp", request.RemoteAddr)
if err != nil { if err != nil {
remoteAddr = &gonet.TCPAddr{} remoteAddr = &net.TCPAddr{
IP: []byte{0, 0, 0, 0},
Port: 0,
}
}
if request.ProtoMajor == 3 {
remoteAddr = &net.UDPAddr{
IP: remoteAddr.(*net.TCPAddr).IP,
Port: remoteAddr.(*net.TCPAddr).Port,
}
} }
if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() { if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() {
remoteAddr = &net.TCPAddr{ remoteAddr = &net.TCPAddr{
IP: forwardedAddrs[0].IP(), IP: forwardedAddrs[0].IP(),
Port: int(0), Port: 0,
} }
} }
@ -289,6 +299,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
responseFlusher: responseFlusher, responseFlusher: responseFlusher,
}, },
reader: request.Body, reader: request.Body,
localAddr: h.localAddr,
remoteAddr: remoteAddr, remoteAddr: remoteAddr,
} }
if sessionId != "" { // if not stream-one if sessionId != "" { // if not stream-one
@ -362,34 +373,30 @@ type Listener struct {
isH3 bool isH3 bool
} }
func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
l := &Listener{ l := &Listener{
addConn: addConn, addConn: addConn,
} }
shSettings := streamSettings.ProtocolSettings.(*Config) l.config = streamSettings.ProtocolSettings.(*Config)
l.config = shSettings
if l.config != nil { if l.config != nil {
if streamSettings.SocketSettings == nil { if streamSettings.SocketSettings == nil {
streamSettings.SocketSettings = &internet.SocketConfig{} streamSettings.SocketSettings = &internet.SocketConfig{}
} }
} }
var listener net.Listener
var err error
var localAddr = gonet.TCPAddr{}
handler := &requestHandler{ handler := &requestHandler{
config: shSettings, config: l.config,
host: shSettings.Host, host: l.config.Host,
path: shSettings.GetNormalizedPath(), path: l.config.GetNormalizedPath(),
ln: l, ln: l,
sessionMu: &sync.Mutex{}, sessionMu: &sync.Mutex{},
sessions: sync.Map{}, sessions: sync.Map{},
localAddr: localAddr,
} }
tlsConfig := getTLSConfig(streamSettings) tlsConfig := getTLSConfig(streamSettings)
l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3" l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"
var err error
if port == net.Port(0) { // unix if port == net.Port(0) { // unix
listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ l.listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
Name: address.Domain(), Name: address.Domain(),
Net: "unix", Net: "unix",
}, streamSettings.SocketSettings) }, streamSettings.SocketSettings)
@ -405,13 +412,14 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
if err != nil { if err != nil {
return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err) return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err)
} }
h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil) l.h3listener, err = quic.ListenEarly(Conn, tlsConfig, nil)
if err != nil { if err != nil {
return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err) return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err)
} }
l.h3listener = h3listener
errors.LogInfo(ctx, "listening QUIC for XHTTP/3 on ", address, ":", port) errors.LogInfo(ctx, "listening QUIC for XHTTP/3 on ", address, ":", port)
handler.localAddr = l.h3listener.Addr()
l.h3server = &http3.Server{ l.h3server = &http3.Server{
Handler: handler, Handler: handler,
} }
@ -421,11 +429,7 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
} }
}() }()
} else { // tcp } else { // tcp
localAddr = gonet.TCPAddr{ l.listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(),
Port: int(port),
}
listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(), IP: address.IP(),
Port: int(port), Port: int(port),
}, streamSettings.SocketSettings) }, streamSettings.SocketSettings)
@ -436,26 +440,24 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
} }
// tcp/unix (h1/h2) // tcp/unix (h1/h2)
if listener != nil { if l.listener != nil {
if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
listener = tls.NewListener(listener, tlsConfig) l.listener = gotls.NewListener(l.listener, tlsConfig)
} }
} }
if config := reality.ConfigFromStreamSettings(streamSettings); config != nil { if config := reality.ConfigFromStreamSettings(streamSettings); config != nil {
listener = goreality.NewListener(listener, config.GetREALITYConfig()) l.listener = goreality.NewListener(l.listener, config.GetREALITYConfig())
} }
handler.localAddr = l.listener.Addr()
// h2cHandler can handle both plaintext HTTP/1.1 and h2c // h2cHandler can handle both plaintext HTTP/1.1 and h2c
h2cHandler := h2c.NewHandler(handler, &http2.Server{})
l.listener = listener
l.server = http.Server{ l.server = http.Server{
Handler: h2cHandler, Handler: h2c.NewHandler(handler, &http2.Server{}),
ReadHeaderTimeout: time.Second * 4, ReadHeaderTimeout: time.Second * 4,
MaxHeaderBytes: 8192, MaxHeaderBytes: 8192,
} }
go func() { go func() {
if err := l.server.Serve(l.listener); err != nil { if err := l.server.Serve(l.listener); err != nil {
errors.LogWarningInner(ctx, err, "failed to serve HTTP for XHTTP") errors.LogWarningInner(ctx, err, "failed to serve HTTP for XHTTP")
@ -488,13 +490,13 @@ func (ln *Listener) Close() error {
} }
return errors.New("listener does not have an HTTP/3 server or a net.listener") return errors.New("listener does not have an HTTP/3 server or a net.listener")
} }
func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *tls.Config { func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *gotls.Config {
config := v2tls.ConfigFromStreamSettings(streamSettings) config := tls.ConfigFromStreamSettings(streamSettings)
if config == nil { if config == nil {
return &tls.Config{} return &gotls.Config{}
} }
return config.GetTLSConfig() return config.GetTLSConfig()
} }
func init() { func init() {
common.Must(internet.RegisterTransportListener(protocolName, ListenSH)) common.Must(internet.RegisterTransportListener(protocolName, ListenXH))
} }

View File

@ -26,9 +26,9 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
func Test_listenSHAndDial(t *testing.T) { func Test_ListenXHAndDial(t *testing.T) {
listenPort := tcp.PickPort() listenPort := tcp.PickPort()
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
ProtocolName: "splithttp", ProtocolName: "splithttp",
ProtocolSettings: &Config{ ProtocolSettings: &Config{
Path: "/sh", Path: "/sh",
@ -85,7 +85,7 @@ func Test_listenSHAndDial(t *testing.T) {
func TestDialWithRemoteAddr(t *testing.T) { func TestDialWithRemoteAddr(t *testing.T) {
listenPort := tcp.PickPort() listenPort := tcp.PickPort()
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
ProtocolName: "splithttp", ProtocolName: "splithttp",
ProtocolSettings: &Config{ ProtocolSettings: &Config{
Path: "sh", Path: "sh",
@ -125,7 +125,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
common.Must(listen.Close()) common.Must(listen.Close())
} }
func Test_listenSHAndDial_TLS(t *testing.T) { func Test_ListenXHAndDial_TLS(t *testing.T) {
if runtime.GOARCH == "arm64" { if runtime.GOARCH == "arm64" {
return return
} }
@ -145,7 +145,7 @@ func Test_listenSHAndDial_TLS(t *testing.T) {
Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))}, Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
}, },
} }
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
go func() { go func() {
defer conn.Close() defer conn.Close()
@ -180,7 +180,7 @@ func Test_listenSHAndDial_TLS(t *testing.T) {
} }
} }
func Test_listenSHAndDial_H2C(t *testing.T) { func Test_ListenXHAndDial_H2C(t *testing.T) {
if runtime.GOARCH == "arm64" { if runtime.GOARCH == "arm64" {
return return
} }
@ -193,7 +193,7 @@ func Test_listenSHAndDial_H2C(t *testing.T) {
Path: "shs", Path: "shs",
}, },
} }
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
go func() { go func() {
_ = conn.Close() _ = conn.Close()
}() }()
@ -227,7 +227,7 @@ func Test_listenSHAndDial_H2C(t *testing.T) {
} }
} }
func Test_listenSHAndDial_QUIC(t *testing.T) { func Test_ListenXHAndDial_QUIC(t *testing.T) {
if runtime.GOARCH == "arm64" { if runtime.GOARCH == "arm64" {
return return
} }
@ -250,7 +250,7 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
} }
serverClosed := false serverClosed := false
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
go func() { go func() {
defer conn.Close() defer conn.Close()
@ -309,11 +309,11 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
} }
} }
func Test_listenSHAndDial_Unix(t *testing.T) { func Test_ListenXHAndDial_Unix(t *testing.T) {
tempDir := t.TempDir() tempDir := t.TempDir()
tempSocket := tempDir + "/server.sock" tempSocket := tempDir + "/server.sock"
listen, err := ListenSH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{ listen, err := ListenXH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{
ProtocolName: "splithttp", ProtocolName: "splithttp",
ProtocolSettings: &Config{ ProtocolSettings: &Config{
Path: "/sh", Path: "/sh",
@ -373,7 +373,7 @@ func Test_listenSHAndDial_Unix(t *testing.T) {
func Test_queryString(t *testing.T) { func Test_queryString(t *testing.T) {
listenPort := tcp.PickPort() listenPort := tcp.PickPort()
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
ProtocolName: "splithttp", ProtocolName: "splithttp",
ProtocolSettings: &Config{ ProtocolSettings: &Config{
// this querystring does not have any effect, but sometimes people blindly copy it from websocket config. make sure the outbound doesn't break // this querystring does not have any effect, but sometimes people blindly copy it from websocket config. make sure the outbound doesn't break
@ -431,7 +431,7 @@ func Test_maxUpload(t *testing.T) {
} }
var uploadSize int var uploadSize int
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
go func(c stat.Connection) { go func(c stat.Connection) {
defer c.Close() defer c.Close()
var b [10240]byte var b [10240]byte

View File

@ -54,7 +54,7 @@ func (l *listenUDSWrapper) Accept() (net.Conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &listenUDSWrapperConn{Conn: conn}, nil return &UDSWrapperConn{Conn: conn}, nil
} }
func (l *listenUDSWrapper) Close() error { func (l *listenUDSWrapper) Close() error {
@ -65,11 +65,11 @@ func (l *listenUDSWrapper) Close() error {
return l.Listener.Close() return l.Listener.Close()
} }
type listenUDSWrapperConn struct { type UDSWrapperConn struct {
net.Conn net.Conn
} }
func (conn *listenUDSWrapperConn) RemoteAddr() net.Addr { func (conn *UDSWrapperConn) RemoteAddr() net.Addr {
return &net.TCPAddr{ return &net.TCPAddr{
IP: []byte{0, 0, 0, 0}, IP: []byte{0, 0, 0, 0},
} }