From 24b637cd5e726fbc22a553457bb3582b8d13a0e9 Mon Sep 17 00:00:00 2001 From: Arthur Morgan <4637240+badO1a5A90@users.noreply.github.com> Date: Mon, 20 Sep 2021 20:11:21 +0800 Subject: [PATCH] Fix: CounterConnection with ReadV/WriteV (#720) Co-authored-by: JimhHan <50871214+JimhHan@users.noreply.github.com> --- app/proxyman/inbound/worker.go | 14 ++++---- app/proxyman/outbound/handler.go | 8 +++-- app/proxyman/outbound/handler_test.go | 11 +++--- common/buf/io.go | 36 ++++++++++++++++--- common/buf/readv_reader.go | 14 ++++++-- common/buf/readv_test.go | 2 +- common/buf/writer.go | 17 ++++++--- common/crypto/io.go | 2 +- proxy/dns/dns.go | 6 ++-- proxy/dokodemo/dokodemo.go | 5 +-- proxy/freedom/freedom.go | 8 +++-- proxy/http/client.go | 8 +++-- proxy/http/server.go | 7 ++-- proxy/mtproto/server.go | 5 +-- proxy/proxy.go | 4 ++- proxy/shadowsocks/client.go | 4 ++- proxy/shadowsocks/protocol.go | 4 +-- proxy/shadowsocks/server.go | 9 ++--- proxy/socks/client.go | 4 ++- proxy/socks/protocol.go | 12 +++---- proxy/socks/server.go | 9 ++--- proxy/trojan/client.go | 6 ++-- proxy/trojan/protocol.go | 7 ++-- proxy/trojan/server.go | 11 +++--- proxy/vless/encoding/encoding.go | 7 ++-- proxy/vless/inbound/inbound.go | 7 ++-- proxy/vless/outbound/outbound.go | 6 ++-- proxy/vmess/inbound/inbound.go | 7 ++-- proxy/vmess/outbound/outbound.go | 4 ++- testing/mocks/proxy.go | 6 ++-- transport/internet/dialer.go | 8 +++-- transport/internet/domainsocket/dial.go | 4 ++- transport/internet/domainsocket/listener.go | 4 ++- .../internet/domainsocket/listener_test.go | 6 ++-- transport/internet/grpc/dial.go | 5 +-- transport/internet/headers/http/http.go | 2 +- transport/internet/http/dialer.go | 4 ++- transport/internet/http/http_test.go | 4 ++- transport/internet/kcp/dialer.go | 6 ++-- transport/internet/kcp/kcp_test.go | 6 ++-- transport/internet/kcp/listener.go | 4 ++- transport/internet/quic/dialer.go | 6 ++-- transport/internet/quic/quic_test.go | 8 +++-- transport/internet/{ => stat}/connection.go | 8 ++--- transport/internet/tcp/dialer.go | 6 ++-- transport/internet/tcp/hub.go | 4 ++- transport/internet/tcp/sockopt_freebsd.go | 3 +- transport/internet/tcp/sockopt_linux.go | 5 +-- transport/internet/tcp/sockopt_other.go | 4 +-- transport/internet/tcp_hub.go | 4 ++- transport/internet/udp/dialer.go | 6 ++-- transport/internet/websocket/dialer.go | 6 ++-- transport/internet/websocket/ws_test.go | 12 ++++--- 53 files changed, 247 insertions(+), 128 deletions(-) rename transport/internet/{ => stat}/connection.go (71%) diff --git a/app/proxyman/inbound/worker.go b/app/proxyman/inbound/worker.go index 04094a4b..806a59fe 100644 --- a/app/proxyman/inbound/worker.go +++ b/app/proxyman/inbound/worker.go @@ -6,6 +6,8 @@ import ( "sync/atomic" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/app/proxyman" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" @@ -54,7 +56,7 @@ func getTProxyType(s *internet.MemoryStreamConfig) internet.SocketConfig_TProxyM return s.SocketSettings.Tproxy } -func (w *tcpWorker) callback(conn internet.Connection) { +func (w *tcpWorker) callback(conn stat.Connection) { ctx, cancel := context.WithCancel(w.ctx) sid := session.NewID() ctx = session.ContextWithID(ctx, sid) @@ -80,7 +82,7 @@ func (w *tcpWorker) callback(conn internet.Connection) { } if w.uplinkCounter != nil || w.downlinkCounter != nil { - conn = &internet.StatCouterConnection{ + conn = &stat.CounterConnection{ Connection: conn, ReadCounter: w.uplinkCounter, WriteCounter: w.downlinkCounter, @@ -117,7 +119,7 @@ func (w *tcpWorker) Proxy() proxy.Inbound { func (w *tcpWorker) Start() error { ctx := context.Background() - hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn internet.Connection) { + hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn stat.Connection) { go w.callback(conn) }) if err != nil { @@ -436,13 +438,13 @@ type dsWorker struct { ctx context.Context } -func (w *dsWorker) callback(conn internet.Connection) { +func (w *dsWorker) callback(conn stat.Connection) { ctx, cancel := context.WithCancel(w.ctx) sid := session.NewID() ctx = session.ContextWithID(ctx, sid) if w.uplinkCounter != nil || w.downlinkCounter != nil { - conn = &internet.StatCouterConnection{ + conn = &stat.CounterConnection{ Connection: conn, ReadCounter: w.uplinkCounter, WriteCounter: w.downlinkCounter, @@ -482,7 +484,7 @@ func (w *dsWorker) Port() net.Port { } func (w *dsWorker) Start() error { ctx := context.Background() - hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn internet.Connection) { + hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn stat.Connection) { go w.callback(conn) }) if err != nil { diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index 0452ff67..d5cc59f1 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -3,6 +3,8 @@ package outbound import ( "context" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/app/proxyman" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/mux" @@ -158,7 +160,7 @@ func (h *Handler) Address() net.Address { } // Dial implements internet.Dialer. -func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) { +func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connection, error) { if h.senderSettings != nil { if h.senderSettings.ProxySettings.HasTag() { tag := h.senderSettings.ProxySettings.Tag @@ -201,9 +203,9 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Conn return h.getStatCouterConnection(conn), err } -func (h *Handler) getStatCouterConnection(conn internet.Connection) internet.Connection { +func (h *Handler) getStatCouterConnection(conn stat.Connection) stat.Connection { if h.uplinkCounter != nil || h.downlinkCounter != nil { - return &internet.StatCouterConnection{ + return &stat.CounterConnection{ Connection: conn, ReadCounter: h.downlinkCounter, WriteCounter: h.uplinkCounter, diff --git a/app/proxyman/outbound/handler_test.go b/app/proxyman/outbound/handler_test.go index a557ddc0..9a9bf45c 100644 --- a/app/proxyman/outbound/handler_test.go +++ b/app/proxyman/outbound/handler_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/app/policy" . "github.com/xtls/xray-core/app/proxyman/outbound" "github.com/xtls/xray-core/app/stats" @@ -12,7 +14,6 @@ import ( core "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/outbound" "github.com/xtls/xray-core/proxy/freedom" - "github.com/xtls/xray-core/transport/internet" ) func TestInterfaces(t *testing.T) { @@ -44,9 +45,9 @@ func TestOutboundWithoutStatCounter(t *testing.T) { ProxySettings: serial.ToTypedMessage(&freedom.Config{}), }) conn, _ := h.(*Handler).Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146)) - _, ok := conn.(*internet.StatCouterConnection) + _, ok := conn.(*stat.CounterConnection) if ok { - t.Errorf("Expected conn to not be StatCouterConnection") + t.Errorf("Expected conn to not be CounterConnection") } } @@ -73,8 +74,8 @@ func TestOutboundWithStatCounter(t *testing.T) { ProxySettings: serial.ToTypedMessage(&freedom.Config{}), }) conn, _ := h.(*Handler).Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146)) - _, ok := conn.(*internet.StatCouterConnection) + _, ok := conn.(*stat.CounterConnection) if !ok { - t.Errorf("Expected conn to be StatCouterConnection") + t.Errorf("Expected conn to be CounterConnection") } } diff --git a/common/buf/io.go b/common/buf/io.go index 2a4cd670..a1d4a861 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -6,6 +6,9 @@ import ( "os" "syscall" "time" + + "github.com/xtls/xray-core/features/stats" + "github.com/xtls/xray-core/transport/internet/stat" ) // Reader extends io.Reader with MultiBuffer. @@ -29,9 +32,17 @@ type Writer interface { } // WriteAllBytes ensures all bytes are written into the given writer. -func WriteAllBytes(writer io.Writer, payload []byte) error { +func WriteAllBytes(writer io.Writer, payload []byte, c stats.Counter) error { + wc := 0 + defer func() { + if c != nil { + c.Add(int64(wc)) + } + }() + for len(payload) > 0 { n, err := writer.Write(payload) + wc += n if err != nil { return err } @@ -60,12 +71,18 @@ func NewReader(reader io.Reader) Reader { _, isFile := reader.(*os.File) if !isFile && useReadv { + var counter stats.Counter + + if statConn, ok := reader.(*stat.CounterConnection); ok { + reader = statConn.Connection + counter = statConn.ReadCounter + } if sc, ok := reader.(syscall.Conn); ok { rawConn, err := sc.SyscallConn() if err != nil { newError("failed to get sysconn").Base(err).WriteToLog() } else { - return NewReadVReader(reader, rawConn) + return NewReadVReader(reader, rawConn, counter) } } } @@ -104,13 +121,24 @@ func NewWriter(writer io.Writer) Writer { return mw } - if isPacketWriter(writer) { + var iConn = writer + if statConn, ok := writer.(*stat.CounterConnection); ok { + iConn = statConn.Connection + } + + if isPacketWriter(iConn) { return &SequentialWriter{ Writer: writer, } } + var counter stats.Counter + + if statConn, ok := writer.(*stat.CounterConnection); ok { + counter = statConn.WriteCounter + } return &BufferToBytesWriter{ - Writer: writer, + Writer: iConn, + counter: counter, } } diff --git a/common/buf/readv_reader.go b/common/buf/readv_reader.go index 0813f230..14e93571 100644 --- a/common/buf/readv_reader.go +++ b/common/buf/readv_reader.go @@ -6,6 +6,8 @@ import ( "io" "syscall" + "github.com/xtls/xray-core/features/stats" + "github.com/xtls/xray-core/common/platform" ) @@ -53,17 +55,19 @@ type ReadVReader struct { rawConn syscall.RawConn mr multiReader alloc allocStrategy + counter stats.Counter } // NewReadVReader creates a new ReadVReader. -func NewReadVReader(reader io.Reader, rawConn syscall.RawConn) *ReadVReader { +func NewReadVReader(reader io.Reader, rawConn syscall.RawConn, counter stats.Counter) *ReadVReader { return &ReadVReader{ Reader: reader, rawConn: rawConn, alloc: allocStrategy{ current: 1, }, - mr: newMultiReader(), + mr: newMultiReader(), + counter: counter, } } @@ -122,10 +126,16 @@ func (r *ReadVReader) ReadMultiBuffer() (MultiBuffer, error) { if b.IsFull() { r.alloc.Adjust(1) } + if r.counter != nil && b != nil { + r.counter.Add(int64(b.Len())) + } return MultiBuffer{b}, err } mb, err := r.readMulti() + if r.counter != nil && mb != nil { + r.counter.Add(int64(mb.Len())) + } if err != nil { return nil, err } diff --git a/common/buf/readv_test.go b/common/buf/readv_test.go index d8186659..6bca9491 100644 --- a/common/buf/readv_test.go +++ b/common/buf/readv_test.go @@ -50,7 +50,7 @@ func TestReadvReader(t *testing.T) { rawConn, err := conn.(*net.TCPConn).SyscallConn() common.Must(err) - reader := NewReadVReader(conn, rawConn) + reader := NewReadVReader(conn, rawConn, nil) var rmb MultiBuffer for { mb, err := reader.ReadMultiBuffer() diff --git a/common/buf/writer.go b/common/buf/writer.go index a3bfe560..6faa1959 100644 --- a/common/buf/writer.go +++ b/common/buf/writer.go @@ -5,6 +5,8 @@ import ( "net" "sync" + "github.com/xtls/xray-core/features/stats" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" ) @@ -13,7 +15,8 @@ import ( type BufferToBytesWriter struct { io.Writer - cache [][]byte + counter stats.Counter + cache [][]byte } // WriteMultiBuffer implements Writer. This method takes ownership of the given buffer. @@ -26,7 +29,7 @@ func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error { } if len(mb) == 1 { - return WriteAllBytes(w.Writer, mb[0].Bytes()) + return WriteAllBytes(w.Writer, mb[0].Bytes(), w.counter) } if cap(w.cache) < len(mb) { @@ -45,9 +48,15 @@ func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error { }() nb := net.Buffers(bs) - + wc := int64(0) + defer func() { + if w.counter != nil { + w.counter.Add(wc) + } + }() for size > 0 { n, err := nb.WriteTo(w.Writer) + wc += n if err != nil { return err } @@ -173,7 +182,7 @@ func (w *BufferedWriter) flushInternal() error { w.buffer = nil if writer, ok := w.writer.(io.Writer); ok { - err := WriteAllBytes(writer, b.Bytes()) + err := WriteAllBytes(writer, b.Bytes(), nil) b.Release() return err } diff --git a/common/crypto/io.go b/common/crypto/io.go index 9800c6a5..59e16660 100644 --- a/common/crypto/io.go +++ b/common/crypto/io.go @@ -50,7 +50,7 @@ func NewCryptionWriter(stream cipher.Stream, writer io.Writer) *CryptionWriter { func (w *CryptionWriter) Write(data []byte) (int, error) { w.stream.XORKeyStream(data, data) - if err := buf.WriteAllBytes(w.writer, data); err != nil { + if err := buf.WriteAllBytes(w.writer, data, nil); err != nil { return 0, err } return len(data), nil diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index 01c02af6..ccce6283 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -5,6 +5,8 @@ import ( "io" "sync" + "github.com/xtls/xray-core/transport/internet/stat" + "golang.org/x/net/dns/dnsmessage" "github.com/xtls/xray-core/common" @@ -104,7 +106,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet. newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx)) conn := &outboundConn{ - dialer: func() (internet.Connection, error) { + dialer: func() (stat.Connection, error) { return d.Dial(ctx, dest) }, connReady: make(chan struct{}, 1), @@ -266,7 +268,7 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, type outboundConn struct { access sync.Mutex - dialer func() (internet.Connection, error) + dialer func() (stat.Connection, error) conn net.Conn connReady chan struct{} diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 5b6b0948..999c546f 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -7,6 +7,8 @@ import ( "sync/atomic" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/log" @@ -18,7 +20,6 @@ import ( "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" - "github.com/xtls/xray-core/transport/internet" ) func init() { @@ -76,7 +77,7 @@ type hasHandshakeAddress interface { } // Process implements proxy.Inbound. -func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error { +func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { newError("processing connection from: ", conn.RemoteAddr()).AtDebug().WriteToLog(session.ExportIDToError(ctx)) dest := net.Destination{ Network: network, diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index a1f99cc4..21d00414 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -6,6 +6,8 @@ import ( "context" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/dice" @@ -121,7 +123,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte input := link.Reader output := link.Writer - var conn internet.Connection + var conn stat.Connection err := retry.ExponentialBackoff(5, 100).On(func() error { dialDest := destination if h.config.useIP() && dialDest.Address.Family().IsDomain() { @@ -194,7 +196,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte func NewPacketReader(conn net.Conn, UDPOverride net.Destination) buf.Reader { iConn := conn - statConn, ok := iConn.(*internet.StatCouterConnection) + statConn, ok := iConn.(*stat.CounterConnection) if ok { iConn = statConn.Connection } @@ -238,7 +240,7 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context, UDPOverride net.Destination) buf.Writer { iConn := conn - statConn, ok := iConn.(*internet.StatCouterConnection) + statConn, ok := iConn.(*stat.CounterConnection) if ok { iConn = statConn.Connection } diff --git a/proxy/http/client.go b/proxy/http/client.go index 2d7535a5..21247985 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -9,6 +9,8 @@ import ( "net/url" "sync" + "github.com/xtls/xray-core/transport/internet/stat" + "golang.org/x/net/http2" "github.com/xtls/xray-core/common" @@ -77,7 +79,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter } var user *protocol.MemoryUser - var conn internet.Connection + var conn stat.Connection mbuf, _ := link.Reader.ReadMultiBuffer() len := mbuf.Len() @@ -101,7 +103,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter return err } } - conn = internet.Connection(netConn) + conn = stat.Connection(netConn) } return err }); err != nil { @@ -231,7 +233,7 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u } iConn := rawConn - if statConn, ok := iConn.(*internet.StatCouterConnection); ok { + if statConn, ok := iConn.(*stat.CounterConnection); ok { iConn = statConn.Connection } diff --git a/proxy/http/server.go b/proxy/http/server.go index 6d33cdac..414eb2bc 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -9,6 +9,8 @@ import ( "strings" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -22,7 +24,6 @@ import ( "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" - "github.com/xtls/xray-core/transport/internet" ) // Server is an HTTP proxy server. @@ -82,7 +83,7 @@ type readerOnly struct { io.Reader } -func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error { +func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) if inbound != nil { inbound.User = &protocol.MemoryUser{ @@ -157,7 +158,7 @@ Start: return err } -func (s *Server) handleConnect(ctx context.Context, _ *http.Request, reader *bufio.Reader, conn internet.Connection, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error { +func (s *Server) handleConnect(ctx context.Context, _ *http.Request, reader *bufio.Reader, conn stat.Connection, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error { _, err := conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")) if err != nil { return newError("failed to write back OK response").Base(err) diff --git a/proxy/mtproto/server.go b/proxy/mtproto/server.go index 8812bb6f..c6c7b7ce 100644 --- a/proxy/mtproto/server.go +++ b/proxy/mtproto/server.go @@ -5,6 +5,8 @@ import ( "context" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/crypto" @@ -16,7 +18,6 @@ import ( "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" - "github.com/xtls/xray-core/transport/internet" ) var ( @@ -76,7 +77,7 @@ func isValidConnectionType(c [4]byte) bool { return false } -func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error { +func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { sPolicy := s.policy.ForLevel(s.user.Level) if err := conn.SetDeadline(time.Now().Add(sPolicy.Timeouts.Handshake)); err != nil { diff --git a/proxy/proxy.go b/proxy/proxy.go index a32a1ef1..01960cdf 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -8,6 +8,8 @@ package proxy import ( "context" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/features/routing" @@ -21,7 +23,7 @@ type Inbound interface { Network() []net.Network // Process processes a connection of given network. If necessary, the Inbound can dispatch the connection to an Outbound. - Process(context.Context, net.Network, internet.Connection, routing.Dispatcher) error + Process(context.Context, net.Network, stat.Connection, routing.Dispatcher) error } // An Outbound process outbound connections. diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index 374bf8d4..67b2216d 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -3,6 +3,8 @@ package shadowsocks import ( "context" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/net" @@ -55,7 +57,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter network := destination.Network var server *protocol.ServerSpec - var conn internet.Connection + var conn stat.Connection err := retry.ExponentialBackoff(5, 100).On(func() error { server = c.serverPicker.PickServer() diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index 69db3ba2..6c2f7c65 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -173,7 +173,7 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri if account.Cipher.IVSize() > 0 { iv = make([]byte, account.Cipher.IVSize()) common.Must2(rand.Read(iv)) - if err := buf.WriteAllBytes(writer, iv); err != nil { + if err := buf.WriteAllBytes(writer, iv, nil); err != nil { return nil, newError("failed to write IV") } } @@ -218,7 +218,7 @@ func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Wr if account.Cipher.IVSize() > 0 { iv = make([]byte, account.Cipher.IVSize()) common.Must2(rand.Read(iv)) - if err := buf.WriteAllBytes(writer, iv); err != nil { + if err := buf.WriteAllBytes(writer, iv, nil); err != nil { return nil, newError("failed to write IV.").Base(err) } } diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index c34798f6..523bac9f 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -4,6 +4,8 @@ import ( "context" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/log" @@ -16,7 +18,6 @@ import ( "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" - "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/udp" ) @@ -70,7 +71,7 @@ func (s *Server) Network() []net.Network { return list } -func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error { +func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { switch network { case net.Network_TCP: return s.handleConnection(ctx, conn, dispatcher) @@ -81,7 +82,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet } } -func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error { +func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { request := protocol.RequestHeaderFromContext(ctx) if request == nil { @@ -185,7 +186,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, return nil } -func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error { +func (s *Server) handleConnection(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { sessionPolicy := s.policyManager.ForLevel(0) if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil { return newError("unable to set read deadline").Base(err).AtWarning() diff --git a/proxy/socks/client.go b/proxy/socks/client.go index 77c306b5..22eb3650 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -4,6 +4,8 @@ import ( "context" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/net" @@ -59,7 +61,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter // Outbound server's destination. var dest net.Destination // Connection to the outbound server. - var conn internet.Connection + var conn stat.Connection if err := retry.ExponentialBackoff(5, 100).On(func() error { server = c.serverPicker.PickServer() diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index 29b9c913..3f371d57 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -293,7 +293,7 @@ func hasAuthMethod(expectedAuth byte, authCandidates []byte) bool { } func writeSocks5AuthenticationResponse(writer io.Writer, version byte, auth byte) error { - return buf.WriteAllBytes(writer, []byte{version, auth}) + return buf.WriteAllBytes(writer, []byte{version, auth}, nil) } func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error { @@ -305,7 +305,7 @@ func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, po return err } - return buf.WriteAllBytes(writer, buffer.Bytes()) + return buf.WriteAllBytes(writer, buffer.Bytes(), nil) } func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error { @@ -317,7 +317,7 @@ func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, po portBytes := buffer.Extend(2) binary.BigEndian.PutUint16(portBytes, port.Value()) common.Must2(buffer.Write(address.IP())) - return buf.WriteAllBytes(writer, buffer.Bytes()) + return buf.WriteAllBytes(writer, buffer.Bytes(), nil) } func DecodeUDPPacket(packet *buf.Buffer) (*protocol.RequestHeader, error) { @@ -422,7 +422,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i defer b.Release() common.Must2(b.Write([]byte{socks5Version, 0x01, authByte})) - if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil { + if err := buf.WriteAllBytes(writer, b.Bytes(), nil); err != nil { return nil, err } @@ -446,7 +446,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i common.Must2(b.WriteString(account.Username)) common.Must(b.WriteByte(byte(len(account.Password)))) common.Must2(b.WriteString(account.Password)) - if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil { + if err := buf.WriteAllBytes(writer, b.Bytes(), nil); err != nil { return nil, err } @@ -474,7 +474,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i } } - if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil { + if err := buf.WriteAllBytes(writer, b.Bytes(), nil); err != nil { return nil, err } diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 8aeb568a..0225c79c 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -5,6 +5,8 @@ import ( "io" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/log" @@ -18,7 +20,6 @@ import ( "github.com/xtls/xray-core/features" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" - "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/udp" ) @@ -62,7 +63,7 @@ func (s *Server) Network() []net.Network { } // Process implements proxy.Inbound. -func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error { +func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { if inbound := session.InboundFromContext(ctx); inbound != nil { inbound.User = &protocol.MemoryUser{ Level: s.config.UserLevel, @@ -79,7 +80,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet } } -func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error { +func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { plcy := s.policy() if err := conn.SetReadDeadline(time.Now().Add(plcy.Timeouts.Handshake)); err != nil { newError("failed to set deadline").Base(err).WriteToLog(session.ExportIDToError(ctx)) @@ -191,7 +192,7 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ return nil } -func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error { +func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { payload := packet.Payload newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx)) diff --git a/proxy/trojan/client.go b/proxy/trojan/client.go index 272c2f04..cc674d83 100644 --- a/proxy/trojan/client.go +++ b/proxy/trojan/client.go @@ -5,6 +5,8 @@ import ( "syscall" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -61,7 +63,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter network := destination.Network var server *protocol.ServerSpec - var conn internet.Connection + var conn stat.Connection err := retry.ExponentialBackoff(5, 100).On(func() error { server = c.serverPicker.PickServer() @@ -81,7 +83,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter defer conn.Close() iConn := conn - statConn, ok := iConn.(*internet.StatCouterConnection) + statConn, ok := iConn.(*stat.CounterConnection) if ok { iConn = statConn.Connection } diff --git a/proxy/trojan/protocol.go b/proxy/trojan/protocol.go index 1273b0a2..0befb927 100644 --- a/proxy/trojan/protocol.go +++ b/proxy/trojan/protocol.go @@ -8,6 +8,8 @@ import ( "runtime" "syscall" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" @@ -15,7 +17,6 @@ import ( "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/features/stats" - "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/xtls" ) @@ -298,7 +299,7 @@ func ReadV(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, c if sctx != nil { if inbound := session.InboundFromContext(sctx); inbound != nil && inbound.Conn != nil { iConn := inbound.Conn - statConn, ok := iConn.(*internet.StatCouterConnection) + statConn, ok := iConn.(*stat.CounterConnection) if ok { iConn = statConn.Connection } @@ -325,7 +326,7 @@ func ReadV(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, c //panic("XTLS Splice: nil inbound or nil inbound.Conn") } } - reader = buf.NewReadVReader(conn.Connection, rawConn) + reader = buf.NewReadVReader(conn.Connection, rawConn, nil) ct = counter if conn.SHOW { fmt.Println(conn.MARK, "ReadV") diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go index 53056307..0642b764 100644 --- a/proxy/trojan/server.go +++ b/proxy/trojan/server.go @@ -9,6 +9,8 @@ import ( "syscall" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -25,7 +27,6 @@ import ( "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/features/stats" - "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/udp" "github.com/xtls/xray-core/transport/internet/xtls" ) @@ -141,11 +142,11 @@ func (s *Server) Network() []net.Network { } // Process implements proxy.Inbound.Process(). -func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error { +func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { sid := session.ExportIDToError(ctx) iConn := conn - statConn, ok := iConn.(*internet.StatCouterConnection) + statConn, ok := iConn.(*stat.CounterConnection) if ok { iConn = statConn.Connection } @@ -343,7 +344,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session, destination net.Destination, clientReader buf.Reader, - clientWriter buf.Writer, dispatcher routing.Dispatcher, iConn internet.Connection, rawConn syscall.RawConn, statConn *internet.StatCouterConnection) error { + clientWriter buf.Writer, dispatcher routing.Dispatcher, iConn stat.Connection, rawConn syscall.RawConn, statConn *stat.CounterConnection) error { ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) @@ -391,7 +392,7 @@ func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Sess return nil } -func (s *Server) fallback(ctx context.Context, sid errors.ExportOption, err error, sessionPolicy policy.Session, connection internet.Connection, iConn internet.Connection, napfb map[string]map[string]map[string]*Fallback, first *buf.Buffer, firstLen int64, reader buf.Reader) error { +func (s *Server) fallback(ctx context.Context, sid errors.ExportOption, err error, sessionPolicy policy.Session, connection stat.Connection, iConn stat.Connection, napfb map[string]map[string]map[string]*Fallback, first *buf.Buffer, firstLen int64, reader buf.Reader) error { if err := connection.SetReadDeadline(time.Time{}); err != nil { newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid) } diff --git a/proxy/vless/encoding/encoding.go b/proxy/vless/encoding/encoding.go index 77646e6b..eb7999ce 100644 --- a/proxy/vless/encoding/encoding.go +++ b/proxy/vless/encoding/encoding.go @@ -9,6 +9,8 @@ import ( "runtime" "syscall" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" @@ -17,7 +19,6 @@ import ( "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/proxy/vless" - "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/xtls" ) @@ -185,7 +186,7 @@ func ReadV(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, c if sctx != nil { if inbound := session.InboundFromContext(sctx); inbound != nil && inbound.Conn != nil { iConn := inbound.Conn - statConn, ok := iConn.(*internet.StatCouterConnection) + statConn, ok := iConn.(*stat.CounterConnection) if ok { iConn = statConn.Connection } @@ -212,7 +213,7 @@ func ReadV(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, c //panic("XTLS Splice: nil inbound or nil inbound.Conn") } } - reader = buf.NewReadVReader(conn.Connection, rawConn) + reader = buf.NewReadVReader(conn.Connection, rawConn, nil) ct = counter if conn.SHOW { fmt.Println(conn.MARK, "ReadV") diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 73097526..5b21a109 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -10,6 +10,8 @@ import ( "syscall" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -29,7 +31,6 @@ import ( "github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/proxy/vless" "github.com/xtls/xray-core/proxy/vless/encoding" - "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/tls" "github.com/xtls/xray-core/transport/internet/xtls" ) @@ -172,11 +173,11 @@ func (*Handler) Network() []net.Network { } // Process implements proxy.Inbound.Process(). -func (h *Handler) Process(ctx context.Context, network net.Network, connection internet.Connection, dispatcher routing.Dispatcher) error { +func (h *Handler) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { sid := session.ExportIDToError(ctx) iConn := connection - statConn, ok := iConn.(*internet.StatCouterConnection) + statConn, ok := iConn.(*stat.CounterConnection) if ok { iConn = statConn.Connection } diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index 750a3940..ffb36837 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -7,6 +7,8 @@ import ( "syscall" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/net" @@ -77,7 +79,7 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Process implements proxy.Outbound.Process(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { var rec *protocol.ServerSpec - var conn internet.Connection + var conn stat.Connection if err := retry.ExponentialBackoff(5, 200).On(func() error { rec = h.serverPicker.PickServer() @@ -93,7 +95,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte defer conn.Close() iConn := conn - statConn, ok := iConn.(*internet.StatCouterConnection) + statConn, ok := iConn.(*stat.CounterConnection) if ok { iConn = statConn.Connection } diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 3448326b..056f32e6 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -9,6 +9,8 @@ import ( "sync" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -26,7 +28,6 @@ import ( "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/proxy/vmess" "github.com/xtls/xray-core/proxy/vmess/encoding" - "github.com/xtls/xray-core/transport/internet" ) var ( @@ -220,14 +221,14 @@ func isInsecureEncryption(s protocol.SecurityType) bool { } // Process implements proxy.Inbound.Process(). -func (h *Handler) Process(ctx context.Context, network net.Network, connection internet.Connection, dispatcher routing.Dispatcher) error { +func (h *Handler) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { sessionPolicy := h.policyManager.ForLevel(0) if err := connection.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil { return newError("unable to set read deadline").Base(err).AtWarning() } iConn := connection - if statConn, ok := iConn.(*internet.StatCouterConnection); ok { + if statConn, ok := iConn.(*stat.CounterConnection); ok { iConn = statConn.Connection } _, isDrain := iConn.(*net.TCPConn) diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 48462e4a..fbdca3c5 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -6,6 +6,8 @@ import ( "context" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/net" @@ -57,7 +59,7 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Process implements proxy.Outbound.Process(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { var rec *protocol.ServerSpec - var conn internet.Connection + var conn stat.Connection err := retry.ExponentialBackoff(5, 200).On(func() error { rec = h.serverPicker.PickServer() diff --git a/testing/mocks/proxy.go b/testing/mocks/proxy.go index cba5b3ca..19c36aae 100644 --- a/testing/mocks/proxy.go +++ b/testing/mocks/proxy.go @@ -6,12 +6,14 @@ package mocks import ( context "context" + reflect "reflect" + gomock "github.com/golang/mock/gomock" net "github.com/xtls/xray-core/common/net" routing "github.com/xtls/xray-core/features/routing" transport "github.com/xtls/xray-core/transport" internet "github.com/xtls/xray-core/transport/internet" - reflect "reflect" + "github.com/xtls/xray-core/transport/internet/stat" ) // ProxyInbound is a mock of Inbound interface @@ -52,7 +54,7 @@ func (mr *ProxyInboundMockRecorder) Network() *gomock.Call { } // Process mocks base method -func (m *ProxyInbound) Process(arg0 context.Context, arg1 net.Network, arg2 internet.Connection, arg3 routing.Dispatcher) error { +func (m *ProxyInbound) Process(arg0 context.Context, arg1 net.Network, arg2 stat.Connection, arg3 routing.Dispatcher) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Process", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) diff --git a/transport/internet/dialer.go b/transport/internet/dialer.go index 68c88752..fda06a52 100644 --- a/transport/internet/dialer.go +++ b/transport/internet/dialer.go @@ -5,6 +5,8 @@ import ( "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/dice" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net/cnc" "github.com/xtls/xray-core/common/session" @@ -17,14 +19,14 @@ import ( // Dialer is the interface for dialing outbound connections. type Dialer interface { // Dial dials a system connection to the given destination. - Dial(ctx context.Context, destination net.Destination) (Connection, error) + Dial(ctx context.Context, destination net.Destination) (stat.Connection, error) // Address returns the address used by this Dialer. Maybe nil if not known. Address() net.Address } // dialFunc is an interface to dial network connection to a specific destination. -type dialFunc func(ctx context.Context, dest net.Destination, streamSettings *MemoryStreamConfig) (Connection, error) +type dialFunc func(ctx context.Context, dest net.Destination, streamSettings *MemoryStreamConfig) (stat.Connection, error) var ( transportDialerCache = make(map[string]dialFunc) @@ -40,7 +42,7 @@ func RegisterTransportDialer(protocol string, dialer dialFunc) error { } // Dial dials a internet connection towards the given destination. -func Dial(ctx context.Context, dest net.Destination, streamSettings *MemoryStreamConfig) (Connection, error) { +func Dial(ctx context.Context, dest net.Destination, streamSettings *MemoryStreamConfig) (stat.Connection, error) { if dest.Network == net.Network_TCP { if streamSettings == nil { s, err := ToMemoryStreamConfig(nil) diff --git a/transport/internet/domainsocket/dial.go b/transport/internet/domainsocket/dial.go index 91e1941b..a539b769 100644 --- a/transport/internet/domainsocket/dial.go +++ b/transport/internet/domainsocket/dial.go @@ -6,6 +6,8 @@ package domainsocket import ( "context" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/transport/internet" @@ -13,7 +15,7 @@ import ( "github.com/xtls/xray-core/transport/internet/xtls" ) -func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { +func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { settings := streamSettings.ProtocolSettings.(*Config) addr, err := settings.GetUnixAddr() if err != nil { diff --git a/transport/internet/domainsocket/listener.go b/transport/internet/domainsocket/listener.go index 377a0223..492a8cb4 100644 --- a/transport/internet/domainsocket/listener.go +++ b/transport/internet/domainsocket/listener.go @@ -9,6 +9,8 @@ import ( "os" "strings" + "github.com/xtls/xray-core/transport/internet/stat" + goxtls "github.com/xtls/go" "golang.org/x/sys/unix" @@ -98,7 +100,7 @@ func (ln *Listener) run() { conn = xtls.Server(conn, ln.xtlsConfig) } - ln.addConn(internet.Connection(conn)) + ln.addConn(stat.Connection(conn)) } } diff --git a/transport/internet/domainsocket/listener_test.go b/transport/internet/domainsocket/listener_test.go index 8c14604f..0f04e240 100644 --- a/transport/internet/domainsocket/listener_test.go +++ b/transport/internet/domainsocket/listener_test.go @@ -8,6 +8,8 @@ import ( "runtime" "testing" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/net" @@ -23,7 +25,7 @@ func TestListen(t *testing.T) { Path: "/tmp/ts3", }, } - listener, err := Listen(ctx, nil, net.Port(0), streamSettings, func(conn internet.Connection) { + listener, err := Listen(ctx, nil, net.Port(0), streamSettings, func(conn stat.Connection) { defer conn.Close() b := buf.New() @@ -64,7 +66,7 @@ func TestListenAbstract(t *testing.T) { Abstract: true, }, } - listener, err := Listen(ctx, nil, net.Port(0), streamSettings, func(conn internet.Connection) { + listener, err := Listen(ctx, nil, net.Port(0), streamSettings, func(conn stat.Connection) { defer conn.Close() b := buf.New() diff --git a/transport/internet/grpc/dial.go b/transport/internet/grpc/dial.go index 4d69816b..3b1c6c36 100644 --- a/transport/internet/grpc/dial.go +++ b/transport/internet/grpc/dial.go @@ -17,17 +17,18 @@ import ( "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/grpc/encoding" + "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/tls" ) -func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { +func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx)) conn, err := dialgRPC(ctx, dest, streamSettings) if err != nil { return nil, newError("failed to dial gRPC").Base(err) } - return internet.Connection(conn), nil + return stat.Connection(conn), nil } func init() { diff --git a/transport/internet/headers/http/http.go b/transport/internet/headers/http/http.go index 9e7f0f60..b9afb230 100644 --- a/transport/internet/headers/http/http.go +++ b/transport/internet/headers/http/http.go @@ -151,7 +151,7 @@ func (w *HeaderWriter) Write(writer io.Writer) error { if w.header == nil { return nil } - err := buf.WriteAllBytes(writer, w.header.Bytes()) + err := buf.WriteAllBytes(writer, w.header.Bytes(), nil) w.header.Release() w.header = nil return err diff --git a/transport/internet/http/dialer.go b/transport/internet/http/dialer.go index db3be736..60034b7b 100644 --- a/transport/internet/http/dialer.go +++ b/transport/internet/http/dialer.go @@ -8,6 +8,8 @@ import ( "sync" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/net" @@ -110,7 +112,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in } // Dial dials a new TCP connection to the given destination. -func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { +func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { httpSettings := streamSettings.ProtocolSettings.(*Config) client, err := getHTTPClient(ctx, dest, streamSettings) if err != nil { diff --git a/transport/internet/http/http_test.go b/transport/internet/http/http_test.go index ca1d9f2b..91aab580 100644 --- a/transport/internet/http/http_test.go +++ b/transport/internet/http/http_test.go @@ -6,6 +6,8 @@ import ( "testing" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/google/go-cmp/cmp" "github.com/xtls/xray-core/common" @@ -28,7 +30,7 @@ func TestHTTPConnection(t *testing.T) { SecuritySettings: &tls.Config{ Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("www.example.com")))}, }, - }, func(conn internet.Connection) { + }, func(conn stat.Connection) { go func() { defer conn.Close() diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index 9a60efaf..cd30b6ff 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -5,6 +5,8 @@ import ( "io" "sync/atomic" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/dice" @@ -46,7 +48,7 @@ func fetchInput(_ context.Context, input io.Reader, reader PacketReader, conn *C } // DialKCP dials a new KCP connections to the specific destination. -func DialKCP(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { +func DialKCP(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { dest.Network = net.Network_UDP newError("dialing mKCP to ", dest).WriteToLog() @@ -84,7 +86,7 @@ func DialKCP(ctx context.Context, dest net.Destination, streamSettings *internet go fetchInput(ctx, rawConn, reader, session) - var iConn internet.Connection = session + var iConn stat.Connection = session if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { iConn = tls.Client(iConn, config.GetTLSConfig(tls.WithDestination(dest))) diff --git a/transport/internet/kcp/kcp_test.go b/transport/internet/kcp/kcp_test.go index c6086bc5..5346c0ae 100644 --- a/transport/internet/kcp/kcp_test.go +++ b/transport/internet/kcp/kcp_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/google/go-cmp/cmp" "golang.org/x/sync/errgroup" @@ -21,8 +23,8 @@ func TestDialAndListen(t *testing.T) { listerner, err := NewListener(context.Background(), net.LocalHostIP, net.Port(0), &internet.MemoryStreamConfig{ ProtocolName: "mkcp", ProtocolSettings: &Config{}, - }, func(conn internet.Connection) { - go func(c internet.Connection) { + }, func(conn stat.Connection) { + go func(c stat.Connection) { payload := make([]byte, 4096) for { nBytes, err := c.Read(payload) diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index 0c1e1790..e6f67315 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -6,6 +6,8 @@ import ( gotls "crypto/tls" "sync" + "github.com/xtls/xray-core/transport/internet/stat" + goxtls "github.com/xtls/go" "github.com/xtls/xray-core/common" @@ -134,7 +136,7 @@ func (l *Listener) OnReceive(payload *buf.Buffer, src net.Destination) { Security: l.security, Writer: writer, }, writer, l.config) - var netConn internet.Connection = conn + var netConn stat.Connection = conn if l.tlsConfig != nil { netConn = tls.Server(conn, l.tlsConfig) } else if l.xtlsConfig != nil { diff --git a/transport/internet/quic/dialer.go b/transport/internet/quic/dialer.go index a1d5488f..3c5dd0ec 100644 --- a/transport/internet/quic/dialer.go +++ b/transport/internet/quic/dialer.go @@ -5,6 +5,8 @@ import ( "sync" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/lucas-clemente/quic-go" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/net" @@ -114,7 +116,7 @@ func (s *clientSessions) cleanSessions() error { return nil } -func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) { +func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (stat.Connection, error) { s.access.Lock() defer s.access.Unlock() @@ -182,7 +184,7 @@ func init() { common.Must(client.cleanup.Start()) } -func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { +func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { tlsConfig := tls.ConfigFromStreamSettings(streamSettings) if tlsConfig == nil { tlsConfig = &tls.Config{ diff --git a/transport/internet/quic/quic_test.go b/transport/internet/quic/quic_test.go index bf353d83..3354f9cc 100644 --- a/transport/internet/quic/quic_test.go +++ b/transport/internet/quic/quic_test.go @@ -6,6 +6,8 @@ import ( "testing" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/google/go-cmp/cmp" "github.com/xtls/xray-core/common" @@ -37,7 +39,7 @@ func TestQuicConnection(t *testing.T) { ), }, }, - }, func(conn internet.Connection) { + }, func(conn stat.Connection) { go func() { defer conn.Close() @@ -100,7 +102,7 @@ func TestQuicConnectionWithoutTLS(t *testing.T) { listener, err := quic.Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{ ProtocolName: "quic", ProtocolSettings: &quic.Config{}, - }, func(conn internet.Connection) { + }, func(conn stat.Connection) { go func() { defer conn.Close() @@ -164,7 +166,7 @@ func TestQuicConnectionAuthHeader(t *testing.T) { Type: protocol.SecurityType_AES128_GCM, }, }, - }, func(conn internet.Connection) { + }, func(conn stat.Connection) { go func() { defer conn.Close() diff --git a/transport/internet/connection.go b/transport/internet/stat/connection.go similarity index 71% rename from transport/internet/connection.go rename to transport/internet/stat/connection.go index 4c04f89e..6921943d 100644 --- a/transport/internet/connection.go +++ b/transport/internet/stat/connection.go @@ -1,4 +1,4 @@ -package internet +package stat import ( "net" @@ -10,13 +10,13 @@ type Connection interface { net.Conn } -type StatCouterConnection struct { +type CounterConnection struct { Connection ReadCounter stats.Counter WriteCounter stats.Counter } -func (c *StatCouterConnection) Read(b []byte) (int, error) { +func (c *CounterConnection) Read(b []byte) (int, error) { nBytes, err := c.Connection.Read(b) if c.ReadCounter != nil { c.ReadCounter.Add(int64(nBytes)) @@ -25,7 +25,7 @@ func (c *StatCouterConnection) Read(b []byte) (int, error) { return nBytes, err } -func (c *StatCouterConnection) Write(b []byte) (int, error) { +func (c *CounterConnection) Write(b []byte) (int, error) { nBytes, err := c.Connection.Write(b) if c.WriteCounter != nil { c.WriteCounter.Add(int64(nBytes)) diff --git a/transport/internet/tcp/dialer.go b/transport/internet/tcp/dialer.go index 968acbeb..3d75a438 100644 --- a/transport/internet/tcp/dialer.go +++ b/transport/internet/tcp/dialer.go @@ -3,6 +3,8 @@ package tcp import ( "context" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/session" @@ -12,7 +14,7 @@ import ( ) // Dial dials a new TCP connection to the given destination. -func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { +func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { newError("dialing TCP to ", dest).WriteToLog(session.ExportIDToError(ctx)) conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) if err != nil { @@ -46,7 +48,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me } conn = auth.Client(conn) } - return internet.Connection(conn), nil + return stat.Connection(conn), nil } func init() { diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index 02042c3c..841638a8 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -6,6 +6,8 @@ import ( "strings" "time" + "github.com/xtls/xray-core/transport/internet/stat" + goxtls "github.com/xtls/go" "github.com/xtls/xray-core/common" @@ -120,7 +122,7 @@ func (v *Listener) keepAccepting() { conn = v.authConfig.Server(conn) } - v.addConn(internet.Connection(conn)) + v.addConn(stat.Connection(conn)) } } diff --git a/transport/internet/tcp/sockopt_freebsd.go b/transport/internet/tcp/sockopt_freebsd.go index 2a87c9e6..277c3ae2 100644 --- a/transport/internet/tcp/sockopt_freebsd.go +++ b/transport/internet/tcp/sockopt_freebsd.go @@ -5,10 +5,11 @@ package tcp import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/stat" ) // GetOriginalDestination from tcp conn -func GetOriginalDestination(conn internet.Connection) (net.Destination, error) { +func GetOriginalDestination(conn stat.Connection) (net.Destination, error) { la := conn.LocalAddr() ra := conn.RemoteAddr() ip, port, err := internet.OriginalDst(la, ra) diff --git a/transport/internet/tcp/sockopt_linux.go b/transport/internet/tcp/sockopt_linux.go index ebe53123..783cd5fe 100644 --- a/transport/internet/tcp/sockopt_linux.go +++ b/transport/internet/tcp/sockopt_linux.go @@ -6,13 +6,14 @@ import ( "syscall" "unsafe" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/transport/internet" ) const SO_ORIGINAL_DST = 80 -func GetOriginalDestination(conn internet.Connection) (net.Destination, error) { +func GetOriginalDestination(conn stat.Connection) (net.Destination, error) { sysrawconn, f := conn.(syscall.Conn) if !f { return net.Destination{}, newError("unable to get syscall.Conn") diff --git a/transport/internet/tcp/sockopt_other.go b/transport/internet/tcp/sockopt_other.go index ca64e03d..f1ad3f8b 100644 --- a/transport/internet/tcp/sockopt_other.go +++ b/transport/internet/tcp/sockopt_other.go @@ -4,9 +4,9 @@ package tcp import ( "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/stat" ) -func GetOriginalDestination(conn internet.Connection) (net.Destination, error) { +func GetOriginalDestination(conn stat.Connection) (net.Destination, error) { return net.Destination{}, nil } diff --git a/transport/internet/tcp_hub.go b/transport/internet/tcp_hub.go index fb5562dd..ffb81e95 100644 --- a/transport/internet/tcp_hub.go +++ b/transport/internet/tcp_hub.go @@ -3,6 +3,8 @@ package internet import ( "context" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common/net" ) @@ -18,7 +20,7 @@ func RegisterTransportListener(protocol string, listener ListenFunc) error { return nil } -type ConnHandler func(Connection) +type ConnHandler func(stat.Connection) type ListenFunc func(ctx context.Context, address net.Address, port net.Port, settings *MemoryStreamConfig, handler ConnHandler) (Listener, error) diff --git a/transport/internet/udp/dialer.go b/transport/internet/udp/dialer.go index 28306112..6babe3af 100644 --- a/transport/internet/udp/dialer.go +++ b/transport/internet/udp/dialer.go @@ -3,6 +3,8 @@ package udp import ( "context" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/transport/internet" @@ -10,7 +12,7 @@ import ( func init() { common.Must(internet.RegisterTransportDialer(protocolName, - func(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { + func(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { var sockopt *internet.SocketConfig if streamSettings != nil { sockopt = streamSettings.SocketSettings @@ -20,6 +22,6 @@ func init() { return nil, err } // TODO: handle dialer options - return internet.Connection(conn), nil + return stat.Connection(conn), nil })) } diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index 3dc56cc7..fd23177d 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -10,6 +10,8 @@ import ( "os" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/gorilla/websocket" "github.com/xtls/xray-core/common" @@ -41,7 +43,7 @@ func init() { } // Dial dials a WebSocket connection to the given destination. -func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { +func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx)) var conn net.Conn if streamSettings.ProtocolSettings.(*Config).Ed > 0 { @@ -59,7 +61,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me return nil, newError("failed to dial WebSocket").Base(err) } } - return internet.Connection(conn), nil + return stat.Connection(conn), nil } func init() { diff --git a/transport/internet/websocket/ws_test.go b/transport/internet/websocket/ws_test.go index 15a11449..6e82109d 100644 --- a/transport/internet/websocket/ws_test.go +++ b/transport/internet/websocket/ws_test.go @@ -6,6 +6,8 @@ import ( "testing" "time" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol/tls/cert" @@ -20,8 +22,8 @@ func Test_listenWSAndDial(t *testing.T) { ProtocolSettings: &Config{ Path: "ws", }, - }, func(conn internet.Connection) { - go func(c internet.Connection) { + }, func(conn stat.Connection) { + go func(c stat.Connection) { defer c.Close() var b [1024]byte @@ -75,8 +77,8 @@ func TestDialWithRemoteAddr(t *testing.T) { ProtocolSettings: &Config{ Path: "ws", }, - }, func(conn internet.Connection) { - go func(c internet.Connection) { + }, func(conn stat.Connection) { + go func(c stat.Connection) { defer c.Close() var b [1024]byte @@ -129,7 +131,7 @@ func Test_listenWSAndDial_TLS(t *testing.T) { Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))}, }, } - listen, err := ListenWS(context.Background(), net.LocalHostIP, 13143, streamSettings, func(conn internet.Connection) { + listen, err := ListenWS(context.Background(), net.LocalHostIP, 13143, streamSettings, func(conn stat.Connection) { go func() { _ = conn.Close() }()