package dns import ( "context" "io" "sync" "time" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" dns_proto "github.com/xtls/xray-core/common/protocol/dns" "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/stat" "golang.org/x/net/dns/dnsmessage" ) func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { h := new(Handler) if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error { core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { h.fdns = fdns }) return h.Init(config.(*Config), dnsClient, policyManager) }); err != nil { return nil, err } return h, nil })) } type ownLinkVerifier interface { IsOwnLink(ctx context.Context) bool } type Handler struct { client dns.Client fdns dns.FakeDNSEngine ownLinkVerifier ownLinkVerifier server net.Destination timeout time.Duration nonIPQuery string blockTypes []int32 } func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error { h.client = dnsClient h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle if v, ok := dnsClient.(ownLinkVerifier); ok { h.ownLinkVerifier = v } if config.Server != nil { h.server = config.Server.AsDestination() } h.nonIPQuery = config.Non_IPQuery h.blockTypes = config.BlockTypes return nil } func (h *Handler) isOwnLink(ctx context.Context) bool { return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx) } func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) { var parser dnsmessage.Parser header, err := parser.Start(b) if err != nil { errors.LogInfoInner(context.Background(), err, "parser start") return } id = header.ID q, err := parser.Question() if err != nil { errors.LogInfoInner(context.Background(), err, "question") return } domain = q.Name.String() qType = q.Type if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA { return } r = true return } // Process implements proxy.Outbound. func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { outbounds := session.OutboundsFromContext(ctx) ob := outbounds[len(outbounds)-1] if !ob.Target.IsValid() { return errors.New("invalid outbound") } ob.Name = "dns" srcNetwork := ob.Target.Network dest := ob.Target if h.server.Network != net.Network_Unknown { dest.Network = h.server.Network } if h.server.Address != nil { dest.Address = h.server.Address } if h.server.Port != 0 { dest.Port = h.server.Port } errors.LogInfo(ctx, "handling DNS traffic to ", dest) conn := &outboundConn{ dialer: func() (stat.Connection, error) { return d.Dial(ctx, dest) }, connReady: make(chan struct{}, 1), } var reader dns_proto.MessageReader var writer dns_proto.MessageWriter if srcNetwork == net.Network_TCP { reader = dns_proto.NewTCPReader(link.Reader) writer = &dns_proto.TCPWriter{ Writer: link.Writer, } } else { reader = &dns_proto.UDPReader{ Reader: link.Reader, } writer = &dns_proto.UDPWriter{ Writer: link.Writer, } } var connReader dns_proto.MessageReader var connWriter dns_proto.MessageWriter if dest.Network == net.Network_TCP { connReader = dns_proto.NewTCPReader(buf.NewReader(conn)) connWriter = &dns_proto.TCPWriter{ Writer: buf.NewWriter(conn), } } else { connReader = &dns_proto.UDPReader{ Reader: buf.NewPacketReader(conn), } connWriter = &dns_proto.UDPWriter{ Writer: buf.NewWriter(conn), } } if session.TimeoutOnlyFromContext(ctx) { ctx, _ = context.WithCancel(context.Background()) } ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout) request := func() error { defer conn.Close() for { b, err := reader.ReadMessage() if err == io.EOF { return nil } if err != nil { return err } timer.Update() if !h.isOwnLink(ctx) { isIPQuery, domain, id, qType := parseIPQuery(b.Bytes()) if len(h.blockTypes) > 0 { for _, blocktype := range h.blockTypes { if blocktype == int32(qType) { errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain) return nil } } } if isIPQuery { go h.handleIPQuery(id, qType, domain, writer) } if isIPQuery || h.nonIPQuery == "drop" { b.Release() continue } } if err := connWriter.WriteMessage(b); err != nil { return err } } } response := func() error { for { b, err := connReader.ReadMessage() if err == io.EOF { return nil } if err != nil { return err } timer.Update() if err := writer.WriteMessage(b); err != nil { return err } } } if err := task.Run(ctx, request, response); err != nil { return errors.New("connection ends").Base(err) } return nil } func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { var ips []net.IP var err error var ttl uint32 = 600 switch qType { case dnsmessage.TypeA: ips, err = h.client.LookupIP(domain, dns.IPOption{ IPv4Enable: true, IPv6Enable: false, FakeEnable: true, }) case dnsmessage.TypeAAAA: ips, err = h.client.LookupIP(domain, dns.IPOption{ IPv4Enable: false, IPv6Enable: true, FakeEnable: true, }) } rcode := dns.RCodeFromError(err) if rcode == 0 && len(ips) == 0 && !errors.AllEqual(dns.ErrEmptyResponse, errors.Cause(err)) { errors.LogInfoInner(context.Background(), err, "ip query") return } if fkr0, ok := h.fdns.(dns.FakeDNSEngineRev0); ok && len(ips) > 0 && fkr0.IsIPInIPPool(net.IPAddress(ips[0])) { ttl = 1 } switch qType { case dnsmessage.TypeA: for i, ip := range ips { ips[i] = ip.To4() } case dnsmessage.TypeAAAA: for i, ip := range ips { ips[i] = ip.To16() } } b := buf.New() rawBytes := b.Extend(buf.Size) builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ ID: id, RCode: dnsmessage.RCode(rcode), RecursionAvailable: true, RecursionDesired: true, Response: true, Authoritative: true, }) builder.EnableCompression() common.Must(builder.StartQuestions()) common.Must(builder.Question(dnsmessage.Question{ Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, Type: qType, })) common.Must(builder.StartAnswers()) rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl} for _, ip := range ips { if len(ip) == net.IPv4len { var r dnsmessage.AResource copy(r.A[:], ip) common.Must(builder.AResource(rHeader, r)) } else { var r dnsmessage.AAAAResource copy(r.AAAA[:], ip) common.Must(builder.AAAAResource(rHeader, r)) } } msgBytes, err := builder.Finish() if err != nil { errors.LogInfoInner(context.Background(), err, "pack message") b.Release() return } b.Resize(0, int32(len(msgBytes))) if err := writer.WriteMessage(b); err != nil { errors.LogInfoInner(context.Background(), err, "write IP answer") } } type outboundConn struct { access sync.Mutex dialer func() (stat.Connection, error) conn net.Conn connReady chan struct{} } func (c *outboundConn) dial() error { conn, err := c.dialer() if err != nil { return err } c.conn = conn c.connReady <- struct{}{} return nil } func (c *outboundConn) Write(b []byte) (int, error) { c.access.Lock() if c.conn == nil { if err := c.dial(); err != nil { c.access.Unlock() errors.LogWarningInner(context.Background(), err, "failed to dial outbound connection") return len(b), nil } } c.access.Unlock() return c.conn.Write(b) } func (c *outboundConn) Read(b []byte) (int, error) { var conn net.Conn c.access.Lock() conn = c.conn c.access.Unlock() if conn == nil { _, open := <-c.connReady if !open { return 0, io.EOF } conn = c.conn } return conn.Read(b) } func (c *outboundConn) Close() error { c.access.Lock() close(c.connReady) if c.conn != nil { c.conn.Close() } c.access.Unlock() return nil }