From 27224868aba63db8180ee2f2db71aa995930a74d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 28 Sep 2021 14:41:31 +0800 Subject: [PATCH] Override destination if replaced in hosts --- app/dispatcher/default.go | 30 ++++++++++++++++++++++++------ app/dns/dns.go | 16 ++++++++++++++++ features/dns/client.go | 4 ++++ 3 files changed, 44 insertions(+), 6 deletions(-) diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index b32013e8..17010759 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -4,6 +4,7 @@ package dispatcher import ( "context" + "github.com/xtls/xray-core/features/dns" "strings" "sync" "time" @@ -15,7 +16,6 @@ import ( "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/core" - "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/features/outbound" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" @@ -92,13 +92,14 @@ type DefaultDispatcher struct { router routing.Router policy policy.Manager stats stats.Manager + hosts dns.HostsLookup } func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { d := new(DefaultDispatcher) - if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error { - return d.Init(config.(*Config), om, router, pm, sm) + if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { + return d.Init(config.(*Config), om, router, pm, sm, dc) }); err != nil { return nil, err } @@ -107,11 +108,14 @@ func init() { } // Init initializes DefaultDispatcher. -func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error { +func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { d.ohm = om d.router = router d.policy = pm d.stats = sm + if hosts, ok := dc.(dns.HostsLookup); ok { + d.hosts = hosts + } return nil } @@ -294,7 +298,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De result, err := sniffer(ctx, nil, true) if err == nil { content.Protocol = result.Protocol() - if shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) { + if shouldOverride(ctx, result, sniffingRequest, destination) { domain := result.Domain() newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) destination.Address = net.ParseAddress(domain) @@ -316,7 +320,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De if err == nil { content.Protocol = result.Protocol() } - if err == nil && shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) { + if err == nil && shouldOverride(ctx, result, sniffingRequest, destination) { domain := result.Domain() newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) destination.Address = net.ParseAddress(domain) @@ -379,6 +383,20 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (Sni } func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { + ob := session.OutboundFromContext(ctx) + if d.hosts != nil && destination.Address.Family().IsDomain() { + proxied := d.hosts.LookupHosts(ob.Target.String()) + if proxied != nil { + ro := ob.RouteTarget == destination + destination.Address = *proxied + if ro { + ob.RouteTarget = destination + } else { + ob.Target = destination + } + } + } + var handler outbound.Handler if d.router != nil { diff --git a/app/dns/dns.go b/app/dns/dns.go index a669bf10..af5f285b 100644 --- a/app/dns/dns.go +++ b/app/dns/dns.go @@ -223,6 +223,22 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, error) { return nil, newError("returning nil for domain ", domain).Base(errors.Combine(errs...)) } +// LookupHosts implements dns.HostsLookup. +func (s *DNS) LookupHosts(domain string) *net.Address { + domain = strings.TrimSuffix(domain, ".") + if domain == "" { + return nil + } + // Normalize the FQDN form query + addrs := s.hosts.Lookup(domain, *s.ipOption) + if len(addrs) > 0 { + newError("domain replaced: ", domain, " -> ", addrs[0].String()).AtInfo().WriteToLog() + return &addrs[0] + } + + return nil +} + // GetIPOption implements ClientWithIPOption. func (s *DNS) GetIPOption() *dns.IPOption { return s.ipOption diff --git a/features/dns/client.go b/features/dns/client.go index 584e24f8..d3a911da 100644 --- a/features/dns/client.go +++ b/features/dns/client.go @@ -24,6 +24,10 @@ type Client interface { LookupIP(domain string, option IPOption) ([]net.IP, error) } +type HostsLookup interface { + LookupHosts(domain string) *net.Address +} + // ClientType returns the type of Client interface. Can be used for implementing common.HasType. // // xray:api:beta