diff --git a/cmd/pcert/create_test.go b/cmd/pcert/create_test.go index 24f2a78..4ff3b09 100644 --- a/cmd/pcert/create_test.go +++ b/cmd/pcert/create_test.go @@ -13,28 +13,32 @@ import ( "github.com/dvob/pcert" ) -func runCmd(args []string, env map[string]string) (io.WriteCloser, *bytes.Buffer, *bytes.Buffer, error) { +func runCmd(args []string, stdin io.Reader, env map[string]string) (*bytes.Buffer, *bytes.Buffer, error) { + stdout := &bytes.Buffer{} stderr := &bytes.Buffer{} - stdinReader, stdinWriter := io.Pipe() - cmd := newRootCmd() - cmd.SetArgs(args) - cmd.SetIn(stdinReader) - cmd.SetOut(stdout) - cmd.SetErr(stderr) - cmd = WithEnv(cmd, args, func(name string) (string, bool) { + + if stdin == nil { + stdin = bytes.NewReader(nil) + } + + code := run(args, stdin, stdout, stderr, func(key string) (string, bool) { if env == nil { return "", false } - val, ok := env[name] + val, ok := env[key] return val, ok }) - return stdinWriter, stdout, stderr, cmd.Execute() + if code != 0 { + return stdout, stderr, fmt.Errorf("execution failed. stderr='%s'", stderr.String()) + } + + return stdout, stderr, nil } func runAndLoad(args []string, env map[string]string) (*x509.Certificate, error) { - _, stdout, stderr, err := runCmd(args, env) + stdout, stderr, err := runCmd(args, nil, env) if err != nil { return nil, err } @@ -45,7 +49,7 @@ func runAndLoad(args []string, env map[string]string) (*x509.Certificate, error) cert, err := pcert.Parse(stdout.Bytes()) if err != nil { - return nil, fmt.Errorf("could not read certificate from standard output: %s", err) + return nil, fmt.Errorf("could not read certificate from standard output: %s. stdout='%s'", err, stdout.String()) } return cert, err @@ -215,10 +219,10 @@ func Test_create_not_before_with_expiry(t *testing.T) { func Test_create_output_parameter(t *testing.T) { defer os.Remove("tls.crt") defer os.Remove("tls.key") - _, _, _, err := runCmd([]string{ + _, _, err := runCmd([]string{ "create", "tls.crt", - }, nil) + }, nil, nil) if err != nil { t.Fatal(err) return diff --git a/cmd/pcert/main.go b/cmd/pcert/main.go index d1be64b..0294a28 100644 --- a/cmd/pcert/main.go +++ b/cmd/pcert/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "io" "os" "github.com/spf13/cobra" @@ -20,11 +21,27 @@ var ( ) func main() { - err := WithEnv(newRootCmd(), os.Args[1:], os.LookupEnv).Execute() + code := run(os.Args[1:], os.Stdin, os.Stdout, os.Stderr, os.LookupEnv) + os.Exit(code) +} + +func run(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer, getEnv func(string) (string, bool)) int { + rootCmd := newRootCmd() + + rootCmd.SetOut(stdout) + rootCmd.SetErr(stderr) + rootCmd.SetIn(stdin) + + rootCmd = WithEnv(rootCmd, args, getEnv) + rootCmd.SetArgs(args) + + err := rootCmd.Execute() if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) + fmt.Fprintln(stderr, err) + return 1 } + + return 0 } func newRootCmd() *cobra.Command { diff --git a/cmd/pcert/request_test.go b/cmd/pcert/request_test.go index 58575a4..4d8215d 100644 --- a/cmd/pcert/request_test.go +++ b/cmd/pcert/request_test.go @@ -8,11 +8,11 @@ import ( func Test_request(t *testing.T) { name := "foo" - _, stdout, stderr, err := runCmd([]string{ + stdout, stderr, err := runCmd([]string{ "request", "--subject", "/CN=" + name, - }, nil) + }, nil, nil) if err != nil { t.Fatal(err) return