diff --git a/graphql_test.go b/graphql_test.go index f370cfc7..9cc178d9 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -327,3 +327,66 @@ func TestQueryWithCustomRule(t *testing.T) { } } } + +// TestCustomRuleWithArgs tests graphql.GetArgumentValues() be able to access +// field's argument values from custom validation rule. +func TestCustomRuleWithArgs(t *testing.T) { + fieldDef, ok := testutil.StarWarsSchema.QueryType().Fields()["human"] + if !ok { + t.Fatal("can't retrieve \"human\" field definition") + } + + // a custom validation rule to extract argument values of "human" field. + var actual map[string]interface{} + enter := func(p visitor.VisitFuncParams) (string, interface{}) { + // only interested in "human" field. + fieldNode, ok := p.Node.(*ast.Field) + if !ok || fieldNode.Name == nil || fieldNode.Name.Value != "human" { + return visitor.ActionNoChange, nil + } + // extract argument values by graphql.GetArgumentValues(). + actual = graphql.GetArgumentValues(fieldDef.Args, fieldNode.Arguments, nil) + return visitor.ActionNoChange, nil + } + checkHumanArgs := func(context *graphql.ValidationContext) *graphql.ValidationRuleInstance { + return &graphql.ValidationRuleInstance{ + VisitorOpts: &visitor.VisitorOptions{ + KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.Field: {Enter: enter}, + }, + }, + } + } + + for _, tc := range []struct { + query string + expected map[string]interface{} + }{ + { + `query { human(id: "1000") { name } }`, + map[string]interface{}{"id": "1000"}, + }, + { + `query { human(id: "1002") { name } }`, + map[string]interface{}{"id": "1002"}, + }, + { + `query { human(id: "9999") { name } }`, + map[string]interface{}{"id": "9999"}, + }, + } { + actual = nil + params := graphql.Params{ + Schema: testutil.StarWarsSchema, + RequestString: tc.query, + ValidationRules: append(graphql.SpecifiedRules, checkHumanArgs), + } + result := graphql.Do(params) + if len(result.Errors) > 0 { + t.Fatalf("wrong result, unexpected errors: %v", result.Errors) + } + if !reflect.DeepEqual(actual, tc.expected) { + t.Fatalf("unexpected result: want=%+v got=%+v", tc.expected, actual) + } + } +} diff --git a/values.go b/values.go index 06c08af6..8dc5210f 100644 --- a/values.go +++ b/values.go @@ -67,6 +67,17 @@ func getArgumentValues( return results } +// GetArgumentValues prepares an object map of argument values given a list of +// argument definitions and list of argument AST nodes. +// +// This is an exported version of getArgumentValues(), to ease writing custom +// validation rules. +func GetArgumentValues( + argDefs []*Argument, argASTs []*ast.Argument, + variableValues map[string]interface{}) map[string]interface{} { + return getArgumentValues(argDefs, argASTs, variableValues) +} + // Given a variable definition, and any value of input, return a value which // adheres to the variable definition, or throw an error. func getVariableValue(schema Schema, definitionAST *ast.VariableDefinition, input interface{}) (interface{}, error) {