diff --git a/cmd/eksctl-anywhere/cmd/upgradecluster.go b/cmd/eksctl-anywhere/cmd/upgradecluster.go index 12e2098ab953..cd34eff59207 100644 --- a/cmd/eksctl-anywhere/cmd/upgradecluster.go +++ b/cmd/eksctl-anywhere/cmd/upgradecluster.go @@ -39,13 +39,14 @@ var upgradeClusterCmd = &cobra.Command{ Long: "This command is used to upgrade workload clusters", PreRunE: bindFlagsToViper, SilenceUsage: true, + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { if uc.forceClean { logger.MarkFail(forceCleanupDeprecationMessageForUpgrade) return errors.New("please remove the --force-cleanup flag") } - if err := uc.upgradeCluster(cmd); err != nil { + if err := uc.upgradeCluster(cmd, args); err != nil { return fmt.Errorf("failed to upgrade cluster: %v", err) } return nil @@ -65,7 +66,8 @@ func init() { flags.MarkRequired(createClusterCmd.Flags(), flags.ClusterConfig.Name) } -func (uc *upgradeClusterOptions) upgradeCluster(cmd *cobra.Command) error { +// nolint:gocyclo +func (uc *upgradeClusterOptions) upgradeCluster(cmd *cobra.Command, args []string) error { ctx := cmd.Context() clusterConfigFileExist := validations.FileExists(uc.fileName) @@ -91,6 +93,10 @@ func (uc *upgradeClusterOptions) upgradeCluster(cmd *cobra.Command) error { if _, err := uc.commonValidations(ctx); err != nil { return fmt.Errorf("common validations failed due to: %v", err) } + + if err := validations.ValidateClusterNameFromCommandAndConfig(args, clusterConfig.Name); err != nil { + return err + } clusterSpec, err := newClusterSpec(uc.clusterOptions) if err != nil { return err diff --git a/pkg/validations/input.go b/pkg/validations/input.go index 7536c4c6c5f6..071ae5a7f0c2 100644 --- a/pkg/validations/input.go +++ b/pkg/validations/input.go @@ -2,6 +2,7 @@ package validations import ( "errors" + "fmt" "os" "github.com/aws/eks-anywhere/pkg/api/v1alpha1" @@ -31,3 +32,17 @@ func FileExistsAndIsNotEmpty(filename string) bool { info, err := os.Stat(filename) return err == nil && info.Size() > 0 } + +// ValidateClusterNameFromCommandAndConfig validates if cluster name provided in command matches with cluster name in config file. +func ValidateClusterNameFromCommandAndConfig(args []string, clusterNameConfig string) error { + if len(args) != 0 { + clusterNameCli, err := ValidateClusterNameArg(args) + if err != nil { + return fmt.Errorf("please provide a valid ") + } + if clusterNameCli != clusterNameConfig { + return fmt.Errorf("please make sure cluster name provided in command matches with cluster name in config file") + } + } + return nil +} diff --git a/pkg/validations/input_test.go b/pkg/validations/input_test.go index 340677d93c7a..38fc103afb25 100644 --- a/pkg/validations/input_test.go +++ b/pkg/validations/input_test.go @@ -111,3 +111,46 @@ func TestValidateClusterNameArg(t *testing.T) { }) } } + +func TestValidateClusterNameFromCommandAndConfig(t *testing.T) { + tests := []struct { + name string + args []string + clusterNameConfig string + expectedError error + }{ + { + name: "Success cluster name match", + args: []string{"test-cluster"}, + clusterNameConfig: "test-cluster", + expectedError: nil, + }, + { + name: "Success empty Arguments", + args: []string{}, + clusterNameConfig: "test-cluster", + expectedError: nil, + }, + { + name: "Failure invalid cluster name", + args: []string{"123test-Cluster"}, + clusterNameConfig: "test-cluster", + expectedError: errors.New("please provide a valid "), + }, + { + name: "Failure cluster name not match", + args: []string{"test-cluster-1"}, + clusterNameConfig: "test-cluster", + expectedError: errors.New("please make sure cluster name provided in command matches with cluster name in config file"), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(tt *testing.T) { + gotError := validations.ValidateClusterNameFromCommandAndConfig(tc.args, tc.clusterNameConfig) + if !reflect.DeepEqual(tc.expectedError, gotError) { + t.Errorf("\n%v got Error = %v, want Error %v", tc.name, gotError, tc.expectedError) + } + }) + } +}