mirror of
https://github.com/XTLS/Xray-core.git
synced 2025-04-30 01:08:33 +00:00
DNS: Fix some bugs; Refactors; Optimizations (#4659)
This commit is contained in:
parent
1c4e246788
commit
aa4134f4a6
22 changed files with 658 additions and 978 deletions
188
app/dns/cache_controller.go
Normal file
188
app/dns/cache_controller.go
Normal file
|
@ -0,0 +1,188 @@
|
|||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
go_errors "errors"
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/signal/pubsub"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
dns_feature "github.com/xtls/xray-core/features/dns"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CacheController struct {
|
||||
sync.RWMutex
|
||||
ips map[string]*record
|
||||
pub *pubsub.Service
|
||||
cacheCleanup *task.Periodic
|
||||
name string
|
||||
disableCache bool
|
||||
}
|
||||
|
||||
func NewCacheController(name string, disableCache bool) *CacheController {
|
||||
c := &CacheController{
|
||||
name: name,
|
||||
disableCache: disableCache,
|
||||
ips: make(map[string]*record),
|
||||
pub: pubsub.NewService(),
|
||||
}
|
||||
|
||||
c.cacheCleanup = &task.Periodic{
|
||||
Interval: time.Minute,
|
||||
Execute: c.CacheCleanup,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// CacheCleanup clears expired items from cache
|
||||
func (c *CacheController) CacheCleanup() error {
|
||||
now := time.Now()
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if len(c.ips) == 0 {
|
||||
return errors.New("nothing to do. stopping...")
|
||||
}
|
||||
|
||||
for domain, record := range c.ips {
|
||||
if record.A != nil && record.A.Expire.Before(now) {
|
||||
record.A = nil
|
||||
}
|
||||
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
|
||||
record.AAAA = nil
|
||||
}
|
||||
|
||||
if record.A == nil && record.AAAA == nil {
|
||||
errors.LogDebug(context.Background(), c.name, "cache cleanup ", domain)
|
||||
delete(c.ips, domain)
|
||||
} else {
|
||||
c.ips[domain] = record
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.ips) == 0 {
|
||||
c.ips = make(map[string]*record)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) {
|
||||
elapsed := time.Since(req.start)
|
||||
|
||||
c.Lock()
|
||||
rec, found := c.ips[req.domain]
|
||||
if !found {
|
||||
rec = &record{}
|
||||
}
|
||||
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
rec.A = ipRec
|
||||
case dnsmessage.TypeAAAA:
|
||||
rec.AAAA = ipRec
|
||||
}
|
||||
|
||||
errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
|
||||
c.ips[req.domain] = rec
|
||||
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
c.pub.Publish(req.domain+"4", nil)
|
||||
if !c.disableCache {
|
||||
_, _, err := rec.AAAA.getIPs()
|
||||
if !go_errors.Is(err, errRecordNotFound) {
|
||||
c.pub.Publish(req.domain+"6", nil)
|
||||
}
|
||||
}
|
||||
case dnsmessage.TypeAAAA:
|
||||
c.pub.Publish(req.domain+"6", nil)
|
||||
if !c.disableCache {
|
||||
_, _, err := rec.A.getIPs()
|
||||
if !go_errors.Is(err, errRecordNotFound) {
|
||||
c.pub.Publish(req.domain+"4", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Unlock()
|
||||
common.Must(c.cacheCleanup.Start())
|
||||
}
|
||||
|
||||
func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
|
||||
c.RLock()
|
||||
record, found := c.ips[domain]
|
||||
c.RUnlock()
|
||||
|
||||
if !found {
|
||||
return nil, 0, errRecordNotFound
|
||||
}
|
||||
|
||||
var errs []error
|
||||
var allIPs []net.IP
|
||||
var rTTL uint32 = dns_feature.DefaultTTL
|
||||
|
||||
mergeReq := option.IPv4Enable && option.IPv6Enable
|
||||
|
||||
if option.IPv4Enable {
|
||||
ips, ttl, err := record.A.getIPs()
|
||||
if !mergeReq || go_errors.Is(err, errRecordNotFound) {
|
||||
return ips, ttl, err
|
||||
}
|
||||
if ttl < rTTL {
|
||||
rTTL = ttl
|
||||
}
|
||||
if len(ips) > 0 {
|
||||
allIPs = append(allIPs, ips...)
|
||||
} else {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if option.IPv6Enable {
|
||||
ips, ttl, err := record.AAAA.getIPs()
|
||||
if !mergeReq || go_errors.Is(err, errRecordNotFound) {
|
||||
return ips, ttl, err
|
||||
}
|
||||
if ttl < rTTL {
|
||||
rTTL = ttl
|
||||
}
|
||||
if len(ips) > 0 {
|
||||
allIPs = append(allIPs, ips...)
|
||||
} else {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allIPs) > 0 {
|
||||
return allIPs, rTTL, nil
|
||||
}
|
||||
if go_errors.Is(errs[0], errs[1]) {
|
||||
return nil, rTTL, errs[0]
|
||||
}
|
||||
return nil, rTTL, errors.Combine(errs...)
|
||||
}
|
||||
|
||||
func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
|
||||
// ipv4 and ipv6 belong to different subscription groups
|
||||
if option.IPv4Enable {
|
||||
sub4 = c.pub.Subscribe(domain + "4")
|
||||
}
|
||||
if option.IPv6Enable {
|
||||
sub6 = c.pub.Subscribe(domain + "6")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func closeSubscribers(sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
|
||||
if sub4 != nil {
|
||||
sub4.Close()
|
||||
}
|
||||
if sub6 != nil {
|
||||
sub6.Close()
|
||||
}
|
||||
}
|
121
app/dns/dns.go
121
app/dns/dns.go
|
@ -3,12 +3,12 @@ package dns
|
|||
|
||||
import (
|
||||
"context"
|
||||
go_errors "errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/xtls/xray-core/app/router"
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
|
@ -20,8 +20,6 @@ import (
|
|||
// DNS is a DNS rely server.
|
||||
type DNS struct {
|
||||
sync.Mutex
|
||||
tag string
|
||||
disableCache bool
|
||||
disableFallback bool
|
||||
disableFallbackIfMatch bool
|
||||
ipOption *dns.IPOption
|
||||
|
@ -40,13 +38,6 @@ type DomainMatcherInfo struct {
|
|||
|
||||
// New creates a new DNS server with given configuration.
|
||||
func New(ctx context.Context, config *Config) (*DNS, error) {
|
||||
var tag string
|
||||
if len(config.Tag) > 0 {
|
||||
tag = config.Tag
|
||||
} else {
|
||||
tag = generateRandomTag()
|
||||
}
|
||||
|
||||
var clientIP net.IP
|
||||
switch len(config.ClientIp) {
|
||||
case 0, net.IPv4len, net.IPv6len:
|
||||
|
@ -55,26 +46,28 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
|
|||
return nil, errors.New("unexpected client IP length ", len(config.ClientIp))
|
||||
}
|
||||
|
||||
var ipOption *dns.IPOption
|
||||
var ipOption dns.IPOption
|
||||
switch config.QueryStrategy {
|
||||
case QueryStrategy_USE_IP:
|
||||
ipOption = &dns.IPOption{
|
||||
ipOption = dns.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
FakeEnable: false,
|
||||
}
|
||||
case QueryStrategy_USE_IP4:
|
||||
ipOption = &dns.IPOption{
|
||||
ipOption = dns.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: false,
|
||||
FakeEnable: false,
|
||||
}
|
||||
case QueryStrategy_USE_IP6:
|
||||
ipOption = &dns.IPOption{
|
||||
ipOption = dns.IPOption{
|
||||
IPv4Enable: false,
|
||||
IPv6Enable: true,
|
||||
FakeEnable: false,
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unexpected query strategy ", config.QueryStrategy)
|
||||
}
|
||||
|
||||
hosts, err := NewStaticHosts(config.StaticHosts)
|
||||
|
@ -82,8 +75,14 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
|
|||
return nil, errors.New("failed to create hosts").Base(err)
|
||||
}
|
||||
|
||||
clients := []*Client{}
|
||||
var clients []*Client
|
||||
domainRuleCount := 0
|
||||
|
||||
var defaultTag = config.Tag
|
||||
if len(config.Tag) == 0 {
|
||||
defaultTag = generateRandomTag()
|
||||
}
|
||||
|
||||
for _, ns := range config.NameServer {
|
||||
domainRuleCount += len(ns.PrioritizedDomain)
|
||||
}
|
||||
|
@ -91,7 +90,6 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
|
|||
// MatcherInfos is ensured to cover the maximum index domainMatcher could return, where matcher's index starts from 1
|
||||
matcherInfos := make([]*DomainMatcherInfo, domainRuleCount+1)
|
||||
domainMatcher := &strmatcher.MatcherGroup{}
|
||||
geoipContainer := router.GeoIPMatcherContainer{}
|
||||
|
||||
for _, ns := range config.NameServer {
|
||||
clientIdx := len(clients)
|
||||
|
@ -109,7 +107,18 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
|
|||
case net.IPv4len, net.IPv6len:
|
||||
myClientIP = net.IP(ns.ClientIp)
|
||||
}
|
||||
client, err := NewClient(ctx, ns, myClientIP, geoipContainer, &matcherInfos, updateDomain)
|
||||
|
||||
disableCache := config.DisableCache
|
||||
|
||||
var tag = defaultTag
|
||||
if len(ns.Tag) > 0 {
|
||||
tag = ns.Tag
|
||||
}
|
||||
clientIPOption := ResolveIpOptionOverride(ns.QueryStrategy, ipOption)
|
||||
if !clientIPOption.IPv4Enable && !clientIPOption.IPv6Enable {
|
||||
return nil, errors.New("no QueryStrategy available for ", ns.Address)
|
||||
}
|
||||
client, err := NewClient(ctx, ns, myClientIP, disableCache, tag, clientIPOption, &matcherInfos, updateDomain)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to create client").Base(err)
|
||||
}
|
||||
|
@ -118,18 +127,16 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
|
|||
|
||||
// If there is no DNS client in config, add a `localhost` DNS client
|
||||
if len(clients) == 0 {
|
||||
clients = append(clients, NewLocalDNSClient())
|
||||
clients = append(clients, NewLocalDNSClient(ipOption))
|
||||
}
|
||||
|
||||
return &DNS{
|
||||
tag: tag,
|
||||
hosts: hosts,
|
||||
ipOption: ipOption,
|
||||
ipOption: &ipOption,
|
||||
clients: clients,
|
||||
ctx: ctx,
|
||||
domainMatcher: domainMatcher,
|
||||
matcherInfos: matcherInfos,
|
||||
disableCache: config.DisableCache,
|
||||
disableFallback: config.DisableFallback,
|
||||
disableFallbackIfMatch: config.DisableFallbackIfMatch,
|
||||
}, nil
|
||||
|
@ -153,11 +160,21 @@ func (s *DNS) Close() error {
|
|||
// IsOwnLink implements proxy.dns.ownLinkVerifier
|
||||
func (s *DNS) IsOwnLink(ctx context.Context) bool {
|
||||
inbound := session.InboundFromContext(ctx)
|
||||
return inbound != nil && inbound.Tag == s.tag
|
||||
if inbound == nil {
|
||||
return false
|
||||
}
|
||||
for _, client := range s.clients {
|
||||
if client.tag == inbound.Tag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// LookupIP implements dns.Client.
|
||||
func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, error) {
|
||||
// Normalize the FQDN form query
|
||||
domain = strings.TrimSuffix(domain, ".")
|
||||
if domain == "" {
|
||||
return nil, 0, errors.New("empty domain name")
|
||||
}
|
||||
|
@ -169,9 +186,6 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, er
|
|||
return nil, 0, dns.ErrEmptyResponse
|
||||
}
|
||||
|
||||
// Normalize the FQDN form query
|
||||
domain = strings.TrimSuffix(domain, ".")
|
||||
|
||||
// Static host lookup
|
||||
switch addrs := s.hosts.Lookup(domain, option); {
|
||||
case addrs == nil: // Domain not recorded in static host
|
||||
|
@ -184,32 +198,49 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, er
|
|||
default: // Successfully found ip records in static host
|
||||
errors.LogInfo(s.ctx, "returning ", len(addrs), " IP(s) for domain ", domain, " -> ", addrs)
|
||||
ips, err := toNetIP(addrs)
|
||||
return ips, 10, err // Hosts ttl is 10
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return ips, 10, nil // Hosts ttl is 10
|
||||
}
|
||||
|
||||
// Name servers lookup
|
||||
errs := []error{}
|
||||
ctx := session.ContextWithInbound(s.ctx, &session.Inbound{Tag: s.tag})
|
||||
var errs []error
|
||||
for _, client := range s.sortClients(domain) {
|
||||
if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") {
|
||||
errors.LogDebug(s.ctx, "skip DNS resolution for domain ", domain, " at server ", client.Name())
|
||||
continue
|
||||
}
|
||||
ips, ttl, err := client.QueryIP(ctx, domain, option, s.disableCache)
|
||||
|
||||
ips, ttl, err := client.QueryIP(s.ctx, domain, option)
|
||||
|
||||
if len(ips) > 0 {
|
||||
if ttl == 0 {
|
||||
ttl = 1
|
||||
}
|
||||
return ips, ttl, nil
|
||||
}
|
||||
if err != nil {
|
||||
errors.LogInfoInner(s.ctx, err, "failed to lookup ip for domain ", domain, " at server ", client.Name())
|
||||
errs = append(errs, err)
|
||||
}
|
||||
// 5 for RcodeRefused in miekg/dns, hardcode to reduce binary size
|
||||
if err != context.Canceled && err != context.DeadlineExceeded && err != errExpectedIPNonMatch && err != dns.ErrEmptyResponse && dns.RCodeFromError(err) != 5 {
|
||||
return nil, 0, err
|
||||
|
||||
errors.LogInfoInner(s.ctx, err, "failed to lookup ip for domain ", domain, " at server ", client.Name())
|
||||
if err == nil {
|
||||
err = dns.ErrEmptyResponse
|
||||
}
|
||||
errs = append(errs, err)
|
||||
|
||||
}
|
||||
|
||||
return nil, 0, errors.New("returning nil for domain ", domain).Base(errors.Combine(errs...))
|
||||
if len(errs) > 0 {
|
||||
allErrs := errors.Combine(errs...)
|
||||
err0 := errs[0]
|
||||
if errors.AllEqual(err0, allErrs) {
|
||||
if go_errors.Is(err0, dns.ErrEmptyResponse) {
|
||||
return nil, 0, dns.ErrEmptyResponse
|
||||
}
|
||||
return nil, 0, errors.New("returning nil for domain ", domain).Base(err0)
|
||||
}
|
||||
return nil, 0, errors.New("returning nil for domain ", domain).Base(allErrs)
|
||||
}
|
||||
return nil, 0, dns.ErrEmptyResponse
|
||||
}
|
||||
|
||||
// LookupHosts implements dns.HostsLookup.
|
||||
|
@ -228,22 +259,6 @@ func (s *DNS) LookupHosts(domain string) *net.Address {
|
|||
return nil
|
||||
}
|
||||
|
||||
// GetIPOption implements ClientWithIPOption.
|
||||
func (s *DNS) GetIPOption() *dns.IPOption {
|
||||
return s.ipOption
|
||||
}
|
||||
|
||||
// SetQueryOption implements ClientWithIPOption.
|
||||
func (s *DNS) SetQueryOption(isIPv4Enable, isIPv6Enable bool) {
|
||||
s.ipOption.IPv4Enable = isIPv4Enable
|
||||
s.ipOption.IPv6Enable = isIPv6Enable
|
||||
}
|
||||
|
||||
// SetFakeDNSOption implements ClientWithIPOption.
|
||||
func (s *DNS) SetFakeDNSOption(isFakeEnable bool) {
|
||||
s.ipOption.FakeEnable = isFakeEnable
|
||||
}
|
||||
|
||||
func (s *DNS) sortClients(domain string) []*Client {
|
||||
clients := make([]*Client, 0, len(s.clients))
|
||||
clientUsed := make([]bool, len(s.clients))
|
||||
|
|
|
@ -76,6 +76,9 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||
case q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA:
|
||||
ans.MsgHdr.Rcode = dns.RcodeNameError
|
||||
|
||||
case q.Name == "notexist.google.com." && q.Qtype == dns.TypeA:
|
||||
ans.MsgHdr.Rcode = dns.RcodeNameError
|
||||
|
||||
case q.Name == "hostname." && q.Qtype == dns.TypeA:
|
||||
rr, _ := dns.NewRR("hostname. IN A 127.0.0.1")
|
||||
ans.Answer = append(ans.Answer, rr)
|
||||
|
@ -117,7 +120,6 @@ func TestUDPServerSubnet(t *testing.T) {
|
|||
Handler: &staticHandler{},
|
||||
UDPSize: 1200,
|
||||
}
|
||||
|
||||
go dnsServer.ListenAndServe()
|
||||
time.Sleep(time.Second)
|
||||
|
||||
|
|
|
@ -32,31 +32,30 @@ type record struct {
|
|||
// IPRecord is a cacheable item for a resolved domain
|
||||
type IPRecord struct {
|
||||
ReqID uint16
|
||||
IP []net.Address
|
||||
IP []net.IP
|
||||
Expire time.Time
|
||||
RCode dnsmessage.RCode
|
||||
RawHeader *dnsmessage.Header
|
||||
}
|
||||
|
||||
func (r *IPRecord) getIPs() ([]net.Address, uint32, error) {
|
||||
if r == nil || r.Expire.Before(time.Now()) {
|
||||
func (r *IPRecord) getIPs() ([]net.IP, uint32, error) {
|
||||
if r == nil {
|
||||
return nil, 0, errRecordNotFound
|
||||
}
|
||||
if r.RCode != dnsmessage.RCodeSuccess {
|
||||
return nil, 0, dns_feature.RCodeError(r.RCode)
|
||||
untilExpire := time.Until(r.Expire)
|
||||
if untilExpire <= 0 {
|
||||
return nil, 0, errRecordNotFound
|
||||
}
|
||||
ttl := uint32(time.Until(r.Expire) / time.Second)
|
||||
return r.IP, ttl, nil
|
||||
}
|
||||
|
||||
func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
|
||||
if newRec == nil {
|
||||
return false
|
||||
ttl := uint32(untilExpire/time.Second) + uint32(1)
|
||||
if r.RCode != dnsmessage.RCodeSuccess {
|
||||
return nil, ttl, dns_feature.RCodeError(r.RCode)
|
||||
}
|
||||
if baseRec == nil {
|
||||
return true
|
||||
if len(r.IP) == 0 {
|
||||
return nil, ttl, dns_feature.ErrEmptyResponse
|
||||
}
|
||||
return baseRec.Expire.Before(newRec.Expire)
|
||||
|
||||
return r.IP, ttl, nil
|
||||
}
|
||||
|
||||
var errRecordNotFound = errors.New("record not found")
|
||||
|
@ -193,7 +192,7 @@ func parseResponse(payload []byte) (*IPRecord, error) {
|
|||
ipRecord := &IPRecord{
|
||||
ReqID: h.ID,
|
||||
RCode: h.RCode,
|
||||
Expire: now.Add(time.Second * 600),
|
||||
Expire: now.Add(time.Second * dns_feature.DefaultTTL),
|
||||
RawHeader: &h,
|
||||
}
|
||||
|
||||
|
@ -209,7 +208,7 @@ L:
|
|||
|
||||
ttl := ah.TTL
|
||||
if ttl == 0 {
|
||||
ttl = 600
|
||||
ttl = 1
|
||||
}
|
||||
expire := now.Add(time.Duration(ttl) * time.Second)
|
||||
if ipRecord.Expire.After(expire) {
|
||||
|
@ -223,14 +222,17 @@ L:
|
|||
errors.LogInfoInner(context.Background(), err, "failed to parse A record for domain: ", ah.Name)
|
||||
break L
|
||||
}
|
||||
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
|
||||
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]).IP())
|
||||
case dnsmessage.TypeAAAA:
|
||||
ans, err := parser.AAAAResource()
|
||||
if err != nil {
|
||||
errors.LogInfoInner(context.Background(), err, "failed to parse AAAA record for domain: ", ah.Name)
|
||||
break L
|
||||
}
|
||||
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
|
||||
newIP := net.IPAddress(ans.AAAA[:]).IP()
|
||||
if len(newIP) == net.IPv6len {
|
||||
ipRecord.IP = append(ipRecord.IP, newIP)
|
||||
}
|
||||
default:
|
||||
if err := parser.SkipAnswer(); err != nil {
|
||||
errors.LogInfoInner(context.Background(), err, "failed to skip answer")
|
||||
|
|
|
@ -51,7 +51,7 @@ func Test_parseResponse(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
"empty",
|
||||
&IPRecord{0, []net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess, nil},
|
||||
&IPRecord{0, []net.IP(nil), time.Time{}, dnsmessage.RCodeSuccess, nil},
|
||||
false,
|
||||
},
|
||||
{
|
||||
|
@ -63,7 +63,7 @@ func Test_parseResponse(t *testing.T) {
|
|||
"a record",
|
||||
&IPRecord{
|
||||
1,
|
||||
[]net.Address{net.ParseAddress("8.8.8.8"), net.ParseAddress("8.8.4.4")},
|
||||
[]net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")},
|
||||
time.Time{},
|
||||
dnsmessage.RCodeSuccess,
|
||||
nil,
|
||||
|
@ -72,7 +72,7 @@ func Test_parseResponse(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"aaaa record",
|
||||
&IPRecord{2, []net.Address{net.ParseAddress("2001::123:8888"), net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess, nil},
|
||||
&IPRecord{2, []net.IP{net.ParseIP("2001::123:8888"), net.ParseIP("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess, nil},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -21,25 +21,23 @@ type Server interface {
|
|||
// Name of the Client.
|
||||
Name() string
|
||||
// QueryIP sends IP queries to its configured server.
|
||||
QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns.IPOption, disableCache bool) ([]net.IP, uint32, error)
|
||||
QueryIP(ctx context.Context, domain string, option dns.IPOption) ([]net.IP, uint32, error)
|
||||
}
|
||||
|
||||
// Client is the interface for DNS client.
|
||||
type Client struct {
|
||||
server Server
|
||||
clientIP net.IP
|
||||
skipFallback bool
|
||||
domains []string
|
||||
expectedIPs []*router.GeoIPMatcher
|
||||
allowUnexpectedIPs bool
|
||||
tag string
|
||||
timeoutMs time.Duration
|
||||
ipOption *dns.IPOption
|
||||
}
|
||||
|
||||
var errExpectedIPNonMatch = errors.New("expectedIPs not match")
|
||||
|
||||
// NewServer creates a name server object according to the network destination url.
|
||||
func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) (Server, error) {
|
||||
func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dispatcher, disableCache bool, clientIP net.IP) (Server, error) {
|
||||
if address := dest.Address; address.Family().IsDomain() {
|
||||
u, err := url.Parse(address.Domain())
|
||||
if err != nil {
|
||||
|
@ -47,26 +45,29 @@ func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dis
|
|||
}
|
||||
switch {
|
||||
case strings.EqualFold(u.String(), "localhost"):
|
||||
return NewLocalNameServer(queryStrategy), nil
|
||||
return NewLocalNameServer(), nil
|
||||
case strings.EqualFold(u.Scheme, "https"): // DNS-over-HTTPS Remote mode
|
||||
return NewDoHNameServer(u, queryStrategy, dispatcher, false), nil
|
||||
return NewDoHNameServer(u, dispatcher, false, disableCache, clientIP), nil
|
||||
case strings.EqualFold(u.Scheme, "h2c"): // DNS-over-HTTPS h2c Remote mode
|
||||
return NewDoHNameServer(u, queryStrategy, dispatcher, true), nil
|
||||
return NewDoHNameServer(u, dispatcher, true, disableCache, clientIP), nil
|
||||
case strings.EqualFold(u.Scheme, "https+local"): // DNS-over-HTTPS Local mode
|
||||
return NewDoHNameServer(u, queryStrategy, nil, false), nil
|
||||
return NewDoHNameServer(u, nil, false, disableCache, clientIP), nil
|
||||
case strings.EqualFold(u.Scheme, "h2c+local"): // DNS-over-HTTPS h2c Local mode
|
||||
return NewDoHNameServer(u, queryStrategy, nil, true), nil
|
||||
return NewDoHNameServer(u, nil, true, disableCache, clientIP), nil
|
||||
case strings.EqualFold(u.Scheme, "quic+local"): // DNS-over-QUIC Local mode
|
||||
return NewQUICNameServer(u, queryStrategy)
|
||||
return NewQUICNameServer(u, disableCache, clientIP)
|
||||
case strings.EqualFold(u.Scheme, "tcp"): // DNS-over-TCP Remote mode
|
||||
return NewTCPNameServer(u, dispatcher, queryStrategy)
|
||||
return NewTCPNameServer(u, dispatcher, disableCache, clientIP)
|
||||
case strings.EqualFold(u.Scheme, "tcp+local"): // DNS-over-TCP Local mode
|
||||
return NewTCPLocalNameServer(u, queryStrategy)
|
||||
return NewTCPLocalNameServer(u, disableCache, clientIP)
|
||||
case strings.EqualFold(u.String(), "fakedns"):
|
||||
var fd dns.FakeDNSEngine
|
||||
core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
|
||||
err = core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
|
||||
fd = fdns
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewFakeDNSServer(fd), nil
|
||||
}
|
||||
}
|
||||
|
@ -74,7 +75,7 @@ func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dis
|
|||
dest.Network = net.Network_UDP
|
||||
}
|
||||
if dest.Network == net.Network_UDP { // UDP classic DNS mode
|
||||
return NewClassicNameServer(dest, dispatcher, queryStrategy), nil
|
||||
return NewClassicNameServer(dest, dispatcher, disableCache, clientIP), nil
|
||||
}
|
||||
return nil, errors.New("No available name server could be created from ", dest).AtWarning()
|
||||
}
|
||||
|
@ -84,7 +85,9 @@ func NewClient(
|
|||
ctx context.Context,
|
||||
ns *NameServer,
|
||||
clientIP net.IP,
|
||||
container router.GeoIPMatcherContainer,
|
||||
disableCache bool,
|
||||
tag string,
|
||||
ipOption dns.IPOption,
|
||||
matcherInfos *[]*DomainMatcherInfo,
|
||||
updateDomainRule func(strmatcher.Matcher, int, []*DomainMatcherInfo) error,
|
||||
) (*Client, error) {
|
||||
|
@ -92,7 +95,7 @@ func NewClient(
|
|||
|
||||
err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error {
|
||||
// Create a new server for each client for now
|
||||
server, err := NewServer(ctx, ns.Address.AsDestination(), dispatcher, ns.GetQueryStrategy())
|
||||
server, err := NewServer(ctx, ns.Address.AsDestination(), dispatcher, disableCache, clientIP)
|
||||
if err != nil {
|
||||
return errors.New("failed to create nameserver").Base(err).AtWarning()
|
||||
}
|
||||
|
@ -149,7 +152,7 @@ func NewClient(
|
|||
// Establish expected IPs
|
||||
var matchers []*router.GeoIPMatcher
|
||||
for _, geoip := range ns.Geoip {
|
||||
matcher, err := container.Add(geoip)
|
||||
matcher, err := router.GlobalGeoIPContainer.Add(geoip)
|
||||
if err != nil {
|
||||
return errors.New("failed to create ip matcher").Base(err).AtWarning()
|
||||
}
|
||||
|
@ -169,15 +172,15 @@ func NewClient(
|
|||
if ns.TimeoutMs > 0 {
|
||||
timeoutMs = time.Duration(ns.TimeoutMs) * time.Millisecond
|
||||
}
|
||||
|
||||
|
||||
client.server = server
|
||||
client.clientIP = clientIP
|
||||
client.skipFallback = ns.SkipFallback
|
||||
client.domains = rules
|
||||
client.expectedIPs = matchers
|
||||
client.allowUnexpectedIPs = ns.AllowUnexpectedIPs
|
||||
client.tag = ns.Tag
|
||||
client.tag = tag
|
||||
client.timeoutMs = timeoutMs
|
||||
client.ipOption = &ipOption
|
||||
return nil
|
||||
})
|
||||
return client, err
|
||||
|
@ -189,31 +192,43 @@ func (c *Client) Name() string {
|
|||
}
|
||||
|
||||
// QueryIP sends DNS query to the name server with the client's IP.
|
||||
func (c *Client) QueryIP(ctx context.Context, domain string, option dns.IPOption, disableCache bool) ([]net.IP, uint32, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, c.timeoutMs)
|
||||
if len(c.tag) != 0 {
|
||||
content := session.InboundFromContext(ctx)
|
||||
errors.LogDebug(ctx, "DNS: client override tag from ", content.Tag, " to ", c.tag)
|
||||
// create a new context to override the tag
|
||||
// do not direct set *content.Tag, it might be used by other clients
|
||||
ctx = session.ContextWithInbound(ctx, &session.Inbound{Tag: c.tag})
|
||||
func (c *Client) QueryIP(ctx context.Context, domain string, option dns.IPOption) ([]net.IP, uint32, error) {
|
||||
option.IPv4Enable = option.IPv4Enable && c.ipOption.IPv4Enable
|
||||
option.IPv6Enable = option.IPv6Enable && c.ipOption.IPv6Enable
|
||||
if !option.IPv4Enable && !option.IPv6Enable {
|
||||
return nil, 0, dns.ErrEmptyResponse
|
||||
}
|
||||
ips, ttl, err := c.server.QueryIP(ctx, domain, c.clientIP, option, disableCache)
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, c.timeoutMs)
|
||||
ctx = session.ContextWithInbound(ctx, &session.Inbound{Tag: c.tag})
|
||||
ips, ttl, err := c.server.QueryIP(ctx, domain, option)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
return ips, ttl, err
|
||||
return nil, 0, err
|
||||
}
|
||||
netips, err := c.MatchExpectedIPs(domain, ips)
|
||||
return netips, ttl, err
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, 0, dns.ErrEmptyResponse
|
||||
}
|
||||
|
||||
if len(c.expectedIPs) > 0 {
|
||||
newIps := c.MatchExpectedIPs(domain, ips)
|
||||
if len(newIps) == 0 {
|
||||
if !c.allowUnexpectedIPs {
|
||||
return nil, 0, dns.ErrEmptyResponse
|
||||
}
|
||||
} else {
|
||||
ips = newIps
|
||||
}
|
||||
}
|
||||
|
||||
return ips, ttl, nil
|
||||
}
|
||||
|
||||
// MatchExpectedIPs matches queried domain IPs with expected IPs and returns matched ones.
|
||||
func (c *Client) MatchExpectedIPs(domain string, ips []net.IP) ([]net.IP, error) {
|
||||
if len(c.expectedIPs) == 0 {
|
||||
return ips, nil
|
||||
}
|
||||
newIps := []net.IP{}
|
||||
func (c *Client) MatchExpectedIPs(domain string, ips []net.IP) []net.IP {
|
||||
var newIps []net.IP
|
||||
for _, ip := range ips {
|
||||
for _, matcher := range c.expectedIPs {
|
||||
if matcher.Match(ip) {
|
||||
|
@ -222,14 +237,8 @@ func (c *Client) MatchExpectedIPs(domain string, ips []net.IP) ([]net.IP, error)
|
|||
}
|
||||
}
|
||||
}
|
||||
if len(newIps) == 0 {
|
||||
if c.allowUnexpectedIPs {
|
||||
return ips, nil
|
||||
}
|
||||
return nil, errExpectedIPNonMatch
|
||||
}
|
||||
errors.LogDebug(context.Background(), "domain ", domain, " expectedIPs ", newIps, " matched at server ", c.Name())
|
||||
return newIps, nil
|
||||
return newIps
|
||||
}
|
||||
|
||||
func ResolveIpOptionOverride(queryStrategy QueryStrategy, ipOption dns.IPOption) dns.IPOption {
|
||||
|
|
|
@ -4,12 +4,12 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
go_errors "errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
utls "github.com/refraction-networking/utls"
|
||||
|
@ -21,12 +21,9 @@ import (
|
|||
"github.com/xtls/xray-core/common/net/cnc"
|
||||
"github.com/xtls/xray-core/common/protocol/dns"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal/pubsub"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
dns_feature "github.com/xtls/xray-core/features/dns"
|
||||
"github.com/xtls/xray-core/features/routing"
|
||||
"github.com/xtls/xray-core/transport/internet"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
|
@ -34,18 +31,14 @@ import (
|
|||
// which is compatible with traditional dns over udp(RFC1035),
|
||||
// thus most of the DOH implementation is copied from udpns.go
|
||||
type DoHNameServer struct {
|
||||
sync.RWMutex
|
||||
ips map[string]*record
|
||||
pub *pubsub.Service
|
||||
cleanup *task.Periodic
|
||||
httpClient *http.Client
|
||||
dohURL string
|
||||
name string
|
||||
queryStrategy QueryStrategy
|
||||
cacheController *CacheController
|
||||
httpClient *http.Client
|
||||
dohURL string
|
||||
clientIP net.IP
|
||||
}
|
||||
|
||||
// NewDoHNameServer creates DOH/DOHL client object for remote/local resolving.
|
||||
func NewDoHNameServer(url *url.URL, queryStrategy QueryStrategy, dispatcher routing.Dispatcher, h2c bool) *DoHNameServer {
|
||||
func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, h2c bool, disableCache bool, clientIP net.IP) *DoHNameServer {
|
||||
url.Scheme = "https"
|
||||
mode := "DOH"
|
||||
if dispatcher == nil {
|
||||
|
@ -53,15 +46,9 @@ func NewDoHNameServer(url *url.URL, queryStrategy QueryStrategy, dispatcher rout
|
|||
}
|
||||
errors.LogInfo(context.Background(), "DNS: created ", mode, " client for ", url.String(), ", with h2c ", h2c)
|
||||
s := &DoHNameServer{
|
||||
ips: make(map[string]*record),
|
||||
pub: pubsub.NewService(),
|
||||
name: mode + "//" + url.Host,
|
||||
dohURL: url.String(),
|
||||
queryStrategy: queryStrategy,
|
||||
}
|
||||
s.cleanup = &task.Periodic{
|
||||
Interval: time.Minute,
|
||||
Execute: s.Cleanup,
|
||||
cacheController: NewCacheController(mode+"//"+url.Host, disableCache),
|
||||
dohURL: url.String(),
|
||||
clientIP: clientIP,
|
||||
}
|
||||
s.httpClient = &http.Client{
|
||||
Transport: &http2.Transport{
|
||||
|
@ -127,101 +114,25 @@ func NewDoHNameServer(url *url.URL, queryStrategy QueryStrategy, dispatcher rout
|
|||
|
||||
// Name implements Server.
|
||||
func (s *DoHNameServer) Name() string {
|
||||
return s.name
|
||||
}
|
||||
|
||||
// Cleanup clears expired items from cache
|
||||
func (s *DoHNameServer) Cleanup() error {
|
||||
now := time.Now()
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
if len(s.ips) == 0 {
|
||||
return errors.New("nothing to do. stopping...")
|
||||
}
|
||||
|
||||
for domain, record := range s.ips {
|
||||
if record.A != nil && record.A.Expire.Before(now) {
|
||||
record.A = nil
|
||||
}
|
||||
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
|
||||
record.AAAA = nil
|
||||
}
|
||||
|
||||
if record.A == nil && record.AAAA == nil {
|
||||
errors.LogDebug(context.Background(), s.name, " cleanup ", domain)
|
||||
delete(s.ips, domain)
|
||||
} else {
|
||||
s.ips[domain] = record
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.ips) == 0 {
|
||||
s.ips = make(map[string]*record)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
|
||||
elapsed := time.Since(req.start)
|
||||
|
||||
s.Lock()
|
||||
rec, found := s.ips[req.domain]
|
||||
if !found {
|
||||
rec = &record{}
|
||||
}
|
||||
updated := false
|
||||
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
if isNewer(rec.A, ipRec) {
|
||||
rec.A = ipRec
|
||||
updated = true
|
||||
}
|
||||
case dnsmessage.TypeAAAA:
|
||||
addr := make([]net.Address, 0, len(ipRec.IP))
|
||||
for _, ip := range ipRec.IP {
|
||||
if len(ip.IP()) == net.IPv6len {
|
||||
addr = append(addr, ip)
|
||||
}
|
||||
}
|
||||
ipRec.IP = addr
|
||||
if isNewer(rec.AAAA, ipRec) {
|
||||
rec.AAAA = ipRec
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
errors.LogInfo(context.Background(), s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
|
||||
|
||||
if updated {
|
||||
s.ips[req.domain] = rec
|
||||
}
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
s.pub.Publish(req.domain+"4", nil)
|
||||
case dnsmessage.TypeAAAA:
|
||||
s.pub.Publish(req.domain+"6", nil)
|
||||
}
|
||||
s.Unlock()
|
||||
common.Must(s.cleanup.Start())
|
||||
return s.cacheController.name
|
||||
}
|
||||
|
||||
func (s *DoHNameServer) newReqID() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
|
||||
errors.LogInfo(ctx, s.name, " querying: ", domain)
|
||||
func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) {
|
||||
errors.LogInfo(ctx, s.Name(), " querying: ", domain)
|
||||
|
||||
if s.name+"." == "DOH//"+domain {
|
||||
errors.LogError(ctx, s.name, " tries to resolve itself! Use IP or set \"hosts\" instead.")
|
||||
if s.Name()+"." == "DOH//"+domain {
|
||||
errors.LogError(ctx, s.Name(), " tries to resolve itself! Use IP or set \"hosts\" instead.")
|
||||
noResponseErrCh <- errors.New("tries to resolve itself!", s.Name())
|
||||
return
|
||||
}
|
||||
|
||||
// As we don't want our traffic pattern looks like DoH, we use Random-Length Padding instead of Block-Length Padding recommended in RFC 8467
|
||||
// Although DoH server like 1.1.1.1 will pad the response to Block-Length 468, at least it is better than no padding for response at all
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP, int(crypto.RandBetween(100, 300))))
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, int(crypto.RandBetween(100, 300))))
|
||||
|
||||
var deadline time.Time
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
|
@ -256,19 +167,22 @@ func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, clientIP n
|
|||
b, err := dns.PackMessage(r.msg)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to pack dns query for ", domain)
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes())
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to retrieve response for ", domain)
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
rec, err := parseResponse(resp)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to handle DOH response for ", domain)
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
s.updateIP(r, rec)
|
||||
s.cacheController.updateIP(r, rec)
|
||||
}(req)
|
||||
}
|
||||
}
|
||||
|
@ -301,109 +215,50 @@ func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte,
|
|||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
|
||||
s.RLock()
|
||||
record, found := s.ips[domain]
|
||||
s.RUnlock()
|
||||
|
||||
if !found {
|
||||
return nil, 0, errRecordNotFound
|
||||
}
|
||||
|
||||
var err4 error
|
||||
var err6 error
|
||||
var ips []net.Address
|
||||
var ip6 []net.Address
|
||||
var ttl uint32
|
||||
|
||||
if option.IPv4Enable {
|
||||
ips, ttl, err4 = record.A.getIPs()
|
||||
}
|
||||
|
||||
if option.IPv6Enable {
|
||||
ip6, ttl, err6 = record.AAAA.getIPs()
|
||||
ips = append(ips, ip6...)
|
||||
}
|
||||
|
||||
if len(ips) > 0 {
|
||||
netips, err := toNetIP(ips)
|
||||
return netips, ttl, err
|
||||
}
|
||||
|
||||
if err4 != nil {
|
||||
return nil, 0, err4
|
||||
}
|
||||
|
||||
if err6 != nil {
|
||||
return nil, 0, err6
|
||||
}
|
||||
|
||||
if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {
|
||||
return nil, 0, dns_feature.ErrEmptyResponse
|
||||
}
|
||||
|
||||
return nil, 0, errRecordNotFound
|
||||
}
|
||||
|
||||
// QueryIP implements Server.
|
||||
func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, uint32, error) { // nolint: dupl
|
||||
func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { // nolint: dupl
|
||||
fqdn := Fqdn(domain)
|
||||
option = ResolveIpOptionOverride(s.queryStrategy, option)
|
||||
if !option.IPv4Enable && !option.IPv6Enable {
|
||||
return nil, 0, dns_feature.ErrEmptyResponse
|
||||
}
|
||||
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
|
||||
defer closeSubscribers(sub4, sub6)
|
||||
|
||||
if disableCache {
|
||||
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name)
|
||||
if s.cacheController.disableCache {
|
||||
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
|
||||
} else {
|
||||
ips, ttl, err := s.findIPsForDomain(fqdn, option)
|
||||
if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 {
|
||||
errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips)
|
||||
log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
if !go_errors.Is(err, errRecordNotFound) {
|
||||
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
|
||||
return ips, ttl, err
|
||||
}
|
||||
}
|
||||
|
||||
// ipv4 and ipv6 belong to different subscription groups
|
||||
var sub4, sub6 *pubsub.Subscriber
|
||||
if option.IPv4Enable {
|
||||
sub4 = s.pub.Subscribe(fqdn + "4")
|
||||
defer sub4.Close()
|
||||
}
|
||||
if option.IPv6Enable {
|
||||
sub6 = s.pub.Subscribe(fqdn + "6")
|
||||
defer sub6.Close()
|
||||
}
|
||||
done := make(chan interface{})
|
||||
go func() {
|
||||
if sub4 != nil {
|
||||
select {
|
||||
case <-sub4.Wait():
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
if sub6 != nil {
|
||||
select {
|
||||
case <-sub6.Wait():
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
s.sendQuery(ctx, fqdn, clientIP, option)
|
||||
noResponseErrCh := make(chan error, 2)
|
||||
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
|
||||
start := time.Now()
|
||||
|
||||
for {
|
||||
ips, ttl, err := s.findIPsForDomain(fqdn, option)
|
||||
if err != errRecordNotFound {
|
||||
log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
|
||||
return ips, ttl, err
|
||||
}
|
||||
|
||||
if sub4 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case <-done:
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub4.Wait():
|
||||
sub4.Close()
|
||||
}
|
||||
}
|
||||
if sub6 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub6.Wait():
|
||||
sub6.Close()
|
||||
}
|
||||
}
|
||||
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
|
||||
return ips, ttl, err
|
||||
|
||||
}
|
||||
|
|
|
@ -17,12 +17,12 @@ func TestDOHNameServer(t *testing.T) {
|
|||
url, err := url.Parse("https+local://1.1.1.1/dns-query")
|
||||
common.Must(err)
|
||||
|
||||
s := NewDoHNameServer(url, QueryStrategy_USE_IP, nil, false)
|
||||
s := NewDoHNameServer(url, nil, false, false, net.IP(nil))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
}, false)
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if len(ips) == 0 {
|
||||
|
@ -34,12 +34,12 @@ func TestDOHNameServerWithCache(t *testing.T) {
|
|||
url, err := url.Parse("https+local://1.1.1.1/dns-query")
|
||||
common.Must(err)
|
||||
|
||||
s := NewDoHNameServer(url, QueryStrategy_USE_IP, nil, false)
|
||||
s := NewDoHNameServer(url, nil, false, false, net.IP(nil))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
}, false)
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if len(ips) == 0 {
|
||||
|
@ -47,10 +47,10 @@ func TestDOHNameServerWithCache(t *testing.T) {
|
|||
}
|
||||
|
||||
ctx2, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
ips2, _, err := s.QueryIP(ctx2, "google.com", net.IP(nil), dns_feature.IPOption{
|
||||
ips2, _, err := s.QueryIP(ctx2, "google.com", dns_feature.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
}, true)
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if r := cmp.Diff(ips2, ips); r != "" {
|
||||
|
@ -62,12 +62,12 @@ func TestDOHNameServerWithIPv4Override(t *testing.T) {
|
|||
url, err := url.Parse("https+local://1.1.1.1/dns-query")
|
||||
common.Must(err)
|
||||
|
||||
s := NewDoHNameServer(url, QueryStrategy_USE_IP4, nil, false)
|
||||
s := NewDoHNameServer(url, nil, false, false, net.IP(nil))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
}, false)
|
||||
IPv6Enable: false,
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if len(ips) == 0 {
|
||||
|
@ -85,12 +85,12 @@ func TestDOHNameServerWithIPv6Override(t *testing.T) {
|
|||
url, err := url.Parse("https+local://1.1.1.1/dns-query")
|
||||
common.Must(err)
|
||||
|
||||
s := NewDoHNameServer(url, QueryStrategy_USE_IP6, nil, false)
|
||||
s := NewDoHNameServer(url, nil, false, false, net.IP(nil))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
|
||||
IPv4Enable: true,
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
|
||||
IPv4Enable: false,
|
||||
IPv6Enable: true,
|
||||
}, false)
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if len(ips) == 0 {
|
||||
|
|
|
@ -20,7 +20,7 @@ func (FakeDNSServer) Name() string {
|
|||
return "FakeDNS"
|
||||
}
|
||||
|
||||
func (f *FakeDNSServer) QueryIP(ctx context.Context, domain string, _ net.IP, opt dns.IPOption, _ bool) ([]net.IP, uint32, error) {
|
||||
func (f *FakeDNSServer) QueryIP(ctx context.Context, domain string, opt dns.IPOption) ([]net.IP, uint32, error) {
|
||||
if f.fakeDNSEngine == nil {
|
||||
return nil, 0, errors.New("Unable to locate a fake DNS Engine").AtError()
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package dns
|
|||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
|
@ -14,26 +13,15 @@ import (
|
|||
|
||||
// LocalNameServer is an wrapper over local DNS feature.
|
||||
type LocalNameServer struct {
|
||||
client *localdns.Client
|
||||
queryStrategy QueryStrategy
|
||||
client *localdns.Client
|
||||
}
|
||||
|
||||
const errEmptyResponse = "No address associated with hostname"
|
||||
|
||||
// QueryIP implements Server.
|
||||
func (s *LocalNameServer) QueryIP(ctx context.Context, domain string, _ net.IP, option dns.IPOption, _ bool) (ips []net.IP, ttl uint32, err error) {
|
||||
option = ResolveIpOptionOverride(s.queryStrategy, option)
|
||||
if !option.IPv4Enable && !option.IPv6Enable {
|
||||
return nil, 0, dns.ErrEmptyResponse
|
||||
}
|
||||
func (s *LocalNameServer) QueryIP(ctx context.Context, domain string, option dns.IPOption) (ips []net.IP, ttl uint32, err error) {
|
||||
|
||||
start := time.Now()
|
||||
ips, ttl, err = s.client.LookupIP(domain, option)
|
||||
|
||||
if err != nil && strings.HasSuffix(err.Error(), errEmptyResponse) {
|
||||
err = dns.ErrEmptyResponse
|
||||
}
|
||||
|
||||
if len(ips) > 0 {
|
||||
errors.LogInfo(ctx, "Localhost got answer: ", domain, " -> ", ips)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
|
||||
|
@ -48,15 +36,14 @@ func (s *LocalNameServer) Name() string {
|
|||
}
|
||||
|
||||
// NewLocalNameServer creates localdns server object for directly lookup in system DNS.
|
||||
func NewLocalNameServer(queryStrategy QueryStrategy) *LocalNameServer {
|
||||
func NewLocalNameServer() *LocalNameServer {
|
||||
errors.LogInfo(context.Background(), "DNS: created localhost client")
|
||||
return &LocalNameServer{
|
||||
queryStrategy: queryStrategy,
|
||||
client: localdns.New(),
|
||||
client: localdns.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewLocalDNSClient creates localdns client object for directly lookup in system DNS.
|
||||
func NewLocalDNSClient() *Client {
|
||||
return &Client{server: NewLocalNameServer(QueryStrategy_USE_IP)}
|
||||
func NewLocalDNSClient(ipOption dns.IPOption) *Client {
|
||||
return &Client{server: NewLocalNameServer(), ipOption: &ipOption}
|
||||
}
|
||||
|
|
|
@ -7,18 +7,17 @@ import (
|
|||
|
||||
. "github.com/xtls/xray-core/app/dns"
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/features/dns"
|
||||
)
|
||||
|
||||
func TestLocalNameServer(t *testing.T) {
|
||||
s := NewLocalNameServer(QueryStrategy_USE_IP)
|
||||
s := NewLocalNameServer()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", net.IP{}, dns.IPOption{
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", dns.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
FakeEnable: false,
|
||||
}, false)
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if len(ips) == 0 {
|
||||
|
|
|
@ -4,23 +4,20 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
go_errors "errors"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/protocol/dns"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal/pubsub"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
dns_feature "github.com/xtls/xray-core/features/dns"
|
||||
"github.com/xtls/xray-core/transport/internet/tls"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
|
@ -33,17 +30,14 @@ const handshakeTimeout = time.Second * 8
|
|||
// QUICNameServer implemented DNS over QUIC
|
||||
type QUICNameServer struct {
|
||||
sync.RWMutex
|
||||
ips map[string]*record
|
||||
pub *pubsub.Service
|
||||
cleanup *task.Periodic
|
||||
name string
|
||||
destination *net.Destination
|
||||
connection quic.Connection
|
||||
queryStrategy QueryStrategy
|
||||
cacheController *CacheController
|
||||
destination *net.Destination
|
||||
connection quic.Connection
|
||||
clientIP net.IP
|
||||
}
|
||||
|
||||
// NewQUICNameServer creates DNS-over-QUIC client object for local resolving
|
||||
func NewQUICNameServer(url *url.URL, queryStrategy QueryStrategy) (*QUICNameServer, error) {
|
||||
func NewQUICNameServer(url *url.URL, disableCache bool, clientIP net.IP) (*QUICNameServer, error) {
|
||||
errors.LogInfo(context.Background(), "DNS: created Local DNS-over-QUIC client for ", url.String())
|
||||
|
||||
var err error
|
||||
|
@ -57,15 +51,9 @@ func NewQUICNameServer(url *url.URL, queryStrategy QueryStrategy) (*QUICNameServ
|
|||
dest := net.UDPDestination(net.ParseAddress(url.Hostname()), port)
|
||||
|
||||
s := &QUICNameServer{
|
||||
ips: make(map[string]*record),
|
||||
pub: pubsub.NewService(),
|
||||
name: url.String(),
|
||||
destination: &dest,
|
||||
queryStrategy: queryStrategy,
|
||||
}
|
||||
s.cleanup = &task.Periodic{
|
||||
Interval: time.Minute,
|
||||
Execute: s.Cleanup,
|
||||
cacheController: NewCacheController(url.String(), disableCache),
|
||||
destination: &dest,
|
||||
clientIP: clientIP,
|
||||
}
|
||||
|
||||
return s, nil
|
||||
|
@ -73,94 +61,17 @@ func NewQUICNameServer(url *url.URL, queryStrategy QueryStrategy) (*QUICNameServ
|
|||
|
||||
// Name returns client name
|
||||
func (s *QUICNameServer) Name() string {
|
||||
return s.name
|
||||
}
|
||||
|
||||
// Cleanup clears expired items from cache
|
||||
func (s *QUICNameServer) Cleanup() error {
|
||||
now := time.Now()
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
if len(s.ips) == 0 {
|
||||
return errors.New("nothing to do. stopping...")
|
||||
}
|
||||
|
||||
for domain, record := range s.ips {
|
||||
if record.A != nil && record.A.Expire.Before(now) {
|
||||
record.A = nil
|
||||
}
|
||||
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
|
||||
record.AAAA = nil
|
||||
}
|
||||
|
||||
if record.A == nil && record.AAAA == nil {
|
||||
errors.LogDebug(context.Background(), s.name, " cleanup ", domain)
|
||||
delete(s.ips, domain)
|
||||
} else {
|
||||
s.ips[domain] = record
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.ips) == 0 {
|
||||
s.ips = make(map[string]*record)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
|
||||
elapsed := time.Since(req.start)
|
||||
|
||||
s.Lock()
|
||||
rec, found := s.ips[req.domain]
|
||||
if !found {
|
||||
rec = &record{}
|
||||
}
|
||||
updated := false
|
||||
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
if isNewer(rec.A, ipRec) {
|
||||
rec.A = ipRec
|
||||
updated = true
|
||||
}
|
||||
case dnsmessage.TypeAAAA:
|
||||
addr := make([]net.Address, 0)
|
||||
for _, ip := range ipRec.IP {
|
||||
if len(ip.IP()) == net.IPv6len {
|
||||
addr = append(addr, ip)
|
||||
}
|
||||
}
|
||||
ipRec.IP = addr
|
||||
if isNewer(rec.AAAA, ipRec) {
|
||||
rec.AAAA = ipRec
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
errors.LogInfo(context.Background(), s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
|
||||
|
||||
if updated {
|
||||
s.ips[req.domain] = rec
|
||||
}
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
s.pub.Publish(req.domain+"4", nil)
|
||||
case dnsmessage.TypeAAAA:
|
||||
s.pub.Publish(req.domain+"6", nil)
|
||||
}
|
||||
s.Unlock()
|
||||
common.Must(s.cleanup.Start())
|
||||
return s.cacheController.name
|
||||
}
|
||||
|
||||
func (s *QUICNameServer) newReqID() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
|
||||
errors.LogInfo(ctx, s.name, " querying: ", domain)
|
||||
func (s *QUICNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) {
|
||||
errors.LogInfo(ctx, s.Name(), " querying: ", domain)
|
||||
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP, 0))
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
|
||||
|
||||
var deadline time.Time
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
|
@ -192,23 +103,36 @@ func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP
|
|||
b, err := dns.PackMessage(r.msg)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to pack dns query")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
dnsReqBuf := buf.New()
|
||||
binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))
|
||||
dnsReqBuf.Write(b.Bytes())
|
||||
err = binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "binary write failed")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
_, err = dnsReqBuf.Write(b.Bytes())
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "buffer write failed")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
b.Release()
|
||||
|
||||
conn, err := s.openStream(dnsCtx)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to open quic connection")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.Write(dnsReqBuf.Bytes())
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to send query")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -219,136 +143,81 @@ func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP
|
|||
n, err := respBuf.ReadFullFrom(conn, 2)
|
||||
if err != nil && n == 0 {
|
||||
errors.LogErrorInner(ctx, err, "failed to read response length")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
var length int16
|
||||
err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to parse response length")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
respBuf.Clear()
|
||||
n, err = respBuf.ReadFullFrom(conn, int32(length))
|
||||
if err != nil && n == 0 {
|
||||
errors.LogErrorInner(ctx, err, "failed to read response length")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
rec, err := parseResponse(respBuf.Bytes())
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to handle response")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
s.updateIP(r, rec)
|
||||
s.cacheController.updateIP(r, rec)
|
||||
}(req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
|
||||
s.RLock()
|
||||
record, found := s.ips[domain]
|
||||
s.RUnlock()
|
||||
|
||||
if !found {
|
||||
return nil, 0, errRecordNotFound
|
||||
}
|
||||
|
||||
var err4 error
|
||||
var err6 error
|
||||
var ips []net.Address
|
||||
var ip6 []net.Address
|
||||
var ttl uint32
|
||||
|
||||
if option.IPv4Enable {
|
||||
ips, ttl, err4 = record.A.getIPs()
|
||||
}
|
||||
|
||||
if option.IPv6Enable {
|
||||
ip6, ttl, err6 = record.AAAA.getIPs()
|
||||
ips = append(ips, ip6...)
|
||||
}
|
||||
|
||||
if len(ips) > 0 {
|
||||
netips, err := toNetIP(ips)
|
||||
return netips, ttl, err
|
||||
}
|
||||
|
||||
if err4 != nil {
|
||||
return nil, 0, err4
|
||||
}
|
||||
|
||||
if err6 != nil {
|
||||
return nil, 0, err6
|
||||
}
|
||||
|
||||
if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {
|
||||
return nil, 0, dns_feature.ErrEmptyResponse
|
||||
}
|
||||
|
||||
return nil, 0, errRecordNotFound
|
||||
}
|
||||
|
||||
// QueryIP is called from dns.Server->queryIPTimeout
|
||||
func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, uint32, error) {
|
||||
func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
|
||||
fqdn := Fqdn(domain)
|
||||
option = ResolveIpOptionOverride(s.queryStrategy, option)
|
||||
if !option.IPv4Enable && !option.IPv6Enable {
|
||||
return nil, 0, dns_feature.ErrEmptyResponse
|
||||
}
|
||||
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
|
||||
defer closeSubscribers(sub4, sub6)
|
||||
|
||||
if disableCache {
|
||||
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name)
|
||||
if s.cacheController.disableCache {
|
||||
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
|
||||
} else {
|
||||
ips, ttl, err := s.findIPsForDomain(fqdn, option)
|
||||
if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 {
|
||||
errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips)
|
||||
log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
if !go_errors.Is(err, errRecordNotFound) {
|
||||
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
|
||||
return ips, ttl, err
|
||||
}
|
||||
}
|
||||
|
||||
// ipv4 and ipv6 belong to different subscription groups
|
||||
var sub4, sub6 *pubsub.Subscriber
|
||||
if option.IPv4Enable {
|
||||
sub4 = s.pub.Subscribe(fqdn + "4")
|
||||
defer sub4.Close()
|
||||
}
|
||||
if option.IPv6Enable {
|
||||
sub6 = s.pub.Subscribe(fqdn + "6")
|
||||
defer sub6.Close()
|
||||
}
|
||||
done := make(chan interface{})
|
||||
go func() {
|
||||
if sub4 != nil {
|
||||
select {
|
||||
case <-sub4.Wait():
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
if sub6 != nil {
|
||||
select {
|
||||
case <-sub6.Wait():
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
s.sendQuery(ctx, fqdn, clientIP, option)
|
||||
noResponseErrCh := make(chan error, 2)
|
||||
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
|
||||
start := time.Now()
|
||||
|
||||
for {
|
||||
ips, ttl, err := s.findIPsForDomain(fqdn, option)
|
||||
if err != errRecordNotFound {
|
||||
log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
|
||||
return ips, ttl, err
|
||||
}
|
||||
|
||||
if sub4 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case <-done:
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub4.Wait():
|
||||
sub4.Close()
|
||||
}
|
||||
}
|
||||
if sub6 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub6.Wait():
|
||||
sub6.Close()
|
||||
}
|
||||
}
|
||||
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
|
||||
return ips, ttl, err
|
||||
|
||||
}
|
||||
|
||||
func isActive(s quic.Connection) bool {
|
||||
|
|
|
@ -16,24 +16,23 @@ import (
|
|||
func TestQUICNameServer(t *testing.T) {
|
||||
url, err := url.Parse("quic://dns.adguard-dns.com")
|
||||
common.Must(err)
|
||||
s, err := NewQUICNameServer(url, QueryStrategy_USE_IP)
|
||||
s, err := NewQUICNameServer(url, false, net.IP(nil))
|
||||
common.Must(err)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns.IPOption{
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", dns.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
}, false)
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if len(ips) == 0 {
|
||||
t.Error("expect some ips, but got 0")
|
||||
}
|
||||
|
||||
ctx2, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
ips2, _, err := s.QueryIP(ctx2, "google.com", net.IP(nil), dns.IPOption{
|
||||
ips2, _, err := s.QueryIP(ctx2, "google.com", dns.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
}, true)
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if r := cmp.Diff(ips2, ips); r != "" {
|
||||
|
@ -44,13 +43,13 @@ func TestQUICNameServer(t *testing.T) {
|
|||
func TestQUICNameServerWithIPv4Override(t *testing.T) {
|
||||
url, err := url.Parse("quic://dns.adguard-dns.com")
|
||||
common.Must(err)
|
||||
s, err := NewQUICNameServer(url, QueryStrategy_USE_IP4)
|
||||
s, err := NewQUICNameServer(url, false, net.IP(nil))
|
||||
common.Must(err)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns.IPOption{
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", dns.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
}, false)
|
||||
IPv6Enable: false,
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if len(ips) == 0 {
|
||||
|
@ -67,13 +66,13 @@ func TestQUICNameServerWithIPv4Override(t *testing.T) {
|
|||
func TestQUICNameServerWithIPv6Override(t *testing.T) {
|
||||
url, err := url.Parse("quic://dns.adguard-dns.com")
|
||||
common.Must(err)
|
||||
s, err := NewQUICNameServer(url, QueryStrategy_USE_IP6)
|
||||
s, err := NewQUICNameServer(url, false, net.IP(nil))
|
||||
common.Must(err)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns.IPOption{
|
||||
IPv4Enable: true,
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", dns.IPOption{
|
||||
IPv4Enable: false,
|
||||
IPv6Enable: true,
|
||||
}, false)
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if len(ips) == 0 {
|
||||
|
|
|
@ -4,12 +4,11 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
go_errors "errors"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
|
@ -17,34 +16,28 @@ import (
|
|||
"github.com/xtls/xray-core/common/net/cnc"
|
||||
"github.com/xtls/xray-core/common/protocol/dns"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal/pubsub"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
dns_feature "github.com/xtls/xray-core/features/dns"
|
||||
"github.com/xtls/xray-core/features/routing"
|
||||
"github.com/xtls/xray-core/transport/internet"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
)
|
||||
|
||||
// TCPNameServer implemented DNS over TCP (RFC7766).
|
||||
type TCPNameServer struct {
|
||||
sync.RWMutex
|
||||
name string
|
||||
destination *net.Destination
|
||||
ips map[string]*record
|
||||
pub *pubsub.Service
|
||||
cleanup *task.Periodic
|
||||
reqID uint32
|
||||
dial func(context.Context) (net.Conn, error)
|
||||
queryStrategy QueryStrategy
|
||||
cacheController *CacheController
|
||||
destination *net.Destination
|
||||
reqID uint32
|
||||
dial func(context.Context) (net.Conn, error)
|
||||
clientIP net.IP
|
||||
}
|
||||
|
||||
// NewTCPNameServer creates DNS over TCP server object for remote resolving.
|
||||
func NewTCPNameServer(
|
||||
url *url.URL,
|
||||
dispatcher routing.Dispatcher,
|
||||
queryStrategy QueryStrategy,
|
||||
disableCache bool,
|
||||
clientIP net.IP,
|
||||
) (*TCPNameServer, error) {
|
||||
s, err := baseTCPNameServer(url, "TCP", queryStrategy)
|
||||
s, err := baseTCPNameServer(url, "TCP", disableCache, clientIP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -65,8 +58,8 @@ func NewTCPNameServer(
|
|||
}
|
||||
|
||||
// NewTCPLocalNameServer creates DNS over TCP client object for local resolving
|
||||
func NewTCPLocalNameServer(url *url.URL, queryStrategy QueryStrategy) (*TCPNameServer, error) {
|
||||
s, err := baseTCPNameServer(url, "TCPL", queryStrategy)
|
||||
func NewTCPLocalNameServer(url *url.URL, disableCache bool, clientIP net.IP) (*TCPNameServer, error) {
|
||||
s, err := baseTCPNameServer(url, "TCPL", disableCache, clientIP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -78,7 +71,7 @@ func NewTCPLocalNameServer(url *url.URL, queryStrategy QueryStrategy) (*TCPNameS
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func baseTCPNameServer(url *url.URL, prefix string, queryStrategy QueryStrategy) (*TCPNameServer, error) {
|
||||
func baseTCPNameServer(url *url.URL, prefix string, disableCache bool, clientIP net.IP) (*TCPNameServer, error) {
|
||||
port := net.Port(53)
|
||||
if url.Port() != "" {
|
||||
var err error
|
||||
|
@ -89,15 +82,9 @@ func baseTCPNameServer(url *url.URL, prefix string, queryStrategy QueryStrategy)
|
|||
dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port)
|
||||
|
||||
s := &TCPNameServer{
|
||||
destination: &dest,
|
||||
ips: make(map[string]*record),
|
||||
pub: pubsub.NewService(),
|
||||
name: prefix + "//" + dest.NetAddr(),
|
||||
queryStrategy: queryStrategy,
|
||||
}
|
||||
s.cleanup = &task.Periodic{
|
||||
Interval: time.Minute,
|
||||
Execute: s.Cleanup,
|
||||
cacheController: NewCacheController(prefix+"//"+dest.NetAddr(), disableCache),
|
||||
destination: &dest,
|
||||
clientIP: clientIP,
|
||||
}
|
||||
|
||||
return s, nil
|
||||
|
@ -105,94 +92,17 @@ func baseTCPNameServer(url *url.URL, prefix string, queryStrategy QueryStrategy)
|
|||
|
||||
// Name implements Server.
|
||||
func (s *TCPNameServer) Name() string {
|
||||
return s.name
|
||||
}
|
||||
|
||||
// Cleanup clears expired items from cache
|
||||
func (s *TCPNameServer) Cleanup() error {
|
||||
now := time.Now()
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
if len(s.ips) == 0 {
|
||||
return errors.New("nothing to do. stopping...")
|
||||
}
|
||||
|
||||
for domain, record := range s.ips {
|
||||
if record.A != nil && record.A.Expire.Before(now) {
|
||||
record.A = nil
|
||||
}
|
||||
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
|
||||
record.AAAA = nil
|
||||
}
|
||||
|
||||
if record.A == nil && record.AAAA == nil {
|
||||
errors.LogDebug(context.Background(), s.name, " cleanup ", domain)
|
||||
delete(s.ips, domain)
|
||||
} else {
|
||||
s.ips[domain] = record
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.ips) == 0 {
|
||||
s.ips = make(map[string]*record)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
|
||||
elapsed := time.Since(req.start)
|
||||
|
||||
s.Lock()
|
||||
rec, found := s.ips[req.domain]
|
||||
if !found {
|
||||
rec = &record{}
|
||||
}
|
||||
updated := false
|
||||
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
if isNewer(rec.A, ipRec) {
|
||||
rec.A = ipRec
|
||||
updated = true
|
||||
}
|
||||
case dnsmessage.TypeAAAA:
|
||||
addr := make([]net.Address, 0)
|
||||
for _, ip := range ipRec.IP {
|
||||
if len(ip.IP()) == net.IPv6len {
|
||||
addr = append(addr, ip)
|
||||
}
|
||||
}
|
||||
ipRec.IP = addr
|
||||
if isNewer(rec.AAAA, ipRec) {
|
||||
rec.AAAA = ipRec
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
errors.LogInfo(context.Background(), s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
|
||||
|
||||
if updated {
|
||||
s.ips[req.domain] = rec
|
||||
}
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
s.pub.Publish(req.domain+"4", nil)
|
||||
case dnsmessage.TypeAAAA:
|
||||
s.pub.Publish(req.domain+"6", nil)
|
||||
}
|
||||
s.Unlock()
|
||||
common.Must(s.cleanup.Start())
|
||||
return s.cacheController.name
|
||||
}
|
||||
|
||||
func (s *TCPNameServer) newReqID() uint16 {
|
||||
return uint16(atomic.AddUint32(&s.reqID, 1))
|
||||
}
|
||||
|
||||
func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
|
||||
errors.LogDebug(ctx, s.name, " querying DNS for: ", domain)
|
||||
func (s *TCPNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) {
|
||||
errors.LogDebug(ctx, s.Name(), " querying DNS for: ", domain)
|
||||
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP, 0))
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
|
||||
|
||||
var deadline time.Time
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
|
@ -221,23 +131,36 @@ func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP n
|
|||
b, err := dns.PackMessage(r.msg)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to pack dns query")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := s.dial(dnsCtx)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to dial namesever")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
dnsReqBuf := buf.New()
|
||||
binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))
|
||||
dnsReqBuf.Write(b.Bytes())
|
||||
err = binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "binary write failed")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
_, err = dnsReqBuf.Write(b.Bytes())
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "buffer write failed")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
b.Release()
|
||||
|
||||
_, err = conn.Write(dnsReqBuf.Bytes())
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to send query")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
dnsReqBuf.Release()
|
||||
|
@ -247,131 +170,80 @@ func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP n
|
|||
n, err := respBuf.ReadFullFrom(conn, 2)
|
||||
if err != nil && n == 0 {
|
||||
errors.LogErrorInner(ctx, err, "failed to read response length")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
var length int16
|
||||
err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to parse response length")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
respBuf.Clear()
|
||||
n, err = respBuf.ReadFullFrom(conn, int32(length))
|
||||
if err != nil && n == 0 {
|
||||
errors.LogErrorInner(ctx, err, "failed to read response length")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
rec, err := parseResponse(respBuf.Bytes())
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to parse DNS over TCP response")
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
s.updateIP(r, rec)
|
||||
s.cacheController.updateIP(r, rec)
|
||||
}(req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
|
||||
s.RLock()
|
||||
record, found := s.ips[domain]
|
||||
s.RUnlock()
|
||||
|
||||
if !found {
|
||||
return nil, 0, errRecordNotFound
|
||||
}
|
||||
|
||||
var err4 error
|
||||
var err6 error
|
||||
var ips []net.Address
|
||||
var ip6 []net.Address
|
||||
var ttl uint32
|
||||
|
||||
if option.IPv4Enable {
|
||||
ips, ttl, err4 = record.A.getIPs()
|
||||
}
|
||||
|
||||
if option.IPv6Enable {
|
||||
ip6, ttl, err6 = record.AAAA.getIPs()
|
||||
ips = append(ips, ip6...)
|
||||
}
|
||||
|
||||
if len(ips) > 0 {
|
||||
netips, err := toNetIP(ips)
|
||||
return netips, ttl, err
|
||||
}
|
||||
|
||||
if err4 != nil {
|
||||
return nil, 0, err4
|
||||
}
|
||||
|
||||
if err6 != nil {
|
||||
return nil, 0, err6
|
||||
}
|
||||
|
||||
return nil, 0, dns_feature.ErrEmptyResponse
|
||||
}
|
||||
|
||||
// QueryIP implements Server.
|
||||
func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, uint32, error) {
|
||||
func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
|
||||
fqdn := Fqdn(domain)
|
||||
option = ResolveIpOptionOverride(s.queryStrategy, option)
|
||||
if !option.IPv4Enable && !option.IPv6Enable {
|
||||
return nil, 0, dns_feature.ErrEmptyResponse
|
||||
}
|
||||
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
|
||||
defer closeSubscribers(sub4, sub6)
|
||||
|
||||
if disableCache {
|
||||
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name)
|
||||
if s.cacheController.disableCache {
|
||||
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
|
||||
} else {
|
||||
ips, ttl, err := s.findIPsForDomain(fqdn, option)
|
||||
if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 {
|
||||
errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips)
|
||||
log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
if !go_errors.Is(err, errRecordNotFound) {
|
||||
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
|
||||
return ips, ttl, err
|
||||
}
|
||||
}
|
||||
|
||||
// ipv4 and ipv6 belong to different subscription groups
|
||||
var sub4, sub6 *pubsub.Subscriber
|
||||
if option.IPv4Enable {
|
||||
sub4 = s.pub.Subscribe(fqdn + "4")
|
||||
defer sub4.Close()
|
||||
}
|
||||
if option.IPv6Enable {
|
||||
sub6 = s.pub.Subscribe(fqdn + "6")
|
||||
defer sub6.Close()
|
||||
}
|
||||
done := make(chan interface{})
|
||||
go func() {
|
||||
if sub4 != nil {
|
||||
select {
|
||||
case <-sub4.Wait():
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
if sub6 != nil {
|
||||
select {
|
||||
case <-sub6.Wait():
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
s.sendQuery(ctx, fqdn, clientIP, option)
|
||||
noResponseErrCh := make(chan error, 2)
|
||||
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
|
||||
start := time.Now()
|
||||
|
||||
for {
|
||||
ips, ttl, err := s.findIPsForDomain(fqdn, option)
|
||||
if err != errRecordNotFound {
|
||||
log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
|
||||
return ips, ttl, err
|
||||
}
|
||||
|
||||
if sub4 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case <-done:
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub4.Wait():
|
||||
sub4.Close()
|
||||
}
|
||||
}
|
||||
if sub6 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub6.Wait():
|
||||
sub6.Close()
|
||||
}
|
||||
}
|
||||
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
|
||||
return ips, ttl, err
|
||||
|
||||
}
|
||||
|
|
|
@ -16,13 +16,13 @@ import (
|
|||
func TestTCPLocalNameServer(t *testing.T) {
|
||||
url, err := url.Parse("tcp+local://8.8.8.8")
|
||||
common.Must(err)
|
||||
s, err := NewTCPLocalNameServer(url, QueryStrategy_USE_IP)
|
||||
s, err := NewTCPLocalNameServer(url, false, net.IP(nil))
|
||||
common.Must(err)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
}, false)
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if len(ips) == 0 {
|
||||
|
@ -33,13 +33,13 @@ func TestTCPLocalNameServer(t *testing.T) {
|
|||
func TestTCPLocalNameServerWithCache(t *testing.T) {
|
||||
url, err := url.Parse("tcp+local://8.8.8.8")
|
||||
common.Must(err)
|
||||
s, err := NewTCPLocalNameServer(url, QueryStrategy_USE_IP)
|
||||
s, err := NewTCPLocalNameServer(url, false, net.IP(nil))
|
||||
common.Must(err)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
}, false)
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if len(ips) == 0 {
|
||||
|
@ -47,10 +47,10 @@ func TestTCPLocalNameServerWithCache(t *testing.T) {
|
|||
}
|
||||
|
||||
ctx2, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
ips2, _, err := s.QueryIP(ctx2, "google.com", net.IP(nil), dns_feature.IPOption{
|
||||
ips2, _, err := s.QueryIP(ctx2, "google.com", dns_feature.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
}, true)
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if r := cmp.Diff(ips2, ips); r != "" {
|
||||
|
@ -61,13 +61,13 @@ func TestTCPLocalNameServerWithCache(t *testing.T) {
|
|||
func TestTCPLocalNameServerWithIPv4Override(t *testing.T) {
|
||||
url, err := url.Parse("tcp+local://8.8.8.8")
|
||||
common.Must(err)
|
||||
s, err := NewTCPLocalNameServer(url, QueryStrategy_USE_IP4)
|
||||
s, err := NewTCPLocalNameServer(url, false, net.IP(nil))
|
||||
common.Must(err)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
}, false)
|
||||
IPv6Enable: false,
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
|
||||
|
@ -85,13 +85,13 @@ func TestTCPLocalNameServerWithIPv4Override(t *testing.T) {
|
|||
func TestTCPLocalNameServerWithIPv6Override(t *testing.T) {
|
||||
url, err := url.Parse("tcp+local://8.8.8.8")
|
||||
common.Must(err)
|
||||
s, err := NewTCPLocalNameServer(url, QueryStrategy_USE_IP6)
|
||||
s, err := NewTCPLocalNameServer(url, false, net.IP(nil))
|
||||
common.Must(err)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
|
||||
IPv4Enable: true,
|
||||
ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
|
||||
IPv4Enable: false,
|
||||
IPv6Enable: true,
|
||||
}, false)
|
||||
})
|
||||
cancel()
|
||||
common.Must(err)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package dns
|
|||
|
||||
import (
|
||||
"context"
|
||||
go_errors "errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -13,7 +14,6 @@ import (
|
|||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/protocol/dns"
|
||||
udp_proto "github.com/xtls/xray-core/common/protocol/udp"
|
||||
"github.com/xtls/xray-core/common/signal/pubsub"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
dns_feature "github.com/xtls/xray-core/features/dns"
|
||||
"github.com/xtls/xray-core/features/routing"
|
||||
|
@ -24,15 +24,13 @@ import (
|
|||
// ClassicNameServer implemented traditional UDP DNS.
|
||||
type ClassicNameServer struct {
|
||||
sync.RWMutex
|
||||
name string
|
||||
address *net.Destination
|
||||
ips map[string]*record
|
||||
requests map[uint16]*udpDnsRequest
|
||||
pub *pubsub.Service
|
||||
udpServer *udp.Dispatcher
|
||||
cleanup *task.Periodic
|
||||
reqID uint32
|
||||
queryStrategy QueryStrategy
|
||||
cacheController *CacheController
|
||||
address *net.Destination
|
||||
requests map[uint16]*udpDnsRequest
|
||||
udpServer *udp.Dispatcher
|
||||
requestsCleanup *task.Periodic
|
||||
reqID uint32
|
||||
clientIP net.IP
|
||||
}
|
||||
|
||||
type udpDnsRequest struct {
|
||||
|
@ -41,23 +39,21 @@ type udpDnsRequest struct {
|
|||
}
|
||||
|
||||
// NewClassicNameServer creates udp server object for remote resolving.
|
||||
func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) *ClassicNameServer {
|
||||
func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, disableCache bool, clientIP net.IP) *ClassicNameServer {
|
||||
// default to 53 if unspecific
|
||||
if address.Port == 0 {
|
||||
address.Port = net.Port(53)
|
||||
}
|
||||
|
||||
s := &ClassicNameServer{
|
||||
address: &address,
|
||||
ips: make(map[string]*record),
|
||||
requests: make(map[uint16]*udpDnsRequest),
|
||||
pub: pubsub.NewService(),
|
||||
name: strings.ToUpper(address.String()),
|
||||
queryStrategy: queryStrategy,
|
||||
cacheController: NewCacheController(strings.ToUpper(address.String()), disableCache),
|
||||
address: &address,
|
||||
requests: make(map[uint16]*udpDnsRequest),
|
||||
clientIP: clientIP,
|
||||
}
|
||||
s.cleanup = &task.Periodic{
|
||||
s.requestsCleanup = &task.Periodic{
|
||||
Interval: time.Minute,
|
||||
Execute: s.Cleanup,
|
||||
Execute: s.RequestsCleanup,
|
||||
}
|
||||
s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
|
||||
errors.LogInfo(context.Background(), "DNS: created UDP client initialized for ", address.NetAddr())
|
||||
|
@ -66,37 +62,17 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
|
|||
|
||||
// Name implements Server.
|
||||
func (s *ClassicNameServer) Name() string {
|
||||
return s.name
|
||||
return s.cacheController.name
|
||||
}
|
||||
|
||||
// Cleanup clears expired items from cache
|
||||
func (s *ClassicNameServer) Cleanup() error {
|
||||
// RequestsCleanup clears expired items from cache
|
||||
func (s *ClassicNameServer) RequestsCleanup() error {
|
||||
now := time.Now()
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
if len(s.ips) == 0 && len(s.requests) == 0 {
|
||||
return errors.New(s.name, " nothing to do. stopping...")
|
||||
}
|
||||
|
||||
for domain, record := range s.ips {
|
||||
if record.A != nil && record.A.Expire.Before(now) {
|
||||
record.A = nil
|
||||
}
|
||||
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
|
||||
record.AAAA = nil
|
||||
}
|
||||
|
||||
if record.A == nil && record.AAAA == nil {
|
||||
errors.LogDebug(context.Background(), s.name, " cleanup ", domain)
|
||||
delete(s.ips, domain)
|
||||
} else {
|
||||
s.ips[domain] = record
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.ips) == 0 {
|
||||
s.ips = make(map[string]*record)
|
||||
if len(s.requests) == 0 {
|
||||
return errors.New(s.Name(), " nothing to do. stopping...")
|
||||
}
|
||||
|
||||
for id, req := range s.requests {
|
||||
|
@ -116,7 +92,7 @@ func (s *ClassicNameServer) Cleanup() error {
|
|||
func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
|
||||
ipRec, err := parseResponse(packet.Payload.Bytes())
|
||||
if err != nil {
|
||||
errors.LogError(ctx, s.name, " fail to parse responded DNS udp")
|
||||
errors.LogError(ctx, s.Name(), " fail to parse responded DNS udp")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -129,14 +105,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
|||
}
|
||||
s.Unlock()
|
||||
if !ok {
|
||||
errors.LogError(ctx, s.name, " cannot find the pending request")
|
||||
errors.LogError(ctx, s.Name(), " cannot find the pending request")
|
||||
return
|
||||
}
|
||||
|
||||
// if truncated, retry with EDNS0 option(udp payload size: 1350)
|
||||
if ipRec.RawHeader.Truncated {
|
||||
// if already has EDNS0 option, no need to retry
|
||||
if ok && len(req.msg.Additionals) == 0 {
|
||||
if len(req.msg.Additionals) == 0 {
|
||||
// copy necessary meta data from original request
|
||||
// and add EDNS0 option
|
||||
opt := new(dnsmessage.Resource)
|
||||
|
@ -154,51 +130,7 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
|||
}
|
||||
}
|
||||
|
||||
var rec record
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
rec.A = ipRec
|
||||
case dnsmessage.TypeAAAA:
|
||||
rec.AAAA = ipRec
|
||||
}
|
||||
|
||||
elapsed := time.Since(req.start)
|
||||
errors.LogInfo(ctx, s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
|
||||
if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
|
||||
s.updateIP(req.domain, &rec)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) updateIP(domain string, newRec *record) {
|
||||
s.Lock()
|
||||
|
||||
rec, found := s.ips[domain]
|
||||
if !found {
|
||||
rec = &record{}
|
||||
}
|
||||
|
||||
updated := false
|
||||
if isNewer(rec.A, newRec.A) {
|
||||
rec.A = newRec.A
|
||||
updated = true
|
||||
}
|
||||
if isNewer(rec.AAAA, newRec.AAAA) {
|
||||
rec.AAAA = newRec.AAAA
|
||||
updated = true
|
||||
}
|
||||
|
||||
if updated {
|
||||
errors.LogDebug(context.Background(), s.name, " updating IP records for domain:", domain)
|
||||
s.ips[domain] = rec
|
||||
}
|
||||
if newRec.A != nil {
|
||||
s.pub.Publish(domain+"4", nil)
|
||||
}
|
||||
if newRec.AAAA != nil {
|
||||
s.pub.Publish(domain+"6", nil)
|
||||
}
|
||||
s.Unlock()
|
||||
common.Must(s.cleanup.Start())
|
||||
s.cacheController.updateIP(&req.dnsRequest, ipRec)
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) newReqID() uint16 {
|
||||
|
@ -207,17 +139,17 @@ func (s *ClassicNameServer) newReqID() uint16 {
|
|||
|
||||
func (s *ClassicNameServer) addPendingRequest(req *udpDnsRequest) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
id := req.msg.ID
|
||||
req.expire = time.Now().Add(time.Second * 8)
|
||||
s.requests[id] = req
|
||||
s.Unlock()
|
||||
common.Must(s.requestsCleanup.Start())
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
|
||||
errors.LogDebug(ctx, s.name, " querying DNS for: ", domain)
|
||||
func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domain string, option dns_feature.IPOption) {
|
||||
errors.LogDebug(ctx, s.Name(), " querying DNS for: ", domain)
|
||||
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP, 0))
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
|
||||
|
||||
for _, req := range reqs {
|
||||
udpReq := &udpDnsRequest{
|
||||
|
@ -230,105 +162,50 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, client
|
|||
}
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
|
||||
s.RLock()
|
||||
record, found := s.ips[domain]
|
||||
s.RUnlock()
|
||||
|
||||
if !found {
|
||||
return nil, 0, errRecordNotFound
|
||||
}
|
||||
|
||||
var err4 error
|
||||
var err6 error
|
||||
var ips []net.Address
|
||||
var ip6 []net.Address
|
||||
var ttl uint32
|
||||
|
||||
if option.IPv4Enable {
|
||||
ips, ttl, err4 = record.A.getIPs()
|
||||
}
|
||||
|
||||
if option.IPv6Enable {
|
||||
ip6, ttl, err6 = record.AAAA.getIPs()
|
||||
ips = append(ips, ip6...)
|
||||
}
|
||||
|
||||
if len(ips) > 0 {
|
||||
netips, err := toNetIP(ips)
|
||||
return netips, ttl, err
|
||||
}
|
||||
|
||||
if err4 != nil {
|
||||
return nil, 0, err4
|
||||
}
|
||||
|
||||
if err6 != nil {
|
||||
return nil, 0, err6
|
||||
}
|
||||
|
||||
return nil, 0, dns_feature.ErrEmptyResponse
|
||||
}
|
||||
|
||||
// QueryIP implements Server.
|
||||
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, uint32, error) {
|
||||
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
|
||||
fqdn := Fqdn(domain)
|
||||
option = ResolveIpOptionOverride(s.queryStrategy, option)
|
||||
if !option.IPv4Enable && !option.IPv6Enable {
|
||||
return nil, 0, dns_feature.ErrEmptyResponse
|
||||
}
|
||||
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
|
||||
defer closeSubscribers(sub4, sub6)
|
||||
|
||||
if disableCache {
|
||||
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name)
|
||||
if s.cacheController.disableCache {
|
||||
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
|
||||
} else {
|
||||
ips, ttl, err := s.findIPsForDomain(fqdn, option)
|
||||
if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 {
|
||||
errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips)
|
||||
log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
if !go_errors.Is(err, errRecordNotFound) {
|
||||
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
|
||||
return ips, ttl, err
|
||||
}
|
||||
}
|
||||
|
||||
// ipv4 and ipv6 belong to different subscription groups
|
||||
var sub4, sub6 *pubsub.Subscriber
|
||||
if option.IPv4Enable {
|
||||
sub4 = s.pub.Subscribe(fqdn + "4")
|
||||
defer sub4.Close()
|
||||
}
|
||||
if option.IPv6Enable {
|
||||
sub6 = s.pub.Subscribe(fqdn + "6")
|
||||
defer sub6.Close()
|
||||
}
|
||||
done := make(chan interface{})
|
||||
go func() {
|
||||
if sub4 != nil {
|
||||
select {
|
||||
case <-sub4.Wait():
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
if sub6 != nil {
|
||||
select {
|
||||
case <-sub6.Wait():
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
s.sendQuery(ctx, fqdn, clientIP, option)
|
||||
noResponseErrCh := make(chan error, 2)
|
||||
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
|
||||
start := time.Now()
|
||||
|
||||
for {
|
||||
ips, ttl, err := s.findIPsForDomain(fqdn, option)
|
||||
if err != errRecordNotFound {
|
||||
log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
|
||||
return ips, ttl, err
|
||||
}
|
||||
|
||||
if sub4 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case <-done:
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub4.Wait():
|
||||
sub4.Close()
|
||||
}
|
||||
}
|
||||
if sub6 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub6.Wait():
|
||||
sub6.Close()
|
||||
}
|
||||
}
|
||||
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
|
||||
return ips, ttl, err
|
||||
|
||||
}
|
||||
|
|
|
@ -119,7 +119,7 @@ type MultiGeoIPMatcher struct {
|
|||
func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) {
|
||||
var matchers []*GeoIPMatcher
|
||||
for _, geoip := range geoips {
|
||||
matcher, err := globalGeoIPContainer.Add(geoip)
|
||||
matcher, err := GlobalGeoIPContainer.Add(geoip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -115,4 +115,4 @@ func (c *GeoIPMatcherContainer) Add(geoip *GeoIP) (*GeoIPMatcher, error) {
|
|||
return m, nil
|
||||
}
|
||||
|
||||
var globalGeoIPContainer GeoIPMatcherContainer
|
||||
var GlobalGeoIPContainer GeoIPMatcherContainer
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
@ -36,12 +37,12 @@ func AllEqual(expected error, actual error) bool {
|
|||
return false
|
||||
}
|
||||
for _, err := range errs {
|
||||
if err != expected {
|
||||
if !errors.Is(err, expected) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
default:
|
||||
return errs == expected
|
||||
return errors.Is(errs, expected)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,6 +38,8 @@ func ClientType() interface{} {
|
|||
// ErrEmptyResponse indicates that DNS query succeeded but no answer was returned.
|
||||
var ErrEmptyResponse = errors.New("empty response")
|
||||
|
||||
const DefaultTTL = 300
|
||||
|
||||
type RCodeError uint16
|
||||
|
||||
func (e RCodeError) Error() string {
|
||||
|
|
|
@ -30,29 +30,31 @@ func (*Client) LookupIP(host string, option dns.IPOption) ([]net.IP, uint32, err
|
|||
ipv6 := make([]net.IP, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
parsed := net.IPAddress(ip)
|
||||
if parsed != nil {
|
||||
parsedIPs = append(parsedIPs, parsed.IP())
|
||||
if parsed == nil {
|
||||
continue
|
||||
}
|
||||
if len(ip) == net.IPv4len {
|
||||
ipv4 = append(ipv4, ip)
|
||||
}
|
||||
if len(ip) == net.IPv6len {
|
||||
ipv6 = append(ipv6, ip)
|
||||
parsedIP := parsed.IP()
|
||||
parsedIPs = append(parsedIPs, parsedIP)
|
||||
|
||||
if len(parsedIP) == net.IPv4len {
|
||||
ipv4 = append(ipv4, parsedIP)
|
||||
} else {
|
||||
ipv6 = append(ipv6, parsedIP)
|
||||
}
|
||||
}
|
||||
// Local DNS ttl is 600
|
||||
|
||||
switch {
|
||||
case option.IPv4Enable && option.IPv6Enable:
|
||||
if len(parsedIPs) > 0 {
|
||||
return parsedIPs, 600, nil
|
||||
return parsedIPs, dns.DefaultTTL, nil
|
||||
}
|
||||
case option.IPv4Enable:
|
||||
if len(ipv4) > 0 {
|
||||
return ipv4, 600, nil
|
||||
return ipv4, dns.DefaultTTL, nil
|
||||
}
|
||||
case option.IPv6Enable:
|
||||
if len(ipv6) > 0 {
|
||||
return ipv6, 600, nil
|
||||
return ipv6, dns.DefaultTTL, nil
|
||||
}
|
||||
}
|
||||
return nil, 0, dns.ErrEmptyResponse
|
||||
|
|
|
@ -2,6 +2,7 @@ package dns
|
|||
|
||||
import (
|
||||
"context"
|
||||
go_errors "errors"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -255,7 +256,7 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,
|
|||
}
|
||||
|
||||
rcode := dns.RCodeFromError(err)
|
||||
if rcode == 0 && len(ips) == 0 && !errors.AllEqual(dns.ErrEmptyResponse, errors.Cause(err)) {
|
||||
if rcode == 0 && len(ips) == 0 && !go_errors.Is(err, dns.ErrEmptyResponse) {
|
||||
errors.LogInfoInner(context.Background(), err, "ip query")
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue