Skip to content

Commit

Permalink
feat: Add PreRunFunc and PostRunFunc
Browse files Browse the repository at this point in the history
  • Loading branch information
ginokent committed Aug 10, 2024
1 parent 8a56ed3 commit 72b5233
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 41 deletions.
4 changes: 2 additions & 2 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ linters:
- deadcode # deprecated (since v1.49.0) due to: The owner seems to have abandoned the linter. Replaced by unused.
- depguard # unnecessary
- dupl # too many unnecessary detections
- exhaustruct # https://github.com/GaijinEntertainment/go-exhaustruct
- exhaustivestruct # https://github.com/mbilski/exhaustivestruct
- exhaustruct # unnecessary
- gci # unnecessary
- goconst # unnecessary
- godot # unnecessary
Expand Down Expand Up @@ -61,6 +60,7 @@ issues:
- maintidx
- noctx
- revive
- staticcheck
- testpackage
- varnamelen
- wrapcheck
79 changes: 52 additions & 27 deletions exp/cli/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ type (
Command struct {
// Name is the name of the command.
Name string
// Short is the short name of the command.
Short string
// Aliases is the alias names of the command.
Aliases []string
// Usage is the usage of the command.
//
// If you want to use the default usage, remain empty.
Expand All @@ -46,8 +46,12 @@ type (
Description string
// Options is the options of the command.
Options []Option
// PreRunFunc is the function to be executed before RunFunc.
PreRunFunc func(ctx context.Context, cmd *Command, remainingArgs []string) error
// RunFunc is the function to be executed when (*Command).Run is executed.
RunFunc func(ctx context.Context, remainingArgs []string) error
RunFunc func(ctx context.Context, cmd *Command, remainingArgs []string) error
// PostRunFunc is the function to be executed after RunFunc.
PostRunFunc func(ctx context.Context, cmd *Command, remainingArgs []string) error
// SubCommands is the subcommands of the command.
SubCommands []*Command

Expand Down Expand Up @@ -91,8 +95,10 @@ func (cmd *Command) IsCommand(cmdName string) bool {
if cmd.Name == cmdName {
return true
}
if cmd.Short == cmdName {
return true
for _, alias := range cmd.Aliases {
if alias == cmdName {
return true
}
}
return false
}
Expand Down Expand Up @@ -153,15 +159,17 @@ func (cmd *Command) getSubcommand(arg string) (subcmd *Command) {
return nil
}

func equalOptionArg(o Option, arg string) bool {
// --long or -s
func argIsHyphenOption(o Option, arg string) bool {
return longOptionPrefix+o.GetName() == arg || shortOptionPrefix+o.GetShort() == arg
}

func hasPrefixOptionEqualArg(o Option, arg string) bool {
// --long=value or -s=value
func argIsHyphenOptionEqual(o Option, arg string) bool {
return strings.HasPrefix(arg, longOptionPrefix+o.GetName()+"=") || strings.HasPrefix(arg, shortOptionPrefix+o.GetShort()+"=")
}

func extractValueOptionEqualArg(arg string) string {
func extractValueFromHyphenOptionEqual(arg string) string {
return strings.Join(strings.Split(arg, "=")[1:], "=")
}

Expand Down Expand Up @@ -189,7 +197,7 @@ argsLoop:
switch o := opt.(type) {
case *StringOption:
switch {
case equalOptionArg(o, arg):
case argIsHyphenOption(o, arg):
DebugLog.Printf("%s: option: %s: %s", cmd.GetName(), o.Name, arg)
if hasOptionValue(args, i) {
return nil, errorz.Errorf("%s: %w", arg, ErrMissingOptionValue)
Expand All @@ -198,22 +206,22 @@ argsLoop:
i++
TraceLog.Printf("%s: parsed option: %s: %v", cmd.GetName(), o.Name, *o.value)
continue argsLoop
case hasPrefixOptionEqualArg(o, arg):
case argIsHyphenOptionEqual(o, arg):
DebugLog.Printf("%s: option: %s: %s", cmd.GetName(), o.Name, arg)
o.value = ptr(extractValueOptionEqualArg(arg))
o.value = ptr(extractValueFromHyphenOptionEqual(arg))
TraceLog.Printf("%s: parsed option: %s: %v", cmd.GetName(), o.Name, *o.value)
continue argsLoop
}
case *BoolOption:
switch {
case equalOptionArg(o, arg):
case argIsHyphenOption(o, arg):
DebugLog.Printf("%s: option: %s: %s", cmd.GetName(), o.Name, arg)
o.value = ptr(true)
TraceLog.Printf("%s: parsed option: %s: %v", cmd.GetName(), o.Name, *o.value)
continue argsLoop
case hasPrefixOptionEqualArg(o, arg):
case argIsHyphenOptionEqual(o, arg):
DebugLog.Printf("%s: option: %s: %s", cmd.GetName(), o.Name, arg)
optVal, err := strconv.ParseBool(extractValueOptionEqualArg(arg))
optVal, err := strconv.ParseBool(extractValueFromHyphenOptionEqual(arg))
if err != nil {
return nil, errorz.Errorf("%s: %w", arg, err)
}
Expand All @@ -223,7 +231,7 @@ argsLoop:
}
case *IntOption:
switch {
case equalOptionArg(o, arg):
case argIsHyphenOption(o, arg):
DebugLog.Printf("%s: option: %s: %s", cmd.GetName(), o.Name, arg)
if hasOptionValue(args, i) {
return nil, errorz.Errorf("%s: %w", arg, ErrMissingOptionValue)
Expand All @@ -236,9 +244,9 @@ argsLoop:
i++
TraceLog.Printf("%s: parsed option: %s: %v", cmd.GetName(), o.Name, *o.value)
continue argsLoop
case hasPrefixOptionEqualArg(o, arg):
case argIsHyphenOptionEqual(o, arg):
DebugLog.Printf("%s: option: %s: %s", cmd.GetName(), o.Name, arg)
optVal, err := strconv.Atoi(extractValueOptionEqualArg(arg))
optVal, err := strconv.Atoi(extractValueFromHyphenOptionEqual(arg))
if err != nil {
return nil, errorz.Errorf("%s: %w", arg, err)
}
Expand All @@ -248,7 +256,7 @@ argsLoop:
}
case *Float64Option:
switch {
case equalOptionArg(o, arg):
case argIsHyphenOption(o, arg):
DebugLog.Printf("%s: option: %s: %s", cmd.GetName(), o.Name, arg)
if hasOptionValue(args, i) {
return nil, errorz.Errorf("%s: %w", arg, ErrMissingOptionValue)
Expand All @@ -261,9 +269,9 @@ argsLoop:
i++
TraceLog.Printf("%s: parsed option: %s: %v", cmd.GetName(), o.Name, *o.value)
continue argsLoop
case hasPrefixOptionEqualArg(o, arg):
case argIsHyphenOptionEqual(o, arg):
DebugLog.Printf("%s: option: %s: %s", cmd.GetName(), o.Name, arg)
optVal, err := strconv.ParseFloat(extractValueOptionEqualArg(arg), 64)
optVal, err := strconv.ParseFloat(extractValueFromHyphenOptionEqual(arg), 64)
if err != nil {
return nil, errorz.Errorf("%s: %w", arg, err)
}
Expand Down Expand Up @@ -316,9 +324,10 @@ func (cmd *Command) initCommand() {
// This function is idempotent. If the conditions are the same, the same result will be returned no matter how many times it is called.
//
//nolint:cyclop
func (cmd *Command) Parse(args []string) (remainingArgs []string, err error) {
if len(args) > 0 && (args[0] == os.Args[0] || cmd.IsCommand(args[0])) {
args = args[1:]
func (cmd *Command) Parse(osArgs []string) (remainingArgs []string, err error) {
cmdArgs := osArgs
if len(osArgs) > 0 && (osArgs[0] == os.Args[0] || cmd.IsCommand(osArgs[0])) {
cmdArgs = osArgs[1:]
}

cmd.initCommand()
Expand All @@ -340,7 +349,7 @@ func (cmd *Command) Parse(args []string) (remainingArgs []string, err error) {
return nil, errorz.Errorf("failed to load environment: %w", err)
}

remaining, err := cmd.parseArgs(args)
remaining, err := cmd.parseArgs(cmdArgs)
if err != nil {
return nil, errorz.Errorf("failed to parse commands and options: %w", err)
}
Expand All @@ -359,8 +368,8 @@ func (cmd *Command) Parse(args []string) (remainingArgs []string, err error) {
// Run executes (*Command).RunFunc of the specified command or subcommand.
//
// If you only want to parse the options, use Parse instead of this.
func (cmd *Command) Run(ctx context.Context, args []string) error {
remainingArgs, err := cmd.Parse(args)
func (cmd *Command) Run(ctx context.Context, osArgs []string) error {
remainingArgs, err := cmd.Parse(osArgs)
if err != nil {
return errorz.Errorf("%s: %w", cmd.GetName(), err)
}
Expand All @@ -370,9 +379,25 @@ func (cmd *Command) Run(ctx context.Context, args []string) error {
execCmd = execCmd.Next()
}

if execCmd.PreRunFunc != nil {
if err := execCmd.PreRunFunc(ctx, execCmd, remainingArgs); err != nil {
return errorz.Errorf("%s: PreRunFunc: %w", strings.Join(execCmd.calledCommands, " "), err)
}
}

if execCmd.RunFunc == nil {
return errorz.Errorf("%s: %w", strings.Join(execCmd.calledCommands, " "), ErrCommandFuncNotSet)
}

return execCmd.RunFunc(WithContext(ctx, cmd), remainingArgs)
if err := execCmd.RunFunc(ctx, execCmd, remainingArgs); err != nil {
return errorz.Errorf("%s: RunFunc: %w", strings.Join(execCmd.calledCommands, " "), err)
}

if execCmd.PostRunFunc != nil {
if err := execCmd.PostRunFunc(ctx, execCmd, remainingArgs); err != nil {
return errorz.Errorf("%s: PostRunFunc: %w", strings.Join(execCmd.calledCommands, " "), err)
}
}

return nil
}
82 changes: 71 additions & 11 deletions exp/cli/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"errors"
"io"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -78,11 +79,7 @@ func TestCommand(t *testing.T) {
SubCommands: []*Command{
{
Name: "sub-sub-cmd",
RunFunc: func(ctx context.Context, remainingArgs []string) error {
cmd, err := FromContext(ctx)
if err != nil {
return errorz.Errorf("FromContext: %w", err)
}
RunFunc: func(ctx context.Context, cmd *Command, remainingArgs []string) error {
called := cmd.GetCalledCommands()
if !reflect.DeepEqual(called, []string{"my-cli", "sub-cmd", "sub-sub-cmd"}) {
return errorz.Errorf("unexpected command name: %v", called)
Expand Down Expand Up @@ -585,8 +582,8 @@ func TestCommand_IsCommand(t *testing.T) {
t.Parallel()
alias := "short"
cmd := &Command{
Name: "my-cli",
Short: "short",
Name: "my-cli",
Aliases: []string{"short"},
}
expected := true
actual := cmd.IsCommand(alias)
Expand Down Expand Up @@ -655,10 +652,10 @@ func TestCommand_GetCalledCommands(t *testing.T) {
})
}

func TestCommand_Exec(t *testing.T) {
func TestCommand_Run(t *testing.T) {
t.Parallel()

t.Run("success,Exec,ErrHelp", func(t *testing.T) {
t.Run("success,Run,ErrHelp", func(t *testing.T) {
t.Parallel()
args := []string{"my-cli", "--help"}
c := &Command{}
Expand All @@ -667,11 +664,17 @@ func TestCommand_Exec(t *testing.T) {
}
})

t.Run("success,Exec,", func(t *testing.T) {
t.Run("success,Run,", func(t *testing.T) {
t.Parallel()
c := &Command{
Name: "my-cli",
RunFunc: func(ctx context.Context, remainingArgs []string) error {
PreRunFunc: func(ctx context.Context, cmd *Command, remainingArgs []string) error {
return nil
},
RunFunc: func(ctx context.Context, cmd *Command, remainingArgs []string) error {
return nil
},
PostRunFunc: func(ctx context.Context, cmd *Command, remainingArgs []string) error {
return nil
},
SubCommands: []*Command{
Expand All @@ -687,4 +690,61 @@ func TestCommand_Exec(t *testing.T) {
t.Errorf("❌: err != nil: %v != %+v", nil, err)
}
})

t.Run("failure,Run,PreRunFunc", func(t *testing.T) {
t.Parallel()
c := &Command{
Name: "my-cli",
PreRunFunc: func(ctx context.Context, cmd *Command, remainingArgs []string) error {
return io.ErrUnexpectedEOF
},
RunFunc: func(ctx context.Context, cmd *Command, remainingArgs []string) error {
return nil
},
PostRunFunc: func(ctx context.Context, cmd *Command, remainingArgs []string) error {
return nil
},
}
if err := c.Run(context.Background(), []string{"my-cli"}[1:]); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Errorf("❌: expect != actual: %v != %+v", io.ErrUnexpectedEOF, err)
}
})

t.Run("failure,Run,RunFunc", func(t *testing.T) {
t.Parallel()
c := &Command{
Name: "my-cli",
PreRunFunc: func(ctx context.Context, cmd *Command, remainingArgs []string) error {
return nil
},
RunFunc: func(ctx context.Context, cmd *Command, remainingArgs []string) error {
return io.ErrUnexpectedEOF
},
PostRunFunc: func(ctx context.Context, cmd *Command, remainingArgs []string) error {
return nil
},
}
if err := c.Run(context.Background(), []string{"my-cli"}[1:]); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Errorf("❌: expect != actual: %v != %+v", io.ErrUnexpectedEOF, err)
}
})

t.Run("failure,Run,PostRunFunc", func(t *testing.T) {
t.Parallel()
c := &Command{
Name: "my-cli",
PreRunFunc: func(ctx context.Context, cmd *Command, remainingArgs []string) error {
return nil
},
RunFunc: func(ctx context.Context, cmd *Command, remainingArgs []string) error {
return nil
},
PostRunFunc: func(ctx context.Context, cmd *Command, remainingArgs []string) error {
return io.ErrUnexpectedEOF
},
}
if err := c.Run(context.Background(), []string{"my-cli"}[1:]); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Errorf("❌: expect != actual: %v != %+v", io.ErrUnexpectedEOF, err)
}
})
}
19 changes: 18 additions & 1 deletion exp/cli/context.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package cliz

import "context"
import (
"context"
"fmt"
)

type contextKeyCommand struct{}

Expand All @@ -9,10 +12,24 @@ func WithContext(ctx context.Context, cmd *Command) context.Context {
}

func FromContext(ctx context.Context) (*Command, error) {
if ctx == nil {
return nil, ErrNilContext
}

c, ok := ctx.Value(contextKeyCommand{}).(*Command)
if !ok {
return nil, ErrCommandNotSetInContext
}

return c, nil
}

func MustFromContext(ctx context.Context) *Command {
c, err := FromContext(ctx)
if err != nil {
err = fmt.Errorf("FromContext: %w", err)
panic(err)
}

return c
}
Loading

0 comments on commit 72b5233

Please sign in to comment.