From d6d225c6981812f024f90adab8cefc26a5ac0937 Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Sat, 2 Sep 2023 11:37:50 -0400 Subject: [PATCH] Refactor Vision reader writer - Vision now use traffic states to capture two-way info about a connection - XTLS is de-couple with Vision, it only read traffic states to switch to direct copy mode - fix a edge case error when Vision unpadding read 5 command bytes --- proxy/proxy.go | 386 +++++++++++++++++++++++++++++++ proxy/vless/encoding/addons.go | 19 +- proxy/vless/encoding/encoding.go | 356 +++------------------------- proxy/vless/inbound/inbound.go | 25 +- proxy/vless/outbound/outbound.go | 28 +-- 5 files changed, 440 insertions(+), 374 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index 12b9631b..142acb77 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -6,10 +6,14 @@ package proxy import ( + "bytes" "context" + "crypto/rand" gotls "crypto/tls" "io" + "math/big" "runtime" + "strconv" "github.com/pires/go-proxyproto" "github.com/xtls/xray-core/common/buf" @@ -27,6 +31,30 @@ import ( "github.com/xtls/xray-core/transport/internet/tls" ) +var ( + Tls13SupportedVersions = []byte{0x00, 0x2b, 0x00, 0x02, 0x03, 0x04} + TlsClientHandShakeStart = []byte{0x16, 0x03} + TlsServerHandShakeStart = []byte{0x16, 0x03, 0x03} + TlsApplicationDataStart = []byte{0x17, 0x03, 0x03} + + Tls13CipherSuiteDic = map[uint16]string{ + 0x1301: "TLS_AES_128_GCM_SHA256", + 0x1302: "TLS_AES_256_GCM_SHA384", + 0x1303: "TLS_CHACHA20_POLY1305_SHA256", + 0x1304: "TLS_AES_128_CCM_SHA256", + 0x1305: "TLS_AES_128_CCM_8_SHA256", + } +) + +const ( + TlsHandshakeTypeClientHello byte = 0x01 + TlsHandshakeTypeServerHello byte = 0x02 + + CommandPaddingContinue byte = 0x00 + CommandPaddingEnd byte = 0x01 + CommandPaddingDirect byte = 0x02 +) + // An Inbound processes inbound connections. type Inbound interface { // Network returns a list of networks that this inbound supports. Connections with not-supported networks will not be passed into Process(). @@ -59,6 +87,364 @@ type GetOutbound interface { GetOutbound() Outbound } +// TrafficState is used to track uplink and downlink of one connection +// It is used by XTLS to determine if switch to raw copy mode, It is used by Vision to calculate padding +type TrafficState struct { + UserUUID []byte + NumberOfPacketToFilter int + EnableXtls bool + IsTLS12orAbove bool + IsTLS bool + Cipher uint16 + RemainingServerHello int32 + + // reader link state + WithinPaddingBuffers bool + ReaderSwitchToDirectCopy bool + RemainingCommand int32 + RemainingContent int32 + RemainingPadding int32 + CurrentCommand int + + // write link state + IsPadding bool + WriterSwitchToDirectCopy bool +} + +func NewTrafficState(userUUID []byte) *TrafficState { + return &TrafficState{ + UserUUID: userUUID, + NumberOfPacketToFilter: 8, + EnableXtls: false, + IsTLS12orAbove: false, + IsTLS: false, + Cipher: 0, + RemainingServerHello: -1, + WithinPaddingBuffers: true, + ReaderSwitchToDirectCopy: false, + RemainingCommand: -1, + RemainingContent: -1, + RemainingPadding: -1, + CurrentCommand: 0, + IsPadding: true, + WriterSwitchToDirectCopy: false, + } +} + +// VisionReader is used to read xtls vision protocol +// Note Vision probably only make sense as the inner most layer of reader, since it need assess traffic state from origin proxy traffic +type VisionReader struct { + buf.Reader + trafficState *TrafficState + ctx context.Context +} + +func NewVisionReader(reader buf.Reader, state *TrafficState, context context.Context) *VisionReader { + return &VisionReader{ + Reader: reader, + trafficState: state, + ctx: context, + } +} + +func (w *VisionReader) ReadMultiBuffer() (buf.MultiBuffer, error) { + buffer, err := w.Reader.ReadMultiBuffer() + if !buffer.IsEmpty() { + if w.trafficState.WithinPaddingBuffers || w.trafficState.NumberOfPacketToFilter > 0 { + mb2 := make(buf.MultiBuffer, 0, len(buffer)) + for _, b := range buffer { + newbuffer := XtlsUnpadding(b, w.trafficState, w.ctx) + if newbuffer.Len() > 0 { + mb2 = append(mb2, newbuffer) + } + } + buffer = mb2 + if w.trafficState.RemainingContent == 0 && w.trafficState.RemainingPadding == 0 { + if w.trafficState.CurrentCommand == 1 { + w.trafficState.WithinPaddingBuffers = false + } else if w.trafficState.CurrentCommand == 2 { + w.trafficState.WithinPaddingBuffers = false + w.trafficState.ReaderSwitchToDirectCopy = true + } else if w.trafficState.CurrentCommand == 0 { + w.trafficState.WithinPaddingBuffers = true + } else { + newError("XtlsRead unknown command ", w.trafficState.CurrentCommand, buffer.Len()).WriteToLog(session.ExportIDToError(w.ctx)) + } + } else if w.trafficState.RemainingContent > 0 || w.trafficState.RemainingPadding > 0 { + w.trafficState.WithinPaddingBuffers = true + } else { + w.trafficState.WithinPaddingBuffers = false + } + } + if w.trafficState.NumberOfPacketToFilter > 0 { + XtlsFilterTls(buffer, w.trafficState, w.ctx) + } + } + return buffer, err +} + +// VisionWriter is used to write xtls vision protocol +// Note Vision probably only make sense as the inner most layer of writer, since it need assess traffic state from origin proxy traffic +type VisionWriter struct { + buf.Writer + trafficState *TrafficState + ctx context.Context + writeOnceUserUUID []byte +} + +func NewVisionWriter(writer buf.Writer, state *TrafficState, context context.Context) *VisionWriter { + w := make([]byte, len(state.UserUUID)) + copy(w, state.UserUUID) + return &VisionWriter{ + Writer: writer, + trafficState: state, + ctx: context, + writeOnceUserUUID: w, + } +} + +func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { + if w.trafficState.NumberOfPacketToFilter > 0 { + XtlsFilterTls(mb, w.trafficState, w.ctx) + } + if w.trafficState.IsPadding { + if len(mb) == 1 && mb[0] == nil { + mb[0] = XtlsPadding(nil, CommandPaddingContinue, &w.writeOnceUserUUID, true, w.ctx) // we do a long padding to hide vless header + return w.Writer.WriteMultiBuffer(mb) + } + mb = ReshapeMultiBuffer(w.ctx, mb) + longPadding := w.trafficState.IsTLS + for i, b := range mb { + if w.trafficState.IsTLS && b.Len() >= 6 && bytes.Equal(TlsApplicationDataStart, b.BytesTo(3)) { + if w.trafficState.EnableXtls { + w.trafficState.WriterSwitchToDirectCopy = true + } + var command byte = CommandPaddingContinue + if i == len(mb) - 1 { + command = CommandPaddingEnd + if w.trafficState.EnableXtls { + command = CommandPaddingDirect + } + } + mb[i] = XtlsPadding(b, command, &w.writeOnceUserUUID, true, w.ctx) + w.trafficState.IsPadding = false // padding going to end + longPadding = false + continue + } else if !w.trafficState.IsTLS12orAbove && w.trafficState.NumberOfPacketToFilter <= 1 { // For compatibility with earlier vision receiver, we finish padding 1 packet early + w.trafficState.IsPadding = false + mb[i] = XtlsPadding(b, CommandPaddingEnd, &w.writeOnceUserUUID, longPadding, w.ctx) + break + } + var command byte = CommandPaddingContinue + if i == len(mb) - 1 && !w.trafficState.IsPadding { + command = CommandPaddingEnd + if w.trafficState.EnableXtls { + command = CommandPaddingDirect + } + } + mb[i] = XtlsPadding(b, command, &w.writeOnceUserUUID, longPadding, w.ctx) + } + } + return w.Writer.WriteMultiBuffer(mb) +} + +// ReshapeMultiBuffer prepare multi buffer for padding stucture (max 21 bytes) +func ReshapeMultiBuffer(ctx context.Context, buffer buf.MultiBuffer) buf.MultiBuffer { + needReshape := 0 + for _, b := range buffer { + if b.Len() >= buf.Size-21 { + needReshape += 1 + } + } + if needReshape == 0 { + return buffer + } + mb2 := make(buf.MultiBuffer, 0, len(buffer)+needReshape) + toPrint := "" + for i, buffer1 := range buffer { + if buffer1.Len() >= buf.Size-21 { + index := int32(bytes.LastIndex(buffer1.Bytes(), TlsApplicationDataStart)) + if index <= 0 || index > buf.Size-21 { + index = buf.Size / 2 + } + buffer2 := buf.New() + buffer2.Write(buffer1.BytesFrom(index)) + buffer1.Resize(0, index) + mb2 = append(mb2, buffer1, buffer2) + toPrint += " " + strconv.Itoa(int(buffer1.Len())) + " " + strconv.Itoa(int(buffer2.Len())) + } else { + mb2 = append(mb2, buffer1) + toPrint += " " + strconv.Itoa(int(buffer1.Len())) + } + buffer[i] = nil + } + buffer = buffer[:0] + newError("ReshapeMultiBuffer ", toPrint).WriteToLog(session.ExportIDToError(ctx)) + return mb2 +} + +// XtlsPadding add padding to eliminate length siganature during tls handshake +func XtlsPadding(b *buf.Buffer, command byte, userUUID *[]byte, longPadding bool, ctx context.Context) *buf.Buffer { + var contentLen int32 = 0 + var paddingLen int32 = 0 + if b != nil { + contentLen = b.Len() + } + if contentLen < 900 && longPadding { + l, err := rand.Int(rand.Reader, big.NewInt(500)) + if err != nil { + newError("failed to generate padding").Base(err).WriteToLog(session.ExportIDToError(ctx)) + } + paddingLen = int32(l.Int64()) + 900 - contentLen + } else { + l, err := rand.Int(rand.Reader, big.NewInt(256)) + if err != nil { + newError("failed to generate padding").Base(err).WriteToLog(session.ExportIDToError(ctx)) + } + paddingLen = int32(l.Int64()) + } + if paddingLen > buf.Size-21-contentLen { + paddingLen = buf.Size - 21 - contentLen + } + newbuffer := buf.New() + if userUUID != nil { + newbuffer.Write(*userUUID) + *userUUID = nil + } + newbuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)}) + if b != nil { + newbuffer.Write(b.Bytes()) + b.Release() + b = nil + } + newbuffer.Extend(paddingLen) + newError("XtlsPadding ", contentLen, " ", paddingLen, " ", command).WriteToLog(session.ExportIDToError(ctx)) + return newbuffer +} + +// XtlsUnpadding remove padding and parse command +func XtlsUnpadding(b *buf.Buffer, s *TrafficState, ctx context.Context) *buf.Buffer { + if s.RemainingCommand == -1 && s.RemainingContent == -1 && s.RemainingPadding == -1 { // inital state + if b.Len() >= 21 && bytes.Equal(s.UserUUID, b.BytesTo(16)) { + b.Advance(16) + s.RemainingCommand = 5 + } else { + return b + } + } + newbuffer := buf.New() + for b.Len() > 0 { + if s.RemainingCommand > 0 { + data, err := b.ReadByte() + if err != nil { + return newbuffer + } + switch s.RemainingCommand { + case 5: + s.CurrentCommand = int(data) + case 4: + s.RemainingContent = int32(data)<<8 + case 3: + s.RemainingContent = s.RemainingContent | int32(data) + case 2: + s.RemainingPadding = int32(data)<<8 + case 1: + s.RemainingPadding = s.RemainingPadding | int32(data) + newError("Xtls Unpadding new block, content ", s.RemainingContent, " padding ", s.RemainingPadding, " command ", s.CurrentCommand).WriteToLog(session.ExportIDToError(ctx)) + } + s.RemainingCommand-- + } else if s.RemainingContent > 0 { + len := s.RemainingContent + if b.Len() < len { + len = b.Len() + } + data, err := b.ReadBytes(len) + if err != nil { + return newbuffer + } + newbuffer.Write(data) + s.RemainingContent -= len + } else { // remainingPadding > 0 + len := s.RemainingPadding + if b.Len() < len { + len = b.Len() + } + b.Advance(len) + s.RemainingPadding -= len + } + if s.RemainingCommand <= 0 && s.RemainingContent <= 0 && s.RemainingPadding <= 0 { // this block done + if s.CurrentCommand == 0 { + s.RemainingCommand = 5 + } else { + s.RemainingCommand = -1 // set to initial state + s.RemainingContent = -1 + s.RemainingPadding = -1 + if b.Len() > 0 { // shouldn't happen + newbuffer.Write(b.Bytes()) + } + break + } + } + } + b.Release() + b = nil + return newbuffer +} + +// XtlsFilterTls filter and recognize tls 1.3 and other info +func XtlsFilterTls(buffer buf.MultiBuffer, trafficState *TrafficState, ctx context.Context) { + for _, b := range buffer { + if b == nil { + continue + } + trafficState.NumberOfPacketToFilter-- + if b.Len() >= 6 { + startsBytes := b.BytesTo(6) + if bytes.Equal(TlsServerHandShakeStart, startsBytes[:3]) && startsBytes[5] == TlsHandshakeTypeServerHello { + trafficState.RemainingServerHello = (int32(startsBytes[3])<<8 | int32(startsBytes[4])) + 5 + trafficState.IsTLS12orAbove = true + trafficState.IsTLS = true + if b.Len() >= 79 && trafficState.RemainingServerHello >= 79 { + sessionIdLen := int32(b.Byte(43)) + cipherSuite := b.BytesRange(43+sessionIdLen+1, 43+sessionIdLen+3) + trafficState.Cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1]) + } else { + newError("XtlsFilterTls short server hello, tls 1.2 or older? ", b.Len(), " ", trafficState.RemainingServerHello).WriteToLog(session.ExportIDToError(ctx)) + } + } else if bytes.Equal(TlsClientHandShakeStart, startsBytes[:2]) && startsBytes[5] == TlsHandshakeTypeClientHello { + trafficState.IsTLS = true + newError("XtlsFilterTls found tls client hello! ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx)) + } + } + if trafficState.RemainingServerHello > 0 { + end := trafficState.RemainingServerHello + if end > b.Len() { + end = b.Len() + } + trafficState.RemainingServerHello -= b.Len() + if bytes.Contains(b.BytesTo(end), Tls13SupportedVersions) { + v, ok := Tls13CipherSuiteDic[trafficState.Cipher] + if !ok { + v = "Old cipher: " + strconv.FormatUint(uint64(trafficState.Cipher), 16) + } else if v != "TLS_AES_128_CCM_8_SHA256" { + trafficState.EnableXtls = true + } + newError("XtlsFilterTls found tls 1.3! ", b.Len(), " ", v).WriteToLog(session.ExportIDToError(ctx)) + trafficState.NumberOfPacketToFilter = 0 + return + } else if trafficState.RemainingServerHello <= 0 { + newError("XtlsFilterTls found tls 1.2! ", b.Len()).WriteToLog(session.ExportIDToError(ctx)) + trafficState.NumberOfPacketToFilter = 0 + return + } + newError("XtlsFilterTls inconclusive server hello ", b.Len(), " ", trafficState.RemainingServerHello).WriteToLog(session.ExportIDToError(ctx)) + } + if trafficState.NumberOfPacketToFilter <= 0 { + newError("XtlsFilterTls stop filtering", buffer.Len()).WriteToLog(session.ExportIDToError(ctx)) + } + } +} + // UnwrapRawConn support unwrap stats, tls, utls, reality and proxyproto conn and get raw tcp conn from it func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) { var readCounter, writerCounter stats.Counter diff --git a/proxy/vless/encoding/addons.go b/proxy/vless/encoding/addons.go index fc8ddc2a..e3e5071b 100644 --- a/proxy/vless/encoding/addons.go +++ b/proxy/vless/encoding/addons.go @@ -1,10 +1,12 @@ package encoding import ( + "context" "io" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/proxy/vless" "google.golang.org/protobuf/proto" ) @@ -58,14 +60,19 @@ func DecodeHeaderAddons(buffer *buf.Buffer, reader io.Reader) (*Addons, error) { } // EncodeBodyAddons returns a Writer that auto-encrypt content written by caller. -func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, addons *Addons) buf.Writer { - switch addons.Flow { - default: - if request.Command == protocol.RequestCommandUDP { - return NewMultiLengthPacketWriter(writer.(buf.Writer)) +func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, requestAddons *Addons, state *proxy.TrafficState, context context.Context) buf.Writer { + if request.Command == protocol.RequestCommandUDP { + w := writer.(buf.Writer) + if requestAddons.Flow == vless.XRV { + w = proxy.NewVisionWriter(w, state, context) } + return NewMultiLengthPacketWriter(w) } - return buf.NewWriter(writer) + w := buf.NewWriter(writer) + if requestAddons.Flow == vless.XRV { + w = proxy.NewVisionWriter(w, state, context) + } + return w } // DecodeBodyAddons returns a Reader from which caller can fetch decrypted body. diff --git a/proxy/vless/encoding/encoding.go b/proxy/vless/encoding/encoding.go index 48bda497..b7fb66f5 100644 --- a/proxy/vless/encoding/encoding.go +++ b/proxy/vless/encoding/encoding.go @@ -5,11 +5,7 @@ package encoding import ( "bytes" "context" - "crypto/rand" "io" - "math/big" - "strconv" - "time" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -26,30 +22,6 @@ const ( Version = byte(0) ) -var ( - tls13SupportedVersions = []byte{0x00, 0x2b, 0x00, 0x02, 0x03, 0x04} - tlsClientHandShakeStart = []byte{0x16, 0x03} - tlsServerHandShakeStart = []byte{0x16, 0x03, 0x03} - tlsApplicationDataStart = []byte{0x17, 0x03, 0x03} - - Tls13CipherSuiteDic = map[uint16]string{ - 0x1301: "TLS_AES_128_GCM_SHA256", - 0x1302: "TLS_AES_256_GCM_SHA384", - 0x1303: "TLS_CHACHA20_POLY1305_SHA256", - 0x1304: "TLS_AES_128_CCM_SHA256", - 0x1305: "TLS_AES_128_CCM_8_SHA256", - } -) - -const ( - tlsHandshakeTypeClientHello byte = 0x01 - tlsHandshakeTypeServerHello byte = 0x02 - - CommandPaddingContinue byte = 0x00 - CommandPaddingEnd byte = 0x01 - CommandPaddingDirect byte = 0x02 -) - var addrParser = protocol.NewAddressParser( protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4), protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain), @@ -202,18 +174,11 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A } // XtlsRead filter and read xtls protocol -func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, - ctx context.Context, userUUID []byte, numberOfPacketToFilter *int, enableXtls *bool, - isTLS12orAbove *bool, isTLS *bool, cipher *uint16, remainingServerHello *int32, -) error { +func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ctx context.Context) error { err := func() error { - withinPaddingBuffers := true - shouldSwitchToDirectCopy := false - var remainingContent int32 = -1 - var remainingPadding int32 = -1 - currentCommand := 0 + visionReader := proxy.NewVisionReader(reader, trafficState, ctx) for { - if shouldSwitchToDirectCopy { + if trafficState.ReaderSwitchToDirectCopy { var writerConn net.Conn if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil { writerConn = inbound.Conn @@ -223,44 +188,22 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater } return proxy.CopyRawConnIfExist(ctx, conn, writerConn, writer, timer) } - buffer, err := reader.ReadMultiBuffer() + buffer, err := visionReader.ReadMultiBuffer() if !buffer.IsEmpty() { - if withinPaddingBuffers || *numberOfPacketToFilter > 0 { - buffer = XtlsUnpadding(ctx, buffer, userUUID, &remainingContent, &remainingPadding, ¤tCommand) - if remainingContent == 0 && remainingPadding == 0 { - if currentCommand == 1 { - withinPaddingBuffers = false - remainingContent = -1 - remainingPadding = -1 // set to initial state to parse the next padding - } else if currentCommand == 2 { - withinPaddingBuffers = false - shouldSwitchToDirectCopy = true - // XTLS Vision processes struct TLS Conn's input and rawInput - if inputBuffer, err := buf.ReadFrom(input); err == nil { - if !inputBuffer.IsEmpty() { - buffer, _ = buf.MergeMulti(buffer, inputBuffer) - } - } - if rawInputBuffer, err := buf.ReadFrom(rawInput); err == nil { - if !rawInputBuffer.IsEmpty() { - buffer, _ = buf.MergeMulti(buffer, rawInputBuffer) - } - } - } else if currentCommand == 0 { - withinPaddingBuffers = true - } else { - newError("XtlsRead unknown command ", currentCommand, buffer.Len()).WriteToLog(session.ExportIDToError(ctx)) + timer.Update() + if trafficState.ReaderSwitchToDirectCopy { + // XTLS Vision processes struct TLS Conn's input and rawInput + if inputBuffer, err := buf.ReadFrom(input); err == nil { + if !inputBuffer.IsEmpty() { + buffer, _ = buf.MergeMulti(buffer, inputBuffer) + } + } + if rawInputBuffer, err := buf.ReadFrom(rawInput); err == nil { + if !rawInputBuffer.IsEmpty() { + buffer, _ = buf.MergeMulti(buffer, rawInputBuffer) } - } else if remainingContent > 0 || remainingPadding > 0 { - withinPaddingBuffers = true - } else { - withinPaddingBuffers = false } } - if *numberOfPacketToFilter > 0 { - XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, cipher, remainingServerHello, ctx) - } - timer.Update() if werr := writer.WriteMultiBuffer(buffer); werr != nil { return werr } @@ -277,68 +220,27 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater } // XtlsWrite filter and write xtls protocol -func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, - ctx context.Context, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool, - cipher *uint16, remainingServerHello *int32, -) error { +func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ctx context.Context) error { err := func() error { var ct stats.Counter - isPadding := true - shouldSwitchToDirectCopy := false for { buffer, err := reader.ReadMultiBuffer() + if trafficState.WriterSwitchToDirectCopy { + if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.CanSpliceCopy == 2 { + inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter + } + rawConn, _, writerCounter := proxy.UnwrapRawConn(conn) + writer = buf.NewWriter(rawConn) + ct = writerCounter + trafficState.WriterSwitchToDirectCopy = false + } if !buffer.IsEmpty() { - if *numberOfPacketToFilter > 0 { - XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, cipher, remainingServerHello, ctx) + if ct != nil { + ct.Add(int64(buffer.Len())) } - if isPadding { - buffer = ReshapeMultiBuffer(ctx, buffer) - var xtlsSpecIndex int - for i, b := range buffer { - if *isTLS && b.Len() >= 6 && bytes.Equal(tlsApplicationDataStart, b.BytesTo(3)) { - var command byte = CommandPaddingEnd - if *enableXtls { - shouldSwitchToDirectCopy = true - xtlsSpecIndex = i - command = CommandPaddingDirect - } - isPadding = false - buffer[i] = XtlsPadding(b, command, nil, *isTLS, ctx) - break - } else if !*isTLS12orAbove && *numberOfPacketToFilter <= 1 { // For compatibility with earlier vision receiver, we finish padding 1 packet early - isPadding = false - buffer[i] = XtlsPadding(b, CommandPaddingEnd, nil, *isTLS, ctx) - break - } - buffer[i] = XtlsPadding(b, CommandPaddingContinue, nil, *isTLS, ctx) - } - if shouldSwitchToDirectCopy { - encryptBuffer, directBuffer := buf.SplitMulti(buffer, xtlsSpecIndex+1) - if !encryptBuffer.IsEmpty() { - timer.Update() - if werr := writer.WriteMultiBuffer(encryptBuffer); werr != nil { - return werr - } - } - time.Sleep(5 * time.Millisecond) // for some device, the first xtls direct packet fails without this delay - - if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.CanSpliceCopy == 2 { - inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter - } - buffer = directBuffer - rawConn, _, writerCounter := proxy.UnwrapRawConn(conn) - writer = buf.NewWriter(rawConn) - ct = writerCounter - } - } - if !buffer.IsEmpty() { - if ct != nil { - ct.Add(int64(buffer.Len())) - } - timer.Update() - if werr := writer.WriteMultiBuffer(buffer); werr != nil { - return werr - } + timer.Update() + if werr := writer.WriteMultiBuffer(buffer); werr != nil { + return werr } } if err != nil { @@ -351,201 +253,3 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate } return nil } - -// XtlsFilterTls filter and recognize tls 1.3 and other info -func XtlsFilterTls(buffer buf.MultiBuffer, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool, - cipher *uint16, remainingServerHello *int32, ctx context.Context, -) { - for _, b := range buffer { - *numberOfPacketToFilter-- - if b.Len() >= 6 { - startsBytes := b.BytesTo(6) - if bytes.Equal(tlsServerHandShakeStart, startsBytes[:3]) && startsBytes[5] == tlsHandshakeTypeServerHello { - *remainingServerHello = (int32(startsBytes[3])<<8 | int32(startsBytes[4])) + 5 - *isTLS12orAbove = true - *isTLS = true - if b.Len() >= 79 && *remainingServerHello >= 79 { - sessionIdLen := int32(b.Byte(43)) - cipherSuite := b.BytesRange(43+sessionIdLen+1, 43+sessionIdLen+3) - *cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1]) - } else { - newError("XtlsFilterTls short server hello, tls 1.2 or older? ", b.Len(), " ", *remainingServerHello).WriteToLog(session.ExportIDToError(ctx)) - } - } else if bytes.Equal(tlsClientHandShakeStart, startsBytes[:2]) && startsBytes[5] == tlsHandshakeTypeClientHello { - *isTLS = true - newError("XtlsFilterTls found tls client hello! ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx)) - } - } - if *remainingServerHello > 0 { - end := *remainingServerHello - if end > b.Len() { - end = b.Len() - } - *remainingServerHello -= b.Len() - if bytes.Contains(b.BytesTo(end), tls13SupportedVersions) { - v, ok := Tls13CipherSuiteDic[*cipher] - if !ok { - v = "Old cipher: " + strconv.FormatUint(uint64(*cipher), 16) - } else if v != "TLS_AES_128_CCM_8_SHA256" { - *enableXtls = true - } - newError("XtlsFilterTls found tls 1.3! ", b.Len(), " ", v).WriteToLog(session.ExportIDToError(ctx)) - *numberOfPacketToFilter = 0 - return - } else if *remainingServerHello <= 0 { - newError("XtlsFilterTls found tls 1.2! ", b.Len()).WriteToLog(session.ExportIDToError(ctx)) - *numberOfPacketToFilter = 0 - return - } - newError("XtlsFilterTls inconclusive server hello ", b.Len(), " ", *remainingServerHello).WriteToLog(session.ExportIDToError(ctx)) - } - if *numberOfPacketToFilter <= 0 { - newError("XtlsFilterTls stop filtering", buffer.Len()).WriteToLog(session.ExportIDToError(ctx)) - } - } -} - -// ReshapeMultiBuffer prepare multi buffer for padding stucture (max 21 bytes) -func ReshapeMultiBuffer(ctx context.Context, buffer buf.MultiBuffer) buf.MultiBuffer { - needReshape := 0 - for _, b := range buffer { - if b.Len() >= buf.Size-21 { - needReshape += 1 - } - } - if needReshape == 0 { - return buffer - } - mb2 := make(buf.MultiBuffer, 0, len(buffer)+needReshape) - toPrint := "" - for i, buffer1 := range buffer { - if buffer1.Len() >= buf.Size-21 { - index := int32(bytes.LastIndex(buffer1.Bytes(), tlsApplicationDataStart)) - if index <= 0 || index > buf.Size-21 { - index = buf.Size / 2 - } - buffer2 := buf.New() - buffer2.Write(buffer1.BytesFrom(index)) - buffer1.Resize(0, index) - mb2 = append(mb2, buffer1, buffer2) - toPrint += " " + strconv.Itoa(int(buffer1.Len())) + " " + strconv.Itoa(int(buffer2.Len())) - } else { - mb2 = append(mb2, buffer1) - toPrint += " " + strconv.Itoa(int(buffer1.Len())) - } - buffer[i] = nil - } - buffer = buffer[:0] - newError("ReshapeMultiBuffer ", toPrint).WriteToLog(session.ExportIDToError(ctx)) - return mb2 -} - -// XtlsPadding add padding to eliminate length siganature during tls handshake -func XtlsPadding(b *buf.Buffer, command byte, userUUID *[]byte, longPadding bool, ctx context.Context) *buf.Buffer { - var contentLen int32 = 0 - var paddingLen int32 = 0 - if b != nil { - contentLen = b.Len() - } - if contentLen < 900 && longPadding { - l, err := rand.Int(rand.Reader, big.NewInt(500)) - if err != nil { - newError("failed to generate padding").Base(err).WriteToLog(session.ExportIDToError(ctx)) - } - paddingLen = int32(l.Int64()) + 900 - contentLen - } else { - l, err := rand.Int(rand.Reader, big.NewInt(256)) - if err != nil { - newError("failed to generate padding").Base(err).WriteToLog(session.ExportIDToError(ctx)) - } - paddingLen = int32(l.Int64()) - } - if paddingLen > buf.Size-21-contentLen { - paddingLen = buf.Size - 21 - contentLen - } - newbuffer := buf.New() - if userUUID != nil { - newbuffer.Write(*userUUID) - *userUUID = nil - } - newbuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)}) - if b != nil { - newbuffer.Write(b.Bytes()) - b.Release() - b = nil - } - newbuffer.Extend(paddingLen) - newError("XtlsPadding ", contentLen, " ", paddingLen, " ", command).WriteToLog(session.ExportIDToError(ctx)) - return newbuffer -} - -// XtlsUnpadding remove padding and parse command -func XtlsUnpadding(ctx context.Context, buffer buf.MultiBuffer, userUUID []byte, remainingContent *int32, remainingPadding *int32, currentCommand *int) buf.MultiBuffer { - posindex := 0 - var posByte int32 = 0 - if *remainingContent == -1 && *remainingPadding == -1 { - for i, b := range buffer { - if b.Len() >= 21 && bytes.Equal(userUUID, b.BytesTo(16)) { - posindex = i - posByte = 16 - *remainingContent = 0 - *remainingPadding = 0 - *currentCommand = 0 - break - } - } - } - if *remainingContent == -1 && *remainingPadding == -1 { - return buffer - } - mb2 := make(buf.MultiBuffer, 0, len(buffer)) - for i := 0; i < posindex; i++ { - newbuffer := buf.New() - newbuffer.Write(buffer[i].Bytes()) - mb2 = append(mb2, newbuffer) - } - for i := posindex; i < len(buffer); i++ { - b := buffer[i] - for posByte < b.Len() { - if *remainingContent <= 0 && *remainingPadding <= 0 { - if *currentCommand == 1 { // possible buffer after padding, no need to worry about xtls (command 2) - len := b.Len() - posByte - newbuffer := buf.New() - newbuffer.Write(b.BytesRange(posByte, posByte+len)) - mb2 = append(mb2, newbuffer) - posByte += len - } else { - paddingInfo := b.BytesRange(posByte, posByte+5) - *currentCommand = int(paddingInfo[0]) - *remainingContent = int32(paddingInfo[1])<<8 | int32(paddingInfo[2]) - *remainingPadding = int32(paddingInfo[3])<<8 | int32(paddingInfo[4]) - newError("Xtls Unpadding new block", i, " ", posByte, " content ", *remainingContent, " padding ", *remainingPadding, " ", paddingInfo[0]).WriteToLog(session.ExportIDToError(ctx)) - posByte += 5 - } - } else if *remainingContent > 0 { - len := *remainingContent - if b.Len() < posByte+*remainingContent { - len = b.Len() - posByte - } - newbuffer := buf.New() - newbuffer.Write(b.BytesRange(posByte, posByte+len)) - mb2 = append(mb2, newbuffer) - *remainingContent -= len - posByte += len - } else { // remainingPadding > 0 - len := *remainingPadding - if b.Len() < posByte+*remainingPadding { - len = b.Len() - posByte - } - *remainingPadding -= len - posByte += len - } - if posByte == b.Len() { - posByte = 0 - break - } - } - } - buf.ReleaseMulti(buffer) - return mb2 -} diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 388aeecb..4cd3fcb1 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -28,6 +28,7 @@ import ( feature_inbound "github.com/xtls/xray-core/features/inbound" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" + "github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/proxy/vless" "github.com/xtls/xray-core/proxy/vless/encoding" "github.com/xtls/xray-core/transport/internet/reality" @@ -510,13 +511,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s serverReader := link.Reader // .(*pipe.Reader) serverWriter := link.Writer // .(*pipe.Writer) - enableXtls := false - isTLS12orAbove := false - isTLS := false - var cipher uint16 = 0 - var remainingServerHello int32 = -1 - numberOfPacketToFilter := 8 - + trafficState := proxy.NewTrafficState(account.ID.Bytes()) postRequest := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) @@ -527,8 +522,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s if requestAddons.Flow == vless.XRV { ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice - err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, ctx1, account.ID.Bytes(), - &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) + err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, ctx1) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) @@ -550,19 +544,11 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s } // default: clientWriter := bufferWriter - clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, responseAddons) - userUUID := account.ID.Bytes() + clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, ctx) multiBuffer, err1 := serverReader.ReadMultiBuffer() if err1 != nil { return err1 // ... } - if requestAddons.Flow == vless.XRV { - encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello, ctx) - multiBuffer = encoding.ReshapeMultiBuffer(ctx, multiBuffer) - for i, b := range multiBuffer { - multiBuffer[i] = encoding.XtlsPadding(b, encoding.CommandPaddingContinue, &userUUID, isTLS, ctx) - } - } if err := clientWriter.WriteMultiBuffer(multiBuffer); err != nil { return err // ... } @@ -573,8 +559,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s var err error if requestAddons.Flow == vless.XRV { - err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, ctx, &numberOfPacketToFilter, - &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) + err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, ctx) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index bc2e6625..cd30617c 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -22,6 +22,7 @@ import ( "github.com/xtls/xray-core/common/xudp" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/policy" + "github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/proxy/vless" "github.com/xtls/xray-core/proxy/vless/encoding" "github.com/xtls/xray-core/transport" @@ -183,13 +184,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte clientReader := link.Reader // .(*pipe.Reader) clientWriter := link.Writer // .(*pipe.Writer) - enableXtls := false - isTLS12orAbove := false - isTLS := false - var cipher uint16 = 0 - var remainingServerHello int32 = -1 - numberOfPacketToFilter := 8 - + trafficState := proxy.NewTrafficState(account.ID.Bytes()) if request.Command == protocol.RequestCommandUDP && h.cone && request.Port != 53 && request.Port != 443 { request.Command = protocol.RequestCommandMux request.Address = net.DomainAddress("v1.mux.cool") @@ -205,22 +200,14 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } // default: serverWriter := bufferWriter - serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons) + serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, ctx) if request.Command == protocol.RequestCommandMux && request.Port == 666 { serverWriter = xudp.NewPacketWriter(serverWriter, target, xudp.GetGlobalID(ctx)) } - userUUID := account.ID.Bytes() timeoutReader, ok := clientReader.(buf.TimeoutReader) if ok { multiBuffer, err1 := timeoutReader.ReadMultiBufferTimeout(time.Millisecond * 500) if err1 == nil { - if requestAddons.Flow == vless.XRV { - encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello, ctx) - multiBuffer = encoding.ReshapeMultiBuffer(ctx, multiBuffer) - for i, b := range multiBuffer { - multiBuffer[i] = encoding.XtlsPadding(b, encoding.CommandPaddingContinue, &userUUID, isTLS, ctx) - } - } if err := serverWriter.WriteMultiBuffer(multiBuffer); err != nil { return err // ... } @@ -228,10 +215,9 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte return err1 } else if requestAddons.Flow == vless.XRV { mb := make(buf.MultiBuffer, 1) - mb[0] = encoding.XtlsPadding(nil, encoding.CommandPaddingContinue, &userUUID, true, ctx) // we do a long padding to hide vless header newError("Insert padding with empty content to camouflage VLESS header ", mb.Len()).WriteToLog(session.ExportIDToError(ctx)) if err := serverWriter.WriteMultiBuffer(mb); err != nil { - return err + return err // ... } } } else { @@ -254,8 +240,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } } ctx1 := session.ContextWithOutbound(ctx, nil) // TODO enable splice - err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, ctx1, &numberOfPacketToFilter, - &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) + err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ctx1) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) @@ -286,8 +271,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } if requestAddons.Flow == vless.XRV { - err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, ctx, account.ID.Bytes(), - &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) + err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ctx) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))