diff --git a/internal/cmd/beta/session_cmd.go b/internal/cmd/beta/session_cmd.go index fe6e2ec69..9a3802a2b 100644 --- a/internal/cmd/beta/session_cmd.go +++ b/internal/cmd/beta/session_cmd.go @@ -1,12 +1,15 @@ package beta import ( + "context" + "io" "os" "os/exec" "strconv" "github.com/pkg/errors" "github.com/spf13/cobra" + "go.uber.org/multierr" "go.uber.org/zap" "github.com/stateful/runme/v3/internal/command" @@ -30,49 +33,25 @@ All exported variables during the session will be available to the subsequent co ) error { defer logger.Sync() - envCollector, err := command.NewEnvCollectorFactory().Build() - if err != nil { - return errors.WithStack(err) - } - - cfg := &command.ProgramConfig{ - ProgramName: defaultShell(), - Mode: runnerv2.CommandMode_COMMAND_MODE_CLI, - Env: append([]string{"RUNME_SESSION=1"}, envCollector.ExtraEnv()...), - } - options := command.CommandOptions{ - NoShell: true, - Stdin: cmd.InOrStdin(), - Stdout: cmd.OutOrStdout(), - Stderr: cmd.ErrOrStderr(), - } - - program, err := cmdFactory.Build(cfg, options) - if err != nil { - return err - } - - err = program.Start(cmd.Context()) - if err != nil { - return err - } - - err = program.Wait() + envs, err := executeDefaultShellProgram( + cmd.Context(), + cmdFactory, + cmd.InOrStdin(), + cmd.OutOrStdout(), + cmd.ErrOrStderr(), + nil, + ) if err != nil { return err } - changed, _, err := envCollector.Diff() - if err != nil { - return errors.WithStack(err) - } - // TODO(adamb): currently, the collected env are printed out, // but they could be put in a session. if _, err := cmd.ErrOrStderr().Write([]byte("Collected env during the session:\n")); err != nil { return errors.WithStack(err) } - for _, env := range changed { + + for _, env := range envs { _, err := cmd.OutOrStdout().Write([]byte(env + "\n")) if err != nil { return errors.WithStack(err) @@ -90,6 +69,63 @@ All exported variables during the session will be available to the subsequent co return &cmd } +func executeDefaultShellProgram( + ctx context.Context, + commandFactory command.Factory, + stdin io.Reader, + stdout io.Writer, + stderr io.Writer, + additionalEnv []string, +) ([]string, error) { + envCollector, err := command.NewEnvCollectorFactory().Build() + if err != nil { + return nil, errors.WithStack(err) + } + + cfg := &command.ProgramConfig{ + ProgramName: defaultShell(), + Mode: runnerv2.CommandMode_COMMAND_MODE_CLI, + Env: append( + []string{command.CreateEnv(command.EnvCollectorSessionEnvName, "1")}, + append(envCollector.ExtraEnv(), additionalEnv...)..., + ), + } + options := command.CommandOptions{ + NoShell: true, + Stdin: stdin, + Stdout: stdout, + Stderr: stderr, + } + program, err := commandFactory.Build(cfg, options) + if err != nil { + return nil, err + } + + err = program.Start(ctx) + if err != nil { + return nil, err + } + + err = program.Wait() + if err != nil { + return nil, err + } + + changed, _, err := envCollector.Diff() + return changed, err +} + +func defaultShell() string { + shell := os.Getenv("SHELL") + if shell == "" { + shell, _ = exec.LookPath("bash") + } + if shell == "" { + shell = "/bin/sh" + } + return shell +} + func sessionSetupCmd() *cobra.Command { var debug bool @@ -104,9 +140,21 @@ func sessionSetupCmd() *cobra.Command { ) error { defer logger.Sync() - if val, err := strconv.ParseBool(os.Getenv(command.EnvCollectorSessionEnvName)); err != nil || !val { + out := cmd.OutOrStdout() + + if err := requireEnvs( + command.EnvCollectorSessionEnvName, + command.EnvCollectorSessionPrePathEnvName, + command.EnvCollectorSessionPostPathEnvName, + ); err != nil { + logger.Info("session setup is skipped because the environment variable is not set", zap.Error(err)) + return writeNoopShellCommand(out) + } + + sessionSetupEnabled := os.Getenv(command.EnvCollectorSessionEnvName) + if val, err := strconv.ParseBool(sessionSetupEnabled); err != nil || !val { logger.Debug("session setup is skipped", zap.Error(err), zap.Bool("value", val)) - return nil + return writeNoopShellCommand(out) } envSetter := command.NewScriptEnvSetter( @@ -114,9 +162,15 @@ func sessionSetupCmd() *cobra.Command { os.Getenv(command.EnvCollectorSessionPostPathEnvName), debug, ) + if err := envSetter.SetOnShell(out); err != nil { + return err + } + + if _, err := cmd.ErrOrStderr().Write([]byte("Runme session active. When you're done, execute \"exit\".\n")); err != nil { + return errors.WithStack(err) + } - err := envSetter.SetOnShell(cmd.OutOrStdout()) - return errors.WithStack(err) + return nil }, ) }, @@ -127,13 +181,17 @@ func sessionSetupCmd() *cobra.Command { return &cmd } -func defaultShell() string { - shell := os.Getenv("SHELL") - if shell == "" { - shell, _ = exec.LookPath("bash") - } - if shell == "" { - shell = "/bin/sh" +func requireEnvs(names ...string) error { + var err error + for _, name := range names { + if os.Getenv(name) == "" { + err = multierr.Append(err, errors.Errorf("environment variable %q is required", name)) + } } - return shell + return err +} + +func writeNoopShellCommand(w io.Writer) error { + _, err := w.Write([]byte(":")) + return errors.WithStack(err) } diff --git a/internal/command/command_test.go b/internal/command/command_test.go index ff26b535d..d95e3e273 100644 --- a/internal/command/command_test.go +++ b/internal/command/command_test.go @@ -12,13 +12,7 @@ import ( ) func init() { - // Switch from "runme env" to "env -0" for the tests. - // This is because the "runme" program is not available - // in the test environment. - // - // TODO(adamb): this can be changed. runme must be built - // in the test environment and put into the PATH. - SetEnvDumpCommand("env -0") + SetEnvDumpCommandForTesting() } func testExecuteCommand( diff --git a/internal/command/env_collector.go b/internal/command/env_collector.go index 733ae7eab..f637cabd8 100644 --- a/internal/command/env_collector.go +++ b/internal/command/env_collector.go @@ -31,10 +31,12 @@ var envDumpCommand = func() string { return strings.Join([]string{path, "env", "dump", "--insecure"}, " ") }() -// SetEnvDumpCommand overrides the default command that dumps the environment variables. +// SetEnvDumpCommandForTesting overrides the default command that dumps the environment variables. // It is and should be used only for testing purposes. -func SetEnvDumpCommand(cmd string) { - envDumpCommand = cmd +// TODO(adamb): this can be made obsolete. runme must be built +// in the test environment and put into the PATH. +func SetEnvDumpCommandForTesting() { + envDumpCommand = "env -0" // When overriding [envDumpCommand], we disable the encryption. // There is no reliable way at the moment to have encryption and // not control the dump command. diff --git a/internal/command/env_shell.go b/internal/command/env_shell.go index 9560713d9..5ffee8d87 100644 --- a/internal/command/env_shell.go +++ b/internal/command/env_shell.go @@ -2,6 +2,9 @@ package command import ( "io" + + "github.com/pkg/errors" + "go.uber.org/multierr" ) const StoreStdoutEnvName = "__" @@ -23,8 +26,8 @@ type ScriptEnvSetter struct { postPath string } -func NewScriptEnvSetter(prePath, postPath string, debug bool) *ScriptEnvSetter { - return &ScriptEnvSetter{ +func NewScriptEnvSetter(prePath, postPath string, debug bool) ScriptEnvSetter { + return ScriptEnvSetter{ debug: debug, dumpCommand: envDumpCommand, prePath: prePath, @@ -32,10 +35,23 @@ func NewScriptEnvSetter(prePath, postPath string, debug bool) *ScriptEnvSetter { } } -func (s *ScriptEnvSetter) SetOnShell(shell io.Writer) error { +func (s ScriptEnvSetter) SetOnShell(shell io.Writer) error { + if err := s.validate(); err != nil { + return err + } return setOnShell(shell, s.dumpCommand, false, true, s.debug, s.prePath, s.postPath) } +func (s ScriptEnvSetter) validate() (err error) { + if s.prePath == "" { + err = multierr.Append(err, errors.New("pre-path is required")) + } + if s.postPath == "" { + err = multierr.Append(err, errors.New("post-path is required")) + } + return +} + func setOnShell( shell io.Writer, dumpCommand string, diff --git a/internal/command/factory.go b/internal/command/factory.go index 672d4b8ec..002473150 100644 --- a/internal/command/factory.go +++ b/internal/command/factory.go @@ -82,6 +82,9 @@ func NewFactory(opts ...FactoryOption) Factory { for _, opt := range opts { opt(f) } + if f.logger == nil { + f.logger = zap.NewNop() + } return f } diff --git a/internal/runnerv2service/service_execute_test.go b/internal/runnerv2service/service_execute_test.go index e8fca0a4c..db24bc47a 100644 --- a/internal/runnerv2service/service_execute_test.go +++ b/internal/runnerv2service/service_execute_test.go @@ -29,7 +29,7 @@ import ( ) func init() { - command.SetEnvDumpCommand("env -0") + command.SetEnvDumpCommandForTesting() // Server uses autoconfig to get necessary dependencies. // One of them, implicit, is [config.Config]. With the default