mirror of
https://github.com/XTLS/Xray-core.git
synced 2025-02-22 08:40:44 +00:00
XHTTP server: Set remoteAddr & localAddr correctly
Completes 22c50a70c6
This commit is contained in:
parent
eef74b2c7d
commit
8cb63db6c0
@ -76,8 +76,9 @@ type (
|
||||
)
|
||||
|
||||
var (
|
||||
ResolveUnixAddr = net.ResolveUnixAddr
|
||||
ResolveTCPAddr = net.ResolveTCPAddr
|
||||
ResolveUDPAddr = net.ResolveUDPAddr
|
||||
ResolveUnixAddr = net.ResolveUnixAddr
|
||||
)
|
||||
|
||||
type Resolver = net.Resolver
|
||||
|
@ -113,12 +113,12 @@ type TrafficState struct {
|
||||
|
||||
type InboundState struct {
|
||||
// reader link state
|
||||
WithinPaddingBuffers bool
|
||||
UplinkReaderDirectCopy bool
|
||||
RemainingCommand int32
|
||||
RemainingContent int32
|
||||
RemainingPadding int32
|
||||
CurrentCommand int
|
||||
WithinPaddingBuffers bool
|
||||
UplinkReaderDirectCopy bool
|
||||
RemainingCommand int32
|
||||
RemainingContent int32
|
||||
RemainingPadding int32
|
||||
CurrentCommand int
|
||||
// write link state
|
||||
IsPadding bool
|
||||
DownlinkWriterDirectCopy bool
|
||||
@ -133,19 +133,19 @@ type OutboundState struct {
|
||||
RemainingPadding int32
|
||||
CurrentCommand int
|
||||
// write link state
|
||||
IsPadding bool
|
||||
UplinkWriterDirectCopy bool
|
||||
IsPadding bool
|
||||
UplinkWriterDirectCopy bool
|
||||
}
|
||||
|
||||
func NewTrafficState(userUUID []byte) *TrafficState {
|
||||
return &TrafficState{
|
||||
UserUUID: userUUID,
|
||||
NumberOfPacketToFilter: 8,
|
||||
EnableXtls: false,
|
||||
IsTLS12orAbove: false,
|
||||
IsTLS: false,
|
||||
Cipher: 0,
|
||||
RemainingServerHello: -1,
|
||||
UserUUID: userUUID,
|
||||
NumberOfPacketToFilter: 8,
|
||||
EnableXtls: false,
|
||||
IsTLS12orAbove: false,
|
||||
IsTLS: false,
|
||||
Cipher: 0,
|
||||
RemainingServerHello: -1,
|
||||
Inbound: InboundState{
|
||||
WithinPaddingBuffers: true,
|
||||
UplinkReaderDirectCopy: false,
|
||||
@ -524,7 +524,7 @@ func XtlsFilterTls(buffer buf.MultiBuffer, trafficState *TrafficState, ctx conte
|
||||
}
|
||||
}
|
||||
|
||||
// UnwrapRawConn support unwrap stats, tls, utls, reality and proxyproto conn and get raw tcp conn from it
|
||||
// UnwrapRawConn support unwrap stats, tls, utls, reality, proxyproto, uds-wrapper conn and get raw tcp/uds conn from it
|
||||
func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
|
||||
var readCounter, writerCounter stats.Counter
|
||||
if conn != nil {
|
||||
@ -547,6 +547,9 @@ func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
|
||||
conn = pc.Raw()
|
||||
// 8192 > 4096, there is no need to process pc's bufReader
|
||||
}
|
||||
if uc, ok := conn.(*internet.UDSWrapperConn); ok {
|
||||
conn = uc.Conn
|
||||
}
|
||||
}
|
||||
return conn, readCounter, writerCounter
|
||||
}
|
||||
|
@ -3,9 +3,8 @@ package splithttp
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
gotls "crypto/tls"
|
||||
"io"
|
||||
gonet "net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
@ -24,7 +23,7 @@ import (
|
||||
"github.com/xtls/xray-core/transport/internet"
|
||||
"github.com/xtls/xray-core/transport/internet/reality"
|
||||
"github.com/xtls/xray-core/transport/internet/stat"
|
||||
v2tls "github.com/xtls/xray-core/transport/internet/tls"
|
||||
"github.com/xtls/xray-core/transport/internet/tls"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
@ -36,7 +35,7 @@ type requestHandler struct {
|
||||
ln *Listener
|
||||
sessionMu *sync.Mutex
|
||||
sessions sync.Map
|
||||
localAddr gonet.TCPAddr
|
||||
localAddr net.Addr
|
||||
}
|
||||
|
||||
type httpSession struct {
|
||||
@ -144,14 +143,25 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
|
||||
}
|
||||
|
||||
forwardedAddrs := http_proto.ParseXForwardedFor(request.Header)
|
||||
remoteAddr, err := gonet.ResolveTCPAddr("tcp", request.RemoteAddr)
|
||||
var remoteAddr net.Addr
|
||||
var err error
|
||||
remoteAddr, err = net.ResolveTCPAddr("tcp", request.RemoteAddr)
|
||||
if err != nil {
|
||||
remoteAddr = &gonet.TCPAddr{}
|
||||
remoteAddr = &net.TCPAddr{
|
||||
IP: []byte{0, 0, 0, 0},
|
||||
Port: 0,
|
||||
}
|
||||
}
|
||||
if request.ProtoMajor == 3 {
|
||||
remoteAddr = &net.UDPAddr{
|
||||
IP: remoteAddr.(*net.TCPAddr).IP,
|
||||
Port: remoteAddr.(*net.TCPAddr).Port,
|
||||
}
|
||||
}
|
||||
if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() {
|
||||
remoteAddr = &net.TCPAddr{
|
||||
IP: forwardedAddrs[0].IP(),
|
||||
Port: int(0),
|
||||
Port: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@ -289,6 +299,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
|
||||
responseFlusher: responseFlusher,
|
||||
},
|
||||
reader: request.Body,
|
||||
localAddr: h.localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
if sessionId != "" { // if not stream-one
|
||||
@ -362,34 +373,30 @@ type Listener struct {
|
||||
isH3 bool
|
||||
}
|
||||
|
||||
func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
|
||||
func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
|
||||
l := &Listener{
|
||||
addConn: addConn,
|
||||
}
|
||||
shSettings := streamSettings.ProtocolSettings.(*Config)
|
||||
l.config = shSettings
|
||||
l.config = streamSettings.ProtocolSettings.(*Config)
|
||||
if l.config != nil {
|
||||
if streamSettings.SocketSettings == nil {
|
||||
streamSettings.SocketSettings = &internet.SocketConfig{}
|
||||
}
|
||||
}
|
||||
var listener net.Listener
|
||||
var err error
|
||||
var localAddr = gonet.TCPAddr{}
|
||||
handler := &requestHandler{
|
||||
config: shSettings,
|
||||
host: shSettings.Host,
|
||||
path: shSettings.GetNormalizedPath(),
|
||||
config: l.config,
|
||||
host: l.config.Host,
|
||||
path: l.config.GetNormalizedPath(),
|
||||
ln: l,
|
||||
sessionMu: &sync.Mutex{},
|
||||
sessions: sync.Map{},
|
||||
localAddr: localAddr,
|
||||
}
|
||||
tlsConfig := getTLSConfig(streamSettings)
|
||||
l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"
|
||||
|
||||
var err error
|
||||
if port == net.Port(0) { // unix
|
||||
listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
|
||||
l.listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
|
||||
Name: address.Domain(),
|
||||
Net: "unix",
|
||||
}, streamSettings.SocketSettings)
|
||||
@ -405,13 +412,14 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err)
|
||||
}
|
||||
h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil)
|
||||
l.h3listener, err = quic.ListenEarly(Conn, tlsConfig, nil)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err)
|
||||
}
|
||||
l.h3listener = h3listener
|
||||
errors.LogInfo(ctx, "listening QUIC for XHTTP/3 on ", address, ":", port)
|
||||
|
||||
handler.localAddr = l.h3listener.Addr()
|
||||
|
||||
l.h3server = &http3.Server{
|
||||
Handler: handler,
|
||||
}
|
||||
@ -421,11 +429,7 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
|
||||
}
|
||||
}()
|
||||
} else { // tcp
|
||||
localAddr = gonet.TCPAddr{
|
||||
IP: address.IP(),
|
||||
Port: int(port),
|
||||
}
|
||||
listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
|
||||
l.listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
|
||||
IP: address.IP(),
|
||||
Port: int(port),
|
||||
}, streamSettings.SocketSettings)
|
||||
@ -436,26 +440,24 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
|
||||
}
|
||||
|
||||
// tcp/unix (h1/h2)
|
||||
if listener != nil {
|
||||
if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
|
||||
if l.listener != nil {
|
||||
if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
|
||||
if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
|
||||
listener = tls.NewListener(listener, tlsConfig)
|
||||
l.listener = gotls.NewListener(l.listener, tlsConfig)
|
||||
}
|
||||
}
|
||||
|
||||
if config := reality.ConfigFromStreamSettings(streamSettings); config != nil {
|
||||
listener = goreality.NewListener(listener, config.GetREALITYConfig())
|
||||
l.listener = goreality.NewListener(l.listener, config.GetREALITYConfig())
|
||||
}
|
||||
|
||||
handler.localAddr = l.listener.Addr()
|
||||
|
||||
// h2cHandler can handle both plaintext HTTP/1.1 and h2c
|
||||
h2cHandler := h2c.NewHandler(handler, &http2.Server{})
|
||||
l.listener = listener
|
||||
l.server = http.Server{
|
||||
Handler: h2cHandler,
|
||||
Handler: h2c.NewHandler(handler, &http2.Server{}),
|
||||
ReadHeaderTimeout: time.Second * 4,
|
||||
MaxHeaderBytes: 8192,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := l.server.Serve(l.listener); err != nil {
|
||||
errors.LogWarningInner(ctx, err, "failed to serve HTTP for XHTTP")
|
||||
@ -488,13 +490,13 @@ func (ln *Listener) Close() error {
|
||||
}
|
||||
return errors.New("listener does not have an HTTP/3 server or a net.listener")
|
||||
}
|
||||
func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *tls.Config {
|
||||
config := v2tls.ConfigFromStreamSettings(streamSettings)
|
||||
func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *gotls.Config {
|
||||
config := tls.ConfigFromStreamSettings(streamSettings)
|
||||
if config == nil {
|
||||
return &tls.Config{}
|
||||
return &gotls.Config{}
|
||||
}
|
||||
return config.GetTLSConfig()
|
||||
}
|
||||
func init() {
|
||||
common.Must(internet.RegisterTransportListener(protocolName, ListenSH))
|
||||
common.Must(internet.RegisterTransportListener(protocolName, ListenXH))
|
||||
}
|
||||
|
@ -26,9 +26,9 @@ import (
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
func Test_listenSHAndDial(t *testing.T) {
|
||||
func Test_ListenXHAndDial(t *testing.T) {
|
||||
listenPort := tcp.PickPort()
|
||||
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
|
||||
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
|
||||
ProtocolName: "splithttp",
|
||||
ProtocolSettings: &Config{
|
||||
Path: "/sh",
|
||||
@ -85,7 +85,7 @@ func Test_listenSHAndDial(t *testing.T) {
|
||||
|
||||
func TestDialWithRemoteAddr(t *testing.T) {
|
||||
listenPort := tcp.PickPort()
|
||||
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
|
||||
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
|
||||
ProtocolName: "splithttp",
|
||||
ProtocolSettings: &Config{
|
||||
Path: "sh",
|
||||
@ -125,7 +125,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
|
||||
common.Must(listen.Close())
|
||||
}
|
||||
|
||||
func Test_listenSHAndDial_TLS(t *testing.T) {
|
||||
func Test_ListenXHAndDial_TLS(t *testing.T) {
|
||||
if runtime.GOARCH == "arm64" {
|
||||
return
|
||||
}
|
||||
@ -145,7 +145,7 @@ func Test_listenSHAndDial_TLS(t *testing.T) {
|
||||
Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
|
||||
},
|
||||
}
|
||||
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
|
||||
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
|
||||
@ -180,7 +180,7 @@ func Test_listenSHAndDial_TLS(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_listenSHAndDial_H2C(t *testing.T) {
|
||||
func Test_ListenXHAndDial_H2C(t *testing.T) {
|
||||
if runtime.GOARCH == "arm64" {
|
||||
return
|
||||
}
|
||||
@ -193,7 +193,7 @@ func Test_listenSHAndDial_H2C(t *testing.T) {
|
||||
Path: "shs",
|
||||
},
|
||||
}
|
||||
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
|
||||
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
|
||||
go func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
@ -227,7 +227,7 @@ func Test_listenSHAndDial_H2C(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_listenSHAndDial_QUIC(t *testing.T) {
|
||||
func Test_ListenXHAndDial_QUIC(t *testing.T) {
|
||||
if runtime.GOARCH == "arm64" {
|
||||
return
|
||||
}
|
||||
@ -250,7 +250,7 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
|
||||
}
|
||||
|
||||
serverClosed := false
|
||||
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
|
||||
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
|
||||
@ -309,11 +309,11 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_listenSHAndDial_Unix(t *testing.T) {
|
||||
func Test_ListenXHAndDial_Unix(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tempSocket := tempDir + "/server.sock"
|
||||
|
||||
listen, err := ListenSH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{
|
||||
listen, err := ListenXH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{
|
||||
ProtocolName: "splithttp",
|
||||
ProtocolSettings: &Config{
|
||||
Path: "/sh",
|
||||
@ -373,7 +373,7 @@ func Test_listenSHAndDial_Unix(t *testing.T) {
|
||||
|
||||
func Test_queryString(t *testing.T) {
|
||||
listenPort := tcp.PickPort()
|
||||
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
|
||||
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
|
||||
ProtocolName: "splithttp",
|
||||
ProtocolSettings: &Config{
|
||||
// this querystring does not have any effect, but sometimes people blindly copy it from websocket config. make sure the outbound doesn't break
|
||||
@ -431,7 +431,7 @@ func Test_maxUpload(t *testing.T) {
|
||||
}
|
||||
|
||||
var uploadSize int
|
||||
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
|
||||
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
|
||||
go func(c stat.Connection) {
|
||||
defer c.Close()
|
||||
var b [10240]byte
|
||||
|
@ -54,7 +54,7 @@ func (l *listenUDSWrapper) Accept() (net.Conn, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &listenUDSWrapperConn{Conn: conn}, nil
|
||||
return &UDSWrapperConn{Conn: conn}, nil
|
||||
}
|
||||
|
||||
func (l *listenUDSWrapper) Close() error {
|
||||
@ -65,11 +65,11 @@ func (l *listenUDSWrapper) Close() error {
|
||||
return l.Listener.Close()
|
||||
}
|
||||
|
||||
type listenUDSWrapperConn struct {
|
||||
type UDSWrapperConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (conn *listenUDSWrapperConn) RemoteAddr() net.Addr {
|
||||
func (conn *UDSWrapperConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{
|
||||
IP: []byte{0, 0, 0, 0},
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user