From aeafd57956a73899c207e3577ea4dadda07567aa Mon Sep 17 00:00:00 2001 From: Adnan Maolood Date: Thu, 17 Dec 2020 16:52:08 -0500 Subject: [PATCH] Make CertificateDir safe for concurrent use by multiple goroutines --- cert.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/cert.go b/cert.go index 7cadbd9..d129b6d 100644 --- a/cert.go +++ b/cert.go @@ -23,15 +23,20 @@ type CertificateStore map[string]tls.Certificate // CertificateDir represents a certificate store optionally loaded from a directory. // The zero value of CertificateDir is an empty store ready to use. +// +// CertificateDir is safe for concurrent use by multiple goroutines. type CertificateDir struct { CertificateStore dir bool path string + mu sync.Mutex } // Add adds a certificate for the given scope to the store. // It tries to parse the certificate if it is not already parsed. func (c *CertificateDir) Add(scope string, cert tls.Certificate) { + c.mu.Lock() + defer c.mu.Unlock() if c.CertificateStore == nil { c.CertificateStore = CertificateStore{} } @@ -47,6 +52,8 @@ func (c *CertificateDir) Add(scope string, cert tls.Certificate) { // Write writes the provided certificate to the certificate directory. func (c *CertificateDir) Write(scope string, cert tls.Certificate) error { + c.mu.RLock() + defer c.mu.RUnlock() if c.dir { // Escape slash character scope = strings.ReplaceAll(scope, "/", ":") @@ -61,6 +68,8 @@ func (c *CertificateDir) Write(scope string, cert tls.Certificate) error { // Lookup returns the certificate for the given scope. func (c *CertificateDir) Lookup(scope string) (tls.Certificate, bool) { + c.mu.RLock() + defer c.mu.RUnlock() cert, ok := c.CertificateStore[scope] return cert, ok } @@ -72,6 +81,8 @@ func (c *CertificateDir) Lookup(scope string) (tls.Certificate, bool) { // localhost.crt (certificate) and localhost.key (private key). // New certificates will be written to this directory. func (c *CertificateDir) Load(path string) error { + c.mu.Lock() + defer c.mu.Unlock() matches, err := filepath.Glob(filepath.Join(path, "*.crt")) if err != nil { return err @@ -94,6 +105,8 @@ func (c *CertificateDir) Load(path string) error { // SetDir sets the directory that new certificates will be written to. func (c *CertificateDir) SetDir(path string) { + c.mu.Lock() + defer c.mu.Unlock() c.dir = true c.path = path }