CertificateObject: Enable auto-reload for cacert & Add buildChain & Fixes (#3607)

This commit is contained in:
lelemka0 2024-07-29 14:58:58 +08:00 committed by GitHub
parent a342db3e28
commit 4531a7e228
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 163 additions and 125 deletions

View file

@ -10,6 +10,7 @@ import (
"strings"
"sync"
"time"
"bytes"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
@ -50,72 +51,84 @@ func (c *Config) BuildCertificates() []*tls.Certificate {
if entry.Usage != Certificate_ENCIPHERMENT {
continue
}
keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key)
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid X509 key pair")
continue
}
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid certificate")
continue
}
certs = append(certs, &keyPair)
if !entry.OneTimeLoading {
var isOcspstapling bool
hotReloadCertInterval := uint64(3600)
if entry.OcspStapling != 0 {
hotReloadCertInterval = entry.OcspStapling
isOcspstapling = true
getX509KeyPair := func() *tls.Certificate {
keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key)
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid X509 key pair")
return nil
}
index := len(certs) - 1
go func(entry *Certificate, cert *tls.Certificate, index int) {
t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second)
for {
if entry.CertificatePath != "" && entry.KeyPath != "" {
newCert, err := filesystem.ReadFile(entry.CertificatePath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse certificate")
<-t.C
continue
}
newKey, err := filesystem.ReadFile(entry.KeyPath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse key")
<-t.C
continue
}
if string(newCert) != string(entry.Certificate) && string(newKey) != string(entry.Key) {
newKeyPair, err := tls.X509KeyPair(newCert, newKey)
if err != nil {
errors.LogErrorInner(context.Background(), err, "ignoring invalid X509 key pair")
<-t.C
continue
}
if newKeyPair.Leaf, err = x509.ParseCertificate(newKeyPair.Certificate[0]); err != nil {
errors.LogErrorInner(context.Background(), err, "ignoring invalid certificate")
<-t.C
continue
}
cert = &newKeyPair
}
}
if isOcspstapling {
if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid OCSP")
} else if string(newOCSPData) != string(cert.OCSPStaple) {
cert.OCSPStaple = newOCSPData
}
}
certs[index] = cert
<-t.C
}
}(entry, certs[index], index)
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid certificate")
return nil
}
return &keyPair
}
if keyPair := getX509KeyPair(); keyPair != nil {
certs = append(certs, keyPair)
} else {
continue
}
index := len(certs) - 1
setupOcspTicker(entry, func(isReloaded, isOcspstapling bool){
cert := certs[index]
if isReloaded {
if newKeyPair := getX509KeyPair(); newKeyPair != nil {
cert = newKeyPair
} else {
return
}
}
if isOcspstapling {
if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid OCSP")
} else if string(newOCSPData) != string(cert.OCSPStaple) {
cert.OCSPStaple = newOCSPData
}
}
certs[index] = cert
})
}
return certs
}
func setupOcspTicker(entry *Certificate, callback func(isReloaded, isOcspstapling bool)) {
go func() {
if entry.OneTimeLoading {
return
}
var isOcspstapling bool
hotReloadCertInterval := uint64(3600)
if entry.OcspStapling != 0 {
hotReloadCertInterval = entry.OcspStapling
isOcspstapling = true
}
t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second)
for {
var isReloaded bool
if entry.CertificatePath != "" && entry.KeyPath != "" {
newCert, err := filesystem.ReadFile(entry.CertificatePath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse certificate")
return
}
newKey, err := filesystem.ReadFile(entry.KeyPath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse key")
return
}
if string(newCert) != string(entry.Certificate) || string(newKey) != string(entry.Key) {
entry.Certificate = newCert
entry.Key = newKey
isReloaded = true
}
}
callback(isReloaded, isOcspstapling)
<-t.C
}
}()
}
func isCertificateExpired(c *tls.Certificate) bool {
if c.Leaf == nil && len(c.Certificate) > 0 {
if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil {
@ -137,6 +150,9 @@ func issueCertificate(rawCA *Certificate, domain string) (*tls.Certificate, erro
return nil, errors.New("failed to generate new certificate for ", domain).Base(err)
}
newCertPEM, newKeyPEM := newCert.ToPEM()
if rawCA.BuildChain {
newCertPEM = bytes.Join([][]byte{newCertPEM, rawCA.Certificate}, []byte("\n"))
}
cert, err := tls.X509KeyPair(newCertPEM, newKeyPEM)
return &cert, err
}
@ -146,6 +162,7 @@ func (c *Config) getCustomCA() []*Certificate {
for _, certificate := range c.Certificate {
if certificate.Usage == Certificate_AUTHORITY_ISSUE {
certs = append(certs, certificate)
setupOcspTicker(certificate, func(isReloaded, isOcspstapling bool){ })
}
}
return certs