Skip to content

Commit

Permalink
Fix assemble
Browse files Browse the repository at this point in the history
  • Loading branch information
rchowell committed Jul 16, 2024
1 parent 2b4b440 commit a04b43b
Show file tree
Hide file tree
Showing 11 changed files with 37 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.partiql.ast.Statement
import org.partiql.errors.Problem
import org.partiql.errors.ProblemCallback
import org.partiql.plan.PartiQLPlan
import org.partiql.planner.catalog.Catalogs
import org.partiql.planner.catalog.Session
import org.partiql.planner.internal.PartiQLPlannerDefault
import org.partiql.planner.internal.PlannerFlag
Expand Down Expand Up @@ -36,7 +37,7 @@ public interface PartiQLPlanner {
public companion object {

@JvmStatic
public fun builder(): PartiQLPlannerBuilder = PartiQLPlannerBuilder()
public fun builder(): Builder = Builder()

@JvmStatic
public fun default(): PartiQLPlanner = Builder().build()
Expand All @@ -45,13 +46,22 @@ public interface PartiQLPlanner {
public class Builder {

private val flags: MutableSet<PlannerFlag> = mutableSetOf()
private var catalogs: Catalogs? = null

/**
* Build the builder, return an implementation of a [PartiQLPlanner].
*
* @return
*/
public fun build(): PartiQLPlanner = PartiQLPlannerDefault(flags)
public fun build(): PartiQLPlanner {
assert(catalogs != null) { "The `catalogs` field cannot be null, set with .catalgos(...)"}
return PartiQLPlannerDefault(catalogs!!, flags)
}

/**
* Adds a catalog provider to this planner builder.
*/
public fun catalogs(catalogs: Catalogs): Builder = this.apply { this.catalogs = catalogs }

/**
* Java style method for setting the planner to signal mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,10 @@ public class Identifier private constructor(
}

@JvmStatic
public fun of(vararg parts: String): Identifier = of(parts.toList())
public fun delimited(vararg parts: String): Identifier = delimited(parts.toList())

@JvmStatic
public fun of(parts: Collection<String>): Identifier {
public fun delimited(parts: Collection<String>): Identifier {
if (parts.isEmpty()) {
error("Cannot create an identifier with no parts")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public class Name(
val parts = mutableListOf<String>()
parts.addAll(namespace.getLevels())
parts.add(name)
return Identifier.of(parts).toString()
return Identifier.delimited(parts).toString()
}

public companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public class Namespace private constructor(
* Return the SQL identifier representation of this namespace.
*/
public override fun toString(): String {
return Identifier.of(*levels).toString()
return Identifier.delimited(*levels).toString()
}

public companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public sealed interface Routine {
/**
* The function return type. Required.
*/
public fun getReturnType(): PType.Kind
public fun getReturnType(): PType

/**
* Represents an SQL row-value expression call.
Expand Down Expand Up @@ -79,12 +79,12 @@ public sealed interface Routine {
public fun scalar(
name: String,
parameters: Collection<Parameter>,
returnType: PType.Kind,
returnType: PType,
properties: Properties = DEFAULT_PROPERTIES,
): Scalar = object : Scalar {
override fun getName(): String = name
override fun getParameters(): Array<Parameter> = parameters.toTypedArray()
override fun getReturnType(): PType.Kind = returnType
override fun getReturnType(): PType = returnType
override fun getProperties(): Properties = properties
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.partiql.planner.internal

import org.partiql.planner.catalog.Catalog
import org.partiql.planner.catalog.Catalogs
import org.partiql.planner.catalog.Identifier
import org.partiql.planner.catalog.Session
import org.partiql.planner.internal.casts.CastTable
Expand All @@ -19,7 +20,7 @@ import org.partiql.planner.internal.typer.CompilerType
* @property session
*/
internal class Env(
private val catalog: Catalog,
private val catalogs: Catalogs,
private val session: Session,
) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import org.partiql.ast.Statement
import org.partiql.ast.normalize.normalize
import org.partiql.errors.ProblemCallback
import org.partiql.planner.PartiQLPlanner
import org.partiql.planner.catalog.Catalog
import org.partiql.planner.catalog.Catalogs
import org.partiql.planner.catalog.Session
import org.partiql.planner.internal.transforms.AstToPlan
import org.partiql.planner.internal.transforms.PlanTransform
Expand All @@ -14,7 +14,7 @@ import org.partiql.planner.internal.typer.PlanTyper
* Default PartiQL logical query planner.
*/
internal class PartiQLPlannerDefault(
private val catalog: Catalog,
private val catalogs: Catalogs,
private val flags: Set<PlannerFlag>
) : PartiQLPlanner {

Expand All @@ -25,7 +25,7 @@ internal class PartiQLPlannerDefault(
): PartiQLPlanner.Result {

// 0. Initialize the planning environment
val env = Env(catalog, session)
val env = Env(catalogs, session)

// 1. Normalize
val ast = statement.normalize()
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -1418,23 +1418,6 @@ internal data class Rel(
internal fun builder(): RelOpExcludeTypeStructSymbolBuilder =
RelOpExcludeTypeStructSymbolBuilder()
}

// Explicitly override `equals` and `hashcode` for case-insensitivity
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false

other as StructSymbol

if (!symbol.equals(other.symbol, ignoreCase = true)) return false
if (children != other.children) return false

return true
}

override fun hashCode(): Int {
return symbol.lowercase().hashCode()
}
}

internal data class StructKey(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.partiql.planner.internal.ir.Ref
import org.partiql.planner.internal.ir.Rel
import org.partiql.planner.internal.ir.Rex
import org.partiql.planner.internal.ir.Statement
import org.partiql.planner.internal.ir.ref
import org.partiql.planner.internal.ir.visitor.PlanBaseVisitor
import org.partiql.value.PartiQLValueExperimental

Expand Down Expand Up @@ -143,7 +144,8 @@ internal class PlanTransform(
}

override fun visitRexOpCallStatic(node: Rex.Op.Call.Static, ctx: Unit): org.partiql.plan.Rex.Op {
val fn = visitRef(node.fn, ctx)
val ref = ref(node.fn.getName())
val fn = visitRef(ref, ctx)
val args = node.args.map { visitRex(it, ctx) }
return org.partiql.plan.rexOpCallStatic(fn, args)
}
Expand All @@ -161,7 +163,8 @@ internal class PlanTransform(
}

override fun visitRexOpCallDynamicCandidate(node: Rex.Op.Call.Dynamic.Candidate, ctx: Unit): PlanNode {
val fn = visitRef(node.fn, ctx)
val ref = ref(node.fn.getName())
val fn = visitRef(ref, ctx)
val coercions = node.coercions.map { it?.let { visitRefCast(it, ctx) } }
return org.partiql.plan.Rex.Op.Call.Dynamic.Candidate(fn, coercions)
}
Expand Down Expand Up @@ -364,7 +367,8 @@ internal class PlanTransform(
}

override fun visitRelOpAggregateCallResolved(node: Rel.Op.Aggregate.Call.Resolved, ctx: Unit): PlanNode {
val agg = visitRef(node.agg, ctx)
val ref = ref(node.agg.getName())
val agg = visitRef(ref, ctx)
val args = node.args.map { visitRex(it, ctx) }
val setQuantifier = when (node.setQuantifier) {
Rel.Op.Aggregate.SetQuantifier.ALL -> org.partiql.plan.Rel.Op.Aggregate.Call.SetQuantifier.ALL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package org.partiql.planner.internal.typer

import org.partiql.planner.catalog.Identifier
import org.partiql.planner.internal.Env
import org.partiql.planner.internal.ProblemGenerator
import org.partiql.planner.internal.exclude.ExcludeRepr
Expand Down Expand Up @@ -796,17 +795,17 @@ internal class PlanTyper(private val env: Env) {
}

// Check if any arg is always missing
val argIsAlwaysMissing = args.any { it.type.isMissingValue }
if (node.fn.signature.isMissingCall && argIsAlwaysMissing) {
val hasMissingArg = args.any { it.type.isMissingValue }
if (hasMissingArg) {
return ProblemGenerator.missingRex(
node,
ProblemGenerator.expressionAlwaysReturnsMissing("Static function always receives MISSING arguments."),
CompilerType(node.fn.signature.returns, isMissingValue = true)
CompilerType(node.fn.getReturnType(), isMissingValue = true)
)
}

// Infer fn return type
return rex(CompilerType(node.fn.signature.returns), Rex.Op.Call.Static(node.fn, args))
return rex(CompilerType(node.fn.getReturnType()), Rex.Op.Call.Static(node.fn, args))
}

/**
Expand All @@ -817,9 +816,7 @@ internal class PlanTyper(private val env: Env) {
* @return
*/
override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: CompilerType?): Rex {
val types = node.candidates.map { candidate ->
val kind = candidate.fn.getReturnType()
}.toMutableSet()
val types = node.candidates.map { candidate -> candidate.fn.getReturnType().toCType() }.toMutableSet()
// TODO: Should this always be DYNAMIC?
return Rex(type = CompilerType.anyOf(types), op = node)
}
Expand Down Expand Up @@ -1159,7 +1156,7 @@ internal class PlanTyper(private val env: Env) {
if (firstBranchCondition !is Rex.Op.Call.Static) {
return null
}
if (!firstBranchCondition.fn.signature.name.equals("is_struct", ignoreCase = true)) {
if (!firstBranchCondition.fn.getName().equals("is_struct", ignoreCase = true)) {
return null
}
val firstBranchResultType = firstBranch.rex.type
Expand Down

0 comments on commit a04b43b

Please sign in to comment.