diff --git a/proxy/vless/encoding/encoding.go b/proxy/vless/encoding/encoding.go index 038c0ed3..f0699b96 100644 --- a/proxy/vless/encoding/encoding.go +++ b/proxy/vless/encoding/encoding.go @@ -64,7 +64,7 @@ func EncodeRequestHeader(writer io.Writer, request *protocol.RequestHeader, requ } // DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream. -func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validator *vless.Validator) (*protocol.RequestHeader, *Addons, bool, error) { +func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validator vless.Validator) (*protocol.RequestHeader, *Addons, bool, error) { buffer := buf.StackNew() defer buffer.Release() diff --git a/proxy/vless/encoding/encoding_test.go b/proxy/vless/encoding/encoding_test.go index ee7c6df0..9180154a 100644 --- a/proxy/vless/encoding/encoding_test.go +++ b/proxy/vless/encoding/encoding_test.go @@ -42,7 +42,7 @@ func TestRequestSerialization(t *testing.T) { buffer := buf.StackNew() common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons)) - Validator := new(vless.Validator) + Validator := new(vless.MemoryValidator) Validator.Add(user) actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator) @@ -83,7 +83,7 @@ func TestInvalidRequest(t *testing.T) { buffer := buf.StackNew() common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons)) - Validator := new(vless.Validator) + Validator := new(vless.MemoryValidator) Validator.Add(user) _, _, _, err := DecodeRequestHeader(false, nil, &buffer, Validator) @@ -114,7 +114,7 @@ func TestMuxRequest(t *testing.T) { buffer := buf.StackNew() common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons)) - Validator := new(vless.Validator) + Validator := new(vless.MemoryValidator) Validator.Add(user) actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator) diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 3bb2c09c..bcd8a24b 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -45,7 +45,21 @@ func init() { }); err != nil { return nil, err } - return New(ctx, config.(*Config), dc) + + c := config.(*Config) + + validator := new(vless.MemoryValidator) + for _, user := range c.Clients { + u, err := user.ToMemoryUser() + if err != nil { + return nil, errors.New("failed to get VLESS user").Base(err).AtError() + } + if err := validator.Add(u); err != nil { + return nil, errors.New("failed to initiate user").Base(err).AtError() + } + } + + return New(ctx, c, dc, validator) })) } @@ -53,30 +67,20 @@ func init() { type Handler struct { inboundHandlerManager feature_inbound.Manager policyManager policy.Manager - validator *vless.Validator + validator vless.Validator dns dns.Client fallbacks map[string]map[string]map[string]*Fallback // or nil // regexps map[string]*regexp.Regexp // or nil } // New creates a new VLess inbound handler. -func New(ctx context.Context, config *Config, dc dns.Client) (*Handler, error) { +func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Validator) (*Handler, error) { v := core.MustFromContext(ctx) handler := &Handler{ inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager), policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), - validator: new(vless.Validator), dns: dc, - } - - for _, user := range config.Clients { - u, err := user.ToMemoryUser() - if err != nil { - return nil, errors.New("failed to get VLESS user").Base(err).AtError() - } - if err := handler.AddUser(ctx, u); err != nil { - return nil, errors.New("failed to initiate user").Base(err).AtError() - } + validator: validator, } if config.Fallbacks != nil { diff --git a/proxy/vless/validator.go b/proxy/vless/validator.go index 72038cab..596cb62f 100644 --- a/proxy/vless/validator.go +++ b/proxy/vless/validator.go @@ -9,15 +9,21 @@ import ( "github.com/xtls/xray-core/common/uuid" ) -// Validator stores valid VLESS users. -type Validator struct { +type Validator interface { + Get(id uuid.UUID) *protocol.MemoryUser + Add(u *protocol.MemoryUser) error + Del(email string) error +} + +// MemoryValidator stores valid VLESS users. +type MemoryValidator struct { // Considering email's usage here, map + sync.Mutex/RWMutex may have better performance. email sync.Map users sync.Map } // Add a VLESS user, Email must be empty or unique. -func (v *Validator) Add(u *protocol.MemoryUser) error { +func (v *MemoryValidator) Add(u *protocol.MemoryUser) error { if u.Email != "" { _, loaded := v.email.LoadOrStore(strings.ToLower(u.Email), u) if loaded { @@ -29,7 +35,7 @@ func (v *Validator) Add(u *protocol.MemoryUser) error { } // Del a VLESS user with a non-empty Email. -func (v *Validator) Del(e string) error { +func (v *MemoryValidator) Del(e string) error { if e == "" { return errors.New("Email must not be empty.") } @@ -44,7 +50,7 @@ func (v *Validator) Del(e string) error { } // Get a VLESS user with UUID, nil if user doesn't exist. -func (v *Validator) Get(id uuid.UUID) *protocol.MemoryUser { +func (v *MemoryValidator) Get(id uuid.UUID) *protocol.MemoryUser { u, _ := v.users.Load(id) if u != nil { return u.(*protocol.MemoryUser)