mirror of
https://github.com/XTLS/Xray-core.git
synced 2025-04-29 16:58:34 +00:00
Fix: CounterConnection with ReadV/WriteV (#720)
Co-authored-by: JimhHan <50871214+JimhHan@users.noreply.github.com>
This commit is contained in:
parent
f2cb13a8ec
commit
24b637cd5e
53 changed files with 247 additions and 128 deletions
|
@ -6,6 +6,9 @@ import (
|
|||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/xtls/xray-core/features/stats"
|
||||
"github.com/xtls/xray-core/transport/internet/stat"
|
||||
)
|
||||
|
||||
// Reader extends io.Reader with MultiBuffer.
|
||||
|
@ -29,9 +32,17 @@ type Writer interface {
|
|||
}
|
||||
|
||||
// WriteAllBytes ensures all bytes are written into the given writer.
|
||||
func WriteAllBytes(writer io.Writer, payload []byte) error {
|
||||
func WriteAllBytes(writer io.Writer, payload []byte, c stats.Counter) error {
|
||||
wc := 0
|
||||
defer func() {
|
||||
if c != nil {
|
||||
c.Add(int64(wc))
|
||||
}
|
||||
}()
|
||||
|
||||
for len(payload) > 0 {
|
||||
n, err := writer.Write(payload)
|
||||
wc += n
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -60,12 +71,18 @@ func NewReader(reader io.Reader) Reader {
|
|||
|
||||
_, isFile := reader.(*os.File)
|
||||
if !isFile && useReadv {
|
||||
var counter stats.Counter
|
||||
|
||||
if statConn, ok := reader.(*stat.CounterConnection); ok {
|
||||
reader = statConn.Connection
|
||||
counter = statConn.ReadCounter
|
||||
}
|
||||
if sc, ok := reader.(syscall.Conn); ok {
|
||||
rawConn, err := sc.SyscallConn()
|
||||
if err != nil {
|
||||
newError("failed to get sysconn").Base(err).WriteToLog()
|
||||
} else {
|
||||
return NewReadVReader(reader, rawConn)
|
||||
return NewReadVReader(reader, rawConn, counter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -104,13 +121,24 @@ func NewWriter(writer io.Writer) Writer {
|
|||
return mw
|
||||
}
|
||||
|
||||
if isPacketWriter(writer) {
|
||||
var iConn = writer
|
||||
if statConn, ok := writer.(*stat.CounterConnection); ok {
|
||||
iConn = statConn.Connection
|
||||
}
|
||||
|
||||
if isPacketWriter(iConn) {
|
||||
return &SequentialWriter{
|
||||
Writer: writer,
|
||||
}
|
||||
}
|
||||
|
||||
var counter stats.Counter
|
||||
|
||||
if statConn, ok := writer.(*stat.CounterConnection); ok {
|
||||
counter = statConn.WriteCounter
|
||||
}
|
||||
return &BufferToBytesWriter{
|
||||
Writer: writer,
|
||||
Writer: iConn,
|
||||
counter: counter,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"io"
|
||||
"syscall"
|
||||
|
||||
"github.com/xtls/xray-core/features/stats"
|
||||
|
||||
"github.com/xtls/xray-core/common/platform"
|
||||
)
|
||||
|
||||
|
@ -53,17 +55,19 @@ type ReadVReader struct {
|
|||
rawConn syscall.RawConn
|
||||
mr multiReader
|
||||
alloc allocStrategy
|
||||
counter stats.Counter
|
||||
}
|
||||
|
||||
// NewReadVReader creates a new ReadVReader.
|
||||
func NewReadVReader(reader io.Reader, rawConn syscall.RawConn) *ReadVReader {
|
||||
func NewReadVReader(reader io.Reader, rawConn syscall.RawConn, counter stats.Counter) *ReadVReader {
|
||||
return &ReadVReader{
|
||||
Reader: reader,
|
||||
rawConn: rawConn,
|
||||
alloc: allocStrategy{
|
||||
current: 1,
|
||||
},
|
||||
mr: newMultiReader(),
|
||||
mr: newMultiReader(),
|
||||
counter: counter,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -122,10 +126,16 @@ func (r *ReadVReader) ReadMultiBuffer() (MultiBuffer, error) {
|
|||
if b.IsFull() {
|
||||
r.alloc.Adjust(1)
|
||||
}
|
||||
if r.counter != nil && b != nil {
|
||||
r.counter.Add(int64(b.Len()))
|
||||
}
|
||||
return MultiBuffer{b}, err
|
||||
}
|
||||
|
||||
mb, err := r.readMulti()
|
||||
if r.counter != nil && mb != nil {
|
||||
r.counter.Add(int64(mb.Len()))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -50,7 +50,7 @@ func TestReadvReader(t *testing.T) {
|
|||
rawConn, err := conn.(*net.TCPConn).SyscallConn()
|
||||
common.Must(err)
|
||||
|
||||
reader := NewReadVReader(conn, rawConn)
|
||||
reader := NewReadVReader(conn, rawConn, nil)
|
||||
var rmb MultiBuffer
|
||||
for {
|
||||
mb, err := reader.ReadMultiBuffer()
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/xtls/xray-core/features/stats"
|
||||
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
)
|
||||
|
@ -13,7 +15,8 @@ import (
|
|||
type BufferToBytesWriter struct {
|
||||
io.Writer
|
||||
|
||||
cache [][]byte
|
||||
counter stats.Counter
|
||||
cache [][]byte
|
||||
}
|
||||
|
||||
// WriteMultiBuffer implements Writer. This method takes ownership of the given buffer.
|
||||
|
@ -26,7 +29,7 @@ func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
|
|||
}
|
||||
|
||||
if len(mb) == 1 {
|
||||
return WriteAllBytes(w.Writer, mb[0].Bytes())
|
||||
return WriteAllBytes(w.Writer, mb[0].Bytes(), w.counter)
|
||||
}
|
||||
|
||||
if cap(w.cache) < len(mb) {
|
||||
|
@ -45,9 +48,15 @@ func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
|
|||
}()
|
||||
|
||||
nb := net.Buffers(bs)
|
||||
|
||||
wc := int64(0)
|
||||
defer func() {
|
||||
if w.counter != nil {
|
||||
w.counter.Add(wc)
|
||||
}
|
||||
}()
|
||||
for size > 0 {
|
||||
n, err := nb.WriteTo(w.Writer)
|
||||
wc += n
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -173,7 +182,7 @@ func (w *BufferedWriter) flushInternal() error {
|
|||
w.buffer = nil
|
||||
|
||||
if writer, ok := w.writer.(io.Writer); ok {
|
||||
err := WriteAllBytes(writer, b.Bytes())
|
||||
err := WriteAllBytes(writer, b.Bytes(), nil)
|
||||
b.Release()
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -50,7 +50,7 @@ func NewCryptionWriter(stream cipher.Stream, writer io.Writer) *CryptionWriter {
|
|||
func (w *CryptionWriter) Write(data []byte) (int, error) {
|
||||
w.stream.XORKeyStream(data, data)
|
||||
|
||||
if err := buf.WriteAllBytes(w.writer, data); err != nil {
|
||||
if err := buf.WriteAllBytes(w.writer, data, nil); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue