server: Use channel to communicate shutdown

This commit is contained in:
Adnan Maolood 2021-02-22 20:06:05 -05:00
parent 2c64db3863
commit 118e019df0

147
server.go
View File

@ -7,7 +7,6 @@ import (
"log" "log"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
@ -46,11 +45,11 @@ type Server struct {
// If nil, logging is done via the log package's standard logger. // If nil, logging is done via the log package's standard logger.
ErrorLog *log.Logger ErrorLog *log.Logger
listeners map[*net.Listener]context.CancelFunc listeners map[*net.Listener]struct{}
conns map[*net.Conn]context.CancelFunc conns map[*net.Conn]struct{}
doneChan chan struct{} closedChan chan struct{} // closed when the server is closed
status int32 doneChan chan struct{} // closed when no more connections are open
mu sync.Mutex mu sync.Mutex
} }
const ( const (
@ -59,7 +58,22 @@ const (
serverClosed serverClosed
) )
// done returns a channel that's closed when the server has finished closing. // closed returns a channel that's closed when the server is closed.
func (srv *Server) closed() chan struct{} {
srv.mu.Lock()
defer srv.mu.Unlock()
return srv.closedLocked()
}
func (srv *Server) closedLocked() chan struct{} {
if srv.closedChan == nil {
srv.closedChan = make(chan struct{})
}
return srv.closedChan
}
// done returns a channel that's closed when the server is closed and
// all listeners and connections are closed.
func (srv *Server) done() chan struct{} { func (srv *Server) done() chan struct{} {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
@ -73,28 +87,22 @@ func (srv *Server) doneLocked() chan struct{} {
return srv.doneChan return srv.doneChan
} }
// rejectingListeners reports whether the server is rejecting new listeners // tryFinishShutdown closes srv.done() if the server is closed and
// (e.g. after Shutdown or Close has been called). // there are no active listeners or connections.
func (srv *Server) rejectingListeners() bool {
return atomic.LoadInt32(&srv.status) != serverOk
}
// rejectingConns reports whether the server is rejecting new connections
// (e.g. after Close has been called).
func (srv *Server) rejectingConns() bool {
return atomic.LoadInt32(&srv.status) == serverClosed
}
// tryFinishShutdown closes srv.done() if there are no active listeners or requests.
func (srv *Server) tryFinishShutdown() { func (srv *Server) tryFinishShutdown() {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
select {
case <-srv.closedLocked():
default:
return
}
if len(srv.listeners) == 0 && len(srv.conns) == 0 { if len(srv.listeners) == 0 && len(srv.conns) == 0 {
done := srv.doneLocked() ch := srv.doneLocked()
select { select {
case <-done: case <-ch:
default: default:
close(done) close(ch)
} }
} }
} }
@ -102,17 +110,20 @@ func (srv *Server) tryFinishShutdown() {
// Close immediately closes all active net.Listeners and connections. // Close immediately closes all active net.Listeners and connections.
// For a graceful shutdown, use Shutdown. // For a graceful shutdown, use Shutdown.
func (srv *Server) Close() error { func (srv *Server) Close() error {
if !atomic.CompareAndSwapInt32(&srv.status, serverOk, serverClosed) { ch := srv.closed()
return ErrServerClosed select {
case <-ch:
return nil
default:
close(ch)
} }
// Close active listeners and connections. srv.tryFinishShutdown()
// Force all active connections to close.
srv.mu.Lock() srv.mu.Lock()
for _, cancel := range srv.listeners { for conn := range srv.conns {
cancel() (*conn).Close()
}
for _, cancel := range srv.conns {
cancel()
} }
srv.mu.Unlock() srv.mu.Unlock()
@ -129,24 +140,22 @@ func (srv *Server) Close() error {
// If the provided context expires before the shutdown is complete, // If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error. // Shutdown returns the context's error.
// //
// When Shutdown is called, Serve and ListenAndServer immediately // When Shutdown is called, Serve and ListenAndServe immediately
// return ErrServerClosed. Make sure the program doesn't exit and // return ErrServerClosed. Make sure the program doesn't exit and
// waits instead for Shutdown to return. // waits instead for Shutdown to return.
// //
// Once Shutdown has been called on a server, it may not be reused; // Once Shutdown has been called on a server, it may not be reused;
// future calls to methods such as Serve will return ErrServerClosed. // future calls to methods such as Serve will return ErrServerClosed.
func (srv *Server) Shutdown(ctx context.Context) error { func (srv *Server) Shutdown(ctx context.Context) error {
if !atomic.CompareAndSwapInt32(&srv.status, serverOk, serverShutdown) { ch := srv.closed()
return ErrServerClosed select {
case <-ch:
return nil
default:
close(ch)
} }
defer atomic.StoreInt32(&srv.status, serverClosed)
// Close active listeners. srv.tryFinishShutdown()
srv.mu.Lock()
for _, cancel := range srv.listeners {
cancel()
}
srv.mu.Unlock()
// Wait for active connections to finish. // Wait for active connections to finish.
select { select {
@ -166,8 +175,10 @@ func (srv *Server) Shutdown(ctx context.Context) error {
// ListenAndServe always returns a non-nil error. After Shutdown or Close, the // ListenAndServe always returns a non-nil error. After Shutdown or Close, the
// returned error is ErrServerClosed. // returned error is ErrServerClosed.
func (srv *Server) ListenAndServe(ctx context.Context) error { func (srv *Server) ListenAndServe(ctx context.Context) error {
if srv.rejectingListeners() { select {
case <-srv.closed():
return ErrServerClosed return ErrServerClosed
default:
} }
addr := srv.Addr addr := srv.Addr
@ -195,16 +206,18 @@ func (srv *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, err
return srv.GetCertificate(h.ServerName) return srv.GetCertificate(h.ServerName)
} }
func (srv *Server) trackListener(l *net.Listener, cancel context.CancelFunc) bool { func (srv *Server) trackListener(l *net.Listener) bool {
if srv.rejectingListeners() {
return false
}
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
if srv.listeners == nil { select {
srv.listeners = make(map[*net.Listener]context.CancelFunc) case <-srv.closedLocked():
return false
default:
} }
srv.listeners[l] = cancel if srv.listeners == nil {
srv.listeners = make(map[*net.Listener]struct{})
}
srv.listeners[l] = struct{}{}
return true return true
} }
@ -223,10 +236,7 @@ func (srv *Server) deleteListener(l *net.Listener) {
func (srv *Server) Serve(ctx context.Context, l net.Listener) error { func (srv *Server) Serve(ctx context.Context, l net.Listener) error {
defer l.Close() defer l.Close()
lnctx, cancel := context.WithCancel(ctx) if !srv.trackListener(&l) {
defer cancel()
if !srv.trackListener(&l, cancel) {
return ErrServerClosed return ErrServerClosed
} }
defer srv.tryFinishShutdown() defer srv.tryFinishShutdown()
@ -238,13 +248,12 @@ func (srv *Server) Serve(ctx context.Context, l net.Listener) error {
}() }()
select { select {
case <-lnctx.Done(): case <-ctx.Done():
if srv.rejectingListeners() { return ctx.Err()
return ErrServerClosed
}
return lnctx.Err()
case err := <-errch: case err := <-errch:
return err return err
case <-srv.closed():
return ErrServerClosed
} }
} }
@ -274,17 +283,13 @@ func (srv *Server) serve(ctx context.Context, l net.Listener) error {
} }
} }
func (srv *Server) trackConn(conn *net.Conn, cancel context.CancelFunc) bool { func (srv *Server) trackConn(conn *net.Conn) {
if srv.rejectingConns() {
return false
}
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
if srv.conns == nil { if srv.conns == nil {
srv.conns = make(map[*net.Conn]context.CancelFunc) srv.conns = make(map[*net.Conn]struct{})
} }
srv.conns[conn] = cancel srv.conns[conn] = struct{}{}
return true
} }
func (srv *Server) deleteConn(conn *net.Conn) { func (srv *Server) deleteConn(conn *net.Conn) {
@ -295,16 +300,11 @@ func (srv *Server) deleteConn(conn *net.Conn) {
// ServeConn serves a Gemini response over the provided connection. // ServeConn serves a Gemini response over the provided connection.
// It closes the connection when the response has been completed. // It closes the connection when the response has been completed.
// Note that ServeConn will succeed even if a Shutdown is in progress. // ServeConn can be used even after Shutdown or Close have been called.
func (srv *Server) ServeConn(ctx context.Context, conn net.Conn) error { func (srv *Server) ServeConn(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
ctx, cancel := context.WithCancel(ctx) srv.trackConn(&conn)
defer cancel()
if !srv.trackConn(&conn, cancel) {
return ErrServerClosed
}
defer srv.tryFinishShutdown() defer srv.tryFinishShutdown()
defer srv.deleteConn(&conn) defer srv.deleteConn(&conn)
@ -322,9 +322,6 @@ func (srv *Server) ServeConn(ctx context.Context, conn net.Conn) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
if srv.rejectingConns() {
return ErrServerClosed
}
return ctx.Err() return ctx.Err()
case err := <-errch: case err := <-errch:
return err return err