Make KnownHostsFile safe for concurrent use

This commit is contained in:
Adnan Maolood 2020-12-17 16:49:59 -05:00
parent 846fa2ac41
commit e687a05170

17
tofu.go
View File

@ -25,18 +25,25 @@ type KnownHosts map[string]Fingerprint
// KnownHostsFile represents a list of known hosts optionally loaded from a file.
// The zero value for KnownHostsFile represents an empty list ready to use.
//
// KnownHostsFile is safe for concurrent use by multiple goroutines.
type KnownHostsFile struct {
KnownHosts
out io.Writer
mu sync.Mutex
}
// SetOutput sets the output to which new known hosts will be written to.
func (k *KnownHostsFile) SetOutput(w io.Writer) {
k.mu.Lock()
defer k.mu.Unlock()
k.out = w
}
// Add adds a known host to the list of known hosts.
func (k *KnownHostsFile) Add(hostname string, fingerprint Fingerprint) {
k.mu.Lock()
defer k.mu.Unlock()
if k.KnownHosts == nil {
k.KnownHosts = KnownHosts{}
}
@ -46,12 +53,16 @@ func (k *KnownHostsFile) Add(hostname string, fingerprint Fingerprint) {
// Lookup returns the fingerprint of the certificate corresponding to
// the given hostname.
func (k *KnownHostsFile) Lookup(hostname string) (Fingerprint, bool) {
k.mu.RLock()
defer k.mu.RUnlock()
c, ok := k.KnownHosts[hostname]
return c, ok
}
// Write writes a known hosts entry to the configured output.
func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) {
k.mu.RLock()
defer k.mu.RUnlock()
if k.out != nil {
k.writeKnownHost(k.out, hostname, fingerprint)
}
@ -59,6 +70,8 @@ func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) {
// WriteAll writes all of the known hosts to the provided io.Writer.
func (k *KnownHostsFile) WriteAll(w io.Writer) error {
k.mu.RLock()
defer k.mu.RUnlock()
for h, c := range k.KnownHosts {
if _, err := k.writeKnownHost(w, h, c); err != nil {
return err
@ -76,6 +89,8 @@ func (k *KnownHostsFile) writeKnownHost(w io.Writer, hostname string, f Fingerpr
// It creates the file if it does not exist.
// New known hosts will be appended to the file.
func (k *KnownHostsFile) Load(path string) error {
k.mu.Lock()
defer k.mu.Unlock()
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0644)
if err != nil {
return err
@ -94,6 +109,8 @@ func (k *KnownHostsFile) Load(path string) error {
// Parse parses the provided reader and adds the parsed known hosts to the list.
// Invalid entries are ignored.
func (k *KnownHostsFile) Parse(r io.Reader) {
k.mu.Lock()
defer k.mu.Unlock()
if k.KnownHosts == nil {
k.KnownHosts = map[string]Fingerprint{}
}