Fix concurrent map writes error in ohm.Select(). (#2943)

* Add unit test for ohm.tagsCache.

* Fix concurrent map writes in ohm.Select().

---------

Co-authored-by: nobody <nobody@nowhere.mars>
This commit is contained in:
nobody 2024-01-16 23:52:01 +08:00 committed by GitHub
parent 10255bca83
commit d20a835016
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 103 additions and 9 deletions

View file

@ -22,14 +22,14 @@ type Manager struct {
taggedHandler map[string]outbound.Handler
untaggedHandlers []outbound.Handler
running bool
tagsCache map[string][]string
tagsCache *sync.Map
}
// New creates a new Manager.
func New(ctx context.Context, config *proxyman.OutboundConfig) (*Manager, error) {
m := &Manager{
taggedHandler: make(map[string]outbound.Handler),
tagsCache: make(map[string][]string),
tagsCache: &sync.Map{},
}
return m, nil
}
@ -106,7 +106,7 @@ func (m *Manager) AddHandler(ctx context.Context, handler outbound.Handler) erro
m.access.Lock()
defer m.access.Unlock()
m.tagsCache = make(map[string][]string)
m.tagsCache = &sync.Map{}
if m.defaultHandler == nil {
m.defaultHandler = handler
@ -137,7 +137,7 @@ func (m *Manager) RemoveHandler(ctx context.Context, tag string) error {
m.access.Lock()
defer m.access.Unlock()
m.tagsCache = make(map[string][]string)
m.tagsCache = &sync.Map{}
delete(m.taggedHandler, tag)
if m.defaultHandler != nil && m.defaultHandler.Tag() == tag {
@ -149,14 +149,15 @@ func (m *Manager) RemoveHandler(ctx context.Context, tag string) error {
// Select implements outbound.HandlerSelector.
func (m *Manager) Select(selectors []string) []string {
m.access.RLock()
defer m.access.RUnlock()
key := strings.Join(selectors, ",")
if cache, ok := m.tagsCache[key]; ok {
return cache
if cache, ok := m.tagsCache.Load(key); ok {
return cache.([]string)
}
m.access.RLock()
defer m.access.RUnlock()
tags := make([]string, 0, len(selectors))
for tag := range m.taggedHandler {
@ -169,7 +170,7 @@ func (m *Manager) Select(selectors []string) []string {
}
sort.Strings(tags)
m.tagsCache[key] = tags
m.tagsCache.Store(key, tags)
return tags
}