Skip to content

Commit

Permalink
Additional SchemaGeneratorHook that allows modifying GraphQLTypes. (#69)
Browse files Browse the repository at this point in the history
Additional SchemaGeneratorHook that enables rewiring based on the directives. Closes #60
  • Loading branch information
d4rken authored and dariuszkuc committed Nov 19, 2018
1 parent 11f2e9a commit d79e79c
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 45 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@

<properties>
<reflections.version>0.9.11</reflections.version>
<kotlin.version>1.2.71</kotlin.version>
<kotlin.version>1.3.10</kotlin.version>
<kotlin-ktlint.version>0.29.0</kotlin-ktlint.version>
<kotlin-detekt.version>1.0.0.RC8</kotlin-detekt.version>
<mockk.version>1.8.9.kotlin13</mockk.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import graphql.schema.GraphQLArgument
import graphql.schema.GraphQLDirective
import graphql.schema.GraphQLInputType
import kotlin.reflect.KAnnotatedElement
import kotlin.reflect.KClass
import kotlin.reflect.KParameter
import kotlin.reflect.full.findAnnotation
import com.expedia.graphql.annotations.GraphQLDirective as DirectiveAnnotation

Expand Down Expand Up @@ -64,11 +64,10 @@ internal fun KAnnotatedElement.isGraphQLIgnored() = this.findAnnotation<GraphQLI
internal fun KAnnotatedElement.isGraphQLID() = this.findAnnotation<GraphQLID>() != null

private fun Annotation.getDirectiveInfo(): DirectiveInfo? {
val directiveAnnotation = this.annotationClass.annotations.find { it is DirectiveAnnotation } as? DirectiveAnnotation
return when {
directiveAnnotation != null -> DirectiveInfo(this.annotationClass.simpleName ?: "", directiveAnnotation)
else -> null
}
return this.annotationClass.annotations
.filterIsInstance(DirectiveAnnotation::class.java)
.map { DirectiveInfo(this, it) }
.firstOrNull()
}

internal fun KAnnotatedElement.directives(hooks: SchemaGeneratorHooks) =
Expand All @@ -77,27 +76,32 @@ internal fun KAnnotatedElement.directives(hooks: SchemaGeneratorHooks) =
.map { it.getGraphQLDirective(hooks) }
.toList()

internal fun KParameter.directives(hooks: SchemaGeneratorHooks) =
this.annotations.asSequence()
.mapNotNull { it.getDirectiveInfo() }
.map { it.getGraphQLDirective(hooks) }
.toList()

@Throws(CouldNotGetNameOfAnnotationException::class)
private fun DirectiveInfo.getGraphQLDirective(hooks: SchemaGeneratorHooks): GraphQLDirective {
val kClass: KClass<out DirectiveAnnotation> = this.annotation.annotationClass
val builder = GraphQLDirective.newDirective()
val name: String = this.effectiveName ?: throw CouldNotGetNameOfAnnotationException(kClass)
val directiveClass = this.directive.annotationClass
val name: String = this.effectiveName ?: throw CouldNotGetNameOfAnnotationException(directiveClass)

@Suppress("Detekt.SpreadOperator")
val builder = GraphQLDirective.newDirective()
.name(name.normalizeDirectiveName())
.validLocations(*this.directiveAnnotation.locations)
.description(this.directiveAnnotation.description)

builder.name(name.normalizeDirectiveName())
.validLocations(*this.annotation.locations)
.description(this.annotation.description)
directiveClass.getValidProperties(hooks).forEach { prop ->
val propertyName = prop.name
val value = prop.call(this.directive)

kClass.getValidFunctions(hooks).forEach { kFunction ->
val propertyName = kFunction.name
val value = kFunction.call(kClass)
@Suppress("Detekt.UnsafeCast")
val type = defaultGraphQLScalars(kFunction.returnType) as GraphQLInputType
val type = defaultGraphQLScalars(prop.returnType) ?: hooks.willGenerateGraphQLType(prop.returnType)
val argument = GraphQLArgument.newArgument()
.name(propertyName)
.value(value)
.type(type)
.type(type as? GraphQLInputType)
.build()
builder.argument(argument)
}
Expand All @@ -107,10 +111,10 @@ private fun DirectiveInfo.getGraphQLDirective(hooks: SchemaGeneratorHooks): Grap

private fun String.normalizeDirectiveName() = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, this)

private data class DirectiveInfo(private val name: String, val annotation: DirectiveAnnotation) {
private data class DirectiveInfo(val directive: Annotation, val directiveAnnotation: DirectiveAnnotation) {
val effectiveName: String? = when {
annotation.name.isNotEmpty() -> annotation.name
name.isNotEmpty() -> name
directiveAnnotation.name.isNotEmpty() -> directiveAnnotation.name
directive.annotationClass.simpleName.isNullOrEmpty().not() -> directive.annotationClass.simpleName
else -> null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ internal class SchemaGenerator(

val monadType = config.hooks.willResolveMonad(fn.returnType)
builder.type(graphQLTypeOf(monadType) as GraphQLOutputType)
return builder.build()
val graphQLType = builder.build()
return config.hooks.onRewireGraphQLType(monadType, graphQLType) as GraphQLFieldDefinition
}

private fun property(prop: KProperty<*>): GraphQLFieldDefinition {
Expand All @@ -162,20 +163,33 @@ internal class SchemaGenerator(
.type(propertyType)
.deprecate(prop.getDeprecationReason())

return if (config.dataFetcherFactory != null && prop.isLateinit) {
prop.directives(config.hooks).forEach {
fieldBuilder.withDirective(it)
state.directives.add(it)
}

val field = if (config.dataFetcherFactory != null && prop.isLateinit) {
updatePropertyFieldBuilder(propertyType, fieldBuilder, config.dataFetcherFactory)
} else {
fieldBuilder
}.build()

return config.hooks.onRewireGraphQLType(prop.returnType, field) as GraphQLFieldDefinition
}

private fun argument(parameter: KParameter): GraphQLArgument {
parameter.throwIfUnathorizedInterface()
return GraphQLArgument.newArgument()
val builder = GraphQLArgument.newArgument()
.name(parameter.name)
.description(parameter.graphQLDescription() ?: parameter.type.graphQLDescription())
.type(graphQLTypeOf(parameter.type, true) as GraphQLInputType)
.build()

parameter.directives(config.hooks).forEach {
builder.withDirective(it)
state.directives.add(it)
}

return config.hooks.onRewireGraphQLType(parameter.type, builder.build()) as GraphQLArgument
}

private fun graphQLTypeOf(type: KType, inputType: Boolean = false, annotatedAsID: Boolean = false): GraphQLType {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package com.expedia.graphql.schema.generator.directive

import graphql.Assert.assertNotNull
import graphql.schema.GraphQLDirectiveContainer
import graphql.schema.GraphQLFieldDefinition
import graphql.schema.GraphQLObjectType
import graphql.schema.GraphQLInterfaceType
import graphql.schema.GraphQLUnionType
import graphql.schema.GraphQLScalarType
import graphql.schema.GraphQLEnumType
import graphql.schema.GraphQLEnumValueDefinition
import graphql.schema.GraphQLArgument
import graphql.schema.GraphQLInputObjectField
import graphql.schema.GraphQLInputObjectType
import graphql.schema.GraphQLDirective
import graphql.schema.GraphQLType
import graphql.schema.idl.SchemaDirectiveWiring
import graphql.schema.idl.SchemaDirectiveWiringEnvironment
import graphql.schema.idl.SchemaDirectiveWiringEnvironmentImpl
import graphql.schema.idl.WiringFactory

/**
* Based on
* https://github.com/graphql-java/graphql-java/blob/master/src/main/java/graphql/schema/idl/SchemaGeneratorDirectiveHelper.java
*/
class DirectiveWiringHelper(private val wiringFactory: WiringFactory, private val manualWiring: Map<String, SchemaDirectiveWiring> = mutableMapOf()) {

@Suppress("UNCHECKED_CAST", "Detekt.ComplexMethod")
fun onWire(generatedType: GraphQLType): GraphQLType {
if (generatedType !is GraphQLDirectiveContainer) return generatedType

return wireDirectives(generatedType, getDirectives(generatedType),
{ outputElement, directive -> createWiringEnvironment(outputElement, directive) },
{ wiring, environment ->
when (environment.element) {
is GraphQLObjectType -> wiring.onObject(environment as SchemaDirectiveWiringEnvironment<GraphQLObjectType>)
is GraphQLFieldDefinition -> wiring.onField(environment as SchemaDirectiveWiringEnvironment<GraphQLFieldDefinition>)
is GraphQLInterfaceType -> wiring.onInterface(environment as SchemaDirectiveWiringEnvironment<GraphQLInterfaceType>)
is GraphQLUnionType -> wiring.onUnion(environment as SchemaDirectiveWiringEnvironment<GraphQLUnionType>)
is GraphQLScalarType -> wiring.onScalar(environment as SchemaDirectiveWiringEnvironment<GraphQLScalarType>)
is GraphQLEnumType -> wiring.onEnum(environment as SchemaDirectiveWiringEnvironment<GraphQLEnumType>)
is GraphQLEnumValueDefinition -> wiring.onEnumValue(environment as SchemaDirectiveWiringEnvironment<GraphQLEnumValueDefinition>)
is GraphQLArgument -> wiring.onArgument(environment as SchemaDirectiveWiringEnvironment<GraphQLArgument>)
is GraphQLInputObjectType -> wiring.onInputObjectType(environment as SchemaDirectiveWiringEnvironment<GraphQLInputObjectType>)
is GraphQLInputObjectField -> wiring.onInputObjectField(environment as SchemaDirectiveWiringEnvironment<GraphQLInputObjectField>)
else -> generatedType
}
}
)
}

private fun getDirectives(generatedType: GraphQLDirectiveContainer): MutableList<GraphQLDirective> {
// A function without directives may still be rewired if the arguments have directives
val directives = generatedType.directives
if (generatedType is GraphQLFieldDefinition) {
generatedType.arguments.forEach { directives.addAll(it.directives) }
}
return directives
}

private fun <T : GraphQLDirectiveContainer> createWiringEnvironment(element: T, directive: GraphQLDirective): SchemaDirectiveWiringEnvironment<T> =
SchemaDirectiveWiringEnvironmentImpl(element, directive, null, null, null)

private fun <T : GraphQLDirectiveContainer> wireDirectives(
element: T,
directives: List<GraphQLDirective>,
envBuilder: (T, GraphQLDirective) -> SchemaDirectiveWiringEnvironment<T>,
invoker: (SchemaDirectiveWiring, SchemaDirectiveWiringEnvironment<T>) -> T
): T {
var outputObject = element
for (directive in directives) {
val env = envBuilder.invoke(outputObject, directive)
val directiveWiring = discoverWiringProvider(directive.name, env)
if (directiveWiring != null) {
val newElement = invoker.invoke(directiveWiring, env)
assertNotNull(newElement, "The SchemaDirectiveWiring MUST return a non null return value for element '" + element.name + "'")
outputObject = newElement
}
}
return outputObject
}

private fun <T : GraphQLDirectiveContainer> discoverWiringProvider(directiveName: String, env: SchemaDirectiveWiringEnvironment<T>): SchemaDirectiveWiring? {
return if (wiringFactory.providesSchemaDirectiveWiring(env)) {
wiringFactory.getSchemaDirectiveWiring(env)
} else {
manualWiring[directiveName]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ interface SchemaGeneratorHooks {
@Suppress("Detekt.FunctionOnlyReturningConstant")
fun isValidFunction(function: KFunction<*>): Boolean = true

/**
* Called after `willGenerateGraphQLType` and before `didGenerateGraphQLType`.
* Enables you to change the wiring, e.g. directives to alter data fetchers.
*/
fun onRewireGraphQLType(type: KType, generatedType: GraphQLType): GraphQLType = generatedType

/**
* Called after wrapping the type based on nullity but before adding the generated type to the schema
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.expedia.graphql.TopLevelObjectDef
import com.expedia.graphql.annotations.GraphQLDirective
import com.expedia.graphql.schema.testSchemaConfig
import com.expedia.graphql.toSchema
import graphql.Scalars
import graphql.introspection.Introspection
import graphql.schema.GraphQLInputObjectType
import graphql.schema.GraphQLNonNull
Expand Down Expand Up @@ -59,53 +60,106 @@ class DirectiveTests {

@Test
@Suppress("Detekt.UnsafeCast")
fun `SchemaGenerator creates directives`() {
fun `Directive renaming`() {
val schema = toSchema(listOf(TopLevelObjectDef(QueryObject())), config = testSchemaConfig)

val geographyType = schema.getType("Geography") as? GraphQLObjectType
assertNotNull(geographyType?.getDirective("whatever"))
assertNotNull(geographyType?.getFieldDefinition("somethingCool")?.getDirective("directiveOnFunction"))
assertNotNull((schema.getType("Location") as? GraphQLObjectType)?.getDirective("renamedDirective"))
assertNotNull(schema.getDirective("whatever"))
assertNotNull(schema.getDirective("renamedDirective"))
val directiveOnFunction = schema.getDirective("directiveOnFunction")
assertNotNull(directiveOnFunction)
val renamedDirective = assertNotNull(
(schema.getType("Location") as? GraphQLObjectType)
?.getDirective("rightNameDirective")
)

assertEquals("arenaming", renamedDirective.arguments[0].value)
assertEquals("arg", renamedDirective.arguments[0].name)
assertEquals(Scalars.GraphQLString, renamedDirective.arguments[0].type)
}

@Test
@Suppress("Detekt.UnsafeCast")
fun `Directives on classes`() {
val schema = toSchema(listOf(TopLevelObjectDef(QueryObject())), config = testSchemaConfig)

val directive = assertNotNull(
(schema.getType("Geography") as? GraphQLObjectType)
?.getDirective("onClassDirective")
)

assertEquals("aclass", directive.arguments[0].value)
assertEquals("arg", directive.arguments[0].name)
assertEquals(Scalars.GraphQLString, directive.arguments[0].type)
}

@Test
@Suppress("Detekt.UnsafeCast")
fun `Directives on functions`() {
val schema = toSchema(listOf(TopLevelObjectDef(QueryObject())), config = testSchemaConfig)

val directive = assertNotNull(
(schema.getType("Geography") as? GraphQLObjectType)
?.getFieldDefinition("somethingCool")
?.getDirective("onFunctionDirective")
)

assertEquals("afunction", directive.arguments[0].value)
assertEquals("arg", directive.arguments[0].name)
assertEquals(Scalars.GraphQLString, directive.arguments[0].type)

assertNotNull(directive)
assertEquals(
directiveOnFunction.validLocations()?.toSet(),
directive.validLocations()?.toSet(),
setOf(Introspection.DirectiveLocation.FIELD_DEFINITION, Introspection.DirectiveLocation.FIELD)
)
}

@Test
@Suppress("Detekt.UnsafeCast")
fun `Directives on arguments`() {
val schema = toSchema(listOf(TopLevelObjectDef(QueryObject())), config = testSchemaConfig)

val directive = assertNotNull(
schema.queryType
.getFieldDefinition("query")
.getArgument("value")
.getDirective("onArgumentDirective")
)

assertEquals("anargument", directive.arguments[0].value)
assertEquals("arg", directive.arguments[0].name)
assertEquals(Scalars.GraphQLString, directive.arguments[0].type)
}
}

@GraphQLDirective(name = "RightNameDirective")
annotation class WrongNameDirective(val arg: String)

@GraphQLDirective
annotation class Whatever
annotation class OnClassDirective(val arg: String)

@GraphQLDirective(locations = [Introspection.DirectiveLocation.FIELD_DEFINITION, Introspection.DirectiveLocation.FIELD])
annotation class DirectiveOnFunction
@GraphQLDirective
annotation class OnArgumentDirective(val arg: String)

@GraphQLDirective(name = "RenamedDirective")
annotation class RenamedDirective(val x: Boolean)
@GraphQLDirective(locations = [Introspection.DirectiveLocation.FIELD_DEFINITION, Introspection.DirectiveLocation.FIELD])
annotation class OnFunctionDirective(val arg: String)

@Whatever
@OnClassDirective(arg = "aclass")
class Geography(
val id: Int?,
val type: GeoType,
val locations: List<Location>
) {
@Suppress("Detekt.FunctionOnlyReturningConstant")
@DirectiveOnFunction
@OnFunctionDirective(arg = "afunction")
fun somethingCool(): String = "Something cool"
}

enum class GeoType {
CITY, STATE
}

@RenamedDirective(x = false)
@WrongNameDirective(arg = "arenaming")
data class Location(val lat: Double, val lon: Double)

class QueryObject {
fun query(value: Int): Geography = Geography(value, GeoType.CITY, listOf())
fun query(@OnArgumentDirective(arg = "anargument") value: Int): Geography = Geography(value, GeoType.CITY, listOf())
}

class QueryWithDeprecatedFields {
Expand Down

0 comments on commit d79e79c

Please sign in to comment.