diff --git a/transport/internet/tls/ech.go b/transport/internet/tls/ech.go index 177aeb6a..9471384d 100644 --- a/transport/internet/tls/ech.go +++ b/transport/internet/tls/ech.go @@ -26,7 +26,7 @@ func ApplyECH(c *Config, config *tls.Config) error { if err != nil { return errors.New("invalid ECH config") } - } else { + } else { // ECH config > DOH lookup if c.ServerName == "" { return errors.New("Using DOH for ECH needs serverName") } @@ -58,54 +58,58 @@ func QueryRecord(domain string, server string) (string, error) { } mutex.Lock() defer mutex.Unlock() - record, err := dohQuery(server, domain) + record, ttl, err := dohQuery(server, domain) if err != nil { return "", err } + // Use TTL for good, but many HTTPS records have TTL 60, too short + if ttl < 600 { + ttl = 600 + } rec.record = record - rec.expire = time.Now().Add(time.Second * 600) + rec.expire = time.Now().Add(time.Second * time.Duration(ttl)) dnsCache[domain] = rec return record, nil } -func dohQuery(server string, domain string) (string, error) { +func dohQuery(server string, domain string) (string, uint32, error) { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypeHTTPS) msg, err := m.Pack() if err != nil { - return "", err + return "", 0, err } client := &http.Client{ Timeout: 5 * time.Second, } req, err := http.NewRequest("POST", server, bytes.NewReader(msg)) if err != nil { - return "", err + return "", 0, err } req.Header.Set("Content-Type", "application/dns-message") resp, err := client.Do(req) if err != nil { - return "", err + return "", 0, err } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { - return "", err + return "", 0, err } if resp.StatusCode != http.StatusOK { - return "", errors.New("query failed with response code:", resp.StatusCode) + return "", 0, errors.New("query failed with response code:", resp.StatusCode) } respMsg := new(dns.Msg) err = respMsg.Unpack(respBody) if err != nil { - return "", err + return "", 0, err } if len(respMsg.Answer) > 0 { re := regexp.MustCompile(`ech="([^"]+)"`) match := re.FindStringSubmatch(respMsg.Answer[0].String()) if match[1] != "" { - return match[1], nil + return match[1], respMsg.Answer[0].Header().Ttl, nil } } - return "", errors.New("no ech record found") + return "", 0, errors.New("no ech record found") }