Move cert.go to a subpackage

This commit is contained in:
Adnan Maolood 2021-01-14 20:42:12 -05:00
parent 7a00539f75
commit 14d89f304a
2 changed files with 49 additions and 56 deletions

View File

@ -1,4 +1,5 @@
package gemini // Package certificate provides utility functions for TLS certificates.
package certificate
import ( import (
"crypto" "crypto"
@ -19,27 +20,23 @@ import (
"time" "time"
) )
// CertificateDir maps certificate scopes to certificates. // Dir represents a directory of certificates.
type CertificateStore map[string]tls.Certificate // The zero value of Dir is an empty directory ready to use.
// CertificateDir represents a certificate store optionally loaded from a directory.
// The zero value of CertificateDir is an empty store ready to use.
// //
// CertificateDir is safe for concurrent use by multiple goroutines. // Dir is safe for concurrent use by multiple goroutines.
type CertificateDir struct { type Dir struct {
CertificateStore certs map[string]tls.Certificate
dir bool path *string
path string mu sync.RWMutex
mu sync.RWMutex
} }
// Add adds a certificate for the given scope to the store. // Add adds a certificate for the given scope to the directory.
// It tries to parse the certificate if it is not already parsed. // It tries to parse the certificate if it is not already parsed.
func (c *CertificateDir) Add(scope string, cert tls.Certificate) { func (d *Dir) Add(scope string, cert tls.Certificate) error {
c.mu.Lock() d.mu.Lock()
defer c.mu.Unlock() defer d.mu.Unlock()
if c.CertificateStore == nil { if d.certs == nil {
c.CertificateStore = CertificateStore{} d.certs = map[string]tls.Certificate{}
} }
// Parse certificate if not already parsed // Parse certificate if not already parsed
if cert.Leaf == nil { if cert.Leaf == nil {
@ -48,30 +45,26 @@ func (c *CertificateDir) Add(scope string, cert tls.Certificate) {
cert.Leaf = parsed cert.Leaf = parsed
} }
} }
c.CertificateStore[scope] = cert
}
// Write writes the provided certificate to the certificate directory. if d.path != nil {
func (c *CertificateDir) Write(scope string, cert tls.Certificate) error {
c.mu.RLock()
defer c.mu.RUnlock()
if c.dir {
// Escape slash character // Escape slash character
scope = strings.ReplaceAll(scope, "/", ":") scope = strings.ReplaceAll(scope, "/", ":")
certPath := filepath.Join(c.path, scope+".crt") certPath := filepath.Join(*d.path, scope+".crt")
keyPath := filepath.Join(c.path, scope+".key") keyPath := filepath.Join(*d.path, scope+".key")
if err := WriteCertificate(cert, certPath, keyPath); err != nil { if err := Write(cert, certPath, keyPath); err != nil {
return err return err
} }
} }
d.certs[scope] = cert
return nil return nil
} }
// Lookup returns the certificate for the given scope. // Lookup returns the certificate for the provided scope.
func (c *CertificateDir) Lookup(scope string) (tls.Certificate, bool) { func (d *Dir) Lookup(scope string) (tls.Certificate, bool) {
c.mu.RLock() d.mu.RLock()
defer c.mu.RUnlock() defer d.mu.RUnlock()
cert, ok := c.CertificateStore[scope] cert, ok := d.certs[scope]
return cert, ok return cert, ok
} }
@ -81,7 +74,7 @@ func (c *CertificateDir) Lookup(scope string) (tls.Certificate, bool) {
// For example, the hostname "localhost" would have the corresponding files // For example, the hostname "localhost" would have the corresponding files
// localhost.crt (certificate) and localhost.key (private key). // localhost.crt (certificate) and localhost.key (private key).
// New certificates will be written to this directory. // New certificates will be written to this directory.
func (c *CertificateDir) Load(path string) error { func (d *Dir) Load(path string) error {
matches, err := filepath.Glob(filepath.Join(path, "*.crt")) matches, err := filepath.Glob(filepath.Join(path, "*.crt"))
if err != nil { if err != nil {
return err return err
@ -95,31 +88,30 @@ func (c *CertificateDir) Load(path string) error {
scope := strings.TrimSuffix(filepath.Base(crtPath), ".crt") scope := strings.TrimSuffix(filepath.Base(crtPath), ".crt")
// Unescape slash character // Unescape slash character
scope = strings.ReplaceAll(scope, ":", "/") scope = strings.ReplaceAll(scope, ":", "/")
c.Add(scope, cert) d.Add(scope, cert)
} }
c.SetDir(path) d.SetPath(path)
return nil return nil
} }
// SetDir sets the directory that new certificates will be written to. // SetPath sets the directory that new certificates will be written to.
func (c *CertificateDir) SetDir(path string) { func (d *Dir) SetPath(path string) {
c.mu.Lock() d.mu.Lock()
defer c.mu.Unlock() defer d.mu.Unlock()
c.dir = true d.path = &path
c.path = path
} }
// CertificateOptions configures the creation of a certificate. // CreateOptions configures the creation of a TLS certificate.
type CertificateOptions struct { type CreateOptions struct {
// Subject Alternate Name values.
// Should contain the IP addresses that the certificate is valid for.
IPAddresses []net.IP
// Subject Alternate Name values. // Subject Alternate Name values.
// Should contain the DNS names that this certificate is valid for. // Should contain the DNS names that this certificate is valid for.
// E.g. example.com, *.example.com // E.g. example.com, *.example.com
DNSNames []string DNSNames []string
// Subject Alternate Name values.
// Should contain the IP addresses that the certificate is valid for.
IPAddresses []net.IP
// Subject specifies the certificate Subject. // Subject specifies the certificate Subject.
// //
// Subject.CommonName can contain the DNS name that this certificate // Subject.CommonName can contain the DNS name that this certificate
@ -136,8 +128,8 @@ type CertificateOptions struct {
Ed25519 bool Ed25519 bool
} }
// CreateCertificate creates a new TLS certificate. // Create creates a new TLS certificate.
func CreateCertificate(options CertificateOptions) (tls.Certificate, error) { func Create(options CreateOptions) (tls.Certificate, error) {
crt, priv, err := newX509KeyPair(options) crt, priv, err := newX509KeyPair(options)
if err != nil { if err != nil {
return tls.Certificate{}, err return tls.Certificate{}, err
@ -150,7 +142,7 @@ func CreateCertificate(options CertificateOptions) (tls.Certificate, error) {
} }
// newX509KeyPair creates and returns a new certificate and private key. // newX509KeyPair creates and returns a new certificate and private key.
func newX509KeyPair(options CertificateOptions) (*x509.Certificate, crypto.PrivateKey, error) { func newX509KeyPair(options CreateOptions) (*x509.Certificate, crypto.PrivateKey, error) {
var pub crypto.PublicKey var pub crypto.PublicKey
var priv crypto.PrivateKey var priv crypto.PrivateKey
if options.Ed25519 { if options.Ed25519 {
@ -206,9 +198,9 @@ func newX509KeyPair(options CertificateOptions) (*x509.Certificate, crypto.Priva
return cert, priv, nil return cert, priv, nil
} }
// WriteCertificate writes the provided certificate and private key // Write writes the provided certificate and its private key
// to certPath and keyPath respectively. // to certPath and keyPath respectively.
func WriteCertificate(cert tls.Certificate, certPath, keyPath string) error { func Write(cert tls.Certificate, certPath, keyPath string) error {
certOut, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) certOut, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil { if err != nil {
return err return err

View File

@ -7,6 +7,8 @@ import (
"net" "net"
"strings" "strings"
"time" "time"
"git.sr.ht/~adnano/go-gemini/certificate"
) )
// Server is a Gemini server. // Server is a Gemini server.
@ -23,7 +25,7 @@ type Server struct {
WriteTimeout time.Duration WriteTimeout time.Duration
// Certificates contains the certificates used by the server. // Certificates contains the certificates used by the server.
Certificates CertificateDir Certificates certificate.Dir
// CreateCertificate, if not nil, will be called to create a new certificate // CreateCertificate, if not nil, will be called to create a new certificate
// if the current one is expired or missing. // if the current one is expired or missing.
@ -157,8 +159,7 @@ func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) {
if s.CreateCertificate != nil { if s.CreateCertificate != nil {
cert, err := s.CreateCertificate(hostname) cert, err := s.CreateCertificate(hostname)
if err == nil { if err == nil {
s.Certificates.Add(hostname, cert) if err := s.Certificates.Add(hostname, cert); err != nil {
if err := s.Certificates.Write(hostname, cert); err != nil {
s.logf("gemini: Failed to write new certificate for %s: %s", hostname, err) s.logf("gemini: Failed to write new certificate for %s: %s", hostname, err)
} }
} }