From 8fe1f86ea155b4ecf3ea7386b88b608d07ed5920 Mon Sep 17 00:00:00 2001 From: Yann Simon Date: Wed, 28 Sep 2022 19:09:56 +0200 Subject: [PATCH] API for setting custom directives Prototype for https://github.com/sangria-graphql/sangria/discussions/913 In this prototype, we check how we can apply custom directives much easier. Furthermore, we check if we can prove type safety, to that a custom directive that can be applied on some elements cannot be applied on others. Limitations: - a directive can be applied on a field definition for example. With that current approach, we cannot formulate that with the type system as a field definition lives in sangria.ast, and we only handle sangria.schema types. - we are introducing new types to mark on which elements a directive can be applied. Those types are kind of duplication of the current [sangria.schema.DirectiveLocation values](https://github.com/sangria-graphql/sangria/blob/f339b5df97bd89c2a24fcfc977a1f20191ffd7fc/modules/core/src/main/scala/sangria/schema/Schema.scala#L1136-L1158). --- build.sbt | 5 +- .../src/main/scala/sangria/ast/QueryAst.scala | 2 +- .../sangria/schema/AstSchemaBuilder.scala | 63 +++++--- .../schema/AstSchemaMaterializer.scala | 3 +- .../ResolverBasedAstSchemaBuilder.scala | 4 +- .../main/scala/sangria/schema/Schema.scala | 134 ++++++++++++++--- .../InputDocumentMaterializerSpec.scala | 3 +- .../sangria/schema/CustomDirectiveSpec.scala | 140 ++++++++++++++++++ 8 files changed, 305 insertions(+), 49 deletions(-) create mode 100644 modules/core/src/test/scala/sangria/schema/CustomDirectiveSpec.scala diff --git a/build.sbt b/build.sbt index a4257907..69cb1a2d 100644 --- a/build.sbt +++ b/build.sbt @@ -126,7 +126,10 @@ lazy val core = project ProblemFilters.exclude[MissingTypesProblem]("sangria.schema.Directive$"), ProblemFilters.exclude[MissingTypesProblem]("sangria.schema.MappedAbstractType"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "sangria.execution.Resolver.resolveSimpleListValue") + "sangria.execution.Resolver.resolveSimpleListValue"), + ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.schema.Argument.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.schema.Field.subs"), + ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.schema.Field.apply") ), Test / testOptions += Tests.Argument(TestFrameworks.ScalaTest, "-oF"), libraryDependencies ++= Seq( diff --git a/modules/ast/src/main/scala/sangria/ast/QueryAst.scala b/modules/ast/src/main/scala/sangria/ast/QueryAst.scala index 80b81563..0fc53333 100644 --- a/modules/ast/src/main/scala/sangria/ast/QueryAst.scala +++ b/modules/ast/src/main/scala/sangria/ast/QueryAst.scala @@ -325,7 +325,7 @@ sealed trait WithDirectives extends AstNode { case class Directive( name: String, - arguments: Vector[Argument], + arguments: Vector[Argument] = Vector.empty, comments: Vector[Comment] = Vector.empty, location: Option[AstLocation] = None) extends AstNode diff --git a/modules/core/src/main/scala/sangria/schema/AstSchemaBuilder.scala b/modules/core/src/main/scala/sangria/schema/AstSchemaBuilder.scala index 3527ddf1..d0b9fb64 100644 --- a/modules/core/src/main/scala/sangria/schema/AstSchemaBuilder.scala +++ b/modules/core/src/main/scala/sangria/schema/AstSchemaBuilder.scala @@ -324,8 +324,8 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { description = definition.flatMap(_.description.map(_.value)), directives = directives, astDirectives = - definition.fold(Vector.empty[ast.Directive])(_.directives) ++ extensions.flatMap( - _.directives), + (definition.fold(Vector.empty[ast.Directive])(_.directives) ++ extensions.flatMap( + _.directives)).asInstanceOf[Vector[ast.Directive with OnSchema]], astNodes = Vector(mat.document) ++ extensions ++ definition.toVector ) @@ -346,7 +346,8 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { directives = directives, description = originalSchema.description, validationRules = originalSchema.validationRules, - astDirectives = originalSchema.astDirectives ++ extensions.flatMap(_.directives), + astDirectives = (originalSchema.astDirectives ++ extensions.flatMap(_.directives)) + .asInstanceOf[Vector[ast.Directive with OnSchema]], astNodes = { val (docs, other) = originalSchema.astNodes.partition(_.isInstanceOf[ast.Document]) val newDoc = ast.Document.merge(docs.asInstanceOf[Vector[ast.Document]] :+ mat.document) @@ -374,7 +375,7 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { interfaces = interfaces, instanceCheck = (value: Any, clazz: Class[_], _: ObjectType[Ctx, Any]) => fn(value, clazz), - astDirectives = directives, + astDirectives = directives.asInstanceOf[Vector[ast.Directive with OnObjectType]], astNodes = (definition +: extensions).toVector ) case None => @@ -384,7 +385,7 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { fieldsFn = fields, interfaces = interfaces, instanceCheck = ObjectType.defaultInstanceCheck[Ctx, Any], - astDirectives = directives, + astDirectives = directives.asInstanceOf[Vector[ast.Directive with OnObjectType]], astNodes = (definition +: extensions).toVector ) } @@ -404,7 +405,8 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { existing.copy( fieldsFn = fields, interfaces = interfaces, - astDirectives = existing.astDirectives ++ extensions.flatMap(_.directives), + astDirectives = existing.astDirectives ++ extensions.flatMap( + _.directives.asInstanceOf[Vector[ast.Directive with OnObjectType]]), astNodes = existing.astNodes ++ extensions, instanceCheck = (value: Any, clazz: Class[_], _: ObjectType[Ctx, Any]) => fn(value, clazz) )(ClassTag(existing.valClass)) @@ -412,7 +414,8 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { existing.copy( fieldsFn = fields, interfaces = interfaces, - astDirectives = existing.astDirectives ++ extensions.flatMap(_.directives), + astDirectives = existing.astDirectives ++ extensions.flatMap( + _.directives.asInstanceOf[Vector[ast.Directive with OnObjectType]]), astNodes = existing.astNodes ++ extensions, instanceCheck = existing.instanceCheck.asInstanceOf[(Any, Class[_], ObjectType[Ctx, _]) => Boolean] @@ -430,7 +433,8 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { name = typeName(definition), description = typeDescription(definition), fieldsFn = fields, - astDirectives = definition.directives ++ extensions.flatMap(_.directives), + astDirectives = (definition.directives ++ extensions.flatMap(_.directives)) + .asInstanceOf[Vector[ast.Directive with OnInputObjectType]], astNodes = definition +: extensions )) @@ -449,7 +453,7 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { fieldsFn = fields, interfaces = Nil, manualPossibleTypes = () => Nil, - astDirectives = directives, + astDirectives = directives.asInstanceOf[Vector[ast.Directive with OnInterfaceType]], astNodes = (definition +: extensions).toVector )) } @@ -464,7 +468,9 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { fieldsFn = fields, manualPossibleTypes = () => Nil, interfaces = Nil, - astDirectives = existing.astDirectives ++ extensions.flatMap(_.directives), + astDirectives = existing.astDirectives ++ extensions + .flatMap(_.directives) + .asInstanceOf[scala.collection.IterableOnce[ast.Directive with OnInterfaceType]], astNodes = existing.astNodes ++ extensions ) @@ -479,7 +485,8 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { name = typeName(definition), description = typeDescription(definition), types = types, - astDirectives = definition.directives ++ extensions.flatMap(_.directives), + astDirectives = (definition.directives ++ extensions.flatMap(_.directives)) + .asInstanceOf[Vector[ast.Directive with OnUnionType]], astNodes = definition +: extensions )) @@ -491,8 +498,10 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { mat: AstSchemaMaterializer[Ctx]): UnionType[Ctx] = existing.copy( typesFn = () => types, - astDirectives = existing.astDirectives ++ extensions.flatMap(_.directives), - astNodes = existing.astNodes ++ extensions) + astDirectives = (existing.astDirectives ++ extensions.flatMap(_.directives)) + .asInstanceOf[Vector[ast.Directive with OnUnionType]], + astNodes = existing.astNodes ++ extensions + ) def extendScalarAlias[T, ST]( origin: MatOrigin, @@ -516,7 +525,8 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { coerceInput = scalarCoerceInput(definition), complexity = scalarComplexity(definition), scalarInfo = scalarValueInfo(definition), - astDirectives = definition.directives ++ extensions.flatMap(_.directives), + astDirectives = (definition.directives ++ extensions.flatMap(_.directives)) + .asInstanceOf[Vector[ast.Directive with OnScalarType]], astNodes = definition +: extensions )) @@ -531,7 +541,8 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { name = typeName(definition), description = typeDescription(definition), values = values, - astDirectives = definition.directives ++ extensions.flatMap(_.directives), + astDirectives = (definition.directives ++ extensions.flatMap(_.directives)) + .asInstanceOf[Vector[ast.Directive with OnEnumType]], astNodes = definition +: extensions )) @@ -547,7 +558,7 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { description = enumValueDescription(definition), value = enumValue(typeDefinition, definition), deprecationReason = enumValueDeprecationReason(definition), - astDirectives = definition.directives, + astDirectives = definition.directives.asInstanceOf[Vector[ast.Directive with OnEnumValue]], astNodes = Vector(definition) )) @@ -570,7 +581,7 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { deprecationReason = fieldDeprecationReason(definition), complexity = fieldComplexity(typeDefinition, definition), manualPossibleTypes = () => Nil, - astDirectives = definition.directives, + astDirectives = definition.directives.asInstanceOf[Vector[ast.Directive with OnField]], astNodes = Vector(definition) )) @@ -655,7 +666,7 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { description = inputFieldDescription(definition), fieldType = tpe, defaultValue = defaultValue, - astDirectives = definition.directives, + astDirectives = definition.directives.asInstanceOf[Vector[ast.Directive with OnInputField]], astNodes = Vector(definition) )) @@ -683,7 +694,7 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { argumentType = tpe, defaultValue = defaultValue, fromInput = argumentFromInput(typeDefinition, fieldDefinition, definition), - astDirectives = definition.directives, + astDirectives = definition.directives.asInstanceOf[Vector[ast.Directive with OnArgument]], astNodes = Vector(definition) )) @@ -720,8 +731,10 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { mat: AstSchemaMaterializer[Ctx]): InputObjectType[T] = existing.copy( fieldsFn = fields, - astDirectives = existing.astDirectives ++ extensions.flatMap(_.directives), - astNodes = existing.astNodes ++ extensions) + astDirectives = (existing.astDirectives ++ extensions.flatMap(_.directives)) + .asInstanceOf[Vector[ast.Directive with OnInputObjectType]], + astNodes = existing.astNodes ++ extensions + ) def transformEnumType[T]( origin: MatOrigin, @@ -731,7 +744,9 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { val dirs = existing.astDirectives ++ extensions.flatMap(_.directives) if (dirs.nonEmpty) - existing.copy(astDirectives = dirs, astNodes = existing.astNodes ++ extensions) + existing.copy( + astDirectives = dirs.asInstanceOf[Vector[ast.Directive with OnEnumType]], + astNodes = existing.astNodes ++ extensions) else existing } @@ -743,7 +758,9 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] { val dirs = existing.astDirectives ++ extensions.flatMap(_.directives) if (dirs.nonEmpty) - existing.copy(astDirectives = dirs, astNodes = existing.astNodes ++ extensions) + existing.copy( + astDirectives = dirs.asInstanceOf[Vector[ast.Directive with OnScalarType]], + astNodes = existing.astNodes ++ extensions) else existing } diff --git a/modules/core/src/main/scala/sangria/schema/AstSchemaMaterializer.scala b/modules/core/src/main/scala/sangria/schema/AstSchemaMaterializer.scala index 0201662b..85ab7f69 100644 --- a/modules/core/src/main/scala/sangria/schema/AstSchemaMaterializer.scala +++ b/modules/core/src/main/scala/sangria/schema/AstSchemaMaterializer.scala @@ -703,7 +703,8 @@ class AstSchemaMaterializer[Ctx] private ( def extendEnumType(origin: MatOrigin, tpe: EnumType[_]) = { val extensions = findEnumExtensions(tpe.name) val extraValues = extensions.flatMap(_.values) - val extraDirs = extensions.flatMap(_.directives) + val extraDirs = + extensions.flatMap(_.directives).asInstanceOf[Vector[ast.Directive with OnEnumType]] val ev = extraValues.flatMap(buildEnumValue(origin, Right(tpe), _, extensions)) diff --git a/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala b/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala index d738f362..6c99c587 100644 --- a/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala +++ b/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala @@ -149,8 +149,8 @@ class ResolverBasedAstSchemaBuilder[Ctx](val resolvers: Seq[AstSchemaResolver[Ct description = definition.flatMap(_.description.map(_.value)), directives = directives, astDirectives = - definition.fold(Vector.empty[ast.Directive])(_.directives) ++ extensions.flatMap( - _.directives), + (definition.fold(Vector.empty[ast.Directive])(_.directives) ++ extensions.flatMap( + _.directives)).asInstanceOf[Vector[ast.Directive with OnSchema]], astNodes = Vector(mat.document) ++ extensions ++ definition.toVector, validationRules = SchemaValidationRule.default :+ new ResolvedDirectiveValidationRule( this.directives.filterNot(_.repeatable).map(_.name).toSet) diff --git a/modules/core/src/main/scala/sangria/schema/Schema.scala b/modules/core/src/main/scala/sangria/schema/Schema.scala index d7892a9d..5d849a03 100644 --- a/modules/core/src/main/scala/sangria/schema/Schema.scala +++ b/modules/core/src/main/scala/sangria/schema/Schema.scala @@ -105,6 +105,8 @@ object Named { def isValidName(name: String): Boolean = NameRegexp.pattern.matcher(name).matches() } +trait OnScalarType + /** Defines a GraphQL scalar value type. * * `coerceOutput` is allowed to return following scala values: @@ -135,7 +137,7 @@ case class ScalarType[T]( coerceInput: ast.Value => Either[Violation, T], complexity: Double = 0.0d, scalarInfo: Set[ScalarValueInfo] = Set.empty, - astDirectives: Vector[ast.Directive] = Vector.empty, + astDirectives: Vector[ast.Directive with OnScalarType] = Vector.empty, astNodes: Vector[ast.AstNode] = Vector.empty ) extends InputType[T @@ CoercedScalaResult] with OutputType[T] @@ -143,6 +145,13 @@ case class ScalarType[T]( with NullableType with UnmodifiedType with Named { + + def withDirective(directive: ast.Directive with OnScalarType): ScalarType[T] = + copy(astDirectives = astDirectives :+ directive) + + def withDirectives(directives: (ast.Directive with OnScalarType)*): ScalarType[T] = + copy(astDirectives = astDirectives ++ directives) + def rename(newName: String): this.type = copy(name = newName).asInstanceOf[this.type] def toAst: ast.TypeDefinition = SchemaRenderer.renderType(this) } @@ -214,6 +223,8 @@ sealed trait ObjectLikeType[Ctx, Val] def toAst: ast.TypeDefinition = SchemaRenderer.renderType(this) } +trait OnObjectType + /** GraphQL schema object description. * * Describes a type of object in a GraphQL schema that is presented by a Sangria server. Objects of @@ -235,10 +246,14 @@ case class ObjectType[Ctx, Val: ClassTag]( fieldsFn: () => List[Field[Ctx, Val]], interfaces: List[InterfaceType[Ctx, _]], instanceCheck: (Any, Class[_], ObjectType[Ctx, Val]) => Boolean, - astDirectives: Vector[ast.Directive], + astDirectives: Vector[ast.Directive with OnObjectType], astNodes: Vector[ast.AstNode] ) extends ObjectLikeType[Ctx, Val] { lazy val valClass: Class[_] = implicitly[ClassTag[Val]].runtimeClass + def withDirective(directive: ast.Directive with OnObjectType): ObjectType[Ctx, Val] = + copy(astDirectives = astDirectives :+ directive) + def withDirectives(directives: (ast.Directive with OnObjectType)*): ObjectType[Ctx, Val] = + copy(astDirectives = astDirectives ++ directives) def withInstanceCheck( fn: (Any, Class[_], ObjectType[Ctx, Val]) => Boolean): ObjectType[Ctx, Val] = @@ -367,6 +382,8 @@ object ObjectType { (value, valClass, tpe) => valClass.isAssignableFrom(value.getClass) } +trait OnInterfaceType + /** @param description * A description of this schema element that can be presented to clients of the GraphQL service. */ @@ -376,7 +393,7 @@ case class InterfaceType[Ctx, Val]( fieldsFn: () => List[Field[Ctx, Val]], interfaces: List[InterfaceType[Ctx, _]], manualPossibleTypes: () => List[ObjectType[_, _]], - astDirectives: Vector[ast.Directive], + astDirectives: Vector[ast.Directive with OnInterfaceType] = Vector.empty, astNodes: Vector[ast.AstNode] = Vector.empty ) extends ObjectLikeType[Ctx, Val] with AbstractType { @@ -385,6 +402,12 @@ case class InterfaceType[Ctx, Val]( def withPossibleTypes(possible: () => List[PossibleObject[Ctx, Val]]): InterfaceType[Ctx, Val] = copy(manualPossibleTypes = () => possible().map(_.objectType)) def rename(newName: String): this.type = copy(name = newName).asInstanceOf[this.type] + + def withDirective(directive: ast.Directive with OnInterfaceType): InterfaceType[Ctx, Val] = + copy(astDirectives = astDirectives :+ directive) + + def withDirectives(directives: (ast.Directive with OnInterfaceType)*): InterfaceType[Ctx, Val] = + copy(astDirectives = astDirectives ++ directives) } object InterfaceType { @@ -522,6 +545,8 @@ object PossibleType { create[Abstract, Concrete] } +trait OnUnionType + /** @param description * A description of this schema element that can be presented to clients of the GraphQL service. */ @@ -529,7 +554,7 @@ case class UnionType[Ctx]( name: String, description: Option[String] = None, typesFn: () => List[ObjectType[Ctx, _]], - astDirectives: Vector[ast.Directive] = Vector.empty, + astDirectives: Vector[ast.Directive with OnUnionType] = Vector.empty, astNodes: Vector[ast.AstNode] = Vector.empty) extends OutputType[Any] with CompositeType[Any] @@ -537,6 +562,13 @@ case class UnionType[Ctx]( with NullableType with UnmodifiedType with HasAstInfo { + + def withDirective(directive: ast.Directive with OnUnionType): UnionType[Ctx] = + copy(astDirectives = astDirectives :+ directive) + + def withDirectives(directives: (ast.Directive with OnUnionType)*): UnionType[Ctx] = + copy(astDirectives = astDirectives ++ directives) + def rename(newName: String): this.type = copy(name = newName).asInstanceOf[this.type] def toAst: ast.TypeDefinition = SchemaRenderer.renderType(this) @@ -566,18 +598,20 @@ object UnionType { name: String, description: Option[String], types: List[ObjectType[Ctx, _]], - astDirectives: Vector[ast.Directive]): UnionType[Ctx] = + astDirectives: Vector[ast.Directive with OnUnionType]): UnionType[Ctx] = UnionType(name, description, () => types, astDirectives) def apply[Ctx]( name: String, description: Option[String], types: List[ObjectType[Ctx, _]], - astDirectives: Vector[ast.Directive], + astDirectives: Vector[ast.Directive with OnUnionType], astNodes: Vector[ast.AstNode]): UnionType[Ctx] = UnionType[Ctx](name, description, () => types, astDirectives, astNodes) } +trait OnField + /** @param description * A description of this schema element that can be presented to clients of the GraphQL service. * @param resolve @@ -594,7 +628,7 @@ case class Field[Ctx, Val]( tags: List[FieldTag], complexity: Option[(Ctx, Args, Double) => Double], manualPossibleTypes: () => List[ObjectType[_, _]], - astDirectives: Vector[ast.Directive], + astDirectives: Vector[ast.Directive with OnField], astNodes: Vector[ast.AstNode]) extends Named with HasArguments @@ -604,6 +638,10 @@ case class Field[Ctx, Val]( copy(manualPossibleTypes = () => possible.toList.map(_.objectType)) def withPossibleTypes(possible: () => List[PossibleObject[Ctx, Val]]): Field[Ctx, Val] = copy(manualPossibleTypes = () => possible().map(_.objectType)) + def withDirective(directive: ast.Directive with OnField): Field[Ctx, Val] = + copy(astDirectives = astDirectives :+ directive) + def withDirectives(directives: (ast.Directive with OnField)*): Field[Ctx, Val] = + copy(astDirectives = astDirectives ++ directives) def rename(newName: String): this.type = copy(name = newName).asInstanceOf[this.type] def toAst: ast.FieldDefinition = SchemaRenderer.renderField(this) } @@ -618,8 +656,9 @@ object Field { possibleTypes: => List[PossibleObject[_, _]] = Nil, tags: List[FieldTag] = Nil, complexity: Option[(Ctx, Args, Double) => Double] = None, - deprecationReason: Option[String] = None)(implicit - ev: ValidOutType[Res, Out]): Field[Ctx, Val] = + deprecationReason: Option[String] = None, + astDirectives: Vector[ast.Directive with OnField] = Vector.empty + )(implicit ev: ValidOutType[Res, Out]): Field[Ctx, Val] = Field[Ctx, Val]( name, fieldType, @@ -630,7 +669,7 @@ object Field { tags, complexity, () => possibleTypes.map(_.objectType), - Vector.empty, + astDirectives, Vector.empty) def subs[Ctx, Val, StreamSource, Res, Out]( @@ -642,7 +681,8 @@ object Field { possibleTypes: => List[PossibleObject[_, _]] = Nil, tags: List[FieldTag] = Nil, complexity: Option[(Ctx, Args, Double) => Double] = None, - deprecationReason: Option[String] = None + deprecationReason: Option[String] = None, + astDirectives: Vector[ast.Directive with OnField] = Vector.empty )(implicit stream: SubscriptionStreamLike[StreamSource, Action, Ctx, Res, Out]): Field[Ctx, Val] = { val s = stream.subscriptionStream @@ -660,7 +700,7 @@ object Field { SubscriptionField[stream.StreamSource](s) +: tags, complexity, () => possibleTypes.map(_.objectType), - Vector.empty, + astDirectives, Vector.empty ) } @@ -687,6 +727,8 @@ trait InputValue[T] { def defaultValue: Option[(_, ToInput[_, _])] } +trait OnArgument + /** @param description * A description of this schema element that can be presented to clients of the GraphQL service. */ @@ -696,13 +738,17 @@ case class Argument[T]( description: Option[String], defaultValue: Option[(_, ToInput[_, _])], fromInput: FromInput[_], - astDirectives: Vector[ast.Directive], + astDirectives: Vector[ast.Directive with OnArgument], astNodes: Vector[ast.AstNode]) extends InputValue[T] with Named with HasAstInfo { override def inputValueType: InputType[_] = argumentType override def rename(newName: String): this.type = copy(name = newName).asInstanceOf[this.type] + def withDirective(directive: ast.Directive with OnArgument): Argument[T] = + copy(astDirectives = astDirectives :+ directive) + def withDirectives(directives: (ast.Directive with OnArgument)*): Argument[T] = + copy(astDirectives = astDirectives ++ directives) def toAst: ast.InputValueDefinition = SchemaRenderer.renderArg(this) } @@ -724,7 +770,12 @@ object Argument { Vector.empty, Vector.empty) - def apply[T, Default](name: String, argumentType: InputType[T], defaultValue: Default)(implicit + def apply[T, Default]( + name: String, + argumentType: InputType[T], + defaultValue: Default, + astDirectives: Vector[ast.Directive with OnArgument] = Vector.empty + )(implicit toInput: ToInput[Default, _], fromInput: FromInput[T], res: ArgumentType[T]): Argument[res.Res] = @@ -734,7 +785,7 @@ object Argument { None, Some(defaultValue -> toInput), fromInput, - Vector.empty, + astDirectives, Vector.empty) def apply[T](name: String, argumentType: InputType[T], description: String)(implicit @@ -921,6 +972,8 @@ trait ArgumentTypeLowestPrio { } } +trait OnEnumType + /** @param description * A description of this schema element that can be presented to clients of the GraphQL service. */ @@ -928,7 +981,7 @@ case class EnumType[T]( name: String, description: Option[String] = None, values: List[EnumValue[T]], - astDirectives: Vector[ast.Directive] = Vector.empty, + astDirectives: Vector[ast.Directive with OnEnumType] = Vector.empty, astNodes: Vector[ast.AstNode] = Vector.empty) extends InputType[T @@ CoercedScalaResult] with OutputType[T] @@ -942,6 +995,12 @@ case class EnumType[T]( lazy val byValue: Map[T, EnumValue[T]] = values.groupBy(_.value).map { case (k, v) => (k, v.head) } + def withDirective(directive: ast.Directive with OnEnumType): EnumType[T] = + copy(astDirectives = astDirectives :+ directive) + + def withDirectives(directives: (ast.Directive with OnEnumType)*): EnumType[T] = + copy(astDirectives = astDirectives ++ directives) + def coerceUserInput(value: Any): Either[Violation, (T, Boolean)] = value match { case valueName: String => byName @@ -968,6 +1027,8 @@ case class EnumType[T]( def toAst: ast.TypeDefinition = SchemaRenderer.renderType(this) } +trait OnEnumValue + /** @param description * A description of this schema element that can be presented to clients of the GraphQL service. */ @@ -976,15 +1037,24 @@ case class EnumValue[+T]( description: Option[String] = None, value: T, deprecationReason: Option[String] = None, - astDirectives: Vector[ast.Directive] = Vector.empty, + astDirectives: Vector[ast.Directive with OnEnumValue] = Vector.empty, astNodes: Vector[ast.AstNode] = Vector.empty) extends Named with HasDeprecation with HasAstInfo { + + def withDirective(directive: ast.Directive with OnEnumValue): EnumValue[T] = + copy(astDirectives = astDirectives :+ directive) + + def withDirectives(directives: (ast.Directive with OnEnumValue)*): EnumValue[T] = + copy(astDirectives = astDirectives ++ directives) + def rename(newName: String): this.type = copy(name = newName).asInstanceOf[this.type] def toAst: ast.EnumValueDefinition = SchemaRenderer.renderEnumValue(this) } +trait OnInputObjectType + /** @param description * A description of this schema element that can be presented to clients of the GraphQL service. */ @@ -992,7 +1062,7 @@ case class InputObjectType[T]( name: String, description: Option[String] = None, fieldsFn: () => List[InputField[_]], - astDirectives: Vector[ast.Directive], + astDirectives: Vector[ast.Directive with OnInputObjectType], astNodes: Vector[ast.AstNode] ) extends InputType[T @@ InputObjectResult] with NullableType @@ -1005,6 +1075,12 @@ case class InputObjectType[T]( lazy val fieldsByName: Map[String, InputField[_]] = fields.groupBy(_.name).map { case (k, v) => (k, v.head) }.toMap // required for 2.12 + def withDirective(directive: ast.Directive with OnInputObjectType): InputObjectType[T] = + copy(astDirectives = astDirectives :+ directive) + + def withDirectives(directives: (ast.Directive with OnInputObjectType)*): InputObjectType[T] = + copy(astDirectives = astDirectives ++ directives) + def rename(newName: String): this.type = copy(name = newName).asInstanceOf[this.type] def toAst: ast.TypeDefinition = SchemaRenderer.renderType(this) } @@ -1051,6 +1127,8 @@ trait InputObjectDefaultResultLowPrio { } } +trait OnInputField + /** @param description * A description of this schema element that can be presented to clients of the GraphQL service. */ @@ -1059,11 +1137,18 @@ case class InputField[T]( fieldType: InputType[T], description: Option[String], defaultValue: Option[(_, ToInput[_, _])], - astDirectives: Vector[ast.Directive], + astDirectives: Vector[ast.Directive with OnInputField], astNodes: Vector[ast.AstNode] ) extends InputValue[T] with Named with HasAstInfo { + + def withDirective(directive: ast.Directive with OnInputField): InputField[T] = + copy(astDirectives = astDirectives :+ directive) + + def withDirectives(directives: (ast.Directive with OnInputField)*): InputField[T] = + copy(astDirectives = astDirectives ++ directives) + def inputValueType: InputType[T] = fieldType def rename(newName: String): InputField.this.type = copy(name = newName).asInstanceOf[this.type] def toAst: ast.InputValueDefinition = SchemaRenderer.renderInputField(this) @@ -1220,6 +1305,8 @@ case class Directive( def toAst: ast.DirectiveDefinition = SchemaRenderer.renderDirective(this) } +trait OnSchema + /** GraphQL schema description. * * Describes the schema that is presented by a Sangria server. An instance of this type needs to be @@ -1247,10 +1334,17 @@ case class Schema[Ctx, Val]( override val description: Option[String] = None, directives: List[Directive] = BuiltinDirectives, validationRules: List[SchemaValidationRule] = SchemaValidationRule.default, - override val astDirectives: Vector[ast.Directive] = Vector.empty, + override val astDirectives: Vector[ast.Directive with OnSchema] = Vector.empty, override val astNodes: Vector[ast.AstNode] = Vector.empty) extends HasAstInfo with HasDescription { + + def withDirective(directive: ast.Directive with OnSchema): Schema[Ctx, Val] = + copy(astDirectives = astDirectives :+ directive) + + def withDirectives(directives: (ast.Directive with OnSchema)*): Schema[Ctx, Val] = + copy(astDirectives = astDirectives ++ directives) + def extend( document: ast.Document, builder: AstSchemaBuilder[Ctx] = AstSchemaBuilder.default[Ctx]): Schema[Ctx, Val] = diff --git a/modules/core/src/test/scala/sangria/execution/InputDocumentMaterializerSpec.scala b/modules/core/src/test/scala/sangria/execution/InputDocumentMaterializerSpec.scala index 30d7fc17..6a6e905c 100644 --- a/modules/core/src/test/scala/sangria/execution/InputDocumentMaterializerSpec.scala +++ b/modules/core/src/test/scala/sangria/execution/InputDocumentMaterializerSpec.scala @@ -67,7 +67,8 @@ class InputDocumentMaterializerSpec extends AnyWordSpec with Matchers with Strin coerceInput = v => Right(v), complexity = scalarComplexity(definition), scalarInfo = scalarValueInfo(definition), - astDirectives = definition.directives + astDirectives = + definition.directives.asInstanceOf[Vector[ast.Directive with OnScalarType]] )) else super.buildScalarType(origin, extensions, definition, mat) diff --git a/modules/core/src/test/scala/sangria/schema/CustomDirectiveSpec.scala b/modules/core/src/test/scala/sangria/schema/CustomDirectiveSpec.scala new file mode 100644 index 00000000..14848bd3 --- /dev/null +++ b/modules/core/src/test/scala/sangria/schema/CustomDirectiveSpec.scala @@ -0,0 +1,140 @@ +package sangria.schema + +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import sangria.ast +import sangria.renderer.QueryRenderer + +class CustomDirectiveSpec extends AnyWordSpec with Matchers { + + case class Domain(value: Int) + + private val AllDirective = new ast.Directive("field-directive") + with OnField + with OnArgument + with OnObjectType + with OnInterfaceType + + private val FieldDirective = new ast.Directive("field-directive") with OnField + + private val ArgumentDirective = new ast.Directive("arg-directive") with OnArgument + + private val ObjectDirective = new ast.Directive("object-directive") with OnObjectType + + private val InterfaceDirective = new ast.Directive("interface-directive") with OnInterfaceType + + private val CustomDirective = ast.Directive("custom-directive") + + private val resolve: Context[Unit, Domain] => Action[Unit, Int] = _.value.value + + "custom directive" when { + "in context of a Field" should { + "be applied if marked with OnField" in { + fields[Unit, Domain]( + Field("field", IntType, resolve = resolve, astDirectives = Vector(FieldDirective))) + + fields[Unit, Domain]( + Field("field", IntType, resolve = resolve, astDirectives = Vector(AllDirective))) + + fields[Unit, Domain]( + Field( + "field", + IntType, + resolve = resolve, + astDirectives = Vector(FieldDirective, AllDirective))) + + Field("field", IntType, resolve = resolve, astDirectives = Vector(FieldDirective)): Field[ + Unit, + Domain] + + val field = (Field("field", IntType, resolve = resolve): Field[Unit, Domain]) + .withDirective(FieldDirective) + .withDirectives(FieldDirective, AllDirective) + field.astDirectives should be(Vector(FieldDirective, FieldDirective, AllDirective)) + } + + "not be applied if not marked with OnField" in { + assertTypeError(""" + |fields[Unit, Domain]( + | Field("field", IntType, resolve = resolve, astDirectives = Vector(CustomDirective))) + |""".stripMargin) + + assertTypeError(""" + |val field: Field[Unit, Domain] = + | Field("field", IntType, resolve = resolve, astDirectives = Vector(CustomDirective)) + |""".stripMargin) + } + + "be combined with the @deprecated directive" in { + val field = (Field( + "field", + IntType, + resolve = resolve, + deprecationReason = Some("use field2")): Field[Unit, Domain]) + .withDirective(FieldDirective) + + field.astDirectives should be(Vector(FieldDirective)) + QueryRenderer.renderPretty(field.toAst) should equal( + """field: Int! @field-directive @deprecated(reason: "use field2")""") + } + } + } + + "in context of an Argument" should { + "be applied if marked with OnArgument" in { + Argument("name", IntType, 42, astDirectives = Vector(ArgumentDirective)) + Argument("name", IntType, 42, astDirectives = Vector(AllDirective)) + Argument("name", IntType, 42).withDirective(ArgumentDirective) + val arg = Argument("name", IntType, 42) + .withDirective(AllDirective) + .withDirectives(ArgumentDirective, ArgumentDirective) + arg.astDirectives should be(Vector(AllDirective, ArgumentDirective, ArgumentDirective)) + } + + "not be applied if not marked with OnArgument" in { + assertTypeError(""" + |Argument("name", IntType, 42, astDirectives = Vector(FieldDirective)) + |""".stripMargin) + assertTypeError(""" + |Argument("name", IntType, 42).withDirective(FieldDirective) + |""".stripMargin) + } + } + + "in context of an ObjectType" should { + "be applied if marked with OnObjectType" in { + val obj = ObjectType[Unit, Domain]("name", fields[Unit, Domain]()) + .withDirective(ObjectDirective) + .withDirectives(AllDirective, ObjectDirective) + obj.astDirectives should be(Vector(ObjectDirective, AllDirective, ObjectDirective)) + } + + "not be applied if not marked with OnObjectType" in { + assertTypeError(""" + |ObjectType[Unit, Domain]("name", fields[Unit, Domain]()).withDirective(CustomDirective) + |""".stripMargin) + assertTypeError(""" + |ObjectType[Unit, Domain]("name", fields[Unit, Domain]()).withDirective(FieldDirective) + |""".stripMargin) + } + } + + "in context of an InterfaceType" should { + "be applied if marked with OnInterfaceType" in { + val interface = InterfaceType[Unit, Domain]("name", fields[Unit, Domain]()) + .withDirective(InterfaceDirective) + .withDirectives(AllDirective, InterfaceDirective) + interface.astDirectives should be( + Vector(InterfaceDirective, AllDirective, InterfaceDirective)) + } + + "not be applied if not marked with OnObjectType" in { + assertTypeError(""" + |InterfaceType[Unit, Domain]("name", fields[Unit, Domain]()).withDirective(CustomDirective) + |""".stripMargin) + assertTypeError(""" + |InterfaceType[Unit, Domain]("name", fields[Unit, Domain]()).withDirective(FieldDirective) + |""".stripMargin) + } + } +}