From d20a835016b9c0a22b8a698a22a5669d1c5edae6 Mon Sep 17 00:00:00 2001 From: nobody <59990325+vrnobody@users.noreply.github.com> Date: Tue, 16 Jan 2024 23:52:01 +0800 Subject: [PATCH] Fix concurrent map writes error in ohm.Select(). (#2943) * Add unit test for ohm.tagsCache. * Fix concurrent map writes in ohm.Select(). --------- Co-authored-by: nobody --- app/proxyman/outbound/handler_test.go | 93 +++++++++++++++++++++++++++ app/proxyman/outbound/outbound.go | 19 +++--- 2 files changed, 103 insertions(+), 9 deletions(-) diff --git a/app/proxyman/outbound/handler_test.go b/app/proxyman/outbound/handler_test.go index c5afea70..e5b67308 100644 --- a/app/proxyman/outbound/handler_test.go +++ b/app/proxyman/outbound/handler_test.go @@ -2,9 +2,14 @@ package outbound_test import ( "context" + "fmt" + "sync" + "sync/atomic" "testing" + "time" "github.com/xtls/xray-core/app/policy" + "github.com/xtls/xray-core/app/proxyman" . "github.com/xtls/xray-core/app/proxyman/outbound" "github.com/xtls/xray-core/app/stats" "github.com/xtls/xray-core/common/net" @@ -78,3 +83,91 @@ func TestOutboundWithStatCounter(t *testing.T) { t.Errorf("Expected conn to be CounterConnection") } } + +func TestTagsCache(t *testing.T) { + + test_duration := 10 * time.Second + threads_num := 50 + delay := 10 * time.Millisecond + tags_prefix := "node" + + tags := sync.Map{} + counter := atomic.Uint64{} + + ohm, err := New(context.Background(), &proxyman.OutboundConfig{}) + if err != nil { + t.Error("failed to create outbound handler manager") + } + config := &core.Config{ + App: []*serial.TypedMessage{}, + } + v, _ := core.New(config) + v.AddFeature(ohm) + ctx := context.WithValue(context.Background(), xrayKey, v) + + stop_add_rm := false + wg_add_rm := sync.WaitGroup{} + addHandlers := func() { + defer wg_add_rm.Done() + for !stop_add_rm { + time.Sleep(delay) + idx := counter.Add(1) + tag := fmt.Sprintf("%s%d", tags_prefix, idx) + cfg := &core.OutboundHandlerConfig{ + Tag: tag, + ProxySettings: serial.ToTypedMessage(&freedom.Config{}), + } + if h, err := NewHandler(ctx, cfg); err == nil { + if err := ohm.AddHandler(ctx, h); err == nil { + // t.Log("add handler:", tag) + tags.Store(tag, nil) + } else { + t.Error("failed to add handler:", tag) + } + } else { + t.Error("failed to create handler:", tag) + } + } + } + + rmHandlers := func() { + defer wg_add_rm.Done() + for !stop_add_rm { + time.Sleep(delay) + tags.Range(func(key interface{}, value interface{}) bool { + if _, ok := tags.LoadAndDelete(key); ok { + // t.Log("remove handler:", key) + ohm.RemoveHandler(ctx, key.(string)) + return false + } + return true + }) + } + } + + selectors := []string{tags_prefix} + wg_get := sync.WaitGroup{} + stop_get := false + getTags := func() { + defer wg_get.Done() + for !stop_get { + time.Sleep(delay) + _ = ohm.Select(selectors) + // t.Logf("get tags: %v", tag) + } + } + + for i := 0; i < threads_num; i++ { + wg_add_rm.Add(2) + go rmHandlers() + go addHandlers() + wg_get.Add(1) + go getTags() + } + + time.Sleep(test_duration) + stop_add_rm = true + wg_add_rm.Wait() + stop_get = true + wg_get.Wait() +} diff --git a/app/proxyman/outbound/outbound.go b/app/proxyman/outbound/outbound.go index 3bd0d85c..40f32965 100644 --- a/app/proxyman/outbound/outbound.go +++ b/app/proxyman/outbound/outbound.go @@ -22,14 +22,14 @@ type Manager struct { taggedHandler map[string]outbound.Handler untaggedHandlers []outbound.Handler running bool - tagsCache map[string][]string + tagsCache *sync.Map } // New creates a new Manager. func New(ctx context.Context, config *proxyman.OutboundConfig) (*Manager, error) { m := &Manager{ taggedHandler: make(map[string]outbound.Handler), - tagsCache: make(map[string][]string), + tagsCache: &sync.Map{}, } return m, nil } @@ -106,7 +106,7 @@ func (m *Manager) AddHandler(ctx context.Context, handler outbound.Handler) erro m.access.Lock() defer m.access.Unlock() - m.tagsCache = make(map[string][]string) + m.tagsCache = &sync.Map{} if m.defaultHandler == nil { m.defaultHandler = handler @@ -137,7 +137,7 @@ func (m *Manager) RemoveHandler(ctx context.Context, tag string) error { m.access.Lock() defer m.access.Unlock() - m.tagsCache = make(map[string][]string) + m.tagsCache = &sync.Map{} delete(m.taggedHandler, tag) if m.defaultHandler != nil && m.defaultHandler.Tag() == tag { @@ -149,14 +149,15 @@ func (m *Manager) RemoveHandler(ctx context.Context, tag string) error { // Select implements outbound.HandlerSelector. func (m *Manager) Select(selectors []string) []string { - m.access.RLock() - defer m.access.RUnlock() key := strings.Join(selectors, ",") - if cache, ok := m.tagsCache[key]; ok { - return cache + if cache, ok := m.tagsCache.Load(key); ok { + return cache.([]string) } + m.access.RLock() + defer m.access.RUnlock() + tags := make([]string, 0, len(selectors)) for tag := range m.taggedHandler { @@ -169,7 +170,7 @@ func (m *Manager) Select(selectors []string) []string { } sort.Strings(tags) - m.tagsCache[key] = tags + m.tagsCache.Store(key, tags) return tags }