diff --git a/server.go b/server.go index 779c003..491f516 100644 --- a/server.go +++ b/server.go @@ -47,16 +47,98 @@ type Server struct { // If nil, logging is done via the log package's standard logger. ErrorLog *log.Logger - // Context is the base context to use. - // If nil, context.Background is used. - Context context.Context - - listeners map[*net.Listener]struct{} - conns map[*net.Conn]struct{} - done int32 + listeners map[*net.Listener]context.CancelFunc + conns map[*net.Conn]context.CancelFunc + doneChan chan struct{} + closed int32 mu sync.Mutex } +// done returns a channel that's closed when the server has finished closing. +func (srv *Server) done() chan struct{} { + srv.mu.Lock() + defer srv.mu.Unlock() + return srv.doneLocked() +} + +func (srv *Server) doneLocked() chan struct{} { + if srv.doneChan == nil { + srv.doneChan = make(chan struct{}) + } + return srv.doneChan +} + +// tryFinishShutdown closes srv.done() if there are no active listeners or requests. +func (srv *Server) tryFinishShutdown() { + srv.mu.Lock() + defer srv.mu.Unlock() + if len(srv.listeners) == 0 && len(srv.conns) == 0 { + done := srv.doneLocked() + select { + case <-done: + default: + close(done) + } + } +} + +// Close immediately closes all active net.Listeners and connections. +// For a graceful shutdown, use Shutdown. +func (srv *Server) Close() error { + if !atomic.CompareAndSwapInt32(&srv.closed, 0, 1) { + return ErrServerClosed + } + + // Close active listeners and connections. + srv.mu.Lock() + for _, cancel := range srv.listeners { + cancel() + } + for _, cancel := range srv.conns { + cancel() + } + srv.mu.Unlock() + + select { + case <-srv.done(): + return nil + } +} + +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open +// listeners and then waiting indefinitely for connections +// to close and then shut down. +// If the provided context expires before the shutdown is complete, +// Shutdown returns the context's error. +// +// When Shutdown is called, Serve and ListenAndServer immediately +// return ErrServerClosed. Make sure the program doesn't exit and +// waits instead for Shutdown to return. +// +// Once Shutdown has been called on a server, it may not be reused; +// future calls to methods such as Serve will return ErrServerClosed. +func (srv *Server) Shutdown(ctx context.Context) error { + if !atomic.CompareAndSwapInt32(&srv.closed, 0, 1) { + return ErrServerClosed + } + + // Close active listeners. + srv.mu.Lock() + for _, cancel := range srv.listeners { + cancel() + } + srv.mu.Unlock() + + // Wait for active connections to finish. + select { + case <-ctx.Done(): + return ctx.Err() + case <-srv.done(): + return nil + } +} + // ListenAndServe listens for requests at the server's configured address. // ListenAndServe listens on the TCP network address srv.Addr and then calls // Serve to handle requests on incoming connections. @@ -65,8 +147,8 @@ type Server struct { // // ListenAndServe always returns a non-nil error. After Shutdown or Close, the // returned error is ErrServerClosed. -func (srv *Server) ListenAndServe() error { - if atomic.LoadInt32(&srv.done) == 1 { +func (srv *Server) ListenAndServe(ctx context.Context) error { + if atomic.LoadInt32(&srv.closed) == 1 { return ErrServerClosed } @@ -75,26 +157,33 @@ func (srv *Server) ListenAndServe() error { addr = ":1965" } - ln, err := net.Listen("tcp", addr) + l, err := net.Listen("tcp", addr) if err != nil { return err } - defer ln.Close() - return srv.Serve(tls.NewListener(ln, &tls.Config{ + l = tls.NewListener(l, &tls.Config{ ClientAuth: tls.RequestClientCert, MinVersion: tls.VersionTLS12, GetCertificate: srv.getCertificate, - })) + }) + return srv.Serve(ctx, l) } -func (srv *Server) trackListener(l *net.Listener) { +func (srv *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { + if srv.GetCertificate == nil { + return nil, errors.New("gemini: GetCertificate is nil") + } + return srv.GetCertificate(h.ServerName) +} + +func (srv *Server) trackListener(l *net.Listener, cancel context.CancelFunc) { srv.mu.Lock() defer srv.mu.Unlock() if srv.listeners == nil { - srv.listeners = make(map[*net.Listener]struct{}) + srv.listeners = make(map[*net.Listener]context.CancelFunc) } - srv.listeners[l] = struct{}{} + srv.listeners[l] = cancel } func (srv *Server) deleteListener(l *net.Listener) { @@ -109,24 +198,46 @@ func (srv *Server) deleteListener(l *net.Listener) { // // Serve always returns a non-nil error and closes l. After Shutdown or Close, // the returned error is ErrServerClosed. -func (srv *Server) Serve(l net.Listener) error { +func (srv *Server) Serve(ctx context.Context, l net.Listener) error { defer l.Close() - srv.trackListener(&l) - defer srv.deleteListener(&l) - - if atomic.LoadInt32(&srv.done) == 1 { + if atomic.LoadInt32(&srv.closed) == 1 { return ErrServerClosed } - var tempDelay time.Duration // how long to sleep on accept failure + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + srv.trackListener(&l, cancel) + defer srv.tryFinishShutdown() + defer srv.deleteListener(&l) + + errch := make(chan error, 1) + go func() { + errch <- srv.serve(ctx, l) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errch: + return err + } +} + +func (srv *Server) serve(ctx context.Context, l net.Listener) error { + // how long to sleep on accept failure + var tempDelay time.Duration for { rw, err := l.Accept() if err != nil { - if atomic.LoadInt32(&srv.done) == 1 { - return ErrServerClosed + select { + case <-ctx.Done(): + return ctx.Err() + default: } + // If this is a temporary error, sleep if ne, ok := err.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { @@ -142,115 +253,21 @@ func (srv *Server) Serve(l net.Listener) error { continue } - // Otherwise, return the error return err } tempDelay = 0 - go srv.respond(rw) + go srv.serveConn(ctx, rw) } } -func (srv *Server) closeListenersLocked() error { - var err error - for ln := range srv.listeners { - if cerr := (*ln).Close(); cerr != nil && err == nil { - err = cerr - } - delete(srv.listeners, ln) - } - return err -} - -// Close immediately closes all active net.Listeners and connections. -// For a graceful shutdown, use Shutdown. -// -// Close returns any error returned from closing the Server's -// underlying Listener(s). -func (srv *Server) Close() error { - srv.mu.Lock() - defer srv.mu.Unlock() - if !atomic.CompareAndSwapInt32(&srv.done, 0, 1) { - return ErrServerClosed - } - err := srv.closeListenersLocked() - - // Close active connections - for conn := range srv.conns { - (*conn).Close() - delete(srv.conns, conn) - } - return err -} - -func (srv *Server) numConns() int { - srv.mu.Lock() - defer srv.mu.Unlock() - return len(srv.conns) -} - -// shutdownPollInterval is how often we poll for quiescence -// during Server.Shutdown. This is lower during tests, to -// speed up tests. -// Ideally we could find a solution that doesn't involve polling, -// but which also doesn't have a high runtime cost (and doesn't -// involve any contentious mutexes), but that is left as an -// exercise for the reader. -var shutdownPollInterval = 500 * time.Millisecond - -// Shutdown gracefully shuts down the server without interrupting any -// active connections. Shutdown works by first closing all open -// listeners and then waiting indefinitely for connections -// to close and then shut down. -// If the provided context expires before the shutdown is complete, -// Shutdown returns the context's error, otherwise it returns any -// error returned from closing the Server's underlying Listener(s). -// -// When Shutdown is called, Serve, ListenAndServe, and -// ListenAndServeTLS immediately return ErrServerClosed. Make sure the -// program doesn't exit and waits instead for Shutdown to return. -// -// Once Shutdown has been called on a server, it may not be reused; -// future calls to methods such as Serve will return ErrServerClosed. -func (srv *Server) Shutdown(ctx context.Context) error { - if !atomic.CompareAndSwapInt32(&srv.done, 0, 1) { - return ErrServerClosed - } - - srv.mu.Lock() - err := srv.closeListenersLocked() - srv.mu.Unlock() - - // Wait for active connections to close - ticker := time.NewTicker(shutdownPollInterval) - defer ticker.Stop() - for { - if srv.numConns() == 0 { - return err - } - - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - } - } -} - -func (srv *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { - if srv.GetCertificate == nil { - return nil, errors.New("gemini: GetCertificate is nil") - } - return srv.GetCertificate(h.ServerName) -} - -func (srv *Server) trackConn(conn *net.Conn) { +func (srv *Server) trackConn(conn *net.Conn, cancel context.CancelFunc) { srv.mu.Lock() defer srv.mu.Unlock() if srv.conns == nil { - srv.conns = make(map[*net.Conn]struct{}) + srv.conns = make(map[*net.Conn]context.CancelFunc) } - srv.conns[conn] = struct{}{} + srv.conns[conn] = cancel } func (srv *Server) deleteConn(conn *net.Conn) { @@ -259,10 +276,22 @@ func (srv *Server) deleteConn(conn *net.Conn) { delete(srv.conns, conn) } -// respond responds to a connection. -func (srv *Server) respond(conn net.Conn) { +// serveConn serves a Gemini response over the provided connection. +// It closes the connection when the response has been completed. +func (srv *Server) serveConn(ctx context.Context, conn net.Conn) { defer conn.Close() + if atomic.LoadInt32(&srv.closed) == 1 { + return + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + srv.trackConn(&conn, cancel) + defer srv.tryFinishShutdown() + defer srv.deleteConn(&conn) + defer func() { if err := recover(); err != nil && err != ErrAbortHandler { const size = 64 << 10 @@ -272,9 +301,6 @@ func (srv *Server) respond(conn net.Conn) { } }() - srv.trackConn(&conn) - defer srv.deleteConn(&conn) - if d := srv.ReadTimeout; d != 0 { conn.SetReadDeadline(time.Now().Add(d)) } @@ -282,16 +308,29 @@ func (srv *Server) respond(conn net.Conn) { conn.SetWriteDeadline(time.Now().Add(d)) } + done := make(chan struct{}) + go func() { + srv.respond(ctx, conn) + close(done) + }() + + select { + case <-ctx.Done(): + case <-done: + } +} + +func (srv *Server) respond(ctx context.Context, conn net.Conn) { w := newResponseWriter(conn) + defer w.Flush() req, err := ReadRequest(conn) if err != nil { w.WriteHeader(StatusBadRequest, "Bad request") - w.Flush() return } - // Store information about the TLS connection + // Store the TLS connection state if tlsConn, ok := conn.(*tls.Conn); ok { state := tlsConn.ConnectionState() req.TLS = &state @@ -304,20 +343,10 @@ func (srv *Server) respond(conn net.Conn) { h := srv.Handler if h == nil { w.WriteHeader(StatusNotFound, "Not found") - w.Flush() return } - ctx := srv.context() h.ServeGemini(ctx, w, req) - w.Flush() -} - -func (srv *Server) context() context.Context { - if srv.Context != nil { - return srv.Context - } - return context.Background() } func (srv *Server) logf(format string, args ...interface{}) {