diff --git a/transport/internet/httpupgrade/dialer.go b/transport/internet/httpupgrade/dialer.go index b54af9a6..9f909503 100644 --- a/transport/internet/httpupgrade/dialer.go +++ b/transport/internet/httpupgrade/dialer.go @@ -100,7 +100,7 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings * return connRF, nil } -func dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { +func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx)) conn, err := dialhttpUpgrade(ctx, dest, streamSettings) @@ -111,5 +111,5 @@ func dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me } func init() { - common.Must(internet.RegisterTransportDialer(protocolName, dial)) + common.Must(internet.RegisterTransportDialer(protocolName, Dial)) } diff --git a/transport/internet/httpupgrade/httpupgrade_test.go b/transport/internet/httpupgrade/httpupgrade_test.go new file mode 100644 index 00000000..d8d3ad84 --- /dev/null +++ b/transport/internet/httpupgrade/httpupgrade_test.go @@ -0,0 +1,153 @@ +package httpupgrade_test + +import ( + "context" + "runtime" + "testing" + "time" + + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/protocol/tls/cert" + "github.com/xtls/xray-core/testing/servers/tcp" + "github.com/xtls/xray-core/transport/internet" + . "github.com/xtls/xray-core/transport/internet/httpupgrade" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/transport/internet/tls" +) + +func Test_listenHTTPUpgradeAndDial(t *testing.T) { + listenPort := tcp.PickPort() + listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ + ProtocolName: "httpupgrade", + ProtocolSettings: &Config{ + Path: "httpupgrade", + }, + }, func(conn stat.Connection) { + go func(c stat.Connection) { + defer c.Close() + + var b [1024]byte + _, err := c.Read(b[:]) + if err != nil { + return + } + + common.Must2(c.Write([]byte("Response"))) + }(conn) + }) + common.Must(err) + + ctx := context.Background() + streamSettings := &internet.MemoryStreamConfig{ + ProtocolName: "httpupgrade", + ProtocolSettings: &Config{Path: "httpupgrade"}, + } + conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) + + common.Must(err) + _, err = conn.Write([]byte("Test connection 1")) + common.Must(err) + + var b [1024]byte + n, err := conn.Read(b[:]) + common.Must(err) + if string(b[:n]) != "Response" { + t.Error("response: ", string(b[:n])) + } + + common.Must(conn.Close()) + <-time.After(time.Second * 5) + conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) + common.Must(err) + _, err = conn.Write([]byte("Test connection 2")) + common.Must(err) + n, err = conn.Read(b[:]) + common.Must(err) + if string(b[:n]) != "Response" { + t.Error("response: ", string(b[:n])) + } + common.Must(conn.Close()) + + common.Must(listen.Close()) +} + +func TestDialWithRemoteAddr(t *testing.T) { + listenPort := tcp.PickPort() + listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ + ProtocolName: "httpupgrade", + ProtocolSettings: &Config{ + Path: "httpupgrade", + }, + }, func(conn stat.Connection) { + go func(c stat.Connection) { + defer c.Close() + + var b [1024]byte + _, err := c.Read(b[:]) + // common.Must(err) + if err != nil { + return + } + + _, err = c.Write([]byte("Response")) + common.Must(err) + }(conn) + }) + common.Must(err) + + conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), listenPort), &internet.MemoryStreamConfig{ + ProtocolName: "httpupgrade", + ProtocolSettings: &Config{Path: "httpupgrade", Header: map[string]string{"X-Forwarded-For": "1.1.1.1"}}, + }) + + common.Must(err) + _, err = conn.Write([]byte("Test connection 1")) + common.Must(err) + + var b [1024]byte + n, err := conn.Read(b[:]) + common.Must(err) + if string(b[:n]) != "Response" { + t.Error("response: ", string(b[:n])) + } + + common.Must(listen.Close()) +} + +func Test_listenHTTPUpgradeAndDial_TLS(t *testing.T) { + listenPort := tcp.PickPort() + if runtime.GOARCH == "arm64" { + return + } + + start := time.Now() + + streamSettings := &internet.MemoryStreamConfig{ + ProtocolName: "httpupgrade", + ProtocolSettings: &Config{ + Path: "httpupgrades", + }, + SecurityType: "tls", + SecuritySettings: &tls.Config{ + AllowInsecure: true, + Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))}, + }, + } + listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { + go func() { + _ = conn.Close() + }() + }) + common.Must(err) + defer listen.Close() + + conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) + common.Must(err) + _ = conn.Close() + + end := time.Now() + if !end.Before(start.Add(time.Second * 5)) { + t.Error("end: ", end, " start: ", start) + } +} \ No newline at end of file diff --git a/transport/internet/httpupgrade/hub.go b/transport/internet/httpupgrade/hub.go index 816730bc..5c3ccbd3 100644 --- a/transport/internet/httpupgrade/hub.go +++ b/transport/internet/httpupgrade/hub.go @@ -97,7 +97,7 @@ func (s *server) keepAccepting() { } } -func listenHTTPUpgrade(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { +func ListenHTTPUpgrade(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { transportConfiguration := streamSettings.ProtocolSettings.(*Config) if transportConfiguration != nil { if streamSettings.SocketSettings == nil { @@ -147,5 +147,5 @@ func listenHTTPUpgrade(ctx context.Context, address net.Address, port net.Port, } func init() { - common.Must(internet.RegisterTransportListener(protocolName, listenHTTPUpgrade)) + common.Must(internet.RegisterTransportListener(protocolName, ListenHTTPUpgrade)) }