diff --git a/transport/internet/httpupgrade/httpupgrade_test.go b/transport/internet/httpupgrade/httpupgrade_test.go index f94298ca..6fcb7a97 100644 --- a/transport/internet/httpupgrade/httpupgrade_test.go +++ b/transport/internet/httpupgrade/httpupgrade_test.go @@ -151,7 +151,7 @@ func TestDialWithRemoteAddr(t *testing.T) { return } - _, err = c.Write([]byte("Response")) + _, err = c.Write([]byte(c.RemoteAddr().String())) common.Must(err) }(conn) }) @@ -169,7 +169,7 @@ func TestDialWithRemoteAddr(t *testing.T) { var b [1024]byte n, err := conn.Read(b[:]) common.Must(err) - if string(b[:n]) != "Response" { + if string(b[:n]) != "1.1.1.1:0" { t.Error("response: ", string(b[:n])) } diff --git a/transport/internet/splithttp/splithttp_test.go b/transport/internet/splithttp/splithttp_test.go index 7e22c9ad..d125cedd 100644 --- a/transport/internet/splithttp/splithttp_test.go +++ b/transport/internet/splithttp/splithttp_test.go @@ -96,7 +96,7 @@ func TestDialWithRemoteAddr(t *testing.T) { return } - _, err = c.Write([]byte("Response")) + _, err = c.Write([]byte(c.RemoteAddr().String())) common.Must(err) }(conn) }) @@ -113,7 +113,7 @@ func TestDialWithRemoteAddr(t *testing.T) { var b [1024]byte n, _ := conn.Read(b[:]) - if string(b[:n]) != "Response" { + if string(b[:n]) != "1.1.1.1:0" { t.Error("response: ", string(b[:n])) } diff --git a/transport/internet/websocket/connection.go b/transport/internet/websocket/connection.go index 0bb5dd7b..3ccead47 100644 --- a/transport/internet/websocket/connection.go +++ b/transport/internet/websocket/connection.go @@ -14,15 +14,19 @@ import ( var _ buf.Writer = (*connection)(nil) // connection is a wrapper for net.Conn over WebSocket connection. +// remoteAddr is used to pass "virtual" remote IP addresses in X-Forwarded-For. +// so we shouldn't directly read it form conn. type connection struct { - conn *websocket.Conn - reader io.Reader + conn *websocket.Conn + reader io.Reader + remoteAddr net.Addr } func NewConnection(conn *websocket.Conn, remoteAddr net.Addr, extraReader io.Reader) *connection { return &connection{ - conn: conn, - reader: extraReader, + conn: conn, + remoteAddr: remoteAddr, + reader: extraReader, } } @@ -90,7 +94,7 @@ func (c *connection) LocalAddr() net.Addr { } func (c *connection) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() + return c.remoteAddr } func (c *connection) SetDeadline(t time.Time) error { diff --git a/transport/internet/websocket/ws_test.go b/transport/internet/websocket/ws_test.go index 637b7f72..5fcc0a47 100644 --- a/transport/internet/websocket/ws_test.go +++ b/transport/internet/websocket/ws_test.go @@ -91,7 +91,7 @@ func TestDialWithRemoteAddr(t *testing.T) { return } - _, err = c.Write([]byte("Response")) + _, err = c.Write([]byte(c.RemoteAddr().String())) common.Must(err) }(conn) }) @@ -109,7 +109,7 @@ func TestDialWithRemoteAddr(t *testing.T) { var b [1024]byte n, err := conn.Read(b[:]) common.Must(err) - if string(b[:n]) != "Response" { + if string(b[:n]) != "1.1.1.1:0" { t.Error("response: ", string(b[:n])) }