From 9357bed8cfae57b59838f06312ffff8edf40ab85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A3=8E=E6=89=87=E6=BB=91=E7=BF=94=E7=BF=BC?= Date: Sun, 15 Sep 2024 13:37:57 +0000 Subject: [PATCH] Use server returned TTL for cache, instead of 600 --- transport/internet/tls/ech.go | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/transport/internet/tls/ech.go b/transport/internet/tls/ech.go index 177aeb6a..9471384d 100644 --- a/transport/internet/tls/ech.go +++ b/transport/internet/tls/ech.go @@ -26,7 +26,7 @@ func ApplyECH(c *Config, config *tls.Config) error { if err != nil { return errors.New("invalid ECH config") } - } else { + } else { // ECH config > DOH lookup if c.ServerName == "" { return errors.New("Using DOH for ECH needs serverName") } @@ -58,54 +58,58 @@ func QueryRecord(domain string, server string) (string, error) { } mutex.Lock() defer mutex.Unlock() - record, err := dohQuery(server, domain) + record, ttl, err := dohQuery(server, domain) if err != nil { return "", err } + // Use TTL for good, but many HTTPS records have TTL 60, too short + if ttl < 600 { + ttl = 600 + } rec.record = record - rec.expire = time.Now().Add(time.Second * 600) + rec.expire = time.Now().Add(time.Second * time.Duration(ttl)) dnsCache[domain] = rec return record, nil } -func dohQuery(server string, domain string) (string, error) { +func dohQuery(server string, domain string) (string, uint32, error) { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypeHTTPS) msg, err := m.Pack() if err != nil { - return "", err + return "", 0, err } client := &http.Client{ Timeout: 5 * time.Second, } req, err := http.NewRequest("POST", server, bytes.NewReader(msg)) if err != nil { - return "", err + return "", 0, err } req.Header.Set("Content-Type", "application/dns-message") resp, err := client.Do(req) if err != nil { - return "", err + return "", 0, err } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { - return "", err + return "", 0, err } if resp.StatusCode != http.StatusOK { - return "", errors.New("query failed with response code:", resp.StatusCode) + return "", 0, errors.New("query failed with response code:", resp.StatusCode) } respMsg := new(dns.Msg) err = respMsg.Unpack(respBody) if err != nil { - return "", err + return "", 0, err } if len(respMsg.Answer) > 0 { re := regexp.MustCompile(`ech="([^"]+)"`) match := re.FindStringSubmatch(respMsg.Answer[0].String()) if match[1] != "" { - return match[1], nil + return match[1], respMsg.Answer[0].Header().Ttl, nil } } - return "", errors.New("no ech record found") + return "", 0, errors.New("no ech record found") }