Add separate uplink/downlink flag for direct copy

For each connection, xtls need 4 flags for uplink/downlink reader/writer to decide when it switch to direct copy.
In the past, there were only one for read and one for write.
If service has xtls inbound and xtls outbound, the two flags may be corrupted by signal from different directions.
This commit is contained in:
yuhan6665 2025-01-26 12:17:17 -05:00
parent 7b59379d73
commit 009037b1e2
5 changed files with 44 additions and 24 deletions

View File

@ -110,7 +110,8 @@ type TrafficState struct {
// reader link state
WithinPaddingBuffers bool
ReaderSwitchToDirectCopy bool
DownlinkReaderDirectCopy bool
UplinkReaderDirectCopy bool
RemainingCommand int32
RemainingContent int32
RemainingPadding int32
@ -118,7 +119,8 @@ type TrafficState struct {
// write link state
IsPadding bool
WriterSwitchToDirectCopy bool
DownlinkWriterDirectCopy bool
UplinkWriterDirectCopy bool
}
func NewTrafficState(userUUID []byte) *TrafficState {
@ -131,13 +133,15 @@ func NewTrafficState(userUUID []byte) *TrafficState {
Cipher: 0,
RemainingServerHello: -1,
WithinPaddingBuffers: true,
ReaderSwitchToDirectCopy: false,
DownlinkReaderDirectCopy: false,
UplinkReaderDirectCopy: false,
RemainingCommand: -1,
RemainingContent: -1,
RemainingPadding: -1,
CurrentCommand: 0,
IsPadding: true,
WriterSwitchToDirectCopy: false,
DownlinkWriterDirectCopy: false,
UplinkWriterDirectCopy: false,
}
}
@ -147,13 +151,15 @@ type VisionReader struct {
buf.Reader
trafficState *TrafficState
ctx context.Context
isUplink bool
}
func NewVisionReader(reader buf.Reader, state *TrafficState, context context.Context) *VisionReader {
func NewVisionReader(reader buf.Reader, state *TrafficState, isUplink bool, context context.Context) *VisionReader {
return &VisionReader{
Reader: reader,
trafficState: state,
ctx: context,
isUplink: isUplink,
}
}
@ -175,7 +181,11 @@ func (w *VisionReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
w.trafficState.WithinPaddingBuffers = false
} else if w.trafficState.CurrentCommand == 2 {
w.trafficState.WithinPaddingBuffers = false
w.trafficState.ReaderSwitchToDirectCopy = true
if w.isUplink {
w.trafficState.UplinkReaderDirectCopy = true
} else {
w.trafficState.DownlinkReaderDirectCopy = true
}
} else {
errors.LogInfo(w.ctx, "XtlsRead unknown command ", w.trafficState.CurrentCommand, buffer.Len())
}
@ -194,9 +204,10 @@ type VisionWriter struct {
trafficState *TrafficState
ctx context.Context
writeOnceUserUUID []byte
isUplink bool
}
func NewVisionWriter(writer buf.Writer, state *TrafficState, context context.Context) *VisionWriter {
func NewVisionWriter(writer buf.Writer, state *TrafficState, isUplink bool, context context.Context) *VisionWriter {
w := make([]byte, len(state.UserUUID))
copy(w, state.UserUUID)
return &VisionWriter{
@ -204,6 +215,7 @@ func NewVisionWriter(writer buf.Writer, state *TrafficState, context context.Con
trafficState: state,
ctx: context,
writeOnceUserUUID: w,
isUplink: isUplink,
}
}
@ -221,7 +233,11 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
for i, b := range mb {
if w.trafficState.IsTLS && b.Len() >= 6 && bytes.Equal(TlsApplicationDataStart, b.BytesTo(3)) {
if w.trafficState.EnableXtls {
w.trafficState.WriterSwitchToDirectCopy = true
if w.isUplink {
w.trafficState.UplinkWriterDirectCopy = true
} else {
w.trafficState.DownlinkWriterDirectCopy = true
}
}
var command byte = CommandPaddingContinue
if i == len(mb)-1 {

View File

@ -61,13 +61,13 @@ func DecodeHeaderAddons(buffer *buf.Buffer, reader io.Reader) (*Addons, error) {
}
// EncodeBodyAddons returns a Writer that auto-encrypt content written by caller.
func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, requestAddons *Addons, state *proxy.TrafficState, context context.Context) buf.Writer {
func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, requestAddons *Addons, state *proxy.TrafficState, isUplink bool, context context.Context) buf.Writer {
if request.Command == protocol.RequestCommandUDP {
return NewMultiLengthPacketWriter(writer.(buf.Writer))
}
w := buf.NewWriter(writer)
if requestAddons.Flow == vless.XRV {
w = proxy.NewVisionWriter(w, state, context)
w = proxy.NewVisionWriter(w, state, isUplink, context)
}
return w
}

View File

@ -172,10 +172,10 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A
}
// XtlsRead filter and read xtls protocol
func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error {
func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error {
err := func() error {
for {
if trafficState.ReaderSwitchToDirectCopy {
if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy {
var writerConn net.Conn
var inTimer *signal.ActivityTimer
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil {
@ -193,7 +193,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer,
buffer, err := reader.ReadMultiBuffer()
if !buffer.IsEmpty() {
timer.Update()
if trafficState.ReaderSwitchToDirectCopy {
if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy {
// XTLS Vision processes struct TLS Conn's input and rawInput
if inputBuffer, err := buf.ReadFrom(input); err == nil {
if !inputBuffer.IsEmpty() {
@ -222,12 +222,12 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer,
}
// XtlsWrite filter and write xtls protocol
func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error {
func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error {
err := func() error {
var ct stats.Counter
for {
buffer, err := reader.ReadMultiBuffer()
if trafficState.WriterSwitchToDirectCopy {
if isUplink && trafficState.UplinkWriterDirectCopy || !isUplink && trafficState.DownlinkWriterDirectCopy {
if inbound := session.InboundFromContext(ctx); inbound != nil {
if inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1
@ -239,7 +239,11 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
rawConn, _, writerCounter := proxy.UnwrapRawConn(conn)
writer = buf.NewWriter(rawConn)
ct = writerCounter
trafficState.WriterSwitchToDirectCopy = false
if isUplink {
trafficState.UplinkWriterDirectCopy = false
} else {
trafficState.DownlinkWriterDirectCopy = false
}
}
if !buffer.IsEmpty() {
if ct != nil {

View File

@ -538,8 +538,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
if requestAddons.Flow == vless.XRV {
ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice
clientReader = proxy.NewVisionReader(clientReader, trafficState, ctx1)
err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, ctx1)
clientReader = proxy.NewVisionReader(clientReader, trafficState, true, ctx1)
err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, true, ctx1)
} else {
// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBuffer
err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
@ -561,7 +561,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
}
// default: clientWriter := bufferWriter
clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, ctx)
clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, false, ctx)
multiBuffer, err1 := serverReader.ReadMultiBuffer()
if err1 != nil {
return err1 // ...
@ -576,7 +576,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
var err error
if requestAddons.Flow == vless.XRV {
err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, nil, ctx)
err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, nil, false, ctx)
} else {
// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBuffer
err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))

View File

@ -194,7 +194,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
}
// default: serverWriter := bufferWriter
serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, ctx)
serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, true, ctx)
if request.Command == protocol.RequestCommandMux && request.Port == 666 {
serverWriter = xudp.NewPacketWriter(serverWriter, target, xudp.GetGlobalID(ctx))
}
@ -234,7 +234,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
}
}
ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice
err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ob, ctx1)
err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ob, true, ctx1)
} else {
// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBuffer
err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
@ -261,7 +261,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
// default: serverReader := buf.NewReader(conn)
serverReader := encoding.DecodeBodyAddons(conn, request, responseAddons)
if requestAddons.Flow == vless.XRV {
serverReader = proxy.NewVisionReader(serverReader, trafficState, ctx)
serverReader = proxy.NewVisionReader(serverReader, trafficState, false, ctx)
}
if request.Command == protocol.RequestCommandMux && request.Port == 666 {
if requestAddons.Flow == vless.XRV {
@ -272,7 +272,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
}
if requestAddons.Flow == vless.XRV {
err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ob, ctx)
err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ob, false, ctx)
} else {
// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBuffer
err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))