certificate.Store: Don't check parent scopes in Lookup

Limit the scopes of client certificates to hostnames only instead of
hostnames and paths.
This commit is contained in:
Adnan Maolood 2021-03-06 12:57:42 -05:00
parent d1cb8967b6
commit 504da9afd8

View File

@ -6,9 +6,7 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"errors" "errors"
"fmt" "fmt"
"io/fs"
"os" "os"
"path"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
@ -23,8 +21,7 @@ import (
// Servers will most likely use the methods Register, Load and Get. // Servers will most likely use the methods Register, Load and Get.
// //
// Store can also be used to store client certificates. // Store can also be used to store client certificates.
// Clients should provide the hostname and path of a URL as a certificate scope // Clients should provide a hostname as a certificate scope.
// (without a trailing slash).
// Clients will most likely use the methods Add, Load, and Lookup. // Clients will most likely use the methods Add, Load, and Lookup.
// //
// Store is safe for concurrent use by multiple goroutines. // Store is safe for concurrent use by multiple goroutines.
@ -142,25 +139,10 @@ func (s *Store) Get(hostname string) (*tls.Certificate, error) {
} }
// Lookup returns the certificate for the provided scope. // Lookup returns the certificate for the provided scope.
// Lookup also checks for certificates in parent scopes.
// For example, given the scope "example.com/a/b/c", Lookup will first check
// "example.com/a/b/c", then "example.com/a/b", then "example.com/a", and
// finally "example.com" for a certificate. As a result, a certificate with
// scope "example.com" will match all scopes beginning with "example.com".
func (s *Store) Lookup(scope string) (tls.Certificate, bool) { func (s *Store) Lookup(scope string) (tls.Certificate, bool) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
cert, ok := s.certs[scope] cert, ok := s.certs[scope]
if !ok {
scope = path.Dir(scope)
for scope != "." {
cert, ok = s.certs[scope]
if ok {
break
}
scope = path.Dir(scope)
}
}
return cert, ok return cert, ok
} }
@ -183,7 +165,11 @@ func (s *Store) createCertificate(scope string) (tls.Certificate, error) {
// and private keys named "scope.crt" and "scope.key" respectively, // and private keys named "scope.crt" and "scope.key" respectively,
// where "scope" is the scope of the certificate. // where "scope" is the scope of the certificate.
func (s *Store) Load(path string) error { func (s *Store) Load(path string) error {
matches := findCertificates(path) path = filepath.Clean(path)
matches, err := filepath.Glob(filepath.Join(path, "*.crt"))
if err != nil {
return err
}
for _, crtPath := range matches { for _, crtPath := range matches {
keyPath := strings.TrimSuffix(crtPath, ".crt") + ".key" keyPath := strings.TrimSuffix(crtPath, ".crt") + ".key"
cert, err := tls.LoadX509KeyPair(crtPath, keyPath) cert, err := tls.LoadX509KeyPair(crtPath, keyPath)
@ -191,26 +177,17 @@ func (s *Store) Load(path string) error {
continue continue
} }
scope := filepath.Clean(crtPath) scope := strings.TrimPrefix(crtPath, path)
scope = strings.TrimPrefix(crtPath, filepath.Clean(path))
scope = strings.TrimPrefix(scope, "/") scope = strings.TrimPrefix(scope, "/")
scope = strings.TrimSuffix(scope, ".crt") scope = strings.TrimSuffix(scope, ".crt")
s.Add(scope, cert) s.Add(scope, cert)
} }
s.SetPath(path) s.mu.Lock()
defer s.mu.Unlock()
s.path = path
return nil return nil
} }
func findCertificates(path string) (matches []string) {
filepath.Walk(path, func(path string, _ fs.FileInfo, err error) error {
if filepath.Ext(path) == ".crt" {
matches = append(matches, path)
}
return nil
})
return
}
// Entries returns a map of scopes to certificates. // Entries returns a map of scopes to certificates.
func (s *Store) Entries() map[string]tls.Certificate { func (s *Store) Entries() map[string]tls.Certificate {
s.mu.RLock() s.mu.RLock()
@ -226,5 +203,5 @@ func (s *Store) Entries() map[string]tls.Certificate {
func (s *Store) SetPath(path string) { func (s *Store) SetPath(path string) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.path = path s.path = filepath.Clean(path)
} }