diff --git a/cli/cli.go b/cli/cli.go index 43ef3f7..9574237 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -24,3 +24,14 @@ func Printf (format string, values ...any) { func ServiceUser (service string) string { return "hn-" + strings.ToLower(service) } + +// NeedRoot halts the program and displays an error if it is not being run as +// root. This should be called whenever an operation takes place that requires +// root privelages. +func NeedRoot() { + uid := os.Getuid() + if uid != 0 { + Sayf("this utility must be run as root") + os.Exit(1) + } +} diff --git a/cmd/hnctl/main.go b/cmd/hnctl/main.go new file mode 100644 index 0000000..5cda221 --- /dev/null +++ b/cmd/hnctl/main.go @@ -0,0 +1,151 @@ +package main + +import "os" +import "fmt" +import "time" +import "flag" +import "os/exec" +import "hnakra/cli" +import "path/filepath" +import "hnakra/cmd/hnctl/spawn" + +func main () { + flag.Usage = func () { + out := flag.CommandLine.Output() + fmt.Fprintf(out, "Usage of %s:\n", os.Args[0]) + fmt.Fprintf(out, " start\n") + fmt.Fprintf(out, " Start a service\n") + fmt.Fprintf(out, " stop\n") + fmt.Fprintf(out, " Stop a service\n") + fmt.Fprintf(out, " restart\n") + fmt.Fprintf(out, " Start and then stop a service\n") + os.Exit(1) + } + + // define commands + startCommand := flag.NewFlagSet("start", flag.ExitOnError) + startService := startCommand.String("s", "router", "Service to start") + + stopCommand := flag.NewFlagSet("stop", flag.ExitOnError) + stopService := stopCommand.String("s", "router", "Service to stop") + + restartCommand := flag.NewFlagSet("restart", flag.ExitOnError) + restartService := stopCommand.String("s", "router", "Service to restart") + + flag.Parse() + + // execute correct command + if len(os.Args) < 2 { + flag.Usage() + os.Exit(1) + } + subCommandArgs := os.Args[2:] + switch os.Args[1] { + case "start": + startCommand.Parse(subCommandArgs) + execStart(*startService) + case "stop": + stopCommand.Parse(subCommandArgs) + execStop(*stopService) + case "restart": + restartCommand.Parse(subCommandArgs) + execStop(*restartService) + execStart(*restartService) + } +} + +func execStart (service string) { + fullName := cli.ServiceUser(service) + cli.NeedRoot() + + pid, err := spawn.PidOf(fullName) + if err == nil && spawn.Running(pid) { + cli.Sayf("service is already running") + return + } + + uid, gid, err := spawn.LookupUID(fullName) + if err != nil { + cli.Sayf("cannot start service: %v", err) + os.Exit(1) + } + + path, err := exec.LookPath(fullName) + if err != nil { + cli.Sayf("cannot start service: %v", err) + os.Exit(1) + } + + logDir := filepath.Join("/var/log/", fullName) + env := append(os.Environ(), "HNAKRA_LOG_DIR=" + logDir) + err = ensureLogDir(logDir, int(uid), int(gid)) + if err != nil { + cli.Sayf("cannot start service: %v", err) + os.Exit(1) + } + + // prepare pidfile. the service will be responsible for actually writing + // to it + err = ensurePidFile(spawn.PidFile(fullName), int(uid), int(gid)) + if err != nil { + cli.Sayf("cannot start service: %v", err) + os.Exit(1) + } + + // spawn the service + pid, err = spawn.Spawn(path, uid, gid, env) + if err != nil { + cli.Sayf("cannot start service: %v", err) + os.Exit(1) + } + + fmt.Println(pid) +} + +func execStop (service string) { + fullName := cli.ServiceUser(service) + cli.NeedRoot() + + pid, err := spawn.PidOf(fullName) + if err != nil || !spawn.Running(pid) { + cli.Sayf("service is not running") + return + } + + process, err := os.FindProcess(pid) + if err != nil { + cli.Sayf("service is not running") + return + } + + err = spawn.KillAndWait(process, 16 * time.Second) + if err != nil { + cli.Sayf("could not stop service: %v", err) + os.Exit(1) + } +} + +func ensureLogDir (directory string, uid, gid int) error { + err := os.MkdirAll(directory, 0755) + if err != nil { return err } + err = os.Chmod(directory, 0770) + if err != nil { return err } + err = os.Chown(directory, uid, gid) + if err != nil { return err } + + return nil +} + +func ensurePidFile (file string, uid, gid int) error { + pidFile, err := os.Create(file) + if err != nil { return err } + err = pidFile.Close() + if err != nil { return err } + + err = os.Chmod(file, 0660) + if err != nil { return err } + err = os.Chown(file, uid, gid) + if err != nil { return err } + + return nil +} diff --git a/cmd/hnctl/spawn/spawn.go b/cmd/hnctl/spawn/spawn.go new file mode 100644 index 0000000..558351e --- /dev/null +++ b/cmd/hnctl/spawn/spawn.go @@ -0,0 +1,132 @@ +// Package spawn provides utilities for daemonizing services. +package spawn + +import "os" +import "fmt" +import "time" +import "errors" +import "syscall" +import "os/user" +import "strconv" +import "path/filepath" + +// Spawn spawns a process in the background and returns its PID. +func Spawn (path string, uid, gid uint32, env []string, args ...string) (pid int, err error) { + cred := &syscall.Credential{ + Uid: uid, + Gid: gid, + Groups: []uint32{}, + NoSetGroups: false, + } + + // the Noctty flag is used to detach the process from parent tty + sysproc := &syscall.SysProcAttr{ + Credential: cred, + Noctty: true, + } + attr := os.ProcAttr{ + Dir: ".", + Env: os.Environ(), + Files: []*os.File{ + os.Stdin, + nil, + nil, + }, + Sys: sysproc, + } + + process, err := os.StartProcess(path, args, &attr) + if err != nil { + return 0, err + } + + // Release() is what actually detatches the process and places it under + // init + return process.Pid, process.Release() +} + +// LookupUID returns the uid and gid of the given username, if it exists. +func LookupUID (name string) (uid, gid uint32, err error) { + user, err := user.Lookup(name) + if err != nil { + return 0, 0, err + } + + puid, err := strconv.Atoi(user.Uid) + if err != nil { + return 0, 0, err + } + pgid, err := strconv.Atoi(user.Gid) + if err != nil { + return 0, 0, err + } + return uint32(puid), uint32(pgid), nil +} + +// PidFile returns the path of a pidfile under the specified name. More +// specifically, it returns `/run/.pid`. +func PidFile (name string) string { + return filepath.Join("/run/", name + ".pid") +} + +// PidOf returns the PID stored in the pidfile of the given name as defined by +// PidFile. +func PidOf (name string) (pid int, err error) { + content, err := os.ReadFile(PidFile(name)) + if err != nil { + return 0, err + } + pid, err = strconv.Atoi(string(content)) + if err != nil { + return 0, err + } + return pid, nil +} + +// Running returns whether or not a process with the given PID is running. +func Running (pid int) bool { + directoryInfo, err := os.Stat("/proc/") + if os.IsNotExist(err) || !directoryInfo.IsDir() { + // if /proc/ does not exist, fallback to sending a signal + process, err := os.FindProcess(pid) + if err != nil { + return false + } + err = process.Signal(syscall.Signal(0)) + if err != nil { + return false + } + } else { + // if /proc/ exists, see if the process's directory exists there + _, err = os.Stat("/proc/" + strconv.Itoa(pid)) + if err != nil { + return false + } + } + + return true +} + +// KillAndWait kills a process and waits for it to finish, with a timeout. If +// the timeout is zero, it will wait indefinetly. This function will poll every +// 100 milliseconds to see if the process has finished. +func KillAndWait (process *os.Process, timeout time.Duration) error { + pid := process.Pid + err := process.Kill() + if err != nil { + return err + } + + // wait for the process to exit, with a timeout + timeoutPoint := time.Now() + for timeout == 0 || time.Since(timeoutPoint) < 16 * time.Second { + if !Running(pid) { + return nil + } + + time.Sleep(100 * time.Millisecond) + } + + return errors.New(fmt.Sprintf ( + "timeout exceeded while waiting for process %d to finish", pid)) +}