diff --git a/examples/board/main.go b/examples/board/main.go index 2384392..3f516bb 100644 --- a/examples/board/main.go +++ b/examples/board/main.go @@ -12,7 +12,7 @@ type Post struct { } type Board struct { - service.Service + *service.Service root string mux *http.ServeMux @@ -25,15 +25,11 @@ type Board struct { func main () { board := Board { root: "/board/" } board.mux = http.NewServeMux() - board.Service = service.Service { - &service.HTTP { - Mount: service.M ( - "Board", - "A board where you can post things.", - "@", board.root), - Handler: board.mux, - }, - } + board.Service = service.NewService ( + "Board", + "A board where you can post things.", + service.NewHTTP("@", board.root, board.mux)) + handle := func (pattern string, handler func (http.ResponseWriter, *http.Request)) { board.mux.HandleFunc(pattern, handler) } diff --git a/service/handle.go b/service/handle.go index ee935de..53bb1ca 100644 --- a/service/handle.go +++ b/service/handle.go @@ -10,7 +10,7 @@ func (mount *HTTP) handle (request protocol.MessageHTTPRequest) { bodyReader := &bodyReader { id: request.ID, reader: mount.requests.readerFor(request.ID), - send: mount.send, + send: mount.conn.Send, close: func () { mount.requests.end(request.ID) }, } defer mount.requests.remove(request.ID) @@ -28,10 +28,10 @@ func (mount *HTTP) handle (request protocol.MessageHTTPRequest) { handler.ServeHTTP(&responseWriter { id: request.ID, header: make(http.Header), - send: mount.send, + send: mount.conn.Send, }, httpRequest) - mount.send(protocol.MessageHTTPBodyEnd { ID: request.ID }) + mount.conn.Send(protocol.MessageHTTPBodyEnd { ID: request.ID }) } type bodyReader struct { diff --git a/service/http.go b/service/http.go index e214a1d..31b0ecd 100644 --- a/service/http.go +++ b/service/http.go @@ -1,17 +1,13 @@ package service import "log" -import "net" -import "sync" -import "bufio" import "errors" import "net/http" import "hnakra/protocol" // HTTP is an https:// mount. type HTTP struct { - // Mount specifies the mount config to use for connecting to the router. - Mount MountConfig + MountInfo // AllowInsecure allows this mount to respond to plain-text HTTP // requests. You can get a TLS cert for free nowadays so there are very @@ -22,82 +18,71 @@ type HTTP struct { // http.DefaultServeMux is used. Handler http.Handler - conn net.Conn - connLock sync.Mutex - connReadWriter *bufio.ReadWriter + conn *Conn requests requestManager - idFactory *protocol.IDFactory } // Close closes the mount abruptly, interrupting any active connections. -func (mount *HTTP) Close () error { - mount.connLock.Lock() - defer mount.connLock.Unlock() - return mount.conn.Close() +func (htmount *HTTP) Close () error { + return htmount.conn.Close() } // Shutdown gracefully shuts down the service without interrupting any active // connections. -func (mount *HTTP) Shutdown () error { +func (htmount *HTTP) Shutdown () error { // TODO - return mount.Close() + return htmount.Close() } // Run connects to the router, and blocks while fulfilling requests. This method // will only return when the connection to the router has been closed. -func (mount *HTTP) Run () (err error) { - if mount.AllowInsecure { - mount.Mount.Scheme = "http" +func (htmount *HTTP) Run (service ServiceInfo) (err error) { + if htmount.AllowInsecure { + htmount.MountInfo.Scheme = "http" } else { - mount.Mount.Scheme = "https" + htmount.MountInfo.Scheme = "https" } - mount.conn, mount.idFactory, err = mount.Mount.Connect() + htmount.conn, err = Dial(htmount.MountInfo, service) if err != nil { return } - mount.connReadWriter = bufio.NewReadWriter ( - bufio.NewReader(mount.conn), - bufio.NewWriter(mount.conn)) - mount.requests.init() + htmount.requests.init() for { - message, err := protocol.ReadMessage(mount.connReadWriter) + message, err := htmount.conn.Receive() if err != nil { return err } switch message.(type) { case protocol.MessageHTTPRequest: request := message.(protocol.MessageHTTPRequest) - mount.requests.add(request.ID) - go mount.handle(request) + htmount.requests.add(request.ID) + go htmount.handle(request) case protocol.MessageHTTPBodySegment: segment := message.(protocol.MessageHTTPBodySegment) - mount.requests.feed(segment.ID, segment.Data) + htmount.requests.feed(segment.ID, segment.Data) case protocol.MessageHTTPBodyEnd: end := message.(protocol.MessageHTTPBodyEnd) - mount.requests.end(end.ID) + htmount.requests.end(end.ID) case protocol.MessageStatus: status := message.(protocol.MessageStatus) log.Println("router says:", status.Status) default: - mount.Close() + htmount.Close() return errors.New("router sent unknown type code") } } } -// NewHTTP creates a very basic https:// mount with the specified name and -// description. -func NewHTTP (name, description, host, path string) *HTTP { - return &HTTP { Mount: M(name, description, host, path) } -} - -func (mount *HTTP) send (message protocol.Message) (err error) { - mount.connLock.Lock() - defer mount.connLock.Unlock() - err = message.Send(mount.connReadWriter) - if err != nil { return } - return mount.connReadWriter.Flush() +// NewHTTP creates a new HTTPS mount that uses the specified handler. +func NewHTTP (host, path string, handler http.Handler) *HTTP { + return &HTTP { + MountInfo: MountInfo { + Host: host, + Path: path, + }, + Handler: handler, + } } diff --git a/service/mount.go b/service/mount.go index 11a7bcb..061c4ea 100644 --- a/service/mount.go +++ b/service/mount.go @@ -4,39 +4,33 @@ import "os" import "log" import "fmt" import "net" +import "sync" +import "bufio" import "errors" import "strings" import "crypto/tls" import "encoding/base64" import "hnakra/protocol" -// M creates a very basic MountConfig with the specified name and description. -func M (name, description, host, path string) MountConfig { - return MountConfig { - Host: host, - Path: path, - Name: name, - Description: description, - } -} - // Mount is an interface satisfied by all mount types. type Mount interface { + Run (ServiceInfo) error Close () error Shutdown () error - Run () error } -// MountConfig contains generic information common to all mounts. -type MountConfig struct { +// MountInfo contains information about a mount point. +type MountInfo struct { // Host specifies the host to mount on. If the host is left empty, it // will default to @ (meaning default/any host). The port is entirely up - // to the router. Maximum length for host portion: 255 bytes + // to the router. + // Maximum length: 255 bytes Host string // Scheme specifies the protocol to mount on. This will be automatically // set by specialized mount types, so setting it manually shouldn't be // needed. + // Maximum length: 255 bytes Scheme string // Path specifies the path to mount on. If the path ends with a /, then @@ -45,7 +39,37 @@ type MountConfig struct { // path exactly (when normalized). // Maximum length: 2^16-1 bytes Path string +} +// String returns a string representation of the mount. +func (mount *MountInfo) String () string { + return mount.Scheme + "://" + mount.Host + mount.Path +} + +// FillDefault fills most empty fields with a hard-coded default value. +func (mount *MountInfo) FillDefault () { + if mount.Host == "" { mount.Host = "@" } + if mount.Path == "" { mount.Scheme = "/" } +} + +// Fits returns an error if any data is too big to send over the connection. +func (mount *MountInfo) Fits () error { + switch { + case len(mount.Host) > 255: + return errors.New("host cannot be longer than 255 bytes") + case len(mount.Scheme) > 255: + return errors.New("scheme cannot be longer than 255 bytes") + case len(mount.Path) > int(protocol.MaxIntOfSize(2)): + return errors.New(fmt.Sprint ( + "mount point path cannot be longer than ", + protocol.MaxIntOfSize(2), " bytes")) + default: return nil + } +} + +// ServiceInfo contains information about the service as a whole, such as a +// human readable description and login credentials. +type ServiceInfo struct { // Router specifies the host:port of the router to connect to. This // defaults to $HNAKRA_ROUTER_HOST:$HNAKRA_ROUTER_PORT if left empty. // The default value of these environment variables (if not set) is @@ -84,18 +108,14 @@ type MountConfig struct { // set InsecureSkipVerify to false, consider instead setting the // environment variables $SSL_CERT_FILE or $SSL_CERT_DIR to point toward // a custom root certificate. - TLSConfig *tls.Config + TLSConfig *tls.Config } -// Connect creates a new connection to the router specified in the MountConfig. -func (mount *MountConfig) Connect () ( - conn net.Conn, - idFactory *protocol.IDFactory, - err error, -) { - log.Println("(i) service", mount.Name) - idFactory = protocol.NewServiceIDFactory() - +// FillDefault fills most empty fields with values from environment variables. +// If an environment variable is blank, it uses a hard-coded default value +// instead. +func (service *ServiceInfo) FillDefault () (err error) { + // host defaultRouterHost := os.Getenv("HNAKRA_ROUTER_HOST") if defaultRouterHost == "" { defaultRouterHost = "localhost" @@ -104,111 +124,136 @@ func (mount *MountConfig) Connect () ( if defaultRouterPort == "" { defaultRouterPort = "2048" } - - // parse router host/port - routerHost, routerPort, _ := strings.Cut(mount.Router, ":") + routerHost, routerPort, _ := strings.Cut(service.Router, ":") if routerHost == "" { routerHost = defaultRouterHost } if routerPort == "" { routerPort = defaultRouterPort } + service.Router = routerHost + ":" + routerPort - // get mount point - scheme := mount.Scheme - host := mount.Host - if host == "" { - host = "@" - } - if len(host) > 255 { - return nil, nil, errors.New ( - "mount point host cannot be longer than 255 bytes") - } - path := mount.Path - if path == "" { - path = "/" - } - if len(path) > int(protocol.MaxIntOfSize(2)) { - return nil, nil, errors.New(fmt.Sprint ( - "mount point path cannot be longer than ", - protocol.MaxIntOfSize(2), " bytes")) + // user + if service.User == "" { + service.User = os.Getenv("HNAKRA_USER") } - // get user - user := mount.User - if user == "" { - user = os.Getenv("HNAKRA_USER") - } - if len(user) > 255 { - return nil, nil, errors.New ( - "user cannot be longer than 255 bytes") - } - - // get key - key := mount.Key - if key == nil { + // key + if service.Key == nil { base64Key := os.Getenv("HNAKRA_KEY") - key, err = base64.StdEncoding.DecodeString(base64Key) - if err != nil { return nil, nil, err } - } - if len(key) > 255 { - return nil, nil, errors.New ( - "key cannot be longer than 255 bytes") + service.Key, err = base64.StdEncoding.DecodeString(base64Key) + if err != nil { return } } - // ensure name/description aren't too big - if len(mount.Name) > 255 { - return nil, nil, errors.New ( - "service name cannot be longer than 255 bytes") + return +} + +// Fits returns an error if any data is too big to send over the connection. +func (service *ServiceInfo) Fits () (err error) { + switch { + case len(service.Name) > 255: + return errors.New("name cannot be longer than 255 bytes") + case len(service.Description) > 255: + return errors.New("description cannot be longer than 255 bytes") + case len(service.User) > 255: + return errors.New("user cannot be longer than 255 bytes") + case len(service.Key) > 255: + return errors.New("key cannot be longer than 255 bytes") + default: return nil } - if len(mount.Description) > 255 { - return nil, nil, errors.New ( - "service description cannot be longer than 255 bytes") +} + +// Conn represents a connection to a router. +type Conn struct { + IDFactory *protocol.IDFactory + + conn net.Conn + writeLock sync.Mutex + readWriter *bufio.ReadWriter +} + +// Dial connects to a router, returning the resulting connection. It handles +// performing the login sequence and sets ID(0) as active automatically. +func Dial (mount MountInfo, service ServiceInfo) (conn *Conn, err error) { + // fill in default values from env variables and such + mount.FillDefault() + err = service.FillDefault() + if err != nil { return nil, err } + + // sanity check + err = mount.Fits() + if err != nil { return nil, err } + err = service.Fits() + if err != nil { return nil, err } + + conn = &Conn { + IDFactory: protocol.NewServiceIDFactory(), } // connect to router - routerAddr := fmt.Sprint(routerHost, ":", routerPort) - log.Println("... dialing", routerAddr) - conn, err = tls.Dial("tcp", routerAddr, mount.TLSConfig) - if err != nil { return nil, nil, err } + log.Println("... dialing", service.Router) + conn.conn, err = tls.Dial("tcp", service.Router, service.TLSConfig) + if err != nil { return nil, err } + conn.readWriter = bufio.NewReadWriter ( + bufio.NewReader(conn.conn), + bufio.NewWriter(conn.conn)) // log in - log.Println ( - "... logging in as", user, - "on", scheme + "://" + host + path) - err = protocol.MessageLogin { - ID: idFactory.Next(), + log.Println("... logging in as", service.User, "on", mount) + err = conn.Send(protocol.MessageLogin { + ID: conn.IDFactory.Next(), Version: protocol.Version { Major: 0, Minor: 0 }, - User: user, - Key: key, - Name: mount.Name, - Description: mount.Description, - Scheme: scheme, - Host: host, - Path: path, - }.Send(conn) + User: service.User, + Key: service.Key, + Name: service.Name, + Description: service.Description, + Scheme: mount.Scheme, + Host: mount.Host, + Path: mount.Path, + }) if err != nil { conn.Close() - return nil, nil, err + return nil, err } // read status - message, err := protocol.ReadMessage(conn) + message, err := conn.Receive() if err != nil { conn.Close() - return nil, nil, err + return nil, err } status, ok := message.(protocol.MessageStatus) if !ok { conn.Close() - return nil, nil, errors.New(fmt.Sprint ( + return nil, errors.New(fmt.Sprint ( "router sent unknown type, expecting", protocol.TypeStatus)) } if status.Status != protocol.StatusOk { - return nil, nil, status + return nil, status } log.Println(".// logged in") - return conn, idFactory, nil + return conn, nil +} + +// Send sends a message along the connection, along with its type code. This +// method may be called concurrently. +func (conn *Conn) Send (message protocol.Message) (err error) { + conn.writeLock.Lock() + defer conn.writeLock.Unlock() + err = message.Send(conn.readWriter) + if err != nil { return } + return conn.readWriter.Flush() +} + +// Receive recieves a message from the connection. This method may not be called +// concurrently. +func (conn *Conn) Receive () (message protocol.Message, err error) { + return protocol.ReadMessage(conn.conn) +} + +// Close closes the connection. +func (conn *Conn) Close () error { + return conn.conn.Close() } diff --git a/service/service.go b/service/service.go index f246bef..ada29a2 100644 --- a/service/service.go +++ b/service/service.go @@ -9,12 +9,27 @@ import "hnakra/routines" // Service is capable of managing multiple mounts. It also sets up logging // automatically. -type Service []Mount +type Service struct { + ServiceInfo + Mounts []Mount +} + +// NewService provides a shorthand for creating a new service, leaving most +// values to their default. +func NewService (name, description string, mounts ...Mount) *Service { + return &Service { + ServiceInfo: ServiceInfo { + Name: name, + Description: description, + }, + Mounts: mounts, + } +} // Run runs the mounts within the service, and only exits when all of them have // exited. It will automatically start logging to the directory specified by // $HNAKRA_LOG_DIR. If that variable is unset, it will just log to stdout. -func (service Service) Run () error { +func (service *Service) Run () error { // set up logging logDir := os.Getenv("HNAKRA_LOG_DIR") if logDir != "" { @@ -22,12 +37,16 @@ func (service Service) Run () error { if err != nil { log.Fatal("cannot access log dir:", err) } log.SetOutput(logger) } + + log.Println("... starting service", service.Name) // set up routine manager manager := routines.Manager { RestartDeadline: time.Second * 8 } - manager.Routines = make([]routines.Routine, len(service)) - for index, mount := range service { - manager.Routines[index] = mount.Run + manager.Routines = make([]routines.Routine, len(service.Mounts)) + for index, mount := range service.Mounts { + manager.Routines[index] = func () error { + return mount.Run(service.ServiceInfo) + } } // send it @@ -38,8 +57,8 @@ func (service Service) Run () error { // Close abruptly closes all mounts in the service. This will cause Run() to // exit. -func (service Service) Close () (err error) { - for _, mount := range service { +func (service *Service) Close () (err error) { + for _, mount := range service.Mounts { singleErr := mount.Close() if singleErr != nil { err = singleErr @@ -50,8 +69,8 @@ func (service Service) Close () (err error) { // Shutdown gracefully shuts down each mount in the service. This will cause // Run() to exit. -func (service Service) Shutdown () (err error) { - for _, mount := range service { +func (service *Service) Shutdown () (err error) { + for _, mount := range service.Mounts { singleErr := mount.Shutdown() if singleErr != nil { err = singleErr