From cae94570df467d6073f791a80bafc5f3f36773f9 Mon Sep 17 00:00:00 2001 From: deorth-kku Date: Sat, 3 Feb 2024 19:45:37 +0800 Subject: [PATCH] Fixing tcp connestions leak - always use HandshakeContext instead of Handshake - pickup dailer dropped ctx - rename HandshakeContextAddress to HandshakeAddressContext --- proxy/dokodemo/dokodemo.go | 8 +++--- proxy/http/client.go | 2 +- transport/internet/http/dialer.go | 2 +- transport/internet/tcp/dialer.go | 2 +- transport/internet/tls/grpc.go | 2 +- transport/internet/tls/tls.go | 36 ++++++++++++++++++++------ transport/internet/websocket/dialer.go | 4 +-- 7 files changed, 38 insertions(+), 18 deletions(-) diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 4a4735e8..1c59fe62 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -71,8 +71,8 @@ func (d *DokodemoDoor) policy() policy.Session { return p } -type hasHandshakeAddress interface { - HandshakeAddress() net.Address +type hasHandshakeAddressContext interface { + HandshakeAddressContext(ctx context.Context) net.Address } // Process implements proxy.Inbound. @@ -89,8 +89,8 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() { dest = outbound.Target destinationOverridden = true - } else if handshake, ok := conn.(hasHandshakeAddress); ok { - addr := handshake.HandshakeAddress() + } else if handshake, ok := conn.(hasHandshakeAddressContext); ok { + addr := handshake.HandshakeAddressContext(ctx) if addr != nil { dest.Address = addr destinationOverridden = true diff --git a/proxy/http/client.go b/proxy/http/client.go index 302e521d..72060c4d 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -308,7 +308,7 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u nextProto := "" if tlsConn, ok := iConn.(*tls.Conn); ok { - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { rawConn.Close() return nil, err } diff --git a/transport/internet/http/dialer.go b/transport/internet/http/dialer.go index 1ea3a738..513962d3 100644 --- a/transport/internet/http/dialer.go +++ b/transport/internet/http/dialer.go @@ -87,7 +87,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in } else { cn = tls.Client(pconn, tlsConfig).(*tls.Conn) } - if err := cn.Handshake(); err != nil { + if err := cn.HandshakeContext(ctx); err != nil { newError("failed to dial to " + addr).Base(err).AtError().WriteToLog() return nil, err } diff --git a/transport/internet/tcp/dialer.go b/transport/internet/tcp/dialer.go index 840062b1..06ee3ecf 100644 --- a/transport/internet/tcp/dialer.go +++ b/transport/internet/tcp/dialer.go @@ -24,7 +24,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me tlsConfig := config.GetTLSConfig(tls.WithDestination(dest)) if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil { conn = tls.UClient(conn, tlsConfig, fingerprint) - if err := conn.(*tls.UConn).Handshake(); err != nil { + if err := conn.(*tls.UConn).HandshakeContext(ctx); err != nil { return nil, err } } else { diff --git a/transport/internet/tls/grpc.go b/transport/internet/tls/grpc.go index a698196b..6e5dc578 100644 --- a/transport/internet/tls/grpc.go +++ b/transport/internet/tls/grpc.go @@ -65,7 +65,7 @@ func (c *grpcUtls) ClientHandshake(ctx context.Context, authority string, rawCon conn := UClient(rawConn, cfg, c.fingerprint).(*UConn) errChannel := make(chan error, 1) go func() { - errChannel <- conn.Handshake() + errChannel <- conn.HandshakeContext(ctx) close(errChannel) }() select { diff --git a/transport/internet/tls/tls.go b/transport/internet/tls/tls.go index e73a495b..73631ef9 100644 --- a/transport/internet/tls/tls.go +++ b/transport/internet/tls/tls.go @@ -1,9 +1,11 @@ package tls import ( + "context" "crypto/rand" "crypto/tls" "math/big" + "time" utls "github.com/refraction-networking/utls" "github.com/xtls/xray-core/common/buf" @@ -14,7 +16,7 @@ import ( type Interface interface { net.Conn - Handshake() error + HandshakeContext(ctx context.Context) error VerifyHostname(host string) error NegotiatedProtocol() (name string, mutual bool) } @@ -25,6 +27,16 @@ type Conn struct { *tls.Conn } +const tlsCloseTimeout = 250 * time.Millisecond + +func (c *Conn) Close() error { + timer := time.AfterFunc(tlsCloseTimeout, func() { + c.Conn.NetConn().Close() + }) + defer timer.Stop() + return c.Conn.Close() +} + func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error { mb = buf.Compact(mb) mb, err := buf.WriteMultiBuffer(c, mb) @@ -32,8 +44,8 @@ func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error { return err } -func (c *Conn) HandshakeAddress() net.Address { - if err := c.Handshake(); err != nil { +func (c *Conn) HandshakeAddressContext(ctx context.Context) net.Address { + if err := c.HandshakeContext(ctx); err != nil { return nil } state := c.ConnectionState() @@ -64,8 +76,16 @@ type UConn struct { *utls.UConn } -func (c *UConn) HandshakeAddress() net.Address { - if err := c.Handshake(); err != nil { +func (c *UConn) Close() error { + timer := time.AfterFunc(tlsCloseTimeout, func() { + c.Conn.NetConn().Close() + }) + defer timer.Stop() + return c.Conn.Close() +} + +func (c *UConn) HandshakeAddressContext(ctx context.Context) net.Address { + if err := c.HandshakeContext(ctx); err != nil { return nil } state := c.ConnectionState() @@ -77,7 +97,7 @@ func (c *UConn) HandshakeAddress() net.Address { // WebsocketHandshake basically calls UConn.Handshake inside it but it will only send // http/1.1 in its ALPN. -func (c *UConn) WebsocketHandshake() error { +func (c *UConn) WebsocketHandshakeContext(ctx context.Context) error { // Build the handshake state. This will apply every variable of the TLS of the // fingerprint in the UConn if err := c.BuildHandshakeState(); err != nil { @@ -99,7 +119,7 @@ func (c *UConn) WebsocketHandshake() error { if err := c.BuildHandshakeState(); err != nil { return err } - return c.Handshake() + return c.HandshakeContext(ctx) } func (c *UConn) NegotiatedProtocol() (name string, mutual bool) { @@ -118,7 +138,7 @@ func copyConfig(c *tls.Config) *utls.Config { ServerName: c.ServerName, InsecureSkipVerify: c.InsecureSkipVerify, VerifyPeerCertificate: c.VerifyPeerCertificate, - KeyLogWriter: c.KeyLogWriter, + KeyLogWriter: c.KeyLogWriter, } } diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index 02b73a66..4ef27831 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -96,7 +96,7 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in } // TLS and apply the handshake cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn) - if err := cn.WebsocketHandshake(); err != nil { + if err := cn.WebsocketHandshakeContext(ctx); err != nil { newError("failed to dial to " + addr).Base(err).AtError().WriteToLog() return nil, err } @@ -147,7 +147,7 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in header.Set("Sec-WebSocket-Protocol", base64.RawURLEncoding.EncodeToString(ed)) } - conn, resp, err := dialer.Dial(uri, header) + conn, resp, err := dialer.DialContext(ctx, uri, header) if err != nil { var reason string if resp != nil {