Make KnownHostsFile safe for concurrent use
This commit is contained in:
parent
846fa2ac41
commit
e687a05170
17
tofu.go
17
tofu.go
@ -25,18 +25,25 @@ type KnownHosts map[string]Fingerprint
|
|||||||
|
|
||||||
// KnownHostsFile represents a list of known hosts optionally loaded from a file.
|
// KnownHostsFile represents a list of known hosts optionally loaded from a file.
|
||||||
// The zero value for KnownHostsFile represents an empty list ready to use.
|
// The zero value for KnownHostsFile represents an empty list ready to use.
|
||||||
|
//
|
||||||
|
// KnownHostsFile is safe for concurrent use by multiple goroutines.
|
||||||
type KnownHostsFile struct {
|
type KnownHostsFile struct {
|
||||||
KnownHosts
|
KnownHosts
|
||||||
out io.Writer
|
out io.Writer
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetOutput sets the output to which new known hosts will be written to.
|
// SetOutput sets the output to which new known hosts will be written to.
|
||||||
func (k *KnownHostsFile) SetOutput(w io.Writer) {
|
func (k *KnownHostsFile) SetOutput(w io.Writer) {
|
||||||
|
k.mu.Lock()
|
||||||
|
defer k.mu.Unlock()
|
||||||
k.out = w
|
k.out = w
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add adds a known host to the list of known hosts.
|
// Add adds a known host to the list of known hosts.
|
||||||
func (k *KnownHostsFile) Add(hostname string, fingerprint Fingerprint) {
|
func (k *KnownHostsFile) Add(hostname string, fingerprint Fingerprint) {
|
||||||
|
k.mu.Lock()
|
||||||
|
defer k.mu.Unlock()
|
||||||
if k.KnownHosts == nil {
|
if k.KnownHosts == nil {
|
||||||
k.KnownHosts = KnownHosts{}
|
k.KnownHosts = KnownHosts{}
|
||||||
}
|
}
|
||||||
@ -46,12 +53,16 @@ func (k *KnownHostsFile) Add(hostname string, fingerprint Fingerprint) {
|
|||||||
// Lookup returns the fingerprint of the certificate corresponding to
|
// Lookup returns the fingerprint of the certificate corresponding to
|
||||||
// the given hostname.
|
// the given hostname.
|
||||||
func (k *KnownHostsFile) Lookup(hostname string) (Fingerprint, bool) {
|
func (k *KnownHostsFile) Lookup(hostname string) (Fingerprint, bool) {
|
||||||
|
k.mu.RLock()
|
||||||
|
defer k.mu.RUnlock()
|
||||||
c, ok := k.KnownHosts[hostname]
|
c, ok := k.KnownHosts[hostname]
|
||||||
return c, ok
|
return c, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write writes a known hosts entry to the configured output.
|
// Write writes a known hosts entry to the configured output.
|
||||||
func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) {
|
func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) {
|
||||||
|
k.mu.RLock()
|
||||||
|
defer k.mu.RUnlock()
|
||||||
if k.out != nil {
|
if k.out != nil {
|
||||||
k.writeKnownHost(k.out, hostname, fingerprint)
|
k.writeKnownHost(k.out, hostname, fingerprint)
|
||||||
}
|
}
|
||||||
@ -59,6 +70,8 @@ func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) {
|
|||||||
|
|
||||||
// WriteAll writes all of the known hosts to the provided io.Writer.
|
// WriteAll writes all of the known hosts to the provided io.Writer.
|
||||||
func (k *KnownHostsFile) WriteAll(w io.Writer) error {
|
func (k *KnownHostsFile) WriteAll(w io.Writer) error {
|
||||||
|
k.mu.RLock()
|
||||||
|
defer k.mu.RUnlock()
|
||||||
for h, c := range k.KnownHosts {
|
for h, c := range k.KnownHosts {
|
||||||
if _, err := k.writeKnownHost(w, h, c); err != nil {
|
if _, err := k.writeKnownHost(w, h, c); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -76,6 +89,8 @@ func (k *KnownHostsFile) writeKnownHost(w io.Writer, hostname string, f Fingerpr
|
|||||||
// It creates the file if it does not exist.
|
// It creates the file if it does not exist.
|
||||||
// New known hosts will be appended to the file.
|
// New known hosts will be appended to the file.
|
||||||
func (k *KnownHostsFile) Load(path string) error {
|
func (k *KnownHostsFile) Load(path string) error {
|
||||||
|
k.mu.Lock()
|
||||||
|
defer k.mu.Unlock()
|
||||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0644)
|
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -94,6 +109,8 @@ func (k *KnownHostsFile) Load(path string) error {
|
|||||||
// Parse parses the provided reader and adds the parsed known hosts to the list.
|
// Parse parses the provided reader and adds the parsed known hosts to the list.
|
||||||
// Invalid entries are ignored.
|
// Invalid entries are ignored.
|
||||||
func (k *KnownHostsFile) Parse(r io.Reader) {
|
func (k *KnownHostsFile) Parse(r io.Reader) {
|
||||||
|
k.mu.Lock()
|
||||||
|
defer k.mu.Unlock()
|
||||||
if k.KnownHosts == nil {
|
if k.KnownHosts == nil {
|
||||||
k.KnownHosts = map[string]Fingerprint{}
|
k.KnownHosts = map[string]Fingerprint{}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user