Skip to content

Commit

Permalink
Merge pull request #11 from kubukoz/dotty
Browse files Browse the repository at this point in the history
  • Loading branch information
kubukoz authored Mar 7, 2021
2 parents c44d335 + 4cc6617 commit 431e8df
Show file tree
Hide file tree
Showing 10 changed files with 318 additions and 97 deletions.
24 changes: 23 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ jobs:
- 2.13.3
- 2.13.4
- 2.13.5
- 3.0.0-M3
- 3.0.0-RC1
java: [[email protected]]
runs-on: ${{ matrix.os }}
steps:
Expand Down Expand Up @@ -84,7 +86,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
scala: [2.12.10]
scala: [3.0.0-RC1]
java: [[email protected]]
runs-on: ${{ matrix.os }}
steps:
Expand Down Expand Up @@ -200,6 +202,26 @@ jobs:
tar xf targets.tar
rm targets.tar
- name: Download target directories (3.0.0-M3)
uses: actions/download-artifact@v2
with:
name: target-${{ matrix.os }}-3.0.0-M3-${{ matrix.java }}

- name: Inflate target directories (3.0.0-M3)
run: |
tar xf targets.tar
rm targets.tar
- name: Download target directories (3.0.0-RC1)
uses: actions/download-artifact@v2
with:
name: target-${{ matrix.os }}-3.0.0-RC1-${{ matrix.java }}

- name: Inflate target directories (3.0.0-RC1)
run: |
tar xf targets.tar
rm targets.tar
- uses: olafurpg/setup-gpg@v3

- run: sbt ++${{ matrix.scala }} ci-release
20 changes: 16 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ inThisBuild(

val GraalVM11 = "[email protected]"

ThisBuild / scalaVersion := "2.12.10"
ThisBuild / scalaVersion := "3.0.0-RC1"
ThisBuild / crossScalaVersions := Seq(
"2.12.10",
"2.12.11",
Expand All @@ -29,7 +29,10 @@ ThisBuild / crossScalaVersions := Seq(
"2.13.2",
"2.13.3",
"2.13.4",
"2.13.5"
"2.13.5",
//
"3.0.0-M3",
"3.0.0-RC1"
)

ThisBuild / githubWorkflowJavaVersions := Seq(GraalVM11)
Expand Down Expand Up @@ -57,13 +60,22 @@ val commonSettings = Seq(
scalacOptions -= "-Xfatal-warnings"
)

def scalatestVersion(scalaVersion: String) = scalaVersion match {
case "3.0.0-M3" => "3.2.3"
case _ => "3.2.5"
}

val plugin = project.settings(
name := "better-tostring",
commonSettings,
crossTarget := target.value / s"scala-${scalaVersion.value}", // workaround for https://github.com/sbt/sbt/issues/5097
crossVersion := CrossVersion.full,
libraryDependencies ++= Seq(
scalaOrganization.value % "scala-compiler" % scalaVersion.value
scalaOrganization.value % (
if (isDotty.value)
s"scala3-compiler_${scalaVersion.value}"
else "scala-compiler"
) % scalaVersion.value
)
)

Expand All @@ -79,7 +91,7 @@ val tests = project.settings(
) //borrowed from bm4
},
libraryDependencies ++= Seq(
"org.scalatest" %% "scalatest" % "3.2.5" % Test
"org.scalatest" %% "scalatest" % scalatestVersion(scalaVersion.value) % Test
)
)

Expand Down
1 change: 1 addition & 0 deletions plugin/src/main/resources/plugin.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pluginClass=com.kubukoz.BetterToStringPlugin
50 changes: 50 additions & 0 deletions plugin/src/main/scala-2/BetterToStringPlugin.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.kubukoz

import scala.tools.nsc.Global
import scala.tools.nsc.Phase
import scala.tools.nsc.plugins.Plugin
import scala.tools.nsc.plugins.PluginComponent
import scala.tools.nsc.transform.TypingTransformers

final class BetterToStringPlugin(override val global: Global) extends Plugin {
override val name: String = "better-tostring"
override val description: String =
"scala compiler plugin for better default toString implementations"
override val components: List[PluginComponent] = List(
new BetterToStringPluginComponent(global)
)
}

final class BetterToStringPluginComponent(val global: Global)
extends PluginComponent
with TypingTransformers {
import global._
override val phaseName: String = "better-tostring-phase"
override val runsAfter: List[String] = List("parser")

private val impl: BetterToStringImpl[Scala2CompilerApi[global.type]] =
BetterToStringImpl.instance(Scala2CompilerApi.instance(global))

private def modifyClasses(tree: Tree): Tree =
tree match {
case p: PackageDef => p.copy(stats = p.stats.map(modifyClasses))
case m: ModuleDef =>
m.copy(impl = m.impl.copy(body = m.impl.body.map(modifyClasses)))
case clazz: ClassDef =>
impl.transformClass(
clazz,
// If it was nested, we wouldn't be in this branch.
// Scala 2.x compiler API limitation (classes can't tell what the owner is).
// This should be more optimal as we don't traverse every template, but it hasn't been benchmarked.
isNested = false
)
case other => other
}

override def newPhase(prev: Phase): Phase = new StdPhase(prev) {
override def apply(unit: CompilationUnit): Unit =
new Transformer {
override def transform(tree: Tree): Tree = modifyClasses(tree)
}.transformUnit(unit)
}
}
50 changes: 50 additions & 0 deletions plugin/src/main/scala-2/Scala2CompilerApi.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.kubukoz

import scala.reflect.internal.Flags
import scala.tools.nsc.Global

trait Scala2CompilerApi[G <: Global] extends CompilerApi {
val theGlobal: G
import theGlobal._
type Tree = theGlobal.Tree
type Clazz = ClassDef
type Param = ValDef
type ParamName = TermName
type Method = DefDef
}

object Scala2CompilerApi {
def instance(global: Global): Scala2CompilerApi[global.type] =
new Scala2CompilerApi[global.type] {
val theGlobal: global.type = global
import global._

def params(clazz: Clazz): List[Param] = clazz.impl.body.collect {
case v: ValDef if v.mods.hasFlag(Flags.CASEACCESSOR) => v
}

def className(clazz: Clazz): String = clazz.name.toString
def literalConstant(value: String): Tree = Literal(Constant(value))
def paramName(param: Param): ParamName = param.name
def selectInThis(clazz: Clazz, name: ParamName): Tree = q"this.$name"
def concat(l: Tree, r: Tree): Tree = q"$l + $r"

def createToString(clazz: Clazz, body: Tree): Method = DefDef(
Modifiers(Flags.OVERRIDE),
TermName("toString"),
Nil,
List(List()),
Ident(TypeName("String")),
body
)

def addMethod(clazz: Clazz, method: Method): Clazz =
clazz.copy(impl = clazz.impl.copy(body = clazz.impl.body :+ method))

def methodNames(clazz: Clazz): List[String] = clazz.impl.body.collect {
case d: DefDef => d.name.toString
}

def isCaseClass(clazz: Clazz): Boolean = clazz.mods.hasFlag(Flags.CASE)
}
}
28 changes: 28 additions & 0 deletions plugin/src/main/scala-3/BetterToStringPlugin.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.kubukoz

import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.core.Flags.Module
import dotty.tools.dotc.plugins.{PluginPhase, StandardPlugin}
import dotty.tools.dotc.typer.FrontEnd
import tpd._

final class BetterToStringPlugin extends StandardPlugin:
override val name: String = "better-tostring"
override val description: String = "scala compiler plugin for better default toString implementations"
override def init(options: List[String]): List[PluginPhase] = List(new BetterToStringPluginPhase)

final class BetterToStringPluginPhase extends PluginPhase:

override val phaseName: String = "better-tostring-phase"
override val runsAfter: Set[String] = Set(FrontEnd.name)

override def transformTemplate(t: Template)(using ctx: Context): Tree =
val clazz = ctx.owner.asClass

val isNested = !(ctx.owner.owner.isPackageObject || ctx.owner.owner.is(Module))

BetterToStringImpl
.instance(Scala3CompilerApi.instance)
.transformClass(Scala3CompilerApi.ClassContext(t, clazz), isNested)
.t
64 changes: 64 additions & 0 deletions plugin/src/main/scala-3/Scala3CompilerApi.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package com.kubukoz

import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.core.Symbols
import dotty.tools.dotc.core.Flags.CaseAccessor
import dotty.tools.dotc.core.Flags.CaseClass
import dotty.tools.dotc.core.Flags.Override
import dotty.tools.dotc.core.Types
import dotty.tools.dotc.core.Names
import dotty.tools.dotc.core.Constants.Constant
import dotty.tools.dotc.core.Symbols.ClassSymbol
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.ast.Trees
import tpd._

trait Scala3CompilerApi extends CompilerApi:
type Tree = Trees.Tree[Types.Type]
type Clazz = Scala3CompilerApi.ClassContext
type Param = ValDef
type ParamName = Names.TermName
type Method = DefDef

object Scala3CompilerApi:
final case class ClassContext(t: Template, clazz: ClassSymbol):
def mapTemplate(f: Template => Template): ClassContext = copy(t = f(t))

def instance(using Context): Scala3CompilerApi = new Scala3CompilerApi:
def params(clazz: Clazz): List[Param] =
clazz.t.body.collect {
case v: ValDef if v.mods.is(CaseAccessor) => v
}

def className(clazz: Clazz): String = clazz.clazz.name.toString
def literalConstant(value: String): Tree = Literal(Constant(value))
def paramName(param: Param): ParamName = param.name
def selectInThis(clazz: Clazz, name: ParamName): Tree = This(clazz.clazz).select(name)
def concat(l: Tree, r: Tree): Tree = l.select("+".toTermName).appliedTo(r)

def createToString(owner: Clazz, body: Tree): Method = {
val clazz = owner.clazz
// this was adapted from dotty.tools.dotc.transform.SyntheticMembers (line 115)
val sym = Symbols.defn.Any_toString

val toStringSymbol = sym.copy(
owner = clazz,
flags = sym.flags | Override,
info = clazz.thisType.memberInfo(sym),
coord = clazz.coord
).entered.asTerm

DefDef(toStringSymbol, body)
}

def addMethod(clazz: Clazz, method: Method): Clazz =
clazz.mapTemplate(
t => cpy.Template(t)(body = t.body :+ method)
)

def methodNames(clazz: Clazz): List[String] = clazz.t.body.collect {
case d: DefDef => d.name.toString
}

def isCaseClass(clazz: Clazz): Boolean = clazz.clazz.flags.is(CaseClass)
85 changes: 85 additions & 0 deletions plugin/src/main/scala/BetterToStringImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package com.kubukoz

// Source-compatible core between 2.x and 3.x implementations

trait CompilerApi {
type Tree
type Clazz
type Param
type ParamName
type Method

def className(clazz: Clazz): String
def params(clazz: Clazz): List[Param]
def literalConstant(value: String): Tree

def paramName(param: Param): ParamName
def selectInThis(clazz: Clazz, name: ParamName): Tree
def concat(l: Tree, r: Tree): Tree

def createToString(clazz: Clazz, body: Tree): Method
def addMethod(clazz: Clazz, method: Method): Clazz
def methodNames(clazz: Clazz): List[String]
def isCaseClass(clazz: Clazz): Boolean
}

trait BetterToStringImpl[+C <: CompilerApi] {
val compilerApi: C

def transformClass(
clazz: compilerApi.Clazz,
isNested: Boolean
): compilerApi.Clazz
}

object BetterToStringImpl {
def instance(
api: CompilerApi
): BetterToStringImpl[api.type] =
new BetterToStringImpl[api.type] {
val compilerApi: api.type = api

import api._

def transformClass(
clazz: Clazz,
isNested: Boolean
): Clazz = {
val hasToString: Boolean = methodNames(clazz).contains("toString")

val shouldModify =
isCaseClass(clazz) && !isNested && !hasToString

if (shouldModify) overrideToString(clazz)
else clazz
}

private def overrideToString(clazz: Clazz): Clazz =
addMethod(clazz, createToString(clazz, toStringImpl(clazz)))

private def toStringImpl(clazz: Clazz): Tree = {
val className = api.className(clazz)

val paramListParts: List[Tree] = params(clazz).zipWithIndex.flatMap {
case (v, index) =>
val commaPrefix = if (index > 0) ", " else ""

val name = paramName(v)

List(
literalConstant(commaPrefix ++ name.toString ++ " = "),
selectInThis(clazz, name)
)
}

val parts =
List(
List(literalConstant(className ++ "(")),
paramListParts,
List(literalConstant(")"))
).flatten

parts.reduceLeft(concat(_, _))
}
}
}
Loading

0 comments on commit 431e8df

Please sign in to comment.