diff --git a/cert.go b/cert.go index 7c151f6..9399876 100644 --- a/cert.go +++ b/cert.go @@ -6,8 +6,11 @@ import ( "crypto/rand" "crypto/tls" "crypto/x509" + "encoding/pem" + "log" "math/big" "net" + "os" "path/filepath" "strings" "time" @@ -17,6 +20,8 @@ import ( // The zero value of CertificateStore is an empty store ready to use. type CertificateStore struct { store map[string]tls.Certificate + dir bool + path string } // Add adds a certificate for the given scope to the store. @@ -32,6 +37,15 @@ func (c *CertificateStore) Add(scope string, cert tls.Certificate) { cert.Leaf = parsed } } + if c.dir { + // Write certificates + log.Printf("Writing certificate for %s to disk", scope) + certPath := filepath.Join(c.path, scope+".crt") + keyPath := filepath.Join(c.path, scope+".key") + if err := WriteCertificate(cert, certPath, keyPath); err != nil { + log.Printf("Failed to write certificate to disk: %s", err) + } + } c.store[scope] = cert } @@ -53,6 +67,7 @@ func (c *CertificateStore) Lookup(scope string) (*tls.Certificate, error) { // in the form scope.crt and scope.key. // For example, the hostname "localhost" would have the corresponding files // localhost.crt (certificate) and localhost.key (private key). +// New certificates will be written to this directory. func (c *CertificateStore) Load(path string) error { matches, err := filepath.Glob(filepath.Join(path, "*.crt")) if err != nil { @@ -67,6 +82,8 @@ func (c *CertificateStore) Load(path string) error { scope := strings.TrimSuffix(filepath.Base(crtPath), ".crt") c.Add(scope, cert) } + c.dir = true + c.path = path return nil } @@ -133,3 +150,30 @@ func newX509KeyPair(options CertificateOptions) (*x509.Certificate, crypto.Priva } return cert, priv, nil } + +// WriteCertificate writes the provided certificate and private key +// to certPath and keyPath respectively. +func WriteCertificate(cert tls.Certificate, certPath, keyPath string) error { + certOut, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer certOut.Close() + if err := pem.Encode(certOut, &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Leaf.Raw, + }); err != nil { + return err + } + + keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer keyOut.Close() + privBytes, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey) + if err != nil { + return err + } + return pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}) +} diff --git a/examples/server.go b/examples/server.go index 7dfd331..c26e66b 100644 --- a/examples/server.go +++ b/examples/server.go @@ -3,14 +3,8 @@ package main import ( - "crypto" "crypto/tls" - "crypto/x509" - "encoding/pem" - "fmt" - "io" "log" - "os" "time" "git.sr.ht/~adnano/go-gemini" @@ -22,16 +16,10 @@ func main() { log.Fatal(err) } server.CreateCertificate = func(hostname string) (tls.Certificate, error) { - fmt.Println("Generating certificate for", hostname) - cert, err := gemini.CreateCertificate(gemini.CertificateOptions{ + return gemini.CreateCertificate(gemini.CertificateOptions{ DNSNames: []string{hostname}, Duration: time.Minute, // for testing purposes }) - if err == nil { - // Write the new certificate to disk - err = writeCertificate("/var/lib/gemini/certs/"+hostname, cert) - } - return cert, err } var mux gemini.ServeMux @@ -42,39 +30,3 @@ func main() { log.Fatal(err) } } - -// writeCertificate writes the provided certificate and private key -// to path.crt and path.key respectively. -func writeCertificate(path string, cert tls.Certificate) error { - // Write the certificate - crtPath := path + ".crt" - crtOut, err := os.OpenFile(crtPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return err - } - if err := marshalX509Certificate(crtOut, cert.Leaf.Raw); err != nil { - return err - } - - // Write the private key - keyPath := path + ".key" - keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return err - } - return marshalPrivateKey(keyOut, cert.PrivateKey) -} - -// marshalX509Certificate writes a PEM-encoded version of the given certificate. -func marshalX509Certificate(w io.Writer, cert []byte) error { - return pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: cert}) -} - -// marshalPrivateKey writes a PEM-encoded version of the given private key. -func marshalPrivateKey(w io.Writer, priv crypto.PrivateKey) error { - privBytes, err := x509.MarshalPKCS8PrivateKey(priv) - if err != nil { - return err - } - return pem.Encode(w, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}) -} diff --git a/server.go b/server.go index 1834bfa..41ddd75 100644 --- a/server.go +++ b/server.go @@ -39,7 +39,7 @@ type responderKey struct { // // Patterns must be in the form of hostname or scheme://hostname // (e.g. gemini://example.com). -// If no scheme is specified, a default scheme of gemini:// is assumed. +// If no scheme is specified, a default scheme of gemini:// is implied. // // Wildcard patterns are supported (e.g. *.example.com). // To register a certificate for a wildcard hostname, call Certificates.Add: @@ -132,27 +132,29 @@ func (s *Server) Serve(l net.Listener) error { } func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { - hostname := h.ServerName - cert, err := s.Certificates.Lookup(hostname) - if err == ErrCertificateUnknown { + cert, err := s.getCertificateFor(h.ServerName) + if err != nil { + // Try wildcard wildcard := strings.SplitN(h.ServerName, ".", 2) if len(wildcard) == 2 { - cert, err = s.Certificates.Lookup("*." + wildcard[1]) + cert, err = s.getCertificateFor("*." + wildcard[1]) } } + return cert, err +} + +func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) { + if _, ok := s.hosts[hostname]; !ok { + return nil, ErrCertificateUnknown + } + cert, err := s.Certificates.Lookup(hostname) switch err { - case ErrCertificateUnknown: - if _, ok := s.hosts[hostname]; !ok { - break - } - fallthrough - - case ErrCertificateExpired: + case ErrCertificateUnknown, ErrCertificateExpired: if s.CreateCertificate != nil { - cert, err := s.CreateCertificate(h.ServerName) + cert, err := s.CreateCertificate(hostname) if err == nil { - s.Certificates.Add(h.ServerName, cert) + s.Certificates.Add(hostname, cert) } return &cert, err }