client: Remove GetInput and CheckRedirect callbacks
This commit is contained in:
		
							parent
							
								
									a09cb5a23c
								
							
						
					
					
						commit
						e2c907a7f6
					
				
							
								
								
									
										53
									
								
								client.go
									
									
									
									
									
								
							
							
						
						
									
										53
									
								
								client.go
									
									
									
									
									
								
							| @ -6,7 +6,6 @@ import ( | |||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"crypto/x509" | 	"crypto/x509" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/url" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| @ -20,15 +19,6 @@ type Client struct { | |||||||
| 	// and the request will be aborted. | 	// and the request will be aborted. | ||||||
| 	TrustCertificate func(hostname string, cert *x509.Certificate) error | 	TrustCertificate func(hostname string, cert *x509.Certificate) error | ||||||
| 
 | 
 | ||||||
| 	// GetInput is called to retrieve input when the server requests it. |  | ||||||
| 	// If GetInput is nil or returns false, no input will be sent and |  | ||||||
| 	// the response will be returned. |  | ||||||
| 	GetInput func(prompt string, sensitive bool) (input string, ok bool) |  | ||||||
| 
 |  | ||||||
| 	// CheckRedirect determines whether to follow a redirect. |  | ||||||
| 	// If CheckRedirect is nil, redirects will not be followed. |  | ||||||
| 	CheckRedirect func(req *Request, via []*Request) error |  | ||||||
| 
 |  | ||||||
| 	// Timeout specifies a time limit for requests made by this | 	// Timeout specifies a time limit for requests made by this | ||||||
| 	// Client. The timeout includes connection time and reading | 	// Client. The timeout includes connection time and reading | ||||||
| 	// the response body. The timer remains running after | 	// the response body. The timer remains running after | ||||||
| @ -49,10 +39,6 @@ func (c *Client) Get(url string) (*Response, error) { | |||||||
| 
 | 
 | ||||||
| // Do performs a Gemini request and returns a Gemini response. | // Do performs a Gemini request and returns a Gemini response. | ||||||
| func (c *Client) Do(req *Request) (*Response, error) { | func (c *Client) Do(req *Request) (*Response, error) { | ||||||
| 	return c.do(req, nil) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (c *Client) do(req *Request, via []*Request) (*Response, error) { |  | ||||||
| 	// Extract hostname | 	// Extract hostname | ||||||
| 	colonPos := strings.LastIndex(req.Host, ":") | 	colonPos := strings.LastIndex(req.Host, ":") | ||||||
| 	if colonPos == -1 { | 	if colonPos == -1 { | ||||||
| @ -75,7 +61,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { | |||||||
| 		}, | 		}, | ||||||
| 		ServerName: hostname, | 		ServerName: hostname, | ||||||
| 	} | 	} | ||||||
| 
 | 	// Set connection context | ||||||
| 	ctx := req.Context | 	ctx := req.Context | ||||||
| 	if ctx == nil { | 	if ctx == nil { | ||||||
| 		ctx = context.Background() | 		ctx = context.Background() | ||||||
| @ -106,43 +92,6 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { | |||||||
| 	// Store connection state | 	// Store connection state | ||||||
| 	resp.TLS = conn.ConnectionState() | 	resp.TLS = conn.ConnectionState() | ||||||
| 
 | 
 | ||||||
| 	switch resp.Status.Class() { |  | ||||||
| 	case StatusClassInput: |  | ||||||
| 		if c.GetInput == nil { |  | ||||||
| 			break |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput) |  | ||||||
| 		if ok { |  | ||||||
| 			req.URL.ForceQuery = true |  | ||||||
| 			req.URL.RawQuery = QueryEscape(input) |  | ||||||
| 			return c.do(req, via) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 	case StatusClassRedirect: |  | ||||||
| 		if c.CheckRedirect == nil { |  | ||||||
| 			break |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if via == nil { |  | ||||||
| 			via = []*Request{} |  | ||||||
| 		} |  | ||||||
| 		via = append(via, req) |  | ||||||
| 
 |  | ||||||
| 		target, err := url.Parse(resp.Meta) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return resp, err |  | ||||||
| 		} |  | ||||||
| 		target = req.URL.ResolveReference(target) |  | ||||||
| 
 |  | ||||||
| 		redirect := NewRequestFromURL(target) |  | ||||||
| 		redirect.Context = req.Context |  | ||||||
| 		if err := c.CheckRedirect(redirect, via); err != nil { |  | ||||||
| 			return resp, err |  | ||||||
| 		} |  | ||||||
| 		return c.do(redirect, via) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return resp, nil | 	return resp, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -9,6 +9,7 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"log" | 	"log" | ||||||
|  | 	"net/url" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
| 	"time" | 	"time" | ||||||
| @ -17,6 +18,22 @@ import ( | |||||||
| 	"git.sr.ht/~adnano/go-xdg" | 	"git.sr.ht/~adnano/go-xdg" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | var ( | ||||||
|  | 	hosts   gemini.KnownHostsFile | ||||||
|  | 	scanner *bufio.Scanner | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func init() { | ||||||
|  | 	// Load known hosts file | ||||||
|  | 	path := filepath.Join(xdg.DataHome(), "gemini", "known_hosts") | ||||||
|  | 	err := hosts.Load(path) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Println(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	scanner = bufio.NewScanner(os.Stdin) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| const trustPrompt = `The certificate offered by %s is of unknown trust. Its fingerprint is: | const trustPrompt = `The certificate offered by %s is of unknown trust. Its fingerprint is: | ||||||
| %s | %s | ||||||
| 
 | 
 | ||||||
| @ -26,23 +43,8 @@ Otherwise, this should be safe to trust. | |||||||
| [t]rust always; trust [o]nce; [a]bort | [t]rust always; trust [o]nce; [a]bort | ||||||
| => ` | => ` | ||||||
| 
 | 
 | ||||||
| func main() { | func trustCertificate(hostname string, cert *x509.Certificate) error { | ||||||
| 	if len(os.Args) < 2 { | 	knownHost, ok := hosts.Lookup(hostname) | ||||||
| 		fmt.Printf("usage: %s <url> [host]", os.Args[0]) |  | ||||||
| 		os.Exit(1) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Load known hosts file |  | ||||||
| 	var knownHosts gemini.KnownHostsFile |  | ||||||
| 	if err := knownHosts.Load(filepath.Join(xdg.DataHome(), "gemini", "known_hosts")); err != nil { |  | ||||||
| 		log.Println(err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	scanner := bufio.NewScanner(os.Stdin) |  | ||||||
| 
 |  | ||||||
| 	var client gemini.Client |  | ||||||
| 	client.TrustCertificate = func(hostname string, cert *x509.Certificate) error { |  | ||||||
| 		knownHost, ok := knownHosts.Lookup(hostname) |  | ||||||
| 	if ok && time.Now().Before(knownHost.Expires) { | 	if ok && time.Now().Before(knownHost.Expires) { | ||||||
| 		// Certificate is in known hosts file and is not expired | 		// Certificate is in known hosts file and is not expired | ||||||
| 		return nil | 		return nil | ||||||
| @ -53,22 +55,67 @@ func main() { | |||||||
| 	scanner.Scan() | 	scanner.Scan() | ||||||
| 	switch scanner.Text() { | 	switch scanner.Text() { | ||||||
| 	case "t": | 	case "t": | ||||||
| 			knownHosts.Add(hostname, fingerprint) | 		hosts.Add(hostname, fingerprint) | ||||||
| 			knownHosts.Write(hostname, fingerprint) | 		hosts.Write(hostname, fingerprint) | ||||||
| 		return nil | 		return nil | ||||||
| 	case "o": | 	case "o": | ||||||
| 			knownHosts.Add(hostname, fingerprint) | 		hosts.Add(hostname, fingerprint) | ||||||
| 		return nil | 		return nil | ||||||
| 	default: | 	default: | ||||||
| 		return errors.New("certificate not trusted") | 		return errors.New("certificate not trusted") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 	client.GetInput = func(prompt string, sensitive bool) (string, bool) { | 
 | ||||||
|  | func getInput(prompt string, sensitive bool) (input string, ok bool) { | ||||||
| 	fmt.Printf("%s ", prompt) | 	fmt.Printf("%s ", prompt) | ||||||
| 	scanner.Scan() | 	scanner.Scan() | ||||||
| 	return scanner.Text(), true | 	return scanner.Text(), true | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func do(req *gemini.Request, via []*gemini.Request) (*gemini.Response, error) { | ||||||
|  | 	client := gemini.Client{ | ||||||
|  | 		TrustCertificate: trustCertificate, | ||||||
|  | 	} | ||||||
|  | 	resp, err := client.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return resp, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	switch resp.Status.Class() { | ||||||
|  | 	case gemini.StatusClassInput: | ||||||
|  | 		input, ok := getInput(resp.Meta, resp.Status == gemini.StatusSensitiveInput) | ||||||
|  | 		if !ok { | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 		req.URL.ForceQuery = true | ||||||
|  | 		req.URL.RawQuery = gemini.QueryEscape(input) | ||||||
|  | 		return do(req, via) | ||||||
|  | 
 | ||||||
|  | 	case gemini.StatusClassRedirect: | ||||||
|  | 		via = append(via, req) | ||||||
|  | 		if len(via) > 5 { | ||||||
|  | 			return resp, errors.New("too many redirects") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		target, err := url.Parse(resp.Meta) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return resp, err | ||||||
|  | 		} | ||||||
|  | 		target = req.URL.ResolveReference(target) | ||||||
|  | 		redirect := *req | ||||||
|  | 		redirect.URL = target | ||||||
|  | 		return do(&redirect, via) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return resp, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func main() { | ||||||
|  | 	if len(os.Args) < 2 { | ||||||
|  | 		fmt.Printf("usage: %s <url> [host]\n", os.Args[0]) | ||||||
|  | 		os.Exit(1) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Do the request | 	// Do the request | ||||||
| 	url := os.Args[1] | 	url := os.Args[1] | ||||||
| 	req, err := gemini.NewRequest(url) | 	req, err := gemini.NewRequest(url) | ||||||
| @ -79,7 +126,7 @@ func main() { | |||||||
| 	if len(os.Args) == 3 { | 	if len(os.Args) == 3 { | ||||||
| 		req.Host = os.Args[2] | 		req.Host = os.Args[2] | ||||||
| 	} | 	} | ||||||
| 	resp, err := client.Do(req) | 	resp, err := do(req, nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		fmt.Println(err) | 		fmt.Println(err) | ||||||
| 		os.Exit(1) | 		os.Exit(1) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user