Call CreateCertificate for missing certificates

This commit is contained in:
Adnan Maolood 2020-10-31 15:38:39 -04:00
parent a89065babb
commit 3aa254870a

View File

@ -22,11 +22,12 @@ type Server struct {
Certificates CertificateStore Certificates CertificateStore
// CreateCertificate, if not nil, will be called to create a new certificate // CreateCertificate, if not nil, will be called to create a new certificate
// if the current one is expired. // if the current one is expired or missing.
CreateCertificate func(hostname string) (tls.Certificate, error) CreateCertificate func(hostname string) (tls.Certificate, error)
// registered responders // registered responders
responders map[responderKey]Responder responders map[responderKey]Responder
hosts map[string]bool
} }
type responderKey struct { type responderKey struct {
@ -54,6 +55,7 @@ func (s *Server) Register(pattern string, responder Responder) {
} }
if s.responders == nil { if s.responders == nil {
s.responders = map[responderKey]Responder{} s.responders = map[responderKey]Responder{}
s.hosts = map[string]bool{}
} }
split := strings.SplitN(pattern, "://", 2) split := strings.SplitN(pattern, "://", 2)
@ -70,6 +72,7 @@ func (s *Server) Register(pattern string, responder Responder) {
panic("gemini: multiple registrations for " + pattern) panic("gemini: multiple registrations for " + pattern)
} }
s.responders[key] = responder s.responders[key] = responder
s.hosts[key.hostname] = true
} }
// RegisterFunc registers a responder function for the given pattern. // RegisterFunc registers a responder function for the given pattern.
@ -129,7 +132,8 @@ func (s *Server) Serve(l net.Listener) error {
} }
func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := s.Certificates.Lookup(h.ServerName) hostname := h.ServerName
cert, err := s.Certificates.Lookup(hostname)
if err == ErrCertificateUnknown { if err == ErrCertificateUnknown {
wildcard := strings.SplitN(h.ServerName, ".", 2) wildcard := strings.SplitN(h.ServerName, ".", 2)
if len(wildcard) == 2 { if len(wildcard) == 2 {
@ -137,7 +141,14 @@ func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error
} }
} }
if err == ErrCertificateExpired { switch err {
case ErrCertificateUnknown:
if _, ok := s.hosts[hostname]; !ok {
break
}
fallthrough
case ErrCertificateExpired:
if s.CreateCertificate != nil { if s.CreateCertificate != nil {
cert, err := s.CreateCertificate(h.ServerName) cert, err := s.CreateCertificate(h.ServerName)
if err == nil { if err == nil {