Fix: CounterConnection with ReadV/WriteV (#720)

Co-authored-by: JimhHan <50871214+JimhHan@users.noreply.github.com>
This commit is contained in:
Arthur Morgan 2021-09-20 20:11:21 +08:00 committed by GitHub
parent f2cb13a8ec
commit 24b637cd5e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
53 changed files with 247 additions and 128 deletions

View file

@ -6,6 +6,8 @@ import (
"sync/atomic"
"time"
"github.com/xtls/xray-core/transport/internet/stat"
"github.com/xtls/xray-core/app/proxyman"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
@ -54,7 +56,7 @@ func getTProxyType(s *internet.MemoryStreamConfig) internet.SocketConfig_TProxyM
return s.SocketSettings.Tproxy
}
func (w *tcpWorker) callback(conn internet.Connection) {
func (w *tcpWorker) callback(conn stat.Connection) {
ctx, cancel := context.WithCancel(w.ctx)
sid := session.NewID()
ctx = session.ContextWithID(ctx, sid)
@ -80,7 +82,7 @@ func (w *tcpWorker) callback(conn internet.Connection) {
}
if w.uplinkCounter != nil || w.downlinkCounter != nil {
conn = &internet.StatCouterConnection{
conn = &stat.CounterConnection{
Connection: conn,
ReadCounter: w.uplinkCounter,
WriteCounter: w.downlinkCounter,
@ -117,7 +119,7 @@ func (w *tcpWorker) Proxy() proxy.Inbound {
func (w *tcpWorker) Start() error {
ctx := context.Background()
hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn internet.Connection) {
hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn stat.Connection) {
go w.callback(conn)
})
if err != nil {
@ -436,13 +438,13 @@ type dsWorker struct {
ctx context.Context
}
func (w *dsWorker) callback(conn internet.Connection) {
func (w *dsWorker) callback(conn stat.Connection) {
ctx, cancel := context.WithCancel(w.ctx)
sid := session.NewID()
ctx = session.ContextWithID(ctx, sid)
if w.uplinkCounter != nil || w.downlinkCounter != nil {
conn = &internet.StatCouterConnection{
conn = &stat.CounterConnection{
Connection: conn,
ReadCounter: w.uplinkCounter,
WriteCounter: w.downlinkCounter,
@ -482,7 +484,7 @@ func (w *dsWorker) Port() net.Port {
}
func (w *dsWorker) Start() error {
ctx := context.Background()
hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn internet.Connection) {
hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn stat.Connection) {
go w.callback(conn)
})
if err != nil {

View file

@ -3,6 +3,8 @@ package outbound
import (
"context"
"github.com/xtls/xray-core/transport/internet/stat"
"github.com/xtls/xray-core/app/proxyman"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/mux"
@ -158,7 +160,7 @@ func (h *Handler) Address() net.Address {
}
// Dial implements internet.Dialer.
func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) {
func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connection, error) {
if h.senderSettings != nil {
if h.senderSettings.ProxySettings.HasTag() {
tag := h.senderSettings.ProxySettings.Tag
@ -201,9 +203,9 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Conn
return h.getStatCouterConnection(conn), err
}
func (h *Handler) getStatCouterConnection(conn internet.Connection) internet.Connection {
func (h *Handler) getStatCouterConnection(conn stat.Connection) stat.Connection {
if h.uplinkCounter != nil || h.downlinkCounter != nil {
return &internet.StatCouterConnection{
return &stat.CounterConnection{
Connection: conn,
ReadCounter: h.downlinkCounter,
WriteCounter: h.uplinkCounter,

View file

@ -4,6 +4,8 @@ import (
"context"
"testing"
"github.com/xtls/xray-core/transport/internet/stat"
"github.com/xtls/xray-core/app/policy"
. "github.com/xtls/xray-core/app/proxyman/outbound"
"github.com/xtls/xray-core/app/stats"
@ -12,7 +14,6 @@ import (
core "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/proxy/freedom"
"github.com/xtls/xray-core/transport/internet"
)
func TestInterfaces(t *testing.T) {
@ -44,9 +45,9 @@ func TestOutboundWithoutStatCounter(t *testing.T) {
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
})
conn, _ := h.(*Handler).Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146))
_, ok := conn.(*internet.StatCouterConnection)
_, ok := conn.(*stat.CounterConnection)
if ok {
t.Errorf("Expected conn to not be StatCouterConnection")
t.Errorf("Expected conn to not be CounterConnection")
}
}
@ -73,8 +74,8 @@ func TestOutboundWithStatCounter(t *testing.T) {
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
})
conn, _ := h.(*Handler).Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146))
_, ok := conn.(*internet.StatCouterConnection)
_, ok := conn.(*stat.CounterConnection)
if !ok {
t.Errorf("Expected conn to be StatCouterConnection")
t.Errorf("Expected conn to be CounterConnection")
}
}