diff --git a/cmd/app.go b/cmd/app.go index 537b84e9e..fd6732b11 100644 --- a/cmd/app.go +++ b/cmd/app.go @@ -1130,8 +1130,8 @@ func processValueUpgrades(ctx *cli.Context, existingValues string) (string, erro } // processValues creates a map of the values file -func processValues(ctx *cli.Context, existingValues string) (map[string]string, error) { - answers := make(map[string]string) +func processValues(ctx *cli.Context, existingValues string) (map[string]interface{}, error) { + answers := make(map[string]interface{}) if existingValues != "" { //parse values into map to ensure previous values are considered on update valuesMap, err := createValuesMap([]byte(existingValues)) @@ -1173,7 +1173,7 @@ func processAnswerInstall( } if multicluster && !interactive { // add default values if answers missing from map - err = fillInDefaultAnswers(tv, answers) + err = fillInDefaultAnswersStringMap(tv, answers) if err != nil { return answers, err } @@ -1204,7 +1204,7 @@ func processAnswerUpdates(ctx *cli.Context, answers map[string]string) (map[stri } // parseMapToYamlString create yaml string from answers map -func parseMapToYamlString(answerMap map[string]string) (string, error) { +func parseMapToYamlString(answerMap map[string]interface{}) (string, error) { yamlFileString, err := yaml.Marshal(answerMap) if err != nil { return "", err @@ -1229,7 +1229,7 @@ func parseAnswersFile(location string, answers map[string]string) error { } // parseValuesFile reads a values file and parses it to answers in helm strvals format -func parseValuesFile(location string, answers map[string]string) error { +func parseValuesFile(location string, answers map[string]interface{}) error { values, err := parseFile(location) if err != nil { return err @@ -1261,13 +1261,13 @@ func createValuesMap(bytes []byte) (map[string]interface{}, error) { return values, nil } -func valuesToAnswers(values map[string]interface{}, answers map[string]string) { +func valuesToAnswers(values map[string]interface{}, answers map[string]interface{}) { for k, v := range values { traverseValuesToAnswers(k, v, answers) } } -func traverseValuesToAnswers(key string, obj interface{}, answers map[string]string) { +func traverseValuesToAnswers(key string, obj interface{}, answers map[string]interface{}) { if obj == nil { return } @@ -1287,6 +1287,8 @@ func traverseValuesToAnswers(key string, obj interface{}, answers map[string]str traverseValuesToAnswers(nextKey, v, answers) } } + case reflect.Bool: + answers[key] = obj default: answers[key] = fmt.Sprintf("%v", obj) } @@ -1301,13 +1303,13 @@ func askQuestions(tv *managementClient.TemplateVersion, answers map[string]strin for { attempts++ for _, question := range tv.Questions { - if _, ok := answers[question.Variable]; !ok && checkShowIf(question.ShowIf, answers) { + if _, ok := answers[question.Variable]; !ok && checkShowIfStringMap(question.ShowIf, answers) { asked = true answers[question.Variable] = askQuestion(question) - if checkShowSubquestionIf(question, answers) { + if checkShowSubquestionIfStringMap(question, answers) { for _, subQuestion := range question.Subquestions { // only ask the question if there is not an answer and it passes the ShowIf check - if _, ok := answers[subQuestion.Variable]; !ok && checkShowIf(subQuestion.ShowIf, answers) { + if _, ok := answers[subQuestion.Variable]; !ok && checkShowIfStringMap(subQuestion.ShowIf, answers) { answers[subQuestion.Variable] = askSubQuestion(subQuestion) } } @@ -1374,7 +1376,7 @@ func askSubQuestion(q managementClient.SubQuestion) string { } // fillInDefaultAnswers parses through questions and creates an answer map with default answers if missing from map -func fillInDefaultAnswers(tv *managementClient.TemplateVersion, answers map[string]string) error { +func fillInDefaultAnswers(tv *managementClient.TemplateVersion, answers map[string]interface{}) error { if tv == nil { return nil } @@ -1399,7 +1401,51 @@ func fillInDefaultAnswers(tv *managementClient.TemplateVersion, answers map[stri // checkShowIf uses the ShowIf field to determine if a question should be asked // this field comes in the format = where key is a question id and value is the answer -func checkShowIf(s string, answers map[string]string) bool { +func checkShowIf(s string, answers map[string]interface{}) bool { + // No ShowIf so always ask the question + if len(s) == 0 { + return true + } + + pieces := strings.Split(s, "=") + if len(pieces) != 2 { + return false + } + + //if the key exists and the val matches the expression ask the question + if val, ok := answers[pieces[0]]; ok && fmt.Sprintf("%v", val) == pieces[1] { + return true + } + return false +} + +// fillInDefaultAnswersStringMap parses through questions and creates an answer map with default answers if missing from map +func fillInDefaultAnswersStringMap(tv *managementClient.TemplateVersion, answers map[string]string) error { + if tv == nil { + return nil + } + for _, question := range tv.Questions { + if _, ok := answers[question.Variable]; !ok && checkShowIfStringMap(question.ShowIf, answers) { + answers[question.Variable] = question.Default + if checkShowSubquestionIfStringMap(question, answers) { + for _, subQuestion := range question.Subquestions { + // set the sub-question if the showIf check passes + if _, ok := answers[subQuestion.Variable]; !ok && checkShowIfStringMap(subQuestion.ShowIf, answers) { + answers[subQuestion.Variable] = subQuestion.Default + } + } + } + } + } + if answers == nil { + return errors.New("could not generate default answers") + } + return nil +} + +// checkShowIfStringMap uses the ShowIf field to determine if a question should be asked +// this field comes in the format = where key is a question id and value is the answer +func checkShowIfStringMap(s string, answers map[string]string) bool { // No ShowIf so always ask the question if len(s) == 0 { return true @@ -1417,7 +1463,16 @@ func checkShowIf(s string, answers map[string]string) bool { return false } -func checkShowSubquestionIf(q managementClient.Question, answers map[string]string) bool { +func checkShowSubquestionIf(q managementClient.Question, answers map[string]interface{}) bool { + if val, ok := answers[q.Variable]; ok { + if fmt.Sprintf("%v", val) == q.ShowSubquestionIf { + return true + } + } + return false +} + +func checkShowSubquestionIfStringMap(q managementClient.Question, answers map[string]string) bool { if val, ok := answers[q.Variable]; ok { if val == q.ShowSubquestionIf { return true diff --git a/cmd/app_test.go b/cmd/app_test.go index 220d128ca..55f75736a 100644 --- a/cmd/app_test.go +++ b/cmd/app_test.go @@ -59,7 +59,7 @@ configmap: |- func TestValuesToAnswers(t *testing.T) { assert := assert.New(t) - answers := map[string]string{} + answers := map[string]interface{}{} values := map[string]interface{}{} if err := yaml.Unmarshal([]byte(redisSample), &values); err != nil { t.Error(err) @@ -67,14 +67,14 @@ func TestValuesToAnswers(t *testing.T) { valuesToAnswers(values, answers) assert.Equal("docker.io", answers["image.registry"], "unexpected image.registry") - assert.Equal("true", answers["cluster.enabled"], "unexpected cluster.enabled") + assert.Equal(true, answers["cluster.enabled"], "unexpected cluster.enabled") assert.Equal("1", answers["cluster.slaveCount"], "unexpected cluster.slaveCount") - assert.Equal("", answers["rbac.role.rules"], "unexpected rbac.role.rules") - assert.Equal("", answers["persistence"], "unexpected persistence") + assert.Equal(nil, answers["rbac.role.rules"], "unexpected rbac.role.rules") + assert.Equal(nil, answers["persistence"], "unexpected persistence") assert.Equal("redis-server", answers["master.args[0]"], "unexpected master.args[0]") assert.Equal("--maxmemory-policy volatile-ttl", answers["master.args[1]"], "unexpected master.args[1]") assert.Equal("FLUSHDB,FLUSHALL", answers["master.disableCommands"], "unexpected master.disableCommands") - assert.Equal("", answers["master.service.loadBalancerIP"], "unexpected master.service.loadBalancerIP") + assert.Equal(nil, answers["master.service.loadBalancerIP"], "unexpected master.service.loadBalancerIP") assert.Equal("ReadWriteOnce", answers["master.persistence.accessModes[0]"], "unexpected master.persistence.accessModes[0]") assert.Equal("# Redis configuration file\nbind 127.0.0.1\nport 6379", answers["configmap"], "unexpected configmap") }