From 66e03ef1e4115cf4a8f136baf3f107d070bbcb3f Mon Sep 17 00:00:00 2001 From: Adnan Maolood Date: Fri, 19 Feb 2021 18:37:28 -0500 Subject: [PATCH] certificate: Implement GetCertificate callback --- certificate/{certificate.go => create.go} | 94 ------------- certificate/store.go | 155 ++++++++++++++++++++++ 2 files changed, 155 insertions(+), 94 deletions(-) rename certificate/{certificate.go => create.go} (61%) create mode 100644 certificate/store.go diff --git a/certificate/certificate.go b/certificate/create.go similarity index 61% rename from certificate/certificate.go rename to certificate/create.go index 6e77599..d9c2bb5 100644 --- a/certificate/certificate.go +++ b/certificate/create.go @@ -14,103 +14,9 @@ import ( "math/big" "net" "os" - "path/filepath" - "strings" - "sync" "time" ) -// Dir represents a directory of certificates. -// The zero value for Dir is an empty directory ready to use. -// -// Dir is safe for concurrent use by multiple goroutines. -type Dir struct { - certs map[string]tls.Certificate - path *string - mu sync.RWMutex -} - -// Add adds a certificate for the given scope to the directory. -// It tries to parse the certificate if it is not already parsed. -func (d *Dir) Add(scope string, cert tls.Certificate) error { - d.mu.Lock() - defer d.mu.Unlock() - if d.certs == nil { - d.certs = map[string]tls.Certificate{} - } - // Parse certificate if not already parsed - if cert.Leaf == nil { - parsed, err := x509.ParseCertificate(cert.Certificate[0]) - if err == nil { - cert.Leaf = parsed - } - } - - if d.path != nil { - // Escape slash character - scope = strings.ReplaceAll(scope, "/", ":") - certPath := filepath.Join(*d.path, scope+".crt") - keyPath := filepath.Join(*d.path, scope+".key") - if err := Write(cert, certPath, keyPath); err != nil { - return err - } - } - - d.certs[scope] = cert - return nil -} - -// Lookup returns the certificate for the provided scope. -func (d *Dir) Lookup(scope string) (tls.Certificate, bool) { - d.mu.RLock() - defer d.mu.RUnlock() - cert, ok := d.certs[scope] - return cert, ok -} - -// Entries returns a map of hostnames to certificates. -func (d *Dir) Entries() map[string]tls.Certificate { - certs := map[string]tls.Certificate{} - for key := range d.certs { - certs[key] = d.certs[key] - } - return certs -} - -// Load loads certificates from the provided path. -// Add will write certificates to this path. -// -// The directory should contain certificates and private keys -// named scope.crt and scope.key respectively, where scope is -// the scope of the certificate. -func (d *Dir) Load(path string) error { - matches, err := filepath.Glob(filepath.Join(path, "*.crt")) - if err != nil { - return err - } - for _, crtPath := range matches { - keyPath := strings.TrimSuffix(crtPath, ".crt") + ".key" - cert, err := tls.LoadX509KeyPair(crtPath, keyPath) - if err != nil { - continue - } - scope := strings.TrimSuffix(filepath.Base(crtPath), ".crt") - // Unescape slash character - scope = strings.ReplaceAll(scope, ":", "/") - d.Add(scope, cert) - } - d.SetPath(path) - return nil -} - -// SetPath sets the directory path. -// Add will write certificates to this path. -func (d *Dir) SetPath(path string) { - d.mu.Lock() - defer d.mu.Unlock() - d.path = &path -} - // CreateOptions configures the creation of a TLS certificate. type CreateOptions struct { // Subject Alternate Name values. diff --git a/certificate/store.go b/certificate/store.go new file mode 100644 index 0000000..6f57d91 --- /dev/null +++ b/certificate/store.go @@ -0,0 +1,155 @@ +package certificate + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "path/filepath" + "strings" + "sync" + "time" +) + +// Store represents a certificate store. +// The zero value for Store is an empty store ready to use. +// +// Store is safe for concurrent use by multiple goroutines. +type Store struct { + // CreateCertificate, if not nil, is called to create a new certificate + // to replace a missing or expired certificate. + CreateCertificate func(scope string) (tls.Certificate, error) + + certs map[string]tls.Certificate + path string + mu sync.RWMutex +} + +// Register registers the provided scope in the certificate store. +// The certificate will be created upon calling GetCertificate. +func (s *Store) Register(scope string) { + s.mu.Lock() + defer s.mu.Unlock() + if s.certs == nil { + s.certs = make(map[string]tls.Certificate) + } + s.certs[scope] = tls.Certificate{} +} + +// Add adds a certificate for the given scope to the certificate store. +func (s *Store) Add(scope string, cert tls.Certificate) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.certs == nil { + s.certs = make(map[string]tls.Certificate) + } + + // Parse certificate if not already parsed + if cert.Leaf == nil { + parsed, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return err + } + cert.Leaf = parsed + } + + if s.path != "" { + // Escape slash character + path := strings.ReplaceAll(scope, "/", ":") + certPath := filepath.Join(s.path, path+".crt") + keyPath := filepath.Join(s.path, path+".key") + if err := Write(cert, certPath, keyPath); err != nil { + return err + } + } + + s.certs[scope] = cert + return nil +} + +// Lookup returns the certificate for the provided scope. +func (s *Store) Lookup(scope string) (tls.Certificate, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + cert, ok := s.certs[scope] + return cert, ok +} + +// GetCertificate retrieves the certificate for the given scope. +// If the retrieved certificate is expired or the scope is registered but +// has no certificate, it calls CreateCertificate to create a new certificate. +func (s *Store) GetCertificate(scope string) (*tls.Certificate, error) { + cert, ok := s.Lookup(scope) + if !ok { + // Try wildcard + wildcard := strings.SplitN(scope, ".", 2) + if len(wildcard) == 2 { + cert, ok = s.Lookup("*." + wildcard[1]) + } + } + if !ok { + return nil, errors.New("unrecognized scope") + } + + // If the certificate is empty or expired, generate a new one. + // TODO: Add sane defaults for certificate generation + if cert.Leaf == nil || cert.Leaf.NotAfter.Before(time.Now()) { + if s.CreateCertificate != nil { + cert, err := s.CreateCertificate(scope) + if err != nil { + return nil, err + } + if err := s.Add(scope, cert); err != nil { + return nil, fmt.Errorf("failed to write new certificate for %s: %w", scope, err) + } + return &cert, nil + } + return nil, errors.New("no suitable certificate found") + } + + return &cert, nil +} + +// Load loads certificates from the provided path. +// New certificates will be written to this path. +// +// The path should lead to a directory containing certificates +// and private keys named "scope.crt" and "scope.key" respectively, +// where "scope" is the scope of the certificate. +func (s *Store) Load(path string) error { + matches, err := filepath.Glob(filepath.Join(path, "*.crt")) + if err != nil { + return err + } + for _, crtPath := range matches { + keyPath := strings.TrimSuffix(crtPath, ".crt") + ".key" + cert, err := tls.LoadX509KeyPair(crtPath, keyPath) + if err != nil { + continue + } + scope := strings.TrimSuffix(filepath.Base(crtPath), ".crt") + // Unescape slash character + scope = strings.ReplaceAll(scope, ":", "/") + s.Add(scope, cert) + } + s.SetPath(path) + return nil +} + +// Entries returns a map of scopes to certificates. +func (s *Store) Entries() map[string]tls.Certificate { + s.mu.RLock() + defer s.mu.RUnlock() + certs := make(map[string]tls.Certificate) + for key := range s.certs { + certs[key] = s.certs[key] + } + return certs +} + +// SetPath sets the path that new certificates will be written to. +func (s *Store) SetPath(path string) { + s.mu.Lock() + defer s.mu.Unlock() + s.path = path +}