Refactor client certificates
This commit is contained in:
parent
768664e0c5
commit
fbd97a62de
75
cert.go
75
cert.go
@ -20,9 +20,9 @@ type CertificateStore struct {
|
||||
store map[string]tls.Certificate
|
||||
}
|
||||
|
||||
// Add adds a certificate for the given hostname to the store.
|
||||
// Add adds a certificate for the given scope to the store.
|
||||
// It tries to parse the certificate if it is not already parsed.
|
||||
func (c *CertificateStore) Add(hostname string, cert tls.Certificate) {
|
||||
func (c *CertificateStore) Add(scope string, cert tls.Certificate) {
|
||||
if c.store == nil {
|
||||
c.store = map[string]tls.Certificate{}
|
||||
}
|
||||
@ -33,7 +33,7 @@ func (c *CertificateStore) Add(hostname string, cert tls.Certificate) {
|
||||
cert.Leaf = parsed
|
||||
}
|
||||
}
|
||||
c.store[hostname] = cert
|
||||
c.store[scope] = cert
|
||||
}
|
||||
|
||||
// Lookup returns the certificate for the given hostname.
|
||||
@ -49,6 +49,22 @@ func (c *CertificateStore) Lookup(hostname string) (*tls.Certificate, error) {
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
// lookup returns the certificate for the given hostname + path.
|
||||
func (c *CertificateStore) lookup(scope string) (*tls.Certificate, error) {
|
||||
for {
|
||||
cert, err := c.Lookup(scope)
|
||||
switch err {
|
||||
case ErrCertificateExpired, nil:
|
||||
return cert, err
|
||||
}
|
||||
scope = path.Dir(scope)
|
||||
if scope == "." {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil, ErrCertificateUnknown
|
||||
}
|
||||
|
||||
// Load loads certificates from the given path.
|
||||
// The path should lead to a directory containing certificates and private keys
|
||||
// in the form hostname.crt and hostname.key.
|
||||
@ -71,36 +87,16 @@ func (c *CertificateStore) Load(path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type ClientCertificateStore struct {
|
||||
CertificateStore
|
||||
// CertificateOptions configures how a certificate is created.
|
||||
type CertificateOptions struct {
|
||||
IPAddresses []net.IP
|
||||
DNSNames []string
|
||||
Duration time.Duration
|
||||
}
|
||||
|
||||
func (c *ClientCertificateStore) Lookup(hostname, urlPath string) (*tls.Certificate, error) {
|
||||
urlPath = path.Clean(urlPath)
|
||||
if urlPath == "." {
|
||||
urlPath = "/"
|
||||
}
|
||||
if urlPath[0] != '/' {
|
||||
urlPath = "/" + urlPath
|
||||
}
|
||||
for {
|
||||
cert, err := c.CertificateStore.Lookup(hostname + urlPath)
|
||||
switch err {
|
||||
case ErrCertificateExpired, nil:
|
||||
return cert, err
|
||||
}
|
||||
slash := urlPath == "/"
|
||||
urlPath = path.Dir(urlPath)
|
||||
if slash && urlPath == "/" {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil, ErrCertificateUnknown
|
||||
}
|
||||
|
||||
// NewCertificate creates and returns a new parsed certificate.
|
||||
func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) {
|
||||
crt, priv, err := newX509KeyPair(host, duration)
|
||||
// CreateCertificate creates a new TLS certificate.
|
||||
func CreateCertificate(options CertificateOptions) (tls.Certificate, error) {
|
||||
crt, priv, err := newX509KeyPair(options)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
@ -112,7 +108,7 @@ func NewCertificate(host string, duration time.Duration) (tls.Certificate, error
|
||||
}
|
||||
|
||||
// newX509KeyPair creates and returns a new certificate and private key.
|
||||
func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, crypto.PrivateKey, error) {
|
||||
func newX509KeyPair(options CertificateOptions) (*x509.Certificate, crypto.PrivateKey, error) {
|
||||
// Generate an ED25519 private key
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
@ -131,7 +127,7 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(duration)
|
||||
notAfter := notBefore.Add(options.Duration)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
@ -140,17 +136,8 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
|
||||
KeyUsage: keyUsage,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
if host != "" {
|
||||
hosts := strings.Split(host, ",")
|
||||
for _, h := range hosts {
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
template.IPAddresses = append(template.IPAddresses, ip)
|
||||
} else {
|
||||
template.DNSNames = append(template.DNSNames, h)
|
||||
}
|
||||
}
|
||||
IPAddresses: options.IPAddresses,
|
||||
DNSNames: options.DNSNames,
|
||||
}
|
||||
|
||||
crt, err := x509.CreateCertificate(rand.Reader, &template, &template, public, priv)
|
||||
|
151
client.go
151
client.go
@ -6,37 +6,38 @@ import (
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Client represents a Gemini client.
|
||||
// Client is a Gemini client.
|
||||
type Client struct {
|
||||
// KnownHosts is a list of known hosts that the client trusts.
|
||||
// KnownHosts is a list of known hosts.
|
||||
KnownHosts KnownHosts
|
||||
|
||||
// CertificateStore maps hostnames to certificates.
|
||||
// It is used to determine which certificate to use when the server requests
|
||||
// a certificate.
|
||||
CertificateStore ClientCertificateStore
|
||||
// Certificates stores client-side certificates.
|
||||
Certificates CertificateStore
|
||||
|
||||
// CheckRedirect, if not nil, will be called to determine whether
|
||||
// to follow a redirect.
|
||||
// GetInput is called to retrieve input when the server requests it.
|
||||
// If GetInput is nil or returns false, no input will be sent and
|
||||
// the response will be returned.
|
||||
GetInput func(prompt string, sensitive bool) (input string, ok bool)
|
||||
|
||||
// CheckRedirect determines whether to follow a redirect.
|
||||
// If CheckRedirect is nil, a default policy of no more than 5 consecutive
|
||||
// redirects will be enforced.
|
||||
CheckRedirect func(req *Request, via []*Request) error
|
||||
|
||||
// GetInput, if not nil, will be called to retrieve input when the server
|
||||
// requests it.
|
||||
GetInput func(prompt string, sensitive bool) (string, bool)
|
||||
// CreateCertificate is called to generate a certificate upon
|
||||
// the request of a server.
|
||||
// If CreateCertificate is nil or the returned error is not nil,
|
||||
// the request will not be sent again and the response will be returned.
|
||||
CreateCertificate func(hostname, path string) (tls.Certificate, error)
|
||||
|
||||
// GetCertificate, if not nil, will be called when a server requests a certificate.
|
||||
// The returned certificate will be used when sending the request again.
|
||||
// If the certificate is nil, the request will not be sent again and
|
||||
// the response will be returned.
|
||||
GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate
|
||||
|
||||
// TrustCertificate, if not nil, will be called to determine whether the
|
||||
// client should trust the given certificate.
|
||||
// If error is not nil, the connection will be aborted.
|
||||
// TrustCertificate determines whether the client should trust
|
||||
// the provided certificate.
|
||||
// If the returned error is not nil, the connection will be aborted.
|
||||
// If TrustCertificate is nil, the client will check KnownHosts
|
||||
// for the certificate.
|
||||
TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error
|
||||
}
|
||||
|
||||
@ -59,38 +60,18 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
||||
config := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
// Request certificates take precedence over client certificates
|
||||
if req.Certificate != nil {
|
||||
return req.Certificate, nil
|
||||
}
|
||||
// If we have already stored the certificate, return it
|
||||
if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil {
|
||||
return cert, nil
|
||||
}
|
||||
return &tls.Certificate{}, nil
|
||||
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return c.getClientCertificate(req)
|
||||
},
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
cert := cs.PeerCertificates[0]
|
||||
// Verify the hostname
|
||||
if err := verifyHostname(cert, hostname(req.Host)); err != nil {
|
||||
return err
|
||||
}
|
||||
// Check that the client trusts the certificate
|
||||
if c.TrustCertificate == nil {
|
||||
if err := c.KnownHosts.Lookup(hostname(req.Host), cert); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := c.TrustCertificate(hostname(req.Host), cert, &c.KnownHosts); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return c.verifyConnection(req, cs)
|
||||
},
|
||||
}
|
||||
conn, err := tls.Dial("tcp", req.Host, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO: Set connection deadline
|
||||
|
||||
// Write the request
|
||||
w := bufio.NewWriter(conn)
|
||||
@ -104,24 +85,39 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
||||
if err := resp.read(conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Store connection information
|
||||
// Store connection state
|
||||
resp.TLS = conn.ConnectionState()
|
||||
|
||||
// Resend the request with a certificate if the server responded
|
||||
// with CertificateRequired
|
||||
if resp.Status == StatusCertificateRequired {
|
||||
switch {
|
||||
case resp.Status == StatusCertificateRequired:
|
||||
// Check to see if a certificate was already provided to prevent an infinite loop
|
||||
if req.Certificate != nil {
|
||||
return resp, nil
|
||||
}
|
||||
if c.GetCertificate != nil {
|
||||
if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil {
|
||||
req.Certificate = cert
|
||||
return c.Do(req)
|
||||
|
||||
hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
|
||||
if c.CreateCertificate != nil {
|
||||
cert, err := c.CreateCertificate(hostname, path)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
c.Certificates.Add(hostname+path, cert)
|
||||
return c.do(req, via)
|
||||
}
|
||||
return resp, ErrCertificateRequired
|
||||
} else if resp.Status.Class() == StatusClassRedirect {
|
||||
|
||||
case resp.Status.Class() == StatusClassInput:
|
||||
if c.GetInput != nil {
|
||||
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
|
||||
if ok {
|
||||
req.URL.ForceQuery = true
|
||||
req.URL.RawQuery = url.QueryEscape(input)
|
||||
return c.do(req, via)
|
||||
}
|
||||
}
|
||||
return resp, ErrInputRequired
|
||||
|
||||
case resp.Status.Class() == StatusClassRedirect:
|
||||
if via == nil {
|
||||
via = []*Request{}
|
||||
}
|
||||
@ -146,27 +142,44 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
||||
return resp, ErrTooManyRedirects
|
||||
}
|
||||
return c.do(redirect, via)
|
||||
} else if resp.Status.Class() == StatusClassInput {
|
||||
if c.GetInput != nil {
|
||||
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
|
||||
if ok {
|
||||
req.URL.ForceQuery = true
|
||||
req.URL.RawQuery = url.QueryEscape(input)
|
||||
return c.do(req, via)
|
||||
}
|
||||
}
|
||||
return resp, ErrInputRequired
|
||||
}
|
||||
|
||||
resp.Request = req
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// hostname returns the host without the port.
|
||||
func hostname(host string) string {
|
||||
hostname, _, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return host
|
||||
func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) {
|
||||
// Request certificates have the highest precedence
|
||||
if req.Certificate != nil {
|
||||
return req.Certificate, nil
|
||||
}
|
||||
return hostname
|
||||
hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
|
||||
if cert, err := c.Certificates.lookup(hostname + path); err == nil {
|
||||
// Remember the certificate used
|
||||
req.Certificate = cert
|
||||
return cert, nil
|
||||
}
|
||||
return &tls.Certificate{}, nil
|
||||
}
|
||||
|
||||
func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
|
||||
// Verify the hostname
|
||||
var hostname string
|
||||
if host, _, err := net.SplitHostPort(req.Host); err == nil {
|
||||
hostname = host
|
||||
} else {
|
||||
hostname = req.Host
|
||||
}
|
||||
cert := cs.PeerCertificates[0]
|
||||
if err := verifyHostname(cert, hostname); err != nil {
|
||||
return err
|
||||
}
|
||||
// Check that the client trusts the certificate
|
||||
var err error
|
||||
if c.TrustCertificate != nil {
|
||||
return c.TrustCertificate(hostname, cert, &c.KnownHosts)
|
||||
} else {
|
||||
err = c.KnownHosts.Lookup(hostname, cert)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -46,21 +46,11 @@ func init() {
|
||||
}
|
||||
return err
|
||||
}
|
||||
client.GetCertificate = func(req *gemini.Request, store *gemini.ClientCertificateStore) *tls.Certificate {
|
||||
// If the certificate is in the store, return it
|
||||
if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
|
||||
return cert
|
||||
}
|
||||
// Otherwise, generate a certificate
|
||||
fmt.Println("Generating client certificate for", req.URL.Hostname()+req.URL.Path)
|
||||
duration := time.Hour
|
||||
cert, err := gemini.NewCertificate("", duration)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
// Store and return the certificate
|
||||
store.Add(req.URL.Hostname()+req.URL.Path, cert)
|
||||
return &cert
|
||||
client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
|
||||
fmt.Println("Generating client certificate for", hostname, path)
|
||||
return gemini.CreateCertificate(gemini.CertificateOptions{
|
||||
Duration: time.Hour,
|
||||
})
|
||||
}
|
||||
client.GetInput = func(prompt string, sensitive bool) (string, bool) {
|
||||
fmt.Printf("%s: ", prompt)
|
||||
@ -69,8 +59,7 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// sendRequest sends a request to the given URL.
|
||||
func sendRequest(req *gemini.Request) error {
|
||||
func doRequest(req *gemini.Request) error {
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -131,7 +120,7 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := sendRequest(req); err != nil {
|
||||
if err := doRequest(req); err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
@ -29,7 +29,10 @@ func main() {
|
||||
fallthrough
|
||||
case gmi.ErrCertificateUnknown:
|
||||
// Generate a certificate if one does not exist.
|
||||
cert, err := gmi.NewCertificate(hostname, time.Minute)
|
||||
cert, err := gmi.CreateCertificate(gmi.CertificateOptions{
|
||||
DNSNames: []string{hostname},
|
||||
Duration: time.Hour,
|
||||
})
|
||||
if err != nil {
|
||||
// Failed to generate new certificate, abort
|
||||
return nil
|
||||
|
15
gemini.go
15
gemini.go
@ -51,16 +51,9 @@ func init() {
|
||||
defaultClientOnce.Do(func() { knownHosts.LoadDefault() })
|
||||
return knownHosts.Lookup(hostname, cert)
|
||||
}
|
||||
DefaultClient.GetCertificate = func(req *Request, store *ClientCertificateStore) *tls.Certificate {
|
||||
if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
|
||||
return cert
|
||||
}
|
||||
duration := time.Hour
|
||||
cert, err := NewCertificate("", duration)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
store.Add(req.URL.Hostname()+req.URL.Path, cert)
|
||||
return &cert
|
||||
DefaultClient.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
|
||||
return CreateCertificate(CertificateOptions{
|
||||
Duration: time.Hour,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
24
status.go
24
status.go
@ -24,6 +24,18 @@ const (
|
||||
StatusCertificateNotValid Status = 62
|
||||
)
|
||||
|
||||
// Status code categories.
|
||||
type StatusClass int
|
||||
|
||||
const (
|
||||
StatusClassInput StatusClass = 1
|
||||
StatusClassSuccess StatusClass = 2
|
||||
StatusClassRedirect StatusClass = 3
|
||||
StatusClassTemporaryFailure StatusClass = 4
|
||||
StatusClassPermanentFailure StatusClass = 5
|
||||
StatusClassCertificateRequired StatusClass = 6
|
||||
)
|
||||
|
||||
// Class returns the status class for this status code.
|
||||
func (s Status) Class() StatusClass {
|
||||
return StatusClass(s / 10)
|
||||
@ -71,15 +83,3 @@ func (s Status) Message() string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Status code categories.
|
||||
type StatusClass int
|
||||
|
||||
const (
|
||||
StatusClassInput StatusClass = 1
|
||||
StatusClassSuccess StatusClass = 2
|
||||
StatusClassRedirect StatusClass = 3
|
||||
StatusClassTemporaryFailure StatusClass = 4
|
||||
StatusClassPermanentFailure StatusClass = 5
|
||||
StatusClassCertificateRequired StatusClass = 6
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user