diff --git a/cli/app.go b/cli/app.go index 415446a0..209e3cbd 100644 --- a/cli/app.go +++ b/cli/app.go @@ -25,6 +25,7 @@ type App struct { OnExit OnExit ContextConfig func(Context, context.Context) context.Context ContextOptions []ContextOption + AllowRoot bool } type Option func(*App) @@ -33,9 +34,10 @@ type ContextOption func(*Context) func NewApp(opts ...Option) *App { app := &App{ - Stdout: os.Stdout, - Stderr: os.Stderr, - OnExit: newOnExit(), + Stdout: os.Stdout, + Stderr: os.Stderr, + OnExit: newOnExit(), + AllowRoot: false, } for _, opt := range opts { opt(app) diff --git a/cli/cliviper/cliviper_test.go b/cli/cliviper/cliviper_test.go index 136ab344..748d5c8f 100644 --- a/cli/cliviper/cliviper_test.go +++ b/cli/cliviper/cliviper_test.go @@ -22,6 +22,7 @@ func TestCLIViperApp(t *testing.T) { // set cliviper.App() option app := cli.NewApp(cliviper.App()) + app.AllowRoot = true app.Flags = []flag.Flag{ flag.StringFlag{Name: msgFlag}, } diff --git a/cli/debugapp_test.go b/cli/debugapp_test.go index 38db56ba..fdef5051 100644 --- a/cli/debugapp_test.go +++ b/cli/debugapp_test.go @@ -50,6 +50,7 @@ func TestNewDebugApp(t *testing.T) { {err: testExitCoder{error: fmt.Errorf("foo"), exitCode: 2}, wantExitCode: 2, errorStringer: testErrorStringer, wantDebugFalse: "^foo\n$", wantDebugTrue: "^error-stringer\n$"}, } { app := cli.NewApp(cli.DebugHandler(currCase.errorStringer)) + app.AllowRoot = true app.Action = func(ctx cli.Context) error { return currCase.err } diff --git a/cli/exitcode_test.go b/cli/exitcode_test.go index 3b7a59f9..bfb04029 100644 --- a/cli/exitcode_test.go +++ b/cli/exitcode_test.go @@ -28,6 +28,7 @@ import ( func main() { app := cli.NewApp() + app.AllowRoot = true app.Action = func(ctx cli.Context) error { %v } diff --git a/cli/flagvalue_test.go b/cli/flagvalue_test.go index eb178534..e5300ffc 100644 --- a/cli/flagvalue_test.go +++ b/cli/flagvalue_test.go @@ -52,6 +52,7 @@ func TestBindFlagValues(t *testing.T) { }, } { app := cli.NewApp() + app.AllowRoot = true app.Command = cli.Command{ Name: "foo", Flags: []flag.Flag{ @@ -111,6 +112,7 @@ func TestBindFlagValuesStringParam(t *testing.T) { }, } { app := cli.NewApp() + app.AllowRoot = true app.Command = cli.Command{ Name: "foo", Flags: []flag.Flag{ diff --git a/cli/parse_test.go b/cli/parse_test.go index 674167fe..ebd5afc5 100644 --- a/cli/parse_test.go +++ b/cli/parse_test.go @@ -294,6 +294,7 @@ func TestParseFlags(t *testing.T) { t.Run(currCase.name, func(t *testing.T) { app := cli.NewApp() app.Name = "test" + app.AllowRoot = true output := &bytes.Buffer{} app.Subcommands = []cli.Command{ diff --git a/cli/run.go b/cli/run.go index 1fbd0a7b..0bd954a2 100644 --- a/cli/run.go +++ b/cli/run.go @@ -37,6 +37,13 @@ func (app *App) Run(args []string) (exitStatus int) { return 0 } + if !app.AllowRoot { + if syscall.Getuid() == 0 { + ctx.Errorf("%v\n", "Root is not allowed to run this program.") + return 1 + } + } + baseContext := context.Background() if app.ContextConfig != nil { baseContext = app.ContextConfig(ctx, baseContext) diff --git a/cli/run_test.go b/cli/run_test.go index 632bf394..c8a44230 100644 --- a/cli/run_test.go +++ b/cli/run_test.go @@ -34,6 +34,7 @@ func TestRunErrorOutput(t *testing.T) { for i, currCase := range cases { app := cli.NewApp() + app.AllowRoot = true app.Action = func(ctx cli.Context) error { return currCase.err } @@ -79,6 +80,7 @@ func TestRunErrorHandler(t *testing.T) { for i, currCase := range cases { app := cli.NewApp() + app.AllowRoot = true app.ErrorHandler = currCase.handler app.Action = func(ctx cli.Context) error { return currCase.err @@ -118,6 +120,7 @@ func TestRunContext(t *testing.T) { name: "check that context is propagated to app action", check: func(t *testing.T) { app := cli.NewApp() + app.AllowRoot = true app.Command.Flags = []flag.Flag{ flag.StringFlag{ @@ -136,6 +139,7 @@ func TestRunContext(t *testing.T) { name: "check that context is propagated to app error handler", check: func(t *testing.T) { app := cli.NewApp() + app.AllowRoot = true app.Command.Flags = []flag.Flag{ flag.StringFlag{ @@ -162,6 +166,7 @@ func TestRunContext(t *testing.T) { name: "check that context is propagated to app subcommand", check: func(t *testing.T) { app := cli.NewApp() + app.AllowRoot = true app.Command.Flags = []flag.Flag{ flag.StringFlag{