From eef74b2c7dd6d988b8baab6b5d15062e6497f372 Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Tue, 18 Feb 2025 03:37:52 -0500 Subject: [PATCH] XTLS: More separate uplink/downlink flags for splice copy (#4407) - In 03131c72dbbfc13ba4ce8e1f9f65f43f3dda7372 new flags were added for uplink/downlink, but that was not suffcient - Now that the traffic state contains all possible info - Each inbound and outbound is responsible to set their own CanSpliceCopy flag. Note that this also open up more splice usage. E.g. socks in -> freedom out - Fixes https://github.com/XTLS/Xray-core/issues/4033 --- proxy/http/client.go | 1 + proxy/http/server.go | 1 + proxy/proxy.go | 175 +++++++++++++++++++++---------- proxy/socks/client.go | 2 + proxy/socks/server.go | 2 + proxy/vless/encoding/encoding.go | 18 ++-- 6 files changed, 132 insertions(+), 67 deletions(-) diff --git a/proxy/http/client.go b/proxy/http/client.go index 862ca418..b1326bec 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -151,6 +151,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) } responseFunc := func() error { + ob.CanSpliceCopy = 1 defer timer.SetTimeout(p.Timeouts.UplinkOnly) return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) } diff --git a/proxy/http/server.go b/proxy/http/server.go index 01216513..24708e69 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -207,6 +207,7 @@ func (s *Server) handleConnect(ctx context.Context, _ *http.Request, reader *buf } responseDone := func() error { + inbound.CanSpliceCopy = 1 defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) v2writer := buf.NewWriter(conn) diff --git a/proxy/proxy.go b/proxy/proxy.go index a3d3fccb..a90aa340 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -107,19 +107,33 @@ type TrafficState struct { IsTLS bool Cipher uint16 RemainingServerHello int32 + Inbound InboundState + Outbound OutboundState +} +type InboundState struct { // reader link state WithinPaddingBuffers bool - DownlinkReaderDirectCopy bool UplinkReaderDirectCopy bool RemainingCommand int32 RemainingContent int32 RemainingPadding int32 CurrentCommand int - // write link state IsPadding bool DownlinkWriterDirectCopy bool +} + +type OutboundState struct { + // reader link state + WithinPaddingBuffers bool + DownlinkReaderDirectCopy bool + RemainingCommand int32 + RemainingContent int32 + RemainingPadding int32 + CurrentCommand int + // write link state + IsPadding bool UplinkWriterDirectCopy bool } @@ -132,16 +146,26 @@ func NewTrafficState(userUUID []byte) *TrafficState { IsTLS: false, Cipher: 0, RemainingServerHello: -1, - WithinPaddingBuffers: true, - DownlinkReaderDirectCopy: false, - UplinkReaderDirectCopy: false, - RemainingCommand: -1, - RemainingContent: -1, - RemainingPadding: -1, - CurrentCommand: 0, - IsPadding: true, - DownlinkWriterDirectCopy: false, - UplinkWriterDirectCopy: false, + Inbound: InboundState{ + WithinPaddingBuffers: true, + UplinkReaderDirectCopy: false, + RemainingCommand: -1, + RemainingContent: -1, + RemainingPadding: -1, + CurrentCommand: 0, + IsPadding: true, + DownlinkWriterDirectCopy: false, + }, + Outbound: OutboundState{ + WithinPaddingBuffers: true, + DownlinkReaderDirectCopy: false, + RemainingCommand: -1, + RemainingContent: -1, + RemainingPadding: -1, + CurrentCommand: 0, + IsPadding: true, + UplinkWriterDirectCopy: false, + }, } } @@ -166,28 +190,43 @@ func NewVisionReader(reader buf.Reader, state *TrafficState, isUplink bool, cont func (w *VisionReader) ReadMultiBuffer() (buf.MultiBuffer, error) { buffer, err := w.Reader.ReadMultiBuffer() if !buffer.IsEmpty() { - if w.trafficState.WithinPaddingBuffers || w.trafficState.NumberOfPacketToFilter > 0 { + var withinPaddingBuffers *bool + var remainingContent *int32 + var remainingPadding *int32 + var currentCommand *int + var switchToDirectCopy *bool + if w.isUplink { + withinPaddingBuffers = &w.trafficState.Inbound.WithinPaddingBuffers + remainingContent = &w.trafficState.Inbound.RemainingContent + remainingPadding = &w.trafficState.Inbound.RemainingPadding + currentCommand = &w.trafficState.Inbound.CurrentCommand + switchToDirectCopy = &w.trafficState.Inbound.UplinkReaderDirectCopy + } else { + withinPaddingBuffers = &w.trafficState.Outbound.WithinPaddingBuffers + remainingContent = &w.trafficState.Outbound.RemainingContent + remainingPadding = &w.trafficState.Outbound.RemainingPadding + currentCommand = &w.trafficState.Outbound.CurrentCommand + switchToDirectCopy = &w.trafficState.Outbound.DownlinkReaderDirectCopy + } + + if *withinPaddingBuffers || w.trafficState.NumberOfPacketToFilter > 0 { mb2 := make(buf.MultiBuffer, 0, len(buffer)) for _, b := range buffer { - newbuffer := XtlsUnpadding(b, w.trafficState, w.ctx) + newbuffer := XtlsUnpadding(b, w.trafficState, w.isUplink, w.ctx) if newbuffer.Len() > 0 { mb2 = append(mb2, newbuffer) } } buffer = mb2 - if w.trafficState.RemainingContent > 0 || w.trafficState.RemainingPadding > 0 || w.trafficState.CurrentCommand == 0 { - w.trafficState.WithinPaddingBuffers = true - } else if w.trafficState.CurrentCommand == 1 { - w.trafficState.WithinPaddingBuffers = false - } else if w.trafficState.CurrentCommand == 2 { - w.trafficState.WithinPaddingBuffers = false - if w.isUplink { - w.trafficState.UplinkReaderDirectCopy = true - } else { - w.trafficState.DownlinkReaderDirectCopy = true - } + if *remainingContent > 0 || *remainingPadding > 0 || *currentCommand == 0 { + *withinPaddingBuffers = true + } else if *currentCommand == 1 { + *withinPaddingBuffers = false + } else if *currentCommand == 2 { + *withinPaddingBuffers = false + *switchToDirectCopy = true } else { - errors.LogInfo(w.ctx, "XtlsRead unknown command ", w.trafficState.CurrentCommand, buffer.Len()) + errors.LogInfo(w.ctx, "XtlsRead unknown command ", *currentCommand, buffer.Len()) } } if w.trafficState.NumberOfPacketToFilter > 0 { @@ -223,7 +262,16 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { if w.trafficState.NumberOfPacketToFilter > 0 { XtlsFilterTls(mb, w.trafficState, w.ctx) } - if w.trafficState.IsPadding { + var isPadding *bool + var switchToDirectCopy *bool + if w.isUplink { + isPadding = &w.trafficState.Outbound.IsPadding + switchToDirectCopy = &w.trafficState.Outbound.UplinkWriterDirectCopy + } else { + isPadding = &w.trafficState.Inbound.IsPadding + switchToDirectCopy = &w.trafficState.Inbound.DownlinkWriterDirectCopy + } + if *isPadding { if len(mb) == 1 && mb[0] == nil { mb[0] = XtlsPadding(nil, CommandPaddingContinue, &w.writeOnceUserUUID, true, w.ctx) // we do a long padding to hide vless header return w.Writer.WriteMultiBuffer(mb) @@ -233,11 +281,7 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { for i, b := range mb { if w.trafficState.IsTLS && b.Len() >= 6 && bytes.Equal(TlsApplicationDataStart, b.BytesTo(3)) { if w.trafficState.EnableXtls { - if w.isUplink { - w.trafficState.UplinkWriterDirectCopy = true - } else { - w.trafficState.DownlinkWriterDirectCopy = true - } + *switchToDirectCopy = true } var command byte = CommandPaddingContinue if i == len(mb)-1 { @@ -247,16 +291,16 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { } } mb[i] = XtlsPadding(b, command, &w.writeOnceUserUUID, true, w.ctx) - w.trafficState.IsPadding = false // padding going to end + *isPadding = false // padding going to end longPadding = false continue } else if !w.trafficState.IsTLS12orAbove && w.trafficState.NumberOfPacketToFilter <= 1 { // For compatibility with earlier vision receiver, we finish padding 1 packet early - w.trafficState.IsPadding = false + *isPadding = false mb[i] = XtlsPadding(b, CommandPaddingEnd, &w.writeOnceUserUUID, longPadding, w.ctx) break } var command byte = CommandPaddingContinue - if i == len(mb)-1 && !w.trafficState.IsPadding { + if i == len(mb)-1 && !*isPadding { command = CommandPaddingEnd if w.trafficState.EnableXtls { command = CommandPaddingDirect @@ -343,38 +387,53 @@ func XtlsPadding(b *buf.Buffer, command byte, userUUID *[]byte, longPadding bool } // XtlsUnpadding remove padding and parse command -func XtlsUnpadding(b *buf.Buffer, s *TrafficState, ctx context.Context) *buf.Buffer { - if s.RemainingCommand == -1 && s.RemainingContent == -1 && s.RemainingPadding == -1 { // initial state +func XtlsUnpadding(b *buf.Buffer, s *TrafficState, isUplink bool, ctx context.Context) *buf.Buffer { + var remainingCommand *int32 + var remainingContent *int32 + var remainingPadding *int32 + var currentCommand *int + if isUplink { + remainingCommand = &s.Inbound.RemainingCommand + remainingContent = &s.Inbound.RemainingContent + remainingPadding = &s.Inbound.RemainingPadding + currentCommand = &s.Inbound.CurrentCommand + } else { + remainingCommand = &s.Outbound.RemainingCommand + remainingContent = &s.Outbound.RemainingContent + remainingPadding = &s.Outbound.RemainingPadding + currentCommand = &s.Outbound.CurrentCommand + } + if *remainingCommand == -1 && *remainingContent == -1 && *remainingPadding == -1 { // initial state if b.Len() >= 21 && bytes.Equal(s.UserUUID, b.BytesTo(16)) { b.Advance(16) - s.RemainingCommand = 5 + *remainingCommand = 5 } else { return b } } newbuffer := buf.New() for b.Len() > 0 { - if s.RemainingCommand > 0 { + if *remainingCommand > 0 { data, err := b.ReadByte() if err != nil { return newbuffer } - switch s.RemainingCommand { + switch *remainingCommand { case 5: - s.CurrentCommand = int(data) + *currentCommand = int(data) case 4: - s.RemainingContent = int32(data) << 8 + *remainingContent = int32(data) << 8 case 3: - s.RemainingContent = s.RemainingContent | int32(data) + *remainingContent = *remainingContent | int32(data) case 2: - s.RemainingPadding = int32(data) << 8 + *remainingPadding = int32(data) << 8 case 1: - s.RemainingPadding = s.RemainingPadding | int32(data) - errors.LogInfo(ctx, "Xtls Unpadding new block, content ", s.RemainingContent, " padding ", s.RemainingPadding, " command ", s.CurrentCommand) + *remainingPadding = *remainingPadding | int32(data) + errors.LogInfo(ctx, "Xtls Unpadding new block, content ", *remainingContent, " padding ", *remainingPadding, " command ", *currentCommand) } - s.RemainingCommand-- - } else if s.RemainingContent > 0 { - len := s.RemainingContent + *remainingCommand-- + } else if *remainingContent > 0 { + len := *remainingContent if b.Len() < len { len = b.Len() } @@ -383,22 +442,22 @@ func XtlsUnpadding(b *buf.Buffer, s *TrafficState, ctx context.Context) *buf.Buf return newbuffer } newbuffer.Write(data) - s.RemainingContent -= len + *remainingContent -= len } else { // remainingPadding > 0 - len := s.RemainingPadding + len := *remainingPadding if b.Len() < len { len = b.Len() } b.Advance(len) - s.RemainingPadding -= len + *remainingPadding -= len } - if s.RemainingCommand <= 0 && s.RemainingContent <= 0 && s.RemainingPadding <= 0 { // this block done - if s.CurrentCommand == 0 { - s.RemainingCommand = 5 + if *remainingCommand <= 0 && *remainingContent <= 0 && *remainingPadding <= 0 { // this block done + if *currentCommand == 0 { + *remainingCommand = 5 } else { - s.RemainingCommand = -1 // set to initial state - s.RemainingContent = -1 - s.RemainingPadding = -1 + *remainingCommand = -1 // set to initial state + *remainingContent = -1 + *remainingPadding = -1 if b.Len() > 0 { // shouldn't happen newbuffer.Write(b.Bytes()) } diff --git a/proxy/socks/client.go b/proxy/socks/client.go index 2ed5740a..232215ec 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -146,6 +146,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) } responseFunc = func() error { + ob.CanSpliceCopy = 1 defer timer.SetTimeout(p.Timeouts.UplinkOnly) return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) } @@ -161,6 +162,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)) } responseFunc = func() error { + ob.CanSpliceCopy = 1 defer timer.SetTimeout(p.Timeouts.UplinkOnly) reader := &UDPReader{Reader: udpConn} return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)) diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 472b23a0..dd6f3953 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -199,6 +199,7 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ } responseDone := func() error { + inbound.CanSpliceCopy = 1 defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) v2writer := buf.NewWriter(writer) @@ -256,6 +257,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis if inbound != nil && inbound.Source.IsValid() { errors.LogInfo(ctx, "client UDP connection from ", inbound.Source) } + inbound.CanSpliceCopy = 1 var dest *net.Destination diff --git a/proxy/vless/encoding/encoding.go b/proxy/vless/encoding/encoding.go index 3fce3290..38043e68 100644 --- a/proxy/vless/encoding/encoding.go +++ b/proxy/vless/encoding/encoding.go @@ -175,16 +175,16 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error { err := func() error { for { - if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy { + if isUplink && trafficState.Inbound.UplinkReaderDirectCopy || !isUplink && trafficState.Outbound.DownlinkReaderDirectCopy { var writerConn net.Conn var inTimer *signal.ActivityTimer if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil { writerConn = inbound.Conn inTimer = inbound.Timer - if inbound.CanSpliceCopy == 2 { + if isUplink && inbound.CanSpliceCopy == 2 { inbound.CanSpliceCopy = 1 } - if ob != nil && ob.CanSpliceCopy == 2 { // ob need to be passed in due to context can change + if !isUplink && ob != nil && ob.CanSpliceCopy == 2 { // ob need to be passed in due to context can change ob.CanSpliceCopy = 1 } } @@ -193,7 +193,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, buffer, err := reader.ReadMultiBuffer() if !buffer.IsEmpty() { timer.Update() - if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy { + if isUplink && trafficState.Inbound.UplinkReaderDirectCopy || !isUplink && trafficState.Outbound.DownlinkReaderDirectCopy { // XTLS Vision processes struct TLS Conn's input and rawInput if inputBuffer, err := buf.ReadFrom(input); err == nil { if !inputBuffer.IsEmpty() { @@ -227,12 +227,12 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate var ct stats.Counter for { buffer, err := reader.ReadMultiBuffer() - if isUplink && trafficState.UplinkWriterDirectCopy || !isUplink && trafficState.DownlinkWriterDirectCopy { + if isUplink && trafficState.Outbound.UplinkWriterDirectCopy || !isUplink && trafficState.Inbound.DownlinkWriterDirectCopy { if inbound := session.InboundFromContext(ctx); inbound != nil { - if inbound.CanSpliceCopy == 2 { + if !isUplink && inbound.CanSpliceCopy == 2 { inbound.CanSpliceCopy = 1 } - if ob != nil && ob.CanSpliceCopy == 2 { + if isUplink && ob != nil && ob.CanSpliceCopy == 2 { ob.CanSpliceCopy = 1 } } @@ -240,9 +240,9 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate writer = buf.NewWriter(rawConn) ct = writerCounter if isUplink { - trafficState.UplinkWriterDirectCopy = false + trafficState.Outbound.UplinkWriterDirectCopy = false } else { - trafficState.DownlinkWriterDirectCopy = false + trafficState.Inbound.DownlinkWriterDirectCopy = false } } if !buffer.IsEmpty() {