diff --git a/proxy/proxy.go b/proxy/proxy.go index 1e4c69f5..a3d3fccb 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -110,7 +110,8 @@ type TrafficState struct { // reader link state WithinPaddingBuffers bool - ReaderSwitchToDirectCopy bool + DownlinkReaderDirectCopy bool + UplinkReaderDirectCopy bool RemainingCommand int32 RemainingContent int32 RemainingPadding int32 @@ -118,7 +119,8 @@ type TrafficState struct { // write link state IsPadding bool - WriterSwitchToDirectCopy bool + DownlinkWriterDirectCopy bool + UplinkWriterDirectCopy bool } func NewTrafficState(userUUID []byte) *TrafficState { @@ -131,13 +133,15 @@ func NewTrafficState(userUUID []byte) *TrafficState { Cipher: 0, RemainingServerHello: -1, WithinPaddingBuffers: true, - ReaderSwitchToDirectCopy: false, + DownlinkReaderDirectCopy: false, + UplinkReaderDirectCopy: false, RemainingCommand: -1, RemainingContent: -1, RemainingPadding: -1, CurrentCommand: 0, IsPadding: true, - WriterSwitchToDirectCopy: false, + DownlinkWriterDirectCopy: false, + UplinkWriterDirectCopy: false, } } @@ -147,13 +151,15 @@ type VisionReader struct { buf.Reader trafficState *TrafficState ctx context.Context + isUplink bool } -func NewVisionReader(reader buf.Reader, state *TrafficState, context context.Context) *VisionReader { +func NewVisionReader(reader buf.Reader, state *TrafficState, isUplink bool, context context.Context) *VisionReader { return &VisionReader{ Reader: reader, trafficState: state, ctx: context, + isUplink: isUplink, } } @@ -175,7 +181,11 @@ func (w *VisionReader) ReadMultiBuffer() (buf.MultiBuffer, error) { w.trafficState.WithinPaddingBuffers = false } else if w.trafficState.CurrentCommand == 2 { w.trafficState.WithinPaddingBuffers = false - w.trafficState.ReaderSwitchToDirectCopy = true + if w.isUplink { + w.trafficState.UplinkReaderDirectCopy = true + } else { + w.trafficState.DownlinkReaderDirectCopy = true + } } else { errors.LogInfo(w.ctx, "XtlsRead unknown command ", w.trafficState.CurrentCommand, buffer.Len()) } @@ -194,9 +204,10 @@ type VisionWriter struct { trafficState *TrafficState ctx context.Context writeOnceUserUUID []byte + isUplink bool } -func NewVisionWriter(writer buf.Writer, state *TrafficState, context context.Context) *VisionWriter { +func NewVisionWriter(writer buf.Writer, state *TrafficState, isUplink bool, context context.Context) *VisionWriter { w := make([]byte, len(state.UserUUID)) copy(w, state.UserUUID) return &VisionWriter{ @@ -204,6 +215,7 @@ func NewVisionWriter(writer buf.Writer, state *TrafficState, context context.Con trafficState: state, ctx: context, writeOnceUserUUID: w, + isUplink: isUplink, } } @@ -221,7 +233,11 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { for i, b := range mb { if w.trafficState.IsTLS && b.Len() >= 6 && bytes.Equal(TlsApplicationDataStart, b.BytesTo(3)) { if w.trafficState.EnableXtls { - w.trafficState.WriterSwitchToDirectCopy = true + if w.isUplink { + w.trafficState.UplinkWriterDirectCopy = true + } else { + w.trafficState.DownlinkWriterDirectCopy = true + } } var command byte = CommandPaddingContinue if i == len(mb)-1 { diff --git a/proxy/vless/encoding/addons.go b/proxy/vless/encoding/addons.go index 1bf1817d..4474e3c9 100644 --- a/proxy/vless/encoding/addons.go +++ b/proxy/vless/encoding/addons.go @@ -61,13 +61,13 @@ func DecodeHeaderAddons(buffer *buf.Buffer, reader io.Reader) (*Addons, error) { } // EncodeBodyAddons returns a Writer that auto-encrypt content written by caller. -func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, requestAddons *Addons, state *proxy.TrafficState, context context.Context) buf.Writer { +func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, requestAddons *Addons, state *proxy.TrafficState, isUplink bool, context context.Context) buf.Writer { if request.Command == protocol.RequestCommandUDP { return NewMultiLengthPacketWriter(writer.(buf.Writer)) } w := buf.NewWriter(writer) if requestAddons.Flow == vless.XRV { - w = proxy.NewVisionWriter(w, state, context) + w = proxy.NewVisionWriter(w, state, isUplink, context) } return w } diff --git a/proxy/vless/encoding/encoding.go b/proxy/vless/encoding/encoding.go index 8b067a96..3fce3290 100644 --- a/proxy/vless/encoding/encoding.go +++ b/proxy/vless/encoding/encoding.go @@ -172,10 +172,10 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A } // XtlsRead filter and read xtls protocol -func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error { +func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error { err := func() error { for { - if trafficState.ReaderSwitchToDirectCopy { + if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy { var writerConn net.Conn var inTimer *signal.ActivityTimer if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil { @@ -193,7 +193,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, buffer, err := reader.ReadMultiBuffer() if !buffer.IsEmpty() { timer.Update() - if trafficState.ReaderSwitchToDirectCopy { + if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy { // XTLS Vision processes struct TLS Conn's input and rawInput if inputBuffer, err := buf.ReadFrom(input); err == nil { if !inputBuffer.IsEmpty() { @@ -222,12 +222,12 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, } // XtlsWrite filter and write xtls protocol -func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error { +func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error { err := func() error { var ct stats.Counter for { buffer, err := reader.ReadMultiBuffer() - if trafficState.WriterSwitchToDirectCopy { + if isUplink && trafficState.UplinkWriterDirectCopy || !isUplink && trafficState.DownlinkWriterDirectCopy { if inbound := session.InboundFromContext(ctx); inbound != nil { if inbound.CanSpliceCopy == 2 { inbound.CanSpliceCopy = 1 @@ -239,7 +239,11 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate rawConn, _, writerCounter := proxy.UnwrapRawConn(conn) writer = buf.NewWriter(rawConn) ct = writerCounter - trafficState.WriterSwitchToDirectCopy = false + if isUplink { + trafficState.UplinkWriterDirectCopy = false + } else { + trafficState.DownlinkWriterDirectCopy = false + } } if !buffer.IsEmpty() { if ct != nil { diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index a2415a44..1da2e091 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -538,8 +538,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s if requestAddons.Flow == vless.XRV { ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice - clientReader = proxy.NewVisionReader(clientReader, trafficState, ctx1) - err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, ctx1) + clientReader = proxy.NewVisionReader(clientReader, trafficState, true, ctx1) + err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, true, ctx1) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBuffer err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) @@ -561,7 +561,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s } // default: clientWriter := bufferWriter - clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, ctx) + clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, false, ctx) multiBuffer, err1 := serverReader.ReadMultiBuffer() if err1 != nil { return err1 // ... @@ -576,7 +576,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s var err error if requestAddons.Flow == vless.XRV { - err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, nil, ctx) + err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, nil, false, ctx) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBuffer err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index ed9e07dc..e1a727eb 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -194,7 +194,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } // default: serverWriter := bufferWriter - serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, ctx) + serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, true, ctx) if request.Command == protocol.RequestCommandMux && request.Port == 666 { serverWriter = xudp.NewPacketWriter(serverWriter, target, xudp.GetGlobalID(ctx)) } @@ -234,7 +234,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } } ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice - err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ob, ctx1) + err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ob, true, ctx1) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBuffer err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) @@ -261,7 +261,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte // default: serverReader := buf.NewReader(conn) serverReader := encoding.DecodeBodyAddons(conn, request, responseAddons) if requestAddons.Flow == vless.XRV { - serverReader = proxy.NewVisionReader(serverReader, trafficState, ctx) + serverReader = proxy.NewVisionReader(serverReader, trafficState, false, ctx) } if request.Command == protocol.RequestCommandMux && request.Port == 666 { if requestAddons.Flow == vless.XRV { @@ -272,7 +272,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } if requestAddons.Flow == vless.XRV { - err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ob, ctx) + err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ob, false, ctx) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBuffer err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))