diff --git a/server.go b/server.go index 8484eec..7574782 100644 --- a/server.go +++ b/server.go @@ -7,7 +7,6 @@ import ( "log" "net" "sync" - "sync/atomic" "time" ) @@ -49,7 +48,7 @@ type Server struct { listeners map[*net.Listener]context.CancelFunc conns map[*net.Conn]context.CancelFunc doneChan chan struct{} - closed int32 + closed bool mu sync.Mutex } @@ -67,6 +66,22 @@ func (srv *Server) doneLocked() chan struct{} { return srv.doneChan } +func (srv *Server) isClosed() bool { + srv.mu.Lock() + defer srv.mu.Unlock() + return srv.closed +} + +func (srv *Server) tryClose() bool { + srv.mu.Lock() + defer srv.mu.Unlock() + if srv.closed { + return false + } + srv.closed = true + return true +} + // tryFinishShutdown closes srv.done() if there are no active listeners or requests. func (srv *Server) tryFinishShutdown() { srv.mu.Lock() @@ -84,7 +99,7 @@ func (srv *Server) tryFinishShutdown() { // Close immediately closes all active net.Listeners and connections. // For a graceful shutdown, use Shutdown. func (srv *Server) Close() error { - if !atomic.CompareAndSwapInt32(&srv.closed, 0, 1) { + if !srv.tryClose() { return ErrServerClosed } @@ -118,7 +133,7 @@ func (srv *Server) Close() error { // Once Shutdown has been called on a server, it may not be reused; // future calls to methods such as Serve will return ErrServerClosed. func (srv *Server) Shutdown(ctx context.Context) error { - if !atomic.CompareAndSwapInt32(&srv.closed, 0, 1) { + if !srv.tryClose() { return ErrServerClosed } @@ -147,7 +162,7 @@ func (srv *Server) Shutdown(ctx context.Context) error { // ListenAndServe always returns a non-nil error. After Shutdown or Close, the // returned error is ErrServerClosed. func (srv *Server) ListenAndServe(ctx context.Context) error { - if atomic.LoadInt32(&srv.closed) == 1 { + if srv.isClosed() { return ErrServerClosed } @@ -176,13 +191,17 @@ func (srv *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, err return srv.GetCertificate(h.ServerName) } -func (srv *Server) trackListener(l *net.Listener, cancel context.CancelFunc) { +func (srv *Server) trackListener(l *net.Listener, cancel context.CancelFunc) bool { srv.mu.Lock() defer srv.mu.Unlock() + if srv.closed { + return false + } if srv.listeners == nil { srv.listeners = make(map[*net.Listener]context.CancelFunc) } srv.listeners[l] = cancel + return true } func (srv *Server) deleteListener(l *net.Listener) { @@ -200,14 +219,12 @@ func (srv *Server) deleteListener(l *net.Listener) { func (srv *Server) Serve(ctx context.Context, l net.Listener) error { defer l.Close() - if atomic.LoadInt32(&srv.closed) == 1 { - return ErrServerClosed - } - lnctx, cancel := context.WithCancel(ctx) defer cancel() - srv.trackListener(&l, cancel) + if !srv.trackListener(&l, cancel) { + return ErrServerClosed + } defer srv.tryFinishShutdown() defer srv.deleteListener(&l) @@ -218,7 +235,7 @@ func (srv *Server) Serve(ctx context.Context, l net.Listener) error { select { case <-lnctx.Done(): - if atomic.LoadInt32(&srv.closed) == 1 { + if srv.isClosed() { return ErrServerClosed } return lnctx.Err() @@ -228,21 +245,10 @@ func (srv *Server) Serve(ctx context.Context, l net.Listener) error { } func (srv *Server) serve(ctx context.Context, l net.Listener) error { - // how long to sleep on accept failure - var tempDelay time.Duration - + var tempDelay time.Duration // how long to sleep on accept failure for { rw, err := l.Accept() if err != nil { - select { - case <-ctx.Done(): - if atomic.LoadInt32(&srv.closed) == 1 { - return ErrServerClosed - } - return ctx.Err() - default: - } - // If this is a temporary error, sleep if ne, ok := err.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { @@ -257,22 +263,24 @@ func (srv *Server) serve(ctx context.Context, l net.Listener) error { time.Sleep(tempDelay) continue } - return err } - tempDelay = 0 - go srv.serveConn(ctx, rw) + go srv.ServeConn(ctx, rw) } } -func (srv *Server) trackConn(conn *net.Conn, cancel context.CancelFunc) { +func (srv *Server) trackConn(conn *net.Conn, cancel context.CancelFunc) bool { srv.mu.Lock() defer srv.mu.Unlock() + if srv.closed { + return false + } if srv.conns == nil { srv.conns = make(map[*net.Conn]context.CancelFunc) } srv.conns[conn] = cancel + return true } func (srv *Server) deleteConn(conn *net.Conn) { @@ -281,19 +289,17 @@ func (srv *Server) deleteConn(conn *net.Conn) { delete(srv.conns, conn) } -// serveConn serves a Gemini response over the provided connection. +// ServeConn serves a Gemini response over the provided connection. // It closes the connection when the response has been completed. -func (srv *Server) serveConn(ctx context.Context, conn net.Conn) { +func (srv *Server) ServeConn(ctx context.Context, conn net.Conn) error { defer conn.Close() - if atomic.LoadInt32(&srv.closed) == 1 { - return - } - ctx, cancel := context.WithCancel(ctx) defer cancel() - srv.trackConn(&conn, cancel) + if !srv.trackConn(&conn, cancel) { + return ErrServerClosed + } defer srv.tryFinishShutdown() defer srv.deleteConn(&conn) @@ -304,26 +310,26 @@ func (srv *Server) serveConn(ctx context.Context, conn net.Conn) { conn.SetWriteDeadline(time.Now().Add(d)) } - done := make(chan struct{}) + errch := make(chan error, 1) go func() { - srv.respond(ctx, conn) - close(done) + errch <- srv.serveConn(ctx, conn) }() select { case <-ctx.Done(): - case <-done: + return ctx.Err() + case err := <-errch: + return err } } -func (srv *Server) respond(ctx context.Context, conn net.Conn) { +func (srv *Server) serveConn(ctx context.Context, conn net.Conn) error { w := newResponseWriter(conn) - defer w.Flush() req, err := ReadRequest(conn) if err != nil { w.WriteHeader(StatusBadRequest, "Bad request") - return + return w.Flush() } // Store the TLS connection state @@ -339,10 +345,11 @@ func (srv *Server) respond(ctx context.Context, conn net.Conn) { h := srv.Handler if h == nil { w.WriteHeader(StatusNotFound, "Not found") - return + return w.Flush() } h.ServeGemini(ctx, w, req) + return w.Flush() } func (srv *Server) logf(format string, args ...interface{}) {