diff --git a/pom.xml b/pom.xml index e1bd6a6a21..57b6da71cd 100644 --- a/pom.xml +++ b/pom.xml @@ -71,7 +71,7 @@ 0.9.11 - 1.2.71 + 1.3.10 0.29.0 1.0.0.RC8 1.8.9.kotlin13 diff --git a/src/main/kotlin/com/expedia/graphql/schema/extensions/annotationExtensions.kt b/src/main/kotlin/com/expedia/graphql/schema/extensions/annotationExtensions.kt index 0418504b0d..ebad71b28b 100644 --- a/src/main/kotlin/com/expedia/graphql/schema/extensions/annotationExtensions.kt +++ b/src/main/kotlin/com/expedia/graphql/schema/extensions/annotationExtensions.kt @@ -11,7 +11,7 @@ import graphql.schema.GraphQLArgument import graphql.schema.GraphQLDirective import graphql.schema.GraphQLInputType import kotlin.reflect.KAnnotatedElement -import kotlin.reflect.KClass +import kotlin.reflect.KParameter import kotlin.reflect.full.findAnnotation import com.expedia.graphql.annotations.GraphQLDirective as DirectiveAnnotation @@ -64,11 +64,10 @@ internal fun KAnnotatedElement.isGraphQLIgnored() = this.findAnnotation() != null private fun Annotation.getDirectiveInfo(): DirectiveInfo? { - val directiveAnnotation = this.annotationClass.annotations.find { it is DirectiveAnnotation } as? DirectiveAnnotation - return when { - directiveAnnotation != null -> DirectiveInfo(this.annotationClass.simpleName ?: "", directiveAnnotation) - else -> null - } + return this.annotationClass.annotations + .filterIsInstance(DirectiveAnnotation::class.java) + .map { DirectiveInfo(this, it) } + .firstOrNull() } internal fun KAnnotatedElement.directives(hooks: SchemaGeneratorHooks) = @@ -77,27 +76,32 @@ internal fun KAnnotatedElement.directives(hooks: SchemaGeneratorHooks) = .map { it.getGraphQLDirective(hooks) } .toList() +internal fun KParameter.directives(hooks: SchemaGeneratorHooks) = + this.annotations.asSequence() + .mapNotNull { it.getDirectiveInfo() } + .map { it.getGraphQLDirective(hooks) } + .toList() + @Throws(CouldNotGetNameOfAnnotationException::class) private fun DirectiveInfo.getGraphQLDirective(hooks: SchemaGeneratorHooks): GraphQLDirective { - val kClass: KClass = this.annotation.annotationClass - val builder = GraphQLDirective.newDirective() - val name: String = this.effectiveName ?: throw CouldNotGetNameOfAnnotationException(kClass) + val directiveClass = this.directive.annotationClass + val name: String = this.effectiveName ?: throw CouldNotGetNameOfAnnotationException(directiveClass) @Suppress("Detekt.SpreadOperator") + val builder = GraphQLDirective.newDirective() + .name(name.normalizeDirectiveName()) + .validLocations(*this.directiveAnnotation.locations) + .description(this.directiveAnnotation.description) - builder.name(name.normalizeDirectiveName()) - .validLocations(*this.annotation.locations) - .description(this.annotation.description) + directiveClass.getValidProperties(hooks).forEach { prop -> + val propertyName = prop.name + val value = prop.call(this.directive) - kClass.getValidFunctions(hooks).forEach { kFunction -> - val propertyName = kFunction.name - val value = kFunction.call(kClass) - @Suppress("Detekt.UnsafeCast") - val type = defaultGraphQLScalars(kFunction.returnType) as GraphQLInputType + val type = defaultGraphQLScalars(prop.returnType) ?: hooks.willGenerateGraphQLType(prop.returnType) val argument = GraphQLArgument.newArgument() .name(propertyName) .value(value) - .type(type) + .type(type as? GraphQLInputType) .build() builder.argument(argument) } @@ -107,10 +111,10 @@ private fun DirectiveInfo.getGraphQLDirective(hooks: SchemaGeneratorHooks): Grap private fun String.normalizeDirectiveName() = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, this) -private data class DirectiveInfo(private val name: String, val annotation: DirectiveAnnotation) { +private data class DirectiveInfo(val directive: Annotation, val directiveAnnotation: DirectiveAnnotation) { val effectiveName: String? = when { - annotation.name.isNotEmpty() -> annotation.name - name.isNotEmpty() -> name + directiveAnnotation.name.isNotEmpty() -> directiveAnnotation.name + directive.annotationClass.simpleName.isNullOrEmpty().not() -> directive.annotationClass.simpleName else -> null } } diff --git a/src/main/kotlin/com/expedia/graphql/schema/generator/SchemaGenerator.kt b/src/main/kotlin/com/expedia/graphql/schema/generator/SchemaGenerator.kt index 1a785d2454..6d7c4cfc3a 100644 --- a/src/main/kotlin/com/expedia/graphql/schema/generator/SchemaGenerator.kt +++ b/src/main/kotlin/com/expedia/graphql/schema/generator/SchemaGenerator.kt @@ -150,7 +150,8 @@ internal class SchemaGenerator( val monadType = config.hooks.willResolveMonad(fn.returnType) builder.type(graphQLTypeOf(monadType) as GraphQLOutputType) - return builder.build() + val graphQLType = builder.build() + return config.hooks.onRewireGraphQLType(monadType, graphQLType) as GraphQLFieldDefinition } private fun property(prop: KProperty<*>): GraphQLFieldDefinition { @@ -162,20 +163,33 @@ internal class SchemaGenerator( .type(propertyType) .deprecate(prop.getDeprecationReason()) - return if (config.dataFetcherFactory != null && prop.isLateinit) { + prop.directives(config.hooks).forEach { + fieldBuilder.withDirective(it) + state.directives.add(it) + } + + val field = if (config.dataFetcherFactory != null && prop.isLateinit) { updatePropertyFieldBuilder(propertyType, fieldBuilder, config.dataFetcherFactory) } else { fieldBuilder }.build() + + return config.hooks.onRewireGraphQLType(prop.returnType, field) as GraphQLFieldDefinition } private fun argument(parameter: KParameter): GraphQLArgument { parameter.throwIfUnathorizedInterface() - return GraphQLArgument.newArgument() + val builder = GraphQLArgument.newArgument() .name(parameter.name) .description(parameter.graphQLDescription() ?: parameter.type.graphQLDescription()) .type(graphQLTypeOf(parameter.type, true) as GraphQLInputType) - .build() + + parameter.directives(config.hooks).forEach { + builder.withDirective(it) + state.directives.add(it) + } + + return config.hooks.onRewireGraphQLType(parameter.type, builder.build()) as GraphQLArgument } private fun graphQLTypeOf(type: KType, inputType: Boolean = false, annotatedAsID: Boolean = false): GraphQLType { diff --git a/src/main/kotlin/com/expedia/graphql/schema/generator/directive/DirectiveWiringHelper.kt b/src/main/kotlin/com/expedia/graphql/schema/generator/directive/DirectiveWiringHelper.kt new file mode 100644 index 0000000000..01a6227856 --- /dev/null +++ b/src/main/kotlin/com/expedia/graphql/schema/generator/directive/DirectiveWiringHelper.kt @@ -0,0 +1,90 @@ +package com.expedia.graphql.schema.generator.directive + +import graphql.Assert.assertNotNull +import graphql.schema.GraphQLDirectiveContainer +import graphql.schema.GraphQLFieldDefinition +import graphql.schema.GraphQLObjectType +import graphql.schema.GraphQLInterfaceType +import graphql.schema.GraphQLUnionType +import graphql.schema.GraphQLScalarType +import graphql.schema.GraphQLEnumType +import graphql.schema.GraphQLEnumValueDefinition +import graphql.schema.GraphQLArgument +import graphql.schema.GraphQLInputObjectField +import graphql.schema.GraphQLInputObjectType +import graphql.schema.GraphQLDirective +import graphql.schema.GraphQLType +import graphql.schema.idl.SchemaDirectiveWiring +import graphql.schema.idl.SchemaDirectiveWiringEnvironment +import graphql.schema.idl.SchemaDirectiveWiringEnvironmentImpl +import graphql.schema.idl.WiringFactory + +/** + * Based on + * https://github.com/graphql-java/graphql-java/blob/master/src/main/java/graphql/schema/idl/SchemaGeneratorDirectiveHelper.java + */ +class DirectiveWiringHelper(private val wiringFactory: WiringFactory, private val manualWiring: Map = mutableMapOf()) { + + @Suppress("UNCHECKED_CAST", "Detekt.ComplexMethod") + fun onWire(generatedType: GraphQLType): GraphQLType { + if (generatedType !is GraphQLDirectiveContainer) return generatedType + + return wireDirectives(generatedType, getDirectives(generatedType), + { outputElement, directive -> createWiringEnvironment(outputElement, directive) }, + { wiring, environment -> + when (environment.element) { + is GraphQLObjectType -> wiring.onObject(environment as SchemaDirectiveWiringEnvironment) + is GraphQLFieldDefinition -> wiring.onField(environment as SchemaDirectiveWiringEnvironment) + is GraphQLInterfaceType -> wiring.onInterface(environment as SchemaDirectiveWiringEnvironment) + is GraphQLUnionType -> wiring.onUnion(environment as SchemaDirectiveWiringEnvironment) + is GraphQLScalarType -> wiring.onScalar(environment as SchemaDirectiveWiringEnvironment) + is GraphQLEnumType -> wiring.onEnum(environment as SchemaDirectiveWiringEnvironment) + is GraphQLEnumValueDefinition -> wiring.onEnumValue(environment as SchemaDirectiveWiringEnvironment) + is GraphQLArgument -> wiring.onArgument(environment as SchemaDirectiveWiringEnvironment) + is GraphQLInputObjectType -> wiring.onInputObjectType(environment as SchemaDirectiveWiringEnvironment) + is GraphQLInputObjectField -> wiring.onInputObjectField(environment as SchemaDirectiveWiringEnvironment) + else -> generatedType + } + } + ) + } + + private fun getDirectives(generatedType: GraphQLDirectiveContainer): MutableList { + // A function without directives may still be rewired if the arguments have directives + val directives = generatedType.directives + if (generatedType is GraphQLFieldDefinition) { + generatedType.arguments.forEach { directives.addAll(it.directives) } + } + return directives + } + + private fun createWiringEnvironment(element: T, directive: GraphQLDirective): SchemaDirectiveWiringEnvironment = + SchemaDirectiveWiringEnvironmentImpl(element, directive, null, null, null) + + private fun wireDirectives( + element: T, + directives: List, + envBuilder: (T, GraphQLDirective) -> SchemaDirectiveWiringEnvironment, + invoker: (SchemaDirectiveWiring, SchemaDirectiveWiringEnvironment) -> T + ): T { + var outputObject = element + for (directive in directives) { + val env = envBuilder.invoke(outputObject, directive) + val directiveWiring = discoverWiringProvider(directive.name, env) + if (directiveWiring != null) { + val newElement = invoker.invoke(directiveWiring, env) + assertNotNull(newElement, "The SchemaDirectiveWiring MUST return a non null return value for element '" + element.name + "'") + outputObject = newElement + } + } + return outputObject + } + + private fun discoverWiringProvider(directiveName: String, env: SchemaDirectiveWiringEnvironment): SchemaDirectiveWiring? { + return if (wiringFactory.providesSchemaDirectiveWiring(env)) { + wiringFactory.getSchemaDirectiveWiring(env) + } else { + manualWiring[directiveName] + } + } +} diff --git a/src/main/kotlin/com/expedia/graphql/schema/hooks/SchemaGeneratorHooks.kt b/src/main/kotlin/com/expedia/graphql/schema/hooks/SchemaGeneratorHooks.kt index 18fe0c2bc9..a17f2ebd77 100644 --- a/src/main/kotlin/com/expedia/graphql/schema/hooks/SchemaGeneratorHooks.kt +++ b/src/main/kotlin/com/expedia/graphql/schema/hooks/SchemaGeneratorHooks.kt @@ -53,6 +53,12 @@ interface SchemaGeneratorHooks { @Suppress("Detekt.FunctionOnlyReturningConstant") fun isValidFunction(function: KFunction<*>): Boolean = true + /** + * Called after `willGenerateGraphQLType` and before `didGenerateGraphQLType`. + * Enables you to change the wiring, e.g. directives to alter data fetchers. + */ + fun onRewireGraphQLType(type: KType, generatedType: GraphQLType): GraphQLType = generatedType + /** * Called after wrapping the type based on nullity but before adding the generated type to the schema */ diff --git a/src/test/kotlin/com/expedia/graphql/schema/generator/DirectiveTests.kt b/src/test/kotlin/com/expedia/graphql/schema/generator/DirectiveTests.kt index dd78c8d8a9..6f68921ab0 100644 --- a/src/test/kotlin/com/expedia/graphql/schema/generator/DirectiveTests.kt +++ b/src/test/kotlin/com/expedia/graphql/schema/generator/DirectiveTests.kt @@ -4,6 +4,7 @@ import com.expedia.graphql.TopLevelObjectDef import com.expedia.graphql.annotations.GraphQLDirective import com.expedia.graphql.schema.testSchemaConfig import com.expedia.graphql.toSchema +import graphql.Scalars import graphql.introspection.Introspection import graphql.schema.GraphQLInputObjectType import graphql.schema.GraphQLNonNull @@ -59,41 +60,94 @@ class DirectiveTests { @Test @Suppress("Detekt.UnsafeCast") - fun `SchemaGenerator creates directives`() { + fun `Directive renaming`() { val schema = toSchema(listOf(TopLevelObjectDef(QueryObject())), config = testSchemaConfig) - val geographyType = schema.getType("Geography") as? GraphQLObjectType - assertNotNull(geographyType?.getDirective("whatever")) - assertNotNull(geographyType?.getFieldDefinition("somethingCool")?.getDirective("directiveOnFunction")) - assertNotNull((schema.getType("Location") as? GraphQLObjectType)?.getDirective("renamedDirective")) - assertNotNull(schema.getDirective("whatever")) - assertNotNull(schema.getDirective("renamedDirective")) - val directiveOnFunction = schema.getDirective("directiveOnFunction") - assertNotNull(directiveOnFunction) + val renamedDirective = assertNotNull( + (schema.getType("Location") as? GraphQLObjectType) + ?.getDirective("rightNameDirective") + ) + + assertEquals("arenaming", renamedDirective.arguments[0].value) + assertEquals("arg", renamedDirective.arguments[0].name) + assertEquals(Scalars.GraphQLString, renamedDirective.arguments[0].type) + } + + @Test + @Suppress("Detekt.UnsafeCast") + fun `Directives on classes`() { + val schema = toSchema(listOf(TopLevelObjectDef(QueryObject())), config = testSchemaConfig) + + val directive = assertNotNull( + (schema.getType("Geography") as? GraphQLObjectType) + ?.getDirective("onClassDirective") + ) + + assertEquals("aclass", directive.arguments[0].value) + assertEquals("arg", directive.arguments[0].name) + assertEquals(Scalars.GraphQLString, directive.arguments[0].type) + } + + @Test + @Suppress("Detekt.UnsafeCast") + fun `Directives on functions`() { + val schema = toSchema(listOf(TopLevelObjectDef(QueryObject())), config = testSchemaConfig) + + val directive = assertNotNull( + (schema.getType("Geography") as? GraphQLObjectType) + ?.getFieldDefinition("somethingCool") + ?.getDirective("onFunctionDirective") + ) + + assertEquals("afunction", directive.arguments[0].value) + assertEquals("arg", directive.arguments[0].name) + assertEquals(Scalars.GraphQLString, directive.arguments[0].type) + + assertNotNull(directive) assertEquals( - directiveOnFunction.validLocations()?.toSet(), + directive.validLocations()?.toSet(), setOf(Introspection.DirectiveLocation.FIELD_DEFINITION, Introspection.DirectiveLocation.FIELD) ) } + + @Test + @Suppress("Detekt.UnsafeCast") + fun `Directives on arguments`() { + val schema = toSchema(listOf(TopLevelObjectDef(QueryObject())), config = testSchemaConfig) + + val directive = assertNotNull( + schema.queryType + .getFieldDefinition("query") + .getArgument("value") + .getDirective("onArgumentDirective") + ) + + assertEquals("anargument", directive.arguments[0].value) + assertEquals("arg", directive.arguments[0].name) + assertEquals(Scalars.GraphQLString, directive.arguments[0].type) + } } +@GraphQLDirective(name = "RightNameDirective") +annotation class WrongNameDirective(val arg: String) + @GraphQLDirective -annotation class Whatever +annotation class OnClassDirective(val arg: String) -@GraphQLDirective(locations = [Introspection.DirectiveLocation.FIELD_DEFINITION, Introspection.DirectiveLocation.FIELD]) -annotation class DirectiveOnFunction +@GraphQLDirective +annotation class OnArgumentDirective(val arg: String) -@GraphQLDirective(name = "RenamedDirective") -annotation class RenamedDirective(val x: Boolean) +@GraphQLDirective(locations = [Introspection.DirectiveLocation.FIELD_DEFINITION, Introspection.DirectiveLocation.FIELD]) +annotation class OnFunctionDirective(val arg: String) -@Whatever +@OnClassDirective(arg = "aclass") class Geography( val id: Int?, val type: GeoType, val locations: List ) { @Suppress("Detekt.FunctionOnlyReturningConstant") - @DirectiveOnFunction + @OnFunctionDirective(arg = "afunction") fun somethingCool(): String = "Something cool" } @@ -101,11 +155,11 @@ enum class GeoType { CITY, STATE } -@RenamedDirective(x = false) +@WrongNameDirective(arg = "arenaming") data class Location(val lat: Double, val lon: Double) class QueryObject { - fun query(value: Int): Geography = Geography(value, GeoType.CITY, listOf()) + fun query(@OnArgumentDirective(arg = "anargument") value: Int): Geography = Geography(value, GeoType.CITY, listOf()) } class QueryWithDeprecatedFields {