From b2fe0ffec64315b2ab2f48377c9325d152e20d59 Mon Sep 17 00:00:00 2001 From: Shude Li Date: Tue, 14 May 2024 21:27:59 +0800 Subject: [PATCH] feat(client): overwrite client context instead of setting new one (#20356) --- client/cmd.go | 13 ++++++++----- client/cmd_test.go | 26 +++++++++++++++++++++++--- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/client/cmd.go b/client/cmd.go index e817649d24dc..d4f0693d7e4a 100644 --- a/client/cmd.go +++ b/client/cmd.go @@ -359,14 +359,17 @@ func GetClientContextFromCmd(cmd *cobra.Command) Context { // SetCmdClientContext sets a command's Context value to the provided argument. // If the context has not been set, set the given context as the default. func SetCmdClientContext(cmd *cobra.Command, clientCtx Context) error { - var cmdCtx context.Context - - if cmd.Context() == nil { + cmdCtx := cmd.Context() + if cmdCtx == nil { cmdCtx = context.Background() + } + + v := cmd.Context().Value(ClientContextKey) + if clientCtxPtr, ok := v.(*Context); ok { + *clientCtxPtr = clientCtx } else { - cmdCtx = cmd.Context() + cmd.SetContext(context.WithValue(cmdCtx, ClientContextKey, &clientCtx)) } - cmd.SetContext(context.WithValue(cmdCtx, ClientContextKey, &clientCtx)) return nil } diff --git a/client/cmd_test.go b/client/cmd_test.go index 81d5719ccfb6..559d31d39f40 100644 --- a/client/cmd_test.go +++ b/client/cmd_test.go @@ -79,11 +79,13 @@ func TestSetCmdClientContextHandler(t *testing.T) { name string expectedContext client.Context args []string + ctx context.Context }{ { "no flags set", initClientCtx, []string{}, + context.WithValue(context.Background(), client.ClientContextKey, &client.Context{}), }, { "flags set", @@ -91,6 +93,7 @@ func TestSetCmdClientContextHandler(t *testing.T) { []string{ fmt.Sprintf("--%s=new-chain-id", flags.FlagChainID), }, + context.WithValue(context.Background(), client.ClientContextKey, &client.Context{}), }, { "flags set with space", @@ -99,6 +102,25 @@ func TestSetCmdClientContextHandler(t *testing.T) { fmt.Sprintf("--%s", flags.FlagHome), "/tmp/dir", }, + context.Background(), + }, + { + "no context provided", + initClientCtx.WithHomeDir("/tmp/noctx"), + []string{ + fmt.Sprintf("--%s", flags.FlagHome), + "/tmp/noctx", + }, + nil, + }, + { + "with invalid client value in the context", + initClientCtx.WithHomeDir("/tmp/invalid"), + []string{ + fmt.Sprintf("--%s", flags.FlagHome), + "/tmp/invalid", + }, + context.WithValue(context.Background(), client.ClientContextKey, "invalid"), }, } @@ -106,13 +128,11 @@ func TestSetCmdClientContextHandler(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), client.ClientContextKey, &client.Context{}) - cmd := newCmd() _ = testutil.ApplyMockIODiscardOutErr(cmd) cmd.SetArgs(tc.args) - require.NoError(t, cmd.ExecuteContext(ctx)) + require.NoError(t, cmd.ExecuteContext(tc.ctx)) clientCtx := client.GetClientContextFromCmd(cmd) require.Equal(t, tc.expectedContext, clientCtx)