From 6800db480030a4d5d65345f7e3505138873778a5 Mon Sep 17 00:00:00 2001 From: rsteube Date: Mon, 25 Sep 2023 18:47:22 +0200 Subject: [PATCH] tmp --- carapace.go | 4 ++-- compat.go | 6 +++++- defaultActions.go | 6 +++++- storage.go | 38 ++++++++++++++++++++++++++++++-------- 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/carapace.go b/carapace.go index 714139f69..9d03feac8 100644 --- a/carapace.go +++ b/carapace.go @@ -57,7 +57,7 @@ func (c Carapace) PositionalCompletion(action ...Action) { // PositionalAnyCompletion defines completion for any positional arguments not already defined. func (c Carapace) PositionalAnyCompletion(action Action) { - storage.get(c.cmd).positionalAny = action + storage.get(c.cmd).positionalAny = &action } // DashCompletion defines completion for positional arguments after dash (`--`) using a list of Actions. @@ -67,7 +67,7 @@ func (c Carapace) DashCompletion(action ...Action) { // DashAnyCompletion defines completion for any positional arguments after dash (`--`) not already defined. func (c Carapace) DashAnyCompletion(action Action) { - storage.get(c.cmd).dashAny = action + storage.get(c.cmd).dashAny = &action } // FlagCompletion defines completion for flags using a map consisting of name and Action. diff --git a/compat.go b/compat.go index 263d81c09..c83a2861c 100644 --- a/compat.go +++ b/compat.go @@ -11,7 +11,11 @@ import ( func registerValidArgsFunction(cmd *cobra.Command) { if cmd.ValidArgsFunction == nil { cmd.ValidArgsFunction = func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - action := storage.getPositional(cmd, len(args)).Invoke(Context{Args: args, Value: toComplete}) + // TODO check storage.hasPositional to prevent loop + action := Action{}.Invoke(Context{Args: args, Value: toComplete}) // TODO just IvokedAction{} ok? + if storage.hasPositional(cmd, len(args)) { + action = storage.getPositional(cmd, len(args)).Invoke(Context{Args: args, Value: toComplete}) + } return cobraValuesFor(action), cobraDirectiveFor(action) } } diff --git a/defaultActions.go b/defaultActions.go index 75799e386..c17166878 100644 --- a/defaultActions.go +++ b/defaultActions.go @@ -473,7 +473,11 @@ func ActionPositional(cmd *cobra.Command) Action { c.Args = cmd.Flags().Args() entry := storage.get(cmd) - a := entry.positionalAny + var a Action + if entry.positionalAny != nil { + a = *entry.positionalAny + } + if index := len(c.Args); index < len(entry.positional) { a = entry.positional[len(c.Args)] } diff --git a/storage.go b/storage.go index ebebbb892..3b9687a92 100644 --- a/storage.go +++ b/storage.go @@ -17,9 +17,9 @@ type entry struct { flag ActionMap flagMutex sync.RWMutex positional []Action - positionalAny Action + positionalAny *Action dash []Action - dashAny Action + dashAny *Action preinvoke func(cmd *cobra.Command, flag *pflag.Flag, action Action) Action prerun func(cmd *cobra.Command, args []string) bridged bool @@ -60,8 +60,7 @@ func (s _storage) bridge(cmd *cobra.Command) { defer bridgeMutex.Unlock() if !entry.initialized { - // TODO only if completion is defined in carapace - // registerValidArgsFunction(cmd) + registerValidArgsFunction(cmd) registerFlagCompletion(cmd) entry.initialized = true } @@ -133,6 +132,24 @@ func (s _storage) preinvoke(cmd *cobra.Command, flag *pflag.Flag, action Action) return a } +func (s _storage) hasPositional(cmd *cobra.Command, index int) bool { + entry := s.get(cmd) + isDash := common.IsDash(cmd) + + // TODO fallback to cobra defined completion if exists + + switch { + case !isDash && len(entry.positional) > index: + return true + case !isDash: + return entry.positionalAny != nil + case len(entry.dash) > index: + return true + default: + return entry.dashAny != nil + } +} + func (s _storage) getPositional(cmd *cobra.Command, index int) Action { entry := s.get(cmd) isDash := common.IsDash(cmd) @@ -142,14 +159,19 @@ func (s _storage) getPositional(cmd *cobra.Command, index int) Action { var a Action switch { case !isDash && len(entry.positional) > index: - a = s.preinvoke(cmd, nil, entry.positional[index]) + a = entry.positional[index] case !isDash: - a = s.preinvoke(cmd, nil, entry.positionalAny) + if entry.positionalAny != nil { + a = *entry.positionalAny + } case len(entry.dash) > index: - a = s.preinvoke(cmd, nil, entry.dash[index]) + a = entry.dash[index] default: - a = s.preinvoke(cmd, nil, entry.dashAny) + if entry.dashAny != nil { + a = *entry.dashAny + } } + a = s.preinvoke(cmd, nil, a) return ActionCallback(func(c Context) Action { invoked := a.Invoke(c)