server: Use explicit context arguments
Replace the Server.Context field with explicit context.Context arguments to most Server functions.
This commit is contained in:
parent
0c8c945eba
commit
f6505ae4c4
307
server.go
307
server.go
@ -47,16 +47,98 @@ type Server struct {
|
|||||||
// If nil, logging is done via the log package's standard logger.
|
// If nil, logging is done via the log package's standard logger.
|
||||||
ErrorLog *log.Logger
|
ErrorLog *log.Logger
|
||||||
|
|
||||||
// Context is the base context to use.
|
listeners map[*net.Listener]context.CancelFunc
|
||||||
// If nil, context.Background is used.
|
conns map[*net.Conn]context.CancelFunc
|
||||||
Context context.Context
|
doneChan chan struct{}
|
||||||
|
closed int32
|
||||||
listeners map[*net.Listener]struct{}
|
|
||||||
conns map[*net.Conn]struct{}
|
|
||||||
done int32
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// done returns a channel that's closed when the server has finished closing.
|
||||||
|
func (srv *Server) done() chan struct{} {
|
||||||
|
srv.mu.Lock()
|
||||||
|
defer srv.mu.Unlock()
|
||||||
|
return srv.doneLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) doneLocked() chan struct{} {
|
||||||
|
if srv.doneChan == nil {
|
||||||
|
srv.doneChan = make(chan struct{})
|
||||||
|
}
|
||||||
|
return srv.doneChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryFinishShutdown closes srv.done() if there are no active listeners or requests.
|
||||||
|
func (srv *Server) tryFinishShutdown() {
|
||||||
|
srv.mu.Lock()
|
||||||
|
defer srv.mu.Unlock()
|
||||||
|
if len(srv.listeners) == 0 && len(srv.conns) == 0 {
|
||||||
|
done := srv.doneLocked()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
default:
|
||||||
|
close(done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close immediately closes all active net.Listeners and connections.
|
||||||
|
// For a graceful shutdown, use Shutdown.
|
||||||
|
func (srv *Server) Close() error {
|
||||||
|
if !atomic.CompareAndSwapInt32(&srv.closed, 0, 1) {
|
||||||
|
return ErrServerClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close active listeners and connections.
|
||||||
|
srv.mu.Lock()
|
||||||
|
for _, cancel := range srv.listeners {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
for _, cancel := range srv.conns {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
srv.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-srv.done():
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown gracefully shuts down the server without interrupting any
|
||||||
|
// active connections. Shutdown works by first closing all open
|
||||||
|
// listeners and then waiting indefinitely for connections
|
||||||
|
// to close and then shut down.
|
||||||
|
// If the provided context expires before the shutdown is complete,
|
||||||
|
// Shutdown returns the context's error.
|
||||||
|
//
|
||||||
|
// When Shutdown is called, Serve and ListenAndServer immediately
|
||||||
|
// return ErrServerClosed. Make sure the program doesn't exit and
|
||||||
|
// waits instead for Shutdown to return.
|
||||||
|
//
|
||||||
|
// Once Shutdown has been called on a server, it may not be reused;
|
||||||
|
// future calls to methods such as Serve will return ErrServerClosed.
|
||||||
|
func (srv *Server) Shutdown(ctx context.Context) error {
|
||||||
|
if !atomic.CompareAndSwapInt32(&srv.closed, 0, 1) {
|
||||||
|
return ErrServerClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close active listeners.
|
||||||
|
srv.mu.Lock()
|
||||||
|
for _, cancel := range srv.listeners {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
srv.mu.Unlock()
|
||||||
|
|
||||||
|
// Wait for active connections to finish.
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-srv.done():
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ListenAndServe listens for requests at the server's configured address.
|
// ListenAndServe listens for requests at the server's configured address.
|
||||||
// ListenAndServe listens on the TCP network address srv.Addr and then calls
|
// ListenAndServe listens on the TCP network address srv.Addr and then calls
|
||||||
// Serve to handle requests on incoming connections.
|
// Serve to handle requests on incoming connections.
|
||||||
@ -65,8 +147,8 @@ type Server struct {
|
|||||||
//
|
//
|
||||||
// ListenAndServe always returns a non-nil error. After Shutdown or Close, the
|
// ListenAndServe always returns a non-nil error. After Shutdown or Close, the
|
||||||
// returned error is ErrServerClosed.
|
// returned error is ErrServerClosed.
|
||||||
func (srv *Server) ListenAndServe() error {
|
func (srv *Server) ListenAndServe(ctx context.Context) error {
|
||||||
if atomic.LoadInt32(&srv.done) == 1 {
|
if atomic.LoadInt32(&srv.closed) == 1 {
|
||||||
return ErrServerClosed
|
return ErrServerClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,26 +157,33 @@ func (srv *Server) ListenAndServe() error {
|
|||||||
addr = ":1965"
|
addr = ":1965"
|
||||||
}
|
}
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", addr)
|
l, err := net.Listen("tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer ln.Close()
|
|
||||||
|
|
||||||
return srv.Serve(tls.NewListener(ln, &tls.Config{
|
l = tls.NewListener(l, &tls.Config{
|
||||||
ClientAuth: tls.RequestClientCert,
|
ClientAuth: tls.RequestClientCert,
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
GetCertificate: srv.getCertificate,
|
GetCertificate: srv.getCertificate,
|
||||||
}))
|
})
|
||||||
|
return srv.Serve(ctx, l)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) trackListener(l *net.Listener) {
|
func (srv *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
if srv.GetCertificate == nil {
|
||||||
|
return nil, errors.New("gemini: GetCertificate is nil")
|
||||||
|
}
|
||||||
|
return srv.GetCertificate(h.ServerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) trackListener(l *net.Listener, cancel context.CancelFunc) {
|
||||||
srv.mu.Lock()
|
srv.mu.Lock()
|
||||||
defer srv.mu.Unlock()
|
defer srv.mu.Unlock()
|
||||||
if srv.listeners == nil {
|
if srv.listeners == nil {
|
||||||
srv.listeners = make(map[*net.Listener]struct{})
|
srv.listeners = make(map[*net.Listener]context.CancelFunc)
|
||||||
}
|
}
|
||||||
srv.listeners[l] = struct{}{}
|
srv.listeners[l] = cancel
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) deleteListener(l *net.Listener) {
|
func (srv *Server) deleteListener(l *net.Listener) {
|
||||||
@ -109,24 +198,46 @@ func (srv *Server) deleteListener(l *net.Listener) {
|
|||||||
//
|
//
|
||||||
// Serve always returns a non-nil error and closes l. After Shutdown or Close,
|
// Serve always returns a non-nil error and closes l. After Shutdown or Close,
|
||||||
// the returned error is ErrServerClosed.
|
// the returned error is ErrServerClosed.
|
||||||
func (srv *Server) Serve(l net.Listener) error {
|
func (srv *Server) Serve(ctx context.Context, l net.Listener) error {
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
|
|
||||||
srv.trackListener(&l)
|
if atomic.LoadInt32(&srv.closed) == 1 {
|
||||||
defer srv.deleteListener(&l)
|
|
||||||
|
|
||||||
if atomic.LoadInt32(&srv.done) == 1 {
|
|
||||||
return ErrServerClosed
|
return ErrServerClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
var tempDelay time.Duration // how long to sleep on accept failure
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
srv.trackListener(&l, cancel)
|
||||||
|
defer srv.tryFinishShutdown()
|
||||||
|
defer srv.deleteListener(&l)
|
||||||
|
|
||||||
|
errch := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errch <- srv.serve(ctx, l)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case err := <-errch:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) serve(ctx context.Context, l net.Listener) error {
|
||||||
|
// how long to sleep on accept failure
|
||||||
|
var tempDelay time.Duration
|
||||||
|
|
||||||
for {
|
for {
|
||||||
rw, err := l.Accept()
|
rw, err := l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if atomic.LoadInt32(&srv.done) == 1 {
|
select {
|
||||||
return ErrServerClosed
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this is a temporary error, sleep
|
// If this is a temporary error, sleep
|
||||||
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
||||||
if tempDelay == 0 {
|
if tempDelay == 0 {
|
||||||
@ -142,115 +253,21 @@ func (srv *Server) Serve(l net.Listener) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, return the error
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tempDelay = 0
|
tempDelay = 0
|
||||||
go srv.respond(rw)
|
go srv.serveConn(ctx, rw)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) closeListenersLocked() error {
|
func (srv *Server) trackConn(conn *net.Conn, cancel context.CancelFunc) {
|
||||||
var err error
|
|
||||||
for ln := range srv.listeners {
|
|
||||||
if cerr := (*ln).Close(); cerr != nil && err == nil {
|
|
||||||
err = cerr
|
|
||||||
}
|
|
||||||
delete(srv.listeners, ln)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close immediately closes all active net.Listeners and connections.
|
|
||||||
// For a graceful shutdown, use Shutdown.
|
|
||||||
//
|
|
||||||
// Close returns any error returned from closing the Server's
|
|
||||||
// underlying Listener(s).
|
|
||||||
func (srv *Server) Close() error {
|
|
||||||
srv.mu.Lock()
|
|
||||||
defer srv.mu.Unlock()
|
|
||||||
if !atomic.CompareAndSwapInt32(&srv.done, 0, 1) {
|
|
||||||
return ErrServerClosed
|
|
||||||
}
|
|
||||||
err := srv.closeListenersLocked()
|
|
||||||
|
|
||||||
// Close active connections
|
|
||||||
for conn := range srv.conns {
|
|
||||||
(*conn).Close()
|
|
||||||
delete(srv.conns, conn)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) numConns() int {
|
|
||||||
srv.mu.Lock()
|
|
||||||
defer srv.mu.Unlock()
|
|
||||||
return len(srv.conns)
|
|
||||||
}
|
|
||||||
|
|
||||||
// shutdownPollInterval is how often we poll for quiescence
|
|
||||||
// during Server.Shutdown. This is lower during tests, to
|
|
||||||
// speed up tests.
|
|
||||||
// Ideally we could find a solution that doesn't involve polling,
|
|
||||||
// but which also doesn't have a high runtime cost (and doesn't
|
|
||||||
// involve any contentious mutexes), but that is left as an
|
|
||||||
// exercise for the reader.
|
|
||||||
var shutdownPollInterval = 500 * time.Millisecond
|
|
||||||
|
|
||||||
// Shutdown gracefully shuts down the server without interrupting any
|
|
||||||
// active connections. Shutdown works by first closing all open
|
|
||||||
// listeners and then waiting indefinitely for connections
|
|
||||||
// to close and then shut down.
|
|
||||||
// If the provided context expires before the shutdown is complete,
|
|
||||||
// Shutdown returns the context's error, otherwise it returns any
|
|
||||||
// error returned from closing the Server's underlying Listener(s).
|
|
||||||
//
|
|
||||||
// When Shutdown is called, Serve, ListenAndServe, and
|
|
||||||
// ListenAndServeTLS immediately return ErrServerClosed. Make sure the
|
|
||||||
// program doesn't exit and waits instead for Shutdown to return.
|
|
||||||
//
|
|
||||||
// Once Shutdown has been called on a server, it may not be reused;
|
|
||||||
// future calls to methods such as Serve will return ErrServerClosed.
|
|
||||||
func (srv *Server) Shutdown(ctx context.Context) error {
|
|
||||||
if !atomic.CompareAndSwapInt32(&srv.done, 0, 1) {
|
|
||||||
return ErrServerClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
srv.mu.Lock()
|
|
||||||
err := srv.closeListenersLocked()
|
|
||||||
srv.mu.Unlock()
|
|
||||||
|
|
||||||
// Wait for active connections to close
|
|
||||||
ticker := time.NewTicker(shutdownPollInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
if srv.numConns() == 0 {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case <-ticker.C:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
||||||
if srv.GetCertificate == nil {
|
|
||||||
return nil, errors.New("gemini: GetCertificate is nil")
|
|
||||||
}
|
|
||||||
return srv.GetCertificate(h.ServerName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) trackConn(conn *net.Conn) {
|
|
||||||
srv.mu.Lock()
|
srv.mu.Lock()
|
||||||
defer srv.mu.Unlock()
|
defer srv.mu.Unlock()
|
||||||
if srv.conns == nil {
|
if srv.conns == nil {
|
||||||
srv.conns = make(map[*net.Conn]struct{})
|
srv.conns = make(map[*net.Conn]context.CancelFunc)
|
||||||
}
|
}
|
||||||
srv.conns[conn] = struct{}{}
|
srv.conns[conn] = cancel
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) deleteConn(conn *net.Conn) {
|
func (srv *Server) deleteConn(conn *net.Conn) {
|
||||||
@ -259,10 +276,22 @@ func (srv *Server) deleteConn(conn *net.Conn) {
|
|||||||
delete(srv.conns, conn)
|
delete(srv.conns, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// respond responds to a connection.
|
// serveConn serves a Gemini response over the provided connection.
|
||||||
func (srv *Server) respond(conn net.Conn) {
|
// It closes the connection when the response has been completed.
|
||||||
|
func (srv *Server) serveConn(ctx context.Context, conn net.Conn) {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
if atomic.LoadInt32(&srv.closed) == 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
srv.trackConn(&conn, cancel)
|
||||||
|
defer srv.tryFinishShutdown()
|
||||||
|
defer srv.deleteConn(&conn)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil && err != ErrAbortHandler {
|
if err := recover(); err != nil && err != ErrAbortHandler {
|
||||||
const size = 64 << 10
|
const size = 64 << 10
|
||||||
@ -272,9 +301,6 @@ func (srv *Server) respond(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
srv.trackConn(&conn)
|
|
||||||
defer srv.deleteConn(&conn)
|
|
||||||
|
|
||||||
if d := srv.ReadTimeout; d != 0 {
|
if d := srv.ReadTimeout; d != 0 {
|
||||||
conn.SetReadDeadline(time.Now().Add(d))
|
conn.SetReadDeadline(time.Now().Add(d))
|
||||||
}
|
}
|
||||||
@ -282,16 +308,29 @@ func (srv *Server) respond(conn net.Conn) {
|
|||||||
conn.SetWriteDeadline(time.Now().Add(d))
|
conn.SetWriteDeadline(time.Now().Add(d))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
srv.respond(ctx, conn)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) respond(ctx context.Context, conn net.Conn) {
|
||||||
w := newResponseWriter(conn)
|
w := newResponseWriter(conn)
|
||||||
|
defer w.Flush()
|
||||||
|
|
||||||
req, err := ReadRequest(conn)
|
req, err := ReadRequest(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(StatusBadRequest, "Bad request")
|
w.WriteHeader(StatusBadRequest, "Bad request")
|
||||||
w.Flush()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store information about the TLS connection
|
// Store the TLS connection state
|
||||||
if tlsConn, ok := conn.(*tls.Conn); ok {
|
if tlsConn, ok := conn.(*tls.Conn); ok {
|
||||||
state := tlsConn.ConnectionState()
|
state := tlsConn.ConnectionState()
|
||||||
req.TLS = &state
|
req.TLS = &state
|
||||||
@ -304,20 +343,10 @@ func (srv *Server) respond(conn net.Conn) {
|
|||||||
h := srv.Handler
|
h := srv.Handler
|
||||||
if h == nil {
|
if h == nil {
|
||||||
w.WriteHeader(StatusNotFound, "Not found")
|
w.WriteHeader(StatusNotFound, "Not found")
|
||||||
w.Flush()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := srv.context()
|
|
||||||
h.ServeGemini(ctx, w, req)
|
h.ServeGemini(ctx, w, req)
|
||||||
w.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) context() context.Context {
|
|
||||||
if srv.Context != nil {
|
|
||||||
return srv.Context
|
|
||||||
}
|
|
||||||
return context.Background()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) logf(format string, args ...interface{}) {
|
func (srv *Server) logf(format string, args ...interface{}) {
|
||||||
|
Loading…
Reference in New Issue
Block a user