fix(proxy): removed the udp payload length check when encryption is disabled

This commit is contained in:
cty123 2023-08-19 22:10:59 +02:00 committed by yuhan6665
parent f67167bb3b
commit a343d68944
3 changed files with 107 additions and 53 deletions

View File

@ -4,6 +4,7 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"errors"
"hash/crc32" "hash/crc32"
"io" "io"
@ -236,37 +237,37 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
} }
func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) { func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
bs := payload.Bytes() rawPayload := payload.Bytes()
if len(bs) <= 32 { user, _, d, _, err := validator.Get(rawPayload, protocol.RequestCommandUDP)
return nil, nil, newError("len(bs) <= 32")
}
user, _, d, _, err := validator.Get(bs, protocol.RequestCommandUDP) if errors.Is(err, ErrIVNotUnique) {
switch err {
case ErrIVNotUnique:
return nil, nil, newError("failed iv check").Base(err) return nil, nil, newError("failed iv check").Base(err)
case ErrNotFound:
return nil, nil, newError("failed to match an user").Base(err)
default:
account := user.Account.(*MemoryAccount)
if account.Cipher.IsAEAD() {
payload.Clear()
payload.Write(d)
} else {
if account.Cipher.IVSize() > 0 {
iv := make([]byte, account.Cipher.IVSize())
copy(iv, payload.BytesTo(account.Cipher.IVSize()))
}
if err = account.Cipher.DecodePacket(account.Key, payload); err != nil {
return nil, nil, newError("failed to decrypt UDP payload").Base(err)
}
}
} }
request := &protocol.RequestHeader{ if errors.Is(err, ErrNotFound) {
Version: Version, return nil, nil, newError("failed to match an user").Base(err)
User: user, }
Command: protocol.RequestCommandUDP,
if err != nil {
return nil, nil, newError("unexpected error").Base(err)
}
account, ok := user.Account.(*MemoryAccount)
if !ok {
return nil, nil, newError("expected MemoryAccount returned from validator")
}
if account.Cipher.IsAEAD() {
payload.Clear()
payload.Write(d)
} else {
if account.Cipher.IVSize() > 0 {
iv := make([]byte, account.Cipher.IVSize())
copy(iv, payload.BytesTo(account.Cipher.IVSize()))
}
if err = account.Cipher.DecodePacket(account.Key, payload); err != nil {
return nil, nil, newError("failed to decrypt UDP payload").Base(err)
}
} }
payload.SetByte(0, payload.Byte(0)&0x0F) payload.SetByte(0, payload.Byte(0)&0x0F)
@ -276,8 +277,13 @@ func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.Reque
return nil, nil, newError("failed to parse address").Base(err) return nil, nil, newError("failed to parse address").Base(err)
} }
request.Address = addr request := &protocol.RequestHeader{
request.Port = port Version: Version,
User: user,
Command: protocol.RequestCommandUDP,
Address: addr,
Port: port,
}
return request, payload, nil return request, payload, nil
} }

View File

@ -23,37 +23,80 @@ func equalRequestHeader(x, y *protocol.RequestHeader) bool {
})) }))
} }
func TestUDPEncoding(t *testing.T) { func TestUDPEncodingDecoding(t *testing.T) {
request := &protocol.RequestHeader{ testRequests := []protocol.RequestHeader{
Version: Version, {
Command: protocol.RequestCommandUDP, Version: Version,
Address: net.LocalHostIP, Command: protocol.RequestCommandUDP,
Port: 1234, Address: net.LocalHostIP,
User: &protocol.MemoryUser{ Port: 1234,
Email: "love@example.com", User: &protocol.MemoryUser{
Account: toAccount(&Account{ Email: "love@example.com",
Password: "password", Account: toAccount(&Account{
CipherType: CipherType_AES_128_GCM, Password: "password",
}), CipherType: CipherType_AES_128_GCM,
}),
},
},
{
Version: Version,
Command: protocol.RequestCommandUDP,
Address: net.LocalHostIP,
Port: 1234,
User: &protocol.MemoryUser{
Email: "love@example.com",
Account: toAccount(&Account{
Password: "123",
CipherType: CipherType_NONE,
}),
},
}, },
} }
data := buf.New() for _, request := range testRequests {
common.Must2(data.WriteString("test string")) data := buf.New()
encodedData, err := EncodeUDPPacket(request, data.Bytes()) common.Must2(data.WriteString("test string"))
common.Must(err) encodedData, err := EncodeUDPPacket(&request, data.Bytes())
common.Must(err)
validator := new(Validator) validator := new(Validator)
validator.Add(request.User) validator.Add(request.User)
decodedRequest, decodedData, err := DecodeUDPPacket(validator, encodedData) decodedRequest, decodedData, err := DecodeUDPPacket(validator, encodedData)
common.Must(err) common.Must(err)
if r := cmp.Diff(decodedData.Bytes(), data.Bytes()); r != "" { if r := cmp.Diff(decodedData.Bytes(), data.Bytes()); r != "" {
t.Error("data: ", r) t.Error("data: ", r)
}
if equalRequestHeader(decodedRequest, &request) == false {
t.Error("different request")
}
}
}
func TestUDPDecodingWithPayloadTooShort(t *testing.T) {
testAccounts := []protocol.Account{
toAccount(&Account{
Password: "password",
CipherType: CipherType_AES_128_GCM,
}),
toAccount(&Account{
Password: "password",
CipherType: CipherType_NONE,
}),
} }
if equalRequestHeader(decodedRequest, request) == false { for _, account := range testAccounts {
t.Error("different request") data := buf.New()
data.WriteString("short payload")
validator := new(Validator)
validator.Add(&protocol.MemoryUser{
Account: account,
})
_, _, err := DecodeUDPPacket(validator, data)
if err == nil {
t.Fatal("expected error")
}
} }
} }

View File

@ -80,6 +80,11 @@ func (v *Validator) Get(bs []byte, command protocol.RequestCommand) (u *protocol
for _, user := range v.users { for _, user := range v.users {
if account := user.Account.(*MemoryAccount); account.Cipher.IsAEAD() { if account := user.Account.(*MemoryAccount); account.Cipher.IsAEAD() {
// AEAD payload decoding requires the payload to be over 32 bytes
if len(bs) < 32 {
continue
}
aeadCipher := account.Cipher.(*AEADCipher) aeadCipher := account.Cipher.(*AEADCipher)
ivLen = aeadCipher.IVSize() ivLen = aeadCipher.IVSize()
iv := bs[:ivLen] iv := bs[:ivLen]