From fc72224ce9a604cc14d11c7fedfdedb465facbf5 Mon Sep 17 00:00:00 2001 From: Adnan Maolood Date: Tue, 27 Oct 2020 22:12:10 -0400 Subject: [PATCH] client: Follow redirects --- client.go | 36 ++++++++++++++++++++++++++++++++++ examples/client.go | 49 ++++++++++++++++++---------------------------- gemini.go | 1 + 3 files changed, 56 insertions(+), 30 deletions(-) diff --git a/client.go b/client.go index e48fe22..d8d88ae 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "net" + "net/url" ) // Client represents a Gemini client. @@ -17,6 +18,12 @@ type Client struct { // a certificate. CertificateStore CertificateStore + // CheckRedirect, if not nil, will be called to determine whether + // to follow a redirect. + // If CheckRedirect is nil, a default policy of no more than 5 consecutive + // redirects will be enforced. + CheckRedirect func(req *Request, via []*Request) error + // GetCertificate, if not nil, will be called when a server requests a certificate. // The returned certificate will be used when sending the request again. // If the certificate is nil, the request will not be sent again and @@ -40,6 +47,10 @@ func (c *Client) Get(url string) (*Response, error) { // Do performs a Gemini request and returns a Gemini response. func (c *Client) Do(req *Request) (*Response, error) { + return c.do(req, nil) +} + +func (c *Client) do(req *Request, via []*Request) (*Response, error) { // Connect to the host config := &tls.Config{ InsecureSkipVerify: true, @@ -105,6 +116,31 @@ func (c *Client) Do(req *Request) (*Response, error) { return c.Do(req) } } + } else if resp.Status.Class() == StatusClassRedirect { + if via == nil { + via = []*Request{} + } + via = append(via, req) + + target, err := url.Parse(resp.Meta) + if err != nil { + return resp, err + } + target = req.URL.ResolveReference(target) + redirect, err := NewRequestFromURL(target) + if err != nil { + return resp, err + } + + if c.CheckRedirect != nil { + if err := c.CheckRedirect(redirect, via); err != nil { + return resp, err + } + } else if len(via) > 5 { + // Default policy of no more than 5 redirects + return resp, ErrTooManyRedirects + } + return c.do(redirect, via) } return resp, nil } diff --git a/examples/client.go b/examples/client.go index 16608ca..5180f5c 100644 --- a/examples/client.go +++ b/examples/client.go @@ -12,26 +12,25 @@ import ( "os" "time" - gmi "git.sr.ht/~adnano/go-gemini" + "git.sr.ht/~adnano/go-gemini" ) var ( scanner = bufio.NewScanner(os.Stdin) - client = &gmi.Client{} + client = &gemini.Client{} ) func init() { - // Initialize the client - client.KnownHosts.LoadDefault() // Load known hosts - client.TrustCertificate = func(hostname string, cert *x509.Certificate, knownHosts *gmi.KnownHosts) error { + client.KnownHosts.LoadDefault() + client.TrustCertificate = func(hostname string, cert *x509.Certificate, knownHosts *gemini.KnownHosts) error { err := knownHosts.Lookup(hostname, cert) if err != nil { switch err { - case gmi.ErrCertificateNotTrusted: + case gemini.ErrCertificateNotTrusted: // Alert the user that the certificate is not trusted fmt.Printf("Warning: Certificate for %s is not trusted!\n", hostname) fmt.Println("This could indicate a Man-in-the-Middle attack.") - case gmi.ErrCertificateUnknown: + case gemini.ErrCertificateUnknown: // Prompt the user to trust the certificate trust := trustCertificate(cert) switch trust { @@ -48,7 +47,7 @@ func init() { } return err } - client.GetCertificate = func(hostname string, store *gmi.CertificateStore) *tls.Certificate { + client.GetCertificate = func(hostname string, store *gemini.CertificateStore) *tls.Certificate { // If the certificate is in the store, return it if cert, err := store.Lookup(hostname); err == nil { return cert @@ -56,7 +55,7 @@ func init() { // Otherwise, generate a certificate fmt.Println("Generating client certificate for", hostname) duration := time.Hour - cert, err := gmi.NewCertificate(hostname, duration) + cert, err := gemini.NewCertificate(hostname, duration) if err != nil { return nil } @@ -67,20 +66,19 @@ func init() { } // sendRequest sends a request to the given URL. -func sendRequest(req *gmi.Request) error { +func sendRequest(req *gemini.Request) error { resp, err := client.Do(req) if err != nil { return err } - // TODO: More fine-grained analysis of the status code. switch resp.Status.Class() { - case gmi.StatusClassInput: + case gemini.StatusClassInput: fmt.Printf("%s: ", resp.Meta) scanner.Scan() req.URL.RawQuery = url.QueryEscape(scanner.Text()) return sendRequest(req) - case gmi.StatusClassSuccess: + case gemini.StatusClassSuccess: defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -88,23 +86,14 @@ func sendRequest(req *gmi.Request) error { } fmt.Print(string(body)) return nil - case gmi.StatusClassRedirect: - fmt.Println("Redirecting to", resp.Meta) - target, err := url.Parse(resp.Meta) - if err != nil { - return err - } - // TODO: Prompt the user if the redirect is to another domain. - redirect, err := gmi.NewRequestFromURL(req.URL.ResolveReference(target)) - if err != nil { - return err - } - return sendRequest(redirect) - case gmi.StatusClassTemporaryFailure: + case gemini.StatusClassRedirect: + // This should not happen unless CheckRedirect returns false. + return fmt.Errorf("Failed to redirect to %s", resp.Meta) + case gemini.StatusClassTemporaryFailure: return fmt.Errorf("Temporary failure: %s", resp.Meta) - case gmi.StatusClassPermanentFailure: + case gemini.StatusClassPermanentFailure: return fmt.Errorf("Permanent failure: %s", resp.Meta) - case gmi.StatusClassCertificateRequired: + case gemini.StatusClassCertificateRequired: // Note that this should not happen unless the server responds with // CertificateRequired even after we send a certificate. // CertificateNotAuthorized and CertificateNotValid are handled here. @@ -131,7 +120,7 @@ Otherwise, this should be safe to trust. => ` func trustCertificate(cert *x509.Certificate) trust { - fmt.Printf(trustPrompt, gmi.Fingerprint(cert)) + fmt.Printf(trustPrompt, gemini.Fingerprint(cert)) scanner.Scan() switch scanner.Text() { case "t": @@ -150,7 +139,7 @@ func main() { } url := os.Args[1] - req, err := gmi.NewRequest(url) + req, err := gemini.NewRequest(url) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/gemini.go b/gemini.go index 173b137..f3e0c81 100644 --- a/gemini.go +++ b/gemini.go @@ -20,6 +20,7 @@ var ( ErrNotAFile = errors.New("gemini: not a file") ErrNotAGeminiURL = errors.New("gemini: not a Gemini URL") ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body") + ErrTooManyRedirects = errors.New("gemini: too many redirects") ) // DefaultClient is the default client. It is used by Get and Do.