diff --git a/app/dns/cache_controller.go b/app/dns/cache_controller.go new file mode 100644 index 00000000..f23c414d --- /dev/null +++ b/app/dns/cache_controller.go @@ -0,0 +1,188 @@ +package dns + +import ( + "context" + go_errors "errors" + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/signal/pubsub" + "github.com/xtls/xray-core/common/task" + dns_feature "github.com/xtls/xray-core/features/dns" + "golang.org/x/net/dns/dnsmessage" + "sync" + "time" +) + +type CacheController struct { + sync.RWMutex + ips map[string]*record + pub *pubsub.Service + cacheCleanup *task.Periodic + name string + disableCache bool +} + +func NewCacheController(name string, disableCache bool) *CacheController { + c := &CacheController{ + name: name, + disableCache: disableCache, + ips: make(map[string]*record), + pub: pubsub.NewService(), + } + + c.cacheCleanup = &task.Periodic{ + Interval: time.Minute, + Execute: c.CacheCleanup, + } + return c +} + +// CacheCleanup clears expired items from cache +func (c *CacheController) CacheCleanup() error { + now := time.Now() + c.Lock() + defer c.Unlock() + + if len(c.ips) == 0 { + return errors.New("nothing to do. stopping...") + } + + for domain, record := range c.ips { + if record.A != nil && record.A.Expire.Before(now) { + record.A = nil + } + if record.AAAA != nil && record.AAAA.Expire.Before(now) { + record.AAAA = nil + } + + if record.A == nil && record.AAAA == nil { + errors.LogDebug(context.Background(), c.name, "cache cleanup ", domain) + delete(c.ips, domain) + } else { + c.ips[domain] = record + } + } + + if len(c.ips) == 0 { + c.ips = make(map[string]*record) + } + + return nil +} + +func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) { + elapsed := time.Since(req.start) + + c.Lock() + rec, found := c.ips[req.domain] + if !found { + rec = &record{} + } + + switch req.reqType { + case dnsmessage.TypeA: + rec.A = ipRec + case dnsmessage.TypeAAAA: + rec.AAAA = ipRec + } + + errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed) + c.ips[req.domain] = rec + + switch req.reqType { + case dnsmessage.TypeA: + c.pub.Publish(req.domain+"4", nil) + if !c.disableCache { + _, _, err := rec.AAAA.getIPs() + if !go_errors.Is(err, errRecordNotFound) { + c.pub.Publish(req.domain+"6", nil) + } + } + case dnsmessage.TypeAAAA: + c.pub.Publish(req.domain+"6", nil) + if !c.disableCache { + _, _, err := rec.A.getIPs() + if !go_errors.Is(err, errRecordNotFound) { + c.pub.Publish(req.domain+"4", nil) + } + } + } + + c.Unlock() + common.Must(c.cacheCleanup.Start()) +} + +func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { + c.RLock() + record, found := c.ips[domain] + c.RUnlock() + + if !found { + return nil, 0, errRecordNotFound + } + + var errs []error + var allIPs []net.IP + var rTTL uint32 = dns_feature.DefaultTTL + + mergeReq := option.IPv4Enable && option.IPv6Enable + + if option.IPv4Enable { + ips, ttl, err := record.A.getIPs() + if !mergeReq || go_errors.Is(err, errRecordNotFound) { + return ips, ttl, err + } + if ttl < rTTL { + rTTL = ttl + } + if len(ips) > 0 { + allIPs = append(allIPs, ips...) + } else { + errs = append(errs, err) + } + } + + if option.IPv6Enable { + ips, ttl, err := record.AAAA.getIPs() + if !mergeReq || go_errors.Is(err, errRecordNotFound) { + return ips, ttl, err + } + if ttl < rTTL { + rTTL = ttl + } + if len(ips) > 0 { + allIPs = append(allIPs, ips...) + } else { + errs = append(errs, err) + } + } + + if len(allIPs) > 0 { + return allIPs, rTTL, nil + } + if go_errors.Is(errs[0], errs[1]) { + return nil, rTTL, errs[0] + } + return nil, rTTL, errors.Combine(errs...) +} + +func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) { + // ipv4 and ipv6 belong to different subscription groups + if option.IPv4Enable { + sub4 = c.pub.Subscribe(domain + "4") + } + if option.IPv6Enable { + sub6 = c.pub.Subscribe(domain + "6") + } + return +} + +func closeSubscribers(sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) { + if sub4 != nil { + sub4.Close() + } + if sub6 != nil { + sub6.Close() + } +} diff --git a/app/dns/dns.go b/app/dns/dns.go index 9b84106c..3b9cfbcb 100644 --- a/app/dns/dns.go +++ b/app/dns/dns.go @@ -3,12 +3,12 @@ package dns import ( "context" + go_errors "errors" "fmt" "sort" "strings" "sync" - "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" @@ -20,8 +20,6 @@ import ( // DNS is a DNS rely server. type DNS struct { sync.Mutex - tag string - disableCache bool disableFallback bool disableFallbackIfMatch bool ipOption *dns.IPOption @@ -40,13 +38,6 @@ type DomainMatcherInfo struct { // New creates a new DNS server with given configuration. func New(ctx context.Context, config *Config) (*DNS, error) { - var tag string - if len(config.Tag) > 0 { - tag = config.Tag - } else { - tag = generateRandomTag() - } - var clientIP net.IP switch len(config.ClientIp) { case 0, net.IPv4len, net.IPv6len: @@ -55,26 +46,28 @@ func New(ctx context.Context, config *Config) (*DNS, error) { return nil, errors.New("unexpected client IP length ", len(config.ClientIp)) } - var ipOption *dns.IPOption + var ipOption dns.IPOption switch config.QueryStrategy { case QueryStrategy_USE_IP: - ipOption = &dns.IPOption{ + ipOption = dns.IPOption{ IPv4Enable: true, IPv6Enable: true, FakeEnable: false, } case QueryStrategy_USE_IP4: - ipOption = &dns.IPOption{ + ipOption = dns.IPOption{ IPv4Enable: true, IPv6Enable: false, FakeEnable: false, } case QueryStrategy_USE_IP6: - ipOption = &dns.IPOption{ + ipOption = dns.IPOption{ IPv4Enable: false, IPv6Enable: true, FakeEnable: false, } + default: + return nil, errors.New("unexpected query strategy ", config.QueryStrategy) } hosts, err := NewStaticHosts(config.StaticHosts) @@ -82,8 +75,14 @@ func New(ctx context.Context, config *Config) (*DNS, error) { return nil, errors.New("failed to create hosts").Base(err) } - clients := []*Client{} + var clients []*Client domainRuleCount := 0 + + var defaultTag = config.Tag + if len(config.Tag) == 0 { + defaultTag = generateRandomTag() + } + for _, ns := range config.NameServer { domainRuleCount += len(ns.PrioritizedDomain) } @@ -91,7 +90,6 @@ func New(ctx context.Context, config *Config) (*DNS, error) { // MatcherInfos is ensured to cover the maximum index domainMatcher could return, where matcher's index starts from 1 matcherInfos := make([]*DomainMatcherInfo, domainRuleCount+1) domainMatcher := &strmatcher.MatcherGroup{} - geoipContainer := router.GeoIPMatcherContainer{} for _, ns := range config.NameServer { clientIdx := len(clients) @@ -109,7 +107,18 @@ func New(ctx context.Context, config *Config) (*DNS, error) { case net.IPv4len, net.IPv6len: myClientIP = net.IP(ns.ClientIp) } - client, err := NewClient(ctx, ns, myClientIP, geoipContainer, &matcherInfos, updateDomain) + + disableCache := config.DisableCache + + var tag = defaultTag + if len(ns.Tag) > 0 { + tag = ns.Tag + } + clientIPOption := ResolveIpOptionOverride(ns.QueryStrategy, ipOption) + if !clientIPOption.IPv4Enable && !clientIPOption.IPv6Enable { + return nil, errors.New("no QueryStrategy available for ", ns.Address) + } + client, err := NewClient(ctx, ns, myClientIP, disableCache, tag, clientIPOption, &matcherInfos, updateDomain) if err != nil { return nil, errors.New("failed to create client").Base(err) } @@ -118,18 +127,16 @@ func New(ctx context.Context, config *Config) (*DNS, error) { // If there is no DNS client in config, add a `localhost` DNS client if len(clients) == 0 { - clients = append(clients, NewLocalDNSClient()) + clients = append(clients, NewLocalDNSClient(ipOption)) } return &DNS{ - tag: tag, hosts: hosts, - ipOption: ipOption, + ipOption: &ipOption, clients: clients, ctx: ctx, domainMatcher: domainMatcher, matcherInfos: matcherInfos, - disableCache: config.DisableCache, disableFallback: config.DisableFallback, disableFallbackIfMatch: config.DisableFallbackIfMatch, }, nil @@ -153,11 +160,21 @@ func (s *DNS) Close() error { // IsOwnLink implements proxy.dns.ownLinkVerifier func (s *DNS) IsOwnLink(ctx context.Context) bool { inbound := session.InboundFromContext(ctx) - return inbound != nil && inbound.Tag == s.tag + if inbound == nil { + return false + } + for _, client := range s.clients { + if client.tag == inbound.Tag { + return true + } + } + return false } // LookupIP implements dns.Client. func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, error) { + // Normalize the FQDN form query + domain = strings.TrimSuffix(domain, ".") if domain == "" { return nil, 0, errors.New("empty domain name") } @@ -169,9 +186,6 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, er return nil, 0, dns.ErrEmptyResponse } - // Normalize the FQDN form query - domain = strings.TrimSuffix(domain, ".") - // Static host lookup switch addrs := s.hosts.Lookup(domain, option); { case addrs == nil: // Domain not recorded in static host @@ -184,32 +198,49 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, er default: // Successfully found ip records in static host errors.LogInfo(s.ctx, "returning ", len(addrs), " IP(s) for domain ", domain, " -> ", addrs) ips, err := toNetIP(addrs) - return ips, 10, err // Hosts ttl is 10 + if err != nil { + return nil, 0, err + } + return ips, 10, nil // Hosts ttl is 10 } // Name servers lookup - errs := []error{} - ctx := session.ContextWithInbound(s.ctx, &session.Inbound{Tag: s.tag}) + var errs []error for _, client := range s.sortClients(domain) { if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") { errors.LogDebug(s.ctx, "skip DNS resolution for domain ", domain, " at server ", client.Name()) continue } - ips, ttl, err := client.QueryIP(ctx, domain, option, s.disableCache) + + ips, ttl, err := client.QueryIP(s.ctx, domain, option) + if len(ips) > 0 { + if ttl == 0 { + ttl = 1 + } return ips, ttl, nil } - if err != nil { - errors.LogInfoInner(s.ctx, err, "failed to lookup ip for domain ", domain, " at server ", client.Name()) - errs = append(errs, err) - } - // 5 for RcodeRefused in miekg/dns, hardcode to reduce binary size - if err != context.Canceled && err != context.DeadlineExceeded && err != errExpectedIPNonMatch && err != dns.ErrEmptyResponse && dns.RCodeFromError(err) != 5 { - return nil, 0, err + + errors.LogInfoInner(s.ctx, err, "failed to lookup ip for domain ", domain, " at server ", client.Name()) + if err == nil { + err = dns.ErrEmptyResponse } + errs = append(errs, err) + } - return nil, 0, errors.New("returning nil for domain ", domain).Base(errors.Combine(errs...)) + if len(errs) > 0 { + allErrs := errors.Combine(errs...) + err0 := errs[0] + if errors.AllEqual(err0, allErrs) { + if go_errors.Is(err0, dns.ErrEmptyResponse) { + return nil, 0, dns.ErrEmptyResponse + } + return nil, 0, errors.New("returning nil for domain ", domain).Base(err0) + } + return nil, 0, errors.New("returning nil for domain ", domain).Base(allErrs) + } + return nil, 0, dns.ErrEmptyResponse } // LookupHosts implements dns.HostsLookup. @@ -228,22 +259,6 @@ func (s *DNS) LookupHosts(domain string) *net.Address { return nil } -// GetIPOption implements ClientWithIPOption. -func (s *DNS) GetIPOption() *dns.IPOption { - return s.ipOption -} - -// SetQueryOption implements ClientWithIPOption. -func (s *DNS) SetQueryOption(isIPv4Enable, isIPv6Enable bool) { - s.ipOption.IPv4Enable = isIPv4Enable - s.ipOption.IPv6Enable = isIPv6Enable -} - -// SetFakeDNSOption implements ClientWithIPOption. -func (s *DNS) SetFakeDNSOption(isFakeEnable bool) { - s.ipOption.FakeEnable = isFakeEnable -} - func (s *DNS) sortClients(domain string) []*Client { clients := make([]*Client, 0, len(s.clients)) clientUsed := make([]bool, len(s.clients)) diff --git a/app/dns/dns_test.go b/app/dns/dns_test.go index 4bdc9ae3..7ea6fcf8 100644 --- a/app/dns/dns_test.go +++ b/app/dns/dns_test.go @@ -76,6 +76,9 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { case q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA: ans.MsgHdr.Rcode = dns.RcodeNameError + case q.Name == "notexist.google.com." && q.Qtype == dns.TypeA: + ans.MsgHdr.Rcode = dns.RcodeNameError + case q.Name == "hostname." && q.Qtype == dns.TypeA: rr, _ := dns.NewRR("hostname. IN A 127.0.0.1") ans.Answer = append(ans.Answer, rr) @@ -117,7 +120,6 @@ func TestUDPServerSubnet(t *testing.T) { Handler: &staticHandler{}, UDPSize: 1200, } - go dnsServer.ListenAndServe() time.Sleep(time.Second) diff --git a/app/dns/dnscommon.go b/app/dns/dnscommon.go index 0dc07f72..0bd712ff 100644 --- a/app/dns/dnscommon.go +++ b/app/dns/dnscommon.go @@ -32,31 +32,30 @@ type record struct { // IPRecord is a cacheable item for a resolved domain type IPRecord struct { ReqID uint16 - IP []net.Address + IP []net.IP Expire time.Time RCode dnsmessage.RCode RawHeader *dnsmessage.Header } -func (r *IPRecord) getIPs() ([]net.Address, uint32, error) { - if r == nil || r.Expire.Before(time.Now()) { +func (r *IPRecord) getIPs() ([]net.IP, uint32, error) { + if r == nil { return nil, 0, errRecordNotFound } - if r.RCode != dnsmessage.RCodeSuccess { - return nil, 0, dns_feature.RCodeError(r.RCode) + untilExpire := time.Until(r.Expire) + if untilExpire <= 0 { + return nil, 0, errRecordNotFound } - ttl := uint32(time.Until(r.Expire) / time.Second) - return r.IP, ttl, nil -} -func isNewer(baseRec *IPRecord, newRec *IPRecord) bool { - if newRec == nil { - return false + ttl := uint32(untilExpire/time.Second) + uint32(1) + if r.RCode != dnsmessage.RCodeSuccess { + return nil, ttl, dns_feature.RCodeError(r.RCode) } - if baseRec == nil { - return true + if len(r.IP) == 0 { + return nil, ttl, dns_feature.ErrEmptyResponse } - return baseRec.Expire.Before(newRec.Expire) + + return r.IP, ttl, nil } var errRecordNotFound = errors.New("record not found") @@ -193,7 +192,7 @@ func parseResponse(payload []byte) (*IPRecord, error) { ipRecord := &IPRecord{ ReqID: h.ID, RCode: h.RCode, - Expire: now.Add(time.Second * 600), + Expire: now.Add(time.Second * dns_feature.DefaultTTL), RawHeader: &h, } @@ -209,7 +208,7 @@ L: ttl := ah.TTL if ttl == 0 { - ttl = 600 + ttl = 1 } expire := now.Add(time.Duration(ttl) * time.Second) if ipRecord.Expire.After(expire) { @@ -223,14 +222,17 @@ L: errors.LogInfoInner(context.Background(), err, "failed to parse A record for domain: ", ah.Name) break L } - ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:])) + ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]).IP()) case dnsmessage.TypeAAAA: ans, err := parser.AAAAResource() if err != nil { errors.LogInfoInner(context.Background(), err, "failed to parse AAAA record for domain: ", ah.Name) break L } - ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:])) + newIP := net.IPAddress(ans.AAAA[:]).IP() + if len(newIP) == net.IPv6len { + ipRecord.IP = append(ipRecord.IP, newIP) + } default: if err := parser.SkipAnswer(); err != nil { errors.LogInfoInner(context.Background(), err, "failed to skip answer") diff --git a/app/dns/dnscommon_test.go b/app/dns/dnscommon_test.go index 2affb2a3..bbaa9a21 100644 --- a/app/dns/dnscommon_test.go +++ b/app/dns/dnscommon_test.go @@ -51,7 +51,7 @@ func Test_parseResponse(t *testing.T) { }{ { "empty", - &IPRecord{0, []net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess, nil}, + &IPRecord{0, []net.IP(nil), time.Time{}, dnsmessage.RCodeSuccess, nil}, false, }, { @@ -63,7 +63,7 @@ func Test_parseResponse(t *testing.T) { "a record", &IPRecord{ 1, - []net.Address{net.ParseAddress("8.8.8.8"), net.ParseAddress("8.8.4.4")}, + []net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")}, time.Time{}, dnsmessage.RCodeSuccess, nil, @@ -72,7 +72,7 @@ func Test_parseResponse(t *testing.T) { }, { "aaaa record", - &IPRecord{2, []net.Address{net.ParseAddress("2001::123:8888"), net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess, nil}, + &IPRecord{2, []net.IP{net.ParseIP("2001::123:8888"), net.ParseIP("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess, nil}, false, }, } diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index f1d02616..31681c4a 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -21,25 +21,23 @@ type Server interface { // Name of the Client. Name() string // QueryIP sends IP queries to its configured server. - QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns.IPOption, disableCache bool) ([]net.IP, uint32, error) + QueryIP(ctx context.Context, domain string, option dns.IPOption) ([]net.IP, uint32, error) } // Client is the interface for DNS client. type Client struct { server Server - clientIP net.IP skipFallback bool domains []string expectedIPs []*router.GeoIPMatcher allowUnexpectedIPs bool tag string timeoutMs time.Duration + ipOption *dns.IPOption } -var errExpectedIPNonMatch = errors.New("expectedIPs not match") - // NewServer creates a name server object according to the network destination url. -func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) (Server, error) { +func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dispatcher, disableCache bool, clientIP net.IP) (Server, error) { if address := dest.Address; address.Family().IsDomain() { u, err := url.Parse(address.Domain()) if err != nil { @@ -47,26 +45,29 @@ func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dis } switch { case strings.EqualFold(u.String(), "localhost"): - return NewLocalNameServer(queryStrategy), nil + return NewLocalNameServer(), nil case strings.EqualFold(u.Scheme, "https"): // DNS-over-HTTPS Remote mode - return NewDoHNameServer(u, queryStrategy, dispatcher, false), nil + return NewDoHNameServer(u, dispatcher, false, disableCache, clientIP), nil case strings.EqualFold(u.Scheme, "h2c"): // DNS-over-HTTPS h2c Remote mode - return NewDoHNameServer(u, queryStrategy, dispatcher, true), nil + return NewDoHNameServer(u, dispatcher, true, disableCache, clientIP), nil case strings.EqualFold(u.Scheme, "https+local"): // DNS-over-HTTPS Local mode - return NewDoHNameServer(u, queryStrategy, nil, false), nil + return NewDoHNameServer(u, nil, false, disableCache, clientIP), nil case strings.EqualFold(u.Scheme, "h2c+local"): // DNS-over-HTTPS h2c Local mode - return NewDoHNameServer(u, queryStrategy, nil, true), nil + return NewDoHNameServer(u, nil, true, disableCache, clientIP), nil case strings.EqualFold(u.Scheme, "quic+local"): // DNS-over-QUIC Local mode - return NewQUICNameServer(u, queryStrategy) + return NewQUICNameServer(u, disableCache, clientIP) case strings.EqualFold(u.Scheme, "tcp"): // DNS-over-TCP Remote mode - return NewTCPNameServer(u, dispatcher, queryStrategy) + return NewTCPNameServer(u, dispatcher, disableCache, clientIP) case strings.EqualFold(u.Scheme, "tcp+local"): // DNS-over-TCP Local mode - return NewTCPLocalNameServer(u, queryStrategy) + return NewTCPLocalNameServer(u, disableCache, clientIP) case strings.EqualFold(u.String(), "fakedns"): var fd dns.FakeDNSEngine - core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { + err = core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { fd = fdns }) + if err != nil { + return nil, err + } return NewFakeDNSServer(fd), nil } } @@ -74,7 +75,7 @@ func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dis dest.Network = net.Network_UDP } if dest.Network == net.Network_UDP { // UDP classic DNS mode - return NewClassicNameServer(dest, dispatcher, queryStrategy), nil + return NewClassicNameServer(dest, dispatcher, disableCache, clientIP), nil } return nil, errors.New("No available name server could be created from ", dest).AtWarning() } @@ -84,7 +85,9 @@ func NewClient( ctx context.Context, ns *NameServer, clientIP net.IP, - container router.GeoIPMatcherContainer, + disableCache bool, + tag string, + ipOption dns.IPOption, matcherInfos *[]*DomainMatcherInfo, updateDomainRule func(strmatcher.Matcher, int, []*DomainMatcherInfo) error, ) (*Client, error) { @@ -92,7 +95,7 @@ func NewClient( err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error { // Create a new server for each client for now - server, err := NewServer(ctx, ns.Address.AsDestination(), dispatcher, ns.GetQueryStrategy()) + server, err := NewServer(ctx, ns.Address.AsDestination(), dispatcher, disableCache, clientIP) if err != nil { return errors.New("failed to create nameserver").Base(err).AtWarning() } @@ -149,7 +152,7 @@ func NewClient( // Establish expected IPs var matchers []*router.GeoIPMatcher for _, geoip := range ns.Geoip { - matcher, err := container.Add(geoip) + matcher, err := router.GlobalGeoIPContainer.Add(geoip) if err != nil { return errors.New("failed to create ip matcher").Base(err).AtWarning() } @@ -169,15 +172,15 @@ func NewClient( if ns.TimeoutMs > 0 { timeoutMs = time.Duration(ns.TimeoutMs) * time.Millisecond } - + client.server = server - client.clientIP = clientIP client.skipFallback = ns.SkipFallback client.domains = rules client.expectedIPs = matchers client.allowUnexpectedIPs = ns.AllowUnexpectedIPs - client.tag = ns.Tag + client.tag = tag client.timeoutMs = timeoutMs + client.ipOption = &ipOption return nil }) return client, err @@ -189,31 +192,43 @@ func (c *Client) Name() string { } // QueryIP sends DNS query to the name server with the client's IP. -func (c *Client) QueryIP(ctx context.Context, domain string, option dns.IPOption, disableCache bool) ([]net.IP, uint32, error) { - ctx, cancel := context.WithTimeout(ctx, c.timeoutMs) - if len(c.tag) != 0 { - content := session.InboundFromContext(ctx) - errors.LogDebug(ctx, "DNS: client override tag from ", content.Tag, " to ", c.tag) - // create a new context to override the tag - // do not direct set *content.Tag, it might be used by other clients - ctx = session.ContextWithInbound(ctx, &session.Inbound{Tag: c.tag}) +func (c *Client) QueryIP(ctx context.Context, domain string, option dns.IPOption) ([]net.IP, uint32, error) { + option.IPv4Enable = option.IPv4Enable && c.ipOption.IPv4Enable + option.IPv6Enable = option.IPv6Enable && c.ipOption.IPv6Enable + if !option.IPv4Enable && !option.IPv6Enable { + return nil, 0, dns.ErrEmptyResponse } - ips, ttl, err := c.server.QueryIP(ctx, domain, c.clientIP, option, disableCache) + + ctx, cancel := context.WithTimeout(ctx, c.timeoutMs) + ctx = session.ContextWithInbound(ctx, &session.Inbound{Tag: c.tag}) + ips, ttl, err := c.server.QueryIP(ctx, domain, option) cancel() if err != nil { - return ips, ttl, err + return nil, 0, err } - netips, err := c.MatchExpectedIPs(domain, ips) - return netips, ttl, err + + if len(ips) == 0 { + return nil, 0, dns.ErrEmptyResponse + } + + if len(c.expectedIPs) > 0 { + newIps := c.MatchExpectedIPs(domain, ips) + if len(newIps) == 0 { + if !c.allowUnexpectedIPs { + return nil, 0, dns.ErrEmptyResponse + } + } else { + ips = newIps + } + } + + return ips, ttl, nil } // MatchExpectedIPs matches queried domain IPs with expected IPs and returns matched ones. -func (c *Client) MatchExpectedIPs(domain string, ips []net.IP) ([]net.IP, error) { - if len(c.expectedIPs) == 0 { - return ips, nil - } - newIps := []net.IP{} +func (c *Client) MatchExpectedIPs(domain string, ips []net.IP) []net.IP { + var newIps []net.IP for _, ip := range ips { for _, matcher := range c.expectedIPs { if matcher.Match(ip) { @@ -222,14 +237,8 @@ func (c *Client) MatchExpectedIPs(domain string, ips []net.IP) ([]net.IP, error) } } } - if len(newIps) == 0 { - if c.allowUnexpectedIPs { - return ips, nil - } - return nil, errExpectedIPNonMatch - } errors.LogDebug(context.Background(), "domain ", domain, " expectedIPs ", newIps, " matched at server ", c.Name()) - return newIps, nil + return newIps } func ResolveIpOptionOverride(queryStrategy QueryStrategy, ipOption dns.IPOption) dns.IPOption { diff --git a/app/dns/nameserver_doh.go b/app/dns/nameserver_doh.go index 6cdb8ee7..cba59423 100644 --- a/app/dns/nameserver_doh.go +++ b/app/dns/nameserver_doh.go @@ -4,12 +4,12 @@ import ( "bytes" "context" "crypto/tls" + go_errors "errors" "fmt" "io" "net/http" "net/url" "strings" - "sync" "time" utls "github.com/refraction-networking/utls" @@ -21,12 +21,9 @@ import ( "github.com/xtls/xray-core/common/net/cnc" "github.com/xtls/xray-core/common/protocol/dns" "github.com/xtls/xray-core/common/session" - "github.com/xtls/xray-core/common/signal/pubsub" - "github.com/xtls/xray-core/common/task" dns_feature "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport/internet" - "golang.org/x/net/dns/dnsmessage" "golang.org/x/net/http2" ) @@ -34,18 +31,14 @@ import ( // which is compatible with traditional dns over udp(RFC1035), // thus most of the DOH implementation is copied from udpns.go type DoHNameServer struct { - sync.RWMutex - ips map[string]*record - pub *pubsub.Service - cleanup *task.Periodic - httpClient *http.Client - dohURL string - name string - queryStrategy QueryStrategy + cacheController *CacheController + httpClient *http.Client + dohURL string + clientIP net.IP } // NewDoHNameServer creates DOH/DOHL client object for remote/local resolving. -func NewDoHNameServer(url *url.URL, queryStrategy QueryStrategy, dispatcher routing.Dispatcher, h2c bool) *DoHNameServer { +func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, h2c bool, disableCache bool, clientIP net.IP) *DoHNameServer { url.Scheme = "https" mode := "DOH" if dispatcher == nil { @@ -53,15 +46,9 @@ func NewDoHNameServer(url *url.URL, queryStrategy QueryStrategy, dispatcher rout } errors.LogInfo(context.Background(), "DNS: created ", mode, " client for ", url.String(), ", with h2c ", h2c) s := &DoHNameServer{ - ips: make(map[string]*record), - pub: pubsub.NewService(), - name: mode + "//" + url.Host, - dohURL: url.String(), - queryStrategy: queryStrategy, - } - s.cleanup = &task.Periodic{ - Interval: time.Minute, - Execute: s.Cleanup, + cacheController: NewCacheController(mode+"//"+url.Host, disableCache), + dohURL: url.String(), + clientIP: clientIP, } s.httpClient = &http.Client{ Transport: &http2.Transport{ @@ -127,101 +114,25 @@ func NewDoHNameServer(url *url.URL, queryStrategy QueryStrategy, dispatcher rout // Name implements Server. func (s *DoHNameServer) Name() string { - return s.name -} - -// Cleanup clears expired items from cache -func (s *DoHNameServer) Cleanup() error { - now := time.Now() - s.Lock() - defer s.Unlock() - - if len(s.ips) == 0 { - return errors.New("nothing to do. stopping...") - } - - for domain, record := range s.ips { - if record.A != nil && record.A.Expire.Before(now) { - record.A = nil - } - if record.AAAA != nil && record.AAAA.Expire.Before(now) { - record.AAAA = nil - } - - if record.A == nil && record.AAAA == nil { - errors.LogDebug(context.Background(), s.name, " cleanup ", domain) - delete(s.ips, domain) - } else { - s.ips[domain] = record - } - } - - if len(s.ips) == 0 { - s.ips = make(map[string]*record) - } - - return nil -} - -func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { - elapsed := time.Since(req.start) - - s.Lock() - rec, found := s.ips[req.domain] - if !found { - rec = &record{} - } - updated := false - - switch req.reqType { - case dnsmessage.TypeA: - if isNewer(rec.A, ipRec) { - rec.A = ipRec - updated = true - } - case dnsmessage.TypeAAAA: - addr := make([]net.Address, 0, len(ipRec.IP)) - for _, ip := range ipRec.IP { - if len(ip.IP()) == net.IPv6len { - addr = append(addr, ip) - } - } - ipRec.IP = addr - if isNewer(rec.AAAA, ipRec) { - rec.AAAA = ipRec - updated = true - } - } - errors.LogInfo(context.Background(), s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed) - - if updated { - s.ips[req.domain] = rec - } - switch req.reqType { - case dnsmessage.TypeA: - s.pub.Publish(req.domain+"4", nil) - case dnsmessage.TypeAAAA: - s.pub.Publish(req.domain+"6", nil) - } - s.Unlock() - common.Must(s.cleanup.Start()) + return s.cacheController.name } func (s *DoHNameServer) newReqID() uint16 { return 0 } -func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { - errors.LogInfo(ctx, s.name, " querying: ", domain) +func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) { + errors.LogInfo(ctx, s.Name(), " querying: ", domain) - if s.name+"." == "DOH//"+domain { - errors.LogError(ctx, s.name, " tries to resolve itself! Use IP or set \"hosts\" instead.") + if s.Name()+"." == "DOH//"+domain { + errors.LogError(ctx, s.Name(), " tries to resolve itself! Use IP or set \"hosts\" instead.") + noResponseErrCh <- errors.New("tries to resolve itself!", s.Name()) return } // As we don't want our traffic pattern looks like DoH, we use Random-Length Padding instead of Block-Length Padding recommended in RFC 8467 // Although DoH server like 1.1.1.1 will pad the response to Block-Length 468, at least it is better than no padding for response at all - reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP, int(crypto.RandBetween(100, 300)))) + reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, int(crypto.RandBetween(100, 300)))) var deadline time.Time if d, ok := ctx.Deadline(); ok { @@ -256,19 +167,22 @@ func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, clientIP n b, err := dns.PackMessage(r.msg) if err != nil { errors.LogErrorInner(ctx, err, "failed to pack dns query for ", domain) + noResponseErrCh <- err return } resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes()) if err != nil { errors.LogErrorInner(ctx, err, "failed to retrieve response for ", domain) + noResponseErrCh <- err return } rec, err := parseResponse(resp) if err != nil { errors.LogErrorInner(ctx, err, "failed to handle DOH response for ", domain) + noResponseErrCh <- err return } - s.updateIP(r, rec) + s.cacheController.updateIP(r, rec) }(req) } } @@ -301,109 +215,50 @@ func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte, return io.ReadAll(resp.Body) } -func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { - s.RLock() - record, found := s.ips[domain] - s.RUnlock() - - if !found { - return nil, 0, errRecordNotFound - } - - var err4 error - var err6 error - var ips []net.Address - var ip6 []net.Address - var ttl uint32 - - if option.IPv4Enable { - ips, ttl, err4 = record.A.getIPs() - } - - if option.IPv6Enable { - ip6, ttl, err6 = record.AAAA.getIPs() - ips = append(ips, ip6...) - } - - if len(ips) > 0 { - netips, err := toNetIP(ips) - return netips, ttl, err - } - - if err4 != nil { - return nil, 0, err4 - } - - if err6 != nil { - return nil, 0, err6 - } - - if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) { - return nil, 0, dns_feature.ErrEmptyResponse - } - - return nil, 0, errRecordNotFound -} - // QueryIP implements Server. -func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, uint32, error) { // nolint: dupl +func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { // nolint: dupl fqdn := Fqdn(domain) - option = ResolveIpOptionOverride(s.queryStrategy, option) - if !option.IPv4Enable && !option.IPv6Enable { - return nil, 0, dns_feature.ErrEmptyResponse - } + sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option) + defer closeSubscribers(sub4, sub6) - if disableCache { - errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name) + if s.cacheController.disableCache { + errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name()) } else { - ips, ttl, err := s.findIPsForDomain(fqdn, option) - if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 { - errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips) - log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) + ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + if !go_errors.Is(err, errRecordNotFound) { + errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips) + log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) return ips, ttl, err } } - // ipv4 and ipv6 belong to different subscription groups - var sub4, sub6 *pubsub.Subscriber - if option.IPv4Enable { - sub4 = s.pub.Subscribe(fqdn + "4") - defer sub4.Close() - } - if option.IPv6Enable { - sub6 = s.pub.Subscribe(fqdn + "6") - defer sub6.Close() - } - done := make(chan interface{}) - go func() { - if sub4 != nil { - select { - case <-sub4.Wait(): - case <-ctx.Done(): - } - } - if sub6 != nil { - select { - case <-sub6.Wait(): - case <-ctx.Done(): - } - } - close(done) - }() - s.sendQuery(ctx, fqdn, clientIP, option) + noResponseErrCh := make(chan error, 2) + s.sendQuery(ctx, noResponseErrCh, fqdn, option) start := time.Now() - for { - ips, ttl, err := s.findIPsForDomain(fqdn, option) - if err != errRecordNotFound { - log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) - return ips, ttl, err - } - + if sub4 != nil { select { case <-ctx.Done(): return nil, 0, ctx.Err() - case <-done: + case err := <-noResponseErrCh: + return nil, 0, err + case <-sub4.Wait(): + sub4.Close() } } + if sub6 != nil { + select { + case <-ctx.Done(): + return nil, 0, ctx.Err() + case err := <-noResponseErrCh: + return nil, 0, err + case <-sub6.Wait(): + sub6.Close() + } + } + + ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) + return ips, ttl, err + } diff --git a/app/dns/nameserver_doh_test.go b/app/dns/nameserver_doh_test.go index a27a5e9f..96412c22 100644 --- a/app/dns/nameserver_doh_test.go +++ b/app/dns/nameserver_doh_test.go @@ -17,12 +17,12 @@ func TestDOHNameServer(t *testing.T) { url, err := url.Parse("https+local://1.1.1.1/dns-query") common.Must(err) - s := NewDoHNameServer(url, QueryStrategy_USE_IP, nil, false) + s := NewDoHNameServer(url, nil, false, false, net.IP(nil)) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{ + ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{ IPv4Enable: true, IPv6Enable: true, - }, false) + }) cancel() common.Must(err) if len(ips) == 0 { @@ -34,12 +34,12 @@ func TestDOHNameServerWithCache(t *testing.T) { url, err := url.Parse("https+local://1.1.1.1/dns-query") common.Must(err) - s := NewDoHNameServer(url, QueryStrategy_USE_IP, nil, false) + s := NewDoHNameServer(url, nil, false, false, net.IP(nil)) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{ + ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{ IPv4Enable: true, IPv6Enable: true, - }, false) + }) cancel() common.Must(err) if len(ips) == 0 { @@ -47,10 +47,10 @@ func TestDOHNameServerWithCache(t *testing.T) { } ctx2, cancel := context.WithTimeout(context.Background(), time.Second*5) - ips2, _, err := s.QueryIP(ctx2, "google.com", net.IP(nil), dns_feature.IPOption{ + ips2, _, err := s.QueryIP(ctx2, "google.com", dns_feature.IPOption{ IPv4Enable: true, IPv6Enable: true, - }, true) + }) cancel() common.Must(err) if r := cmp.Diff(ips2, ips); r != "" { @@ -62,12 +62,12 @@ func TestDOHNameServerWithIPv4Override(t *testing.T) { url, err := url.Parse("https+local://1.1.1.1/dns-query") common.Must(err) - s := NewDoHNameServer(url, QueryStrategy_USE_IP4, nil, false) + s := NewDoHNameServer(url, nil, false, false, net.IP(nil)) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{ + ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{ IPv4Enable: true, - IPv6Enable: true, - }, false) + IPv6Enable: false, + }) cancel() common.Must(err) if len(ips) == 0 { @@ -85,12 +85,12 @@ func TestDOHNameServerWithIPv6Override(t *testing.T) { url, err := url.Parse("https+local://1.1.1.1/dns-query") common.Must(err) - s := NewDoHNameServer(url, QueryStrategy_USE_IP6, nil, false) + s := NewDoHNameServer(url, nil, false, false, net.IP(nil)) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{ - IPv4Enable: true, + ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{ + IPv4Enable: false, IPv6Enable: true, - }, false) + }) cancel() common.Must(err) if len(ips) == 0 { diff --git a/app/dns/nameserver_fakedns.go b/app/dns/nameserver_fakedns.go index 37c2f723..8c598ac8 100644 --- a/app/dns/nameserver_fakedns.go +++ b/app/dns/nameserver_fakedns.go @@ -20,7 +20,7 @@ func (FakeDNSServer) Name() string { return "FakeDNS" } -func (f *FakeDNSServer) QueryIP(ctx context.Context, domain string, _ net.IP, opt dns.IPOption, _ bool) ([]net.IP, uint32, error) { +func (f *FakeDNSServer) QueryIP(ctx context.Context, domain string, opt dns.IPOption) ([]net.IP, uint32, error) { if f.fakeDNSEngine == nil { return nil, 0, errors.New("Unable to locate a fake DNS Engine").AtError() } diff --git a/app/dns/nameserver_local.go b/app/dns/nameserver_local.go index 1b45e5f0..91b003e3 100644 --- a/app/dns/nameserver_local.go +++ b/app/dns/nameserver_local.go @@ -2,7 +2,6 @@ package dns import ( "context" - "strings" "time" "github.com/xtls/xray-core/common/errors" @@ -14,26 +13,15 @@ import ( // LocalNameServer is an wrapper over local DNS feature. type LocalNameServer struct { - client *localdns.Client - queryStrategy QueryStrategy + client *localdns.Client } -const errEmptyResponse = "No address associated with hostname" - // QueryIP implements Server. -func (s *LocalNameServer) QueryIP(ctx context.Context, domain string, _ net.IP, option dns.IPOption, _ bool) (ips []net.IP, ttl uint32, err error) { - option = ResolveIpOptionOverride(s.queryStrategy, option) - if !option.IPv4Enable && !option.IPv6Enable { - return nil, 0, dns.ErrEmptyResponse - } +func (s *LocalNameServer) QueryIP(ctx context.Context, domain string, option dns.IPOption) (ips []net.IP, ttl uint32, err error) { start := time.Now() ips, ttl, err = s.client.LookupIP(domain, option) - if err != nil && strings.HasSuffix(err.Error(), errEmptyResponse) { - err = dns.ErrEmptyResponse - } - if len(ips) > 0 { errors.LogInfo(ctx, "Localhost got answer: ", domain, " -> ", ips) log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) @@ -48,15 +36,14 @@ func (s *LocalNameServer) Name() string { } // NewLocalNameServer creates localdns server object for directly lookup in system DNS. -func NewLocalNameServer(queryStrategy QueryStrategy) *LocalNameServer { +func NewLocalNameServer() *LocalNameServer { errors.LogInfo(context.Background(), "DNS: created localhost client") return &LocalNameServer{ - queryStrategy: queryStrategy, - client: localdns.New(), + client: localdns.New(), } } // NewLocalDNSClient creates localdns client object for directly lookup in system DNS. -func NewLocalDNSClient() *Client { - return &Client{server: NewLocalNameServer(QueryStrategy_USE_IP)} +func NewLocalDNSClient(ipOption dns.IPOption) *Client { + return &Client{server: NewLocalNameServer(), ipOption: &ipOption} } diff --git a/app/dns/nameserver_local_test.go b/app/dns/nameserver_local_test.go index a32c69e6..71aa08c4 100644 --- a/app/dns/nameserver_local_test.go +++ b/app/dns/nameserver_local_test.go @@ -7,18 +7,17 @@ import ( . "github.com/xtls/xray-core/app/dns" "github.com/xtls/xray-core/common" - "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/features/dns" ) func TestLocalNameServer(t *testing.T) { - s := NewLocalNameServer(QueryStrategy_USE_IP) + s := NewLocalNameServer() ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) - ips, _, err := s.QueryIP(ctx, "google.com", net.IP{}, dns.IPOption{ + ips, _, err := s.QueryIP(ctx, "google.com", dns.IPOption{ IPv4Enable: true, IPv6Enable: true, FakeEnable: false, - }, false) + }) cancel() common.Must(err) if len(ips) == 0 { diff --git a/app/dns/nameserver_quic.go b/app/dns/nameserver_quic.go index 6ce5809b..5512edc4 100644 --- a/app/dns/nameserver_quic.go +++ b/app/dns/nameserver_quic.go @@ -4,23 +4,20 @@ import ( "bytes" "context" "encoding/binary" + go_errors "errors" "net/url" "sync" "time" "github.com/quic-go/quic-go" - "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/log" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol/dns" "github.com/xtls/xray-core/common/session" - "github.com/xtls/xray-core/common/signal/pubsub" - "github.com/xtls/xray-core/common/task" dns_feature "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/transport/internet/tls" - "golang.org/x/net/dns/dnsmessage" "golang.org/x/net/http2" ) @@ -33,17 +30,14 @@ const handshakeTimeout = time.Second * 8 // QUICNameServer implemented DNS over QUIC type QUICNameServer struct { sync.RWMutex - ips map[string]*record - pub *pubsub.Service - cleanup *task.Periodic - name string - destination *net.Destination - connection quic.Connection - queryStrategy QueryStrategy + cacheController *CacheController + destination *net.Destination + connection quic.Connection + clientIP net.IP } // NewQUICNameServer creates DNS-over-QUIC client object for local resolving -func NewQUICNameServer(url *url.URL, queryStrategy QueryStrategy) (*QUICNameServer, error) { +func NewQUICNameServer(url *url.URL, disableCache bool, clientIP net.IP) (*QUICNameServer, error) { errors.LogInfo(context.Background(), "DNS: created Local DNS-over-QUIC client for ", url.String()) var err error @@ -57,15 +51,9 @@ func NewQUICNameServer(url *url.URL, queryStrategy QueryStrategy) (*QUICNameServ dest := net.UDPDestination(net.ParseAddress(url.Hostname()), port) s := &QUICNameServer{ - ips: make(map[string]*record), - pub: pubsub.NewService(), - name: url.String(), - destination: &dest, - queryStrategy: queryStrategy, - } - s.cleanup = &task.Periodic{ - Interval: time.Minute, - Execute: s.Cleanup, + cacheController: NewCacheController(url.String(), disableCache), + destination: &dest, + clientIP: clientIP, } return s, nil @@ -73,94 +61,17 @@ func NewQUICNameServer(url *url.URL, queryStrategy QueryStrategy) (*QUICNameServ // Name returns client name func (s *QUICNameServer) Name() string { - return s.name -} - -// Cleanup clears expired items from cache -func (s *QUICNameServer) Cleanup() error { - now := time.Now() - s.Lock() - defer s.Unlock() - - if len(s.ips) == 0 { - return errors.New("nothing to do. stopping...") - } - - for domain, record := range s.ips { - if record.A != nil && record.A.Expire.Before(now) { - record.A = nil - } - if record.AAAA != nil && record.AAAA.Expire.Before(now) { - record.AAAA = nil - } - - if record.A == nil && record.AAAA == nil { - errors.LogDebug(context.Background(), s.name, " cleanup ", domain) - delete(s.ips, domain) - } else { - s.ips[domain] = record - } - } - - if len(s.ips) == 0 { - s.ips = make(map[string]*record) - } - - return nil -} - -func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { - elapsed := time.Since(req.start) - - s.Lock() - rec, found := s.ips[req.domain] - if !found { - rec = &record{} - } - updated := false - - switch req.reqType { - case dnsmessage.TypeA: - if isNewer(rec.A, ipRec) { - rec.A = ipRec - updated = true - } - case dnsmessage.TypeAAAA: - addr := make([]net.Address, 0) - for _, ip := range ipRec.IP { - if len(ip.IP()) == net.IPv6len { - addr = append(addr, ip) - } - } - ipRec.IP = addr - if isNewer(rec.AAAA, ipRec) { - rec.AAAA = ipRec - updated = true - } - } - errors.LogInfo(context.Background(), s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed) - - if updated { - s.ips[req.domain] = rec - } - switch req.reqType { - case dnsmessage.TypeA: - s.pub.Publish(req.domain+"4", nil) - case dnsmessage.TypeAAAA: - s.pub.Publish(req.domain+"6", nil) - } - s.Unlock() - common.Must(s.cleanup.Start()) + return s.cacheController.name } func (s *QUICNameServer) newReqID() uint16 { return 0 } -func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { - errors.LogInfo(ctx, s.name, " querying: ", domain) +func (s *QUICNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) { + errors.LogInfo(ctx, s.Name(), " querying: ", domain) - reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP, 0)) + reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0)) var deadline time.Time if d, ok := ctx.Deadline(); ok { @@ -192,23 +103,36 @@ func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP b, err := dns.PackMessage(r.msg) if err != nil { errors.LogErrorInner(ctx, err, "failed to pack dns query") + noResponseErrCh <- err return } dnsReqBuf := buf.New() - binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len())) - dnsReqBuf.Write(b.Bytes()) + err = binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len())) + if err != nil { + errors.LogErrorInner(ctx, err, "binary write failed") + noResponseErrCh <- err + return + } + _, err = dnsReqBuf.Write(b.Bytes()) + if err != nil { + errors.LogErrorInner(ctx, err, "buffer write failed") + noResponseErrCh <- err + return + } b.Release() conn, err := s.openStream(dnsCtx) if err != nil { errors.LogErrorInner(ctx, err, "failed to open quic connection") + noResponseErrCh <- err return } _, err = conn.Write(dnsReqBuf.Bytes()) if err != nil { errors.LogErrorInner(ctx, err, "failed to send query") + noResponseErrCh <- err return } @@ -219,136 +143,81 @@ func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP n, err := respBuf.ReadFullFrom(conn, 2) if err != nil && n == 0 { errors.LogErrorInner(ctx, err, "failed to read response length") + noResponseErrCh <- err return } var length int16 err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length) if err != nil { errors.LogErrorInner(ctx, err, "failed to parse response length") + noResponseErrCh <- err return } respBuf.Clear() n, err = respBuf.ReadFullFrom(conn, int32(length)) if err != nil && n == 0 { errors.LogErrorInner(ctx, err, "failed to read response length") + noResponseErrCh <- err return } rec, err := parseResponse(respBuf.Bytes()) if err != nil { errors.LogErrorInner(ctx, err, "failed to handle response") + noResponseErrCh <- err return } - s.updateIP(r, rec) + s.cacheController.updateIP(r, rec) }(req) } } -func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { - s.RLock() - record, found := s.ips[domain] - s.RUnlock() - - if !found { - return nil, 0, errRecordNotFound - } - - var err4 error - var err6 error - var ips []net.Address - var ip6 []net.Address - var ttl uint32 - - if option.IPv4Enable { - ips, ttl, err4 = record.A.getIPs() - } - - if option.IPv6Enable { - ip6, ttl, err6 = record.AAAA.getIPs() - ips = append(ips, ip6...) - } - - if len(ips) > 0 { - netips, err := toNetIP(ips) - return netips, ttl, err - } - - if err4 != nil { - return nil, 0, err4 - } - - if err6 != nil { - return nil, 0, err6 - } - - if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) { - return nil, 0, dns_feature.ErrEmptyResponse - } - - return nil, 0, errRecordNotFound -} - // QueryIP is called from dns.Server->queryIPTimeout -func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, uint32, error) { +func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { fqdn := Fqdn(domain) - option = ResolveIpOptionOverride(s.queryStrategy, option) - if !option.IPv4Enable && !option.IPv6Enable { - return nil, 0, dns_feature.ErrEmptyResponse - } + sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option) + defer closeSubscribers(sub4, sub6) - if disableCache { - errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name) + if s.cacheController.disableCache { + errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name()) } else { - ips, ttl, err := s.findIPsForDomain(fqdn, option) - if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 { - errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips) - log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) + ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + if !go_errors.Is(err, errRecordNotFound) { + errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips) + log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) return ips, ttl, err } } - // ipv4 and ipv6 belong to different subscription groups - var sub4, sub6 *pubsub.Subscriber - if option.IPv4Enable { - sub4 = s.pub.Subscribe(fqdn + "4") - defer sub4.Close() - } - if option.IPv6Enable { - sub6 = s.pub.Subscribe(fqdn + "6") - defer sub6.Close() - } - done := make(chan interface{}) - go func() { - if sub4 != nil { - select { - case <-sub4.Wait(): - case <-ctx.Done(): - } - } - if sub6 != nil { - select { - case <-sub6.Wait(): - case <-ctx.Done(): - } - } - close(done) - }() - s.sendQuery(ctx, fqdn, clientIP, option) + noResponseErrCh := make(chan error, 2) + s.sendQuery(ctx, noResponseErrCh, fqdn, option) start := time.Now() - for { - ips, ttl, err := s.findIPsForDomain(fqdn, option) - if err != errRecordNotFound { - log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) - return ips, ttl, err - } - + if sub4 != nil { select { case <-ctx.Done(): return nil, 0, ctx.Err() - case <-done: + case err := <-noResponseErrCh: + return nil, 0, err + case <-sub4.Wait(): + sub4.Close() } } + if sub6 != nil { + select { + case <-ctx.Done(): + return nil, 0, ctx.Err() + case err := <-noResponseErrCh: + return nil, 0, err + case <-sub6.Wait(): + sub6.Close() + } + } + + ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) + return ips, ttl, err + } func isActive(s quic.Connection) bool { diff --git a/app/dns/nameserver_quic_test.go b/app/dns/nameserver_quic_test.go index 56f9c3ee..fd11d2e6 100644 --- a/app/dns/nameserver_quic_test.go +++ b/app/dns/nameserver_quic_test.go @@ -16,24 +16,23 @@ import ( func TestQUICNameServer(t *testing.T) { url, err := url.Parse("quic://dns.adguard-dns.com") common.Must(err) - s, err := NewQUICNameServer(url, QueryStrategy_USE_IP) + s, err := NewQUICNameServer(url, false, net.IP(nil)) common.Must(err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) - ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns.IPOption{ + ips, _, err := s.QueryIP(ctx, "google.com", dns.IPOption{ IPv4Enable: true, IPv6Enable: true, - }, false) + }) cancel() common.Must(err) if len(ips) == 0 { t.Error("expect some ips, but got 0") } - ctx2, cancel := context.WithTimeout(context.Background(), time.Second*5) - ips2, _, err := s.QueryIP(ctx2, "google.com", net.IP(nil), dns.IPOption{ + ips2, _, err := s.QueryIP(ctx2, "google.com", dns.IPOption{ IPv4Enable: true, IPv6Enable: true, - }, true) + }) cancel() common.Must(err) if r := cmp.Diff(ips2, ips); r != "" { @@ -44,13 +43,13 @@ func TestQUICNameServer(t *testing.T) { func TestQUICNameServerWithIPv4Override(t *testing.T) { url, err := url.Parse("quic://dns.adguard-dns.com") common.Must(err) - s, err := NewQUICNameServer(url, QueryStrategy_USE_IP4) + s, err := NewQUICNameServer(url, false, net.IP(nil)) common.Must(err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) - ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns.IPOption{ + ips, _, err := s.QueryIP(ctx, "google.com", dns.IPOption{ IPv4Enable: true, - IPv6Enable: true, - }, false) + IPv6Enable: false, + }) cancel() common.Must(err) if len(ips) == 0 { @@ -67,13 +66,13 @@ func TestQUICNameServerWithIPv4Override(t *testing.T) { func TestQUICNameServerWithIPv6Override(t *testing.T) { url, err := url.Parse("quic://dns.adguard-dns.com") common.Must(err) - s, err := NewQUICNameServer(url, QueryStrategy_USE_IP6) + s, err := NewQUICNameServer(url, false, net.IP(nil)) common.Must(err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) - ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns.IPOption{ - IPv4Enable: true, + ips, _, err := s.QueryIP(ctx, "google.com", dns.IPOption{ + IPv4Enable: false, IPv6Enable: true, - }, false) + }) cancel() common.Must(err) if len(ips) == 0 { diff --git a/app/dns/nameserver_tcp.go b/app/dns/nameserver_tcp.go index 49854312..1937c25c 100644 --- a/app/dns/nameserver_tcp.go +++ b/app/dns/nameserver_tcp.go @@ -4,12 +4,11 @@ import ( "bytes" "context" "encoding/binary" + go_errors "errors" "net/url" - "sync" "sync/atomic" "time" - "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/log" @@ -17,34 +16,28 @@ import ( "github.com/xtls/xray-core/common/net/cnc" "github.com/xtls/xray-core/common/protocol/dns" "github.com/xtls/xray-core/common/session" - "github.com/xtls/xray-core/common/signal/pubsub" - "github.com/xtls/xray-core/common/task" dns_feature "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport/internet" - "golang.org/x/net/dns/dnsmessage" ) // TCPNameServer implemented DNS over TCP (RFC7766). type TCPNameServer struct { - sync.RWMutex - name string - destination *net.Destination - ips map[string]*record - pub *pubsub.Service - cleanup *task.Periodic - reqID uint32 - dial func(context.Context) (net.Conn, error) - queryStrategy QueryStrategy + cacheController *CacheController + destination *net.Destination + reqID uint32 + dial func(context.Context) (net.Conn, error) + clientIP net.IP } // NewTCPNameServer creates DNS over TCP server object for remote resolving. func NewTCPNameServer( url *url.URL, dispatcher routing.Dispatcher, - queryStrategy QueryStrategy, + disableCache bool, + clientIP net.IP, ) (*TCPNameServer, error) { - s, err := baseTCPNameServer(url, "TCP", queryStrategy) + s, err := baseTCPNameServer(url, "TCP", disableCache, clientIP) if err != nil { return nil, err } @@ -65,8 +58,8 @@ func NewTCPNameServer( } // NewTCPLocalNameServer creates DNS over TCP client object for local resolving -func NewTCPLocalNameServer(url *url.URL, queryStrategy QueryStrategy) (*TCPNameServer, error) { - s, err := baseTCPNameServer(url, "TCPL", queryStrategy) +func NewTCPLocalNameServer(url *url.URL, disableCache bool, clientIP net.IP) (*TCPNameServer, error) { + s, err := baseTCPNameServer(url, "TCPL", disableCache, clientIP) if err != nil { return nil, err } @@ -78,7 +71,7 @@ func NewTCPLocalNameServer(url *url.URL, queryStrategy QueryStrategy) (*TCPNameS return s, nil } -func baseTCPNameServer(url *url.URL, prefix string, queryStrategy QueryStrategy) (*TCPNameServer, error) { +func baseTCPNameServer(url *url.URL, prefix string, disableCache bool, clientIP net.IP) (*TCPNameServer, error) { port := net.Port(53) if url.Port() != "" { var err error @@ -89,15 +82,9 @@ func baseTCPNameServer(url *url.URL, prefix string, queryStrategy QueryStrategy) dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port) s := &TCPNameServer{ - destination: &dest, - ips: make(map[string]*record), - pub: pubsub.NewService(), - name: prefix + "//" + dest.NetAddr(), - queryStrategy: queryStrategy, - } - s.cleanup = &task.Periodic{ - Interval: time.Minute, - Execute: s.Cleanup, + cacheController: NewCacheController(prefix+"//"+dest.NetAddr(), disableCache), + destination: &dest, + clientIP: clientIP, } return s, nil @@ -105,94 +92,17 @@ func baseTCPNameServer(url *url.URL, prefix string, queryStrategy QueryStrategy) // Name implements Server. func (s *TCPNameServer) Name() string { - return s.name -} - -// Cleanup clears expired items from cache -func (s *TCPNameServer) Cleanup() error { - now := time.Now() - s.Lock() - defer s.Unlock() - - if len(s.ips) == 0 { - return errors.New("nothing to do. stopping...") - } - - for domain, record := range s.ips { - if record.A != nil && record.A.Expire.Before(now) { - record.A = nil - } - if record.AAAA != nil && record.AAAA.Expire.Before(now) { - record.AAAA = nil - } - - if record.A == nil && record.AAAA == nil { - errors.LogDebug(context.Background(), s.name, " cleanup ", domain) - delete(s.ips, domain) - } else { - s.ips[domain] = record - } - } - - if len(s.ips) == 0 { - s.ips = make(map[string]*record) - } - - return nil -} - -func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { - elapsed := time.Since(req.start) - - s.Lock() - rec, found := s.ips[req.domain] - if !found { - rec = &record{} - } - updated := false - - switch req.reqType { - case dnsmessage.TypeA: - if isNewer(rec.A, ipRec) { - rec.A = ipRec - updated = true - } - case dnsmessage.TypeAAAA: - addr := make([]net.Address, 0) - for _, ip := range ipRec.IP { - if len(ip.IP()) == net.IPv6len { - addr = append(addr, ip) - } - } - ipRec.IP = addr - if isNewer(rec.AAAA, ipRec) { - rec.AAAA = ipRec - updated = true - } - } - errors.LogInfo(context.Background(), s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed) - - if updated { - s.ips[req.domain] = rec - } - switch req.reqType { - case dnsmessage.TypeA: - s.pub.Publish(req.domain+"4", nil) - case dnsmessage.TypeAAAA: - s.pub.Publish(req.domain+"6", nil) - } - s.Unlock() - common.Must(s.cleanup.Start()) + return s.cacheController.name } func (s *TCPNameServer) newReqID() uint16 { return uint16(atomic.AddUint32(&s.reqID, 1)) } -func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { - errors.LogDebug(ctx, s.name, " querying DNS for: ", domain) +func (s *TCPNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) { + errors.LogDebug(ctx, s.Name(), " querying DNS for: ", domain) - reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP, 0)) + reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0)) var deadline time.Time if d, ok := ctx.Deadline(); ok { @@ -221,23 +131,36 @@ func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP n b, err := dns.PackMessage(r.msg) if err != nil { errors.LogErrorInner(ctx, err, "failed to pack dns query") + noResponseErrCh <- err return } conn, err := s.dial(dnsCtx) if err != nil { errors.LogErrorInner(ctx, err, "failed to dial namesever") + noResponseErrCh <- err return } defer conn.Close() dnsReqBuf := buf.New() - binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len())) - dnsReqBuf.Write(b.Bytes()) + err = binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len())) + if err != nil { + errors.LogErrorInner(ctx, err, "binary write failed") + noResponseErrCh <- err + return + } + _, err = dnsReqBuf.Write(b.Bytes()) + if err != nil { + errors.LogErrorInner(ctx, err, "buffer write failed") + noResponseErrCh <- err + return + } b.Release() _, err = conn.Write(dnsReqBuf.Bytes()) if err != nil { errors.LogErrorInner(ctx, err, "failed to send query") + noResponseErrCh <- err return } dnsReqBuf.Release() @@ -247,131 +170,80 @@ func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP n n, err := respBuf.ReadFullFrom(conn, 2) if err != nil && n == 0 { errors.LogErrorInner(ctx, err, "failed to read response length") + noResponseErrCh <- err return } var length int16 err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length) if err != nil { errors.LogErrorInner(ctx, err, "failed to parse response length") + noResponseErrCh <- err return } respBuf.Clear() n, err = respBuf.ReadFullFrom(conn, int32(length)) if err != nil && n == 0 { errors.LogErrorInner(ctx, err, "failed to read response length") + noResponseErrCh <- err return } rec, err := parseResponse(respBuf.Bytes()) if err != nil { errors.LogErrorInner(ctx, err, "failed to parse DNS over TCP response") + noResponseErrCh <- err return } - s.updateIP(r, rec) + s.cacheController.updateIP(r, rec) }(req) } } -func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { - s.RLock() - record, found := s.ips[domain] - s.RUnlock() - - if !found { - return nil, 0, errRecordNotFound - } - - var err4 error - var err6 error - var ips []net.Address - var ip6 []net.Address - var ttl uint32 - - if option.IPv4Enable { - ips, ttl, err4 = record.A.getIPs() - } - - if option.IPv6Enable { - ip6, ttl, err6 = record.AAAA.getIPs() - ips = append(ips, ip6...) - } - - if len(ips) > 0 { - netips, err := toNetIP(ips) - return netips, ttl, err - } - - if err4 != nil { - return nil, 0, err4 - } - - if err6 != nil { - return nil, 0, err6 - } - - return nil, 0, dns_feature.ErrEmptyResponse -} - // QueryIP implements Server. -func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, uint32, error) { +func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { fqdn := Fqdn(domain) - option = ResolveIpOptionOverride(s.queryStrategy, option) - if !option.IPv4Enable && !option.IPv6Enable { - return nil, 0, dns_feature.ErrEmptyResponse - } + sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option) + defer closeSubscribers(sub4, sub6) - if disableCache { - errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name) + if s.cacheController.disableCache { + errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name()) } else { - ips, ttl, err := s.findIPsForDomain(fqdn, option) - if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 { - errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips) - log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) + ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + if !go_errors.Is(err, errRecordNotFound) { + errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips) + log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) return ips, ttl, err } } - // ipv4 and ipv6 belong to different subscription groups - var sub4, sub6 *pubsub.Subscriber - if option.IPv4Enable { - sub4 = s.pub.Subscribe(fqdn + "4") - defer sub4.Close() - } - if option.IPv6Enable { - sub6 = s.pub.Subscribe(fqdn + "6") - defer sub6.Close() - } - done := make(chan interface{}) - go func() { - if sub4 != nil { - select { - case <-sub4.Wait(): - case <-ctx.Done(): - } - } - if sub6 != nil { - select { - case <-sub6.Wait(): - case <-ctx.Done(): - } - } - close(done) - }() - s.sendQuery(ctx, fqdn, clientIP, option) + noResponseErrCh := make(chan error, 2) + s.sendQuery(ctx, noResponseErrCh, fqdn, option) start := time.Now() - for { - ips, ttl, err := s.findIPsForDomain(fqdn, option) - if err != errRecordNotFound { - log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) - return ips, ttl, err - } - + if sub4 != nil { select { case <-ctx.Done(): return nil, 0, ctx.Err() - case <-done: + case err := <-noResponseErrCh: + return nil, 0, err + case <-sub4.Wait(): + sub4.Close() } } + if sub6 != nil { + select { + case <-ctx.Done(): + return nil, 0, ctx.Err() + case err := <-noResponseErrCh: + return nil, 0, err + case <-sub6.Wait(): + sub6.Close() + } + } + + ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) + return ips, ttl, err + } diff --git a/app/dns/nameserver_tcp_test.go b/app/dns/nameserver_tcp_test.go index de4ecb89..074896ca 100644 --- a/app/dns/nameserver_tcp_test.go +++ b/app/dns/nameserver_tcp_test.go @@ -16,13 +16,13 @@ import ( func TestTCPLocalNameServer(t *testing.T) { url, err := url.Parse("tcp+local://8.8.8.8") common.Must(err) - s, err := NewTCPLocalNameServer(url, QueryStrategy_USE_IP) + s, err := NewTCPLocalNameServer(url, false, net.IP(nil)) common.Must(err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{ + ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{ IPv4Enable: true, IPv6Enable: true, - }, false) + }) cancel() common.Must(err) if len(ips) == 0 { @@ -33,13 +33,13 @@ func TestTCPLocalNameServer(t *testing.T) { func TestTCPLocalNameServerWithCache(t *testing.T) { url, err := url.Parse("tcp+local://8.8.8.8") common.Must(err) - s, err := NewTCPLocalNameServer(url, QueryStrategy_USE_IP) + s, err := NewTCPLocalNameServer(url, false, net.IP(nil)) common.Must(err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{ + ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{ IPv4Enable: true, IPv6Enable: true, - }, false) + }) cancel() common.Must(err) if len(ips) == 0 { @@ -47,10 +47,10 @@ func TestTCPLocalNameServerWithCache(t *testing.T) { } ctx2, cancel := context.WithTimeout(context.Background(), time.Second*5) - ips2, _, err := s.QueryIP(ctx2, "google.com", net.IP(nil), dns_feature.IPOption{ + ips2, _, err := s.QueryIP(ctx2, "google.com", dns_feature.IPOption{ IPv4Enable: true, IPv6Enable: true, - }, true) + }) cancel() common.Must(err) if r := cmp.Diff(ips2, ips); r != "" { @@ -61,13 +61,13 @@ func TestTCPLocalNameServerWithCache(t *testing.T) { func TestTCPLocalNameServerWithIPv4Override(t *testing.T) { url, err := url.Parse("tcp+local://8.8.8.8") common.Must(err) - s, err := NewTCPLocalNameServer(url, QueryStrategy_USE_IP4) + s, err := NewTCPLocalNameServer(url, false, net.IP(nil)) common.Must(err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{ + ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{ IPv4Enable: true, - IPv6Enable: true, - }, false) + IPv6Enable: false, + }) cancel() common.Must(err) @@ -85,13 +85,13 @@ func TestTCPLocalNameServerWithIPv4Override(t *testing.T) { func TestTCPLocalNameServerWithIPv6Override(t *testing.T) { url, err := url.Parse("tcp+local://8.8.8.8") common.Must(err) - s, err := NewTCPLocalNameServer(url, QueryStrategy_USE_IP6) + s, err := NewTCPLocalNameServer(url, false, net.IP(nil)) common.Must(err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{ - IPv4Enable: true, + ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{ + IPv4Enable: false, IPv6Enable: true, - }, false) + }) cancel() common.Must(err) diff --git a/app/dns/nameserver_udp.go b/app/dns/nameserver_udp.go index 23803efa..3c25e612 100644 --- a/app/dns/nameserver_udp.go +++ b/app/dns/nameserver_udp.go @@ -2,6 +2,7 @@ package dns import ( "context" + go_errors "errors" "strings" "sync" "sync/atomic" @@ -13,7 +14,6 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol/dns" udp_proto "github.com/xtls/xray-core/common/protocol/udp" - "github.com/xtls/xray-core/common/signal/pubsub" "github.com/xtls/xray-core/common/task" dns_feature "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/features/routing" @@ -24,15 +24,13 @@ import ( // ClassicNameServer implemented traditional UDP DNS. type ClassicNameServer struct { sync.RWMutex - name string - address *net.Destination - ips map[string]*record - requests map[uint16]*udpDnsRequest - pub *pubsub.Service - udpServer *udp.Dispatcher - cleanup *task.Periodic - reqID uint32 - queryStrategy QueryStrategy + cacheController *CacheController + address *net.Destination + requests map[uint16]*udpDnsRequest + udpServer *udp.Dispatcher + requestsCleanup *task.Periodic + reqID uint32 + clientIP net.IP } type udpDnsRequest struct { @@ -41,23 +39,21 @@ type udpDnsRequest struct { } // NewClassicNameServer creates udp server object for remote resolving. -func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) *ClassicNameServer { +func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, disableCache bool, clientIP net.IP) *ClassicNameServer { // default to 53 if unspecific if address.Port == 0 { address.Port = net.Port(53) } s := &ClassicNameServer{ - address: &address, - ips: make(map[string]*record), - requests: make(map[uint16]*udpDnsRequest), - pub: pubsub.NewService(), - name: strings.ToUpper(address.String()), - queryStrategy: queryStrategy, + cacheController: NewCacheController(strings.ToUpper(address.String()), disableCache), + address: &address, + requests: make(map[uint16]*udpDnsRequest), + clientIP: clientIP, } - s.cleanup = &task.Periodic{ + s.requestsCleanup = &task.Periodic{ Interval: time.Minute, - Execute: s.Cleanup, + Execute: s.RequestsCleanup, } s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse) errors.LogInfo(context.Background(), "DNS: created UDP client initialized for ", address.NetAddr()) @@ -66,37 +62,17 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher // Name implements Server. func (s *ClassicNameServer) Name() string { - return s.name + return s.cacheController.name } -// Cleanup clears expired items from cache -func (s *ClassicNameServer) Cleanup() error { +// RequestsCleanup clears expired items from cache +func (s *ClassicNameServer) RequestsCleanup() error { now := time.Now() s.Lock() defer s.Unlock() - if len(s.ips) == 0 && len(s.requests) == 0 { - return errors.New(s.name, " nothing to do. stopping...") - } - - for domain, record := range s.ips { - if record.A != nil && record.A.Expire.Before(now) { - record.A = nil - } - if record.AAAA != nil && record.AAAA.Expire.Before(now) { - record.AAAA = nil - } - - if record.A == nil && record.AAAA == nil { - errors.LogDebug(context.Background(), s.name, " cleanup ", domain) - delete(s.ips, domain) - } else { - s.ips[domain] = record - } - } - - if len(s.ips) == 0 { - s.ips = make(map[string]*record) + if len(s.requests) == 0 { + return errors.New(s.Name(), " nothing to do. stopping...") } for id, req := range s.requests { @@ -116,7 +92,7 @@ func (s *ClassicNameServer) Cleanup() error { func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) { ipRec, err := parseResponse(packet.Payload.Bytes()) if err != nil { - errors.LogError(ctx, s.name, " fail to parse responded DNS udp") + errors.LogError(ctx, s.Name(), " fail to parse responded DNS udp") return } @@ -129,14 +105,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot } s.Unlock() if !ok { - errors.LogError(ctx, s.name, " cannot find the pending request") + errors.LogError(ctx, s.Name(), " cannot find the pending request") return } // if truncated, retry with EDNS0 option(udp payload size: 1350) if ipRec.RawHeader.Truncated { // if already has EDNS0 option, no need to retry - if ok && len(req.msg.Additionals) == 0 { + if len(req.msg.Additionals) == 0 { // copy necessary meta data from original request // and add EDNS0 option opt := new(dnsmessage.Resource) @@ -154,51 +130,7 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot } } - var rec record - switch req.reqType { - case dnsmessage.TypeA: - rec.A = ipRec - case dnsmessage.TypeAAAA: - rec.AAAA = ipRec - } - - elapsed := time.Since(req.start) - errors.LogInfo(ctx, s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed) - if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) { - s.updateIP(req.domain, &rec) - } -} - -func (s *ClassicNameServer) updateIP(domain string, newRec *record) { - s.Lock() - - rec, found := s.ips[domain] - if !found { - rec = &record{} - } - - updated := false - if isNewer(rec.A, newRec.A) { - rec.A = newRec.A - updated = true - } - if isNewer(rec.AAAA, newRec.AAAA) { - rec.AAAA = newRec.AAAA - updated = true - } - - if updated { - errors.LogDebug(context.Background(), s.name, " updating IP records for domain:", domain) - s.ips[domain] = rec - } - if newRec.A != nil { - s.pub.Publish(domain+"4", nil) - } - if newRec.AAAA != nil { - s.pub.Publish(domain+"6", nil) - } - s.Unlock() - common.Must(s.cleanup.Start()) + s.cacheController.updateIP(&req.dnsRequest, ipRec) } func (s *ClassicNameServer) newReqID() uint16 { @@ -207,17 +139,17 @@ func (s *ClassicNameServer) newReqID() uint16 { func (s *ClassicNameServer) addPendingRequest(req *udpDnsRequest) { s.Lock() - defer s.Unlock() - id := req.msg.ID req.expire = time.Now().Add(time.Second * 8) s.requests[id] = req + s.Unlock() + common.Must(s.requestsCleanup.Start()) } -func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { - errors.LogDebug(ctx, s.name, " querying DNS for: ", domain) +func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domain string, option dns_feature.IPOption) { + errors.LogDebug(ctx, s.Name(), " querying DNS for: ", domain) - reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP, 0)) + reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0)) for _, req := range reqs { udpReq := &udpDnsRequest{ @@ -230,105 +162,50 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, client } } -func (s *ClassicNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { - s.RLock() - record, found := s.ips[domain] - s.RUnlock() - - if !found { - return nil, 0, errRecordNotFound - } - - var err4 error - var err6 error - var ips []net.Address - var ip6 []net.Address - var ttl uint32 - - if option.IPv4Enable { - ips, ttl, err4 = record.A.getIPs() - } - - if option.IPv6Enable { - ip6, ttl, err6 = record.AAAA.getIPs() - ips = append(ips, ip6...) - } - - if len(ips) > 0 { - netips, err := toNetIP(ips) - return netips, ttl, err - } - - if err4 != nil { - return nil, 0, err4 - } - - if err6 != nil { - return nil, 0, err6 - } - - return nil, 0, dns_feature.ErrEmptyResponse -} - // QueryIP implements Server. -func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, uint32, error) { +func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { fqdn := Fqdn(domain) - option = ResolveIpOptionOverride(s.queryStrategy, option) - if !option.IPv4Enable && !option.IPv6Enable { - return nil, 0, dns_feature.ErrEmptyResponse - } + sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option) + defer closeSubscribers(sub4, sub6) - if disableCache { - errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name) + if s.cacheController.disableCache { + errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name()) } else { - ips, ttl, err := s.findIPsForDomain(fqdn, option) - if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 { - errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips) - log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) + ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + if !go_errors.Is(err, errRecordNotFound) { + errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips) + log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) return ips, ttl, err } } - // ipv4 and ipv6 belong to different subscription groups - var sub4, sub6 *pubsub.Subscriber - if option.IPv4Enable { - sub4 = s.pub.Subscribe(fqdn + "4") - defer sub4.Close() - } - if option.IPv6Enable { - sub6 = s.pub.Subscribe(fqdn + "6") - defer sub6.Close() - } - done := make(chan interface{}) - go func() { - if sub4 != nil { - select { - case <-sub4.Wait(): - case <-ctx.Done(): - } - } - if sub6 != nil { - select { - case <-sub6.Wait(): - case <-ctx.Done(): - } - } - close(done) - }() - s.sendQuery(ctx, fqdn, clientIP, option) + noResponseErrCh := make(chan error, 2) + s.sendQuery(ctx, noResponseErrCh, fqdn, option) start := time.Now() - for { - ips, ttl, err := s.findIPsForDomain(fqdn, option) - if err != errRecordNotFound { - log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) - return ips, ttl, err - } - + if sub4 != nil { select { case <-ctx.Done(): return nil, 0, ctx.Err() - case <-done: + case err := <-noResponseErrCh: + return nil, 0, err + case <-sub4.Wait(): + sub4.Close() } } + if sub6 != nil { + select { + case <-ctx.Done(): + return nil, 0, ctx.Err() + case err := <-noResponseErrCh: + return nil, 0, err + case <-sub6.Wait(): + sub6.Close() + } + } + + ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) + return ips, ttl, err + } diff --git a/app/router/condition.go b/app/router/condition.go index 89da1adf..35e2424d 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -119,7 +119,7 @@ type MultiGeoIPMatcher struct { func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) { var matchers []*GeoIPMatcher for _, geoip := range geoips { - matcher, err := globalGeoIPContainer.Add(geoip) + matcher, err := GlobalGeoIPContainer.Add(geoip) if err != nil { return nil, err } diff --git a/app/router/condition_geoip.go b/app/router/condition_geoip.go index 09c81fa8..0140cdfb 100644 --- a/app/router/condition_geoip.go +++ b/app/router/condition_geoip.go @@ -115,4 +115,4 @@ func (c *GeoIPMatcherContainer) Add(geoip *GeoIP) (*GeoIPMatcher, error) { return m, nil } -var globalGeoIPContainer GeoIPMatcherContainer +var GlobalGeoIPContainer GeoIPMatcherContainer diff --git a/common/errors/multi_error.go b/common/errors/multi_error.go index 8066ac9e..cdfec9cd 100644 --- a/common/errors/multi_error.go +++ b/common/errors/multi_error.go @@ -1,6 +1,7 @@ package errors import ( + "errors" "strings" ) @@ -36,12 +37,12 @@ func AllEqual(expected error, actual error) bool { return false } for _, err := range errs { - if err != expected { + if !errors.Is(err, expected) { return false } } return true default: - return errs == expected + return errors.Is(errs, expected) } } diff --git a/features/dns/client.go b/features/dns/client.go index 9e7be91d..7dc576fb 100644 --- a/features/dns/client.go +++ b/features/dns/client.go @@ -38,6 +38,8 @@ func ClientType() interface{} { // ErrEmptyResponse indicates that DNS query succeeded but no answer was returned. var ErrEmptyResponse = errors.New("empty response") +const DefaultTTL = 300 + type RCodeError uint16 func (e RCodeError) Error() string { diff --git a/features/dns/localdns/client.go b/features/dns/localdns/client.go index a94c9474..48e740ee 100644 --- a/features/dns/localdns/client.go +++ b/features/dns/localdns/client.go @@ -30,29 +30,31 @@ func (*Client) LookupIP(host string, option dns.IPOption) ([]net.IP, uint32, err ipv6 := make([]net.IP, 0, len(ips)) for _, ip := range ips { parsed := net.IPAddress(ip) - if parsed != nil { - parsedIPs = append(parsedIPs, parsed.IP()) + if parsed == nil { + continue } - if len(ip) == net.IPv4len { - ipv4 = append(ipv4, ip) - } - if len(ip) == net.IPv6len { - ipv6 = append(ipv6, ip) + parsedIP := parsed.IP() + parsedIPs = append(parsedIPs, parsedIP) + + if len(parsedIP) == net.IPv4len { + ipv4 = append(ipv4, parsedIP) + } else { + ipv6 = append(ipv6, parsedIP) } } - // Local DNS ttl is 600 + switch { case option.IPv4Enable && option.IPv6Enable: if len(parsedIPs) > 0 { - return parsedIPs, 600, nil + return parsedIPs, dns.DefaultTTL, nil } case option.IPv4Enable: if len(ipv4) > 0 { - return ipv4, 600, nil + return ipv4, dns.DefaultTTL, nil } case option.IPv6Enable: if len(ipv6) > 0 { - return ipv6, 600, nil + return ipv6, dns.DefaultTTL, nil } } return nil, 0, dns.ErrEmptyResponse diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index 2344d412..3308faef 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -2,6 +2,7 @@ package dns import ( "context" + go_errors "errors" "io" "sync" "time" @@ -255,7 +256,7 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, } rcode := dns.RCodeFromError(err) - if rcode == 0 && len(ips) == 0 && !errors.AllEqual(dns.ErrEmptyResponse, errors.Cause(err)) { + if rcode == 0 && len(ips) == 0 && !go_errors.Is(err, dns.ErrEmptyResponse) { errors.LogInfoInner(context.Background(), err, "ip query") return }