From 118e019df01fdea54fabdb95e4e2799db768919e Mon Sep 17 00:00:00 2001 From: Adnan Maolood Date: Mon, 22 Feb 2021 20:06:05 -0500 Subject: [PATCH] server: Use channel to communicate shutdown --- server.go | 147 ++++++++++++++++++++++++++---------------------------- 1 file changed, 72 insertions(+), 75 deletions(-) diff --git a/server.go b/server.go index 908ba95..a9dbba5 100644 --- a/server.go +++ b/server.go @@ -7,7 +7,6 @@ import ( "log" "net" "sync" - "sync/atomic" "time" ) @@ -46,11 +45,11 @@ type Server struct { // If nil, logging is done via the log package's standard logger. ErrorLog *log.Logger - listeners map[*net.Listener]context.CancelFunc - conns map[*net.Conn]context.CancelFunc - doneChan chan struct{} - status int32 - mu sync.Mutex + listeners map[*net.Listener]struct{} + conns map[*net.Conn]struct{} + closedChan chan struct{} // closed when the server is closed + doneChan chan struct{} // closed when no more connections are open + mu sync.Mutex } const ( @@ -59,7 +58,22 @@ const ( serverClosed ) -// done returns a channel that's closed when the server has finished closing. +// closed returns a channel that's closed when the server is closed. +func (srv *Server) closed() chan struct{} { + srv.mu.Lock() + defer srv.mu.Unlock() + return srv.closedLocked() +} + +func (srv *Server) closedLocked() chan struct{} { + if srv.closedChan == nil { + srv.closedChan = make(chan struct{}) + } + return srv.closedChan +} + +// done returns a channel that's closed when the server is closed and +// all listeners and connections are closed. func (srv *Server) done() chan struct{} { srv.mu.Lock() defer srv.mu.Unlock() @@ -73,28 +87,22 @@ func (srv *Server) doneLocked() chan struct{} { return srv.doneChan } -// rejectingListeners reports whether the server is rejecting new listeners -// (e.g. after Shutdown or Close has been called). -func (srv *Server) rejectingListeners() bool { - return atomic.LoadInt32(&srv.status) != serverOk -} - -// rejectingConns reports whether the server is rejecting new connections -// (e.g. after Close has been called). -func (srv *Server) rejectingConns() bool { - return atomic.LoadInt32(&srv.status) == serverClosed -} - -// tryFinishShutdown closes srv.done() if there are no active listeners or requests. +// tryFinishShutdown closes srv.done() if the server is closed and +// there are no active listeners or connections. func (srv *Server) tryFinishShutdown() { srv.mu.Lock() defer srv.mu.Unlock() + select { + case <-srv.closedLocked(): + default: + return + } if len(srv.listeners) == 0 && len(srv.conns) == 0 { - done := srv.doneLocked() + ch := srv.doneLocked() select { - case <-done: + case <-ch: default: - close(done) + close(ch) } } } @@ -102,17 +110,20 @@ func (srv *Server) tryFinishShutdown() { // Close immediately closes all active net.Listeners and connections. // For a graceful shutdown, use Shutdown. func (srv *Server) Close() error { - if !atomic.CompareAndSwapInt32(&srv.status, serverOk, serverClosed) { - return ErrServerClosed + ch := srv.closed() + select { + case <-ch: + return nil + default: + close(ch) } - // Close active listeners and connections. + srv.tryFinishShutdown() + + // Force all active connections to close. srv.mu.Lock() - for _, cancel := range srv.listeners { - cancel() - } - for _, cancel := range srv.conns { - cancel() + for conn := range srv.conns { + (*conn).Close() } srv.mu.Unlock() @@ -129,24 +140,22 @@ func (srv *Server) Close() error { // If the provided context expires before the shutdown is complete, // Shutdown returns the context's error. // -// When Shutdown is called, Serve and ListenAndServer immediately +// When Shutdown is called, Serve and ListenAndServe immediately // return ErrServerClosed. Make sure the program doesn't exit and // waits instead for Shutdown to return. // // Once Shutdown has been called on a server, it may not be reused; // future calls to methods such as Serve will return ErrServerClosed. func (srv *Server) Shutdown(ctx context.Context) error { - if !atomic.CompareAndSwapInt32(&srv.status, serverOk, serverShutdown) { - return ErrServerClosed + ch := srv.closed() + select { + case <-ch: + return nil + default: + close(ch) } - defer atomic.StoreInt32(&srv.status, serverClosed) - // Close active listeners. - srv.mu.Lock() - for _, cancel := range srv.listeners { - cancel() - } - srv.mu.Unlock() + srv.tryFinishShutdown() // Wait for active connections to finish. select { @@ -166,8 +175,10 @@ func (srv *Server) Shutdown(ctx context.Context) error { // ListenAndServe always returns a non-nil error. After Shutdown or Close, the // returned error is ErrServerClosed. func (srv *Server) ListenAndServe(ctx context.Context) error { - if srv.rejectingListeners() { + select { + case <-srv.closed(): return ErrServerClosed + default: } addr := srv.Addr @@ -195,16 +206,18 @@ func (srv *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, err return srv.GetCertificate(h.ServerName) } -func (srv *Server) trackListener(l *net.Listener, cancel context.CancelFunc) bool { - if srv.rejectingListeners() { - return false - } +func (srv *Server) trackListener(l *net.Listener) bool { srv.mu.Lock() defer srv.mu.Unlock() - if srv.listeners == nil { - srv.listeners = make(map[*net.Listener]context.CancelFunc) + select { + case <-srv.closedLocked(): + return false + default: } - srv.listeners[l] = cancel + if srv.listeners == nil { + srv.listeners = make(map[*net.Listener]struct{}) + } + srv.listeners[l] = struct{}{} return true } @@ -223,10 +236,7 @@ func (srv *Server) deleteListener(l *net.Listener) { func (srv *Server) Serve(ctx context.Context, l net.Listener) error { defer l.Close() - lnctx, cancel := context.WithCancel(ctx) - defer cancel() - - if !srv.trackListener(&l, cancel) { + if !srv.trackListener(&l) { return ErrServerClosed } defer srv.tryFinishShutdown() @@ -238,13 +248,12 @@ func (srv *Server) Serve(ctx context.Context, l net.Listener) error { }() select { - case <-lnctx.Done(): - if srv.rejectingListeners() { - return ErrServerClosed - } - return lnctx.Err() + case <-ctx.Done(): + return ctx.Err() case err := <-errch: return err + case <-srv.closed(): + return ErrServerClosed } } @@ -274,17 +283,13 @@ func (srv *Server) serve(ctx context.Context, l net.Listener) error { } } -func (srv *Server) trackConn(conn *net.Conn, cancel context.CancelFunc) bool { - if srv.rejectingConns() { - return false - } +func (srv *Server) trackConn(conn *net.Conn) { srv.mu.Lock() defer srv.mu.Unlock() if srv.conns == nil { - srv.conns = make(map[*net.Conn]context.CancelFunc) + srv.conns = make(map[*net.Conn]struct{}) } - srv.conns[conn] = cancel - return true + srv.conns[conn] = struct{}{} } func (srv *Server) deleteConn(conn *net.Conn) { @@ -295,16 +300,11 @@ func (srv *Server) deleteConn(conn *net.Conn) { // ServeConn serves a Gemini response over the provided connection. // It closes the connection when the response has been completed. -// Note that ServeConn will succeed even if a Shutdown is in progress. +// ServeConn can be used even after Shutdown or Close have been called. func (srv *Server) ServeConn(ctx context.Context, conn net.Conn) error { defer conn.Close() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - if !srv.trackConn(&conn, cancel) { - return ErrServerClosed - } + srv.trackConn(&conn) defer srv.tryFinishShutdown() defer srv.deleteConn(&conn) @@ -322,9 +322,6 @@ func (srv *Server) ServeConn(ctx context.Context, conn net.Conn) error { select { case <-ctx.Done(): - if srv.rejectingConns() { - return ErrServerClosed - } return ctx.Err() case err := <-errch: return err