From 5d099a4fe1b846fcc987259b19cb5c8922ffd399 Mon Sep 17 00:00:00 2001 From: adnano Date: Sun, 27 Sep 2020 23:58:45 -0400 Subject: [PATCH] Only generate certificates after CertificateRequired --- client.go | 24 ++++++++++++++++++------ examples/client/client.go | 1 - 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index 37b1b6d..94b4363 100644 --- a/client.go +++ b/client.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "errors" "io/ioutil" + "log" "net" "net/url" "strconv" @@ -188,7 +189,7 @@ type Client struct { CertificateStore CertificateStore // GetCertificate, if not nil, will be called to determine which certificate - // (if any) should be used for a request. + // to use when the server responds with CertificateRequired. GetCertificate func(hostname string, store CertificateStore) *tls.Certificate // TrustCertificate, if not nil, will be called to determine whether the @@ -204,11 +205,6 @@ func (c *Client) Send(req *Request) (*Response, error) { InsecureSkipVerify: true, MinVersion: tls.VersionTLS12, GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - if c.GetCertificate != nil { - if cert := c.GetCertificate(req.Hostname(), c.CertificateStore); cert != nil { - return cert, nil - } - } if req.Certificate != nil { return req.Certificate, nil } @@ -261,6 +257,22 @@ func (c *Client) Send(req *Request) (*Response, error) { } // Store connection information resp.TLS = conn.ConnectionState() + + // Resend the request with a certificate if the server responded + // with CertificateRequired + if resp.Status == StatusCertificateRequired { + // Check to see if a certificate was already provided to prevent an infinite loop + if req.Certificate != nil { + return resp, nil + } + if c.GetCertificate != nil { + if cert := c.GetCertificate(req.Hostname(), c.CertificateStore); cert != nil { + req.Certificate = cert + return c.Send(req) + } + } + } + return resp, nil } diff --git a/examples/client/client.go b/examples/client/client.go index a64365f..e8a16be 100644 --- a/examples/client/client.go +++ b/examples/client/client.go @@ -89,7 +89,6 @@ func sendRequest(req *gmi.Request) error { } // Handle relative redirects red.URL = req.URL.ResolveReference(red.URL) - fmt.Println(red.URL, red.Host) return sendRequest(red) case gmi.StatusClassTemporaryFailure: return fmt.Errorf("Temporary failure: %s", resp.Meta)