diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 31f3be2..5c0b800 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,6 +37,8 @@ jobs: - 2.13.3 - 2.13.4 - 2.13.5 + - 3.0.0-M3 + - 3.0.0-RC1 java: [graalvm-ce-java11@20.3.0] runs-on: ${{ matrix.os }} steps: @@ -84,7 +86,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - scala: [2.12.10] + scala: [3.0.0-RC1] java: [graalvm-ce-java11@20.3.0] runs-on: ${{ matrix.os }} steps: @@ -200,6 +202,26 @@ jobs: tar xf targets.tar rm targets.tar + - name: Download target directories (3.0.0-M3) + uses: actions/download-artifact@v2 + with: + name: target-${{ matrix.os }}-3.0.0-M3-${{ matrix.java }} + + - name: Inflate target directories (3.0.0-M3) + run: | + tar xf targets.tar + rm targets.tar + + - name: Download target directories (3.0.0-RC1) + uses: actions/download-artifact@v2 + with: + name: target-${{ matrix.os }}-3.0.0-RC1-${{ matrix.java }} + + - name: Inflate target directories (3.0.0-RC1) + run: | + tar xf targets.tar + rm targets.tar + - uses: olafurpg/setup-gpg@v3 - run: sbt ++${{ matrix.scala }} ci-release \ No newline at end of file diff --git a/build.sbt b/build.sbt index 8397451..e246703 100644 --- a/build.sbt +++ b/build.sbt @@ -18,7 +18,7 @@ inThisBuild( val GraalVM11 = "graalvm-ce-java11@20.3.0" -ThisBuild / scalaVersion := "2.12.10" +ThisBuild / scalaVersion := "3.0.0-RC1" ThisBuild / crossScalaVersions := Seq( "2.12.10", "2.12.11", @@ -29,7 +29,10 @@ ThisBuild / crossScalaVersions := Seq( "2.13.2", "2.13.3", "2.13.4", - "2.13.5" + "2.13.5", + // + "3.0.0-M3", + "3.0.0-RC1" ) ThisBuild / githubWorkflowJavaVersions := Seq(GraalVM11) @@ -57,13 +60,22 @@ val commonSettings = Seq( scalacOptions -= "-Xfatal-warnings" ) +def scalatestVersion(scalaVersion: String) = scalaVersion match { + case "3.0.0-M3" => "3.2.3" + case _ => "3.2.5" +} + val plugin = project.settings( name := "better-tostring", commonSettings, crossTarget := target.value / s"scala-${scalaVersion.value}", // workaround for https://github.com/sbt/sbt/issues/5097 crossVersion := CrossVersion.full, libraryDependencies ++= Seq( - scalaOrganization.value % "scala-compiler" % scalaVersion.value + scalaOrganization.value % ( + if (isDotty.value) + s"scala3-compiler_${scalaVersion.value}" + else "scala-compiler" + ) % scalaVersion.value ) ) @@ -79,7 +91,7 @@ val tests = project.settings( ) //borrowed from bm4 }, libraryDependencies ++= Seq( - "org.scalatest" %% "scalatest" % "3.2.5" % Test + "org.scalatest" %% "scalatest" % scalatestVersion(scalaVersion.value) % Test ) ) diff --git a/plugin/src/main/resources/plugin.properties b/plugin/src/main/resources/plugin.properties new file mode 100644 index 0000000..cc0ddac --- /dev/null +++ b/plugin/src/main/resources/plugin.properties @@ -0,0 +1 @@ +pluginClass=com.kubukoz.BetterToStringPlugin diff --git a/plugin/src/main/scala-2/BetterToStringPlugin.scala b/plugin/src/main/scala-2/BetterToStringPlugin.scala new file mode 100644 index 0000000..954b04c --- /dev/null +++ b/plugin/src/main/scala-2/BetterToStringPlugin.scala @@ -0,0 +1,50 @@ +package com.kubukoz + +import scala.tools.nsc.Global +import scala.tools.nsc.Phase +import scala.tools.nsc.plugins.Plugin +import scala.tools.nsc.plugins.PluginComponent +import scala.tools.nsc.transform.TypingTransformers + +final class BetterToStringPlugin(override val global: Global) extends Plugin { + override val name: String = "better-tostring" + override val description: String = + "scala compiler plugin for better default toString implementations" + override val components: List[PluginComponent] = List( + new BetterToStringPluginComponent(global) + ) +} + +final class BetterToStringPluginComponent(val global: Global) + extends PluginComponent + with TypingTransformers { + import global._ + override val phaseName: String = "better-tostring-phase" + override val runsAfter: List[String] = List("parser") + + private val impl: BetterToStringImpl[Scala2CompilerApi[global.type]] = + BetterToStringImpl.instance(Scala2CompilerApi.instance(global)) + + private def modifyClasses(tree: Tree): Tree = + tree match { + case p: PackageDef => p.copy(stats = p.stats.map(modifyClasses)) + case m: ModuleDef => + m.copy(impl = m.impl.copy(body = m.impl.body.map(modifyClasses))) + case clazz: ClassDef => + impl.transformClass( + clazz, + // If it was nested, we wouldn't be in this branch. + // Scala 2.x compiler API limitation (classes can't tell what the owner is). + // This should be more optimal as we don't traverse every template, but it hasn't been benchmarked. + isNested = false + ) + case other => other + } + + override def newPhase(prev: Phase): Phase = new StdPhase(prev) { + override def apply(unit: CompilationUnit): Unit = + new Transformer { + override def transform(tree: Tree): Tree = modifyClasses(tree) + }.transformUnit(unit) + } +} diff --git a/plugin/src/main/scala-2/Scala2CompilerApi.scala b/plugin/src/main/scala-2/Scala2CompilerApi.scala new file mode 100644 index 0000000..e1bce89 --- /dev/null +++ b/plugin/src/main/scala-2/Scala2CompilerApi.scala @@ -0,0 +1,50 @@ +package com.kubukoz + +import scala.reflect.internal.Flags +import scala.tools.nsc.Global + +trait Scala2CompilerApi[G <: Global] extends CompilerApi { + val theGlobal: G + import theGlobal._ + type Tree = theGlobal.Tree + type Clazz = ClassDef + type Param = ValDef + type ParamName = TermName + type Method = DefDef +} + +object Scala2CompilerApi { + def instance(global: Global): Scala2CompilerApi[global.type] = + new Scala2CompilerApi[global.type] { + val theGlobal: global.type = global + import global._ + + def params(clazz: Clazz): List[Param] = clazz.impl.body.collect { + case v: ValDef if v.mods.hasFlag(Flags.CASEACCESSOR) => v + } + + def className(clazz: Clazz): String = clazz.name.toString + def literalConstant(value: String): Tree = Literal(Constant(value)) + def paramName(param: Param): ParamName = param.name + def selectInThis(clazz: Clazz, name: ParamName): Tree = q"this.$name" + def concat(l: Tree, r: Tree): Tree = q"$l + $r" + + def createToString(clazz: Clazz, body: Tree): Method = DefDef( + Modifiers(Flags.OVERRIDE), + TermName("toString"), + Nil, + List(List()), + Ident(TypeName("String")), + body + ) + + def addMethod(clazz: Clazz, method: Method): Clazz = + clazz.copy(impl = clazz.impl.copy(body = clazz.impl.body :+ method)) + + def methodNames(clazz: Clazz): List[String] = clazz.impl.body.collect { + case d: DefDef => d.name.toString + } + + def isCaseClass(clazz: Clazz): Boolean = clazz.mods.hasFlag(Flags.CASE) + } +} diff --git a/plugin/src/main/scala-3/BetterToStringPlugin.scala b/plugin/src/main/scala-3/BetterToStringPlugin.scala new file mode 100644 index 0000000..6a36fbe --- /dev/null +++ b/plugin/src/main/scala-3/BetterToStringPlugin.scala @@ -0,0 +1,28 @@ +package com.kubukoz + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Flags.Module +import dotty.tools.dotc.plugins.{PluginPhase, StandardPlugin} +import dotty.tools.dotc.typer.FrontEnd +import tpd._ + +final class BetterToStringPlugin extends StandardPlugin: + override val name: String = "better-tostring" + override val description: String = "scala compiler plugin for better default toString implementations" + override def init(options: List[String]): List[PluginPhase] = List(new BetterToStringPluginPhase) + +final class BetterToStringPluginPhase extends PluginPhase: + + override val phaseName: String = "better-tostring-phase" + override val runsAfter: Set[String] = Set(FrontEnd.name) + + override def transformTemplate(t: Template)(using ctx: Context): Tree = + val clazz = ctx.owner.asClass + + val isNested = !(ctx.owner.owner.isPackageObject || ctx.owner.owner.is(Module)) + + BetterToStringImpl + .instance(Scala3CompilerApi.instance) + .transformClass(Scala3CompilerApi.ClassContext(t, clazz), isNested) + .t diff --git a/plugin/src/main/scala-3/Scala3CompilerApi.scala b/plugin/src/main/scala-3/Scala3CompilerApi.scala new file mode 100644 index 0000000..0f93c29 --- /dev/null +++ b/plugin/src/main/scala-3/Scala3CompilerApi.scala @@ -0,0 +1,64 @@ +package com.kubukoz + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Symbols +import dotty.tools.dotc.core.Flags.CaseAccessor +import dotty.tools.dotc.core.Flags.CaseClass +import dotty.tools.dotc.core.Flags.Override +import dotty.tools.dotc.core.Types +import dotty.tools.dotc.core.Names +import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.core.Symbols.ClassSymbol +import dotty.tools.dotc.core.Decorators._ +import dotty.tools.dotc.ast.Trees +import tpd._ + +trait Scala3CompilerApi extends CompilerApi: + type Tree = Trees.Tree[Types.Type] + type Clazz = Scala3CompilerApi.ClassContext + type Param = ValDef + type ParamName = Names.TermName + type Method = DefDef + +object Scala3CompilerApi: + final case class ClassContext(t: Template, clazz: ClassSymbol): + def mapTemplate(f: Template => Template): ClassContext = copy(t = f(t)) + + def instance(using Context): Scala3CompilerApi = new Scala3CompilerApi: + def params(clazz: Clazz): List[Param] = + clazz.t.body.collect { + case v: ValDef if v.mods.is(CaseAccessor) => v + } + + def className(clazz: Clazz): String = clazz.clazz.name.toString + def literalConstant(value: String): Tree = Literal(Constant(value)) + def paramName(param: Param): ParamName = param.name + def selectInThis(clazz: Clazz, name: ParamName): Tree = This(clazz.clazz).select(name) + def concat(l: Tree, r: Tree): Tree = l.select("+".toTermName).appliedTo(r) + + def createToString(owner: Clazz, body: Tree): Method = { + val clazz = owner.clazz + // this was adapted from dotty.tools.dotc.transform.SyntheticMembers (line 115) + val sym = Symbols.defn.Any_toString + + val toStringSymbol = sym.copy( + owner = clazz, + flags = sym.flags | Override, + info = clazz.thisType.memberInfo(sym), + coord = clazz.coord + ).entered.asTerm + + DefDef(toStringSymbol, body) + } + + def addMethod(clazz: Clazz, method: Method): Clazz = + clazz.mapTemplate( + t => cpy.Template(t)(body = t.body :+ method) + ) + + def methodNames(clazz: Clazz): List[String] = clazz.t.body.collect { + case d: DefDef => d.name.toString + } + + def isCaseClass(clazz: Clazz): Boolean = clazz.clazz.flags.is(CaseClass) diff --git a/plugin/src/main/scala/BetterToStringImpl.scala b/plugin/src/main/scala/BetterToStringImpl.scala new file mode 100644 index 0000000..41dc2e3 --- /dev/null +++ b/plugin/src/main/scala/BetterToStringImpl.scala @@ -0,0 +1,85 @@ +package com.kubukoz + +// Source-compatible core between 2.x and 3.x implementations + +trait CompilerApi { + type Tree + type Clazz + type Param + type ParamName + type Method + + def className(clazz: Clazz): String + def params(clazz: Clazz): List[Param] + def literalConstant(value: String): Tree + + def paramName(param: Param): ParamName + def selectInThis(clazz: Clazz, name: ParamName): Tree + def concat(l: Tree, r: Tree): Tree + + def createToString(clazz: Clazz, body: Tree): Method + def addMethod(clazz: Clazz, method: Method): Clazz + def methodNames(clazz: Clazz): List[String] + def isCaseClass(clazz: Clazz): Boolean +} + +trait BetterToStringImpl[+C <: CompilerApi] { + val compilerApi: C + + def transformClass( + clazz: compilerApi.Clazz, + isNested: Boolean + ): compilerApi.Clazz +} + +object BetterToStringImpl { + def instance( + api: CompilerApi + ): BetterToStringImpl[api.type] = + new BetterToStringImpl[api.type] { + val compilerApi: api.type = api + + import api._ + + def transformClass( + clazz: Clazz, + isNested: Boolean + ): Clazz = { + val hasToString: Boolean = methodNames(clazz).contains("toString") + + val shouldModify = + isCaseClass(clazz) && !isNested && !hasToString + + if (shouldModify) overrideToString(clazz) + else clazz + } + + private def overrideToString(clazz: Clazz): Clazz = + addMethod(clazz, createToString(clazz, toStringImpl(clazz))) + + private def toStringImpl(clazz: Clazz): Tree = { + val className = api.className(clazz) + + val paramListParts: List[Tree] = params(clazz).zipWithIndex.flatMap { + case (v, index) => + val commaPrefix = if (index > 0) ", " else "" + + val name = paramName(v) + + List( + literalConstant(commaPrefix ++ name.toString ++ " = "), + selectInThis(clazz, name) + ) + } + + val parts = + List( + List(literalConstant(className ++ "(")), + paramListParts, + List(literalConstant(")")) + ).flatten + + parts.reduceLeft(concat(_, _)) + } + } +} diff --git a/plugin/src/main/scala/BetterToStringPlugin.scala b/plugin/src/main/scala/BetterToStringPlugin.scala deleted file mode 100644 index 63bb4aa..0000000 --- a/plugin/src/main/scala/BetterToStringPlugin.scala +++ /dev/null @@ -1,92 +0,0 @@ -package com.kubukoz - -import scala.tools.nsc.plugins.{Plugin, PluginComponent} -import scala.tools.nsc.{Global, Phase} -import scala.tools.nsc.transform.TypingTransformers -import scala.reflect.internal.Flags - -class BetterToStringPlugin(override val global: Global) extends Plugin { - override val name: String = "better-tostring" - override val description: String = "scala compiler plugin for better default toString implementations" - override val components: List[PluginComponent] = List(new BetterToStringPluginComponent(global)) -} - -class BetterToStringPluginComponent(val global: Global) extends PluginComponent with TypingTransformers { - import global._ - override val phaseName: String = "better-tostring-phase" - override val runsAfter: List[String] = List("parser") - - private def addToString(clazz: ClassDef): ClassDef = { - val params = clazz.impl.body.collect { - case v: ValDef if v.mods.hasFlag(Flags.CASEACCESSOR) => v - } - - val toStringImpl: Tree = { - val className = clazz.name.toString() - - val paramListParts: List[Tree] = params.zipWithIndex.flatMap { - case (v, index) => - val commaPrefix = if (index > 0) ", " else "" - - List( - Literal(Constant(commaPrefix ++ v.name.toString ++ " = ")), - q"this.${v.name}" - ) - } - - val parts = - List( - List(Literal(Constant(className ++ "("))), - paramListParts, - List(Literal(Constant(")"))) - ).flatten - - parts.reduceLeft((a, b) => q"$a + $b") - } - - val methodBody = DefDef( - Modifiers(Flags.OVERRIDE), - TermName("toString"), - Nil, - List(List()), - Ident(TypeName("String")), - toStringImpl - ) - - clazz.copy(impl = clazz.impl.copy(body = clazz.impl.body :+ methodBody)) - } - - private def transformClass(clazz: ClassDef): ClassDef = { - val hasCustomToString: Boolean = clazz.impl.body.exists { - - case fun: DefDef => - //so meta - fun.name.toString == "toString" - case _ => false - } - - val shouldModify = !hasCustomToString - - if (shouldModify) addToString(clazz) - else clazz - } - - private def modifyClasses(f: ClassDef => ClassDef)(tree: Tree): Tree = tree match { - case p: PackageDef => p.copy(stats = p.stats.map(modifyClasses(f))) - case m: ModuleDef => m.copy(impl = m.impl.copy(body = m.impl.body.map(modifyClasses(f)))) - //Only case classes - case clazz: ClassDef if clazz.mods.hasFlag(Flags.CASE) => f(clazz) - case other => other - } - - override def newPhase(prev: Phase): Phase = new StdPhase(prev) { - - override def apply(unit: CompilationUnit): Unit = { - val trans = new Transformer { - override def transform(tree: Tree): Tree = modifyClasses(transformClass)(tree) - } - - trans.transformUnit(unit) - } - } -} diff --git a/project/plugins.sbt b/project/plugins.sbt index 7c59508..82e1c27 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,3 +1,4 @@ addSbtPlugin("com.geirsson" % "sbt-ci-release" % "1.5.5") addSbtPlugin("io.github.davidgregory084" % "sbt-tpolecat" % "0.1.16") addSbtPlugin("com.codecommit" % "sbt-github-actions" % "0.10.1") +addSbtPlugin("ch.epfl.lamp" % "sbt-dotty" % "0.5.3")