Make KnownHostsFile safe for concurrent use

This commit is contained in:
Adnan Maolood 2020-12-17 16:49:59 -05:00
parent 846fa2ac41
commit e687a05170

17
tofu.go
View File

@ -25,18 +25,25 @@ type KnownHosts map[string]Fingerprint
// KnownHostsFile represents a list of known hosts optionally loaded from a file. // KnownHostsFile represents a list of known hosts optionally loaded from a file.
// The zero value for KnownHostsFile represents an empty list ready to use. // The zero value for KnownHostsFile represents an empty list ready to use.
//
// KnownHostsFile is safe for concurrent use by multiple goroutines.
type KnownHostsFile struct { type KnownHostsFile struct {
KnownHosts KnownHosts
out io.Writer out io.Writer
mu sync.Mutex
} }
// SetOutput sets the output to which new known hosts will be written to. // SetOutput sets the output to which new known hosts will be written to.
func (k *KnownHostsFile) SetOutput(w io.Writer) { func (k *KnownHostsFile) SetOutput(w io.Writer) {
k.mu.Lock()
defer k.mu.Unlock()
k.out = w k.out = w
} }
// Add adds a known host to the list of known hosts. // Add adds a known host to the list of known hosts.
func (k *KnownHostsFile) Add(hostname string, fingerprint Fingerprint) { func (k *KnownHostsFile) Add(hostname string, fingerprint Fingerprint) {
k.mu.Lock()
defer k.mu.Unlock()
if k.KnownHosts == nil { if k.KnownHosts == nil {
k.KnownHosts = KnownHosts{} k.KnownHosts = KnownHosts{}
} }
@ -46,12 +53,16 @@ func (k *KnownHostsFile) Add(hostname string, fingerprint Fingerprint) {
// Lookup returns the fingerprint of the certificate corresponding to // Lookup returns the fingerprint of the certificate corresponding to
// the given hostname. // the given hostname.
func (k *KnownHostsFile) Lookup(hostname string) (Fingerprint, bool) { func (k *KnownHostsFile) Lookup(hostname string) (Fingerprint, bool) {
k.mu.RLock()
defer k.mu.RUnlock()
c, ok := k.KnownHosts[hostname] c, ok := k.KnownHosts[hostname]
return c, ok return c, ok
} }
// Write writes a known hosts entry to the configured output. // Write writes a known hosts entry to the configured output.
func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) { func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) {
k.mu.RLock()
defer k.mu.RUnlock()
if k.out != nil { if k.out != nil {
k.writeKnownHost(k.out, hostname, fingerprint) k.writeKnownHost(k.out, hostname, fingerprint)
} }
@ -59,6 +70,8 @@ func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) {
// WriteAll writes all of the known hosts to the provided io.Writer. // WriteAll writes all of the known hosts to the provided io.Writer.
func (k *KnownHostsFile) WriteAll(w io.Writer) error { func (k *KnownHostsFile) WriteAll(w io.Writer) error {
k.mu.RLock()
defer k.mu.RUnlock()
for h, c := range k.KnownHosts { for h, c := range k.KnownHosts {
if _, err := k.writeKnownHost(w, h, c); err != nil { if _, err := k.writeKnownHost(w, h, c); err != nil {
return err return err
@ -76,6 +89,8 @@ func (k *KnownHostsFile) writeKnownHost(w io.Writer, hostname string, f Fingerpr
// It creates the file if it does not exist. // It creates the file if it does not exist.
// New known hosts will be appended to the file. // New known hosts will be appended to the file.
func (k *KnownHostsFile) Load(path string) error { func (k *KnownHostsFile) Load(path string) error {
k.mu.Lock()
defer k.mu.Unlock()
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0644) f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0644)
if err != nil { if err != nil {
return err return err
@ -94,6 +109,8 @@ func (k *KnownHostsFile) Load(path string) error {
// Parse parses the provided reader and adds the parsed known hosts to the list. // Parse parses the provided reader and adds the parsed known hosts to the list.
// Invalid entries are ignored. // Invalid entries are ignored.
func (k *KnownHostsFile) Parse(r io.Reader) { func (k *KnownHostsFile) Parse(r io.Reader) {
k.mu.Lock()
defer k.mu.Unlock()
if k.KnownHosts == nil { if k.KnownHosts == nil {
k.KnownHosts = map[string]Fingerprint{} k.KnownHosts = map[string]Fingerprint{}
} }