XTLS: More separate uplink/downlink flags for splice copy (#4407)

- In 03131c72dbbfc13ba4ce8e1f9f65f43f3dda7372 new flags were added for uplink/downlink, but that was not suffcient
- Now that the traffic state contains all possible info
- Each inbound and outbound is responsible to set their own CanSpliceCopy flag. Note that this also open up more splice usage. E.g. socks in -> freedom out
- Fixes https://github.com/XTLS/Xray-core/issues/4033
This commit is contained in:
yuhan6665 2025-02-18 03:37:52 -05:00 committed by GitHub
parent a1714cc4ce
commit eef74b2c7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 132 additions and 67 deletions

View File

@ -151,6 +151,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
} }
responseFunc := func() error { responseFunc := func() error {
ob.CanSpliceCopy = 1
defer timer.SetTimeout(p.Timeouts.UplinkOnly) defer timer.SetTimeout(p.Timeouts.UplinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
} }

View File

@ -207,6 +207,7 @@ func (s *Server) handleConnect(ctx context.Context, _ *http.Request, reader *buf
} }
responseDone := func() error { responseDone := func() error {
inbound.CanSpliceCopy = 1
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
v2writer := buf.NewWriter(conn) v2writer := buf.NewWriter(conn)

View File

@ -107,19 +107,33 @@ type TrafficState struct {
IsTLS bool IsTLS bool
Cipher uint16 Cipher uint16
RemainingServerHello int32 RemainingServerHello int32
Inbound InboundState
Outbound OutboundState
}
type InboundState struct {
// reader link state // reader link state
WithinPaddingBuffers bool WithinPaddingBuffers bool
DownlinkReaderDirectCopy bool
UplinkReaderDirectCopy bool UplinkReaderDirectCopy bool
RemainingCommand int32 RemainingCommand int32
RemainingContent int32 RemainingContent int32
RemainingPadding int32 RemainingPadding int32
CurrentCommand int CurrentCommand int
// write link state // write link state
IsPadding bool IsPadding bool
DownlinkWriterDirectCopy bool DownlinkWriterDirectCopy bool
}
type OutboundState struct {
// reader link state
WithinPaddingBuffers bool
DownlinkReaderDirectCopy bool
RemainingCommand int32
RemainingContent int32
RemainingPadding int32
CurrentCommand int
// write link state
IsPadding bool
UplinkWriterDirectCopy bool UplinkWriterDirectCopy bool
} }
@ -132,16 +146,26 @@ func NewTrafficState(userUUID []byte) *TrafficState {
IsTLS: false, IsTLS: false,
Cipher: 0, Cipher: 0,
RemainingServerHello: -1, RemainingServerHello: -1,
WithinPaddingBuffers: true, Inbound: InboundState{
DownlinkReaderDirectCopy: false, WithinPaddingBuffers: true,
UplinkReaderDirectCopy: false, UplinkReaderDirectCopy: false,
RemainingCommand: -1, RemainingCommand: -1,
RemainingContent: -1, RemainingContent: -1,
RemainingPadding: -1, RemainingPadding: -1,
CurrentCommand: 0, CurrentCommand: 0,
IsPadding: true, IsPadding: true,
DownlinkWriterDirectCopy: false, DownlinkWriterDirectCopy: false,
UplinkWriterDirectCopy: false, },
Outbound: OutboundState{
WithinPaddingBuffers: true,
DownlinkReaderDirectCopy: false,
RemainingCommand: -1,
RemainingContent: -1,
RemainingPadding: -1,
CurrentCommand: 0,
IsPadding: true,
UplinkWriterDirectCopy: false,
},
} }
} }
@ -166,28 +190,43 @@ func NewVisionReader(reader buf.Reader, state *TrafficState, isUplink bool, cont
func (w *VisionReader) ReadMultiBuffer() (buf.MultiBuffer, error) { func (w *VisionReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
buffer, err := w.Reader.ReadMultiBuffer() buffer, err := w.Reader.ReadMultiBuffer()
if !buffer.IsEmpty() { if !buffer.IsEmpty() {
if w.trafficState.WithinPaddingBuffers || w.trafficState.NumberOfPacketToFilter > 0 { var withinPaddingBuffers *bool
var remainingContent *int32
var remainingPadding *int32
var currentCommand *int
var switchToDirectCopy *bool
if w.isUplink {
withinPaddingBuffers = &w.trafficState.Inbound.WithinPaddingBuffers
remainingContent = &w.trafficState.Inbound.RemainingContent
remainingPadding = &w.trafficState.Inbound.RemainingPadding
currentCommand = &w.trafficState.Inbound.CurrentCommand
switchToDirectCopy = &w.trafficState.Inbound.UplinkReaderDirectCopy
} else {
withinPaddingBuffers = &w.trafficState.Outbound.WithinPaddingBuffers
remainingContent = &w.trafficState.Outbound.RemainingContent
remainingPadding = &w.trafficState.Outbound.RemainingPadding
currentCommand = &w.trafficState.Outbound.CurrentCommand
switchToDirectCopy = &w.trafficState.Outbound.DownlinkReaderDirectCopy
}
if *withinPaddingBuffers || w.trafficState.NumberOfPacketToFilter > 0 {
mb2 := make(buf.MultiBuffer, 0, len(buffer)) mb2 := make(buf.MultiBuffer, 0, len(buffer))
for _, b := range buffer { for _, b := range buffer {
newbuffer := XtlsUnpadding(b, w.trafficState, w.ctx) newbuffer := XtlsUnpadding(b, w.trafficState, w.isUplink, w.ctx)
if newbuffer.Len() > 0 { if newbuffer.Len() > 0 {
mb2 = append(mb2, newbuffer) mb2 = append(mb2, newbuffer)
} }
} }
buffer = mb2 buffer = mb2
if w.trafficState.RemainingContent > 0 || w.trafficState.RemainingPadding > 0 || w.trafficState.CurrentCommand == 0 { if *remainingContent > 0 || *remainingPadding > 0 || *currentCommand == 0 {
w.trafficState.WithinPaddingBuffers = true *withinPaddingBuffers = true
} else if w.trafficState.CurrentCommand == 1 { } else if *currentCommand == 1 {
w.trafficState.WithinPaddingBuffers = false *withinPaddingBuffers = false
} else if w.trafficState.CurrentCommand == 2 { } else if *currentCommand == 2 {
w.trafficState.WithinPaddingBuffers = false *withinPaddingBuffers = false
if w.isUplink { *switchToDirectCopy = true
w.trafficState.UplinkReaderDirectCopy = true
} else {
w.trafficState.DownlinkReaderDirectCopy = true
}
} else { } else {
errors.LogInfo(w.ctx, "XtlsRead unknown command ", w.trafficState.CurrentCommand, buffer.Len()) errors.LogInfo(w.ctx, "XtlsRead unknown command ", *currentCommand, buffer.Len())
} }
} }
if w.trafficState.NumberOfPacketToFilter > 0 { if w.trafficState.NumberOfPacketToFilter > 0 {
@ -223,7 +262,16 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
if w.trafficState.NumberOfPacketToFilter > 0 { if w.trafficState.NumberOfPacketToFilter > 0 {
XtlsFilterTls(mb, w.trafficState, w.ctx) XtlsFilterTls(mb, w.trafficState, w.ctx)
} }
if w.trafficState.IsPadding { var isPadding *bool
var switchToDirectCopy *bool
if w.isUplink {
isPadding = &w.trafficState.Outbound.IsPadding
switchToDirectCopy = &w.trafficState.Outbound.UplinkWriterDirectCopy
} else {
isPadding = &w.trafficState.Inbound.IsPadding
switchToDirectCopy = &w.trafficState.Inbound.DownlinkWriterDirectCopy
}
if *isPadding {
if len(mb) == 1 && mb[0] == nil { if len(mb) == 1 && mb[0] == nil {
mb[0] = XtlsPadding(nil, CommandPaddingContinue, &w.writeOnceUserUUID, true, w.ctx) // we do a long padding to hide vless header mb[0] = XtlsPadding(nil, CommandPaddingContinue, &w.writeOnceUserUUID, true, w.ctx) // we do a long padding to hide vless header
return w.Writer.WriteMultiBuffer(mb) return w.Writer.WriteMultiBuffer(mb)
@ -233,11 +281,7 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
for i, b := range mb { for i, b := range mb {
if w.trafficState.IsTLS && b.Len() >= 6 && bytes.Equal(TlsApplicationDataStart, b.BytesTo(3)) { if w.trafficState.IsTLS && b.Len() >= 6 && bytes.Equal(TlsApplicationDataStart, b.BytesTo(3)) {
if w.trafficState.EnableXtls { if w.trafficState.EnableXtls {
if w.isUplink { *switchToDirectCopy = true
w.trafficState.UplinkWriterDirectCopy = true
} else {
w.trafficState.DownlinkWriterDirectCopy = true
}
} }
var command byte = CommandPaddingContinue var command byte = CommandPaddingContinue
if i == len(mb)-1 { if i == len(mb)-1 {
@ -247,16 +291,16 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
} }
} }
mb[i] = XtlsPadding(b, command, &w.writeOnceUserUUID, true, w.ctx) mb[i] = XtlsPadding(b, command, &w.writeOnceUserUUID, true, w.ctx)
w.trafficState.IsPadding = false // padding going to end *isPadding = false // padding going to end
longPadding = false longPadding = false
continue continue
} else if !w.trafficState.IsTLS12orAbove && w.trafficState.NumberOfPacketToFilter <= 1 { // For compatibility with earlier vision receiver, we finish padding 1 packet early } else if !w.trafficState.IsTLS12orAbove && w.trafficState.NumberOfPacketToFilter <= 1 { // For compatibility with earlier vision receiver, we finish padding 1 packet early
w.trafficState.IsPadding = false *isPadding = false
mb[i] = XtlsPadding(b, CommandPaddingEnd, &w.writeOnceUserUUID, longPadding, w.ctx) mb[i] = XtlsPadding(b, CommandPaddingEnd, &w.writeOnceUserUUID, longPadding, w.ctx)
break break
} }
var command byte = CommandPaddingContinue var command byte = CommandPaddingContinue
if i == len(mb)-1 && !w.trafficState.IsPadding { if i == len(mb)-1 && !*isPadding {
command = CommandPaddingEnd command = CommandPaddingEnd
if w.trafficState.EnableXtls { if w.trafficState.EnableXtls {
command = CommandPaddingDirect command = CommandPaddingDirect
@ -343,38 +387,53 @@ func XtlsPadding(b *buf.Buffer, command byte, userUUID *[]byte, longPadding bool
} }
// XtlsUnpadding remove padding and parse command // XtlsUnpadding remove padding and parse command
func XtlsUnpadding(b *buf.Buffer, s *TrafficState, ctx context.Context) *buf.Buffer { func XtlsUnpadding(b *buf.Buffer, s *TrafficState, isUplink bool, ctx context.Context) *buf.Buffer {
if s.RemainingCommand == -1 && s.RemainingContent == -1 && s.RemainingPadding == -1 { // initial state var remainingCommand *int32
var remainingContent *int32
var remainingPadding *int32
var currentCommand *int
if isUplink {
remainingCommand = &s.Inbound.RemainingCommand
remainingContent = &s.Inbound.RemainingContent
remainingPadding = &s.Inbound.RemainingPadding
currentCommand = &s.Inbound.CurrentCommand
} else {
remainingCommand = &s.Outbound.RemainingCommand
remainingContent = &s.Outbound.RemainingContent
remainingPadding = &s.Outbound.RemainingPadding
currentCommand = &s.Outbound.CurrentCommand
}
if *remainingCommand == -1 && *remainingContent == -1 && *remainingPadding == -1 { // initial state
if b.Len() >= 21 && bytes.Equal(s.UserUUID, b.BytesTo(16)) { if b.Len() >= 21 && bytes.Equal(s.UserUUID, b.BytesTo(16)) {
b.Advance(16) b.Advance(16)
s.RemainingCommand = 5 *remainingCommand = 5
} else { } else {
return b return b
} }
} }
newbuffer := buf.New() newbuffer := buf.New()
for b.Len() > 0 { for b.Len() > 0 {
if s.RemainingCommand > 0 { if *remainingCommand > 0 {
data, err := b.ReadByte() data, err := b.ReadByte()
if err != nil { if err != nil {
return newbuffer return newbuffer
} }
switch s.RemainingCommand { switch *remainingCommand {
case 5: case 5:
s.CurrentCommand = int(data) *currentCommand = int(data)
case 4: case 4:
s.RemainingContent = int32(data) << 8 *remainingContent = int32(data) << 8
case 3: case 3:
s.RemainingContent = s.RemainingContent | int32(data) *remainingContent = *remainingContent | int32(data)
case 2: case 2:
s.RemainingPadding = int32(data) << 8 *remainingPadding = int32(data) << 8
case 1: case 1:
s.RemainingPadding = s.RemainingPadding | int32(data) *remainingPadding = *remainingPadding | int32(data)
errors.LogInfo(ctx, "Xtls Unpadding new block, content ", s.RemainingContent, " padding ", s.RemainingPadding, " command ", s.CurrentCommand) errors.LogInfo(ctx, "Xtls Unpadding new block, content ", *remainingContent, " padding ", *remainingPadding, " command ", *currentCommand)
} }
s.RemainingCommand-- *remainingCommand--
} else if s.RemainingContent > 0 { } else if *remainingContent > 0 {
len := s.RemainingContent len := *remainingContent
if b.Len() < len { if b.Len() < len {
len = b.Len() len = b.Len()
} }
@ -383,22 +442,22 @@ func XtlsUnpadding(b *buf.Buffer, s *TrafficState, ctx context.Context) *buf.Buf
return newbuffer return newbuffer
} }
newbuffer.Write(data) newbuffer.Write(data)
s.RemainingContent -= len *remainingContent -= len
} else { // remainingPadding > 0 } else { // remainingPadding > 0
len := s.RemainingPadding len := *remainingPadding
if b.Len() < len { if b.Len() < len {
len = b.Len() len = b.Len()
} }
b.Advance(len) b.Advance(len)
s.RemainingPadding -= len *remainingPadding -= len
} }
if s.RemainingCommand <= 0 && s.RemainingContent <= 0 && s.RemainingPadding <= 0 { // this block done if *remainingCommand <= 0 && *remainingContent <= 0 && *remainingPadding <= 0 { // this block done
if s.CurrentCommand == 0 { if *currentCommand == 0 {
s.RemainingCommand = 5 *remainingCommand = 5
} else { } else {
s.RemainingCommand = -1 // set to initial state *remainingCommand = -1 // set to initial state
s.RemainingContent = -1 *remainingContent = -1
s.RemainingPadding = -1 *remainingPadding = -1
if b.Len() > 0 { // shouldn't happen if b.Len() > 0 { // shouldn't happen
newbuffer.Write(b.Bytes()) newbuffer.Write(b.Bytes())
} }

View File

@ -146,6 +146,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
} }
responseFunc = func() error { responseFunc = func() error {
ob.CanSpliceCopy = 1
defer timer.SetTimeout(p.Timeouts.UplinkOnly) defer timer.SetTimeout(p.Timeouts.UplinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
} }
@ -161,6 +162,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)) return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer))
} }
responseFunc = func() error { responseFunc = func() error {
ob.CanSpliceCopy = 1
defer timer.SetTimeout(p.Timeouts.UplinkOnly) defer timer.SetTimeout(p.Timeouts.UplinkOnly)
reader := &UDPReader{Reader: udpConn} reader := &UDPReader{Reader: udpConn}
return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)) return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer))

View File

@ -199,6 +199,7 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
} }
responseDone := func() error { responseDone := func() error {
inbound.CanSpliceCopy = 1
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
v2writer := buf.NewWriter(writer) v2writer := buf.NewWriter(writer)
@ -256,6 +257,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
if inbound != nil && inbound.Source.IsValid() { if inbound != nil && inbound.Source.IsValid() {
errors.LogInfo(ctx, "client UDP connection from ", inbound.Source) errors.LogInfo(ctx, "client UDP connection from ", inbound.Source)
} }
inbound.CanSpliceCopy = 1
var dest *net.Destination var dest *net.Destination

View File

@ -175,16 +175,16 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A
func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error { func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error {
err := func() error { err := func() error {
for { for {
if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy { if isUplink && trafficState.Inbound.UplinkReaderDirectCopy || !isUplink && trafficState.Outbound.DownlinkReaderDirectCopy {
var writerConn net.Conn var writerConn net.Conn
var inTimer *signal.ActivityTimer var inTimer *signal.ActivityTimer
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil { if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil {
writerConn = inbound.Conn writerConn = inbound.Conn
inTimer = inbound.Timer inTimer = inbound.Timer
if inbound.CanSpliceCopy == 2 { if isUplink && inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1 inbound.CanSpliceCopy = 1
} }
if ob != nil && ob.CanSpliceCopy == 2 { // ob need to be passed in due to context can change if !isUplink && ob != nil && ob.CanSpliceCopy == 2 { // ob need to be passed in due to context can change
ob.CanSpliceCopy = 1 ob.CanSpliceCopy = 1
} }
} }
@ -193,7 +193,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer,
buffer, err := reader.ReadMultiBuffer() buffer, err := reader.ReadMultiBuffer()
if !buffer.IsEmpty() { if !buffer.IsEmpty() {
timer.Update() timer.Update()
if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy { if isUplink && trafficState.Inbound.UplinkReaderDirectCopy || !isUplink && trafficState.Outbound.DownlinkReaderDirectCopy {
// XTLS Vision processes struct TLS Conn's input and rawInput // XTLS Vision processes struct TLS Conn's input and rawInput
if inputBuffer, err := buf.ReadFrom(input); err == nil { if inputBuffer, err := buf.ReadFrom(input); err == nil {
if !inputBuffer.IsEmpty() { if !inputBuffer.IsEmpty() {
@ -227,12 +227,12 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
var ct stats.Counter var ct stats.Counter
for { for {
buffer, err := reader.ReadMultiBuffer() buffer, err := reader.ReadMultiBuffer()
if isUplink && trafficState.UplinkWriterDirectCopy || !isUplink && trafficState.DownlinkWriterDirectCopy { if isUplink && trafficState.Outbound.UplinkWriterDirectCopy || !isUplink && trafficState.Inbound.DownlinkWriterDirectCopy {
if inbound := session.InboundFromContext(ctx); inbound != nil { if inbound := session.InboundFromContext(ctx); inbound != nil {
if inbound.CanSpliceCopy == 2 { if !isUplink && inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1 inbound.CanSpliceCopy = 1
} }
if ob != nil && ob.CanSpliceCopy == 2 { if isUplink && ob != nil && ob.CanSpliceCopy == 2 {
ob.CanSpliceCopy = 1 ob.CanSpliceCopy = 1
} }
} }
@ -240,9 +240,9 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
writer = buf.NewWriter(rawConn) writer = buf.NewWriter(rawConn)
ct = writerCounter ct = writerCounter
if isUplink { if isUplink {
trafficState.UplinkWriterDirectCopy = false trafficState.Outbound.UplinkWriterDirectCopy = false
} else { } else {
trafficState.DownlinkWriterDirectCopy = false trafficState.Inbound.DownlinkWriterDirectCopy = false
} }
} }
if !buffer.IsEmpty() { if !buffer.IsEmpty() {