From 52c876991fbc41468699ea527c343499d0753b43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9my=20LAMBERT?= Date: Mon, 30 Sep 2019 14:12:27 +0200 Subject: [PATCH] Add full form access to rule functions --- README.md | 58 +++++++++++++++++++- doc/CUSTOM_RULE.md | 2 +- rules.go | 78 +++++++++++++-------------- rules_test.go | 123 ++++++++++++++++++++++++++++++++++++++++-- validate_file_test.go | 4 +- validator.go | 20 +++++-- 6 files changed, 233 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index 3527a48..19be844 100644 --- a/README.md +++ b/README.md @@ -179,7 +179,7 @@ Send request to the server using curl or postman: `curl GET "http://localhost:90 ```go func init() { // simple example - govalidator.AddCustomRule("must_john", func(field string, rule string, message string, value interface{}) error { + govalidator.AddCustomRule("must_john", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { val := value.(string) if val != "john" || val != "John" { return fmt.Errorf("The %s field must be John or john", field) @@ -189,7 +189,7 @@ func init() { // custom rules to take fixed length word. // e.g: word:5 will throw error if the field does not contain exact 5 word - govalidator.AddCustomRule("word", func(field string, rule string, message string, value interface{}) error { + govalidator.AddCustomRule("word", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { valSlice := strings.Fields(value.(string)) l, _ := strconv.Atoi(strings.TrimPrefix(rule, "word:")) //handle other error if len(valSlice) != l { @@ -202,6 +202,60 @@ func init() { ``` Note: Array, map, slice can be validated by adding custom rules. +You can use the `form` parameter to compare one field with another. +```go +govalidator.AddCustomRule("greater_than", (f string, rule string, message string, v interface{}, form map[string]interface{}) error { + if form == nil { // All comparison rules should check if the form is provided + panic(errors.New("No form provided for comparison rule")) + } + + compareToKey := strings.TrimPrefix(rule, "greater_than:") + err := fmt.Errorf("The %s field must be greater than the %s field", f, compareToKey) + compareToField := form[compareToKey] + + if compareToField == nil { // Field to compare to doesn't exist or is empty + return err + } + + // Get the value of the field we want to compare to + var compareToValue int + rv := reflect.ValueOf(compareToField) + switch rv.Kind() { + case reflect.String: + v, atoiErr := strconv.Atoi(compareToField.(string)) + if atoiErr != nil { + panic(errStringToInt) + } + compareToValue = v + case reflect.Int: + compareToValue = compareToField.(int) + //... + // Handle other types such as float + } + + // Do the comparison + rv2 := reflect.ValueOf(v) + switch rv2.Kind() { + case reflect.Int: + if v.(int) <= compareToValue { + return err + } + case reflect.String: + vInt, atoiErr := strconv.Atoi(v.(string)) + if atoiErr != nil { + panic(errStringToInt) + } + if vInt <= compareToValue { + return err + } + //... + // Handle other types + } + + return nil +}) +``` + ### Custom Message/ Localization If you need to translate validation message you can pass messages as options. diff --git a/doc/CUSTOM_RULE.md b/doc/CUSTOM_RULE.md index f8c4ab2..131fed2 100644 --- a/doc/CUSTOM_RULE.md +++ b/doc/CUSTOM_RULE.md @@ -20,7 +20,7 @@ import ( func init() { // custom rules to take fixed length word. // e.g: max_word:5 will throw error if the field contains more than 5 words - govalidator.AddCustomRule("max_word", func(field string, rule string, message string, value interface{}) error { + govalidator.AddCustomRule("max_word", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { valSlice := strings.Fields(value.(string)) l, _ := strconv.Atoi(strings.TrimPrefix(rule, "max_word:")) //handle other error if len(valSlice) > l { diff --git a/rules.go b/rules.go index 8034f4d..86e6faa 100644 --- a/rules.go +++ b/rules.go @@ -11,14 +11,14 @@ import ( "strings" ) -var rulesFuncMap = make(map[string]func(string, string, string, interface{}) error) +var rulesFuncMap = make(map[string]func(string, string, string, interface{}, map[string]interface{}) error) // AddCustomRule help to add custom rules for validator // First argument it takes the rule name and second arg a func // Second arg must have this signature below // fn func(name string, fn func(field string, rule string, message string, value interface{}) error // see example in readme: https://github.com/thedevsaddam/govalidator#add-custom-rules -func AddCustomRule(name string, fn func(field string, rule string, message string, value interface{}) error) { +func AddCustomRule(name string, fn func(field string, rule string, message string, value interface{}, form map[string]interface{}) error) { if isRuleExist(name) { panic(fmt.Errorf("govalidator: %s is already defined in rules", name)) } @@ -26,10 +26,10 @@ func AddCustomRule(name string, fn func(field string, rule string, message strin } // validateCustomRules validate custom rules -func validateCustomRules(field string, rule string, message string, value interface{}, errsBag url.Values) { +func validateCustomRules(field string, rule string, message string, value interface{}, form map[string]interface{}, errsBag url.Values) { for k, v := range rulesFuncMap { if k == rule || strings.HasPrefix(rule, k+":") { - err := v(field, rule, message, value) + err := v(field, rule, message, value, form) if err != nil { errsBag.Add(field, err.Error()) } @@ -41,7 +41,7 @@ func validateCustomRules(field string, rule string, message string, value interf func init() { // Required check the Required fields - AddCustomRule("required", func(field, rule, message string, value interface{}) error { + AddCustomRule("required", func(field, rule, message string, value interface{}, form map[string]interface{}) error { err := fmt.Errorf("The %s field is required", field) if message != "" { err = errors.New(message) @@ -156,7 +156,7 @@ func init() { // Regex check the custom Regex rules // Regex:^[a-zA-Z]+$ means this field can only contain alphabet (a-z and A-Z) - AddCustomRule("regex", func(field, rule, message string, value interface{}) error { + AddCustomRule("regex", func(field, rule, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field format is invalid", field) if message != "" { @@ -170,7 +170,7 @@ func init() { }) // Alpha check if provided field contains valid letters - AddCustomRule("alpha", func(field string, vlaue string, message string, value interface{}) error { + AddCustomRule("alpha", func(field string, vlaue string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s may only contain letters", field) if message != "" { @@ -183,7 +183,7 @@ func init() { }) // AlphaDash check if provided field contains valid letters, numbers, underscore and dash - AddCustomRule("alpha_dash", func(field string, vlaue string, message string, value interface{}) error { + AddCustomRule("alpha_dash", func(field string, vlaue string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s may only contain letters, numbers, and dashes", field) if message != "" { @@ -196,7 +196,7 @@ func init() { }) // AlphaDash check if provided field contains valid letters, numbers, underscore and dash - AddCustomRule("alpha_space", func(field string, vlaue string, message string, value interface{}) error { + AddCustomRule("alpha_space", func(field string, vlaue string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s may only contain letters, numbers, dashes, space", field) if message != "" { @@ -209,7 +209,7 @@ func init() { }) // AlphaNumeric check if provided field contains valid letters and numbers - AddCustomRule("alpha_num", func(field string, vlaue string, message string, value interface{}) error { + AddCustomRule("alpha_num", func(field string, vlaue string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s may only contain letters and numbers", field) if message != "" { @@ -223,7 +223,7 @@ func init() { // Boolean check if provided field contains Boolean // in this case: "0", "1", 0, 1, "true", "false", true, false etc - AddCustomRule("bool", func(field string, vlaue string, message string, value interface{}) error { + AddCustomRule("bool", func(field string, vlaue string, message string, value interface{}, form map[string]interface{}) error { err := fmt.Errorf("The %s may only contain boolean value, string or int 0, 1", field) if message != "" { err = errors.New(message) @@ -286,7 +286,7 @@ func init() { // Between check the fields character length range // if the field is array, map, slice then the valdiation rule will be the length of the data // if the value is int or float then the valdiation rule will be the value comparison - AddCustomRule("between", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("between", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { rng := strings.Split(strings.TrimPrefix(rule, "between:"), ",") if len(rng) != 2 { panic(errInvalidArgument) @@ -387,7 +387,7 @@ func init() { // CreditCard check if provided field contains valid credit card number // Accepted cards are Visa, MasterCard, American Express, Diners Club, Discover and JCB card - AddCustomRule("credit_card", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("credit_card", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must be a valid credit card number", field) if message != "" { @@ -400,7 +400,7 @@ func init() { }) // Coordinate check if provided field contains valid Coordinate - AddCustomRule("coordinate", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("coordinate", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must be a valid coordinate", field) if message != "" { @@ -413,7 +413,7 @@ func init() { }) // ValidateCSSColor check if provided field contains a valid CSS color code - AddCustomRule("css_color", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("css_color", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must be a valid CSS color code", field) if message != "" { @@ -428,7 +428,7 @@ func init() { // Digits check the exact matching length of digit (0,9) // Digits:5 means the field must have 5 digit of length. // e.g: 12345 or 98997 etc - AddCustomRule("digits", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("digits", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { l, err := strconv.Atoi(strings.TrimPrefix(rule, "digits:")) if err != nil { panic(errStringToInt) @@ -460,7 +460,7 @@ func init() { // DigitsBetween check if the field contains only digit and length between provided range // e.g: digits_between:4,5 means the field can have value like: 8887 or 12345 etc - AddCustomRule("digits_between", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("digits_between", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { rng := strings.Split(strings.TrimPrefix(rule, "digits_between:"), ",") if len(rng) != 2 { panic(errInvalidArgument) @@ -486,7 +486,7 @@ func init() { }) // Date check the provided field is valid Date - AddCustomRule("date", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("date", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) switch rule { @@ -510,7 +510,7 @@ func init() { }) // Email check the provided field is valid Email - AddCustomRule("email", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("email", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must be a valid email address", field) if message != "" { @@ -523,7 +523,7 @@ func init() { }) // validFloat check the provided field is valid float number - AddCustomRule("float", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("float", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must be a float number", field) if message != "" { @@ -536,7 +536,7 @@ func init() { }) // IP check if provided field is valid IP address - AddCustomRule("ip", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("ip", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must be a valid IP address", field) if message != "" { @@ -549,7 +549,7 @@ func init() { }) // IP check if provided field is valid IP v4 address - AddCustomRule("ip_v4", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("ip_v4", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must be a valid IPv4 address", field) if message != "" { @@ -562,7 +562,7 @@ func init() { }) // IP check if provided field is valid IP v6 address - AddCustomRule("ip_v6", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("ip_v6", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must be a valid IPv6 address", field) if message != "" { @@ -575,7 +575,7 @@ func init() { }) // ValidateJSON check if provided field contains valid json string - AddCustomRule("json", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("json", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must contain valid JSON string", field) if message != "" { @@ -588,7 +588,7 @@ func init() { }) /// Latitude check if provided field contains valid Latitude - AddCustomRule("lat", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("lat", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must contain valid latitude", field) if message != "" { @@ -601,7 +601,7 @@ func init() { }) // Longitude check if provided field contains valid Longitude - AddCustomRule("lon", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("lon", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must contain valid longitude", field) if message != "" { @@ -614,7 +614,7 @@ func init() { }) // Length check the field's character Length - AddCustomRule("len", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("len", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { l, err := strconv.Atoi(strings.TrimPrefix(rule, "len:")) if err != nil { panic(errStringToInt) @@ -641,7 +641,7 @@ func init() { }) // Min check the field's minimum character length for string, value for int, float and size for array, map, slice - AddCustomRule("min", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("min", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { mustLen := strings.TrimPrefix(rule, "min:") lenInt, err := strconv.Atoi(mustLen) if err != nil { @@ -749,7 +749,7 @@ func init() { }) // Max check the field's maximum character length for string, value for int, float and size for array, map, slice - AddCustomRule("max", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("max", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { mustLen := strings.TrimPrefix(rule, "max:") lenInt, err := strconv.Atoi(mustLen) if err != nil { @@ -857,7 +857,7 @@ func init() { }) // Numeric check if the value of the field is Numeric - AddCustomRule("mac_address", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("mac_address", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must be a valid Mac Address", field) if message != "" { @@ -870,7 +870,7 @@ func init() { }) // Numeric check if the value of the field is Numeric - AddCustomRule("numeric", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("numeric", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must be numeric", field) if message != "" { @@ -885,7 +885,7 @@ func init() { // NumericBetween check if the value field numeric value range // e.g: numeric_between:18, 65 means number value must be in between a numeric value 18 & 65 // Both of the bounds can be omited turning it into a min only (`10,`) or a max only (`,10`) - AddCustomRule("numeric_between", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("numeric_between", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { rng := strings.Split(strings.TrimPrefix(rule, "numeric_between:"), ",") if len(rng) != 2 { panic(errInvalidArgument) @@ -981,7 +981,7 @@ func init() { }) // ValidateURL check if provided field is valid URL - AddCustomRule("url", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("url", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field format is invalid", field) if message != "" { @@ -994,7 +994,7 @@ func init() { }) // UUID check if provided field contains valid UUID - AddCustomRule("uuid", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("uuid", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must contain valid UUID", field) if message != "" { @@ -1007,7 +1007,7 @@ func init() { }) // UUID3 check if provided field contains valid UUID of version 3 - AddCustomRule("uuid_v3", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("uuid_v3", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must contain valid UUID V3", field) if message != "" { @@ -1020,7 +1020,7 @@ func init() { }) // UUID4 check if provided field contains valid UUID of version 4 - AddCustomRule("uuid_v4", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("uuid_v4", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must contain valid UUID V4", field) if message != "" { @@ -1033,7 +1033,7 @@ func init() { }) // UUID5 check if provided field contains valid UUID of version 5 - AddCustomRule("uuid_v5", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("uuid_v5", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { str := toString(value) err := fmt.Errorf("The %s field must contain valid UUID V5", field) if message != "" { @@ -1046,7 +1046,7 @@ func init() { }) // In check if provided field equals one of the values specified in the rule - AddCustomRule("in", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("in", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { rng := strings.Split(strings.TrimPrefix(rule, "in:"), ",") if len(rng) == 0 { panic(errInvalidArgument) @@ -1063,7 +1063,7 @@ func init() { }) // In check if provided field equals one of the values specified in the rule - AddCustomRule("not_in", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("not_in", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { rng := strings.Split(strings.TrimPrefix(rule, "not_in:"), ",") if len(rng) == 0 { panic(errInvalidArgument) diff --git a/rules_test.go b/rules_test.go index 3e8d0ce..1aee1e6 100644 --- a/rules_test.go +++ b/rules_test.go @@ -3,14 +3,18 @@ package govalidator import ( "bytes" "encoding/json" + "errors" "fmt" "net/http" "net/url" + "reflect" + "strconv" + "strings" "testing" ) func Test_AddCustomRule(t *testing.T) { - AddCustomRule("__x__", func(f string, rule string, message string, v interface{}) error { + AddCustomRule("__x__", func(f string, rule string, message string, v interface{}, form map[string]interface{}) error { if v.(string) != "xyz" { return fmt.Errorf("The %s field must be xyz", f) } @@ -27,7 +31,7 @@ func Test_AddCustomRule_panic(t *testing.T) { t.Errorf("AddCustomRule failed to panic") } }() - AddCustomRule("__x__", func(f string, rule string, message string, v interface{}) error { + AddCustomRule("__x__", func(f string, rule string, message string, v interface{}, form map[string]interface{}) error { if v.(string) != "xyz" { return fmt.Errorf("The %s field must be xyz", f) } @@ -37,7 +41,7 @@ func Test_AddCustomRule_panic(t *testing.T) { func Test_validateExtraRules(t *testing.T) { errsBag := url.Values{} - validateCustomRules("f_field", "__x__", "a", "", errsBag) + validateCustomRules("f_field", "__x__", "a", "", nil, errsBag) if len(errsBag) != 1 { t.Error("validateExtraRules failed") } @@ -2191,3 +2195,116 @@ func Test_NotIn_string_valid(t *testing.T) { t.Error("not_in validation was triggered when valid!") } } + +func ruleGreaterThan(f string, rule string, message string, v interface{}, form map[string]interface{}) error { + if form == nil { // All comparison rules should check if the form is provided + panic(errors.New("No form provided for comparison rule")) + } + + compareToKey := strings.TrimPrefix(rule, "__gt__:") + err := fmt.Errorf("The %s field must be greater than the %s field", f, compareToKey) + compareToField := form[compareToKey] + + if compareToField == nil { // Field to compare to doesn't exist or is empty + return err + } + + var compareToValue int + rv := reflect.ValueOf(compareToField) + switch rv.Kind() { + case reflect.String: + v, atoiErr := strconv.Atoi(compareToField.(string)) + if atoiErr != nil { + panic(errStringToInt) + } + compareToValue = v + case reflect.Int: + compareToValue = compareToField.(int) + } + + rv2 := reflect.ValueOf(v) + switch rv2.Kind() { // Simple int-only case here since we are just testing that we have access to other fields + case reflect.Int: + if v.(int) <= compareToValue { + return err + } + case reflect.String: + vInt, atoiErr := strconv.Atoi(v.(string)) + if atoiErr != nil { + panic(errStringToInt) + } + if vInt <= compareToValue { + return err + } + } + + return nil +} + +func Test_Comparison_JSON(t *testing.T) { + if rulesFuncMap["__gt__"] != nil { + delete(rulesFuncMap, "__gt__") + } + AddCustomRule("__gt__", ruleGreaterThan) + + type numbers struct { + Number int `json:"number"` + OtherNumber int `json:"other_number"` + } + + postNumbers := numbers{Number: 8, OtherNumber: 10} + var numbersObj numbers + + body, _ := json.Marshal(postNumbers) + req, _ := http.NewRequest("POST", "http://www.example.com", bytes.NewReader(body)) + + rules := MapData{ + "number": []string{"__gt__:other_number"}, + } + + opts := Options{ + Request: req, + Data: &numbersObj, + Rules: rules, + } + + vd := New(opts) + validationErr := vd.ValidateJSON() + + if len(validationErr) != 1 { + t.Log(validationErr) + t.Error("comparison validation failed!") + } +} + +func Test_Comparison_URL_Encoded(t *testing.T) { + if rulesFuncMap["__gt__"] != nil { + delete(rulesFuncMap, "__gt__") + } + AddCustomRule("__gt__", ruleGreaterThan) + + var URL *url.URL + URL, _ = url.Parse("http://www.example.com") + params := url.Values{} + params.Add("number", "8") + params.Add("other_number", "10") + req, _ := http.NewRequest("POST", URL.String(), strings.NewReader(params.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + rules := MapData{ + "number": []string{"__gt__:other_number"}, + } + + opts := Options{ + Request: req, + Rules: rules, + } + + vd := New(opts) + validationErr := vd.Validate() + + if len(validationErr) != 1 { + t.Log(validationErr) + t.Error("comparison validation failed!") + } +} diff --git a/validate_file_test.go b/validate_file_test.go index 8a6330e..b906ccd 100644 --- a/validate_file_test.go +++ b/validate_file_test.go @@ -97,7 +97,7 @@ func Test_validateFiles_CustomRule(t *testing.T) { customRule1WasExecuted := false isMultipartFile := false - AddCustomRule("customRule1", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("customRule1", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { customRule1WasExecuted = true _, isMultipartFile = value.(multipart.File) return nil @@ -105,7 +105,7 @@ func Test_validateFiles_CustomRule(t *testing.T) { customRule2WasExecuted := false isValueNil := false - AddCustomRule("customRule2", func(field string, rule string, message string, value interface{}) error { + AddCustomRule("customRule2", func(field string, rule string, message string, value interface{}, form map[string]interface{}) error { customRule2WasExecuted = true isValueNil = value == nil return nil diff --git a/validator.go b/validator.go index c88e74c..f5b5136 100644 --- a/validator.go +++ b/validator.go @@ -70,6 +70,14 @@ func (v *Validator) SetTagIdentifier(identifier string) { v.Opts.TagIdentifier = identifier } +func generateFlatMap(values url.Values) map[string]interface{} { + var flatMap map[string]interface{} = make(map[string]interface{}) + for field, value := range values { + flatMap[field] = value[0] + } + return flatMap +} + // Validate validate request data like form-data, x-www-form-urlencoded and query params // see example in README.md file // ref: https://github.com/thedevsaddam/govalidator#example @@ -82,6 +90,7 @@ func (v *Validator) Validate() url.Values { // get non required rules nr := v.getNonRequiredFields() + flatMap := generateFlatMap(v.Opts.Request.Form) for field, rules := range v.Opts.Rules { if _, ok := nr[field]; ok { @@ -98,14 +107,14 @@ func (v *Validator) Validate() url.Values { file, fh, _ := v.Opts.Request.FormFile(fld) if file != nil && fh.Filename != "" { validateFiles(v.Opts.Request, fld, rule, msg, errsBag) - validateCustomRules(fld, rule, msg, file, errsBag) + validateCustomRules(fld, rule, msg, file, nil, errsBag) } else { - validateCustomRules(fld, rule, msg, nil, errsBag) + validateCustomRules(fld, rule, msg, nil, nil, errsBag) } } else { // validate if custom rules exist reqVal := strings.TrimSpace(v.Opts.Request.Form.Get(field)) - validateCustomRules(field, rule, msg, reqVal, errsBag) + validateCustomRules(field, rule, msg, reqVal, flatMap, errsBag) } } } @@ -188,7 +197,8 @@ func (v *Validator) internalValidateStruct() url.Values { r.start(v.Opts.Data) //clean if the key is not exist or value is empty or zero value - nr := v.getNonRequiredJSONFields(r.getFlatMap()) + flatMap := r.getFlatMap() + nr := v.getNonRequiredJSONFields(flatMap) for field, rules := range v.Opts.Rules { if _, ok := nr[field]; ok { @@ -200,7 +210,7 @@ func (v *Validator) internalValidateStruct() url.Values { panic(fmt.Errorf("govalidator: %s is not a valid rule", rule)) } msg := v.getCustomMessage(field, rule) - validateCustomRules(field, rule, msg, value, errsBag) + validateCustomRules(field, rule, msg, value, flatMap, errsBag) } }