diff --git a/gemini.go b/gemini.go index 86b31ae..0462dfb 100644 --- a/gemini.go +++ b/gemini.go @@ -11,5 +11,12 @@ var ( ErrInvalidURL = errors.New("gemini: invalid URL") ErrInvalidRequest = errors.New("gemini: invalid request") ErrInvalidResponse = errors.New("gemini: invalid response") - ErrBodyNotAllowed = errors.New("gemini: response body not allowed") + + // ErrBodyNotAllowed is returned by ResponseWriter.Write calls + // when the response status code does not permit a body. + ErrBodyNotAllowed = errors.New("gemini: response status code does not allow body") + + // ErrServerClosed is returned by the Server's Serve and ListenAndServe + // methods after a call to Shutdown or Close. + ErrServerClosed = errors.New("gemini: server closed") ) diff --git a/server.go b/server.go index 806ede6..daeb097 100644 --- a/server.go +++ b/server.go @@ -1,11 +1,14 @@ package gemini import ( + "context" "crypto/tls" "errors" "log" "net" "strings" + "sync" + "sync/atomic" "time" "git.sr.ht/~adnano/go-gemini/certificate" @@ -47,6 +50,11 @@ type Server struct { // registered handlers handlers map[handlerKey]Handler hosts map[string]bool + + listeners map[*net.Listener]struct{} + conns map[*net.Conn]struct{} + done int32 + mu sync.Mutex } type handlerKey struct { @@ -62,6 +70,9 @@ type handlerKey struct { // Wildcard patterns are supported (e.g. "*.example.com"). // To handle any hostname, use the wildcard pattern "*". func (srv *Server) Handle(pattern string, handler Handler) { + srv.mu.Lock() + defer srv.mu.Unlock() + if pattern == "" { panic("gemini: invalid pattern") } @@ -101,7 +112,6 @@ func (srv *Server) HandleFunc(pattern string, handler func(ResponseWriter, *Requ // // If srv.Addr is blank, ":1965" is used. // -// TODO: // ListenAndServe always returns a non-nil error. After Shutdown or Close, the // returned error is ErrServerClosed. func (srv *Server) ListenAndServe() error { @@ -123,19 +133,45 @@ func (srv *Server) ListenAndServe() error { })) } +func (srv *Server) trackListener(l *net.Listener) { + srv.mu.Lock() + defer srv.mu.Unlock() + if srv.listeners == nil { + srv.listeners = make(map[*net.Listener]struct{}) + } + srv.listeners[l] = struct{}{} +} + +func (srv *Server) deleteListener(l *net.Listener) { + srv.mu.Lock() + defer srv.mu.Unlock() + delete(srv.listeners, l) +} + // Serve accepts incoming connections on the Listener l, creating a new // service goroutine for each. The service goroutines read requests and // then calls the appropriate Handler to reply to them. // -// TODO: // Serve always returns a non-nil error and closes l. After Shutdown or Close, // the returned error is ErrServerClosed. func (srv *Server) Serve(l net.Listener) error { + defer l.Close() + + if atomic.LoadInt32(&srv.done) == 1 { + return ErrServerClosed + } + + srv.trackListener(&l) + defer srv.deleteListener(&l) + var tempDelay time.Duration // how long to sleep on accept failure for { rw, err := l.Accept() if err != nil { + if atomic.LoadInt32(&srv.done) == 1 { + return ErrServerClosed + } // If this is a temporary error, sleep if ne, ok := err.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { @@ -160,6 +196,92 @@ func (srv *Server) Serve(l net.Listener) error { } } +func (srv *Server) closeListenersLocked() error { + var err error + for ln := range srv.listeners { + if cerr := (*ln).Close(); cerr != nil && err == nil { + err = cerr + } + delete(srv.listeners, ln) + } + return err +} + +// Close immediately closes all active net.Listeners and connections. +// For a graceful shutdown, use Shutdown. +// +// Close returns any error returned from closing the Server's +// underlying Listener(s). +func (srv *Server) Close() error { + srv.mu.Lock() + defer srv.mu.Unlock() + if !atomic.CompareAndSwapInt32(&srv.done, 0, 1) { + return ErrServerClosed + } + err := srv.closeListenersLocked() + + // Close active connections + for conn := range srv.conns { + (*conn).Close() + delete(srv.conns, conn) + } + return err +} + +func (srv *Server) numConns() int { + srv.mu.Lock() + defer srv.mu.Unlock() + return len(srv.conns) +} + +// shutdownPollInterval is how often we poll for quiescence +// during Server.Shutdown. This is lower during tests, to +// speed up tests. +// Ideally we could find a solution that doesn't involve polling, +// but which also doesn't have a high runtime cost (and doesn't +// involve any contentious mutexes), but that is left as an +// exercise for the reader. +var shutdownPollInterval = 500 * time.Millisecond + +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open +// listeners and then waiting indefinitely for connections +// to close and then shut down. +// If the provided context expires before the shutdown is complete, +// Shutdown returns the context's error, otherwise it returns any +// error returned from closing the Server's underlying Listener(s). +// +// When Shutdown is called, Serve, ListenAndServe, and +// ListenAndServeTLS immediately return ErrServerClosed. Make sure the +// program doesn't exit and waits instead for Shutdown to return. +// +// Once Shutdown has been called on a server, it may not be reused; +// future calls to methods such as Serve will return ErrServerClosed. +func (srv *Server) Shutdown(ctx context.Context) error { + if !atomic.CompareAndSwapInt32(&srv.done, 0, 1) { + return ErrServerClosed + } + + srv.mu.Lock() + err := srv.closeListenersLocked() + srv.mu.Unlock() + + // Wait for active connections to close + ticker := time.NewTicker(shutdownPollInterval) + defer ticker.Stop() + for { + if srv.numConns() == 0 { + return err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + // getCertificate retrieves a certificate for the given client hello. func (srv *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { cert, err := srv.lookupCertificate(h.ServerName, h.ServerName) @@ -207,9 +329,28 @@ func (srv *Server) lookupCertificate(pattern, hostname string) (*tls.Certificate return &cert, nil } +func (srv *Server) trackConn(conn *net.Conn) { + srv.mu.Lock() + defer srv.mu.Unlock() + if srv.conns == nil { + srv.conns = make(map[*net.Conn]struct{}) + } + srv.conns[conn] = struct{}{} +} + +func (srv *Server) deleteConn(conn *net.Conn) { + srv.mu.Lock() + defer srv.mu.Unlock() + delete(srv.conns, conn) +} + // respond responds to a connection. func (srv *Server) respond(conn net.Conn) { defer conn.Close() + + srv.trackConn(&conn) + defer srv.deleteConn(&conn) + if d := srv.ReadTimeout; d != 0 { conn.SetReadDeadline(time.Now().Add(d)) }