diff --git a/plugin/src/main/scala-2/Scala2CompilerApi.scala b/plugin/src/main/scala-2/Scala2CompilerApi.scala index f13a19c..95aaa6f 100644 --- a/plugin/src/main/scala-2/Scala2CompilerApi.scala +++ b/plugin/src/main/scala-2/Scala2CompilerApi.scala @@ -106,7 +106,14 @@ object Scala2CompilerApi { case d: ValDef => d.name.toString } - def isCaseClass(clazz: Clazz): Boolean = clazz.merge.mods.isCase + // No enums in scala 2 + def isEnum(clazz: Clazz): Boolean = false + + def isViableDefinition(clazz: Clazz): Boolean = { + val isCaseClassOrObject = clazz.merge.mods.isCase + + isCaseClassOrObject + } // Always return true for ModuleDef - apparently ModuleDef doesn't have the module flag... def isObject(clazz: Clazz): Boolean = clazz.fold( @@ -114,6 +121,9 @@ object Scala2CompilerApi { obj = _ => true ) + def productPrefixParam: Nothing = sys.error("invalid state: this shouldn't be called in Scala 2") + + def report(str: String): Unit = reporter.echo(str) } } diff --git a/plugin/src/main/scala-3/Scala3CompilerApi.scala b/plugin/src/main/scala-3/Scala3CompilerApi.scala index 86d18a2..4d6b487 100644 --- a/plugin/src/main/scala-3/Scala3CompilerApi.scala +++ b/plugin/src/main/scala-3/Scala3CompilerApi.scala @@ -94,11 +94,22 @@ object Scala3CompilerApi: d.name.toString } - def isCaseClass(clazz: Clazz): Boolean = - // for some reason, this is true for case objects too - clazz.clazz.flags.is(CaseClass) + def isEnum(clazz: Clazz): Boolean = + // the class is an enum if its direct supertypes include scala.reflect.Enum + clazz.clazz.parentTypes.map(_.typeSymbol).contains(Symbols.defn.EnumClass) + + def isViableDefinition(clazz: Clazz): Boolean = { + val isCaseClassOrObject = clazz.clazz.flags.is(CaseClass) + + isCaseClassOrObject || isEnum(clazz) + } def isObject(clazz: Clazz): Boolean = clazz.clazz.flags.is(Module) + def report(str: String): Unit = + dotty.tools.dotc.report.echo(str) + + def productPrefixParam: ParamName = "productPrefix".toTermName + end Scala3CompilerApi diff --git a/plugin/src/main/scala/BetterToStringImpl.scala b/plugin/src/main/scala/BetterToStringImpl.scala index fa60381..95f59ea 100644 --- a/plugin/src/main/scala/BetterToStringImpl.scala +++ b/plugin/src/main/scala/BetterToStringImpl.scala @@ -39,9 +39,13 @@ trait CompilerApi { def createToString(clazz: Clazz, body: Tree): Method def addMethod(clazz: Clazz, method: Method): Clazz def methodNames(clazz: Clazz): List[String] - // better name: "is case class or object" - def isCaseClass(clazz: Clazz): Boolean + def isViableDefinition(clazz: Clazz): Boolean + def isEnum(clazz: Clazz): Boolean def isObject(clazz: Clazz): Boolean + def productPrefixParam: ParamName + + // debugging + def report(str: String): Unit } trait BetterToStringImpl[+C <: CompilerApi] { @@ -73,7 +77,7 @@ object BetterToStringImpl { // technically, the method found by this can be even something like "def toString(s: String): Unit", but we're ignoring that val hasToString: Boolean = methodNames(clazz).contains("toString") - val shouldModify = isCaseClass(clazz) && !isNested && !hasToString + val shouldModify = isViableDefinition(clazz) && !isNested && !hasToString if (shouldModify) overrideToString(clazz, enclosingObject) else clazz @@ -101,6 +105,7 @@ object BetterToStringImpl { val paramParts = if (api.isObject(clazz)) Nil + else if (api.isEnum(clazz)) List(literalConstant("."), selectInThis(clazz, productPrefixParam)) else List( List(literalConstant("(")), diff --git a/tests/src/test/scala-3/Scala3Tests.scala b/tests/src/test/scala-3/Scala3Tests.scala index bf5ddbb..3e717b9 100644 --- a/tests/src/test/scala-3/Scala3Tests.scala +++ b/tests/src/test/scala-3/Scala3Tests.scala @@ -20,15 +20,11 @@ class Scala3Tests extends FunSuite: test("an enum made of constants should have a normal toString") { assertEquals( ScalaVersion.Scala2.toString, - // https://github.com/polyvariant/better-tostring/issues/60 - // should be "ScalaVersion.Scala2" - "Scala2" + "ScalaVersion.Scala2" ) assertEquals( ScalaVersion.Scala3.toString, - // https://github.com/polyvariant/better-tostring/issues/60 - // should be "ScalaVersion.Scala3" - "Scala3" + "ScalaVersion.Scala3" ) } @@ -39,9 +35,7 @@ class Scala3Tests extends FunSuite: ) assertEquals( User.Unauthorized.toString, - // https://github.com/polyvariant/better-tostring/issues/60 - // should be "User.Unauthorized" - "Unauthorized" + "User.Unauthorized" ) }