diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index 34cb0668fcc9..7cd155489642 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -3,9 +3,7 @@ package is.hail.backend import is.hail.asm4s._ import is.hail.backend.Backend.jsonToBytes import is.hail.backend.spark.SparkBackend -import is.hail.expr.ir.{ - BaseIR, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader, -} +import is.hail.expr.ir.{IR, IRParser, LoweringAnalyses, SortField, TableIR, TableReader} import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs._ @@ -18,7 +16,6 @@ import is.hail.types.virtual.TFloat64 import is.hail.utils._ import is.hail.variant.ReferenceGenome -import scala.collection.mutable import scala.reflect.ClassTag import java.io._ @@ -80,8 +77,6 @@ abstract class Backend extends Closeable { StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build() ) - val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map() - def defaultParallelism: Int def canExecuteParallelTasksOnDriver: Boolean = true @@ -138,30 +133,30 @@ abstract class Backend extends Closeable { def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T final def valueType(s: String): Array[Byte] = - jsonToBytes { - withExecuteContext { ctx => - IRParser.parse_value_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON + withExecuteContext { ctx => + jsonToBytes { + IRParser.parse_value_ir(ctx, s).typ.toJSON } } final def tableType(s: String): Array[Byte] = - jsonToBytes { - withExecuteContext { ctx => - IRParser.parse_table_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON + withExecuteContext { ctx => + jsonToBytes { + IRParser.parse_table_ir(ctx, s).typ.toJSON } } final def matrixTableType(s: String): Array[Byte] = - jsonToBytes { - withExecuteContext { ctx => - IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON + withExecuteContext { ctx => + jsonToBytes { + IRParser.parse_matrix_ir(ctx, s).typ.toJSON } } final def blockMatrixType(s: String): Array[Byte] = - jsonToBytes { - withExecuteContext { ctx => - IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON + withExecuteContext { ctx => + jsonToBytes { + IRParser.parse_blockmatrix_ir(ctx, s).typ.toJSON } } diff --git a/hail/src/main/scala/is/hail/backend/BackendServer.scala b/hail/src/main/scala/is/hail/backend/BackendServer.scala index 04ab2200641e..24e7f7b98755 100644 --- a/hail/src/main/scala/is/hail/backend/BackendServer.scala +++ b/hail/src/main/scala/is/hail/backend/BackendServer.scala @@ -1,6 +1,6 @@ package is.hail.backend -import is.hail.expr.ir.{IRParser, IRParserEnvironment} +import is.hail.expr.ir.IRParser import is.hail.utils._ import scala.util.control.NonFatal @@ -89,10 +89,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler { backend.withExecuteContext { ctx => val (res, timings) = ExecutionTimer.time { timer => ctx.local(timer = timer) { ctx => - val irData = IRParser.parse_value_ir( - irStr, - IRParserEnvironment(ctx, irMap = backend.persistedIR.toMap), - ) + val irData = IRParser.parse_value_ir(ctx, irStr) backend.execute(ctx, irData) } } diff --git a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala index 42112522ffa0..a572bed01906 100644 --- a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala +++ b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala @@ -4,7 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations.{Region, RegionPool} import is.hail.asm4s.HailClassLoader import is.hail.backend.local.LocalTaskContext -import is.hail.expr.ir.{CodeCacheKey, CompiledFunction} +import is.hail.expr.ir.{BaseIR, CodeCacheKey, CompiledFunction} import is.hail.expr.ir.lowering.IrMetadata import is.hail.io.fs.FS import is.hail.linalg.BlockMatrix @@ -75,6 +75,7 @@ object ExecuteContext { irMetadata: IrMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix], codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], + irCache: mutable.Map[Int, BaseIR], )( f: ExecuteContext => T ): T = { @@ -95,6 +96,7 @@ object ExecuteContext { irMetadata, blockMatrixCache, codeCache, + irCache, ))(f(_)) } } @@ -126,6 +128,7 @@ class ExecuteContext( val irMetadata: IrMetadata, val BlockMatrixCache: mutable.Map[String, BlockMatrix], val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], + val IrCache: mutable.Map[Int, BaseIR], ) extends Closeable { val rngNonce: Long = @@ -196,6 +199,7 @@ class ExecuteContext( irMetadata: IrMetadata = this.irMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache, codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache, + irCache: mutable.Map[Int, BaseIR] = this.IrCache, )( f: ExecuteContext => A ): A = @@ -214,5 +218,6 @@ class ExecuteContext( irMetadata, blockMatrixCache, codeCache, + irCache, ))(f) } diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index 10498f6debf4..d89e995c1287 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -79,6 +79,7 @@ class LocalBackend( private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader) private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) + private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map() // flags can be set after construction from python def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) @@ -103,6 +104,7 @@ class LocalBackend( new IrMetadata(), ImmutableMap.empty, codeCache, + persistedIR, )(f) } diff --git a/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala index 85f3eee5c302..48e3e7f1c27a 100644 --- a/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala +++ b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala @@ -4,9 +4,9 @@ import is.hail.HailFeatureFlags import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager} import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} import is.hail.expr.ir.{ - BaseIR, BindingEnv, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IR, IRParser, - IRParserEnvironment, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, Name, - NativeReaderOptions, TableIR, TableLiteral, TableValue, + BaseIR, BindingEnv, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IR, IRParser, Interpret, + MatrixIR, MatrixNativeReader, MatrixRead, Name, NativeReaderOptions, TableIR, TableLiteral, + TableValue, } import is.hail.expr.ir.IRParser.parseType import is.hail.expr.ir.functions.IRFunctionRegistry @@ -34,7 +34,6 @@ import sourcecode.Enclosing trait Py4JBackendExtensions { def backend: Backend def references: mutable.Map[String, ReferenceGenome] - def persistedIR: mutable.Map[Int, BaseIR] def flags: HailFeatureFlags def longLifeTempFileManager: TempFileManager @@ -54,14 +53,14 @@ trait Py4JBackendExtensions { irID } - private[this] def addJavaIR(ir: BaseIR): Int = { + private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = { val id = nextIRID() - persistedIR += (id -> ir) + ctx.IrCache += (id -> ir) id } def pyRemoveJavaIR(id: Int): Unit = - persistedIR.remove(id) + backend.withExecuteContext(_.IrCache.remove(id)) def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = backend.withExecuteContext { ctx => @@ -118,7 +117,7 @@ trait Py4JBackendExtensions { argTypeStrs: java.util.ArrayList[String], returnType: String, bodyStr: String, - ): Unit = { + ): Unit = backend.withExecuteContext { ctx => IRFunctionRegistry.registerIR( ctx, @@ -130,17 +129,16 @@ trait Py4JBackendExtensions { bodyStr, ) } - } def pyExecuteLiteral(irStr: String): Int = backend.withExecuteContext { ctx => - val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap)) + val ir = IRParser.parse_value_ir(ctx, irStr) assert(ir.typ.isRealizable) backend.execute(ctx, ir) match { case Left(_) => throw new HailException("Can't create literal") case Right((pt, addr)) => val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) - addJavaIR(field) + addJavaIR(ctx, field) } } @@ -159,14 +157,14 @@ trait Py4JBackendExtensions { ), ctx.theHailClassLoader, ) - val id = addJavaIR(tir) + val id = addJavaIR(ctx, tir) (id, JsonMethods.compact(tir.typ.toJSON)) } } def pyToDF(s: String): DataFrame = backend.withExecuteContext { ctx => - val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + val tir = IRParser.parse_table_ir(ctx, s) Interpret(tir, ctx).toDF() } @@ -231,8 +229,8 @@ trait Py4JBackendExtensions { def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = backend.withExecuteContext { ctx => IRParser.parse_value_ir( + ctx, s, - IRParserEnvironment(ctx, irMap = persistedIR.toMap), BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) => Name(n) -> IRParser.parseType(t) }.toSeq: _*), @@ -240,18 +238,14 @@ trait Py4JBackendExtensions { } def parse_table_ir(s: String): TableIR = - withExecuteContext(selfContainedExecution = false) { ctx => - IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - } + withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_table_ir(ctx, s)) def parse_matrix_ir(s: String): MatrixIR = - withExecuteContext(selfContainedExecution = false) { ctx => - IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - } + withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_matrix_ir(ctx, s)) def parse_blockmatrix_ir(s: String): BlockMatrixIR = withExecuteContext(selfContainedExecution = false) { ctx => - IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + IRParser.parse_blockmatrix_ir(ctx, s) } def loadReferencesFromDataset(path: String): Array[Byte] = diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 1ff50db0e803..ca9d63fdb4e1 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -390,6 +390,7 @@ class ServiceBackend( new IrMetadata(), ImmutableMap.empty, mutable.Map.empty, + ImmutableMap.empty, )(f) } diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index baad79e9b1e2..9b7ea5e7ad7e 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -341,6 +341,7 @@ class SparkBackend( private[this] val bmCache = new BlockMatrixCache() private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) + private[this] val persistedIr = mutable.Map.empty[Int, BaseIR] def createExecuteContextForTests( timer: ExecutionTimer, @@ -365,6 +366,7 @@ class SparkBackend( new IrMetadata(), ImmutableMap.empty, mutable.Map.empty, + ImmutableMap.empty, ) override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = @@ -386,6 +388,7 @@ class SparkBackend( new IrMetadata(), bmCache, codeCache, + persistedIr, )(f) } diff --git a/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala index 1b8f03ef08db..d760850a6964 100644 --- a/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala @@ -113,14 +113,14 @@ case class MatrixLiteral(typ: MatrixType, tl: TableLiteral) extends MatrixIR { } object MatrixReader { - def fromJson(env: IRParserEnvironment, jv: JValue): MatrixReader = { + def fromJson(ctx: ExecuteContext, jv: JValue): MatrixReader = { implicit val formats: Formats = DefaultFormats (jv \ "name").extract[String] match { - case "MatrixRangeReader" => MatrixRangeReader.fromJValue(env.ctx, jv) - case "MatrixNativeReader" => MatrixNativeReader.fromJValue(env.ctx.fs, jv) - case "MatrixBGENReader" => MatrixBGENReader.fromJValue(env, jv) - case "MatrixPLINKReader" => MatrixPLINKReader.fromJValue(env.ctx, jv) - case "MatrixVCFReader" => MatrixVCFReader.fromJValue(env.ctx, jv) + case "MatrixRangeReader" => MatrixRangeReader.fromJValue(ctx, jv) + case "MatrixNativeReader" => MatrixNativeReader.fromJValue(ctx.fs, jv) + case "MatrixBGENReader" => MatrixBGENReader.fromJValue(ctx, jv) + case "MatrixPLINKReader" => MatrixPLINKReader.fromJValue(ctx, jv) + case "MatrixVCFReader" => MatrixVCFReader.fromJValue(ctx, jv) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/Parser.scala b/hail/src/main/scala/is/hail/expr/ir/Parser.scala index f7698a025d9b..7cb771a30923 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Parser.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Parser.scala @@ -1,6 +1,5 @@ package is.hail.expr.ir -import is.hail.HailContext import is.hail.backend.ExecuteContext import is.hail.expr.{JSONAnnotationImpex, Nat, ParserUtils} import is.hail.expr.ir.agg._ @@ -127,12 +126,8 @@ object IRLexer extends JavaTokenParsers { } } -case class IRParserEnvironment( - ctx: ExecuteContext, - irMap: Map[Int, BaseIR] = Map.empty, -) - object IRParser { + def error(t: Token, msg: String): Nothing = ParserUtils.error(t.pos, msg) def deserialize[T](str: String)(implicit formats: Formats, mf: Manifest[T]): T = { @@ -241,13 +236,13 @@ object IRParser { case x: Token => error(x, s"Expected string but found ${x.getName} '${x.value}'.") } - def partitioner_literal(env: IRParserEnvironment)(it: TokenIterator): RVDPartitioner = { + def partitioner_literal(ctx: ExecuteContext)(it: TokenIterator): RVDPartitioner = { identifier(it, "Partitioner") val keyType = type_expr(it).asInstanceOf[TStruct] val vJSON = JsonMethods.parse(string_literal(it)) val rangeBounds = JSONAnnotationImpex.importAnnotation(vJSON, TArray(TInterval(keyType))) new RVDPartitioner( - env.ctx.stateManager, + ctx.stateManager, keyType, rangeBounds.asInstanceOf[mutable.IndexedSeq[Interval]], ) @@ -667,7 +662,7 @@ object IRParser { def agg_op(it: TokenIterator): AggOp = AggOp.fromString(identifier(it)) - def agg_state_signature(env: IRParserEnvironment)(it: TokenIterator): AggStateSig = { + def agg_state_signature(ctx: ExecuteContext)(it: TokenIterator): AggStateSig = { punctuation(it, "(") val sig = identifier(it) match { case "TypedStateSig" => @@ -697,46 +692,46 @@ object IRParser { CollectAsSetStateSig(pt) case "CallStatsStateSig" => CallStatsStateSig() case "ArrayAggStateSig" => - val nested = agg_state_signatures(env)(it) + val nested = agg_state_signatures(ctx)(it) ArrayAggStateSig(nested) case "GroupedStateSig" => val kt = vtwr_expr(it) - val nested = agg_state_signatures(env)(it) + val nested = agg_state_signatures(ctx)(it) GroupedStateSig(kt, nested) case "ApproxCDFStateSig" => ApproxCDFStateSig() case "FoldStateSig" => val vtwr = vtwr_expr(it) val accumName = name(it) val otherAccumName = name(it) - val combIR = ir_value_expr(env)(it).run() + val combIR = ir_value_expr(ctx)(it).run() FoldStateSig(vtwr.canonicalEmitType, accumName, otherAccumName, combIR) } punctuation(it, ")") sig } - def agg_state_signatures(env: IRParserEnvironment)(it: TokenIterator): Array[AggStateSig] = - base_seq_parser(agg_state_signature(env))(it) + def agg_state_signatures(ctx: ExecuteContext)(it: TokenIterator): Array[AggStateSig] = + base_seq_parser(agg_state_signature(ctx))(it) - def p_agg_sigs(env: IRParserEnvironment)(it: TokenIterator): Array[PhysicalAggSig] = - base_seq_parser(p_agg_sig(env))(it) + def p_agg_sigs(ctx: ExecuteContext)(it: TokenIterator): Array[PhysicalAggSig] = + base_seq_parser(p_agg_sig(ctx))(it) - def p_agg_sig(env: IRParserEnvironment)(it: TokenIterator): PhysicalAggSig = { + def p_agg_sig(ctx: ExecuteContext)(it: TokenIterator): PhysicalAggSig = { punctuation(it, "(") val sig = identifier(it) match { case "Grouped" => val pt = vtwr_expr(it) - val nested = p_agg_sigs(env)(it) + val nested = p_agg_sigs(ctx)(it) GroupedAggSig(pt, nested) case "ArrayLen" => val knownLength = boolean_literal(it) - val nested = p_agg_sigs(env)(it) + val nested = p_agg_sigs(ctx)(it) ArrayLenAggSig(knownLength, nested) case "AggElements" => - val nested = p_agg_sigs(env)(it) + val nested = p_agg_sigs(ctx)(it) AggElementsAggSig(nested) case op => - val state = agg_state_signature(env)(it) + val state = agg_state_signature(ctx)(it) PhysicalAggSig(AggOp.fromString(op), state) } punctuation(it, ")") @@ -763,40 +758,39 @@ object IRParser { (typ, v) } - def named_value_irs(env: IRParserEnvironment)(it: TokenIterator) - : StackFrame[Array[(String, IR)]] = - repUntil(it, named_value_ir(env), PunctuationToken(")")) + def named_value_irs(ctx: ExecuteContext)(it: TokenIterator): StackFrame[Array[(String, IR)]] = + repUntil(it, named_value_ir(ctx), PunctuationToken(")")) - def named_value_ir(env: IRParserEnvironment)(it: TokenIterator): StackFrame[(String, IR)] = { + def named_value_ir(ctx: ExecuteContext)(it: TokenIterator): StackFrame[(String, IR)] = { punctuation(it, "(") val name = identifier(it) - ir_value_expr(env)(it).map { value => + ir_value_expr(ctx)(it).map { value => punctuation(it, ")") (name, value) } } - def ir_value_exprs(env: IRParserEnvironment)(it: TokenIterator): StackFrame[Array[IR]] = { + def ir_value_exprs(ctx: ExecuteContext)(it: TokenIterator): StackFrame[Array[IR]] = { punctuation(it, "(") for { - irs <- ir_value_children(env)(it) + irs <- ir_value_children(ctx)(it) _ = punctuation(it, ")") } yield irs } - def ir_value_children(env: IRParserEnvironment)(it: TokenIterator): StackFrame[Array[IR]] = - repUntil(it, ir_value_expr(env), PunctuationToken(")")) + def ir_value_children(ctx: ExecuteContext)(it: TokenIterator): StackFrame[Array[IR]] = + repUntil(it, ir_value_expr(ctx), PunctuationToken(")")) - def ir_value_expr(env: IRParserEnvironment)(it: TokenIterator): StackFrame[IR] = { + def ir_value_expr(ctx: ExecuteContext)(it: TokenIterator): StackFrame[IR] = { punctuation(it, "(") for { - ir <- call(ir_value_expr_1(env)(it)) + ir <- call(ir_value_expr_1(ctx)(it)) _ = punctuation(it, ")") } yield ir } def apply_like( - env: IRParserEnvironment, + ctx: ExecuteContext, cons: (String, Seq[Type], Seq[IR], Type, Int) => IR, )( it: TokenIterator @@ -805,10 +799,10 @@ object IRParser { val function = identifier(it) val typeArgs = type_exprs(it) val rt = type_expr(it) - ir_value_children(env)(it).map(args => cons(function, typeArgs, args, rt, errorID)) + ir_value_children(ctx)(it).map(args => cons(function, typeArgs, args, rt, errorID)) } - def ir_value_expr_1(env: IRParserEnvironment)(it: TokenIterator): StackFrame[IR] = { + def ir_value_expr_1(ctx: ExecuteContext)(it: TokenIterator): StackFrame[IR] = { identifier(it) match { case "I32" => done(I32(int32_literal(it))) case "I64" => done(I64(int64_literal(it))) @@ -833,28 +827,28 @@ object IRParser { case "Void" => done(Void()) case "Cast" => val typ = type_expr(it) - ir_value_expr(env)(it).map(Cast(_, typ)) + ir_value_expr(ctx)(it).map(Cast(_, typ)) case "CastRename" => val typ = type_expr(it) - ir_value_expr(env)(it).map(CastRename(_, typ)) + ir_value_expr(ctx)(it).map(CastRename(_, typ)) case "NA" => done(NA(type_expr(it))) - case "IsNA" => ir_value_expr(env)(it).map(IsNA) + case "IsNA" => ir_value_expr(ctx)(it).map(IsNA) case "Coalesce" => for { - children <- ir_value_children(env)(it) + children <- ir_value_children(ctx)(it) _ = require(children.nonEmpty) } yield Coalesce(children) case "If" => for { - cond <- ir_value_expr(env)(it) - consq <- ir_value_expr(env)(it) - altr <- ir_value_expr(env)(it) + cond <- ir_value_expr(ctx)(it) + consq <- ir_value_expr(ctx)(it) + altr <- ir_value_expr(ctx)(it) } yield If(cond, consq, altr) case "Switch" => for { - x <- ir_value_expr(env)(it) - default <- ir_value_expr(env)(it) - cases <- ir_value_children(env)(it) + x <- ir_value_expr(ctx)(it) + default <- ir_value_expr(ctx)(it) + cases <- ir_value_children(ctx)(it) } yield Switch(x, default, cases) case "Let" | "Block" => val names = @@ -864,10 +858,10 @@ object IRParser { _ <- names.indices.foldLeft(done(())) { case (update, i) => for { _ <- update - value <- ir_value_expr(env)(it) + value <- ir_value_expr(ctx)(it) } yield values.update(i, value) } - body <- ir_value_expr(env)(it) + body <- ir_value_expr(ctx)(it) } yield { val bindings = (names, values).zipped.map { case ((bindType, name), value) => val scope = bindType match { @@ -883,21 +877,21 @@ object IRParser { val n = name(it) val isScan = boolean_literal(it) for { - value <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + value <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield AggLet(n, value, body, isScan) case "TailLoop" => val n = name(it) val paramNames = names(it) val resultType = type_expr(it) for { - paramIRs <- fillArray(paramNames.length)(ir_value_expr(env)(it)) + paramIRs <- fillArray(paramNames.length)(ir_value_expr(ctx)(it)) params = paramNames.zip(paramIRs) - body <- ir_value_expr(env)(it) + body <- ir_value_expr(ctx)(it) } yield TailLoop(n, params, resultType, body) case "Recur" => val n = name(it) - ir_value_children(env)(it).map(args => Recur(n, args, null)) + ir_value_children(ctx)(it).map(args => Recur(n, args, null)) case "Ref" => val id = name(it) done(Ref(id, null)) @@ -908,42 +902,42 @@ object IRParser { case "RelationalLet" => val n = name(it) for { - value <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + value <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield RelationalLet(n, value, body) case "ApplyBinaryPrimOp" => val op = BinaryOp.fromString(identifier(it)) for { - l <- ir_value_expr(env)(it) - r <- ir_value_expr(env)(it) + l <- ir_value_expr(ctx)(it) + r <- ir_value_expr(ctx)(it) } yield ApplyBinaryPrimOp(op, l, r) case "ApplyUnaryPrimOp" => val op = UnaryOp.fromString(identifier(it)) - ir_value_expr(env)(it).map(ApplyUnaryPrimOp(op, _)) + ir_value_expr(ctx)(it).map(ApplyUnaryPrimOp(op, _)) case "ApplyComparisonOp" => val opName = identifier(it) for { - l <- ir_value_expr(env)(it) - r <- ir_value_expr(env)(it) + l <- ir_value_expr(ctx)(it) + r <- ir_value_expr(ctx)(it) } yield ApplyComparisonOp(ComparisonOp.fromString(opName), l, r) case "MakeArray" => val typ = opt(it, type_expr).map(_.asInstanceOf[TArray]).orNull - ir_value_children(env)(it).map(args => MakeArray(args, typ)) + ir_value_children(ctx)(it).map(args => MakeArray(args, typ)) case "MakeStream" => val typ = opt(it, type_expr).map(_.asInstanceOf[TStream]).orNull val requiresMemoryManagementPerElement = boolean_literal(it) - ir_value_children(env)(it).map { args => + ir_value_children(ctx)(it).map { args => MakeStream(args, typ, requiresMemoryManagementPerElement) } case "ArrayRef" => val errorID = int32_literal(it) for { - a <- ir_value_expr(env)(it) - i <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + i <- ir_value_expr(ctx)(it) } yield ArrayRef(a, i, errorID) case "ArraySlice" => val errorID = int32_literal(it) - ir_value_children(env)(it).map { + ir_value_children(ctx)(it).map { case Array(a, start, step) => ArraySlice(a, start, None, step, errorID) case Array(a, start, stop, step) => ArraySlice(a, start, Some(stop), step, errorID) } @@ -951,46 +945,46 @@ object IRParser { done(RNGStateLiteral()) case "RNGSplit" => for { - state <- ir_value_expr(env)(it) - dynBitstring <- ir_value_expr(env)(it) + state <- ir_value_expr(ctx)(it) + dynBitstring <- ir_value_expr(ctx)(it) } yield RNGSplit(state, dynBitstring) - case "ArrayLen" => ir_value_expr(env)(it).map(ArrayLen) - case "StreamLen" => ir_value_expr(env)(it).map(StreamLen) + case "ArrayLen" => ir_value_expr(ctx)(it).map(ArrayLen) + case "StreamLen" => ir_value_expr(ctx)(it).map(StreamLen) case "StreamIota" => val requiresMemoryManagementPerElement = boolean_literal(it) for { - start <- ir_value_expr(env)(it) - step <- ir_value_expr(env)(it) + start <- ir_value_expr(ctx)(it) + step <- ir_value_expr(ctx)(it) } yield StreamIota(start, step, requiresMemoryManagementPerElement) case "StreamRange" => val errorID = int32_literal(it) val requiresMemoryManagementPerElement = boolean_literal(it) for { - start <- ir_value_expr(env)(it) - stop <- ir_value_expr(env)(it) - step <- ir_value_expr(env)(it) + start <- ir_value_expr(ctx)(it) + stop <- ir_value_expr(ctx)(it) + step <- ir_value_expr(ctx)(it) } yield StreamRange(start, stop, step, requiresMemoryManagementPerElement, errorID) case "StreamGrouped" => for { - s <- ir_value_expr(env)(it) - groupSize <- ir_value_expr(env)(it) + s <- ir_value_expr(ctx)(it) + groupSize <- ir_value_expr(ctx)(it) } yield StreamGrouped(s, groupSize) - case "ArrayZeros" => ir_value_expr(env)(it).map(ArrayZeros) + case "ArrayZeros" => ir_value_expr(ctx)(it).map(ArrayZeros) case "ArraySort" => val l = name(it) val r = name(it) for { - a <- ir_value_expr(env)(it) - lessThan <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + lessThan <- ir_value_expr(ctx)(it) } yield ArraySort(a, l, r, lessThan) case "ArrayMaximalIndependentSet" => val hasTieBreaker = boolean_literal(it) val bindings = if (hasTieBreaker) Some(name(it) -> name(it)) else None for { - edges <- ir_value_expr(env)(it) + edges <- ir_value_expr(ctx)(it) tieBreaker <- if (hasTieBreaker) { val Some((left, right)) = bindings - ir_value_expr(env)(it).map(tbf => Some((left, right, tbf))) + ir_value_expr(ctx)(it).map(tbf => Some((left, right, tbf))) } else { done(None) } @@ -998,113 +992,113 @@ object IRParser { case "MakeNDArray" => val errorID = int32_literal(it) for { - data <- ir_value_expr(env)(it) - shape <- ir_value_expr(env)(it) - rowMajor <- ir_value_expr(env)(it) + data <- ir_value_expr(ctx)(it) + shape <- ir_value_expr(ctx)(it) + rowMajor <- ir_value_expr(ctx)(it) } yield MakeNDArray(data, shape, rowMajor, errorID) - case "NDArrayShape" => ir_value_expr(env)(it).map(NDArrayShape) + case "NDArrayShape" => ir_value_expr(ctx)(it).map(NDArrayShape) case "NDArrayReshape" => val errorID = int32_literal(it) for { - nd <- ir_value_expr(env)(it) - shape <- ir_value_expr(env)(it) + nd <- ir_value_expr(ctx)(it) + shape <- ir_value_expr(ctx)(it) } yield NDArrayReshape(nd, shape, errorID) case "NDArrayConcat" => val axis = int32_literal(it) - ir_value_expr(env)(it).map(nds => NDArrayConcat(nds, axis)) + ir_value_expr(ctx)(it).map(nds => NDArrayConcat(nds, axis)) case "NDArrayMap" => val n = name(it) for { - nd <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + nd <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield NDArrayMap(nd, n, body) case "NDArrayMap2" => val errorID = int32_literal(it) val lName = name(it) val rName = name(it) for { - l <- ir_value_expr(env)(it) - r <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + l <- ir_value_expr(ctx)(it) + r <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield NDArrayMap2(l, r, lName, rName, body, errorID) case "NDArrayReindex" => val indexExpr = int32_literals(it) - ir_value_expr(env)(it).map(nd => NDArrayReindex(nd, indexExpr)) + ir_value_expr(ctx)(it).map(nd => NDArrayReindex(nd, indexExpr)) case "NDArrayAgg" => val axes = int32_literals(it) - ir_value_expr(env)(it).map(nd => NDArrayAgg(nd, axes)) + ir_value_expr(ctx)(it).map(nd => NDArrayAgg(nd, axes)) case "NDArrayRef" => val errorID = int32_literal(it) for { - nd <- ir_value_expr(env)(it) - idxs <- ir_value_children(env)(it) + nd <- ir_value_expr(ctx)(it) + idxs <- ir_value_children(ctx)(it) } yield NDArrayRef(nd, idxs, errorID) case "NDArraySlice" => for { - nd <- ir_value_expr(env)(it) - slices <- ir_value_expr(env)(it) + nd <- ir_value_expr(ctx)(it) + slices <- ir_value_expr(ctx)(it) } yield NDArraySlice(nd, slices) case "NDArrayFilter" => for { - nd <- ir_value_expr(env)(it) - filters <- repUntil(it, ir_value_expr(env), PunctuationToken(")")) + nd <- ir_value_expr(ctx)(it) + filters <- repUntil(it, ir_value_expr(ctx), PunctuationToken(")")) } yield NDArrayFilter(nd, filters.toFastSeq) case "NDArrayMatMul" => val errorID = int32_literal(it) for { - l <- ir_value_expr(env)(it) - r <- ir_value_expr(env)(it) + l <- ir_value_expr(ctx)(it) + r <- ir_value_expr(ctx)(it) } yield NDArrayMatMul(l, r, errorID) case "NDArrayWrite" => for { - nd <- ir_value_expr(env)(it) - path <- ir_value_expr(env)(it) + nd <- ir_value_expr(ctx)(it) + path <- ir_value_expr(ctx)(it) } yield NDArrayWrite(nd, path) case "NDArrayQR" => val errorID = int32_literal(it) val mode = string_literal(it) - ir_value_expr(env)(it).map(nd => NDArrayQR(nd, mode, errorID)) + ir_value_expr(ctx)(it).map(nd => NDArrayQR(nd, mode, errorID)) case "NDArraySVD" => val errorID = int32_literal(it) val fullMatrices = boolean_literal(it) val computeUV = boolean_literal(it) - ir_value_expr(env)(it).map(nd => NDArraySVD(nd, fullMatrices, computeUV, errorID)) + ir_value_expr(ctx)(it).map(nd => NDArraySVD(nd, fullMatrices, computeUV, errorID)) case "NDArrayEigh" => val errorID = int32_literal(it) val eigvalsOnly = boolean_literal(it) - ir_value_expr(env)(it).map(nd => NDArrayEigh(nd, eigvalsOnly, errorID)) + ir_value_expr(ctx)(it).map(nd => NDArrayEigh(nd, eigvalsOnly, errorID)) case "NDArrayInv" => val errorID = int32_literal(it) - ir_value_expr(env)(it).map(nd => NDArrayInv(nd, errorID)) - case "ToSet" => ir_value_expr(env)(it).map(ToSet) - case "ToDict" => ir_value_expr(env)(it).map(ToDict) - case "ToArray" => ir_value_expr(env)(it).map(ToArray) - case "CastToArray" => ir_value_expr(env)(it).map(CastToArray) + ir_value_expr(ctx)(it).map(nd => NDArrayInv(nd, errorID)) + case "ToSet" => ir_value_expr(ctx)(it).map(ToSet) + case "ToDict" => ir_value_expr(ctx)(it).map(ToDict) + case "ToArray" => ir_value_expr(ctx)(it).map(ToArray) + case "CastToArray" => ir_value_expr(ctx)(it).map(CastToArray) case "ToStream" => val requiresMemoryManagementPerElement = boolean_literal(it) - ir_value_expr(env)(it).map(a => ToStream(a, requiresMemoryManagementPerElement)) + ir_value_expr(ctx)(it).map(a => ToStream(a, requiresMemoryManagementPerElement)) case "LowerBoundOnOrderedCollection" => val onKey = boolean_literal(it) for { - col <- ir_value_expr(env)(it) - elem <- ir_value_expr(env)(it) + col <- ir_value_expr(ctx)(it) + elem <- ir_value_expr(ctx)(it) } yield LowerBoundOnOrderedCollection(col, elem, onKey) - case "GroupByKey" => ir_value_expr(env)(it).map(GroupByKey) + case "GroupByKey" => ir_value_expr(ctx)(it).map(GroupByKey) case "StreamMap" => val n = name(it) for { - a <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield StreamMap(a, n, body) case "StreamTake" => for { - a <- ir_value_expr(env)(it) - num <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + num <- ir_value_expr(ctx)(it) } yield StreamTake(a, num) case "StreamDrop" => for { - a <- ir_value_expr(env)(it) - num <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + num <- ir_value_expr(ctx)(it) } yield StreamDrop(a, num) case "StreamZip" => val errorID = int32_literal(it) @@ -1116,8 +1110,8 @@ object IRParser { } val ns = names(it) for { - as <- ns.mapRecur(_ => ir_value_expr(env)(it)) - body <- ir_value_expr(env)(it) + as <- ns.mapRecur(_ => ir_value_expr(ctx)(it)) + body <- ir_value_expr(ctx)(it) } yield StreamZip(as, ns, body, behavior, errorID) case "StreamZipJoinProducers" => val key = identifiers(it) @@ -1125,10 +1119,10 @@ object IRParser { val curKey = name(it) val curVals = name(it) for { - ctxs <- ir_value_expr(env)(it) - makeProducer <- ir_value_expr(env)(it) + ctxs <- ir_value_expr(ctx)(it) + makeProducer <- ir_value_expr(ctx)(it) body <- - ir_value_expr(env)(it) + ir_value_expr(ctx)(it) } yield StreamZipJoinProducers(ctxs, ctxName, makeProducer, key, curKey, curVals, body) case "StreamZipJoin" => val nStreams = int32_literal(it) @@ -1136,64 +1130,64 @@ object IRParser { val curKey = name(it) val curVals = name(it) for { - streams <- (0 until nStreams).mapRecur(_ => ir_value_expr(env)(it)) + streams <- (0 until nStreams).mapRecur(_ => ir_value_expr(ctx)(it)) body <- - ir_value_expr(env)(it) + ir_value_expr(ctx)(it) } yield StreamZipJoin(streams, key, curKey, curVals, body) case "StreamMultiMerge" => val key = identifiers(it) for { - streams <- ir_value_exprs(env)(it) + streams <- ir_value_exprs(ctx)(it) } yield StreamMultiMerge(streams, key) case "StreamFilter" => val n = name(it) for { - a <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield StreamFilter(a, n, body) case "StreamTakeWhile" => val n = name(it) for { - a <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield StreamTakeWhile(a, n, body) case "StreamDropWhile" => val n = name(it) for { - a <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield StreamDropWhile(a, n, body) case "StreamFlatMap" => val n = name(it) for { - a <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield StreamFlatMap(a, n, body) case "StreamFold" => val accumName = name(it) val valueName = name(it) for { - a <- ir_value_expr(env)(it) - zero <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + zero <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield StreamFold(a, zero, accumName, valueName, body) case "StreamFold2" => val accumNames = names(it) val valueName = name(it) for { - a <- ir_value_expr(env)(it) - accIRs <- fillArray(accumNames.length)(ir_value_expr(env)(it)) + a <- ir_value_expr(ctx)(it) + accIRs <- fillArray(accumNames.length)(ir_value_expr(ctx)(it)) accs = accumNames.zip(accIRs) - seqs <- fillArray(accs.length)(ir_value_expr(env)(it)) - res <- ir_value_expr(env)(it) + seqs <- fillArray(accs.length)(ir_value_expr(ctx)(it)) + res <- ir_value_expr(ctx)(it) } yield StreamFold2(a, accs, valueName, seqs, res) case "StreamScan" => val accumName = name(it) val valueName = name(it) for { - a <- ir_value_expr(env)(it) - zero <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + zero <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield StreamScan(a, zero, accumName, valueName, body) case "StreamWhiten" => val newChunk = identifier(it) @@ -1204,7 +1198,7 @@ object IRParser { val blockSize = int32_literal(it) val normalizeAfterWhitening = boolean_literal(it) for { - stream <- ir_value_expr(env)(it) + stream <- ir_value_expr(ctx)(it) } yield StreamWhiten(stream, newChunk, prevWindow, vecSize, windowSize, chunkSize, blockSize, normalizeAfterWhitening) case "StreamJoinRightDistinct" => @@ -1214,9 +1208,9 @@ object IRParser { val r = name(it) val joinType = identifier(it) for { - left <- ir_value_expr(env)(it) - right <- ir_value_expr(env)(it) - join <- ir_value_expr(env)(it) + left <- ir_value_expr(ctx)(it) + right <- ir_value_expr(ctx)(it) + join <- ir_value_expr(ctx)(it) } yield StreamJoinRightDistinct(left, right, lKey, rKey, l, r, join, joinType) case "StreamLeftIntervalJoin" => val lKeyFieldName = identifier(it) @@ -1224,63 +1218,63 @@ object IRParser { val lname = name(it) val rname = name(it) for { - left <- ir_value_expr(env)(it) - right <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + left <- ir_value_expr(ctx)(it) + right <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield StreamLeftIntervalJoin(left, right, lKeyFieldName, rIntervalName, lname, rname, body) case "StreamFor" => val n = name(it) for { - a <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield StreamFor(a, n, body) case "StreamAgg" => val n = name(it) for { - a <- ir_value_expr(env)(it) - query <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + query <- ir_value_expr(ctx)(it) } yield StreamAgg(a, n, query) case "StreamAggScan" => val n = name(it) for { - a <- ir_value_expr(env)(it) - query <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + query <- ir_value_expr(ctx)(it) } yield StreamAggScan(a, n, query) case "RunAgg" => - val signatures = agg_state_signatures(env)(it) + val signatures = agg_state_signatures(ctx)(it) for { - body <- ir_value_expr(env)(it) - result <- ir_value_expr(env)(it) + body <- ir_value_expr(ctx)(it) + result <- ir_value_expr(ctx)(it) } yield RunAgg(body, result, signatures) case "RunAggScan" => val n = name(it) - val signatures = agg_state_signatures(env)(it) + val signatures = agg_state_signatures(ctx)(it) for { - array <- ir_value_expr(env)(it) - init <- ir_value_expr(env)(it) - seq <- ir_value_expr(env)(it) - result <- ir_value_expr(env)(it) + array <- ir_value_expr(ctx)(it) + init <- ir_value_expr(ctx)(it) + seq <- ir_value_expr(ctx)(it) + result <- ir_value_expr(ctx)(it) } yield RunAggScan(array, n, init, seq, result, signatures) case "AggFilter" => val isScan = boolean_literal(it) for { - cond <- ir_value_expr(env)(it) - aggIR <- ir_value_expr(env)(it) + cond <- ir_value_expr(ctx)(it) + aggIR <- ir_value_expr(ctx)(it) } yield AggFilter(cond, aggIR, isScan) case "AggExplode" => val n = name(it) val isScan = boolean_literal(it) for { - a <- ir_value_expr(env)(it) - aggBody <- ir_value_expr(env)(it) + a <- ir_value_expr(ctx)(it) + aggBody <- ir_value_expr(ctx)(it) } yield AggExplode(a, n, aggBody, isScan) case "AggGroupBy" => val isScan = boolean_literal(it) for { - key <- ir_value_expr(env)(it) - aggIR <- ir_value_expr(env)(it) + key <- ir_value_expr(ctx)(it) + aggIR <- ir_value_expr(ctx)(it) } yield AggGroupBy(key, aggIR, isScan) case "AggArrayPerElement" => val elementName = name(it) @@ -1288,22 +1282,22 @@ object IRParser { val isScan = boolean_literal(it) val hasKnownLength = boolean_literal(it) for { - a <- ir_value_expr(env)(it) - aggBody <- ir_value_expr(env)(it) - knownLength <- if (hasKnownLength) ir_value_expr(env)(it).map(Some(_)) else done(None) + a <- ir_value_expr(ctx)(it) + aggBody <- ir_value_expr(ctx)(it) + knownLength <- if (hasKnownLength) ir_value_expr(ctx)(it).map(Some(_)) else done(None) } yield AggArrayPerElement(a, elementName, indexName, aggBody, knownLength, isScan) case "ApplyAggOp" => val aggOp = agg_op(it) for { - initOpArgs <- ir_value_exprs(env)(it) - seqOpArgs <- ir_value_exprs(env)(it) + initOpArgs <- ir_value_exprs(ctx)(it) + seqOpArgs <- ir_value_exprs(ctx)(it) aggSig = AggSignature(aggOp, null, null) } yield ApplyAggOp(initOpArgs, seqOpArgs, aggSig) case "ApplyScanOp" => val aggOp = agg_op(it) for { - initOpArgs <- ir_value_exprs(env)(it) - seqOpArgs <- ir_value_exprs(env)(it) + initOpArgs <- ir_value_exprs(ctx)(it) + seqOpArgs <- ir_value_exprs(ctx)(it) aggSig = AggSignature(aggOp, null, null) } yield ApplyScanOp(initOpArgs, seqOpArgs, aggSig) case "AggFold" => @@ -1311,164 +1305,164 @@ object IRParser { val otherAccumName = name(it) val isScan = boolean_literal(it) for { - zero <- ir_value_expr(env)(it) - seqOp <- ir_value_expr(env)(it) - combOp <- ir_value_expr(env)(it) + zero <- ir_value_expr(ctx)(it) + seqOp <- ir_value_expr(ctx)(it) + combOp <- ir_value_expr(ctx)(it) } yield AggFold(zero, seqOp, combOp, accumName, otherAccumName, isScan) case "InitOp" => val i = int32_literal(it) - val aggSig = p_agg_sig(env)(it) - ir_value_exprs(env)(it).map(args => InitOp(i, args, aggSig)) + val aggSig = p_agg_sig(ctx)(it) + ir_value_exprs(ctx)(it).map(args => InitOp(i, args, aggSig)) case "SeqOp" => val i = int32_literal(it) - val aggSig = p_agg_sig(env)(it) - ir_value_exprs(env)(it).map(args => SeqOp(i, args, aggSig)) + val aggSig = p_agg_sig(ctx)(it) + ir_value_exprs(ctx)(it).map(args => SeqOp(i, args, aggSig)) case "CombOp" => val i1 = int32_literal(it) val i2 = int32_literal(it) - val aggSig = p_agg_sig(env)(it) + val aggSig = p_agg_sig(ctx)(it) done(CombOp(i1, i2, aggSig)) case "ResultOp" => val i = int32_literal(it) - val aggSig = p_agg_sig(env)(it) + val aggSig = p_agg_sig(ctx)(it) done(ResultOp(i, aggSig)) case "AggStateValue" => val i = int32_literal(it) - val sig = agg_state_signature(env)(it) + val sig = agg_state_signature(ctx)(it) done(AggStateValue(i, sig)) case "InitFromSerializedValue" => val i = int32_literal(it) - val sig = agg_state_signature(env)(it) - ir_value_expr(env)(it).map(value => InitFromSerializedValue(i, value, sig)) + val sig = agg_state_signature(ctx)(it) + ir_value_expr(ctx)(it).map(value => InitFromSerializedValue(i, value, sig)) case "CombOpValue" => val i = int32_literal(it) - val sig = p_agg_sig(env)(it) - ir_value_expr(env)(it).map(value => CombOpValue(i, value, sig)) + val sig = p_agg_sig(ctx)(it) + ir_value_expr(ctx)(it).map(value => CombOpValue(i, value, sig)) case "SerializeAggs" => val i = int32_literal(it) val i2 = int32_literal(it) val spec = BufferSpec.parse(string_literal(it)) - val aggSigs = agg_state_signatures(env)(it) + val aggSigs = agg_state_signatures(ctx)(it) done(SerializeAggs(i, i2, spec, aggSigs)) case "DeserializeAggs" => val i = int32_literal(it) val i2 = int32_literal(it) val spec = BufferSpec.parse(string_literal(it)) - val aggSigs = agg_state_signatures(env)(it) + val aggSigs = agg_state_signatures(ctx)(it) done(DeserializeAggs(i, i2, spec, aggSigs)) - case "Begin" => ir_value_children(env)(it).map(Begin(_)) - case "MakeStruct" => named_value_irs(env)(it).map(MakeStruct(_)) + case "Begin" => ir_value_children(ctx)(it).map(Begin(_)) + case "MakeStruct" => named_value_irs(ctx)(it).map(MakeStruct(_)) case "SelectFields" => val fields = identifiers(it) - ir_value_expr(env)(it).map(old => SelectFields(old, fields)) + ir_value_expr(ctx)(it).map(old => SelectFields(old, fields)) case "InsertFields" => for { - old <- ir_value_expr(env)(it) + old <- ir_value_expr(ctx)(it) fieldOrder = opt(it, string_literals) - fields <- named_value_irs(env)(it) + fields <- named_value_irs(ctx)(it) } yield InsertFields(old, fields, fieldOrder.map(_.toFastSeq)) case "GetField" => val name = identifier(it) - ir_value_expr(env)(it).map(s => GetField(s, name)) + ir_value_expr(ctx)(it).map(s => GetField(s, name)) case "MakeTuple" => val indices = int32_literals(it) - ir_value_children(env)(it).map(args => MakeTuple(indices.zip(args))) + ir_value_children(ctx)(it).map(args => MakeTuple(indices.zip(args))) case "GetTupleElement" => val idx = int32_literal(it) - ir_value_expr(env)(it).map(tuple => GetTupleElement(tuple, idx)) + ir_value_expr(ctx)(it).map(tuple => GetTupleElement(tuple, idx)) case "Die" => val typ = type_expr(it) val errorID = int32_literal(it) - ir_value_expr(env)(it).map(msg => Die(msg, typ, errorID)) + ir_value_expr(ctx)(it).map(msg => Die(msg, typ, errorID)) case "Trap" => - ir_value_expr(env)(it).map(child => Trap(child)) + ir_value_expr(ctx)(it).map(child => Trap(child)) case "ConsoleLog" => for { - msg <- ir_value_expr(env)(it) - result <- ir_value_expr(env)(it) + msg <- ir_value_expr(ctx)(it) + result <- ir_value_expr(ctx)(it) } yield ConsoleLog(msg, result) case "ApplySeeded" => val function = identifier(it) val staticUID = int64_literal(it) val rt = type_expr(it) for { - rngState <- ir_value_expr(env)(it) - args <- ir_value_children(env)(it) + rngState <- ir_value_expr(ctx)(it) + args <- ir_value_children(ctx)(it) } yield ApplySeeded(function, args, rngState, staticUID, rt) case "ApplyIR" => - apply_like(env, ApplyIR)(it) + apply_like(ctx, ApplyIR)(it) case "ApplySpecial" => - apply_like(env, ApplySpecial)(it) + apply_like(ctx, ApplySpecial)(it) case "Apply" => - apply_like(env, Apply)(it) + apply_like(ctx, Apply)(it) case "MatrixCount" => - matrix_ir(env)(it).map(MatrixCount) + matrix_ir(ctx)(it).map(MatrixCount) case "TableCount" => - table_ir(env)(it).map(TableCount) + table_ir(ctx)(it).map(TableCount) case "TableGetGlobals" => - table_ir(env)(it).map(TableGetGlobals) + table_ir(ctx)(it).map(TableGetGlobals) case "TableCollect" => - table_ir(env)(it).map(TableCollect) + table_ir(ctx)(it).map(TableCollect) case "TableAggregate" => for { - child <- table_ir(env)(it) - query <- ir_value_expr(env)(it) + child <- table_ir(ctx)(it) + query <- ir_value_expr(ctx)(it) } yield TableAggregate(child, query) case "TableToValueApply" => val config = string_literal(it) - table_ir(env)(it).map { child => - TableToValueApply(child, RelationalFunctions.lookupTableToValue(env.ctx, config)) + table_ir(ctx)(it).map { child => + TableToValueApply(child, RelationalFunctions.lookupTableToValue(ctx, config)) } case "MatrixToValueApply" => val config = string_literal(it) - matrix_ir(env)(it).map { child => - MatrixToValueApply(child, RelationalFunctions.lookupMatrixToValue(env.ctx, config)) + matrix_ir(ctx)(it).map { child => + MatrixToValueApply(child, RelationalFunctions.lookupMatrixToValue(ctx, config)) } case "BlockMatrixToValueApply" => val config = string_literal(it) - blockmatrix_ir(env)(it).map { child => + blockmatrix_ir(ctx)(it).map { child => BlockMatrixToValueApply( child, - RelationalFunctions.lookupBlockMatrixToValue(env.ctx, config), + RelationalFunctions.lookupBlockMatrixToValue(ctx, config), ) } case "BlockMatrixCollect" => - blockmatrix_ir(env)(it).map(BlockMatrixCollect) + blockmatrix_ir(ctx)(it).map(BlockMatrixCollect) case "TableWrite" => implicit val formats = TableWriter.formats val writerStr = string_literal(it) - table_ir(env)(it).map(child => TableWrite(child, deserialize[TableWriter](writerStr))) + table_ir(ctx)(it).map(child => TableWrite(child, deserialize[TableWriter](writerStr))) case "TableMultiWrite" => implicit val formats = WrappedMatrixNativeMultiWriter.formats val writerStr = string_literal(it) - table_ir_children(env)(it).map { children => + table_ir_children(ctx)(it).map { children => TableMultiWrite(children, deserialize[WrappedMatrixNativeMultiWriter](writerStr)) } case "MatrixAggregate" => for { - child <- matrix_ir(env)(it) - query <- ir_value_expr(env)(it) + child <- matrix_ir(ctx)(it) + query <- ir_value_expr(ctx)(it) } yield MatrixAggregate(child, query) case "MatrixWrite" => val writerStr = string_literal(it) implicit val formats: Formats = MatrixWriter.formats val writer = deserialize[MatrixWriter](writerStr) - matrix_ir(env)(it).map(child => MatrixWrite(child, writer)) + matrix_ir(ctx)(it).map(child => MatrixWrite(child, writer)) case "MatrixMultiWrite" => val writerStr = string_literal(it) implicit val formats = MatrixNativeMultiWriter.formats val writer = deserialize[MatrixNativeMultiWriter](writerStr) - matrix_ir_children(env)(it).map(children => MatrixMultiWrite(children, writer)) + matrix_ir_children(ctx)(it).map(children => MatrixMultiWrite(children, writer)) case "BlockMatrixWrite" => val writerStr = string_literal(it) implicit val formats: Formats = BlockMatrixWriter.formats val writer = deserialize[BlockMatrixWriter](writerStr) - blockmatrix_ir(env)(it).map(child => BlockMatrixWrite(child, writer)) + blockmatrix_ir(ctx)(it).map(child => BlockMatrixWrite(child, writer)) case "BlockMatrixMultiWrite" => val writerStr = string_literal(it) implicit val formats: Formats = BlockMatrixWriter.formats val writer = deserialize[BlockMatrixMultiWriter](writerStr) - repUntil(it, blockmatrix_ir(env), PunctuationToken(")")).map { blockMatrices => + repUntil(it, blockmatrix_ir(ctx), PunctuationToken(")")).map { blockMatrices => BlockMatrixMultiWrite(blockMatrices.toFastSeq, writer) } case "CollectDistributedArray" => @@ -1476,14 +1470,14 @@ object IRParser { val cname = name(it) val gname = name(it) for { - ctxs <- ir_value_expr(env)(it) - globals <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) - dynamicID <- ir_value_expr(env)(it) + ctxs <- ir_value_expr(ctx)(it) + globals <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) + dynamicID <- ir_value_expr(ctx)(it) } yield CollectDistributedArray(ctxs, globals, cname, gname, body, dynamicID, staticID) case "JavaIR" => val id = int32_literal(it) - done(env.irMap(id).asInstanceOf[IR]) + done(ctx.IrCache(id).asInstanceOf[IR]) case "ReadPartition" => val requestedTypeRaw = it.head match { case x: IdentifierToken if x.value == "None" || x.value == "DropRowUIDs" => @@ -1492,8 +1486,8 @@ object IRParser { case _ => Right(type_expr(it)) } - val reader = PartitionReader.extract(env.ctx, JsonMethods.parse(string_literal(it))) - ir_value_expr(env)(it).map { context => + val reader = PartitionReader.extract(ctx, JsonMethods.parse(string_literal(it))) + ir_value_expr(ctx)(it).map { context => ReadPartition( context, requestedTypeRaw match { @@ -1508,64 +1502,64 @@ object IRParser { import PartitionWriter.formats val writer = JsonMethods.parse(string_literal(it)).extract[PartitionWriter] for { - stream <- ir_value_expr(env)(it) - ctx <- ir_value_expr(env)(it) + stream <- ir_value_expr(ctx)(it) + ctx <- ir_value_expr(ctx)(it) } yield WritePartition(stream, ctx, writer) case "WriteMetadata" => import MetadataWriter.formats val writer = JsonMethods.parse(string_literal(it)).extract[MetadataWriter] - ir_value_expr(env)(it).map(ctx => WriteMetadata(ctx, writer)) + ir_value_expr(ctx)(it).map(ctx => WriteMetadata(ctx, writer)) case "ReadValue" => import ValueReader.formats val reader = JsonMethods.parse(string_literal(it)).extract[ValueReader] val typ = type_expr(it) - ir_value_expr(env)(it).map(path => ReadValue(path, reader, typ)) + ir_value_expr(ctx)(it).map(path => ReadValue(path, reader, typ)) case "WriteValue" => import ValueWriter.formats val writer = JsonMethods.parse(string_literal(it)).extract[ValueWriter] - ir_value_children(env)(it).map { + ir_value_children(ctx)(it).map { case Array(value, path) => WriteValue(value, path, writer) case Array(value, path, stagingFile) => WriteValue(value, path, writer, Some(stagingFile)) } - case "LiftMeOut" => ir_value_expr(env)(it).map(LiftMeOut) + case "LiftMeOut" => ir_value_expr(ctx)(it).map(LiftMeOut) case "ReadPartition" => val rowType = tcoerce[TStruct](type_expr(it)) import PartitionReader.formats val reader = JsonMethods.parse(string_literal(it)).extract[PartitionReader] - ir_value_expr(env)(it).map(context => ReadPartition(context, rowType, reader)) + ir_value_expr(ctx)(it).map(context => ReadPartition(context, rowType, reader)) } } - def table_irs(env: IRParserEnvironment)(it: TokenIterator): StackFrame[Array[TableIR]] = { + def table_irs(ctx: ExecuteContext)(it: TokenIterator): StackFrame[Array[TableIR]] = { punctuation(it, "(") for { - tirs <- table_ir_children(env)(it) + tirs <- table_ir_children(ctx)(it) _ = punctuation(it, ")") } yield tirs } - def table_ir_children(env: IRParserEnvironment)(it: TokenIterator): StackFrame[Array[TableIR]] = - repUntil(it, table_ir(env), PunctuationToken(")")) + def table_ir_children(ctx: ExecuteContext)(it: TokenIterator): StackFrame[Array[TableIR]] = + repUntil(it, table_ir(ctx), PunctuationToken(")")) - def table_ir(env: IRParserEnvironment)(it: TokenIterator): StackFrame[TableIR] = { + def table_ir(ctx: ExecuteContext)(it: TokenIterator): StackFrame[TableIR] = { punctuation(it, "(") for { - ir <- call(table_ir_1(env)(it)) + ir <- call(table_ir_1(ctx)(it)) _ = punctuation(it, ")") } yield ir } - def table_ir_1(env: IRParserEnvironment)(it: TokenIterator): StackFrame[TableIR] = { + def table_ir_1(ctx: ExecuteContext)(it: TokenIterator): StackFrame[TableIR] = { identifier(it) match { case "TableKeyBy" => val keys = identifiers(it) val isSorted = boolean_literal(it) - table_ir(env)(it).map(child => TableKeyBy(child, keys, isSorted)) - case "TableDistinct" => table_ir(env)(it).map(TableDistinct) + table_ir(ctx)(it).map(child => TableKeyBy(child, keys, isSorted)) + case "TableDistinct" => table_ir(ctx)(it).map(TableDistinct) case "TableFilter" => for { - child <- table_ir(env)(it) - pred <- ir_value_expr(env)(it) + child <- table_ir(ctx)(it) + pred <- ir_value_expr(ctx)(it) } yield TableFilter(child, pred) case "TableRead" => val requestedTypeRaw = it.head match { @@ -1578,7 +1572,7 @@ object IRParser { val dropRows = boolean_literal(it) val readerStr = string_literal(it) val reader = - TableReader.fromJValue(env.ctx.fs, JsonMethods.parse(readerStr).asInstanceOf[JObject]) + TableReader.fromJValue(ctx.fs, JsonMethods.parse(readerStr).asInstanceOf[JObject]) val requestedType = requestedTypeRaw match { case Left("None") => reader.fullType case Left("DropRowUIDs") => @@ -1586,113 +1580,113 @@ object IRParser { case Right(t) => t } done(TableRead(requestedType, dropRows, reader)) - case "MatrixColsTable" => matrix_ir(env)(it).map(MatrixColsTable) - case "MatrixRowsTable" => matrix_ir(env)(it).map(MatrixRowsTable) - case "MatrixEntriesTable" => matrix_ir(env)(it).map(MatrixEntriesTable) + case "MatrixColsTable" => matrix_ir(ctx)(it).map(MatrixColsTable) + case "MatrixRowsTable" => matrix_ir(ctx)(it).map(MatrixRowsTable) + case "MatrixEntriesTable" => matrix_ir(ctx)(it).map(MatrixEntriesTable) case "TableAggregateByKey" => for { - child <- table_ir(env)(it) - expr <- ir_value_expr(env)(it) + child <- table_ir(ctx)(it) + expr <- ir_value_expr(ctx)(it) } yield TableAggregateByKey(child, expr) case "TableKeyByAndAggregate" => val nPartitions = opt(it, int32_literal) val bufferSize = int32_literal(it) for { - child <- table_ir(env)(it) - expr <- ir_value_expr(env)(it) - newKey <- ir_value_expr(env)(it) + child <- table_ir(ctx)(it) + expr <- ir_value_expr(ctx)(it) + newKey <- ir_value_expr(ctx)(it) } yield TableKeyByAndAggregate(child, expr, newKey, nPartitions, bufferSize) case "TableRepartition" => val n = int32_literal(it) val strategy = int32_literal(it) - table_ir(env)(it).map(child => TableRepartition(child, n, strategy)) + table_ir(ctx)(it).map(child => TableRepartition(child, n, strategy)) case "TableHead" => val n = int64_literal(it) - table_ir(env)(it).map(child => TableHead(child, n)) + table_ir(ctx)(it).map(child => TableHead(child, n)) case "TableTail" => val n = int64_literal(it) - table_ir(env)(it).map(child => TableTail(child, n)) + table_ir(ctx)(it).map(child => TableTail(child, n)) case "TableJoin" => val joinType = identifier(it) val joinKey = int32_literal(it) for { - left <- table_ir(env)(it) - right <- table_ir(env)(it) + left <- table_ir(ctx)(it) + right <- table_ir(ctx)(it) } yield TableJoin(left, right, joinType, joinKey) case "TableLeftJoinRightDistinct" => val root = identifier(it) for { - left <- table_ir(env)(it) - right <- table_ir(env)(it) + left <- table_ir(ctx)(it) + right <- table_ir(ctx)(it) } yield TableLeftJoinRightDistinct(left, right, root) case "TableIntervalJoin" => val root = identifier(it) val product = boolean_literal(it) for { - left <- table_ir(env)(it) - right <- table_ir(env)(it) + left <- table_ir(ctx)(it) + right <- table_ir(ctx)(it) } yield TableIntervalJoin(left, right, root, product) case "TableMultiWayZipJoin" => val dataName = string_literal(it) val globalsName = string_literal(it) - table_ir_children(env)(it).map { children => + table_ir_children(ctx)(it).map { children => TableMultiWayZipJoin(children, dataName, globalsName) } case "TableParallelize" => val nPartitions = opt(it, int32_literal) - ir_value_expr(env)(it).map(rowsAndGlobal => TableParallelize(rowsAndGlobal, nPartitions)) + ir_value_expr(ctx)(it).map(rowsAndGlobal => TableParallelize(rowsAndGlobal, nPartitions)) case "TableMapRows" => for { - child <- table_ir(env)(it) - newRow <- ir_value_expr(env)(it) + child <- table_ir(ctx)(it) + newRow <- ir_value_expr(ctx)(it) } yield TableMapRows(child, newRow) case "TableMapGlobals" => for { - child <- table_ir(env)(it) - newRow <- ir_value_expr(env)(it) + child <- table_ir(ctx)(it) + newRow <- ir_value_expr(ctx)(it) } yield TableMapGlobals(child, newRow) case "TableRange" => val n = int32_literal(it) val nPartitions = opt(it, int32_literal) - done(TableRange(n, nPartitions.getOrElse(HailContext.backend.defaultParallelism))) - case "TableUnion" => table_ir_children(env)(it).map(TableUnion(_)) + done(TableRange(n, nPartitions.getOrElse(ctx.backend.defaultParallelism))) + case "TableUnion" => table_ir_children(ctx)(it).map(TableUnion(_)) case "TableOrderBy" => val sortFields = sort_fields(it) - table_ir(env)(it).map(child => TableOrderBy(child, sortFields)) + table_ir(ctx)(it).map(child => TableOrderBy(child, sortFields)) case "TableExplode" => val path = string_literals(it) - table_ir(env)(it).map(child => TableExplode(child, path)) + table_ir(ctx)(it).map(child => TableExplode(child, path)) case "CastMatrixToTable" => val entriesField = string_literal(it) val colsField = string_literal(it) - matrix_ir(env)(it).map(child => CastMatrixToTable(child, entriesField, colsField)) + matrix_ir(ctx)(it).map(child => CastMatrixToTable(child, entriesField, colsField)) case "MatrixToTableApply" => val config = string_literal(it) - matrix_ir(env)(it).map { child => - MatrixToTableApply(child, RelationalFunctions.lookupMatrixToTable(env.ctx, config)) + matrix_ir(ctx)(it).map { child => + MatrixToTableApply(child, RelationalFunctions.lookupMatrixToTable(ctx, config)) } case "TableToTableApply" => val config = string_literal(it) - table_ir(env)(it).map { child => - TableToTableApply(child, RelationalFunctions.lookupTableToTable(env.ctx, config)) + table_ir(ctx)(it).map { child => + TableToTableApply(child, RelationalFunctions.lookupTableToTable(ctx, config)) } case "BlockMatrixToTableApply" => val config = string_literal(it) for { - bm <- blockmatrix_ir(env)(it) - aux <- ir_value_expr(env)(it) + bm <- blockmatrix_ir(ctx)(it) + aux <- ir_value_expr(ctx)(it) } yield BlockMatrixToTableApply( bm, aux, - RelationalFunctions.lookupBlockMatrixToTable(env.ctx, config), + RelationalFunctions.lookupBlockMatrixToTable(ctx, config), ) - case "BlockMatrixToTable" => blockmatrix_ir(env)(it).map(BlockMatrixToTable) + case "BlockMatrixToTable" => blockmatrix_ir(ctx)(it).map(BlockMatrixToTable) case "TableRename" => val rowK = string_literals(it) val rowV = string_literals(it) val globalK = string_literals(it) val globalV = string_literals(it) - table_ir(env)(it).map { child => + table_ir(ctx)(it).map { child => TableRename(child, rowK.zip(rowV).toMap, globalK.zip(globalV).toMap) } @@ -1700,19 +1694,19 @@ object IRParser { val cname = name(it) val gname = name(it) val partitioner = - between(punctuation(_, "("), punctuation(_, ")"), partitioner_literal(env))(it) + between(punctuation(_, "("), punctuation(_, ")"), partitioner_literal(ctx))(it) val errorId = int32_literal(it) for { - contexts <- ir_value_expr(env)(it) - globals <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + contexts <- ir_value_expr(ctx)(it) + globals <- ir_value_expr(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield TableGen(contexts, globals, cname, gname, body, partitioner, errorId) case "TableFilterIntervals" => val keyType = type_expr(it) val intervals = string_literal(it) val keep = boolean_literal(it) - table_ir(env)(it).map { child => + table_ir(ctx)(it).map { child => TableFilterIntervals( child, JSONAnnotationImpex.importAnnotation( @@ -1729,92 +1723,92 @@ object IRParser { val requestedKey = int32_literal(it) val allowedOverlap = int32_literal(it) for { - child <- table_ir(env)(it) - body <- ir_value_expr(env)(it) + child <- table_ir(ctx)(it) + body <- ir_value_expr(ctx)(it) } yield TableMapPartitions(child, globalsName, partitionStreamName, body, requestedKey, allowedOverlap) case "RelationalLetTable" => val n = name(it) for { - value <- ir_value_expr(env)(it) - body <- table_ir(env)(it) + value <- ir_value_expr(ctx)(it) + body <- table_ir(ctx)(it) } yield RelationalLetTable(n, value, body) case "JavaTable" => val id = int32_literal(it) - done(env.irMap(id).asInstanceOf[TableIR]) + done(ctx.IrCache(id).asInstanceOf[TableIR]) } } - def matrix_ir_children(env: IRParserEnvironment)(it: TokenIterator): StackFrame[Array[MatrixIR]] = - repUntil(it, matrix_ir(env), PunctuationToken(")")) + def matrix_ir_children(ctx: ExecuteContext)(it: TokenIterator): StackFrame[Array[MatrixIR]] = + repUntil(it, matrix_ir(ctx), PunctuationToken(")")) - def matrix_ir(env: IRParserEnvironment)(it: TokenIterator): StackFrame[MatrixIR] = { + def matrix_ir(ctx: ExecuteContext)(it: TokenIterator): StackFrame[MatrixIR] = { punctuation(it, "(") for { - ir <- call(matrix_ir_1(env)(it)) + ir <- call(matrix_ir_1(ctx)(it)) _ = punctuation(it, ")") } yield ir } - def matrix_ir_1(env: IRParserEnvironment)(it: TokenIterator): StackFrame[MatrixIR] = { + def matrix_ir_1(ctx: ExecuteContext)(it: TokenIterator): StackFrame[MatrixIR] = { identifier(it) match { case "MatrixFilterCols" => for { - child <- matrix_ir(env)(it) - pred <- ir_value_expr(env)(it) + child <- matrix_ir(ctx)(it) + pred <- ir_value_expr(ctx)(it) } yield MatrixFilterCols(child, pred) case "MatrixFilterRows" => for { - child <- matrix_ir(env)(it) - pred <- ir_value_expr(env)(it) + child <- matrix_ir(ctx)(it) + pred <- ir_value_expr(ctx)(it) } yield MatrixFilterRows(child, pred) case "MatrixFilterEntries" => for { - child <- matrix_ir(env)(it) - pred <- ir_value_expr(env)(it) + child <- matrix_ir(ctx)(it) + pred <- ir_value_expr(ctx)(it) } yield MatrixFilterEntries(child, pred) case "MatrixMapCols" => val newKey = opt(it, string_literals) for { - child <- matrix_ir(env)(it) - newCol <- ir_value_expr(env)(it) + child <- matrix_ir(ctx)(it) + newCol <- ir_value_expr(ctx)(it) } yield MatrixMapCols(child, newCol, newKey.map(_.toFastSeq)) case "MatrixKeyRowsBy" => val key = identifiers(it) val isSorted = boolean_literal(it) - matrix_ir(env)(it).map(child => MatrixKeyRowsBy(child, key, isSorted)) + matrix_ir(ctx)(it).map(child => MatrixKeyRowsBy(child, key, isSorted)) case "MatrixMapRows" => for { - child <- matrix_ir(env)(it) - newRow <- ir_value_expr(env)(it) + child <- matrix_ir(ctx)(it) + newRow <- ir_value_expr(ctx)(it) } yield MatrixMapRows(child, newRow) case "MatrixMapEntries" => for { - child <- matrix_ir(env)(it) - newEntry <- ir_value_expr(env)(it) + child <- matrix_ir(ctx)(it) + newEntry <- ir_value_expr(ctx)(it) } yield MatrixMapEntries(child, newEntry) case "MatrixUnionCols" => val joinType = identifier(it) for { - left <- matrix_ir(env)(it) - right <- matrix_ir(env)(it) + left <- matrix_ir(ctx)(it) + right <- matrix_ir(ctx)(it) } yield MatrixUnionCols(left, right, joinType) case "MatrixMapGlobals" => for { - child <- matrix_ir(env)(it) - newGlobals <- ir_value_expr(env)(it) + child <- matrix_ir(ctx)(it) + newGlobals <- ir_value_expr(ctx)(it) } yield MatrixMapGlobals(child, newGlobals) case "MatrixAggregateColsByKey" => for { - child <- matrix_ir(env)(it) - entryExpr <- ir_value_expr(env)(it) - colExpr <- ir_value_expr(env)(it) + child <- matrix_ir(ctx)(it) + entryExpr <- ir_value_expr(ctx)(it) + colExpr <- ir_value_expr(ctx)(it) } yield MatrixAggregateColsByKey(child, entryExpr, colExpr) case "MatrixAggregateRowsByKey" => for { - child <- matrix_ir(env)(it) - entryExpr <- ir_value_expr(env)(it) - rowExpr <- ir_value_expr(env)(it) + child <- matrix_ir(ctx)(it) + entryExpr <- ir_value_expr(ctx)(it) + rowExpr <- ir_value_expr(ctx)(it) } yield MatrixAggregateRowsByKey(child, entryExpr, rowExpr) case "MatrixRead" => val requestedTypeRaw = it.head match { @@ -1828,7 +1822,7 @@ object IRParser { val dropCols = boolean_literal(it) val dropRows = boolean_literal(it) val readerStr = string_literal(it) - val reader = MatrixReader.fromJson(env, JsonMethods.parse(readerStr).asInstanceOf[JObject]) + val reader = MatrixReader.fromJson(ctx, JsonMethods.parse(readerStr).asInstanceOf[JObject]) val fullType = reader.fullMatrixType val requestedType = requestedTypeRaw match { case Left("None") => fullType @@ -1849,53 +1843,53 @@ object IRParser { val root = string_literal(it) val product = boolean_literal(it) for { - child <- matrix_ir(env)(it) - table <- table_ir(env)(it) + child <- matrix_ir(ctx)(it) + table <- table_ir(ctx)(it) } yield MatrixAnnotateRowsTable(child, table, root, product) case "MatrixAnnotateColsTable" => val root = string_literal(it) for { - child <- matrix_ir(env)(it) - table <- table_ir(env)(it) + child <- matrix_ir(ctx)(it) + table <- table_ir(ctx)(it) } yield MatrixAnnotateColsTable(child, table, root) case "MatrixExplodeRows" => val path = identifiers(it) - matrix_ir(env)(it).map(child => MatrixExplodeRows(child, path)) + matrix_ir(ctx)(it).map(child => MatrixExplodeRows(child, path)) case "MatrixExplodeCols" => val path = identifiers(it) - matrix_ir(env)(it).map(child => MatrixExplodeCols(child, path)) + matrix_ir(ctx)(it).map(child => MatrixExplodeCols(child, path)) case "MatrixChooseCols" => val oldIndices = int32_literals(it) - matrix_ir(env)(it).map(child => MatrixChooseCols(child, oldIndices)) + matrix_ir(ctx)(it).map(child => MatrixChooseCols(child, oldIndices)) case "MatrixCollectColsByKey" => - matrix_ir(env)(it).map(MatrixCollectColsByKey) + matrix_ir(ctx)(it).map(MatrixCollectColsByKey) case "MatrixRepartition" => val n = int32_literal(it) val strategy = int32_literal(it) - matrix_ir(env)(it).map(child => MatrixRepartition(child, n, strategy)) - case "MatrixUnionRows" => matrix_ir_children(env)(it).map(MatrixUnionRows(_)) - case "MatrixDistinctByRow" => matrix_ir(env)(it).map(MatrixDistinctByRow) + matrix_ir(ctx)(it).map(child => MatrixRepartition(child, n, strategy)) + case "MatrixUnionRows" => matrix_ir_children(ctx)(it).map(MatrixUnionRows(_)) + case "MatrixDistinctByRow" => matrix_ir(ctx)(it).map(MatrixDistinctByRow) case "MatrixRowsHead" => val n = int64_literal(it) - matrix_ir(env)(it).map(child => MatrixRowsHead(child, n)) + matrix_ir(ctx)(it).map(child => MatrixRowsHead(child, n)) case "MatrixColsHead" => val n = int32_literal(it) - matrix_ir(env)(it).map(child => MatrixColsHead(child, n)) + matrix_ir(ctx)(it).map(child => MatrixColsHead(child, n)) case "MatrixRowsTail" => val n = int64_literal(it) - matrix_ir(env)(it).map(child => MatrixRowsTail(child, n)) + matrix_ir(ctx)(it).map(child => MatrixRowsTail(child, n)) case "MatrixColsTail" => val n = int32_literal(it) - matrix_ir(env)(it).map(child => MatrixColsTail(child, n)) + matrix_ir(ctx)(it).map(child => MatrixColsTail(child, n)) case "CastTableToMatrix" => val entriesField = identifier(it) val colsField = identifier(it) val colKey = identifiers(it) - table_ir(env)(it).map(child => CastTableToMatrix(child, entriesField, colsField, colKey)) + table_ir(ctx)(it).map(child => CastTableToMatrix(child, entriesField, colsField, colKey)) case "MatrixToMatrixApply" => val config = string_literal(it) - matrix_ir(env)(it).map { child => - MatrixToMatrixApply(child, RelationalFunctions.lookupMatrixToMatrix(env.ctx, config)) + matrix_ir(ctx)(it).map { child => + MatrixToMatrixApply(child, RelationalFunctions.lookupMatrixToMatrix(ctx, config)) } case "MatrixRename" => val globalK = string_literals(it) @@ -1906,7 +1900,7 @@ object IRParser { val rowV = string_literals(it) val entryK = string_literals(it) val entryV = string_literals(it) - matrix_ir(env)(it).map { child => + matrix_ir(ctx)(it).map { child => MatrixRename( child, globalK.zip(globalV).toMap, @@ -1919,7 +1913,7 @@ object IRParser { val keyType = type_expr(it) val intervals = string_literal(it) val keep = boolean_literal(it) - matrix_ir(env)(it).map { child => + matrix_ir(ctx)(it).map { child => MatrixFilterIntervals( child, JSONAnnotationImpex.importAnnotation( @@ -1933,47 +1927,45 @@ object IRParser { case "RelationalLetMatrixTable" => val n = name(it) for { - value <- ir_value_expr(env)(it) - body <- matrix_ir(env)(it) + value <- ir_value_expr(ctx)(it) + body <- matrix_ir(ctx)(it) } yield RelationalLetMatrixTable(n, value, body) } } - def blockmatrix_sparsifier(env: IRParserEnvironment)(it: TokenIterator) + def blockmatrix_sparsifier(ctx: ExecuteContext)(it: TokenIterator) : StackFrame[BlockMatrixSparsifier] = { punctuation(it, "(") identifier(it) match { case "PyRowIntervalSparsifier" => val blocksOnly = boolean_literal(it) punctuation(it, ")") - ir_value_expr(env)(it).map { ir_ => - val ir = annotateTypes(env.ctx, ir_, BindingEnv.empty).asInstanceOf[IR] + ir_value_expr(ctx)(it).map { ir_ => + val ir = annotateTypes(ctx, ir_, BindingEnv.empty).asInstanceOf[IR] val Row(starts: IndexedSeq[Long @unchecked], stops: IndexedSeq[Long @unchecked]) = - CompileAndEvaluate[Row](env.ctx, ir) + CompileAndEvaluate(ctx, ir) RowIntervalSparsifier(blocksOnly, starts, stops) } case "PyBandSparsifier" => val blocksOnly = boolean_literal(it) punctuation(it, ")") - ir_value_expr(env)(it).map { ir_ => - val ir = annotateTypes(env.ctx, ir_, BindingEnv.empty).asInstanceOf[IR] - val Row(l: Long, u: Long) = CompileAndEvaluate[Row](env.ctx, ir) + ir_value_expr(ctx)(it).map { ir_ => + val ir = annotateTypes(ctx, ir_, BindingEnv.empty).asInstanceOf[IR] + val Row(l: Long, u: Long) = CompileAndEvaluate(ctx, ir) BandSparsifier(blocksOnly, l, u) } case "PyPerBlockSparsifier" => punctuation(it, ")") - ir_value_expr(env)(it).map { ir_ => - val ir = annotateTypes(env.ctx, ir_, BindingEnv.empty).asInstanceOf[IR] - val indices: IndexedSeq[Int] = - CompileAndEvaluate[IndexedSeq[Int]](env.ctx, ir) + ir_value_expr(ctx)(it).map { ir_ => + val ir = annotateTypes(ctx, ir_, BindingEnv.empty).asInstanceOf[IR] + val indices: IndexedSeq[Int] = CompileAndEvaluate(ctx, ir) PerBlockSparsifier(indices) } case "PyRectangleSparsifier" => punctuation(it, ")") - ir_value_expr(env)(it).map { ir_ => - val ir = annotateTypes(env.ctx, ir_, BindingEnv.empty).asInstanceOf[IR] - val rectangles: IndexedSeq[Long] = - CompileAndEvaluate[IndexedSeq[Long]](env.ctx, ir) + ir_value_expr(ctx)(it).map { ir_ => + val ir = annotateTypes(ctx, ir_, BindingEnv.empty).asInstanceOf[IR] + val rectangles: IndexedSeq[Long] = CompileAndEvaluate(ctx, ir) RectangleSparsifier(rectangles.grouped(4).toIndexedSeq) } case "RowIntervalSparsifier" => @@ -1995,70 +1987,70 @@ object IRParser { } } - def blockmatrix_ir(env: IRParserEnvironment)(it: TokenIterator): StackFrame[BlockMatrixIR] = { + def blockmatrix_ir(ctx: ExecuteContext)(it: TokenIterator): StackFrame[BlockMatrixIR] = { punctuation(it, "(") for { - ir <- call(blockmatrix_ir1(env)(it)) + ir <- call(blockmatrix_ir1(ctx)(it)) _ = punctuation(it, ")") } yield ir } - def blockmatrix_ir1(env: IRParserEnvironment)(it: TokenIterator): StackFrame[BlockMatrixIR] = { + def blockmatrix_ir1(ctx: ExecuteContext)(it: TokenIterator): StackFrame[BlockMatrixIR] = { identifier(it) match { case "BlockMatrixRead" => val readerStr = string_literal(it) - val reader = BlockMatrixReader.fromJValue(env.ctx, JsonMethods.parse(readerStr)) + val reader = BlockMatrixReader.fromJValue(ctx, JsonMethods.parse(readerStr)) done(BlockMatrixRead(reader)) case "BlockMatrixMap" => val n = name(it) val needs_dense = boolean_literal(it) for { - child <- blockmatrix_ir(env)(it) - f <- ir_value_expr(env)(it) + child <- blockmatrix_ir(ctx)(it) + f <- ir_value_expr(ctx)(it) } yield BlockMatrixMap(child, n, f, needs_dense) case "BlockMatrixMap2" => val lName = name(it) val rName = name(it) val sparsityStrategy = SparsityStrategy.fromString(identifier(it)) for { - left <- blockmatrix_ir(env)(it) - right <- blockmatrix_ir(env)(it) - f <- ir_value_expr(env)(it) + left <- blockmatrix_ir(ctx)(it) + right <- blockmatrix_ir(ctx)(it) + f <- ir_value_expr(ctx)(it) } yield BlockMatrixMap2(left, right, lName, rName, f, sparsityStrategy) case "BlockMatrixDot" => for { - left <- blockmatrix_ir(env)(it) - right <- blockmatrix_ir(env)(it) + left <- blockmatrix_ir(ctx)(it) + right <- blockmatrix_ir(ctx)(it) } yield BlockMatrixDot(left, right) case "BlockMatrixBroadcast" => val inIndexExpr = int32_literals(it) val shape = int64_literals(it) val blockSize = int32_literal(it) - blockmatrix_ir(env)(it).map { child => + blockmatrix_ir(ctx)(it).map { child => BlockMatrixBroadcast(child, inIndexExpr, shape, blockSize) } case "BlockMatrixAgg" => val outIndexExpr = int32_literals(it) - blockmatrix_ir(env)(it).map(child => BlockMatrixAgg(child, outIndexExpr)) + blockmatrix_ir(ctx)(it).map(child => BlockMatrixAgg(child, outIndexExpr)) case "BlockMatrixFilter" => val indices = literals(literals(int64_literal))(it) - blockmatrix_ir(env)(it).map(child => BlockMatrixFilter(child, indices)) + blockmatrix_ir(ctx)(it).map(child => BlockMatrixFilter(child, indices)) case "BlockMatrixDensify" => - blockmatrix_ir(env)(it).map(BlockMatrixDensify) + blockmatrix_ir(ctx)(it).map(BlockMatrixDensify) case "BlockMatrixSparsify" => for { - sparsifier <- blockmatrix_sparsifier(env)(it) - child <- blockmatrix_ir(env)(it) + sparsifier <- blockmatrix_sparsifier(ctx)(it) + child <- blockmatrix_ir(ctx)(it) } yield BlockMatrixSparsify(child, sparsifier) case "BlockMatrixSlice" => val slices = literals(literals(int64_literal))(it) - blockmatrix_ir(env)(it).map { child => + blockmatrix_ir(ctx)(it).map { child => BlockMatrixSlice(child, slices.map(_.toFastSeq).toFastSeq) } case "ValueToBlockMatrix" => val shape = int64_literals(it) val blockSize = int32_literal(it) - ir_value_expr(env)(it).map(child => ValueToBlockMatrix(child, shape, blockSize)) + ir_value_expr(ctx)(it).map(child => ValueToBlockMatrix(child, shape, blockSize)) case "BlockMatrixRandom" => val staticUID = int64_literal(it) val gaussian = boolean_literal(it) @@ -2068,8 +2060,8 @@ object IRParser { case "RelationalLetBlockMatrix" => val n = name(it) for { - value <- ir_value_expr(env)(it) - body <- blockmatrix_ir(env)(it) + value <- ir_value_expr(ctx)(it) + body <- blockmatrix_ir(ctx)(it) } yield RelationalLetBlockMatrix(n, value, body) } } @@ -2152,53 +2144,38 @@ object IRParser { f(it) } - def parse_value_ir( - s: String, - env: IRParserEnvironment, - typeEnv: BindingEnv[Type] = BindingEnv.empty, - ): IR = - env.ctx.time { - var ir = parse(s, ir_value_expr(env)(_).run()) - ir = annotateTypes(env.ctx, ir, typeEnv).asInstanceOf[IR] - TypeCheck(env.ctx, ir, typeEnv) + def parse_value_ir(ctx: ExecuteContext, s: String, typeEnv: BindingEnv[Type] = BindingEnv.empty) + : IR = + ctx.time { + var ir = parse(s, ir_value_expr(ctx)(_).run()) + ir = annotateTypes(ctx, ir, typeEnv).asInstanceOf[IR] + TypeCheck(ctx, ir, typeEnv) ir } - def parse_value_ir(ctx: ExecuteContext, s: String): IR = - parse_value_ir(s, IRParserEnvironment(ctx)) - def parse_table_ir(ctx: ExecuteContext, s: String): TableIR = - parse_table_ir(s, IRParserEnvironment(ctx)) - - def parse_table_ir(s: String, env: IRParserEnvironment): TableIR = - env.ctx.time { - var ir = parse(s, table_ir(env)(_).run()) - ir = annotateTypes(env.ctx, ir, BindingEnv.empty).asInstanceOf[TableIR] - TypeCheck(env.ctx, ir) - ir - } - - def parse_matrix_ir(s: String, env: IRParserEnvironment): MatrixIR = - env.ctx.time { - var ir = parse(s, matrix_ir(env)(_).run()) - ir = annotateTypes(env.ctx, ir, BindingEnv.empty).asInstanceOf[MatrixIR] - TypeCheck(env.ctx, ir) + ctx.time { + var ir = parse(s, table_ir(ctx)(_).run()) + ir = annotateTypes(ctx, ir, BindingEnv.empty).asInstanceOf[TableIR] + TypeCheck(ctx, ir) ir } def parse_matrix_ir(ctx: ExecuteContext, s: String): MatrixIR = - parse_matrix_ir(s, IRParserEnvironment(ctx)) - - def parse_blockmatrix_ir(s: String, env: IRParserEnvironment): BlockMatrixIR = - env.ctx.time { - var ir = parse(s, blockmatrix_ir(env)(_).run()) - ir = annotateTypes(env.ctx, ir, BindingEnv.empty).asInstanceOf[BlockMatrixIR] - TypeCheck(env.ctx, ir) + ctx.time { + var ir = parse(s, matrix_ir(ctx)(_).run()) + ir = annotateTypes(ctx, ir, BindingEnv.empty).asInstanceOf[MatrixIR] + TypeCheck(ctx, ir) ir } def parse_blockmatrix_ir(ctx: ExecuteContext, s: String): BlockMatrixIR = - parse_blockmatrix_ir(s, IRParserEnvironment(ctx)) + ctx.time { + var ir = parse(s, blockmatrix_ir(ctx)(_).run()) + ir = annotateTypes(ctx, ir, BindingEnv.empty).asInstanceOf[BlockMatrixIR] + TypeCheck(ctx, ir) + ir + } def parseType(code: String): Type = parse(code, type_expr) diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala index 2d3cf1f6e996..be5f0e3e3b2d 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala @@ -77,11 +77,7 @@ object IRFunctionRegistry { val typeParameters = typeParamStrs.map(IRParser.parseType).toFastSeq val valueParameterTypes = argTypeStrs.map(IRParser.parseType).toFastSeq val refMap = BindingEnv.eval(argNames.zip(valueParameterTypes): _*) - val body = IRParser.parse_value_ir( - bodyStr, - IRParserEnvironment(ctx, Map()), - refMap, - ) + val body = IRParser.parse_value_ir(ctx, bodyStr, refMap) userAddedFunctions += ((name, (body.typ, typeParameters, valueParameterTypes))) addIR( diff --git a/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala b/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala index 7f756fa5aac7..2ba6dcd99757 100644 --- a/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala +++ b/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala @@ -4,10 +4,9 @@ import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext import is.hail.expr.ir.{ - EmitCode, EmitCodeBuilder, EmitMethodBuilder, EmitSettable, EmitValue, IEmitCode, IR, - IRParserEnvironment, Literal, LowerMatrixIR, MakeStruct, MatrixHybridReader, MatrixReader, - PartitionNativeIntervalReader, PartitionReader, ReadPartition, Ref, TableNativeReader, - TableReader, ToStream, + EmitCode, EmitCodeBuilder, EmitMethodBuilder, EmitSettable, EmitValue, IEmitCode, IR, Literal, + LowerMatrixIR, MakeStruct, MatrixHybridReader, MatrixReader, PartitionNativeIntervalReader, + PartitionReader, ReadPartition, Ref, TableNativeReader, TableReader, ToStream, } import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.expr.ir.streams.StreamProducer @@ -361,8 +360,8 @@ object MatrixBGENReader { ) } - def fromJValue(env: IRParserEnvironment, jv: JValue): MatrixBGENReader = - MatrixBGENReader(env.ctx, MatrixBGENReaderParameters.fromJValue(jv)) + def fromJValue(ctx: ExecuteContext, jv: JValue): MatrixBGENReader = + MatrixBGENReader(ctx, MatrixBGENReaderParameters.fromJValue(jv)) def apply( ctx: ExecuteContext, diff --git a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala index e7855248d146..bcaf5a93d571 100644 --- a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala @@ -23,6 +23,7 @@ import is.hail.types.virtual.TIterable.elementType import is.hail.utils.{FastSeq, _} import is.hail.variant.{Call2, Locus} +import scala.collection.mutable import scala.language.implicitConversions import org.apache.spark.sql.Row @@ -3876,12 +3877,8 @@ class IRSuite extends HailSuite { @Test(dataProvider = "valueIRs") def testValueIRParser(x: IR, refMap: BindingEnv[Type]): Unit = { - val env = IRParserEnvironment(ctx) - val s = Pretty.sexprStyle(x, elideLiterals = false) - - val x2 = IRParser.parse_value_ir(s, env, refMap) - + val x2 = IRParser.parse_value_ir(ctx, s, refMap) assert(x2 == x) } @@ -3931,7 +3928,7 @@ class IRSuite extends HailSuite { val cached = Literal(TSet(TInt32), Set(1)) val s = s"(JavaIR 1)" val x2 = ExecuteContext.scoped { ctx => - IRParser.parse_value_ir(s, IRParserEnvironment(ctx, irMap = Map(1 -> cached))) + ctx.local(irCache = mutable.Map(1 -> cached))(ctx => IRParser.parse_value_ir(ctx, s)) } assert(x2 eq cached) } @@ -3940,7 +3937,7 @@ class IRSuite extends HailSuite { val cached = TableRange(1, 1) val s = s"(JavaTable 1)" val x2 = ExecuteContext.scoped { ctx => - IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = Map(1 -> cached))) + ctx.local(irCache = mutable.Map(1 -> cached))(ctx => IRParser.parse_table_ir(ctx, s)) } assert(x2 eq cached) }