diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index dbf58dad..7bc58056 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -106,7 +106,7 @@ func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { d := new(DefaultDispatcher) if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { - core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { // FakeDNSEngine is optional + core.OptionalFeatures(ctx, func(fdns dns.FakeDNSEngine) { d.fdns = fdns }) return d.Init(config.(*Config), om, router, pm, sm, dc) diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index de6e1686..9c2668d9 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -56,7 +56,7 @@ func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dis return NewTCPLocalNameServer(u, queryStrategy) case strings.EqualFold(u.String(), "fakedns"): var fd dns.FakeDNSEngine - core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { // FakeDNSEngine is optional + core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { fd = fdns }) return NewFakeDNSServer(fd), nil diff --git a/app/observatory/command/command.go b/app/observatory/command/command.go index aab85e80..f9bb58e3 100644 --- a/app/observatory/command/command.go +++ b/app/observatory/command/command.go @@ -38,7 +38,7 @@ func init() { sv := &service{v: s} err := s.RequireFeatures(func(Observatory extension.Observatory) { sv.observatory = Observatory - }) + }, false) if err != nil { return nil, err } diff --git a/app/proxyman/command/command.go b/app/proxyman/command/command.go index 3c7824d2..ef710521 100644 --- a/app/proxyman/command/command.go +++ b/app/proxyman/command/command.go @@ -177,7 +177,7 @@ func (s *service) Register(server *grpc.Server) { common.Must(s.v.RequireFeatures(func(im inbound.Manager, om outbound.Manager) { hs.ihm = im hs.ohm = om - })) + }, false)) RegisterHandlerServiceServer(server, hs) // For compatibility purposes diff --git a/app/router/balancing.go b/app/router/balancing.go index 7d8bb022..5f8cb1c2 100644 --- a/app/router/balancing.go +++ b/app/router/balancing.go @@ -5,6 +5,7 @@ import ( sync "sync" "github.com/xtls/xray-core/app/observatory" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/extension" @@ -31,9 +32,10 @@ type RoundRobinStrategy struct { func (s *RoundRobinStrategy) InjectContext(ctx context.Context) { s.ctx = ctx if len(s.FallbackTag) > 0 { - core.RequireFeaturesAsync(s.ctx, func(observatory extension.Observatory) { + common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { s.observatory = observatory - }) + return nil + })) } } diff --git a/app/router/command/command.go b/app/router/command/command.go index baf76b8b..fd9caa22 100644 --- a/app/router/command/command.go +++ b/app/router/command/command.go @@ -135,7 +135,7 @@ func (s *service) Register(server *grpc.Server) { vCoreDesc := RoutingService_ServiceDesc vCoreDesc.ServiceName = "v2ray.core.app.router.command.RoutingService" server.RegisterService(&vCoreDesc, rs) - })) + }, false)) } func init() { diff --git a/app/router/strategy_leastload.go b/app/router/strategy_leastload.go index a4ef1c12..1bf3cbc0 100644 --- a/app/router/strategy_leastload.go +++ b/app/router/strategy_leastload.go @@ -7,6 +7,7 @@ import ( "time" "github.com/xtls/xray-core/app/observatory" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/core" @@ -59,9 +60,10 @@ type node struct { func (s *LeastLoadStrategy) InjectContext(ctx context.Context) { s.ctx = ctx - core.RequireFeaturesAsync(s.ctx, func(observatory extension.Observatory) { + common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { s.observer = observatory - }) + return nil + })) } func (s *LeastLoadStrategy) PickOutbound(candidates []string) string { diff --git a/app/router/strategy_leastping.go b/app/router/strategy_leastping.go index b13d1a7d..ada3492d 100644 --- a/app/router/strategy_leastping.go +++ b/app/router/strategy_leastping.go @@ -4,6 +4,7 @@ import ( "context" "github.com/xtls/xray-core/app/observatory" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/extension" @@ -20,9 +21,10 @@ func (l *LeastPingStrategy) GetPrincipleTarget(strings []string) []string { func (l *LeastPingStrategy) InjectContext(ctx context.Context) { l.ctx = ctx - core.RequireFeaturesAsync(l.ctx, func(observatory extension.Observatory) { + common.Must(core.RequireFeatures(l.ctx, func(observatory extension.Observatory) error { l.observatory = observatory - }) + return nil + })) } func (l *LeastPingStrategy) PickOutbound(strings []string) string { diff --git a/app/router/strategy_random.go b/app/router/strategy_random.go index 9f4cdd77..ea9b7add 100644 --- a/app/router/strategy_random.go +++ b/app/router/strategy_random.go @@ -4,6 +4,7 @@ import ( "context" "github.com/xtls/xray-core/app/observatory" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/extension" @@ -20,9 +21,10 @@ type RandomStrategy struct { func (s *RandomStrategy) InjectContext(ctx context.Context) { s.ctx = ctx if len(s.FallbackTag) > 0 { - core.RequireFeaturesAsync(s.ctx, func(observatory extension.Observatory) { + common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { s.observatory = observatory - }) + return nil + })) } } diff --git a/core/xray.go b/core/xray.go index 5ab10603..f6ccc27d 100644 --- a/core/xray.go +++ b/core/xray.go @@ -4,7 +4,6 @@ import ( "context" "reflect" "sync" - "time" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" @@ -45,22 +44,13 @@ func getFeature(allFeatures []features.Feature, t reflect.Type) features.Feature return nil } -func (r *resolution) resolve(allFeatures []features.Feature) (bool, error) { - var fs []features.Feature - for _, d := range r.deps { - f := getFeature(allFeatures, d) - if f == nil { - return false, nil - } - fs = append(fs, f) - } - +func (r *resolution) callbackResolution(allFeatures []features.Feature) error { callback := reflect.ValueOf(r.callback) var input []reflect.Value callbackType := callback.Type() for i := 0; i < callbackType.NumIn(); i++ { pt := callbackType.In(i) - for _, f := range fs { + for _, f := range allFeatures { if reflect.TypeOf(f).AssignableTo(pt) { input = append(input, reflect.ValueOf(f)) break @@ -85,15 +75,17 @@ func (r *resolution) resolve(allFeatures []features.Feature) (bool, error) { } } - return true, err + return err } // Instance combines all Xray features. type Instance struct { - access sync.Mutex - features []features.Feature - featureResolutions []resolution - running bool + statusLock sync.Mutex + features []features.Feature + pendingResolutions []resolution + pendingOptionalResolutions []resolution + running bool + resolveLock sync.Mutex ctx context.Context } @@ -154,13 +146,14 @@ func addOutboundHandlers(server *Instance, configs []*OutboundHandlerConfig) err // See Instance.RequireFeatures for more information. func RequireFeatures(ctx context.Context, callback interface{}) error { v := MustFromContext(ctx) - return v.RequireFeatures(callback) + return v.RequireFeatures(callback, false) } -// RequireFeaturesAsync registers a callback, which will be called when all dependent features are registered. The order of app init doesn't matter -func RequireFeaturesAsync(ctx context.Context, callback interface{}) { +// OptionalFeatures is a helper function to aquire features from Instance in context. +// See Instance.RequireFeatures for more information. +func OptionalFeatures(ctx context.Context, callback interface{}) error { v := MustFromContext(ctx) - v.RequireFeaturesAsync(callback) + return v.RequireFeatures(callback, true) } // New returns a new Xray instance based on given configuration. @@ -234,9 +227,12 @@ func initInstanceWithConfig(config *Config, server *Instance) (bool, error) { }(), ) - if server.featureResolutions != nil { + server.resolveLock.Lock() + if server.pendingResolutions != nil { + server.resolveLock.Unlock() return true, errors.New("not all dependencies are resolved.") } + server.resolveLock.Unlock() if err := addInboundHandlers(server, config.Inbound); err != nil { return true, err @@ -255,8 +251,8 @@ func (s *Instance) Type() interface{} { // Close shutdown the Xray instance. func (s *Instance) Close() error { - s.access.Lock() - defer s.access.Unlock() + s.statusLock.Lock() + defer s.statusLock.Unlock() s.running = false @@ -275,7 +271,7 @@ func (s *Instance) Close() error { // RequireFeatures registers a callback, which will be called when all dependent features are registered. // The callback must be a func(). All its parameters must be features.Feature. -func (s *Instance) RequireFeatures(callback interface{}) error { +func (s *Instance) RequireFeatures(callback interface{}, optional bool) error { callbackType := reflect.TypeOf(callback) if callbackType.Kind() != reflect.Func { panic("not a function") @@ -290,47 +286,32 @@ func (s *Instance) RequireFeatures(callback interface{}) error { deps: featureTypes, callback: callback, } - if finished, err := r.resolve(s.features); finished { - return err - } - s.featureResolutions = append(s.featureResolutions, r) - return nil -} -// RequireFeaturesAsync registers a callback, which will be called when all dependent features are registered. The order of app init doesn't matter -func (s *Instance) RequireFeaturesAsync(callback interface{}) { - callbackType := reflect.TypeOf(callback) - if callbackType.Kind() != reflect.Func { - panic("not a function") - } - - var featureTypes []reflect.Type - for i := 0; i < callbackType.NumIn(); i++ { - featureTypes = append(featureTypes, reflect.PtrTo(callbackType.In(i))) - } - - r := resolution{ - deps: featureTypes, - callback: callback, - } - go func() { - var finished = false - for i := 0; !finished; i++ { - if i > 100000 { - errors.LogError(s.ctx, "RequireFeaturesAsync failed after count ", i) - break; - } - finished, _ = r.resolve(s.features) - time.Sleep(time.Millisecond) + s.resolveLock.Lock() + foundAll := true + for _, d := range r.deps { + f := getFeature(s.features, d) + if f == nil { + foundAll = false + break } - s.featureResolutions = append(s.featureResolutions, r) - }() + } + if foundAll { + s.resolveLock.Unlock() + return r.callbackResolution(s.features) + } else { + if optional { + s.pendingOptionalResolutions = append(s.pendingOptionalResolutions, r) + } else { + s.pendingResolutions = append(s.pendingResolutions, r) + } + s.resolveLock.Unlock() + return nil + } } // AddFeature registers a feature into current Instance. func (s *Instance) AddFeature(feature features.Feature) error { - s.features = append(s.features, feature) - if s.running { if err := feature.Start(); err != nil { errors.LogInfoInner(s.ctx, err, "failed to start feature") @@ -338,27 +319,52 @@ func (s *Instance) AddFeature(feature features.Feature) error { return nil } - if s.featureResolutions == nil { - return nil - } + s.resolveLock.Lock() + s.features = append(s.features, feature) - var pendingResolutions []resolution - for _, r := range s.featureResolutions { - finished, err := r.resolve(s.features) - if finished && err != nil { - return err + var availableResolution []resolution + var pending []resolution + for _, r := range s.pendingResolutions { + foundAll := true + for _, d := range r.deps { + f := getFeature(s.features, d) + if f == nil { + foundAll = false + break + } } - if !finished { - pendingResolutions = append(pendingResolutions, r) + if foundAll { + availableResolution = append(availableResolution, r) + } else { + pending = append(pending, r) } } - if len(pendingResolutions) == 0 { - s.featureResolutions = nil - } else if len(pendingResolutions) < len(s.featureResolutions) { - s.featureResolutions = pendingResolutions - } + s.pendingResolutions = pending - return nil + var pendingOptional []resolution + for _, r := range s.pendingOptionalResolutions { + foundAll := true + for _, d := range r.deps { + f := getFeature(s.features, d) + if f == nil { + foundAll = false + break + } + } + if foundAll { + availableResolution = append(availableResolution, r) + } else { + pendingOptional = append(pendingOptional, r) + } + } + s.pendingOptionalResolutions = pendingOptional + s.resolveLock.Unlock() + + var err error + for _, r := range availableResolution { + err = r.callbackResolution(s.features) // only return the last error for now + } + return err } // GetFeature returns a feature of the given type, or nil if such feature is not registered. @@ -371,8 +377,8 @@ func (s *Instance) GetFeature(featureType interface{}) features.Feature { // // xray:api:stable func (s *Instance) Start() error { - s.access.Lock() - defer s.access.Unlock() + s.statusLock.Lock() + defer s.statusLock.Unlock() s.running = true for _, f := range s.features { diff --git a/core/xray_test.go b/core/xray_test.go index 43d021ef..f4cb11ab 100644 --- a/core/xray_test.go +++ b/core/xray_test.go @@ -30,7 +30,7 @@ func TestXrayDependency(t *testing.T) { t.Error("expected dns client fulfilled, but actually nil") } wait <- true - }) + }, false) instance.AddFeature(localdns.New()) <-wait } diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index 790c80c1..b7a3264a 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -27,7 +27,7 @@ func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { h := new(Handler) if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error { - core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { // FakeDNSEngine is optional + core.OptionalFeatures(ctx, func(fdns dns.FakeDNSEngine) { h.fdns = fdns }) return h.Init(config.(*Config), dnsClient, policyManager)