diff --git a/common/session/session.go b/common/session/session.go index 1e9ec8d6..afa1fc50 100644 --- a/common/session/session.go +++ b/common/session/session.go @@ -53,6 +53,8 @@ type Inbound struct { Uid uint32 // SagerNet private: AppStatus is the android app's status for the inbound connection AppStatus []string + // SagerNet private + SkipFakeDNS bool } // Outbound is the metadata of an outbound connection. diff --git a/features/routing/session/context.go b/features/routing/session/context.go index 1cebd82e..7b70b806 100644 --- a/features/routing/session/context.go +++ b/features/routing/session/context.go @@ -137,6 +137,13 @@ func (ctx *Context) GetAppStatus() []string { return ctx.Inbound.AppStatus } +func (ctx Context) GetSkipFakeDNS() bool { + if ctx.Inbound == nil { + return false + } + return ctx.Inbound.SkipFakeDNS +} + // AsRoutingContext creates a context from context.context with session info. func AsRoutingContext(ctx context.Context) routing.Context { return &Context{ diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index 0fcd502d..39506066 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -96,6 +96,12 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet. return newError("invalid outbound") } + fakeDNS := true + inbound := session.InboundFromContext(ctx) + if inbound != nil && inbound.SkipFakeDNS { + fakeDNS = false + } + srcNetwork := outbound.Target.Network dest := outbound.Target @@ -171,7 +177,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet. if !h.isOwnLink(ctx) { isIPQuery, domain, id, qType := parseIPQuery(b.Bytes()) if isIPQuery { - go h.handleIPQuery(id, qType, domain, writer) + go h.handleIPQuery(id, qType, domain, writer, fakeDNS) continue } } @@ -208,7 +214,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet. return nil } -func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { +func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter, fakedns bool) { var ips []net.IP var err error @@ -219,13 +225,13 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, ips, err = h.client.LookupIP(domain, dns.IPOption{ IPv4Enable: true, IPv6Enable: false, - FakeEnable: true, + FakeEnable: fakedns, }) case dnsmessage.TypeAAAA: ips, err = h.client.LookupIP(domain, dns.IPOption{ IPv4Enable: false, IPv6Enable: true, - FakeEnable: true, + FakeEnable: fakedns, }) }