diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index f38b56c9..de8d1913 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -41,8 +41,14 @@ func (r *cachedReader) Cache(b *buf.Buffer) { if !mb.IsEmpty() { r.cache, _ = buf.MergeMulti(r.cache, mb) } - b.Clear() - rawBytes := b.Extend(buf.Size) + cacheLen := r.cache.Len() + if cacheLen <= b.Cap() { + b.Clear() + } else { + b.Release() + *b = *buf.NewWithSize(cacheLen) + } + rawBytes := b.Extend(cacheLen) n := r.cache.Copy(rawBytes) b.Resize(0, int32(n)) r.Unlock() diff --git a/common/buf/buffer.go b/common/buf/buffer.go index 57105d67..15d4d4a5 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -207,6 +207,21 @@ func (b *Buffer) Len() int32 { return b.end - b.start } +// Cap returns the capacity of the buffer content. +func (b *Buffer) Cap() int32 { + if b == nil { + return 0 + } + return int32(len(b.v)) +} + +// NewWithSize creates a Buffer with 0 length and capacity with at least the given size. +func NewWithSize(size int32) *Buffer { + return &Buffer{ + v: bytespool.Alloc(size), + } +} + // IsEmpty returns true if the buffer is empty. func (b *Buffer) IsEmpty() bool { return b.Len() == 0 diff --git a/common/protocol/quic/sniff.go b/common/protocol/quic/sniff.go index bf461464..8719a085 100644 --- a/common/protocol/quic/sniff.go +++ b/common/protocol/quic/sniff.go @@ -47,206 +47,224 @@ var ( ) func SniffQUIC(b []byte) (*SniffHeader, error) { - buffer := buf.FromBytes(b) - typeByte, err := buffer.ReadByte() - if err != nil { - return nil, errNotQuic - } - isLongHeader := typeByte&0x80 > 0 - if !isLongHeader || typeByte&0x40 == 0 { - return nil, errNotQuicInitial - } + // Crypto data separated across packets + cryptoLen := 0 + cryptoData := bytespool.Alloc(int32(len(b))) + defer bytespool.Free(cryptoData) - vb, err := buffer.ReadBytes(4) - if err != nil { - return nil, errNotQuic - } + // Parse QUIC packets + for len(b) > 0 { + buffer := buf.FromBytes(b) + typeByte, err := buffer.ReadByte() + if err != nil { + return nil, errNotQuic + } - versionNumber := binary.BigEndian.Uint32(vb) + isLongHeader := typeByte&0x80 > 0 + if !isLongHeader || typeByte&0x40 == 0 { + return nil, errNotQuicInitial + } - if versionNumber != 0 && typeByte&0x40 == 0 { - return nil, errNotQuic - } else if versionNumber != versionDraft29 && versionNumber != version1 { - return nil, errNotQuic - } + vb, err := buffer.ReadBytes(4) + if err != nil { + return nil, errNotQuic + } - if (typeByte&0x30)>>4 != 0x0 { - return nil, errNotQuicInitial - } + versionNumber := binary.BigEndian.Uint32(vb) + if versionNumber != 0 && typeByte&0x40 == 0 { + return nil, errNotQuic + } else if versionNumber != versionDraft29 && versionNumber != version1 { + return nil, errNotQuic + } - var destConnID []byte - if l, err := buffer.ReadByte(); err != nil { - return nil, errNotQuic - } else if destConnID, err = buffer.ReadBytes(int32(l)); err != nil { - return nil, errNotQuic - } + packetType := (typeByte & 0x30) >> 4 + isQuicInitial := packetType == 0x0 - if l, err := buffer.ReadByte(); err != nil { - return nil, errNotQuic - } else if common.Error2(buffer.ReadBytes(int32(l))) != nil { - return nil, errNotQuic - } + var destConnID []byte + if l, err := buffer.ReadByte(); err != nil { + return nil, errNotQuic + } else if destConnID, err = buffer.ReadBytes(int32(l)); err != nil { + return nil, errNotQuic + } - tokenLen, err := quicvarint.Read(buffer) - if err != nil || tokenLen > uint64(len(b)) { - return nil, errNotQuic - } + if l, err := buffer.ReadByte(); err != nil { + return nil, errNotQuic + } else if common.Error2(buffer.ReadBytes(int32(l))) != nil { + return nil, errNotQuic + } - if _, err = buffer.ReadBytes(int32(tokenLen)); err != nil { - return nil, errNotQuic - } + tokenLen, err := quicvarint.Read(buffer) + if err != nil || tokenLen > uint64(len(b)) { + return nil, errNotQuic + } - packetLen, err := quicvarint.Read(buffer) - if err != nil { - return nil, errNotQuic - } + if _, err = buffer.ReadBytes(int32(tokenLen)); err != nil { + return nil, errNotQuic + } - hdrLen := len(b) - int(buffer.Len()) + packetLen, err := quicvarint.Read(buffer) + if err != nil { + return nil, errNotQuic + } - origPNBytes := make([]byte, 4) - copy(origPNBytes, b[hdrLen:hdrLen+4]) + hdrLen := len(b) - int(buffer.Len()) + if len(b) < hdrLen+int(packetLen) { + return nil, common.ErrNoClue // Not enough data to read as a QUIC packet. QUIC is UDP-based, so this is unlikely to happen. + } - var salt []byte - if versionNumber == version1 { - salt = quicSalt - } else { - salt = quicSaltOld - } - initialSecret := hkdf.Extract(crypto.SHA256.New, destConnID, salt) - secret := hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) - hpKey := hkdfExpandLabel(initialSuite.Hash, secret, []byte{}, "quic hp", initialSuite.KeyLen) - block, err := aes.NewCipher(hpKey) - if err != nil { - return nil, err - } + restPayload := b[hdrLen+int(packetLen):] + if !isQuicInitial { // Skip this packet if it's not initial packet + b = restPayload + continue + } - cache := buf.New() - defer cache.Release() + origPNBytes := make([]byte, 4) + copy(origPNBytes, b[hdrLen:hdrLen+4]) - mask := cache.Extend(int32(block.BlockSize())) - block.Encrypt(mask, b[hdrLen+4:hdrLen+4+16]) - b[0] ^= mask[0] & 0xf - for i := range b[hdrLen : hdrLen+4] { - b[hdrLen+i] ^= mask[i+1] - } - packetNumberLength := b[0]&0x3 + 1 - if packetNumberLength != 1 { - return nil, errNotQuicInitial - } - var packetNumber uint32 - { - n, err := buffer.ReadByte() + var salt []byte + if versionNumber == version1 { + salt = quicSalt + } else { + salt = quicSaltOld + } + initialSecret := hkdf.Extract(crypto.SHA256.New, destConnID, salt) + secret := hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) + hpKey := hkdfExpandLabel(initialSuite.Hash, secret, []byte{}, "quic hp", initialSuite.KeyLen) + block, err := aes.NewCipher(hpKey) if err != nil { return nil, err } - packetNumber = uint32(n) - } - if packetNumber != 0 && packetNumber != 1 { - return nil, errNotQuicInitial - } + cache := buf.New() + defer cache.Release() - extHdrLen := hdrLen + int(packetNumberLength) - copy(b[extHdrLen:hdrLen+4], origPNBytes[packetNumberLength:]) - data := b[extHdrLen : int(packetLen)+hdrLen] - - key := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16) - iv := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12) - cipher := AEADAESGCMTLS13(key, iv) - nonce := cache.Extend(int32(cipher.NonceSize())) - binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber)) - decrypted, err := cipher.Open(b[extHdrLen:extHdrLen], nonce, data, b[:extHdrLen]) - if err != nil { - return nil, err - } - buffer = buf.FromBytes(decrypted) - - cryptoLen := uint(0) - cryptoData := bytespool.Alloc(buffer.Len()) - defer bytespool.Free(cryptoData) - for i := 0; !buffer.IsEmpty(); i++ { - frameType := byte(0x0) // Default to PADDING frame - for frameType == 0x0 && !buffer.IsEmpty() { - frameType, _ = buffer.ReadByte() + mask := cache.Extend(int32(block.BlockSize())) + block.Encrypt(mask, b[hdrLen+4:hdrLen+4+16]) + b[0] ^= mask[0] & 0xf + for i := range b[hdrLen : hdrLen+4] { + b[hdrLen+i] ^= mask[i+1] } - switch frameType { - case 0x00: // PADDING frame - case 0x01: // PING frame - case 0x02, 0x03: // ACK frame - if _, err = quicvarint.Read(buffer); err != nil { // Field: Largest Acknowledged - return nil, io.ErrUnexpectedEOF - } - if _, err = quicvarint.Read(buffer); err != nil { // Field: ACK Delay - return nil, io.ErrUnexpectedEOF - } - ackRangeCount, err := quicvarint.Read(buffer) // Field: ACK Range Count - if err != nil { - return nil, io.ErrUnexpectedEOF - } - if _, err = quicvarint.Read(buffer); err != nil { // Field: First ACK Range - return nil, io.ErrUnexpectedEOF - } - for i := 0; i < int(ackRangeCount); i++ { // Field: ACK Range - if _, err = quicvarint.Read(buffer); err != nil { // Field: ACK Range -> Gap - return nil, io.ErrUnexpectedEOF - } - if _, err = quicvarint.Read(buffer); err != nil { // Field: ACK Range -> ACK Range Length - return nil, io.ErrUnexpectedEOF - } - } - if frameType == 0x03 { - if _, err = quicvarint.Read(buffer); err != nil { // Field: ECN Counts -> ECT0 Count - return nil, io.ErrUnexpectedEOF - } - if _, err = quicvarint.Read(buffer); err != nil { // Field: ECN Counts -> ECT1 Count - return nil, io.ErrUnexpectedEOF - } - if _, err = quicvarint.Read(buffer); err != nil { //nolint:misspell // Field: ECN Counts -> ECT-CE Count - return nil, io.ErrUnexpectedEOF - } - } - case 0x06: // CRYPTO frame, we will use this frame - offset, err := quicvarint.Read(buffer) // Field: Offset - if err != nil { - return nil, io.ErrUnexpectedEOF - } - length, err := quicvarint.Read(buffer) // Field: Length - if err != nil || length > uint64(buffer.Len()) { - return nil, io.ErrUnexpectedEOF - } - if cryptoLen < uint(offset+length) { - cryptoLen = uint(offset + length) - } - if _, err := buffer.Read(cryptoData[offset : offset+length]); err != nil { // Field: Crypto Data - return nil, io.ErrUnexpectedEOF - } - case 0x1c: // CONNECTION_CLOSE frame, only 0x1c is permitted in initial packet - if _, err = quicvarint.Read(buffer); err != nil { // Field: Error Code - return nil, io.ErrUnexpectedEOF - } - if _, err = quicvarint.Read(buffer); err != nil { // Field: Frame Type - return nil, io.ErrUnexpectedEOF - } - length, err := quicvarint.Read(buffer) // Field: Reason Phrase Length - if err != nil { - return nil, io.ErrUnexpectedEOF - } - if _, err := buffer.ReadBytes(int32(length)); err != nil { // Field: Reason Phrase - return nil, io.ErrUnexpectedEOF - } - default: - // Only above frame types are permitted in initial packet. - // See https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2.2-8 + packetNumberLength := b[0]&0x3 + 1 + if packetNumberLength != 1 { return nil, errNotQuicInitial } - } + var packetNumber uint32 + { + n, err := buffer.ReadByte() + if err != nil { + return nil, err + } + packetNumber = uint32(n) + } - tlsHdr := &ptls.SniffHeader{} - err = ptls.ReadClientHello(cryptoData[:cryptoLen], tlsHdr) - if err != nil { - return nil, err + extHdrLen := hdrLen + int(packetNumberLength) + copy(b[extHdrLen:hdrLen+4], origPNBytes[packetNumberLength:]) + data := b[extHdrLen : int(packetLen)+hdrLen] + + key := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16) + iv := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12) + cipher := AEADAESGCMTLS13(key, iv) + nonce := cache.Extend(int32(cipher.NonceSize())) + binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber)) + decrypted, err := cipher.Open(b[extHdrLen:extHdrLen], nonce, data, b[:extHdrLen]) + if err != nil { + return nil, err + } + buffer = buf.FromBytes(decrypted) + for i := 0; !buffer.IsEmpty(); i++ { + frameType := byte(0x0) // Default to PADDING frame + for frameType == 0x0 && !buffer.IsEmpty() { + frameType, _ = buffer.ReadByte() + } + switch frameType { + case 0x00: // PADDING frame + case 0x01: // PING frame + case 0x02, 0x03: // ACK frame + if _, err = quicvarint.Read(buffer); err != nil { // Field: Largest Acknowledged + return nil, io.ErrUnexpectedEOF + } + if _, err = quicvarint.Read(buffer); err != nil { // Field: ACK Delay + return nil, io.ErrUnexpectedEOF + } + ackRangeCount, err := quicvarint.Read(buffer) // Field: ACK Range Count + if err != nil { + return nil, io.ErrUnexpectedEOF + } + if _, err = quicvarint.Read(buffer); err != nil { // Field: First ACK Range + return nil, io.ErrUnexpectedEOF + } + for i := 0; i < int(ackRangeCount); i++ { // Field: ACK Range + if _, err = quicvarint.Read(buffer); err != nil { // Field: ACK Range -> Gap + return nil, io.ErrUnexpectedEOF + } + if _, err = quicvarint.Read(buffer); err != nil { // Field: ACK Range -> ACK Range Length + return nil, io.ErrUnexpectedEOF + } + } + if frameType == 0x03 { + if _, err = quicvarint.Read(buffer); err != nil { // Field: ECN Counts -> ECT0 Count + return nil, io.ErrUnexpectedEOF + } + if _, err = quicvarint.Read(buffer); err != nil { // Field: ECN Counts -> ECT1 Count + return nil, io.ErrUnexpectedEOF + } + if _, err = quicvarint.Read(buffer); err != nil { //nolint:misspell // Field: ECN Counts -> ECT-CE Count + return nil, io.ErrUnexpectedEOF + } + } + case 0x06: // CRYPTO frame, we will use this frame + offset, err := quicvarint.Read(buffer) // Field: Offset + if err != nil { + return nil, io.ErrUnexpectedEOF + } + length, err := quicvarint.Read(buffer) // Field: Length + if err != nil || length > uint64(buffer.Len()) { + return nil, io.ErrUnexpectedEOF + } + if cryptoLen < int(offset+length) { + cryptoLen = int(offset + length) + if len(cryptoData) < cryptoLen { + newCryptoData := bytespool.Alloc(int32(cryptoLen)) + copy(newCryptoData, cryptoData) + bytespool.Free(cryptoData) + cryptoData = newCryptoData + } + } + if _, err := buffer.Read(cryptoData[offset : offset+length]); err != nil { // Field: Crypto Data + return nil, io.ErrUnexpectedEOF + } + case 0x1c: // CONNECTION_CLOSE frame, only 0x1c is permitted in initial packet + if _, err = quicvarint.Read(buffer); err != nil { // Field: Error Code + return nil, io.ErrUnexpectedEOF + } + if _, err = quicvarint.Read(buffer); err != nil { // Field: Frame Type + return nil, io.ErrUnexpectedEOF + } + length, err := quicvarint.Read(buffer) // Field: Reason Phrase Length + if err != nil { + return nil, io.ErrUnexpectedEOF + } + if _, err := buffer.ReadBytes(int32(length)); err != nil { // Field: Reason Phrase + return nil, io.ErrUnexpectedEOF + } + default: + // Only above frame types are permitted in initial packet. + // See https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2.2-8 + return nil, errNotQuicInitial + } + } + + tlsHdr := &ptls.SniffHeader{} + err = ptls.ReadClientHello(cryptoData[:cryptoLen], tlsHdr) + if err != nil { + // The crypto data may have not been fully recovered in current packets, + // So we continue to sniff rest packets. + b = restPayload + continue + } + return &SniffHeader{domain: tlsHdr.Domain()}, nil } - return &SniffHeader{domain: tlsHdr.Domain()}, nil + return nil, common.ErrNoClue } func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {