server: Automatically write new certificates to disk

This commit is contained in:
Adnan Maolood 2020-10-31 16:33:56 -04:00
parent 3aa254870a
commit 14ef3be6fe
3 changed files with 61 additions and 63 deletions

44
cert.go
View File

@ -6,8 +6,11 @@ import (
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/pem"
"log"
"math/big" "math/big"
"net" "net"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
@ -17,6 +20,8 @@ import (
// The zero value of CertificateStore is an empty store ready to use. // The zero value of CertificateStore is an empty store ready to use.
type CertificateStore struct { type CertificateStore struct {
store map[string]tls.Certificate store map[string]tls.Certificate
dir bool
path string
} }
// Add adds a certificate for the given scope to the store. // Add adds a certificate for the given scope to the store.
@ -32,6 +37,15 @@ func (c *CertificateStore) Add(scope string, cert tls.Certificate) {
cert.Leaf = parsed cert.Leaf = parsed
} }
} }
if c.dir {
// Write certificates
log.Printf("Writing certificate for %s to disk", scope)
certPath := filepath.Join(c.path, scope+".crt")
keyPath := filepath.Join(c.path, scope+".key")
if err := WriteCertificate(cert, certPath, keyPath); err != nil {
log.Printf("Failed to write certificate to disk: %s", err)
}
}
c.store[scope] = cert c.store[scope] = cert
} }
@ -53,6 +67,7 @@ func (c *CertificateStore) Lookup(scope string) (*tls.Certificate, error) {
// in the form scope.crt and scope.key. // in the form scope.crt and scope.key.
// For example, the hostname "localhost" would have the corresponding files // For example, the hostname "localhost" would have the corresponding files
// localhost.crt (certificate) and localhost.key (private key). // localhost.crt (certificate) and localhost.key (private key).
// New certificates will be written to this directory.
func (c *CertificateStore) Load(path string) error { func (c *CertificateStore) Load(path string) error {
matches, err := filepath.Glob(filepath.Join(path, "*.crt")) matches, err := filepath.Glob(filepath.Join(path, "*.crt"))
if err != nil { if err != nil {
@ -67,6 +82,8 @@ func (c *CertificateStore) Load(path string) error {
scope := strings.TrimSuffix(filepath.Base(crtPath), ".crt") scope := strings.TrimSuffix(filepath.Base(crtPath), ".crt")
c.Add(scope, cert) c.Add(scope, cert)
} }
c.dir = true
c.path = path
return nil return nil
} }
@ -133,3 +150,30 @@ func newX509KeyPair(options CertificateOptions) (*x509.Certificate, crypto.Priva
} }
return cert, priv, nil return cert, priv, nil
} }
// WriteCertificate writes the provided certificate and private key
// to certPath and keyPath respectively.
func WriteCertificate(cert tls.Certificate, certPath, keyPath string) error {
certOut, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer certOut.Close()
if err := pem.Encode(certOut, &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Leaf.Raw,
}); err != nil {
return err
}
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer keyOut.Close()
privBytes, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey)
if err != nil {
return err
}
return pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})
}

View File

@ -3,14 +3,8 @@
package main package main
import ( import (
"crypto"
"crypto/tls" "crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"log" "log"
"os"
"time" "time"
"git.sr.ht/~adnano/go-gemini" "git.sr.ht/~adnano/go-gemini"
@ -22,16 +16,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
server.CreateCertificate = func(hostname string) (tls.Certificate, error) { server.CreateCertificate = func(hostname string) (tls.Certificate, error) {
fmt.Println("Generating certificate for", hostname) return gemini.CreateCertificate(gemini.CertificateOptions{
cert, err := gemini.CreateCertificate(gemini.CertificateOptions{
DNSNames: []string{hostname}, DNSNames: []string{hostname},
Duration: time.Minute, // for testing purposes Duration: time.Minute, // for testing purposes
}) })
if err == nil {
// Write the new certificate to disk
err = writeCertificate("/var/lib/gemini/certs/"+hostname, cert)
}
return cert, err
} }
var mux gemini.ServeMux var mux gemini.ServeMux
@ -42,39 +30,3 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
} }
// writeCertificate writes the provided certificate and private key
// to path.crt and path.key respectively.
func writeCertificate(path string, cert tls.Certificate) error {
// Write the certificate
crtPath := path + ".crt"
crtOut, err := os.OpenFile(crtPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
if err := marshalX509Certificate(crtOut, cert.Leaf.Raw); err != nil {
return err
}
// Write the private key
keyPath := path + ".key"
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
return marshalPrivateKey(keyOut, cert.PrivateKey)
}
// marshalX509Certificate writes a PEM-encoded version of the given certificate.
func marshalX509Certificate(w io.Writer, cert []byte) error {
return pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: cert})
}
// marshalPrivateKey writes a PEM-encoded version of the given private key.
func marshalPrivateKey(w io.Writer, priv crypto.PrivateKey) error {
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return err
}
return pem.Encode(w, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})
}

View File

@ -39,7 +39,7 @@ type responderKey struct {
// //
// Patterns must be in the form of hostname or scheme://hostname // Patterns must be in the form of hostname or scheme://hostname
// (e.g. gemini://example.com). // (e.g. gemini://example.com).
// If no scheme is specified, a default scheme of gemini:// is assumed. // If no scheme is specified, a default scheme of gemini:// is implied.
// //
// Wildcard patterns are supported (e.g. *.example.com). // Wildcard patterns are supported (e.g. *.example.com).
// To register a certificate for a wildcard hostname, call Certificates.Add: // To register a certificate for a wildcard hostname, call Certificates.Add:
@ -132,27 +132,29 @@ func (s *Server) Serve(l net.Listener) error {
} }
func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
hostname := h.ServerName cert, err := s.getCertificateFor(h.ServerName)
cert, err := s.Certificates.Lookup(hostname) if err != nil {
if err == ErrCertificateUnknown { // Try wildcard
wildcard := strings.SplitN(h.ServerName, ".", 2) wildcard := strings.SplitN(h.ServerName, ".", 2)
if len(wildcard) == 2 { if len(wildcard) == 2 {
cert, err = s.Certificates.Lookup("*." + wildcard[1]) cert, err = s.getCertificateFor("*." + wildcard[1])
} }
} }
return cert, err
}
func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) {
if _, ok := s.hosts[hostname]; !ok {
return nil, ErrCertificateUnknown
}
cert, err := s.Certificates.Lookup(hostname)
switch err { switch err {
case ErrCertificateUnknown: case ErrCertificateUnknown, ErrCertificateExpired:
if _, ok := s.hosts[hostname]; !ok {
break
}
fallthrough
case ErrCertificateExpired:
if s.CreateCertificate != nil { if s.CreateCertificate != nil {
cert, err := s.CreateCertificate(h.ServerName) cert, err := s.CreateCertificate(hostname)
if err == nil { if err == nil {
s.Certificates.Add(h.ServerName, cert) s.Certificates.Add(hostname, cert)
} }
return &cert, err return &cert, err
} }