diff --git a/service/http.go b/service/http.go index 0b3d608..0d66b3a 100644 --- a/service/http.go +++ b/service/http.go @@ -10,7 +10,8 @@ import "hnakra/protocol" // HTTP is an https:// mount. type HTTP struct { - Mount + // Mount specifies the mount config to use for connecting to the router. + Mount MountConfig // AllowInsecure allows this mount to respond to plain-text HTTP // requests. You can get a TLS cert for free nowadays so there are very @@ -25,6 +26,7 @@ type HTTP struct { connLock sync.Mutex connReadWriter *bufio.ReadWriter requests requestManager + idFactory *protocol.IDFactory } // Close closes the mount abruptly, interrupting any active connections. @@ -45,10 +47,11 @@ func (mount *HTTP) Shutdown () error { // will only return when the connection to the router has been closed. func (mount *HTTP) Run () (err error) { if mount.AllowInsecure { - mount.conn, err = mount.connect("http") + mount.Mount.Scheme = "http" } else { - mount.conn, err = mount.connect("https") + mount.Mount.Scheme = "https" } + mount.conn, mount.idFactory, err = mount.Mount.Connect() if err != nil { return } mount.connReadWriter = bufio.NewReadWriter ( diff --git a/service/mount.go b/service/mount.go index e2bcef0..2548854 100644 --- a/service/mount.go +++ b/service/mount.go @@ -10,21 +10,33 @@ import "crypto/tls" import "encoding/base64" import "hnakra/protocol" -// M creates a very basic mount with the specified name and description. -func M (name, description string) Mount { - return Mount { +// M creates a very basic MountConfig with the specified name and description. +func M (name, description string) MountConfig { + return MountConfig { Name: name, Description: description, } } -// Mount contains generic information common to all mounts. -type Mount struct { +// Mount is an interface satisfied by all mount types. +type Mount interface { + Close () error + Shutdown () error + Run () error +} + +// MountConfig contains generic information common to all mounts. +type MountConfig struct { // Host specifies the host to mount on. If the host is left empty, it // will default to @ (meaning default/any host). The port is entirely up // to the router. Maximum length for host portion: 255 bytes Host string + // Scheme specifies the protocol to mount on. This will be automatically + // set by specialized mount types, so setting it manually shouldn't be + // needed. + Scheme string + // Path specifies the path to mount on. If the path ends with a /, then // all requests under the path will be sent to this service. If there is // no trailing /, this service will only recieve requests that match the @@ -68,13 +80,16 @@ type Mount struct { // TLSConfig is an optional TLS configuration. TLSConfig *tls.Config - - idFactory *protocol.IDFactory } -func (mount *Mount) connect (scheme string) (conn net.Conn, err error) { +// Connect creates a new connection to the router specified in the MountConfig. +func (mount *MountConfig) Connect () ( + conn net.Conn, + idFactory *protocol.IDFactory, + err error, +) { log.Println("(i) service", mount.Name) - mount.idFactory = protocol.NewServiceIDFactory() + idFactory = protocol.NewServiceIDFactory() defaultRouterHost := os.Getenv("HNAKRA_ROUTER_HOST") if defaultRouterHost == "" { @@ -95,19 +110,21 @@ func (mount *Mount) connect (scheme string) (conn net.Conn, err error) { } // get mount point - host := mount.Host + scheme := mount.Scheme + host := mount.Host if host == "" { host = "@" } if len(host) > 255 { - return nil, errors.New("mount point host cannot be longer than 255 bytes") + return nil, nil, errors.New ( + "mount point host cannot be longer than 255 bytes") } path := mount.Path if path == "" { path = "/" } if len(path) > int(protocol.MaxIntOfSize(2)) { - return nil, errors.New(fmt.Sprint ( + return nil, nil, errors.New(fmt.Sprint ( "mount point path cannot be longer than ", protocol.MaxIntOfSize(2), " bytes")) } @@ -118,39 +135,44 @@ func (mount *Mount) connect (scheme string) (conn net.Conn, err error) { user = os.Getenv("HNAKRA_USER") } if len(user) > 255 { - return nil, errors.New("user cannot be longer than 255 bytes") + return nil, nil, errors.New ( + "user cannot be longer than 255 bytes") } // get key key := mount.Key if key == nil { - key, err = base64.StdEncoding.DecodeString(os.Getenv("HNAKRA_KEY")) - if err != nil { return nil, err } + base64Key := os.Getenv("HNAKRA_KEY") + key, err = base64.StdEncoding.DecodeString(base64Key) + if err != nil { return nil, nil, err } } if len(key) > 255 { - return nil, errors.New("key cannot be longer than 255 bytes") + return nil, nil, errors.New ( + "key cannot be longer than 255 bytes") } // ensure name/description aren't too big if len(mount.Name) > 255 { - return nil, errors.New("service name cannot be longer than 255 bytes") + return nil, nil, errors.New ( + "service name cannot be longer than 255 bytes") } if len(mount.Description) > 255 { - return nil, errors.New("service description cannot be longer than 255 bytes") + return nil, nil, errors.New ( + "service description cannot be longer than 255 bytes") } // connect to router routerAddr := fmt.Sprint(routerHost, ":", routerPort) log.Println("... dialing", routerAddr) conn, err = tls.Dial("tcp", routerAddr, mount.TLSConfig) - if err != nil { return nil, err } + if err != nil { return nil, nil, err } // log in log.Println ( "... logging in as", user, "on", scheme + "://" + host + path) err = protocol.MessageLogin { - ID: mount.idFactory.Next(), + ID: idFactory.Next(), Version: protocol.Version { Major: 0, Minor: 0 }, User: user, Key: key, @@ -162,24 +184,26 @@ func (mount *Mount) connect (scheme string) (conn net.Conn, err error) { }.Send(conn) if err != nil { conn.Close() - return nil, err + return nil, nil, err } // read status message, err := protocol.ReadMessage(conn) if err != nil { conn.Close() - return nil, err + return nil, nil, err } status, ok := message.(protocol.MessageStatus) if !ok { conn.Close() - return nil, errors.New(fmt.Sprint("router sent unknown type, expecting", protocol.TypeStatus)) + return nil, nil, errors.New(fmt.Sprint ( + "router sent unknown type, expecting", + protocol.TypeStatus)) } if status.Status != protocol.StatusOk { - return nil, status + return nil, nil, status } log.Println(".// logged in") - return conn, nil + return conn, idFactory, nil }