24 Commits

Author SHA1 Message Date
Adnan Maolood
31de8d49b0 Guarantee that (*Response).Body is not nil 2021-01-15 15:18:00 -05:00
Adnan Maolood
2b17f3d8eb fs: Remove unused import 2021-01-14 22:45:09 -05:00
Adnan Maolood
f36a1c5c87 client: Add note about TOFU 2021-01-14 22:34:12 -05:00
Adnan Maolood
af61c1b60a fs: Update comments 2021-01-14 22:27:56 -05:00
Adnan Maolood
ad18ae601c fs: Don't panic on mime.AddExtensionType error
It's probably best not to panic if this fails.
2021-01-14 22:25:09 -05:00
Adnan Maolood
8473f3b9d4 fs: Update comments 2021-01-14 22:24:26 -05:00
Adnan Maolood
06c53cc5b1 server: Rename Register to Handle 2021-01-14 22:12:07 -05:00
Adnan Maolood
4b643523fb Update examples 2021-01-14 21:23:13 -05:00
Adnan Maolood
79a4dfd43f certificate: Add Dir.Entries function 2021-01-14 21:19:27 -05:00
Adnan Maolood
14d89f304a Move cert.go to a subpackage 2021-01-14 20:42:12 -05:00
Adnan Maolood
7a00539f75 tofu: Fix example 2021-01-14 19:57:52 -05:00
Adnan Maolood
a0adc42c95 tofu: Update documentation 2021-01-14 19:56:04 -05:00
Adnan Maolood
da8af5dbcb tofu: Update documentation 2021-01-14 19:40:19 -05:00
Adnan Maolood
ced6b06d76 Update examples/auth.go 2021-01-14 19:04:11 -05:00
Adnan Maolood
4a0f8e5e73 tofu: Rename KnownHosts.Hosts to Entries 2021-01-14 18:52:43 -05:00
Adnan Maolood
e701ceff71 Add KnownHosts.Hosts function 2021-01-14 18:50:03 -05:00
Adnan Maolood
1a3974b3a3 Update examples/client.go 2021-01-14 17:28:03 -05:00
Adnan Maolood
3fd55c5cee tofu: Add KnownHosts.Load function 2021-01-14 17:09:31 -05:00
Adnan Maolood
6f11910dff tofu: Add NewHostsFile function 2021-01-14 16:54:38 -05:00
Adnan Maolood
da3e9ac0fe tofu: Protect HostWriter with a mutex 2021-01-14 16:35:54 -05:00
Adnan Maolood
9fe837ffac tofu: Refactor known hosts
This commit introduces the KnownHosts struct, whose purpose is simply to
store known hosts entries. The HostWriter struct is now in charge of
appending hosts to files, and the two are not dependent on each other.
Users are now responsible for opening the known hosts file and closing
it when they are finished with it.
2021-01-14 16:26:43 -05:00
Adnan Maolood
4b8bb16a3d tofu: Rename KnownHost to Host 2021-01-14 14:15:08 -05:00
Hugo Wetterberg
95aff9c573 tofu: Refactor
This commit changes underlying file handling and known hosts parsing.

A known hosts file opened through Load() never closed the underlying
file. During known hosts parsing most errors were unchecked, or just
led to the line being skipped.

I removed the KnownHosts type, which didn't really have a role after
the refactor. The embedding of KnownHosts in KnownHosts file has been
removed as it also leaked the map unprotected by the mutex.

The Fingerprint type is now KnownHost and has taken over the
responsibility of marshalling and unmarshalling.

SetOutput now takes a WriteCloser so that we can close the underlying
writer when it's replaced, or when it's explicitly closed through the
new Close() function.

KnownHostsFile.Add() now also writes the known host to the output if
set. I think that makes sense expectation-wise for the type.

Turned WriteAll() into WriteTo() to conform with the io.WriterTo
interface.

Load() is now Open() to better reflect the fact that a file is opened,
and kept open. It can now also return errors from the parsing process.

The parser does a lot more error checking, and this might be an area
where I've changed a desired behaviour as invalid entries no longer
are ignored, but aborts the parsing process. That could be changed to
a warning, or some kind of parsing feedback.

I added KnownHostsFile.TOFU() to fill the developer experience gap
that was left after the client no longer knows about
KnownHostsFile. It implements a basic non-interactive TOFU flow.
2021-01-14 13:48:57 -05:00
Hugo Wetterberg
de042e4724 client: set the client timout on the dialer, close connection on err
Client.Timout isn't respected for the dial. Requests will hang on dial
until OS-level timouts kick in unless there is a Request.Context with
a deadline. We also fail to close the connection on errors.

This change sets the client timeout as the dialer timeout so that it
will be respected. It also ensures that we close the connection if we
fail to make the request.
2021-01-13 17:13:56 -05:00
12 changed files with 449 additions and 210 deletions

View File

@@ -1,4 +1,5 @@
package gemini
// Package certificate provides utility functions for TLS certificates.
package certificate
import (
"crypto"
@@ -19,27 +20,23 @@ import (
"time"
)
// CertificateDir maps certificate scopes to certificates.
type CertificateStore map[string]tls.Certificate
// CertificateDir represents a certificate store optionally loaded from a directory.
// The zero value of CertificateDir is an empty store ready to use.
// Dir represents a directory of certificates.
// The zero value for Dir is an empty directory ready to use.
//
// CertificateDir is safe for concurrent use by multiple goroutines.
type CertificateDir struct {
CertificateStore
dir bool
path string
mu sync.RWMutex
// Dir is safe for concurrent use by multiple goroutines.
type Dir struct {
certs map[string]tls.Certificate
path *string
mu sync.RWMutex
}
// Add adds a certificate for the given scope to the store.
// Add adds a certificate for the given scope to the directory.
// It tries to parse the certificate if it is not already parsed.
func (c *CertificateDir) Add(scope string, cert tls.Certificate) {
c.mu.Lock()
defer c.mu.Unlock()
if c.CertificateStore == nil {
c.CertificateStore = CertificateStore{}
func (d *Dir) Add(scope string, cert tls.Certificate) error {
d.mu.Lock()
defer d.mu.Unlock()
if d.certs == nil {
d.certs = map[string]tls.Certificate{}
}
// Parse certificate if not already parsed
if cert.Leaf == nil {
@@ -48,40 +45,45 @@ func (c *CertificateDir) Add(scope string, cert tls.Certificate) {
cert.Leaf = parsed
}
}
c.CertificateStore[scope] = cert
}
// Write writes the provided certificate to the certificate directory.
func (c *CertificateDir) Write(scope string, cert tls.Certificate) error {
c.mu.RLock()
defer c.mu.RUnlock()
if c.dir {
if d.path != nil {
// Escape slash character
scope = strings.ReplaceAll(scope, "/", ":")
certPath := filepath.Join(c.path, scope+".crt")
keyPath := filepath.Join(c.path, scope+".key")
if err := WriteCertificate(cert, certPath, keyPath); err != nil {
certPath := filepath.Join(*d.path, scope+".crt")
keyPath := filepath.Join(*d.path, scope+".key")
if err := Write(cert, certPath, keyPath); err != nil {
return err
}
}
d.certs[scope] = cert
return nil
}
// Lookup returns the certificate for the given scope.
func (c *CertificateDir) Lookup(scope string) (tls.Certificate, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
cert, ok := c.CertificateStore[scope]
// Lookup returns the certificate for the provided scope.
func (d *Dir) Lookup(scope string) (tls.Certificate, bool) {
d.mu.RLock()
defer d.mu.RUnlock()
cert, ok := d.certs[scope]
return cert, ok
}
// Load loads certificates from the given path.
// The path should lead to a directory containing certificates and private keys
// in the form scope.crt and scope.key.
// For example, the hostname "localhost" would have the corresponding files
// localhost.crt (certificate) and localhost.key (private key).
// New certificates will be written to this directory.
func (c *CertificateDir) Load(path string) error {
// Entries returns a map of hostnames to certificates.
func (d *Dir) Entries() map[string]tls.Certificate {
certs := map[string]tls.Certificate{}
for key := range d.certs {
certs[key] = d.certs[key]
}
return certs
}
// Load loads certificates from the provided path.
// Add will write certificates to this path.
//
// The directory should contain certificates and private keys
// named scope.crt and scope.key respectively, where scope is
// the scope of the certificate.
func (d *Dir) Load(path string) error {
matches, err := filepath.Glob(filepath.Join(path, "*.crt"))
if err != nil {
return err
@@ -95,31 +97,31 @@ func (c *CertificateDir) Load(path string) error {
scope := strings.TrimSuffix(filepath.Base(crtPath), ".crt")
// Unescape slash character
scope = strings.ReplaceAll(scope, ":", "/")
c.Add(scope, cert)
d.Add(scope, cert)
}
c.SetDir(path)
d.SetPath(path)
return nil
}
// SetDir sets the directory that new certificates will be written to.
func (c *CertificateDir) SetDir(path string) {
c.mu.Lock()
defer c.mu.Unlock()
c.dir = true
c.path = path
// SetPath sets the directory path.
// Add will write certificates to this path.
func (d *Dir) SetPath(path string) {
d.mu.Lock()
defer d.mu.Unlock()
d.path = &path
}
// CertificateOptions configures the creation of a certificate.
type CertificateOptions struct {
// Subject Alternate Name values.
// Should contain the IP addresses that the certificate is valid for.
IPAddresses []net.IP
// CreateOptions configures the creation of a TLS certificate.
type CreateOptions struct {
// Subject Alternate Name values.
// Should contain the DNS names that this certificate is valid for.
// E.g. example.com, *.example.com
DNSNames []string
// Subject Alternate Name values.
// Should contain the IP addresses that the certificate is valid for.
IPAddresses []net.IP
// Subject specifies the certificate Subject.
//
// Subject.CommonName can contain the DNS name that this certificate
@@ -136,8 +138,8 @@ type CertificateOptions struct {
Ed25519 bool
}
// CreateCertificate creates a new TLS certificate.
func CreateCertificate(options CertificateOptions) (tls.Certificate, error) {
// Create creates a new TLS certificate.
func Create(options CreateOptions) (tls.Certificate, error) {
crt, priv, err := newX509KeyPair(options)
if err != nil {
return tls.Certificate{}, err
@@ -150,7 +152,7 @@ func CreateCertificate(options CertificateOptions) (tls.Certificate, error) {
}
// newX509KeyPair creates and returns a new certificate and private key.
func newX509KeyPair(options CertificateOptions) (*x509.Certificate, crypto.PrivateKey, error) {
func newX509KeyPair(options CreateOptions) (*x509.Certificate, crypto.PrivateKey, error) {
var pub crypto.PublicKey
var priv crypto.PrivateKey
if options.Ed25519 {
@@ -206,9 +208,9 @@ func newX509KeyPair(options CertificateOptions) (*x509.Certificate, crypto.Priva
return cert, priv, nil
}
// WriteCertificate writes the provided certificate and private key
// Write writes the provided certificate and its private key
// to certPath and keyPath respectively.
func WriteCertificate(cert tls.Certificate, certPath, keyPath string) error {
func Write(cert tls.Certificate, certPath, keyPath string) error {
certOut, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err

View File

@@ -19,6 +19,9 @@ type Client struct {
// If TrustCertificate is nil, the client will accept any certificate.
// If the returned error is not nil, the certificate will not be trusted
// and the request will be aborted.
//
// For a basic trust on first use implementation, see (*KnownHosts).TOFU
// in the tofu submodule.
TrustCertificate func(hostname string, cert *x509.Certificate) error
// Timeout specifies a time limit for requests made by this
@@ -68,24 +71,48 @@ func (c *Client) Do(req *Request) (*Response, error) {
if ctx == nil {
ctx = context.Background()
}
netConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", req.Host)
start := time.Now()
dialer := net.Dialer{
Timeout: c.Timeout,
}
netConn, err := dialer.DialContext(ctx, "tcp", req.Host)
if err != nil {
return nil, err
}
conn := tls.Client(netConn, config)
// Set connection deadline
if c.Timeout != 0 {
err := conn.SetDeadline(time.Now().Add(c.Timeout))
err := conn.SetDeadline(start.Add(c.Timeout))
if err != nil {
return nil, fmt.Errorf(
"failed to set connection deadline: %w", err)
}
}
resp, err := c.do(conn, req)
if err != nil {
// If we fail to perform the request/response we have
// to take responsibility for closing the connection.
_ = conn.Close()
return nil, err
}
// Store connection state
resp.TLS = conn.ConnectionState()
return resp, nil
}
func (c *Client) do(conn *tls.Conn, req *Request) (*Response, error) {
// Write the request
w := bufio.NewWriter(conn)
err = req.Write(w)
err := req.Write(w)
if err != nil {
return nil, fmt.Errorf(
"failed to write request data: %w", err)
@@ -100,8 +127,6 @@ func (c *Client) Do(req *Request) (*Response, error) {
if err != nil {
return nil, err
}
// Store connection state
resp.TLS = conn.ConnectionState()
return resp, nil
}

5
doc.go
View File

@@ -8,10 +8,7 @@ Client is a Gemini client.
if err != nil {
// handle error
}
if resp.Body != nil {
defer resp.Body.Close()
// ...
}
defer resp.Body.Close()
// ...
Server is a Gemini server.

View File

@@ -12,6 +12,7 @@ import (
"time"
"git.sr.ht/~adnano/go-gemini"
"git.sr.ht/~adnano/go-gemini/certificate"
)
type User struct {
@@ -33,7 +34,7 @@ func main() {
log.Fatal(err)
}
server.CreateCertificate = func(hostname string) (tls.Certificate, error) {
return gemini.CreateCertificate(gemini.CertificateOptions{
return certificate.Create(certificate.CreateOptions{
Subject: pkix.Name{
CommonName: hostname,
},
@@ -41,7 +42,7 @@ func main() {
Duration: time.Hour,
})
}
server.Register("localhost", &mux)
server.Handle("localhost", &mux)
if err := server.ListenAndServe(); err != nil {
log.Fatal(err)
@@ -55,7 +56,7 @@ func fingerprint(cert *x509.Certificate) string {
func profile(w *gemini.ResponseWriter, r *gemini.Request) {
if r.Certificate == nil {
w.WriteStatus(gemini.StatusCertificateRequired)
w.Status(gemini.StatusCertificateRequired)
return
}
fingerprint := fingerprint(r.Certificate.Leaf)
@@ -70,13 +71,13 @@ func profile(w *gemini.ResponseWriter, r *gemini.Request) {
func changeUsername(w *gemini.ResponseWriter, r *gemini.Request) {
if r.Certificate == nil {
w.WriteStatus(gemini.StatusCertificateRequired)
w.Status(gemini.StatusCertificateRequired)
return
}
username, err := gemini.QueryUnescape(r.URL.RawQuery)
if err != nil || username == "" {
w.WriteHeader(gemini.StatusInput, "Username")
w.Header(gemini.StatusInput, "Username")
return
}
fingerprint := fingerprint(r.Certificate.Leaf)
@@ -86,5 +87,5 @@ func changeUsername(w *gemini.ResponseWriter, r *gemini.Request) {
users[fingerprint] = user
}
user.Name = username
w.WriteHeader(gemini.StatusRedirect, "/")
w.Header(gemini.StatusRedirect, "/")
}

View File

@@ -11,7 +11,7 @@ import (
"os"
"time"
"git.sr.ht/~adnano/go-gemini"
"git.sr.ht/~adnano/go-gemini/certificate"
)
func main() {
@@ -24,20 +24,20 @@ func main() {
if err != nil {
log.Fatal(err)
}
options := gemini.CertificateOptions{
options := certificate.CreateOptions{
Subject: pkix.Name{
CommonName: host,
},
DNSNames: []string{host},
Duration: duration,
}
cert, err := gemini.CreateCertificate(options)
cert, err := certificate.Create(options)
if err != nil {
log.Fatal(err)
}
certPath := host + ".crt"
keyPath := host + ".key"
if err := gemini.WriteCertificate(cert, certPath, keyPath); err != nil {
if err := certificate.Write(cert, certPath, keyPath); err != nil {
log.Fatal(err)
}
}

View File

@@ -6,6 +6,7 @@ package main
import (
"bufio"
"bytes"
"crypto/x509"
"errors"
"fmt"
@@ -22,8 +23,9 @@ import (
)
var (
hosts tofu.KnownHostsFile
scanner *bufio.Scanner
hosts tofu.KnownHosts
hostsfile *tofu.HostWriter
scanner *bufio.Scanner
)
func init() {
@@ -31,7 +33,12 @@ func init() {
path := filepath.Join(xdg.DataHome(), "gemini", "known_hosts")
err := hosts.Load(path)
if err != nil {
log.Println(err)
log.Fatal(err)
}
hostsfile, err = tofu.NewHostsFile(path)
if err != nil {
log.Fatal(err)
}
scanner = bufio.NewScanner(os.Stdin)
@@ -47,25 +54,26 @@ Otherwise, this should be safe to trust.
=> `
func trustCertificate(hostname string, cert *x509.Certificate) error {
fingerprint := tofu.NewFingerprint(cert.Raw, cert.NotAfter)
host := tofu.NewHost(hostname, cert.Raw, cert.NotAfter)
knownHost, ok := hosts.Lookup(hostname)
if ok && time.Now().Before(knownHost.Expires) {
// Check fingerprint
if knownHost.Hex == fingerprint.Hex {
if bytes.Equal(knownHost.Fingerprint, host.Fingerprint) {
return nil
}
return errors.New("error: fingerprint does not match!")
}
fmt.Printf(trustPrompt, hostname, fingerprint.Hex)
fmt.Printf(trustPrompt, hostname, host.Fingerprint)
scanner.Scan()
switch scanner.Text() {
case "t":
hosts.Add(hostname, fingerprint)
hosts.Write(hostname, fingerprint)
hosts.Add(host)
hostsfile.WriteHost(host)
return nil
case "o":
hosts.Add(hostname, fingerprint)
hosts.Add(host)
return nil
default:
return errors.New("certificate not trusted")
@@ -137,10 +145,10 @@ func main() {
fmt.Println(err)
os.Exit(1)
}
defer resp.Body.Close()
// Handle response
if resp.Status.Class() == gemini.StatusClassSuccess {
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Fatal(err)

View File

@@ -11,6 +11,7 @@ import (
"time"
"git.sr.ht/~adnano/go-gemini"
"git.sr.ht/~adnano/go-gemini/certificate"
)
func main() {
@@ -21,7 +22,7 @@ func main() {
log.Fatal(err)
}
server.CreateCertificate = func(hostname string) (tls.Certificate, error) {
return gemini.CreateCertificate(gemini.CertificateOptions{
return certificate.Create(certificate.CreateOptions{
Subject: pkix.Name{
CommonName: hostname,
},
@@ -33,7 +34,7 @@ func main() {
var mux gemini.ServeMux
mux.Handle("/", gemini.FileServer(gemini.Dir("/var/www")))
server.Register("localhost", &mux)
server.Handle("localhost", &mux)
if err := server.ListenAndServe(); err != nil {
log.Fatal(err)
}

View File

@@ -13,6 +13,7 @@ import (
"time"
"git.sr.ht/~adnano/go-gemini"
"git.sr.ht/~adnano/go-gemini/certificate"
)
func main() {
@@ -21,7 +22,7 @@ func main() {
log.Fatal(err)
}
server.CreateCertificate = func(hostname string) (tls.Certificate, error) {
return gemini.CreateCertificate(gemini.CertificateOptions{
return certificate.Create(certificate.CreateOptions{
Subject: pkix.Name{
CommonName: hostname,
},
@@ -30,7 +31,7 @@ func main() {
})
}
server.RegisterFunc("localhost", stream)
server.HandleFunc("localhost", stream)
if err := server.ListenAndServe(); err != nil {
log.Fatal(err)
}

25
fs.go
View File

@@ -1,7 +1,6 @@
package gemini
import (
"fmt"
"io"
"mime"
"os"
@@ -10,17 +9,14 @@ import (
func init() {
// Add Gemini mime types
if err := mime.AddExtensionType(".gmi", "text/gemini"); err != nil {
panic(fmt.Errorf("failed to register .gmi extension mimetype: %w", err))
}
if err := mime.AddExtensionType(".gemini", "text/gemini"); err != nil {
panic(fmt.Errorf("failed to register .gemini extension mimetype: %w", err))
}
mime.AddExtensionType(".gmi", "text/gemini")
mime.AddExtensionType(".gemini", "text/gemini")
}
// FileServer takes a filesystem and returns a Responder which uses that filesystem.
// The returned Responder sanitizes paths before handling them.
// The returned Responder cleans paths before handling them.
//
// TODO: Use io/fs.FS when available.
func FileServer(fsys FS) Responder {
return fsHandler{fsys}
}
@@ -44,12 +40,16 @@ func (fsh fsHandler) Respond(w *ResponseWriter, r *Request) {
_, _ = io.Copy(w, f)
}
// TODO: replace with io/fs.FS when available
// FS represents a filesystem.
//
// TODO: Replace with io/fs.FS when available.
type FS interface {
Open(name string) (File, error)
}
// TODO: replace with io/fs.File when available
// File represents a file.
//
// TODO: Replace with io/fs.File when available.
type File interface {
Stat() (os.FileInfo, error)
Read([]byte) (int, error)
@@ -57,6 +57,8 @@ type File interface {
}
// Dir implements FS using the native filesystem restricted to a specific directory.
//
// TODO: replace with os.DirFS when available.
type Dir string
// Open tries to open the file with the given name.
@@ -68,6 +70,7 @@ func (d Dir) Open(name string) (File, error) {
// ServeFile responds to the request with the contents of the named file
// or directory.
//
// TODO: Use io/fs.FS when available.
func ServeFile(w *ResponseWriter, fs FS, name string) {
f, err := fs.Open(name)

View File

@@ -18,7 +18,10 @@ type Response struct {
// Meta should not be longer than 1024 bytes.
Meta string
// Body contains the response body for successful responses.
// Body represents the response body.
// Body is guaranteed to always be non-nil.
//
// The response body is streamed on demand as the Body field is read.
Body io.ReadCloser
// TLS contains information about the TLS connection on which the response
@@ -83,11 +86,22 @@ func ReadResponse(rc io.ReadCloser) (*Response, error) {
if resp.Status.Class() == StatusClassSuccess {
resp.Body = newReadCloserBody(br, rc)
} else {
resp.Body = nopReadCloser{}
rc.Close()
}
return resp, nil
}
type nopReadCloser struct{}
func (nopReadCloser) Read(p []byte) (int, error) {
return 0, io.EOF
}
func (nopReadCloser) Close() error {
return nil
}
type readCloserBody struct {
br *bufio.Reader // used until empty
io.ReadCloser

View File

@@ -7,6 +7,8 @@ import (
"net"
"strings"
"time"
"git.sr.ht/~adnano/go-gemini/certificate"
)
// Server is a Gemini server.
@@ -23,7 +25,7 @@ type Server struct {
WriteTimeout time.Duration
// Certificates contains the certificates used by the server.
Certificates CertificateDir
Certificates certificate.Dir
// CreateCertificate, if not nil, will be called to create a new certificate
// if the current one is expired or missing.
@@ -44,12 +46,12 @@ type responderKey struct {
hostname string
}
// Register registers a responder for the given pattern.
// Handle registers a responder for the given pattern.
//
// Patterns must be in the form of "hostname" or "scheme://hostname".
// The pattern must be in the form of "hostname" or "scheme://hostname".
// If no scheme is specified, a scheme of "gemini://" is implied.
// Wildcard patterns are supported (e.g. "*.example.com").
func (s *Server) Register(pattern string, responder Responder) {
func (s *Server) Handle(pattern string, responder Responder) {
if pattern == "" {
panic("gemini: invalid pattern")
}
@@ -78,9 +80,9 @@ func (s *Server) Register(pattern string, responder Responder) {
s.hosts[key.hostname] = true
}
// RegisterFunc registers a responder function for the given pattern.
func (s *Server) RegisterFunc(pattern string, responder func(*ResponseWriter, *Request)) {
s.Register(pattern, ResponderFunc(responder))
// HandleFunc registers a responder function for the given pattern.
func (s *Server) HandleFunc(pattern string, responder func(*ResponseWriter, *Request)) {
s.Handle(pattern, ResponderFunc(responder))
}
// ListenAndServe listens for requests at the server's configured address.
@@ -157,8 +159,7 @@ func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) {
if s.CreateCertificate != nil {
cert, err := s.CreateCertificate(hostname)
if err == nil {
s.Certificates.Add(hostname, cert)
if err := s.Certificates.Write(hostname, cert); err != nil {
if err := s.Certificates.Add(hostname, cert); err != nil {
s.logf("gemini: Failed to write new certificate for %s: %s", hostname, err)
}
}

View File

@@ -3,156 +3,342 @@ package tofu
import (
"bufio"
"bytes"
"crypto/sha512"
"crypto/x509"
"errors"
"fmt"
"io"
"os"
"sort"
"strconv"
"strings"
"sync"
"time"
)
// KnownHosts maps hosts to fingerprints.
type KnownHosts map[string]Fingerprint
// KnownHostsFile represents a list of known hosts optionally loaded from a file.
// The zero value for KnownHostsFile represents an empty list ready to use.
// KnownHosts represents a list of known hosts.
// The zero value for KnownHosts represents an empty list ready to use.
//
// KnownHostsFile is safe for concurrent use by multiple goroutines.
type KnownHostsFile struct {
KnownHosts
out io.Writer
mu sync.RWMutex
// KnownHosts is safe for concurrent use by multiple goroutines.
type KnownHosts struct {
hosts map[string]Host
mu sync.RWMutex
}
// SetOutput sets the output to which new known hosts will be written to.
func (k *KnownHostsFile) SetOutput(w io.Writer) {
// Add adds a host to the list of known hosts.
func (k *KnownHosts) Add(h Host) error {
k.mu.Lock()
defer k.mu.Unlock()
k.out = w
}
// Add adds a known host to the list of known hosts.
func (k *KnownHostsFile) Add(hostname string, fingerprint Fingerprint) {
k.mu.Lock()
defer k.mu.Unlock()
if k.KnownHosts == nil {
k.KnownHosts = KnownHosts{}
if k.hosts == nil {
k.hosts = map[string]Host{}
}
k.KnownHosts[hostname] = fingerprint
k.hosts[h.Hostname] = h
return nil
}
// Lookup returns the fingerprint of the certificate corresponding to
// the given hostname.
func (k *KnownHostsFile) Lookup(hostname string) (Fingerprint, bool) {
// Lookup returns the known host entry corresponding to the given hostname.
func (k *KnownHosts) Lookup(hostname string) (Host, bool) {
k.mu.RLock()
defer k.mu.RUnlock()
c, ok := k.KnownHosts[hostname]
c, ok := k.hosts[hostname]
return c, ok
}
// Write writes a known hosts entry to the configured output.
func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) error {
// Entries returns the known host entries sorted by hostname.
func (k *KnownHosts) Entries() []Host {
keys := make([]string, 0, len(k.hosts))
for key := range k.hosts {
keys = append(keys, key)
}
sort.Strings(keys)
hosts := make([]Host, 0, len(k.hosts))
for _, key := range keys {
hosts = append(hosts, k.hosts[key])
}
return hosts
}
// WriteTo writes the list of known hosts to the provided io.Writer.
func (k *KnownHosts) WriteTo(w io.Writer) (int64, error) {
k.mu.RLock()
defer k.mu.RUnlock()
if k.out != nil {
_, err := k.writeKnownHost(k.out, hostname, fingerprint)
var written int
bw := bufio.NewWriter(w)
for _, h := range k.hosts {
n, err := bw.WriteString(h.String())
written += n
if err != nil {
return fmt.Errorf("failed to write to known host file: %w", err)
return int64(written), err
}
bw.WriteByte('\n')
written += 1
}
return nil
return int64(written), bw.Flush()
}
// WriteAll writes all of the known hosts to the provided io.Writer.
func (k *KnownHostsFile) WriteAll(w io.Writer) error {
k.mu.RLock()
defer k.mu.RUnlock()
for h, c := range k.KnownHosts {
if _, err := k.writeKnownHost(w, h, c); err != nil {
return err
}
}
return nil
}
// writeKnownHost writes a known host to the provided io.Writer.
func (k *KnownHostsFile) writeKnownHost(w io.Writer, hostname string, f Fingerprint) (int, error) {
return fmt.Fprintf(w, "%s %s %s %d\n", hostname, f.Algorithm, f.Hex, f.Expires.Unix())
}
// Load loads the known hosts from the provided path.
// It creates the file if it does not exist.
// New known hosts will be appended to the file.
func (k *KnownHostsFile) Load(path string) error {
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644)
// Load loads the known hosts entries from the provided path.
func (k *KnownHosts) Load(path string) error {
f, err := os.Open(path)
if err != nil {
return err
}
k.Parse(f)
k.SetOutput(f)
return nil
defer f.Close()
return k.Parse(f)
}
// Parse parses the provided reader and adds the parsed known hosts to the list.
// Parse parses the provided io.Reader and adds the parsed hosts to the list.
// Invalid entries are ignored.
func (k *KnownHostsFile) Parse(r io.Reader) {
//
// For more control over errors encountered during parsing, use bufio.Scanner
// in combination with ParseHost. For example:
//
// var knownHosts tofu.KnownHosts
// scanner := bufio.NewScanner(r)
// for scanner.Scan() {
// host, err := tofu.ParseHost(scanner.Bytes())
// if err != nil {
// // handle error
// } else {
// knownHosts.Add(host)
// }
// }
// err := scanner.Err()
// if err != nil {
// // handle error
// }
//
func (k *KnownHosts) Parse(r io.Reader) error {
k.mu.Lock()
defer k.mu.Unlock()
if k.KnownHosts == nil {
k.KnownHosts = map[string]Fingerprint{}
if k.hosts == nil {
k.hosts = map[string]Host{}
}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
text := scanner.Text()
parts := strings.Split(text, " ")
if len(parts) < 4 {
text := scanner.Bytes()
if len(text) == 0 {
continue
}
hostname := parts[0]
algorithm := parts[1]
if algorithm != "SHA-512" {
continue
}
hex := parts[2]
unix, err := strconv.ParseInt(parts[3], 10, 0)
h, err := ParseHost(text)
if err != nil {
continue
}
expires := time.Unix(unix, 0)
k.KnownHosts[hostname] = Fingerprint{
Algorithm: algorithm,
Hex: hex,
Expires: expires,
}
k.hosts[h.Hostname] = h
}
return scanner.Err()
}
// TOFU implements basic trust on first use.
//
// If the host is not on file, it is added to the list.
// If the host on file is expired, it is replaced with the provided host.
// If the fingerprint does not match the one on file, an error is returned.
func (k *KnownHosts) TOFU(hostname string, cert *x509.Certificate) error {
host := NewHost(hostname, cert.Raw, cert.NotAfter)
knownHost, ok := k.Lookup(hostname)
if !ok || time.Now().After(knownHost.Expires) {
k.Add(host)
return nil
}
// Check fingerprint
if !bytes.Equal(knownHost.Fingerprint, host.Fingerprint) {
return fmt.Errorf("fingerprint for %q does not match", hostname)
}
return nil
}
// HostWriter writes host entries to an io.WriteCloser.
//
// HostWriter is safe for concurrent use by multiple goroutines.
type HostWriter struct {
bw *bufio.Writer
cl io.Closer
mu sync.Mutex
}
// NewHostWriter returns a new host writer that writes to
// the provided io.WriteCloser.
func NewHostWriter(w io.WriteCloser) *HostWriter {
return &HostWriter{
bw: bufio.NewWriter(w),
cl: w,
}
}
// Fingerprint represents a fingerprint using a certain algorithm.
type Fingerprint struct {
Algorithm string // fingerprint algorithm e.g. SHA-512
Hex string // fingerprint in hexadecimal, with ':' between each octet
Expires time.Time // unix time of the fingerprint expiration date
// NewHostsFile returns a new host writer that appends to the file at the given path.
// The file is created if it does not exist.
func NewHostsFile(path string) (*HostWriter, error) {
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return nil, err
}
return NewHostWriter(f), nil
}
// NewFingerprint returns the SHA-512 fingerprint of the provided raw data.
func NewFingerprint(raw []byte, expires time.Time) Fingerprint {
sum512 := sha512.Sum512(raw)
// WriteHost writes the host to the underlying io.Writer.
func (h *HostWriter) WriteHost(host Host) error {
h.mu.Lock()
defer h.mu.Unlock()
h.bw.WriteString(host.String())
h.bw.WriteByte('\n')
if err := h.bw.Flush(); err != nil {
return fmt.Errorf("failed to write host: %w", err)
}
return nil
}
// Close closes the underlying io.Closer.
func (h *HostWriter) Close() error {
h.mu.Lock()
defer h.mu.Unlock()
return h.cl.Close()
}
// Host represents a host entry with a fingerprint using a certain algorithm.
type Host struct {
Hostname string // hostname
Algorithm string // fingerprint algorithm e.g. SHA-512
Fingerprint Fingerprint // fingerprint
Expires time.Time // unix time of the fingerprint expiration date
}
// NewHost returns a new host with a SHA-512 fingerprint of
// the provided raw data.
func NewHost(hostname string, raw []byte, expires time.Time) Host {
sum := sha512.Sum512(raw)
return Host{
Hostname: hostname,
Algorithm: "SHA-512",
Fingerprint: sum[:],
Expires: expires,
}
}
// ParseHost parses a host from the provided text.
func ParseHost(text []byte) (Host, error) {
var h Host
err := h.UnmarshalText(text)
return h, err
}
// String returns a string representation of the host.
func (h Host) String() string {
var b strings.Builder
for i, f := range sum512 {
if i > 0 {
b.WriteByte(':')
}
fmt.Fprintf(&b, "%02X", f)
}
return Fingerprint{
Algorithm: "SHA-512",
Hex: b.String(),
Expires: expires,
}
b.WriteString(h.Hostname)
b.WriteByte(' ')
b.WriteString(h.Algorithm)
b.WriteByte(' ')
b.WriteString(h.Fingerprint.String())
b.WriteByte(' ')
b.WriteString(strconv.FormatInt(h.Expires.Unix(), 10))
return b.String()
}
// UnmarshalText unmarshals the host from the provided text.
func (h *Host) UnmarshalText(text []byte) error {
const format = "hostname algorithm hex-fingerprint expiry-unix-ts"
parts := bytes.Split(text, []byte(" "))
if len(parts) != 4 {
return fmt.Errorf(
"expected the format %q", format)
}
if len(parts[0]) == 0 {
return errors.New("empty hostname")
}
h.Hostname = string(parts[0])
algorithm := string(parts[1])
if algorithm != "SHA-512" {
return fmt.Errorf(
"unsupported algorithm %q", algorithm)
}
h.Algorithm = algorithm
fingerprint := make([]byte, 0, sha512.Size)
scanner := bufio.NewScanner(bytes.NewReader(parts[2]))
scanner.Split(scanFingerprint)
for scanner.Scan() {
b, err := strconv.ParseUint(scanner.Text(), 16, 8)
if err != nil {
return fmt.Errorf("failed to parse fingerprint hash: %w", err)
}
fingerprint = append(fingerprint, byte(b))
}
if len(fingerprint) != sha512.Size {
return fmt.Errorf("invalid fingerprint size %d, expected %d",
len(fingerprint), sha512.Size)
}
h.Fingerprint = fingerprint
unix, err := strconv.ParseInt(string(parts[3]), 10, 0)
if err != nil {
return fmt.Errorf(
"invalid unix timestamp: %w", err)
}
h.Expires = time.Unix(unix, 0)
return nil
}
func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, ':'); i >= 0 {
// We have a full newline-terminated line.
return i + 1, data[0:i], nil
}
// If we're at EOF, we have a final, non-terminated hex byte
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
// Fingerprint represents a fingerprint.
type Fingerprint []byte
// String returns a string representation of the fingerprint.
func (f Fingerprint) String() string {
var sb strings.Builder
for i, b := range f {
if i > 0 {
sb.WriteByte(':')
}
fmt.Fprintf(&sb, "%02X", b)
}
return sb.String()
}