QUIC sniffer: Full support for handling multiple initial packets (#4642)

Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
Co-authored-by: Vigilans <vigilans@foxmail.com>
Co-authored-by: Shelikhoo <xiaokangwang@outlook.com>
Co-authored-by: dyhkwong <50692134+dyhkwong@users.noreply.github.com>
This commit is contained in:
j2rong4cn 2025-04-28 18:03:03 +08:00 committed by GitHub
parent a608c5a1db
commit 58c48664e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 371 additions and 95 deletions

View file

@ -33,23 +33,21 @@ type cachedReader struct {
cache buf.MultiBuffer
}
func (r *cachedReader) Cache(b *buf.Buffer) {
mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100)
func (r *cachedReader) Cache(b *buf.Buffer, deadline time.Duration) error {
mb, err := r.reader.ReadMultiBufferTimeout(deadline)
if err != nil {
return err
}
r.Lock()
if !mb.IsEmpty() {
r.cache, _ = buf.MergeMulti(r.cache, mb)
}
cacheLen := r.cache.Len()
if cacheLen <= b.Cap() {
b.Clear()
} else {
b.Release()
*b = *buf.NewWithSize(cacheLen)
}
rawBytes := b.Extend(cacheLen)
b.Clear()
rawBytes := b.Extend(b.Cap())
n := r.cache.Copy(rawBytes)
b.Resize(0, int32(n))
r.Unlock()
return nil
}
func (r *cachedReader) readInternal() buf.MultiBuffer {
@ -355,7 +353,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
}
func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
payload := buf.New()
payload := buf.NewWithSize(32767)
defer payload.Release()
sniffer := NewSniffer(ctx)
@ -367,26 +365,33 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw
}
contentResult, contentErr := func() (SniffResult, error) {
cacheDeadline := 200 * time.Millisecond
totalAttempt := 0
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
totalAttempt++
if totalAttempt > 2 {
return nil, errSniffingTimeout
}
cachingStartingTimeStamp := time.Now()
cacheErr := cReader.Cache(payload, cacheDeadline)
cachingTimeElapsed := time.Since(cachingStartingTimeStamp)
cacheDeadline -= cachingTimeElapsed
cReader.Cache(payload)
if !payload.IsEmpty() {
result, err := sniffer.Sniff(ctx, payload.Bytes(), network)
if err != common.ErrNoClue {
switch err {
case common.ErrNoClue: // No Clue: protocol not matches, and sniffer cannot determine whether there will be a match or not
totalAttempt++
case protocol.ErrProtoNeedMoreData: // Protocol Need More Data: protocol matches, but need more data to complete sniffing
if cacheErr != nil { // Cache error (e.g. timeout) counts for failed attempt
totalAttempt++
}
default:
return result, err
}
}
if payload.IsFull() {
return nil, errUnknownContent
if totalAttempt >= 2 || cacheDeadline <= 0 {
return nil, errSniffingTimeout
}
}
}

View file

@ -6,6 +6,7 @@ import (
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/protocol/bittorrent"
"github.com/xtls/xray-core/common/protocol/http"
"github.com/xtls/xray-core/common/protocol/quic"
@ -58,14 +59,17 @@ var errUnknownContent = errors.New("unknown content")
func (s *Sniffer) Sniff(c context.Context, payload []byte, network net.Network) (SniffResult, error) {
var pendingSniffer []protocolSnifferWithMetadata
for _, si := range s.sniffer {
s := si.protocolSniffer
protocolSniffer := si.protocolSniffer
if si.metadataSniffer || si.network != network {
continue
}
result, err := s(c, payload)
result, err := protocolSniffer(c, payload)
if err == common.ErrNoClue {
pendingSniffer = append(pendingSniffer, si)
continue
} else if err == protocol.ErrProtoNeedMoreData { // Sniffer protocol matched, but need more data to complete sniffing
s.sniffer = []protocolSnifferWithMetadata{si}
return nil, err
}
if err == nil && result != nil {

View file

@ -15,6 +15,15 @@ const (
var pool = bytespool.GetPool(Size)
// ownership represents the data owner of the buffer.
type ownership uint8
const (
managed ownership = iota
unmanaged
bytespools
)
// Buffer is a recyclable allocation of a byte array. Buffer.Release() recycles
// the buffer into an internal buffer pool, in order to recreate a buffer more
// quickly.
@ -22,11 +31,11 @@ type Buffer struct {
v []byte
start int32
end int32
unmanaged bool
ownership ownership
UDP *net.Destination
}
// New creates a Buffer with 0 length and 8K capacity.
// New creates a Buffer with 0 length and 8K capacity, managed.
func New() *Buffer {
buf := pool.Get().([]byte)
if cap(buf) >= Size {
@ -40,7 +49,7 @@ func New() *Buffer {
}
}
// NewExisted creates a managed, standard size Buffer with an existed bytearray
// NewExisted creates a standard size Buffer with an existed bytearray, managed.
func NewExisted(b []byte) *Buffer {
if cap(b) < Size {
panic("Invalid buffer")
@ -57,16 +66,16 @@ func NewExisted(b []byte) *Buffer {
}
}
// FromBytes creates a Buffer with an existed bytearray
// FromBytes creates a Buffer with an existed bytearray, unmanaged.
func FromBytes(b []byte) *Buffer {
return &Buffer{
v: b,
end: int32(len(b)),
unmanaged: true,
ownership: unmanaged,
}
}
// StackNew creates a new Buffer object on stack.
// StackNew creates a new Buffer object on stack, managed.
// This method is for buffers that is released in the same function.
func StackNew() Buffer {
buf := pool.Get().([]byte)
@ -81,9 +90,17 @@ func StackNew() Buffer {
}
}
// NewWithSize creates a Buffer with 0 length and capacity with at least the given size, bytespool's.
func NewWithSize(size int32) *Buffer {
return &Buffer{
v: bytespool.Alloc(size),
ownership: bytespools,
}
}
// Release recycles the buffer into an internal buffer pool.
func (b *Buffer) Release() {
if b == nil || b.v == nil || b.unmanaged {
if b == nil || b.v == nil || b.ownership == unmanaged {
return
}
@ -91,8 +108,13 @@ func (b *Buffer) Release() {
b.v = nil
b.Clear()
if cap(p) == Size {
pool.Put(p)
switch b.ownership {
case managed:
if cap(p) == Size {
pool.Put(p)
}
case bytespools:
bytespool.Free(p)
}
b.UDP = nil
}
@ -215,13 +237,6 @@ func (b *Buffer) Cap() int32 {
return int32(len(b.v))
}
// NewWithSize creates a Buffer with 0 length and capacity with at least the given size.
func NewWithSize(size int32) *Buffer {
return &Buffer{
v: bytespool.Alloc(size),
}
}
// IsEmpty returns true if the buffer is empty.
func (b *Buffer) IsEmpty() bool {
return b.Len() == 0

View file

@ -1 +1,7 @@
package protocol // import "github.com/xtls/xray-core/common/protocol"
import (
"errors"
)
var ErrProtoNeedMoreData = errors.New("protocol matches, but need more data to complete sniffing")

View file

@ -1,7 +1,6 @@
package quic
import (
"context"
"crypto"
"crypto/aes"
"crypto/tls"
@ -13,6 +12,7 @@ import (
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/bytespool"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/protocol"
ptls "github.com/xtls/xray-core/common/protocol/tls"
"golang.org/x/crypto/hkdf"
)
@ -47,22 +47,17 @@ var (
errNotQuicInitial = errors.New("not initial packet")
)
func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
// In extremely rare cases, this sniffer may cause slice error
// and we set recover() here to prevent crash.
// TODO: Thoroughly fix this panic
defer func() {
if r := recover(); r != nil {
errors.LogError(context.Background(), "Failed to sniff QUIC: ", r)
resultReturn = nil
errorReturn = common.ErrNoClue
}
}()
func SniffQUIC(b []byte) (*SniffHeader, error) {
if len(b) == 0 {
return nil, common.ErrNoClue
}
// Crypto data separated across packets
cryptoLen := 0
cryptoData := bytespool.Alloc(int32(len(b)))
cryptoData := bytespool.Alloc(32767)
defer bytespool.Free(cryptoData)
cache := buf.New()
defer cache.Release()
// Parse QUIC packets
for len(b) > 0 {
@ -105,13 +100,15 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
return nil, errNotQuic
}
tokenLen, err := quicvarint.Read(buffer)
if err != nil || tokenLen > uint64(len(b)) {
return nil, errNotQuic
}
if isQuicInitial { // Only initial packets have token, see https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.2
tokenLen, err := quicvarint.Read(buffer)
if err != nil || tokenLen > uint64(len(b)) {
return nil, errNotQuic
}
if _, err = buffer.ReadBytes(int32(tokenLen)); err != nil {
return nil, errNotQuic
if _, err = buffer.ReadBytes(int32(tokenLen)); err != nil {
return nil, errNotQuic
}
}
packetLen, err := quicvarint.Read(buffer)
@ -130,9 +127,6 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
continue
}
origPNBytes := make([]byte, 4)
copy(origPNBytes, b[hdrLen:hdrLen+4])
var salt []byte
if versionNumber == version1 {
salt = quicSalt
@ -147,44 +141,34 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
return nil, err
}
cache := buf.New()
defer cache.Release()
cache.Clear()
mask := cache.Extend(int32(block.BlockSize()))
block.Encrypt(mask, b[hdrLen+4:hdrLen+4+16])
b[0] ^= mask[0] & 0xf
for i := range b[hdrLen : hdrLen+4] {
packetNumberLength := int(b[0]&0x3 + 1)
for i := range packetNumberLength {
b[hdrLen+i] ^= mask[i+1]
}
packetNumberLength := b[0]&0x3 + 1
if packetNumberLength != 1 {
return nil, errNotQuicInitial
}
var packetNumber uint32
{
n, err := buffer.ReadByte()
if err != nil {
return nil, err
}
packetNumber = uint32(n)
}
extHdrLen := hdrLen + int(packetNumberLength)
copy(b[extHdrLen:hdrLen+4], origPNBytes[packetNumberLength:])
data := b[extHdrLen : int(packetLen)+hdrLen]
key := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16)
iv := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12)
cipher := AEADAESGCMTLS13(key, iv)
nonce := cache.Extend(int32(cipher.NonceSize()))
binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber))
_, err = buffer.Read(nonce[len(nonce)-packetNumberLength:])
if err != nil {
return nil, err
}
extHdrLen := hdrLen + packetNumberLength
data := b[extHdrLen : int(packetLen)+hdrLen]
decrypted, err := cipher.Open(b[extHdrLen:extHdrLen], nonce, data, b[:extHdrLen])
if err != nil {
return nil, err
}
buffer = buf.FromBytes(decrypted)
for i := 0; !buffer.IsEmpty(); i++ {
frameType := byte(0x0) // Default to PADDING frame
for !buffer.IsEmpty() {
frameType, _ := buffer.ReadByte()
for frameType == 0x0 && !buffer.IsEmpty() {
frameType, _ = buffer.ReadByte()
}
@ -234,13 +218,12 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
return nil, io.ErrUnexpectedEOF
}
if cryptoLen < int(offset+length) {
cryptoLen = int(offset + length)
if len(cryptoData) < cryptoLen {
newCryptoData := bytespool.Alloc(int32(cryptoLen))
copy(newCryptoData, cryptoData)
bytespool.Free(cryptoData)
cryptoData = newCryptoData
newCryptoLen := int(offset + length)
if len(cryptoData) < newCryptoLen {
return nil, io.ErrShortBuffer
}
wipeBytes(cryptoData[cryptoLen:newCryptoLen])
cryptoLen = newCryptoLen
}
if _, err := buffer.Read(cryptoData[offset : offset+length]); err != nil { // Field: Crypto Data
return nil, io.ErrUnexpectedEOF
@ -276,7 +259,14 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
}
return &SniffHeader{domain: tlsHdr.Domain()}, nil
}
return nil, common.ErrNoClue
// All payload is parsed as valid QUIC packets, but we need more packets for crypto data to read client hello.
return nil, protocol.ErrProtoNeedMoreData
}
func wipeBytes(b []byte) {
for i := range len(b) {
b[i] = 0x0
}
}
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {

File diff suppressed because one or more lines are too long

View file

@ -3,9 +3,9 @@ package tls
import (
"encoding/binary"
"errors"
"strings"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/protocol"
)
type SniffHeader struct {
@ -59,9 +59,6 @@ func ReadClientHello(data []byte, h *SniffHeader) error {
}
data = data[1+compressionMethodsLen:]
if len(data) == 0 {
return errNotClientHello
}
if len(data) < 2 {
return errNotClientHello
}
@ -104,13 +101,21 @@ func ReadClientHello(data []byte, h *SniffHeader) error {
return errNotClientHello
}
if nameType == 0 {
serverName := string(d[:nameLen])
// QUIC separated across packets
// May cause the serverName to be incomplete
b := byte(0)
for _, b = range d[:nameLen] {
if b <= ' ' {
return protocol.ErrProtoNeedMoreData
}
}
// An SNI value may not include a
// trailing dot. See
// https://tools.ietf.org/html/rfc6066#section-3.
if strings.HasSuffix(serverName, ".") {
if b == '.' {
return errNotClientHello
}
serverName := string(d[:nameLen])
h.domain = serverName
return nil
}