From 20825f6f1ad978bc5b078bf69039da67a756d9e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A3=8E=E6=89=87=E6=BB=91=E7=BF=94=E7=BF=BC?= Date: Sat, 26 Jul 2025 12:05:21 +0000 Subject: [PATCH] Change to TypedSyncMap --- common/utils/typed_sync_map.go | 4 ++-- proxy/freedom/freedom.go | 2 +- proxy/trojan/validator.go | 18 +++++++++--------- proxy/vless/validator.go | 18 +++++++++--------- transport/internet/grpc/dial.go | 5 +---- transport/internet/splithttp/hub.go | 13 +++++++------ 6 files changed, 29 insertions(+), 31 deletions(-) diff --git a/common/utils/typed_sync_map.go b/common/utils/typed_sync_map.go index 39524a6f..11654372 100644 --- a/common/utils/typed_sync_map.go +++ b/common/utils/typed_sync_map.go @@ -15,8 +15,8 @@ type TypedSyncMap[K, V any] struct { // K is key type, V is value type // It is recommended to use pointer types for V because sync.Map might return nil // If sync.Map methods really returned nil, it will return the zero value of the type V -func NewTypedSyncMap[K any, V any]() *TypedSyncMap[K, V] { - return &TypedSyncMap[K, V]{ +func NewTypedSyncMap[K any, V any]() TypedSyncMap[K, V] { + return TypedSyncMap[K, V]{ syncMap: &sync.Map{}, } } diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index 1f9d5ae5..f9a00579 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -363,7 +363,7 @@ func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context, UDPOverride Handler: h, Context: ctx, UDPOverride: UDPOverride, - resolvedUDPAddr: resolvedUDPAddr, + resolvedUDPAddr: &resolvedUDPAddr, } } diff --git a/proxy/trojan/validator.go b/proxy/trojan/validator.go index 7841a249..77f3e602 100644 --- a/proxy/trojan/validator.go +++ b/proxy/trojan/validator.go @@ -2,17 +2,17 @@ package trojan import ( "strings" - "sync" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/utils" ) // Validator stores valid trojan users. type Validator struct { // Considering email's usage here, map + sync.Mutex/RWMutex may have better performance. - email sync.Map - users sync.Map + email utils.TypedSyncMap[string, *protocol.MemoryUser] + users utils.TypedSyncMap[string, *protocol.MemoryUser] } // Add a trojan user, Email must be empty or unique. @@ -38,7 +38,7 @@ func (v *Validator) Del(e string) error { return errors.New("User ", e, " not found.") } v.email.Delete(le) - v.users.Delete(hexString(u.(*protocol.MemoryUser).Account.(*MemoryAccount).Key)) + v.users.Delete(hexString(u.Account.(*MemoryAccount).Key)) return nil } @@ -46,7 +46,7 @@ func (v *Validator) Del(e string) error { func (v *Validator) Get(hash string) *protocol.MemoryUser { u, _ := v.users.Load(hash) if u != nil { - return u.(*protocol.MemoryUser) + return u } return nil } @@ -56,7 +56,7 @@ func (v *Validator) GetByEmail(email string) *protocol.MemoryUser { email = strings.ToLower(email) u, _ := v.email.Load(email) if u != nil { - return u.(*protocol.MemoryUser) + return u } return nil } @@ -64,8 +64,8 @@ func (v *Validator) GetByEmail(email string) *protocol.MemoryUser { // Get all users func (v *Validator) GetAll() []*protocol.MemoryUser { var u = make([]*protocol.MemoryUser, 0, 100) - v.email.Range(func(key, value interface{}) bool { - u = append(u, value.(*protocol.MemoryUser)) + v.email.Range(func(key string, value *protocol.MemoryUser) bool { + u = append(u, value) return true }) return u @@ -74,7 +74,7 @@ func (v *Validator) GetAll() []*protocol.MemoryUser { // Get users count func (v *Validator) GetCount() int64 { var c int64 = 0 - v.email.Range(func(key, value interface{}) bool { + v.email.Range(func(key string, value *protocol.MemoryUser) bool { c++ return true }) diff --git a/proxy/vless/validator.go b/proxy/vless/validator.go index d1356c5f..272d32d0 100644 --- a/proxy/vless/validator.go +++ b/proxy/vless/validator.go @@ -2,10 +2,10 @@ package vless import ( "strings" - "sync" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/utils" "github.com/xtls/xray-core/common/uuid" ) @@ -21,8 +21,8 @@ type Validator interface { // MemoryValidator stores valid VLESS users. type MemoryValidator struct { // Considering email's usage here, map + sync.Mutex/RWMutex may have better performance. - email sync.Map - users sync.Map + email utils.TypedSyncMap[string, *protocol.MemoryUser] + users utils.TypedSyncMap[uuid.UUID, *protocol.MemoryUser] } // Add a VLESS user, Email must be empty or unique. @@ -48,7 +48,7 @@ func (v *MemoryValidator) Del(e string) error { return errors.New("User ", e, " not found.") } v.email.Delete(le) - v.users.Delete(u.(*protocol.MemoryUser).Account.(*MemoryAccount).ID.UUID()) + v.users.Delete(u.Account.(*MemoryAccount).ID.UUID()) return nil } @@ -56,7 +56,7 @@ func (v *MemoryValidator) Del(e string) error { func (v *MemoryValidator) Get(id uuid.UUID) *protocol.MemoryUser { u, _ := v.users.Load(id) if u != nil { - return u.(*protocol.MemoryUser) + return u } return nil } @@ -66,7 +66,7 @@ func (v *MemoryValidator) GetByEmail(email string) *protocol.MemoryUser { email = strings.ToLower(email) u, _ := v.email.Load(email) if u != nil { - return u.(*protocol.MemoryUser) + return u } return nil } @@ -74,8 +74,8 @@ func (v *MemoryValidator) GetByEmail(email string) *protocol.MemoryUser { // Get all users func (v *MemoryValidator) GetAll() []*protocol.MemoryUser { var u = make([]*protocol.MemoryUser, 0, 100) - v.email.Range(func(key, value interface{}) bool { - u = append(u, value.(*protocol.MemoryUser)) + v.email.Range(func(key string, value *protocol.MemoryUser) bool { + u = append(u, value) return true }) return u @@ -84,7 +84,7 @@ func (v *MemoryValidator) GetAll() []*protocol.MemoryUser { // Get users count func (v *MemoryValidator) GetCount() int64 { var c int64 = 0 - v.email.Range(func(key, value interface{}) bool { + v.email.Range(func(key string, value *protocol.MemoryUser) bool { c++ return true }) diff --git a/transport/internet/grpc/dial.go b/transport/internet/grpc/dial.go index b8740dae..09ecbd68 100644 --- a/transport/internet/grpc/dial.go +++ b/transport/internet/grpc/dial.go @@ -43,7 +43,7 @@ type dialerConf struct { } var ( - globalDialerMap map[dialerConf]*grpc.ClientConn + globalDialerMap = make(map[dialerConf]*grpc.ClientConn) globalDialerAccess sync.Mutex ) @@ -77,9 +77,6 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in globalDialerAccess.Lock() defer globalDialerAccess.Unlock() - if globalDialerMap == nil { - globalDialerMap = make(map[dialerConf]*grpc.ClientConn) - } tlsConfig := tls.ConfigFromStreamSettings(streamSettings) realityConfig := reality.ConfigFromStreamSettings(streamSettings) sockopt := streamSettings.SocketSettings diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index d161741a..2a580f67 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -20,6 +20,7 @@ import ( "github.com/xtls/xray-core/common/net" http_proto "github.com/xtls/xray-core/common/protocol/http" "github.com/xtls/xray-core/common/signal/done" + "github.com/xtls/xray-core/common/utils" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/reality" "github.com/xtls/xray-core/transport/internet/stat" @@ -32,7 +33,7 @@ type requestHandler struct { path string ln *Listener sessionMu *sync.Mutex - sessions sync.Map + sessions utils.TypedSyncMap[string, *httpSession] localAddr net.Addr } @@ -47,18 +48,18 @@ type httpSession struct { func (h *requestHandler) upsertSession(sessionId string) *httpSession { // fast path - currentSessionAny, ok := h.sessions.Load(sessionId) + currentSession, ok := h.sessions.Load(sessionId) if ok { - return currentSessionAny.(*httpSession) + return currentSession } // slow path h.sessionMu.Lock() defer h.sessionMu.Unlock() - currentSessionAny, ok = h.sessions.Load(sessionId) + currentSession, ok = h.sessions.Load(sessionId) if ok { - return currentSessionAny.(*httpSession) + return currentSession } s := &httpSession{ @@ -361,7 +362,7 @@ func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSet path: l.config.GetNormalizedPath(), ln: l, sessionMu: &sync.Mutex{}, - sessions: sync.Map{}, + sessions: utils.NewTypedSyncMap[string, *httpSession](), } tlsConfig := getTLSConfig(streamSettings) l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"