diff --git a/common/protocol/quic/sniff.go b/common/protocol/quic/sniff.go index 3a22d454..0691bad6 100644 --- a/common/protocol/quic/sniff.go +++ b/common/protocol/quic/sniff.go @@ -114,6 +114,10 @@ func SniffQUIC(b []byte) (*SniffHeader, error) { if err != nil { return nil, errNotQuic } + // packetLen is impossible to be shorter than this + if packetLen < 4 { + return nil, errNotQuic + } hdrLen := len(b) - int(buffer.Len()) if len(b) < hdrLen+int(packetLen) { diff --git a/common/protocol/quic/sniff_test.go b/common/protocol/quic/sniff_test.go index 121279b5..8a85ea87 100644 --- a/common/protocol/quic/sniff_test.go +++ b/common/protocol/quic/sniff_test.go @@ -267,3 +267,21 @@ func TestSniffQUICPacketNumberLength4(t *testing.T) { t.Error("failed") } } + +func TestSniffFakeQUICPacketWithInvalidPacketNumberLength(t *testing.T) { + pkt, err := hex.DecodeString("cb00000001081c8c6d5aeb53d54400000090709b8600000000000000000000000000000000") + common.Must(err) + _, err = quic.SniffQUIC(pkt) + if err == nil { + t.Error("failed") + } +} + +func TestSniffFakeQUICPacketWithTooShortData(t *testing.T) { + pkt, err := hex.DecodeString("cb00000001081c8c6d5aeb53d54400000090709b86") + common.Must(err) + _, err = quic.SniffQUIC(pkt) + if err == nil { + t.Error("failed") + } +}