DNS: Retry with EDNS0 when response is truncated (#4516)

This commit is contained in:
风扇滑翔翼 2025-03-21 08:22:04 +08:00 committed by RPRX
parent 6f8e253dec
commit 86a225cda1
3 changed files with 49 additions and 15 deletions

View File

@ -31,10 +31,11 @@ type record struct {
// IPRecord is a cacheable item for a resolved domain
type IPRecord struct {
ReqID uint16
IP []net.Address
Expire time.Time
RCode dnsmessage.RCode
ReqID uint16
IP []net.Address
Expire time.Time
RCode dnsmessage.RCode
RawHeader *dnsmessage.Header
}
func (r *IPRecord) getIPs() ([]net.Address, error) {
@ -179,9 +180,10 @@ func parseResponse(payload []byte) (*IPRecord, error) {
now := time.Now()
ipRecord := &IPRecord{
ReqID: h.ID,
RCode: h.RCode,
Expire: now.Add(time.Second * 600),
ReqID: h.ID,
RCode: h.RCode,
Expire: now.Add(time.Second * 600),
RawHeader: &h,
}
L:

View File

@ -51,7 +51,7 @@ func Test_parseResponse(t *testing.T) {
}{
{
"empty",
&IPRecord{0, []net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess},
&IPRecord{0, []net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess, nil},
false,
},
{
@ -66,12 +66,13 @@ func Test_parseResponse(t *testing.T) {
[]net.Address{net.ParseAddress("8.8.8.8"), net.ParseAddress("8.8.4.4")},
time.Time{},
dnsmessage.RCodeSuccess,
nil,
},
false,
},
{
"aaaa record",
&IPRecord{2, []net.Address{net.ParseAddress("2001::123:8888"), net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess},
&IPRecord{2, []net.Address{net.ParseAddress("2001::123:8888"), net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess, nil},
false,
},
}
@ -84,8 +85,9 @@ func Test_parseResponse(t *testing.T) {
}
if got != nil {
// reset the time
// reset the time and RawHeader
got.Expire = time.Time{}
got.RawHeader = nil
}
if cmp.Diff(got, tt.want) != "" {
t.Error(cmp.Diff(got, tt.want))

View File

@ -27,7 +27,7 @@ type ClassicNameServer struct {
name string
address *net.Destination
ips map[string]*record
requests map[uint16]*dnsRequest
requests map[uint16]*udpDnsRequest
pub *pubsub.Service
udpServer *udp.Dispatcher
cleanup *task.Periodic
@ -35,6 +35,11 @@ type ClassicNameServer struct {
queryStrategy QueryStrategy
}
type udpDnsRequest struct {
dnsRequest
ctx context.Context
}
// NewClassicNameServer creates udp server object for remote resolving.
func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) *ClassicNameServer {
// default to 53 if unspecific
@ -45,7 +50,7 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
s := &ClassicNameServer{
address: &address,
ips: make(map[string]*record),
requests: make(map[uint16]*dnsRequest),
requests: make(map[uint16]*udpDnsRequest),
pub: pubsub.NewService(),
name: strings.ToUpper(address.String()),
queryStrategy: queryStrategy,
@ -101,7 +106,7 @@ func (s *ClassicNameServer) Cleanup() error {
}
if len(s.requests) == 0 {
s.requests = make(map[uint16]*dnsRequest)
s.requests = make(map[uint16]*udpDnsRequest)
}
return nil
@ -128,6 +133,27 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
return
}
// if truncated, retry with EDNS0 option(udp payload size: 1350)
if ipRec.RawHeader.Truncated {
// if already has EDNS0 option, no need to retry
if ok && len(req.msg.Additionals) == 0 {
// copy necessary meta data from original request
// and add EDNS0 option
opt := new(dnsmessage.Resource)
common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
opt.Body = &dnsmessage.OPTResource{}
newMsg := *req.msg
newReq := *req
newMsg.Additionals = append(newMsg.Additionals, *opt)
newMsg.ID = s.newReqID()
newReq.msg = &newMsg
s.addPendingRequest(&newReq)
b, _ := dns.PackMessage(newReq.msg)
s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b)
return
}
}
var rec record
switch req.reqType {
case dnsmessage.TypeA:
@ -179,7 +205,7 @@ func (s *ClassicNameServer) newReqID() uint16 {
return uint16(atomic.AddUint32(&s.reqID, 1))
}
func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) {
func (s *ClassicNameServer) addPendingRequest(req *udpDnsRequest) {
s.Lock()
defer s.Unlock()
@ -194,7 +220,11 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, client
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP))
for _, req := range reqs {
s.addPendingRequest(req)
udpReq := &udpDnsRequest{
dnsRequest: *req,
ctx: ctx,
}
s.addPendingRequest(udpReq)
b, _ := dns.PackMessage(req.msg)
s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b)
}