Add Client.GetInput field

This commit is contained in:
Adnan Maolood 2020-10-27 23:35:22 -04:00
parent d1dcf070ff
commit 4c5167f590
6 changed files with 61 additions and 60 deletions

View File

@ -24,6 +24,10 @@ type Client struct {
// redirects will be enforced. // redirects will be enforced.
CheckRedirect func(req *Request, via []*Request) error CheckRedirect func(req *Request, via []*Request) error
// GetInput, if not nil, will be called to retrieve input when the server
// requests it.
GetInput func(prompt string, sensitive bool) (string, bool)
// GetCertificate, if not nil, will be called when a server requests a certificate. // GetCertificate, if not nil, will be called when a server requests a certificate.
// The returned certificate will be used when sending the request again. // The returned certificate will be used when sending the request again.
// If the certificate is nil, the request will not be sent again and // If the certificate is nil, the request will not be sent again and
@ -141,7 +145,17 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
return resp, ErrTooManyRedirects return resp, ErrTooManyRedirects
} }
return c.do(redirect, via) return c.do(redirect, via)
} else if resp.Status.Class() == StatusClassInput {
if c.GetInput != nil {
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
if ok {
req.URL.ForceQuery = true
req.URL.RawQuery = url.QueryEscape(input)
return c.do(req, via)
}
}
} }
return resp, nil return resp, nil
} }

View File

@ -7,7 +7,7 @@ import (
"fmt" "fmt"
"log" "log"
gmi "git.sr.ht/~adnano/go-gemini" "git.sr.ht/~adnano/go-gemini"
) )
type user struct { type user struct {
@ -33,15 +33,14 @@ var (
) )
func main() { func main() {
var mux gmi.ServeMux var mux gemini.ServeMux
mux.HandleFunc("/", welcome) mux.HandleFunc("/", login)
mux.HandleFunc("/login", login) mux.HandleFunc("/password", loginPassword)
mux.HandleFunc("/login/password", loginPassword)
mux.HandleFunc("/profile", profile) mux.HandleFunc("/profile", profile)
mux.HandleFunc("/admin", admin) mux.HandleFunc("/admin", admin)
mux.HandleFunc("/logout", logout) mux.HandleFunc("/logout", logout)
var server gmi.Server var server gemini.Server
if err := server.CertificateStore.Load("/var/lib/gemini/certs"); err != nil { if err := server.CertificateStore.Load("/var/lib/gemini/certs"); err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -53,74 +52,69 @@ func main() {
} }
func getSession(crt *x509.Certificate) (*session, bool) { func getSession(crt *x509.Certificate) (*session, bool) {
fingerprint := gmi.Fingerprint(crt) fingerprint := gemini.Fingerprint(crt)
session, ok := sessions[fingerprint] session, ok := sessions[fingerprint]
return session, ok return session, ok
} }
func welcome(w *gmi.ResponseWriter, r *gmi.Request) { func login(w *gemini.ResponseWriter, r *gemini.Request) {
fmt.Fprintln(w, "Welcome to this example.") cert, ok := gemini.Certificate(w, r)
fmt.Fprintln(w, "=> /login Login")
}
func login(w *gmi.ResponseWriter, r *gmi.Request) {
cert, ok := gmi.Certificate(w, r)
if !ok { if !ok {
return return
} }
username, ok := gmi.Input(w, r, "Username") username, ok := gemini.Input(w, r, "Username")
if !ok { if !ok {
return return
} }
fingerprint := gmi.Fingerprint(cert) fingerprint := gemini.Fingerprint(cert)
sessions[fingerprint] = &session{ sessions[fingerprint] = &session{
username: username, username: username,
} }
gmi.Redirect(w, "/login/password") gemini.Redirect(w, "/password")
} }
func loginPassword(w *gmi.ResponseWriter, r *gmi.Request) { func loginPassword(w *gemini.ResponseWriter, r *gemini.Request) {
cert, ok := gmi.Certificate(w, r) cert, ok := gemini.Certificate(w, r)
if !ok { if !ok {
return return
} }
session, ok := getSession(cert) session, ok := getSession(cert)
if !ok { if !ok {
w.WriteStatus(gmi.StatusCertificateNotAuthorized) w.WriteStatus(gemini.StatusCertificateNotAuthorized)
return return
} }
password, ok := gmi.SensitiveInput(w, r, "Password") password, ok := gemini.SensitiveInput(w, r, "Password")
if !ok { if !ok {
return return
} }
expected := logins[session.username].password expected := logins[session.username].password
if password == expected { if password == expected {
session.authorized = true session.authorized = true
gmi.Redirect(w, "/profile") gemini.Redirect(w, "/profile")
} else { } else {
gmi.SensitiveInput(w, r, "Wrong password. Try again") gemini.SensitiveInput(w, r, "Wrong password. Try again")
} }
} }
func logout(w *gmi.ResponseWriter, r *gmi.Request) { func logout(w *gemini.ResponseWriter, r *gemini.Request) {
cert, ok := gmi.Certificate(w, r) cert, ok := gemini.Certificate(w, r)
if !ok { if !ok {
return return
} }
fingerprint := gmi.Fingerprint(cert) fingerprint := gemini.Fingerprint(cert)
delete(sessions, fingerprint) delete(sessions, fingerprint)
fmt.Fprintln(w, "Successfully logged out.") fmt.Fprintln(w, "Successfully logged out.")
} }
func profile(w *gmi.ResponseWriter, r *gmi.Request) { func profile(w *gemini.ResponseWriter, r *gemini.Request) {
cert, ok := gmi.Certificate(w, r) cert, ok := gemini.Certificate(w, r)
if !ok { if !ok {
return return
} }
session, ok := getSession(cert) session, ok := getSession(cert)
if !ok { if !ok {
w.WriteStatus(gmi.StatusCertificateNotAuthorized) w.WriteStatus(gemini.StatusCertificateNotAuthorized)
return return
} }
user := logins[session.username] user := logins[session.username]
@ -129,19 +123,19 @@ func profile(w *gmi.ResponseWriter, r *gmi.Request) {
fmt.Fprintln(w, "=> /logout Logout") fmt.Fprintln(w, "=> /logout Logout")
} }
func admin(w *gmi.ResponseWriter, r *gmi.Request) { func admin(w *gemini.ResponseWriter, r *gemini.Request) {
cert, ok := gmi.Certificate(w, r) cert, ok := gemini.Certificate(w, r)
if !ok { if !ok {
return return
} }
session, ok := getSession(cert) session, ok := getSession(cert)
if !ok { if !ok {
w.WriteStatus(gmi.StatusCertificateNotAuthorized) w.WriteStatus(gemini.StatusCertificateNotAuthorized)
return return
} }
user := logins[session.username] user := logins[session.username]
if !user.admin { if !user.admin {
w.WriteStatus(gmi.StatusCertificateNotAuthorized) w.WriteStatus(gemini.StatusCertificateNotAuthorized)
return return
} }
fmt.Fprintln(w, "Welcome to the admin portal.") fmt.Fprintln(w, "Welcome to the admin portal.")

View File

@ -11,13 +11,13 @@ import (
"os" "os"
"time" "time"
gmi "git.sr.ht/~adnano/go-gemini" "git.sr.ht/~adnano/go-gemini"
) )
func main() { func main() {
host := "localhost" host := "localhost"
duration := 365 * 24 * time.Hour duration := 365 * 24 * time.Hour
cert, err := gmi.NewCertificate(host, duration) cert, err := gemini.NewCertificate(host, duration)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -76,13 +76,7 @@ func sendRequest(req *gemini.Request) error {
return err return err
} }
switch resp.Status.Class() { if resp.Status.Class() == gemini.StatusClassSuccess {
case gemini.StatusClassInput:
fmt.Printf("%s: ", resp.Meta)
scanner.Scan()
req.URL.RawQuery = url.QueryEscape(scanner.Text())
return sendRequest(req)
case gemini.StatusClassSuccess:
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
@ -90,20 +84,8 @@ func sendRequest(req *gemini.Request) error {
} }
fmt.Print(string(body)) fmt.Print(string(body))
return nil return nil
case gemini.StatusClassRedirect:
// This should not happen unless CheckRedirect returns false.
return fmt.Errorf("Failed to redirect to %s", resp.Meta)
case gemini.StatusClassTemporaryFailure:
return fmt.Errorf("Temporary failure: %s", resp.Meta)
case gemini.StatusClassPermanentFailure:
return fmt.Errorf("Permanent failure: %s", resp.Meta)
case gemini.StatusClassCertificateRequired:
// Note that this should not happen unless the server responds with
// CertificateRequired even after we send a certificate.
// CertificateNotAuthorized and CertificateNotValid are handled here.
return fmt.Errorf("Certificate required: %s", resp.Meta)
} }
panic("unreachable") return fmt.Errorf("request failed: %d %s: %s", resp.Status, resp.Status.Message(), resp.Meta)
} }
type trust int type trust int

View File

@ -273,7 +273,8 @@ type Responder interface {
// If no input is provided, it responds with StatusInput. // If no input is provided, it responds with StatusInput.
func Input(w *ResponseWriter, r *Request, prompt string) (string, bool) { func Input(w *ResponseWriter, r *Request, prompt string) (string, bool) {
if r.URL.ForceQuery || r.URL.RawQuery != "" { if r.URL.ForceQuery || r.URL.RawQuery != "" {
return r.URL.RawQuery, true query, err := url.QueryUnescape(r.URL.RawQuery)
return query, err == nil
} }
w.WriteHeader(StatusInput, prompt) w.WriteHeader(StatusInput, prompt)
return "", false return "", false
@ -283,7 +284,8 @@ func Input(w *ResponseWriter, r *Request, prompt string) (string, bool) {
// If no input is provided, it responds with StatusSensitiveInput. // If no input is provided, it responds with StatusSensitiveInput.
func SensitiveInput(w *ResponseWriter, r *Request, prompt string) (string, bool) { func SensitiveInput(w *ResponseWriter, r *Request, prompt string) (string, bool) {
if r.URL.ForceQuery || r.URL.RawQuery != "" { if r.URL.ForceQuery || r.URL.RawQuery != "" {
return r.URL.RawQuery, true query, err := url.QueryUnescape(r.URL.RawQuery)
return query, err == nil
} }
w.WriteHeader(StatusSensitiveInput, prompt) w.WriteHeader(StatusSensitiveInput, prompt)
return "", false return "", false

View File

@ -30,11 +30,20 @@ func (s Status) Class() StatusClass {
} }
// Message returns a status message corresponding to this status code. // Message returns a status message corresponding to this status code.
// It returns an empty string for input, successs, and redirect status codes.
func (s Status) Message() string { func (s Status) Message() string {
switch s { switch s {
case StatusInput:
return "Input"
case StatusSensitiveInput:
return "Sensitive input"
case StatusSuccess:
return "Success"
case StatusRedirect:
return "Redirect"
case StatusRedirectPermanent:
return "Permanent redirect"
case StatusTemporaryFailure: case StatusTemporaryFailure:
return "TemporaryFailure" return "Temporary failure"
case StatusServerUnavailable: case StatusServerUnavailable:
return "Server unavailable" return "Server unavailable"
case StatusCGIError: case StatusCGIError:
@ -44,7 +53,7 @@ func (s Status) Message() string {
case StatusSlowDown: case StatusSlowDown:
return "Slow down" return "Slow down"
case StatusPermanentFailure: case StatusPermanentFailure:
return "PermanentFailure" return "Permanent failure"
case StatusNotFound: case StatusNotFound:
return "Not found" return "Not found"
case StatusGone: case StatusGone: