Skip to content

Commit

Permalink
Backport creation of Isos for value classes (#3021)
Browse files Browse the repository at this point in the history
  • Loading branch information
serras authored Apr 3, 2023
1 parent d766fa9 commit 246ec98
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ val KSClassDeclaration.isSealed

val KSClassDeclaration.isData
get() = modifiers.contains(Modifier.DATA)

val KSClassDeclaration.isValue
get() = modifiers.contains(Modifier.VALUE)
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class OpticsProcessor(private val codegen: CodeGenerator, private val logger: KS

private fun processClass(klass: KSClassDeclaration) {
// check that it is sealed or data
if (!klass.isSealed && !klass.isData) {
if (!klass.isSealed && !klass.isData && !klass.isValue) {
logger.error(klass.qualifiedNameOrSimpleName.otherClassTypeErrorMessage, klass)
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ typealias SealedClassDsl = Target.SealedClassDsl

typealias DataClassDsl = Target.DataClassDsl

typealias ValueClassDsl = Target.ValueClassDsl

sealed class Target {
abstract val foci: List<Focus>

Expand All @@ -59,6 +61,9 @@ sealed class Target {
data class Optional(override val foci: List<Focus>) : Target()
data class SealedClassDsl(override val foci: List<Focus>) : Target()
data class DataClassDsl(override val foci: List<Focus>) : Target()
data class ValueClassDsl(val focus: Focus) : Target() {
override val foci = listOf(focus)
}
}

typealias NonNullFocus = Focus.NonNull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ fun generatePrismDsl(ele: ADT, isoOptic: SealedClassDsl): Snippet {
)
}

fun generateIsoDsl(ele: ADT, isoOptic: ValueClassDsl): Snippet {
val (className, import) = resolveClassName(ele)
return Snippet(
`package` = ele.packageName,
name = ele.simpleName,
content = processIsoSyntax(ele, isoOptic, className),
imports = setOf(import)
)
}

private fun processLensSyntax(ele: ADT, foci: List<Focus>, className: String): String {
return if (ele.typeParameters.isEmpty()) {
foci.joinToString(separator = "\n") { focus ->
Expand Down Expand Up @@ -137,6 +147,40 @@ private fun processPrismSyntax(ele: ADT, dsl: SealedClassDsl, className: String)
}
}

private fun processIsoSyntax(ele: ADT, dsl: ValueClassDsl, className: String): String =
if (ele.typeParameters.isEmpty()) {
dsl.foci.joinToString(separator = "\n\n") { focus ->
"""
|${ele.visibilityModifierName} inline val <S> $Iso<S, ${ele.sourceClassName}>.${focus.paramName}: $Iso<S, ${focus.className}> inline get() = this + ${className}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Lens<S, ${ele.sourceClassName}>.${focus.paramName}: $Lens<S, ${focus.className}> inline get() = this + ${className}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Optional<S, ${ele.sourceClassName}>.${focus.paramName}: $Optional<S, ${focus.className}> inline get() = this + ${className}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Prism<S, ${ele.sourceClassName}>.${focus.paramName}: $Prism<S, ${focus.className}> inline get() = this + ${className}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Setter<S, ${ele.sourceClassName}>.${focus.paramName}: $Setter<S, ${focus.className}> inline get() = this + ${className}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Traversal<S, ${ele.sourceClassName}>.${focus.paramName}: $Traversal<S, ${focus.className}> inline get() = this + ${className}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Fold<S, ${ele.sourceClassName}>.${focus.paramName}: $Fold<S, ${focus.className}> inline get() = this + ${className}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Every<S, ${ele.sourceClassName}>.${focus.paramName}: $Every<S, ${focus.className}> inline get() = this + ${className}.${focus.paramName}
|""".trimMargin()
}
} else {
dsl.foci.joinToString(separator = "\n\n") { focus ->
val sourceClassNameWithParams = focus.refinedType?.qualifiedString() ?: "${ele.sourceClassName}${ele.angledTypeParameters}"
val joinedTypeParams = when {
focus.refinedArguments.isEmpty() -> ""
else -> focus.refinedArguments.joinToString(separator=",")
}
"""
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Iso<S, $sourceClassNameWithParams>.${focus.paramName}(): $Iso<S, ${focus.className}> = this + ${className}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Lens<S, $sourceClassNameWithParams>.${focus.paramName}(): $Lens<S, ${focus.className}> = this + ${className}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Optional<S, $sourceClassNameWithParams>.${focus.paramName}(): $Optional<S, ${focus.className}> = this + ${className}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Prism<S, $sourceClassNameWithParams>.${focus.paramName}(): $Prism<S, ${focus.className}> = this + ${className}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Setter<S, $sourceClassNameWithParams>.${focus.paramName}(): $Setter<S, ${focus.className}> = this + ${className}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Traversal<S, $sourceClassNameWithParams>.${focus.paramName}(): $Traversal<S, ${focus.className}> = this + ${className}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Fold<S, $sourceClassNameWithParams>.${focus.paramName}(): $Fold<S, ${focus.className}> = this + ${className}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Every<S, $sourceClassNameWithParams>.${focus.paramName}(): $Every<S, ${focus.className}> = this + ${className}.${focus.paramName}()
|""".trimMargin()
}
}

private fun resolveClassName(ele: ADT): Pair<String, String> = if (hasPackageCollisions(ele)) {
val classNameAlias = ele.sourceClassName.replace(".", "")
val aliasImport = "import ${ele.sourceClassName} as $classNameAlias"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ val String.otherClassTypeErrorMessage
"""
|$this cannot be annotated with @optics
| ^
|Only data and sealed classes can be annotated with @optics""".trimMargin()
|Only data, sealed, and value classes can be annotated with @optics""".trimMargin()

val String.typeParametersErrorMessage
get() =
Expand Down Expand Up @@ -47,7 +47,7 @@ val String.isoErrorMessage
|Cannot generate arrow.optics.Iso for $this
| ^
|arrow.optics.OpticsTarget.ISO is an invalid @optics argument for $this.
|It is only valid for data classes.
|It is only valid for data and value classes.
""".trimMargin()

val String.isoTooBigErrorMessage
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package arrow.optics.plugin.internals

import arrow.optics.plugin.isValue

internal fun generateIsos(ele: ADT, target: IsoTarget) =
Snippet(`package` = ele.packageName, name = ele.simpleName, content = processElement(ele, target))

Expand Down Expand Up @@ -63,12 +65,13 @@ private fun processElement(iso: ADT, target: Target): String {
"tuple: ${focusType()} -> ${(foci.indices).joinToString(prefix = "${iso.sourceClassName}(", postfix = ")", transform = { "tuple.${letters[it]}" })}"
}

val isoName = if (iso.declaration.isValue) target.foci.first().paramName else "iso"
val sourceClassNameWithParams = "${iso.sourceClassName}${iso.angledTypeParameters}"
val firstLine = when {
iso.typeParameters.isEmpty() ->
"${iso.visibilityModifierName} inline val ${iso.sourceClassName}.Companion.iso: $Iso<${iso.sourceClassName}, ${focusType()}> inline get()"
"${iso.visibilityModifierName} inline val ${iso.sourceClassName}.Companion.$isoName: $Iso<${iso.sourceClassName}, ${focusType()}> inline get()"
else ->
"${iso.visibilityModifierName} inline fun ${iso.angledTypeParameters} ${iso.sourceClassName}.Companion.iso(): $Iso<$sourceClassNameWithParams, ${focusType()}>"
"${iso.visibilityModifierName} inline fun ${iso.angledTypeParameters} ${iso.sourceClassName}.Companion.$isoName(): $Iso<$sourceClassNameWithParams, ${focusType()}>"
}

return """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package arrow.optics.plugin.internals

import arrow.optics.plugin.isData
import arrow.optics.plugin.isSealed
import arrow.optics.plugin.isValue
import com.google.devtools.ksp.processing.KSPLogger
import com.google.devtools.ksp.symbol.*
import java.util.Locale
Expand All @@ -18,8 +19,12 @@ internal fun adt(c: KSClassDeclaration, logger: KSPLogger): ADT =
OpticsTarget.OPTIONAL ->
evalAnnotatedDataClass(c, c.qualifiedNameOrSimpleName.optionalErrorMessage, logger)
.let(::OptionalTarget)
OpticsTarget.ISO -> evalAnnotatedIsoElement(c, logger).let(::IsoTarget)
OpticsTarget.PRISM -> evalAnnotatedPrismElement(c, logger).let(::PrismTarget)
OpticsTarget.ISO ->
evalAnnotatedIsoElement(c, c.qualifiedNameOrSimpleName.isoErrorMessage, logger)
.let(::IsoTarget)
OpticsTarget.PRISM ->
evalAnnotatedPrismElement(c, c.qualifiedNameOrSimpleName.prismErrorMessage, logger)
.let(::PrismTarget)
OpticsTarget.DSL -> evalAnnotatedDslElement(c, logger)
}
}
Expand All @@ -32,6 +37,9 @@ internal fun KSClassDeclaration.targets(): List<OpticsTarget> =
if (targets.isEmpty())
listOf(OpticsTarget.PRISM, OpticsTarget.DSL)
else targets.filter { it == OpticsTarget.PRISM || it == OpticsTarget.DSL }
isValue ->
listOf(OpticsTarget.ISO, OpticsTarget.DSL)
.filter { targets.isEmpty() || it in targets }
else ->
if (targets.isEmpty())
listOf(OpticsTarget.ISO, OpticsTarget.LENS, OpticsTarget.OPTIONAL, OpticsTarget.DSL)
Expand Down Expand Up @@ -62,6 +70,7 @@ internal fun KSClassDeclaration.targetsFromOpticsAnnotation(): List<OpticsTarget

internal fun evalAnnotatedPrismElement(
element: KSClassDeclaration,
errorMessage: String,
logger: KSPLogger
): List<Focus> =
when {
Expand All @@ -74,7 +83,7 @@ internal fun evalAnnotatedPrismElement(
)
}.toList()
else -> {
logger.error(element.qualifiedNameOrSimpleName.prismErrorMessage, element)
logger.error(errorMessage, element)
emptyList()
}
}
Expand Down Expand Up @@ -109,11 +118,20 @@ internal fun evalAnnotatedDslElement(element: KSClassDeclaration, logger: KSPLog
.getConstructorTypesNames()
.zip(element.getConstructorParamNames(), Focus.Companion::invoke)
)
element.isSealed -> SealedClassDsl(evalAnnotatedPrismElement(element, logger))
else -> throw IllegalStateException("should only be sealed or data by now")
element.isValue ->
ValueClassDsl(
Focus(element.getConstructorTypesNames().first(), element.getConstructorParamNames().first())
)
element.isSealed ->
SealedClassDsl(evalAnnotatedPrismElement(element, element.qualifiedNameOrSimpleName.prismErrorMessage, logger))
else -> throw IllegalStateException("should only be sealed, data, or value by now")
}

internal fun evalAnnotatedIsoElement(element: KSClassDeclaration, logger: KSPLogger): List<Focus> =
internal fun evalAnnotatedIsoElement(
element: KSClassDeclaration,
errorMessage: String,
logger: KSPLogger
): List<Focus> =
when {
element.isData ->
element
Expand All @@ -124,8 +142,10 @@ internal fun evalAnnotatedIsoElement(element: KSClassDeclaration, logger: KSPLog
logger.error(element.qualifiedNameOrSimpleName.isoTooBigErrorMessage, element)
emptyList()
}
element.isValue ->
listOf(Focus(element.getConstructorTypesNames().first(), element.getConstructorParamNames().first()))
else -> {
logger.error(element.qualifiedNameOrSimpleName.isoErrorMessage, element)
logger.error(errorMessage, element)
emptyList()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ internal fun ADT.snippets(): List<Snippet> =
is OptionalTarget -> generateOptionals(this, it)
is SealedClassDsl -> generatePrismDsl(this, it)
is DataClassDsl -> generateOptionalDsl(this, it) + generateLensDsl(this, it)
is ValueClassDsl -> generateIsoDsl(this, it)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,31 @@ class IsoTests {
|}
""".failsWith { it.contains("${`package`.removePrefix("package ")}.IsoXXL".isoTooBigErrorMessage) }
}

@Test
fun `Isos will be generated for value class`() {
"""
|$`package`
|$imports
|@optics @JvmInline
|value class IsoData(
| val field1: String
|) { companion object }
|
|val i: Iso<IsoData, String> = IsoData.field1
|val r = i != null
""".evals("r" to true)
}

@Test
fun `Iso generation requires companion object declaration, value class`() {
"""
|$`package`
|$imports
|@optics @JvmInline
|value class IsoNoCompanion(
| val field1: String
|)
""".failsWith { it.contains("IsoNoCompanion".noCompanion) }
}
}

0 comments on commit 246ec98

Please sign in to comment.