cmd/stepd: Add privelege dropping

This commit is contained in:
Sasha Koshka 2024-12-11 02:00:24 -05:00
parent d24cf1cdc1
commit f8bbf7f3fc

View File

@ -6,7 +6,11 @@ import "log"
import "time"
import "slices"
import "errors"
import "syscall"
import "os/user"
import "context"
import "strings"
import "strconv"
import "net/http"
import "unicode/utf8"
import "path/filepath"
@ -27,6 +31,10 @@ func main () {
'p', "pid-file",
"Write the PID to the specified file",
"", cli.ValString)
flagUser := cli.NewInputFlag (
'u', "user",
"The user:group to run as",
"", cli.ValString)
flagLogDirectory := cli.NewInputFlag (
'l', "log-directory",
"Write logs to the specified directory",
@ -52,6 +60,7 @@ func main () {
cmd := cli.New (
"Run an HTTP server that automaticaly executes STEP files",
flagPidFile,
flagUser,
flagLogDirectory,
flagHTTPAddress,
flagHTTPErrorDocument,
@ -98,11 +107,15 @@ func main () {
if err != nil { log.Fatalln("XXX", err) }
pidFile := daemon.PidFile(pidFileAbs)
err = pidFile.Start()
if err != nil { log.Println("!!! could not write pid:", err) }
defer func () {
err := pidFile.Close()
if err != nil { log.Println("!!! could not delete pidfile:", err) }
} ()
if err != nil { log.Fatalln("XXX could not write pid:", err) }
}
// drop privelege
if flagUser.Value != "" {
log.Println("... dropping privelege to", flagUser.Value)
user, group, _ := strings.Cut(flagUser.Value, ":")
err := dropPrivelege(user, group)
if err != nil { log.Fatalln("XXX could not drop privelege:", err) }
}
// the single argument is for the directory to serve. we actually cd
@ -259,3 +272,23 @@ func logProviders (providers []step.Provider) {
}
line()
}
func dropPrivelege (usr, group string) error {
if group != "" {
groupInfo, err := user.LookupGroup(group)
if err != nil { return err }
gid, err := strconv.Atoi(groupInfo.Gid)
if err != nil { return err }
err = syscall.Setgid(gid)
if err != nil { return err }
}
if usr != "" {
usrInfo, err := user.Lookup(usr)
if err != nil { return err }
uid, err := strconv.Atoi(usrInfo.Uid)
if err != nil { return err }
err = syscall.Setuid(uid)
if err != nil { return err }
}
return nil
}