From e2c907a7f65a98a1737ffe50f55d6c7dfa9abb9e Mon Sep 17 00:00:00 2001 From: Adnan Maolood Date: Fri, 18 Dec 2020 00:12:32 -0500 Subject: [PATCH] client: Remove GetInput and CheckRedirect callbacks --- client.go | 53 +------------------ examples/client.go | 125 +++++++++++++++++++++++++++++++-------------- 2 files changed, 87 insertions(+), 91 deletions(-) diff --git a/client.go b/client.go index 7eea9d8..2dc2ccc 100644 --- a/client.go +++ b/client.go @@ -6,7 +6,6 @@ import ( "crypto/tls" "crypto/x509" "net" - "net/url" "strings" "time" ) @@ -20,15 +19,6 @@ type Client struct { // and the request will be aborted. TrustCertificate func(hostname string, cert *x509.Certificate) error - // GetInput is called to retrieve input when the server requests it. - // If GetInput is nil or returns false, no input will be sent and - // the response will be returned. - GetInput func(prompt string, sensitive bool) (input string, ok bool) - - // CheckRedirect determines whether to follow a redirect. - // If CheckRedirect is nil, redirects will not be followed. - CheckRedirect func(req *Request, via []*Request) error - // Timeout specifies a time limit for requests made by this // Client. The timeout includes connection time and reading // the response body. The timer remains running after @@ -49,10 +39,6 @@ func (c *Client) Get(url string) (*Response, error) { // Do performs a Gemini request and returns a Gemini response. func (c *Client) Do(req *Request) (*Response, error) { - return c.do(req, nil) -} - -func (c *Client) do(req *Request, via []*Request) (*Response, error) { // Extract hostname colonPos := strings.LastIndex(req.Host, ":") if colonPos == -1 { @@ -75,7 +61,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { }, ServerName: hostname, } - + // Set connection context ctx := req.Context if ctx == nil { ctx = context.Background() @@ -106,43 +92,6 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { // Store connection state resp.TLS = conn.ConnectionState() - switch resp.Status.Class() { - case StatusClassInput: - if c.GetInput == nil { - break - } - - input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput) - if ok { - req.URL.ForceQuery = true - req.URL.RawQuery = QueryEscape(input) - return c.do(req, via) - } - - case StatusClassRedirect: - if c.CheckRedirect == nil { - break - } - - if via == nil { - via = []*Request{} - } - via = append(via, req) - - target, err := url.Parse(resp.Meta) - if err != nil { - return resp, err - } - target = req.URL.ResolveReference(target) - - redirect := NewRequestFromURL(target) - redirect.Context = req.Context - if err := c.CheckRedirect(redirect, via); err != nil { - return resp, err - } - return c.do(redirect, via) - } - return resp, nil } diff --git a/examples/client.go b/examples/client.go index 1c98bf5..4975158 100644 --- a/examples/client.go +++ b/examples/client.go @@ -9,6 +9,7 @@ import ( "fmt" "io/ioutil" "log" + "net/url" "os" "path/filepath" "time" @@ -17,6 +18,22 @@ import ( "git.sr.ht/~adnano/go-xdg" ) +var ( + hosts gemini.KnownHostsFile + scanner *bufio.Scanner +) + +func init() { + // Load known hosts file + path := filepath.Join(xdg.DataHome(), "gemini", "known_hosts") + err := hosts.Load(path) + if err != nil { + log.Println(err) + } + + scanner = bufio.NewScanner(os.Stdin) +} + const trustPrompt = `The certificate offered by %s is of unknown trust. Its fingerprint is: %s @@ -26,49 +43,79 @@ Otherwise, this should be safe to trust. [t]rust always; trust [o]nce; [a]bort => ` +func trustCertificate(hostname string, cert *x509.Certificate) error { + knownHost, ok := hosts.Lookup(hostname) + if ok && time.Now().Before(knownHost.Expires) { + // Certificate is in known hosts file and is not expired + return nil + } + + fingerprint := gemini.NewFingerprint(cert.Raw, cert.NotAfter) + fmt.Printf(trustPrompt, hostname, fingerprint.Hex) + scanner.Scan() + switch scanner.Text() { + case "t": + hosts.Add(hostname, fingerprint) + hosts.Write(hostname, fingerprint) + return nil + case "o": + hosts.Add(hostname, fingerprint) + return nil + default: + return errors.New("certificate not trusted") + } +} + +func getInput(prompt string, sensitive bool) (input string, ok bool) { + fmt.Printf("%s ", prompt) + scanner.Scan() + return scanner.Text(), true +} + +func do(req *gemini.Request, via []*gemini.Request) (*gemini.Response, error) { + client := gemini.Client{ + TrustCertificate: trustCertificate, + } + resp, err := client.Do(req) + if err != nil { + return resp, err + } + + switch resp.Status.Class() { + case gemini.StatusClassInput: + input, ok := getInput(resp.Meta, resp.Status == gemini.StatusSensitiveInput) + if !ok { + break + } + req.URL.ForceQuery = true + req.URL.RawQuery = gemini.QueryEscape(input) + return do(req, via) + + case gemini.StatusClassRedirect: + via = append(via, req) + if len(via) > 5 { + return resp, errors.New("too many redirects") + } + + target, err := url.Parse(resp.Meta) + if err != nil { + return resp, err + } + target = req.URL.ResolveReference(target) + redirect := *req + redirect.URL = target + return do(&redirect, via) + } + + return resp, err +} + func main() { if len(os.Args) < 2 { - fmt.Printf("usage: %s [host]", os.Args[0]) + fmt.Printf("usage: %s [host]\n", os.Args[0]) os.Exit(1) } - // Load known hosts file - var knownHosts gemini.KnownHostsFile - if err := knownHosts.Load(filepath.Join(xdg.DataHome(), "gemini", "known_hosts")); err != nil { - log.Println(err) - } - - scanner := bufio.NewScanner(os.Stdin) - - var client gemini.Client - client.TrustCertificate = func(hostname string, cert *x509.Certificate) error { - knownHost, ok := knownHosts.Lookup(hostname) - if ok && time.Now().Before(knownHost.Expires) { - // Certificate is in known hosts file and is not expired - return nil - } - - fingerprint := gemini.NewFingerprint(cert.Raw, cert.NotAfter) - fmt.Printf(trustPrompt, hostname, fingerprint.Hex) - scanner.Scan() - switch scanner.Text() { - case "t": - knownHosts.Add(hostname, fingerprint) - knownHosts.Write(hostname, fingerprint) - return nil - case "o": - knownHosts.Add(hostname, fingerprint) - return nil - default: - return errors.New("certificate not trusted") - } - } - client.GetInput = func(prompt string, sensitive bool) (string, bool) { - fmt.Printf("%s ", prompt) - scanner.Scan() - return scanner.Text(), true - } - // Do the request url := os.Args[1] req, err := gemini.NewRequest(url) @@ -79,7 +126,7 @@ func main() { if len(os.Args) == 3 { req.Host = os.Args[2] } - resp, err := client.Do(req) + resp, err := do(req, nil) if err != nil { fmt.Println(err) os.Exit(1)