diff --git a/cmd/stepd/main.go b/cmd/stepd/main.go index b824b70..6e1b32d 100644 --- a/cmd/stepd/main.go +++ b/cmd/stepd/main.go @@ -6,7 +6,11 @@ import "log" import "time" import "slices" import "errors" +import "syscall" +import "os/user" import "context" +import "strings" +import "strconv" import "net/http" import "unicode/utf8" import "path/filepath" @@ -27,6 +31,10 @@ func main () { 'p', "pid-file", "Write the PID to the specified file", "", cli.ValString) + flagUser := cli.NewInputFlag ( + 'u', "user", + "The user:group to run as", + "", cli.ValString) flagLogDirectory := cli.NewInputFlag ( 'l', "log-directory", "Write logs to the specified directory", @@ -52,6 +60,7 @@ func main () { cmd := cli.New ( "Run an HTTP server that automaticaly executes STEP files", flagPidFile, + flagUser, flagLogDirectory, flagHTTPAddress, flagHTTPErrorDocument, @@ -98,11 +107,15 @@ func main () { if err != nil { log.Fatalln("XXX", err) } pidFile := daemon.PidFile(pidFileAbs) err = pidFile.Start() - if err != nil { log.Println("!!! could not write pid:", err) } - defer func () { - err := pidFile.Close() - if err != nil { log.Println("!!! could not delete pidfile:", err) } - } () + if err != nil { log.Fatalln("XXX could not write pid:", err) } + } + + // drop privelege + if flagUser.Value != "" { + log.Println("... dropping privelege to", flagUser.Value) + user, group, _ := strings.Cut(flagUser.Value, ":") + err := dropPrivelege(user, group) + if err != nil { log.Fatalln("XXX could not drop privelege:", err) } } // the single argument is for the directory to serve. we actually cd @@ -259,3 +272,23 @@ func logProviders (providers []step.Provider) { } line() } + +func dropPrivelege (usr, group string) error { + if group != "" { + groupInfo, err := user.LookupGroup(group) + if err != nil { return err } + gid, err := strconv.Atoi(groupInfo.Gid) + if err != nil { return err } + err = syscall.Setgid(gid) + if err != nil { return err } + } + if usr != "" { + usrInfo, err := user.Lookup(usr) + if err != nil { return err } + uid, err := strconv.Atoi(usrInfo.Uid) + if err != nil { return err } + err = syscall.Setuid(uid) + if err != nil { return err } + } + return nil +}