client: Remove GetInput and CheckRedirect callbacks

This commit is contained in:
Adnan Maolood 2020-12-18 00:12:32 -05:00
parent a09cb5a23c
commit e2c907a7f6
2 changed files with 87 additions and 91 deletions

View File

@ -6,7 +6,6 @@ import (
"crypto/tls"
"crypto/x509"
"net"
"net/url"
"strings"
"time"
)
@ -20,15 +19,6 @@ type Client struct {
// and the request will be aborted.
TrustCertificate func(hostname string, cert *x509.Certificate) error
// GetInput is called to retrieve input when the server requests it.
// If GetInput is nil or returns false, no input will be sent and
// the response will be returned.
GetInput func(prompt string, sensitive bool) (input string, ok bool)
// CheckRedirect determines whether to follow a redirect.
// If CheckRedirect is nil, redirects will not be followed.
CheckRedirect func(req *Request, via []*Request) error
// Timeout specifies a time limit for requests made by this
// Client. The timeout includes connection time and reading
// the response body. The timer remains running after
@ -49,10 +39,6 @@ func (c *Client) Get(url string) (*Response, error) {
// Do performs a Gemini request and returns a Gemini response.
func (c *Client) Do(req *Request) (*Response, error) {
return c.do(req, nil)
}
func (c *Client) do(req *Request, via []*Request) (*Response, error) {
// Extract hostname
colonPos := strings.LastIndex(req.Host, ":")
if colonPos == -1 {
@ -75,7 +61,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
},
ServerName: hostname,
}
// Set connection context
ctx := req.Context
if ctx == nil {
ctx = context.Background()
@ -106,43 +92,6 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
// Store connection state
resp.TLS = conn.ConnectionState()
switch resp.Status.Class() {
case StatusClassInput:
if c.GetInput == nil {
break
}
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
if ok {
req.URL.ForceQuery = true
req.URL.RawQuery = QueryEscape(input)
return c.do(req, via)
}
case StatusClassRedirect:
if c.CheckRedirect == nil {
break
}
if via == nil {
via = []*Request{}
}
via = append(via, req)
target, err := url.Parse(resp.Meta)
if err != nil {
return resp, err
}
target = req.URL.ResolveReference(target)
redirect := NewRequestFromURL(target)
redirect.Context = req.Context
if err := c.CheckRedirect(redirect, via); err != nil {
return resp, err
}
return c.do(redirect, via)
}
return resp, nil
}

View File

@ -9,6 +9,7 @@ import (
"fmt"
"io/ioutil"
"log"
"net/url"
"os"
"path/filepath"
"time"
@ -17,6 +18,22 @@ import (
"git.sr.ht/~adnano/go-xdg"
)
var (
hosts gemini.KnownHostsFile
scanner *bufio.Scanner
)
func init() {
// Load known hosts file
path := filepath.Join(xdg.DataHome(), "gemini", "known_hosts")
err := hosts.Load(path)
if err != nil {
log.Println(err)
}
scanner = bufio.NewScanner(os.Stdin)
}
const trustPrompt = `The certificate offered by %s is of unknown trust. Its fingerprint is:
%s
@ -26,49 +43,79 @@ Otherwise, this should be safe to trust.
[t]rust always; trust [o]nce; [a]bort
=> `
func trustCertificate(hostname string, cert *x509.Certificate) error {
knownHost, ok := hosts.Lookup(hostname)
if ok && time.Now().Before(knownHost.Expires) {
// Certificate is in known hosts file and is not expired
return nil
}
fingerprint := gemini.NewFingerprint(cert.Raw, cert.NotAfter)
fmt.Printf(trustPrompt, hostname, fingerprint.Hex)
scanner.Scan()
switch scanner.Text() {
case "t":
hosts.Add(hostname, fingerprint)
hosts.Write(hostname, fingerprint)
return nil
case "o":
hosts.Add(hostname, fingerprint)
return nil
default:
return errors.New("certificate not trusted")
}
}
func getInput(prompt string, sensitive bool) (input string, ok bool) {
fmt.Printf("%s ", prompt)
scanner.Scan()
return scanner.Text(), true
}
func do(req *gemini.Request, via []*gemini.Request) (*gemini.Response, error) {
client := gemini.Client{
TrustCertificate: trustCertificate,
}
resp, err := client.Do(req)
if err != nil {
return resp, err
}
switch resp.Status.Class() {
case gemini.StatusClassInput:
input, ok := getInput(resp.Meta, resp.Status == gemini.StatusSensitiveInput)
if !ok {
break
}
req.URL.ForceQuery = true
req.URL.RawQuery = gemini.QueryEscape(input)
return do(req, via)
case gemini.StatusClassRedirect:
via = append(via, req)
if len(via) > 5 {
return resp, errors.New("too many redirects")
}
target, err := url.Parse(resp.Meta)
if err != nil {
return resp, err
}
target = req.URL.ResolveReference(target)
redirect := *req
redirect.URL = target
return do(&redirect, via)
}
return resp, err
}
func main() {
if len(os.Args) < 2 {
fmt.Printf("usage: %s <url> [host]", os.Args[0])
fmt.Printf("usage: %s <url> [host]\n", os.Args[0])
os.Exit(1)
}
// Load known hosts file
var knownHosts gemini.KnownHostsFile
if err := knownHosts.Load(filepath.Join(xdg.DataHome(), "gemini", "known_hosts")); err != nil {
log.Println(err)
}
scanner := bufio.NewScanner(os.Stdin)
var client gemini.Client
client.TrustCertificate = func(hostname string, cert *x509.Certificate) error {
knownHost, ok := knownHosts.Lookup(hostname)
if ok && time.Now().Before(knownHost.Expires) {
// Certificate is in known hosts file and is not expired
return nil
}
fingerprint := gemini.NewFingerprint(cert.Raw, cert.NotAfter)
fmt.Printf(trustPrompt, hostname, fingerprint.Hex)
scanner.Scan()
switch scanner.Text() {
case "t":
knownHosts.Add(hostname, fingerprint)
knownHosts.Write(hostname, fingerprint)
return nil
case "o":
knownHosts.Add(hostname, fingerprint)
return nil
default:
return errors.New("certificate not trusted")
}
}
client.GetInput = func(prompt string, sensitive bool) (string, bool) {
fmt.Printf("%s ", prompt)
scanner.Scan()
return scanner.Text(), true
}
// Do the request
url := os.Args[1]
req, err := gemini.NewRequest(url)
@ -79,7 +126,7 @@ func main() {
if len(os.Args) == 3 {
req.Host = os.Args[2]
}
resp, err := client.Do(req)
resp, err := do(req, nil)
if err != nil {
fmt.Println(err)
os.Exit(1)