From 3c6a0854a50d55d40ab77d216b9bf90768df6bce Mon Sep 17 00:00:00 2001 From: Jonathan Schuchart Date: Fri, 5 Jul 2024 14:01:30 +0300 Subject: [PATCH] Fix restriction on maps/collections nesting Signed-off-by: Jonathan Schuchart --- .../flytekitscala/SdkLiteralTypesTest.scala | 10 +- .../flyte/flytekitscala/SdkLiteralTypes.scala | 139 +++++------------- 2 files changed, 42 insertions(+), 107 deletions(-) diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala index 2b70765c..b1c90419 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala @@ -129,6 +129,10 @@ class TestOfReturnsProperTypeProvider extends ArgumentsProvider { Arguments.of( collections(maps(durations())), of[List[Map[String, Duration]]]() + ), + Arguments.of( + collections(maps(collections(maps(collections(strings()))))), + of[List[Map[String, List[Map[String, List[String]]]]]]() ) ) } @@ -142,11 +146,7 @@ class testOfThrowExceptionsForUnsupportedTypesProvider Stream.of( Arguments .of("java type, must use java factory", () => of[java.lang.Long]()), - Arguments.of("not a supported type", () => of[Object]()), - Arguments.of( - "triple nesting not supported in of", - () => of[List[List[List[Long]]]]() - ) + Arguments.of("not a supported type", () => of[Object]()) ) } } diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala index fb128dd5..de245a5b 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala @@ -34,7 +34,8 @@ import scala.reflect.runtime.universe.{ TypeTag, runtimeMirror, termNames, - typeOf + typeOf, + typeTag } import scala.tools.nsc.doc.model.Trait @@ -72,87 +73,17 @@ object SdkLiteralTypes { blobs(BlobType.DEFAULT).asInstanceOf[SdkLiteralType[T]] case t if t =:= typeOf[Binary] => binary().asInstanceOf[SdkLiteralType[T]] + + case t if t <:< typeOf[List[Any]] => + collections(of()(createTypeTag(typeTag[T].mirror, t.typeArgs.head))) + .asInstanceOf[SdkLiteralType[T]] + case t if t <:< typeOf[Map[String, Any]] => + maps(of()(createTypeTag(typeTag[T].mirror, t.typeArgs.last))) + .asInstanceOf[SdkLiteralType[T]] + case t if t <:< typeOf[Product] && !(t =:= typeOf[Option[_]]) => generics().asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[Long]] => - collections(integers()).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[Double]] => - collections(floats()).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[String]] => - collections(strings()).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[Boolean]] => - collections(booleans()).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[Instant]] => - collections(datetimes()).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[Duration]] => - collections(durations()).asInstanceOf[SdkLiteralType[T]] - - case t if t =:= typeOf[Map[String, Long]] => - maps(integers()).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, Double]] => - maps(floats()).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, String]] => - maps(strings()).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, Boolean]] => - maps(booleans()).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, Instant]] => - maps(datetimes()).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, Duration]] => - maps(durations()).asInstanceOf[SdkLiteralType[T]] - - case t if t =:= typeOf[List[List[Long]]] => - collections(collections(integers())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[List[Double]]] => - collections(collections(floats())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[List[String]]] => - collections(collections(strings())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[List[Boolean]]] => - collections(collections(booleans())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[List[Instant]]] => - collections(collections(datetimes())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[List[Duration]]] => - collections(collections(durations())).asInstanceOf[SdkLiteralType[T]] - - case t if t =:= typeOf[List[Map[String, Long]]] => - collections(maps(integers())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[Map[String, Double]]] => - collections(maps(floats())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[Map[String, String]]] => - collections(maps(strings())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[Map[String, Boolean]]] => - collections(maps(booleans())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[Map[String, Instant]]] => - collections(maps(datetimes())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[List[Map[String, Duration]]] => - collections(maps(durations())).asInstanceOf[SdkLiteralType[T]] - - case t if t =:= typeOf[Map[String, Map[String, Long]]] => - maps(maps(integers())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, Map[String, Double]]] => - maps(maps(floats())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, Map[String, String]]] => - maps(maps(strings())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, Map[String, Boolean]]] => - maps(maps(booleans())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, Map[String, Instant]]] => - maps(maps(datetimes())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, Map[String, Duration]]] => - maps(maps(durations())).asInstanceOf[SdkLiteralType[T]] - - case t if t =:= typeOf[Map[String, List[Long]]] => - maps(collections(integers())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, List[Double]]] => - maps(collections(floats())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, List[String]]] => - maps(collections(strings())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, List[Boolean]]] => - maps(collections(booleans())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, List[Instant]]] => - maps(collections(datetimes())).asInstanceOf[SdkLiteralType[T]] - case t if t =:= typeOf[Map[String, List[Duration]]] => - maps(collections(durations())).asInstanceOf[SdkLiteralType[T]] - case _ => throw new IllegalArgumentException(s"Unsupported type: ${typeOf[T]}") } @@ -341,7 +272,7 @@ object SdkLiteralTypes { ) } } else if (tpe <:< typeOf[Product]) { - val typeTag = createTypeTag(tpe) + val typeTag = createTypeTag(mirror, tpe) val classTag = ClassTag( typeTag.mirror.runtimeClass(tpe) ) @@ -356,6 +287,7 @@ object SdkLiteralTypes { // In this case, we use the __TYPE field to get the type of the product. case map: Map[String, Any] if map.contains(__TYPE) => val typeTag = createTypeTag( + mirror, mirror .staticClass(map(__TYPE).asInstanceOf[String]) .typeSignature @@ -369,28 +301,6 @@ object SdkLiteralTypes { } } - def createTypeTag[U <: Product](tpe: Type): TypeTag[U] = { - val typSym = mirror.staticClass(tpe.typeSymbol.fullName) - // note: this uses internal API, otherwise we will need to depend on scala-compiler at runtime - val typeRef = - universe.internal.typeRef(NoPrefix, typSym, List.empty) - - TypeTag( - mirror, - new TypeCreator { - override def apply[V <: Universe with Singleton]( - m: Mirror[V] - ): V#Type = { - assert( - m == mirror, - s"TypeTag[$typeRef] defined in $mirror cannot be migrated to $m." - ) - typeRef.asInstanceOf[V#Type] - } - } - ) - } - val clazz = typeOf[S].typeSymbol.asClass val classMirror = mirror.reflectClass(clazz) val constructor = typeOf[S].decl(termNames.CONSTRUCTOR).asMethod @@ -499,6 +409,31 @@ object SdkLiteralTypes { override def toString: String = s"map of [$valuesType]" } + + private def createTypeTag[U]( + mirror: universe.Mirror, + tpe: Type + ): TypeTag[U] = { + val typSym = mirror.staticClass(tpe.typeSymbol.fullName) + // note: this uses internal API, otherwise we will need to depend on scala-compiler at runtime + val typeRef = + universe.internal.typeRef(NoPrefix, typSym, tpe.typeArgs) + + TypeTag( + mirror, + new TypeCreator { + override def apply[V <: Universe with Singleton]( + m: Mirror[V] + ): V#Type = { + assert( + m == mirror, + s"TypeTag[$typeRef] defined in $mirror cannot be migrated to $m." + ) + typeRef.asInstanceOf[V#Type] + } + } + ) + } } private object ScalaLiteralType {