From 246ec9819c967d3b778e00b973d527846ba4a250 Mon Sep 17 00:00:00 2001 From: Alejandro Serrano Date: Mon, 3 Apr 2023 11:47:44 +0200 Subject: [PATCH] Backport creation of Isos for value classes (#3021) --- .../arrow/optics/plugin/DeclarationUtils.kt | 3 ++ .../arrow/optics/plugin/OpticsProcessor.kt | 2 +- .../arrow/optics/plugin/internals/domain.kt | 5 +++ .../arrow/optics/plugin/internals/dsl.kt | 44 +++++++++++++++++++ .../arrow/optics/plugin/internals/errors.kt | 4 +- .../arrow/optics/plugin/internals/isos.kt | 7 ++- .../optics/plugin/internals/processor.kt | 34 +++++++++++--- .../arrow/optics/plugin/internals/snippets.kt | 1 + .../kotlin/arrow/optics/plugin/IsoTests.kt | 27 ++++++++++++ 9 files changed, 115 insertions(+), 12 deletions(-) diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/DeclarationUtils.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/DeclarationUtils.kt index a126ce33857..bf2e2946650 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/DeclarationUtils.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/DeclarationUtils.kt @@ -11,3 +11,6 @@ val KSClassDeclaration.isSealed val KSClassDeclaration.isData get() = modifiers.contains(Modifier.DATA) + +val KSClassDeclaration.isValue + get() = modifiers.contains(Modifier.VALUE) diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/OpticsProcessor.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/OpticsProcessor.kt index 8571b26df52..a19014bcafb 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/OpticsProcessor.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/OpticsProcessor.kt @@ -32,7 +32,7 @@ class OpticsProcessor(private val codegen: CodeGenerator, private val logger: KS private fun processClass(klass: KSClassDeclaration) { // check that it is sealed or data - if (!klass.isSealed && !klass.isData) { + if (!klass.isSealed && !klass.isData && !klass.isValue) { logger.error(klass.qualifiedNameOrSimpleName.otherClassTypeErrorMessage, klass) return } diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/domain.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/domain.kt index 69d4f052206..159358715d1 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/domain.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/domain.kt @@ -50,6 +50,8 @@ typealias SealedClassDsl = Target.SealedClassDsl typealias DataClassDsl = Target.DataClassDsl +typealias ValueClassDsl = Target.ValueClassDsl + sealed class Target { abstract val foci: List @@ -59,6 +61,9 @@ sealed class Target { data class Optional(override val foci: List) : Target() data class SealedClassDsl(override val foci: List) : Target() data class DataClassDsl(override val foci: List) : Target() + data class ValueClassDsl(val focus: Focus) : Target() { + override val foci = listOf(focus) + } } typealias NonNullFocus = Focus.NonNull diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/dsl.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/dsl.kt index 35aca64ed08..4ff9a1dd189 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/dsl.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/dsl.kt @@ -32,6 +32,16 @@ fun generatePrismDsl(ele: ADT, isoOptic: SealedClassDsl): Snippet { ) } +fun generateIsoDsl(ele: ADT, isoOptic: ValueClassDsl): Snippet { + val (className, import) = resolveClassName(ele) + return Snippet( + `package` = ele.packageName, + name = ele.simpleName, + content = processIsoSyntax(ele, isoOptic, className), + imports = setOf(import) + ) +} + private fun processLensSyntax(ele: ADT, foci: List, className: String): String { return if (ele.typeParameters.isEmpty()) { foci.joinToString(separator = "\n") { focus -> @@ -137,6 +147,40 @@ private fun processPrismSyntax(ele: ADT, dsl: SealedClassDsl, className: String) } } +private fun processIsoSyntax(ele: ADT, dsl: ValueClassDsl, className: String): String = + if (ele.typeParameters.isEmpty()) { + dsl.foci.joinToString(separator = "\n\n") { focus -> + """ + |${ele.visibilityModifierName} inline val $Iso.${focus.paramName}: $Iso inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Lens.${focus.paramName}: $Lens inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Optional.${focus.paramName}: $Optional inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Prism.${focus.paramName}: $Prism inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Setter.${focus.paramName}: $Setter inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Traversal.${focus.paramName}: $Traversal inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Fold.${focus.paramName}: $Fold inline get() = this + ${className}.${focus.paramName} + |${ele.visibilityModifierName} inline val $Every.${focus.paramName}: $Every inline get() = this + ${className}.${focus.paramName} + |""".trimMargin() + } + } else { + dsl.foci.joinToString(separator = "\n\n") { focus -> + val sourceClassNameWithParams = focus.refinedType?.qualifiedString() ?: "${ele.sourceClassName}${ele.angledTypeParameters}" + val joinedTypeParams = when { + focus.refinedArguments.isEmpty() -> "" + else -> focus.refinedArguments.joinToString(separator=",") + } + """ + |${ele.visibilityModifierName} inline fun $Iso.${focus.paramName}(): $Iso = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Lens.${focus.paramName}(): $Lens = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Optional.${focus.paramName}(): $Optional = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Prism.${focus.paramName}(): $Prism = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Setter.${focus.paramName}(): $Setter = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Traversal.${focus.paramName}(): $Traversal = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Fold.${focus.paramName}(): $Fold = this + ${className}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Every.${focus.paramName}(): $Every = this + ${className}.${focus.paramName}() + |""".trimMargin() + } + } + private fun resolveClassName(ele: ADT): Pair = if (hasPackageCollisions(ele)) { val classNameAlias = ele.sourceClassName.replace(".", "") val aliasImport = "import ${ele.sourceClassName} as $classNameAlias" diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/errors.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/errors.kt index 1c9b65fd979..7d634ac6e21 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/errors.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/errors.kt @@ -5,7 +5,7 @@ val String.otherClassTypeErrorMessage """ |$this cannot be annotated with @optics | ^ - |Only data and sealed classes can be annotated with @optics""".trimMargin() + |Only data, sealed, and value classes can be annotated with @optics""".trimMargin() val String.typeParametersErrorMessage get() = @@ -47,7 +47,7 @@ val String.isoErrorMessage |Cannot generate arrow.optics.Iso for $this | ^ |arrow.optics.OpticsTarget.ISO is an invalid @optics argument for $this. - |It is only valid for data classes. + |It is only valid for data and value classes. """.trimMargin() val String.isoTooBigErrorMessage diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/isos.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/isos.kt index dce33e5b3e0..64a89d6a746 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/isos.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/isos.kt @@ -1,5 +1,7 @@ package arrow.optics.plugin.internals +import arrow.optics.plugin.isValue + internal fun generateIsos(ele: ADT, target: IsoTarget) = Snippet(`package` = ele.packageName, name = ele.simpleName, content = processElement(ele, target)) @@ -63,12 +65,13 @@ private fun processElement(iso: ADT, target: Target): String { "tuple: ${focusType()} -> ${(foci.indices).joinToString(prefix = "${iso.sourceClassName}(", postfix = ")", transform = { "tuple.${letters[it]}" })}" } + val isoName = if (iso.declaration.isValue) target.foci.first().paramName else "iso" val sourceClassNameWithParams = "${iso.sourceClassName}${iso.angledTypeParameters}" val firstLine = when { iso.typeParameters.isEmpty() -> - "${iso.visibilityModifierName} inline val ${iso.sourceClassName}.Companion.iso: $Iso<${iso.sourceClassName}, ${focusType()}> inline get()" + "${iso.visibilityModifierName} inline val ${iso.sourceClassName}.Companion.$isoName: $Iso<${iso.sourceClassName}, ${focusType()}> inline get()" else -> - "${iso.visibilityModifierName} inline fun ${iso.angledTypeParameters} ${iso.sourceClassName}.Companion.iso(): $Iso<$sourceClassNameWithParams, ${focusType()}>" + "${iso.visibilityModifierName} inline fun ${iso.angledTypeParameters} ${iso.sourceClassName}.Companion.$isoName(): $Iso<$sourceClassNameWithParams, ${focusType()}>" } return """ diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/processor.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/processor.kt index 4ef66b4dce9..9ef7f4f2809 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/processor.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/processor.kt @@ -2,6 +2,7 @@ package arrow.optics.plugin.internals import arrow.optics.plugin.isData import arrow.optics.plugin.isSealed +import arrow.optics.plugin.isValue import com.google.devtools.ksp.processing.KSPLogger import com.google.devtools.ksp.symbol.* import java.util.Locale @@ -18,8 +19,12 @@ internal fun adt(c: KSClassDeclaration, logger: KSPLogger): ADT = OpticsTarget.OPTIONAL -> evalAnnotatedDataClass(c, c.qualifiedNameOrSimpleName.optionalErrorMessage, logger) .let(::OptionalTarget) - OpticsTarget.ISO -> evalAnnotatedIsoElement(c, logger).let(::IsoTarget) - OpticsTarget.PRISM -> evalAnnotatedPrismElement(c, logger).let(::PrismTarget) + OpticsTarget.ISO -> + evalAnnotatedIsoElement(c, c.qualifiedNameOrSimpleName.isoErrorMessage, logger) + .let(::IsoTarget) + OpticsTarget.PRISM -> + evalAnnotatedPrismElement(c, c.qualifiedNameOrSimpleName.prismErrorMessage, logger) + .let(::PrismTarget) OpticsTarget.DSL -> evalAnnotatedDslElement(c, logger) } } @@ -32,6 +37,9 @@ internal fun KSClassDeclaration.targets(): List = if (targets.isEmpty()) listOf(OpticsTarget.PRISM, OpticsTarget.DSL) else targets.filter { it == OpticsTarget.PRISM || it == OpticsTarget.DSL } + isValue -> + listOf(OpticsTarget.ISO, OpticsTarget.DSL) + .filter { targets.isEmpty() || it in targets } else -> if (targets.isEmpty()) listOf(OpticsTarget.ISO, OpticsTarget.LENS, OpticsTarget.OPTIONAL, OpticsTarget.DSL) @@ -62,6 +70,7 @@ internal fun KSClassDeclaration.targetsFromOpticsAnnotation(): List = when { @@ -74,7 +83,7 @@ internal fun evalAnnotatedPrismElement( ) }.toList() else -> { - logger.error(element.qualifiedNameOrSimpleName.prismErrorMessage, element) + logger.error(errorMessage, element) emptyList() } } @@ -109,11 +118,20 @@ internal fun evalAnnotatedDslElement(element: KSClassDeclaration, logger: KSPLog .getConstructorTypesNames() .zip(element.getConstructorParamNames(), Focus.Companion::invoke) ) - element.isSealed -> SealedClassDsl(evalAnnotatedPrismElement(element, logger)) - else -> throw IllegalStateException("should only be sealed or data by now") + element.isValue -> + ValueClassDsl( + Focus(element.getConstructorTypesNames().first(), element.getConstructorParamNames().first()) + ) + element.isSealed -> + SealedClassDsl(evalAnnotatedPrismElement(element, element.qualifiedNameOrSimpleName.prismErrorMessage, logger)) + else -> throw IllegalStateException("should only be sealed, data, or value by now") } -internal fun evalAnnotatedIsoElement(element: KSClassDeclaration, logger: KSPLogger): List = +internal fun evalAnnotatedIsoElement( + element: KSClassDeclaration, + errorMessage: String, + logger: KSPLogger +): List = when { element.isData -> element @@ -124,8 +142,10 @@ internal fun evalAnnotatedIsoElement(element: KSClassDeclaration, logger: KSPLog logger.error(element.qualifiedNameOrSimpleName.isoTooBigErrorMessage, element) emptyList() } + element.isValue -> + listOf(Focus(element.getConstructorTypesNames().first(), element.getConstructorParamNames().first())) else -> { - logger.error(element.qualifiedNameOrSimpleName.isoErrorMessage, element) + logger.error(errorMessage, element) emptyList() } } diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/snippets.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/snippets.kt index 6252cb46990..ebb6924bf75 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/snippets.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/snippets.kt @@ -9,6 +9,7 @@ internal fun ADT.snippets(): List = is OptionalTarget -> generateOptionals(this, it) is SealedClassDsl -> generatePrismDsl(this, it) is DataClassDsl -> generateOptionalDsl(this, it) + generateLensDsl(this, it) + is ValueClassDsl -> generateIsoDsl(this, it) } } diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/IsoTests.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/IsoTests.kt index 5ed4b3d6556..3c4885a69b6 100755 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/IsoTests.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/IsoTests.kt @@ -99,4 +99,31 @@ class IsoTests { |} """.failsWith { it.contains("${`package`.removePrefix("package ")}.IsoXXL".isoTooBigErrorMessage) } } + + @Test + fun `Isos will be generated for value class`() { + """ + |$`package` + |$imports + |@optics @JvmInline + |value class IsoData( + | val field1: String + |) { companion object } + | + |val i: Iso = IsoData.field1 + |val r = i != null + """.evals("r" to true) + } + + @Test + fun `Iso generation requires companion object declaration, value class`() { + """ + |$`package` + |$imports + |@optics @JvmInline + |value class IsoNoCompanion( + | val field1: String + |) + """.failsWith { it.contains("IsoNoCompanion".noCompanion) } + } }