server: Export ServeConn method

This commit is contained in:
Adnan Maolood 2021-02-21 11:53:15 -05:00
parent bb444fb364
commit 49dac34aff

View File

@ -7,7 +7,6 @@ import (
"log" "log"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
@ -49,7 +48,7 @@ type Server struct {
listeners map[*net.Listener]context.CancelFunc listeners map[*net.Listener]context.CancelFunc
conns map[*net.Conn]context.CancelFunc conns map[*net.Conn]context.CancelFunc
doneChan chan struct{} doneChan chan struct{}
closed int32 closed bool
mu sync.Mutex mu sync.Mutex
} }
@ -67,6 +66,22 @@ func (srv *Server) doneLocked() chan struct{} {
return srv.doneChan return srv.doneChan
} }
func (srv *Server) isClosed() bool {
srv.mu.Lock()
defer srv.mu.Unlock()
return srv.closed
}
func (srv *Server) tryClose() bool {
srv.mu.Lock()
defer srv.mu.Unlock()
if srv.closed {
return false
}
srv.closed = true
return true
}
// tryFinishShutdown closes srv.done() if there are no active listeners or requests. // tryFinishShutdown closes srv.done() if there are no active listeners or requests.
func (srv *Server) tryFinishShutdown() { func (srv *Server) tryFinishShutdown() {
srv.mu.Lock() srv.mu.Lock()
@ -84,7 +99,7 @@ func (srv *Server) tryFinishShutdown() {
// Close immediately closes all active net.Listeners and connections. // Close immediately closes all active net.Listeners and connections.
// For a graceful shutdown, use Shutdown. // For a graceful shutdown, use Shutdown.
func (srv *Server) Close() error { func (srv *Server) Close() error {
if !atomic.CompareAndSwapInt32(&srv.closed, 0, 1) { if !srv.tryClose() {
return ErrServerClosed return ErrServerClosed
} }
@ -118,7 +133,7 @@ func (srv *Server) Close() error {
// Once Shutdown has been called on a server, it may not be reused; // Once Shutdown has been called on a server, it may not be reused;
// future calls to methods such as Serve will return ErrServerClosed. // future calls to methods such as Serve will return ErrServerClosed.
func (srv *Server) Shutdown(ctx context.Context) error { func (srv *Server) Shutdown(ctx context.Context) error {
if !atomic.CompareAndSwapInt32(&srv.closed, 0, 1) { if !srv.tryClose() {
return ErrServerClosed return ErrServerClosed
} }
@ -147,7 +162,7 @@ func (srv *Server) Shutdown(ctx context.Context) error {
// ListenAndServe always returns a non-nil error. After Shutdown or Close, the // ListenAndServe always returns a non-nil error. After Shutdown or Close, the
// returned error is ErrServerClosed. // returned error is ErrServerClosed.
func (srv *Server) ListenAndServe(ctx context.Context) error { func (srv *Server) ListenAndServe(ctx context.Context) error {
if atomic.LoadInt32(&srv.closed) == 1 { if srv.isClosed() {
return ErrServerClosed return ErrServerClosed
} }
@ -176,13 +191,17 @@ func (srv *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, err
return srv.GetCertificate(h.ServerName) return srv.GetCertificate(h.ServerName)
} }
func (srv *Server) trackListener(l *net.Listener, cancel context.CancelFunc) { func (srv *Server) trackListener(l *net.Listener, cancel context.CancelFunc) bool {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
if srv.closed {
return false
}
if srv.listeners == nil { if srv.listeners == nil {
srv.listeners = make(map[*net.Listener]context.CancelFunc) srv.listeners = make(map[*net.Listener]context.CancelFunc)
} }
srv.listeners[l] = cancel srv.listeners[l] = cancel
return true
} }
func (srv *Server) deleteListener(l *net.Listener) { func (srv *Server) deleteListener(l *net.Listener) {
@ -200,14 +219,12 @@ func (srv *Server) deleteListener(l *net.Listener) {
func (srv *Server) Serve(ctx context.Context, l net.Listener) error { func (srv *Server) Serve(ctx context.Context, l net.Listener) error {
defer l.Close() defer l.Close()
if atomic.LoadInt32(&srv.closed) == 1 {
return ErrServerClosed
}
lnctx, cancel := context.WithCancel(ctx) lnctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
srv.trackListener(&l, cancel) if !srv.trackListener(&l, cancel) {
return ErrServerClosed
}
defer srv.tryFinishShutdown() defer srv.tryFinishShutdown()
defer srv.deleteListener(&l) defer srv.deleteListener(&l)
@ -218,7 +235,7 @@ func (srv *Server) Serve(ctx context.Context, l net.Listener) error {
select { select {
case <-lnctx.Done(): case <-lnctx.Done():
if atomic.LoadInt32(&srv.closed) == 1 { if srv.isClosed() {
return ErrServerClosed return ErrServerClosed
} }
return lnctx.Err() return lnctx.Err()
@ -228,21 +245,10 @@ func (srv *Server) Serve(ctx context.Context, l net.Listener) error {
} }
func (srv *Server) serve(ctx context.Context, l net.Listener) error { func (srv *Server) serve(ctx context.Context, l net.Listener) error {
// how long to sleep on accept failure var tempDelay time.Duration // how long to sleep on accept failure
var tempDelay time.Duration
for { for {
rw, err := l.Accept() rw, err := l.Accept()
if err != nil { if err != nil {
select {
case <-ctx.Done():
if atomic.LoadInt32(&srv.closed) == 1 {
return ErrServerClosed
}
return ctx.Err()
default:
}
// If this is a temporary error, sleep // If this is a temporary error, sleep
if ne, ok := err.(net.Error); ok && ne.Temporary() { if ne, ok := err.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 { if tempDelay == 0 {
@ -257,22 +263,24 @@ func (srv *Server) serve(ctx context.Context, l net.Listener) error {
time.Sleep(tempDelay) time.Sleep(tempDelay)
continue continue
} }
return err return err
} }
tempDelay = 0 tempDelay = 0
go srv.serveConn(ctx, rw) go srv.ServeConn(ctx, rw)
} }
} }
func (srv *Server) trackConn(conn *net.Conn, cancel context.CancelFunc) { func (srv *Server) trackConn(conn *net.Conn, cancel context.CancelFunc) bool {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
if srv.closed {
return false
}
if srv.conns == nil { if srv.conns == nil {
srv.conns = make(map[*net.Conn]context.CancelFunc) srv.conns = make(map[*net.Conn]context.CancelFunc)
} }
srv.conns[conn] = cancel srv.conns[conn] = cancel
return true
} }
func (srv *Server) deleteConn(conn *net.Conn) { func (srv *Server) deleteConn(conn *net.Conn) {
@ -281,19 +289,17 @@ func (srv *Server) deleteConn(conn *net.Conn) {
delete(srv.conns, conn) delete(srv.conns, conn)
} }
// serveConn serves a Gemini response over the provided connection. // ServeConn serves a Gemini response over the provided connection.
// It closes the connection when the response has been completed. // It closes the connection when the response has been completed.
func (srv *Server) serveConn(ctx context.Context, conn net.Conn) { func (srv *Server) ServeConn(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
if atomic.LoadInt32(&srv.closed) == 1 {
return
}
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
srv.trackConn(&conn, cancel) if !srv.trackConn(&conn, cancel) {
return ErrServerClosed
}
defer srv.tryFinishShutdown() defer srv.tryFinishShutdown()
defer srv.deleteConn(&conn) defer srv.deleteConn(&conn)
@ -304,26 +310,26 @@ func (srv *Server) serveConn(ctx context.Context, conn net.Conn) {
conn.SetWriteDeadline(time.Now().Add(d)) conn.SetWriteDeadline(time.Now().Add(d))
} }
done := make(chan struct{}) errch := make(chan error, 1)
go func() { go func() {
srv.respond(ctx, conn) errch <- srv.serveConn(ctx, conn)
close(done)
}() }()
select { select {
case <-ctx.Done(): case <-ctx.Done():
case <-done: return ctx.Err()
case err := <-errch:
return err
} }
} }
func (srv *Server) respond(ctx context.Context, conn net.Conn) { func (srv *Server) serveConn(ctx context.Context, conn net.Conn) error {
w := newResponseWriter(conn) w := newResponseWriter(conn)
defer w.Flush()
req, err := ReadRequest(conn) req, err := ReadRequest(conn)
if err != nil { if err != nil {
w.WriteHeader(StatusBadRequest, "Bad request") w.WriteHeader(StatusBadRequest, "Bad request")
return return w.Flush()
} }
// Store the TLS connection state // Store the TLS connection state
@ -339,10 +345,11 @@ func (srv *Server) respond(ctx context.Context, conn net.Conn) {
h := srv.Handler h := srv.Handler
if h == nil { if h == nil {
w.WriteHeader(StatusNotFound, "Not found") w.WriteHeader(StatusNotFound, "Not found")
return return w.Flush()
} }
h.ServeGemini(ctx, w, req) h.ServeGemini(ctx, w, req)
return w.Flush()
} }
func (srv *Server) logf(format string, args ...interface{}) { func (srv *Server) logf(format string, args ...interface{}) {