From e687a05170a8f49754cef51430296ce4679fac1d Mon Sep 17 00:00:00 2001 From: Adnan Maolood Date: Thu, 17 Dec 2020 16:49:59 -0500 Subject: [PATCH] Make KnownHostsFile safe for concurrent use --- tofu.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tofu.go b/tofu.go index 0c15504..b410863 100644 --- a/tofu.go +++ b/tofu.go @@ -25,18 +25,25 @@ type KnownHosts map[string]Fingerprint // KnownHostsFile represents a list of known hosts optionally loaded from a file. // The zero value for KnownHostsFile represents an empty list ready to use. +// +// KnownHostsFile is safe for concurrent use by multiple goroutines. type KnownHostsFile struct { KnownHosts out io.Writer + mu sync.Mutex } // SetOutput sets the output to which new known hosts will be written to. func (k *KnownHostsFile) SetOutput(w io.Writer) { + k.mu.Lock() + defer k.mu.Unlock() k.out = w } // Add adds a known host to the list of known hosts. func (k *KnownHostsFile) Add(hostname string, fingerprint Fingerprint) { + k.mu.Lock() + defer k.mu.Unlock() if k.KnownHosts == nil { k.KnownHosts = KnownHosts{} } @@ -46,12 +53,16 @@ func (k *KnownHostsFile) Add(hostname string, fingerprint Fingerprint) { // Lookup returns the fingerprint of the certificate corresponding to // the given hostname. func (k *KnownHostsFile) Lookup(hostname string) (Fingerprint, bool) { + k.mu.RLock() + defer k.mu.RUnlock() c, ok := k.KnownHosts[hostname] return c, ok } // Write writes a known hosts entry to the configured output. func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) { + k.mu.RLock() + defer k.mu.RUnlock() if k.out != nil { k.writeKnownHost(k.out, hostname, fingerprint) } @@ -59,6 +70,8 @@ func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) { // WriteAll writes all of the known hosts to the provided io.Writer. func (k *KnownHostsFile) WriteAll(w io.Writer) error { + k.mu.RLock() + defer k.mu.RUnlock() for h, c := range k.KnownHosts { if _, err := k.writeKnownHost(w, h, c); err != nil { return err @@ -76,6 +89,8 @@ func (k *KnownHostsFile) writeKnownHost(w io.Writer, hostname string, f Fingerpr // It creates the file if it does not exist. // New known hosts will be appended to the file. func (k *KnownHostsFile) Load(path string) error { + k.mu.Lock() + defer k.mu.Unlock() f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0644) if err != nil { return err @@ -94,6 +109,8 @@ func (k *KnownHostsFile) Load(path string) error { // Parse parses the provided reader and adds the parsed known hosts to the list. // Invalid entries are ignored. func (k *KnownHostsFile) Parse(r io.Reader) { + k.mu.Lock() + defer k.mu.Unlock() if k.KnownHosts == nil { k.KnownHosts = map[string]Fingerprint{} }