diff --git a/cert.go b/cert.go index 39d7f52..0fba5fd 100644 --- a/cert.go +++ b/cert.go @@ -15,7 +15,40 @@ import ( ) // CertificateStore maps hostnames to certificates. -type CertificateStore map[string]*tls.Certificate +type CertificateStore struct { + store map[string]tls.Certificate +} + +// Add adds a certificate for the given hostname to the store. +func (c *CertificateStore) Add(hostname string, cert tls.Certificate) { + if c.store == nil { + c.store = map[string]tls.Certificate{} + } + c.store[hostname] = cert +} + +// Lookup returns the certificate for the given hostname. +func (c *CertificateStore) Lookup(hostname string) (*tls.Certificate, error) { + cert, ok := c.store[hostname] + if !ok { + return nil, ErrUnknownCertificate + } + // TODO: Ensure that the certificate is not expired + // if expired { + // return nil, ErrInvalidCertificate + // } + return &cert, nil +} + +// Load loads certificates from the given path. +// The path should lead to a directory containing certificates and private keys +// in the form hostname.crt and hostname.key. +// For example, the hostname "localhost" would have the corresponding files +// localhost.crt (certificate) and localhost.key (private key). +func (c *CertificateStore) Load(path string) error { + // TODO: Implement this + return nil +} // NewCertificate creates and returns a new parsed certificate. func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) { diff --git a/client.go b/client.go index ff968e4..5a2512c 100644 --- a/client.go +++ b/client.go @@ -218,10 +218,8 @@ func (c *Client) Send(req *Request) (*Response, error) { return req.Certificate, nil } // If we have already stored the certificate, return it - if c.CertificateStore != nil { - if cert, ok := c.CertificateStore[req.Hostname()]; ok { - return cert, nil - } + if cert, err := c.CertificateStore.Lookup(req.Hostname()); err == nil { + return cert, nil } return &tls.Certificate{}, nil }, @@ -280,10 +278,6 @@ func (c *Client) Send(req *Request) (*Response, error) { if req.Certificate != nil { return resp, nil } - // Create the certificate store if it does not exist - if c.CertificateStore == nil { - c.CertificateStore = CertificateStore{} - } if c.GetCertificate != nil { if cert := c.GetCertificate(req.Hostname(), c.CertificateStore); cert != nil { req.Certificate = cert diff --git a/examples/auth/auth.go b/examples/auth/auth.go index d9ad463..1dc5697 100644 --- a/examples/auth/auth.go +++ b/examples/auth/auth.go @@ -52,9 +52,8 @@ func main() { handler.HandleFunc("/admin", admin) handler.HandleFunc("/logout", logout) - server := &gmi.Server{ - Certificate: cert, - } + server := &gmi.Server{} + server.CertificateStore.Add("localhost", cert) server.Handle("localhost", handler) if err := server.ListenAndServe(); err != nil { diff --git a/examples/client/client.go b/examples/client/client.go index 4d61868..ee65c9d 100644 --- a/examples/client/client.go +++ b/examples/client/client.go @@ -48,7 +48,7 @@ func init() { } client.GetCertificate = func(hostname string, store gmi.CertificateStore) *tls.Certificate { // If the certificate is in the store, return it - if cert, ok := store[hostname]; ok { + if cert, err := store.Lookup(hostname); err == nil { return cert } // Otherwise, generate a certificate @@ -59,7 +59,7 @@ func init() { return nil } // Store and return the certificate - store[hostname] = &cert + store.Add(hostname, cert) return &cert } } diff --git a/examples/server/server.go b/examples/server/server.go index 80341b9..df405a1 100644 --- a/examples/server/server.go +++ b/examples/server/server.go @@ -23,9 +23,8 @@ func main() { mux := &gmi.ServeMux{} mux.Handle("/", gmi.FileServer(gmi.Dir("/var/www"))) - server := gmi.Server{ - Certificate: cert, - } + server := gmi.Server{} + server.CertificateStore.Add("localhost", cert) server.Handle("localhost", mux) server.ListenAndServe() } diff --git a/gemini.go b/gemini.go index e3f2562..f03af9e 100644 --- a/gemini.go +++ b/gemini.go @@ -58,7 +58,7 @@ func init() { } DefaultClient.GetCertificate = func(hostname string, store CertificateStore) *tls.Certificate { // If the certificate is in the store, return it - if cert, ok := store[hostname]; ok { + if cert, err := store.Lookup(hostname); err == nil { return cert } // Otherwise, generate a certificate @@ -68,7 +68,7 @@ func init() { return nil } // Store and return the certificate - store[hostname] = &cert + store.Add(hostname, cert) return &cert } } diff --git a/server.go b/server.go index f21863f..a131fba 100644 --- a/server.go +++ b/server.go @@ -27,9 +27,13 @@ type Server struct { // If Addr is empty, the server will listen on the address ":1965". Addr string - // Certificate provides a TLS certificate for use by the server. - // A self-signed certificate is recommended. - Certificate tls.Certificate + // CertificateStore contains the certificates used by the server. + CertificateStore CertificateStore + + // GetCertificate, if not nil, will be called to retrieve the certificate + // to use for a given hostname. + // If the certificate is nil, the connection will be aborted. + GetCertificate func(hostname string) *tls.Certificate // registered handlers handlers []handlerEntry @@ -44,7 +48,6 @@ func (s *Server) Handle(host string, handler Handler) { if handler == nil { panic("gmi: nil handler") } - s.HandleScheme("gemini", host, handler) } @@ -79,8 +82,13 @@ func (s *Server) ListenAndServe() error { config := &tls.Config{ InsecureSkipVerify: true, MinVersion: tls.VersionTLS12, - Certificates: []tls.Certificate{s.Certificate}, - ClientAuth: tls.RequestClientCert, + GetCertificate: func(h *tls.ClientHelloInfo) (*tls.Certificate, error) { + if s.GetCertificate != nil { + return s.GetCertificate(h.ServerName), nil + } + return s.CertificateStore.Lookup(h.ServerName) + }, + ClientAuth: tls.RequestClientCert, } tlsListener := tls.NewListener(ln, config) return s.Serve(tlsListener)