XHTTP server: Set remoteAddr & localAddr correctly

Completes 22c50a70c6
This commit is contained in:
RPRX 2025-02-18 10:50:50 +00:00 committed by GitHub
parent eef74b2c7d
commit 8cb63db6c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 77 additions and 71 deletions

View File

@ -76,8 +76,9 @@ type (
)
var (
ResolveUnixAddr = net.ResolveUnixAddr
ResolveTCPAddr = net.ResolveTCPAddr
ResolveUDPAddr = net.ResolveUDPAddr
ResolveUnixAddr = net.ResolveUnixAddr
)
type Resolver = net.Resolver

View File

@ -524,7 +524,7 @@ func XtlsFilterTls(buffer buf.MultiBuffer, trafficState *TrafficState, ctx conte
}
}
// UnwrapRawConn support unwrap stats, tls, utls, reality and proxyproto conn and get raw tcp conn from it
// UnwrapRawConn support unwrap stats, tls, utls, reality, proxyproto, uds-wrapper conn and get raw tcp/uds conn from it
func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
var readCounter, writerCounter stats.Counter
if conn != nil {
@ -547,6 +547,9 @@ func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
conn = pc.Raw()
// 8192 > 4096, there is no need to process pc's bufReader
}
if uc, ok := conn.(*internet.UDSWrapperConn); ok {
conn = uc.Conn
}
}
return conn, readCounter, writerCounter
}

View File

@ -3,9 +3,8 @@ package splithttp
import (
"bytes"
"context"
"crypto/tls"
gotls "crypto/tls"
"io"
gonet "net"
"net/http"
"net/url"
"strconv"
@ -24,7 +23,7 @@ import (
"github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/reality"
"github.com/xtls/xray-core/transport/internet/stat"
v2tls "github.com/xtls/xray-core/transport/internet/tls"
"github.com/xtls/xray-core/transport/internet/tls"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
@ -36,7 +35,7 @@ type requestHandler struct {
ln *Listener
sessionMu *sync.Mutex
sessions sync.Map
localAddr gonet.TCPAddr
localAddr net.Addr
}
type httpSession struct {
@ -144,14 +143,25 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
}
forwardedAddrs := http_proto.ParseXForwardedFor(request.Header)
remoteAddr, err := gonet.ResolveTCPAddr("tcp", request.RemoteAddr)
var remoteAddr net.Addr
var err error
remoteAddr, err = net.ResolveTCPAddr("tcp", request.RemoteAddr)
if err != nil {
remoteAddr = &gonet.TCPAddr{}
remoteAddr = &net.TCPAddr{
IP: []byte{0, 0, 0, 0},
Port: 0,
}
}
if request.ProtoMajor == 3 {
remoteAddr = &net.UDPAddr{
IP: remoteAddr.(*net.TCPAddr).IP,
Port: remoteAddr.(*net.TCPAddr).Port,
}
}
if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() {
remoteAddr = &net.TCPAddr{
IP: forwardedAddrs[0].IP(),
Port: int(0),
Port: 0,
}
}
@ -289,6 +299,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
responseFlusher: responseFlusher,
},
reader: request.Body,
localAddr: h.localAddr,
remoteAddr: remoteAddr,
}
if sessionId != "" { // if not stream-one
@ -362,34 +373,30 @@ type Listener struct {
isH3 bool
}
func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
l := &Listener{
addConn: addConn,
}
shSettings := streamSettings.ProtocolSettings.(*Config)
l.config = shSettings
l.config = streamSettings.ProtocolSettings.(*Config)
if l.config != nil {
if streamSettings.SocketSettings == nil {
streamSettings.SocketSettings = &internet.SocketConfig{}
}
}
var listener net.Listener
var err error
var localAddr = gonet.TCPAddr{}
handler := &requestHandler{
config: shSettings,
host: shSettings.Host,
path: shSettings.GetNormalizedPath(),
config: l.config,
host: l.config.Host,
path: l.config.GetNormalizedPath(),
ln: l,
sessionMu: &sync.Mutex{},
sessions: sync.Map{},
localAddr: localAddr,
}
tlsConfig := getTLSConfig(streamSettings)
l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"
var err error
if port == net.Port(0) { // unix
listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
l.listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
Name: address.Domain(),
Net: "unix",
}, streamSettings.SocketSettings)
@ -405,13 +412,14 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
if err != nil {
return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err)
}
h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil)
l.h3listener, err = quic.ListenEarly(Conn, tlsConfig, nil)
if err != nil {
return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err)
}
l.h3listener = h3listener
errors.LogInfo(ctx, "listening QUIC for XHTTP/3 on ", address, ":", port)
handler.localAddr = l.h3listener.Addr()
l.h3server = &http3.Server{
Handler: handler,
}
@ -421,11 +429,7 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
}
}()
} else { // tcp
localAddr = gonet.TCPAddr{
IP: address.IP(),
Port: int(port),
}
listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
l.listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(),
Port: int(port),
}, streamSettings.SocketSettings)
@ -436,26 +440,24 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
}
// tcp/unix (h1/h2)
if listener != nil {
if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
if l.listener != nil {
if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
listener = tls.NewListener(listener, tlsConfig)
l.listener = gotls.NewListener(l.listener, tlsConfig)
}
}
if config := reality.ConfigFromStreamSettings(streamSettings); config != nil {
l.listener = goreality.NewListener(l.listener, config.GetREALITYConfig())
}
if config := reality.ConfigFromStreamSettings(streamSettings); config != nil {
listener = goreality.NewListener(listener, config.GetREALITYConfig())
}
handler.localAddr = l.listener.Addr()
// h2cHandler can handle both plaintext HTTP/1.1 and h2c
h2cHandler := h2c.NewHandler(handler, &http2.Server{})
l.listener = listener
l.server = http.Server{
Handler: h2cHandler,
Handler: h2c.NewHandler(handler, &http2.Server{}),
ReadHeaderTimeout: time.Second * 4,
MaxHeaderBytes: 8192,
}
go func() {
if err := l.server.Serve(l.listener); err != nil {
errors.LogWarningInner(ctx, err, "failed to serve HTTP for XHTTP")
@ -488,13 +490,13 @@ func (ln *Listener) Close() error {
}
return errors.New("listener does not have an HTTP/3 server or a net.listener")
}
func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *tls.Config {
config := v2tls.ConfigFromStreamSettings(streamSettings)
func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *gotls.Config {
config := tls.ConfigFromStreamSettings(streamSettings)
if config == nil {
return &tls.Config{}
return &gotls.Config{}
}
return config.GetTLSConfig()
}
func init() {
common.Must(internet.RegisterTransportListener(protocolName, ListenSH))
common.Must(internet.RegisterTransportListener(protocolName, ListenXH))
}

View File

@ -26,9 +26,9 @@ import (
"golang.org/x/net/http2"
)
func Test_listenSHAndDial(t *testing.T) {
func Test_ListenXHAndDial(t *testing.T) {
listenPort := tcp.PickPort()
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
ProtocolName: "splithttp",
ProtocolSettings: &Config{
Path: "/sh",
@ -85,7 +85,7 @@ func Test_listenSHAndDial(t *testing.T) {
func TestDialWithRemoteAddr(t *testing.T) {
listenPort := tcp.PickPort()
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
ProtocolName: "splithttp",
ProtocolSettings: &Config{
Path: "sh",
@ -125,7 +125,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
common.Must(listen.Close())
}
func Test_listenSHAndDial_TLS(t *testing.T) {
func Test_ListenXHAndDial_TLS(t *testing.T) {
if runtime.GOARCH == "arm64" {
return
}
@ -145,7 +145,7 @@ func Test_listenSHAndDial_TLS(t *testing.T) {
Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
},
}
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
go func() {
defer conn.Close()
@ -180,7 +180,7 @@ func Test_listenSHAndDial_TLS(t *testing.T) {
}
}
func Test_listenSHAndDial_H2C(t *testing.T) {
func Test_ListenXHAndDial_H2C(t *testing.T) {
if runtime.GOARCH == "arm64" {
return
}
@ -193,7 +193,7 @@ func Test_listenSHAndDial_H2C(t *testing.T) {
Path: "shs",
},
}
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
go func() {
_ = conn.Close()
}()
@ -227,7 +227,7 @@ func Test_listenSHAndDial_H2C(t *testing.T) {
}
}
func Test_listenSHAndDial_QUIC(t *testing.T) {
func Test_ListenXHAndDial_QUIC(t *testing.T) {
if runtime.GOARCH == "arm64" {
return
}
@ -250,7 +250,7 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
}
serverClosed := false
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
go func() {
defer conn.Close()
@ -309,11 +309,11 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
}
}
func Test_listenSHAndDial_Unix(t *testing.T) {
func Test_ListenXHAndDial_Unix(t *testing.T) {
tempDir := t.TempDir()
tempSocket := tempDir + "/server.sock"
listen, err := ListenSH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{
listen, err := ListenXH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{
ProtocolName: "splithttp",
ProtocolSettings: &Config{
Path: "/sh",
@ -373,7 +373,7 @@ func Test_listenSHAndDial_Unix(t *testing.T) {
func Test_queryString(t *testing.T) {
listenPort := tcp.PickPort()
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
ProtocolName: "splithttp",
ProtocolSettings: &Config{
// this querystring does not have any effect, but sometimes people blindly copy it from websocket config. make sure the outbound doesn't break
@ -431,7 +431,7 @@ func Test_maxUpload(t *testing.T) {
}
var uploadSize int
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
go func(c stat.Connection) {
defer c.Close()
var b [10240]byte

View File

@ -54,7 +54,7 @@ func (l *listenUDSWrapper) Accept() (net.Conn, error) {
if err != nil {
return nil, err
}
return &listenUDSWrapperConn{Conn: conn}, nil
return &UDSWrapperConn{Conn: conn}, nil
}
func (l *listenUDSWrapper) Close() error {
@ -65,11 +65,11 @@ func (l *listenUDSWrapper) Close() error {
return l.Listener.Close()
}
type listenUDSWrapperConn struct {
type UDSWrapperConn struct {
net.Conn
}
func (conn *listenUDSWrapperConn) RemoteAddr() net.Addr {
func (conn *UDSWrapperConn) RemoteAddr() net.Addr {
return &net.TCPAddr{
IP: []byte{0, 0, 0, 0},
}