Skip to content

Commit

Permalink
Generate typed visitors on unions
Browse files Browse the repository at this point in the history
  • Loading branch information
kubukoz committed Nov 12, 2024
1 parent 88ce227 commit 9c24e99
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 19 deletions.
35 changes: 17 additions & 18 deletions modules/core/src/main/scala/playground/smithyql/RangeIndex.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,31 @@ object RangeIndex {
}

def inputNodeRanges(node: playground.generated.nodes.InputNode, base: NodeContext)
: List[ContextRange] =
node match {
case playground.generated.nodes.String_(string) =>
ContextRange(string.range.shrink1, base.inQuotes) :: Nil

case playground.generated.nodes.List_(list) =>
ContextRange(list.range.shrink1, base.inCollectionEntry(None)) ::
list.list_fields.zipWithIndex.flatMap { (inputNode, i) =>
: List[ContextRange] = node.visit(
new playground.generated.nodes.InputNode.Visitor.Default[List[ContextRange]] {
def default: List[ContextRange] = Nil

override def onString(node: playground.generated.nodes.String_): List[ContextRange] =
ContextRange(node.range.shrink1, base.inQuotes) :: Nil

override def onList(node: playground.generated.nodes.List_): List[ContextRange] =
ContextRange(node.range.shrink1, base.inCollectionEntry(None)) ::
node.list_fields.zipWithIndex.flatMap { (inputNode, i) =>
ContextRange(inputNode.range, base.inCollectionEntry(Some(i))) ::
inputNodeRanges(inputNode, base.inCollectionEntry(Some(i)))
}

case playground.generated.nodes.Struct(struct) =>
ContextRange(struct.range, base) ::
ContextRange(struct.range.shrink1, base.inStructBody) ::
struct.bindings.toList.flatMap { binding =>
override def onStruct(node: playground.generated.nodes.Struct): List[ContextRange] =
ContextRange(node.range, base) ::
ContextRange(node.range.shrink1, base.inStructBody) ::
node.bindings.toList.flatMap { binding =>
(binding.key, binding.value).tupled.toList.flatMap { (key, value) =>
ContextRange(
value.range,
base.inStructBody.inStructValue(key.source),
) :: inputNodeRanges(value, base.inStructBody.inStructValue(key.source))
ContextRange(value.range, base.inStructBody.inStructValue(key.source)) ::
inputNodeRanges(value, base.inStructBody.inStructValue(key.source))
}
}

case _ => Nil
}
)

val queryRanges = parsed.statements.zipWithIndex.flatMap { (stat, statementIndex) =>
stat.run_query.toList.flatMap { runQuery =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import scala.meta.Dialect
extension (tn: TypeName) {
@targetName("renderTypeName")
def render: String = tn.value.dropWhile(_ == '_').fromSnakeCase.ident
def renderProjection: String = show"as${tn.value.dropWhile(_ == '_').fromSnakeCase}".ident
def renderProjection: String = show"as${tn.prettyName}".ident
def renderVisitorMethod: String = show"on${tn.prettyName}".ident
private def prettyName = tn.value.dropWhile(_ == '_').fromSnakeCase
def asChildName: FieldName = FieldName(tn.value)
}

Expand Down Expand Up @@ -53,6 +55,7 @@ private def renderUnion(u: Type.Union): String = {
val instanceMethods =
show"""extension (node: $name) {
|${projections.mkString_("\n").indentTrim(2)}
| def visit[A](visitor: Visitor[A]): A = visitor.visit(node)
|}""".stripMargin

val applyMethod = {
Expand All @@ -68,6 +71,39 @@ private def renderUnion(u: Type.Union): String = {

val typedApplyMethod = show"""def apply(node: $underlyingType): $name = node""".stripMargin

val visitor =
show"""
|trait Visitor[A] {
|${u
.subtypes
.map(sub => show"def ${sub.name.renderVisitorMethod}(node: ${sub.name.render}): A")
.mkString_("\n")
.indentTrim(2)}
|
| def visit(node: $name): A = (node: @nowarn("msg=match may not be exhaustive")) match {
|${u
.subtypes
.map(sub => show"case ${sub.name.render}(node) => ${sub.name.renderVisitorMethod}(node)")
.mkString_("\n")
.indentTrim(4)}
| }
|}
|
|object Visitor {
| abstract class Default[A] extends Visitor[A] {
| def default: A
|
|${u
.subtypes
.map(sub =>
show"def ${sub.name.renderVisitorMethod}(node: ${sub.name.render}): A = default"
)
.mkString_("\n")
.indentTrim(4)}
| }
|}
|""".stripMargin

val selectorMethods = u
.subtypes
.map { subtype =>
Expand All @@ -82,6 +118,7 @@ private def renderUnion(u: Type.Union): String = {
|
|import ${classOf[Node].getName()}
|import playground.treesitter4s.std.Selection
|import annotation.nowarn
|
|opaque type $name <: Node = $underlyingType
|
Expand All @@ -97,6 +134,8 @@ private def renderUnion(u: Type.Union): String = {
|
| def unapply(node: Node): Option[$name] = apply(node).toOption
|
|${visitor.indentTrim(2)}
|
| final case class Selector(path: List[$name]) extends Selection[$name] {
|${selectorMethods.indentTrim(4)}
|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package playground.generated.nodes

import org.polyvariant.treesitter4s.Node
import playground.treesitter4s.std.Selection
import annotation.nowarn

opaque type InputNode <: Node = Boolean_ | List_ | Null_ | Number | String_ | Struct

Expand All @@ -15,6 +16,7 @@ object InputNode {
def asNumber: Option[Number] = Number.unapply(node)
def asString: Option[String_] = String_.unapply(node)
def asStruct: Option[Struct] = Struct.unapply(node)
def visit[A](visitor: Visitor[A]): A = visitor.visit(node)
}

def apply(node: Node): Either[String, InputNode] = node match {
Expand All @@ -33,6 +35,38 @@ object InputNode {

def unapply(node: Node): Option[InputNode] = apply(node).toOption


trait Visitor[A] {
def onBoolean(node: Boolean_): A
def onList(node: List_): A
def onNull(node: Null_): A
def onNumber(node: Number): A
def onString(node: String_): A
def onStruct(node: Struct): A

def visit(node: InputNode): A = (node: @nowarn("msg=match may not be exhaustive")) match {
case Boolean_(node) => onBoolean(node)
case List_(node) => onList(node)
case Null_(node) => onNull(node)
case Number(node) => onNumber(node)
case String_(node) => onString(node)
case Struct(node) => onStruct(node)
}
}

object Visitor {
abstract class Default[A] extends Visitor[A] {
def default: A

def onBoolean(node: Boolean_): A = default
def onList(node: List_): A = default
def onNull(node: Null_): A = default
def onNumber(node: Number): A = default
def onString(node: String_): A = default
def onStruct(node: Struct): A = default
}
}

final case class Selector(path: List[InputNode]) extends Selection[InputNode] {
def boolean : Boolean_.Selector = Boolean_.Selector(path.flatMap(_.asBoolean))
def list : List_.Selector = List_.Selector(path.flatMap(_.asList))
Expand Down

0 comments on commit 9c24e99

Please sign in to comment.