diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index 2c1610b5614..b9f986e5834 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -237,17 +237,17 @@ def _rpc(self, action, payload) -> Tuple[bytes, Optional[dict]]: def persist_expression(self, expr): t = expr.dtype - return construct_expr(JavaIR(t, self._jbackend.executeLiteral(self._render_ir(expr._ir))), t) + return construct_expr(JavaIR(t, self._jbackend.pyExecuteLiteral(self._render_ir(expr._ir))), t) def _is_registered_ir_function_name(self, name: str) -> bool: return name in self._registered_ir_function_names def set_flags(self, **flags: Mapping[str, str]): - available = self._jbackend.availableFlags() + available = self._jbackend.pyAvailableFlags() invalid = [] for flag, value in flags.items(): if flag in available: - self._jbackend.setFlag(flag, value) + self._jbackend.pySetFlag(flag, value) else: invalid.append(flag) if len(invalid) != 0: @@ -256,7 +256,7 @@ def set_flags(self, **flags: Mapping[str, str]): ) def get_flags(self, *flags) -> Mapping[str, str]: - return {flag: self._jbackend.getFlag(flag) for flag in flags} + return {flag: self._jbackend.pyGetFlag(flag) for flag in flags} def _add_reference_to_scala_backend(self, rg): self._jbackend.pyAddReference(orjson.dumps(rg._config).decode('utf-8')) diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index b024a36304b..2bef587fc1d 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -3880,7 +3880,7 @@ def __del__(self): if Env._hc: backend = Env.backend() assert isinstance(backend, Py4JBackend) - backend._jbackend.removeJavaIR(self._id) + backend._jbackend.pyRemoveJavaIR(self._id) class JavaIR(IR): diff --git a/hail/python/hail/ir/table_ir.py b/hail/python/hail/ir/table_ir.py index 8184401c126..eb96deee863 100644 --- a/hail/python/hail/ir/table_ir.py +++ b/hail/python/hail/ir/table_ir.py @@ -1215,4 +1215,4 @@ def __del__(self): if Env._hc: backend = Env.backend() assert isinstance(backend, Py4JBackend) - backend._jbackend.removeJavaIR(self._id) + backend._jbackend.pyRemoveJavaIR(self._id) diff --git a/hail/python/test/hail/genetics/test_reference_genome.py b/hail/python/test/hail/genetics/test_reference_genome.py index 88c2940bf9e..a49c2185712 100644 --- a/hail/python/test/hail/genetics/test_reference_genome.py +++ b/hail/python/test/hail/genetics/test_reference_genome.py @@ -194,7 +194,7 @@ def assert_rg_loaded_correctly(name): # loading different reference genome with same name should fail # (different `test_rg_o` definition) with pytest.raises(FatalError): - hl.read_matrix_table(resource('custom_references_2.t')).count() + hl.read_table(resource('custom_references_2.t')).count() assert hl.read_matrix_table(resource('custom_references.mt')).count_rows() == 14 assert_rg_loaded_correctly('test_rg_1') diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index 329ee1a3e38..c9e97038b7b 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -1,12 +1,12 @@ package is.hail.backend import is.hail.asm4s._ +import is.hail.backend.Backend.jsonToBytes import is.hail.backend.spark.SparkBackend import is.hail.expr.ir.{ BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader, } -import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs._ @@ -20,16 +20,14 @@ import is.hail.types.virtual.{BlockMatrixType, TFloat64} import is.hail.utils._ import is.hail.variant.ReferenceGenome -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag import java.io._ import java.nio.charset.StandardCharsets -import com.fasterxml.jackson.core.StreamReadConstraints import org.json4s._ -import org.json4s.jackson.{JsonMethods, Serialization} +import org.json4s.jackson.JsonMethods import sourcecode.Enclosing object Backend { @@ -41,13 +39,6 @@ object Backend { s"hail_query_$id" } - private var irID: Int = 0 - - def nextIRID(): Int = { - irID += 1 - irID - } - def encodeToOutputStream( ctx: ExecuteContext, t: PTuple, @@ -66,6 +57,9 @@ object Backend { assert(t.isFieldDefined(off, 0)) codec.encode(ctx, elementType, t.loadField(off, 0), os) } + + def jsonToBytes(f: => JValue): Array[Byte] = + JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8) } abstract class BroadcastValue[T] { def value: T } @@ -75,28 +69,8 @@ trait BackendContext { } abstract class Backend extends Closeable { - // From https://github.com/hail-is/hail/issues/14580 : - // IR can get quite big, especially as it can contain an arbitrary - // amount of encoded literals from the user's python session. This - // was a (controversial) restriction imposed by Jackson and should be lifted. - // - // We remove this restriction for all backends, and we do so here, in the - // constructor since constructing a backend is one of the first things that - // happens and this constraint should be overrided as early as possible. - StreamReadConstraints.overrideDefaultStreamReadConstraints( - StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build() - ) - val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map() - protected[this] def addJavaIR(ir: BaseIR): Int = { - val id = Backend.nextIRID() - persistedIR += (id -> ir) - id - } - - def removeJavaIR(id: Int): Unit = persistedIR.remove(id) - def defaultParallelism: Int def canExecuteParallelTasksOnDriver: Boolean = true @@ -131,30 +105,7 @@ abstract class Backend extends Closeable { def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) : CompiledFunction[T] - var references: Map[String, ReferenceGenome] = Map.empty - - def addDefaultReferences(): Unit = - references = ReferenceGenome.builtinReferences() - - def addReference(rg: ReferenceGenome): Unit = { - references.get(rg.name) match { - case Some(rg2) => - if (rg != rg2) { - fatal( - s"Cannot add reference genome '${rg.name}', a different reference with that name already exists. Choose a reference name NOT in the following list:\n " + - s"@1", - references.keys.truncatable("\n "), - ) - } - case None => - references += (rg.name -> rg) - } - } - - def hasReference(name: String) = references.contains(name) - - def removeReference(name: String): Unit = - references -= name + def references: mutable.Map[String, ReferenceGenome] def lowerDistributedSort( ctx: ExecuteContext, @@ -189,9 +140,6 @@ abstract class Backend extends Closeable { def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T - private[this] def jsonToBytes(f: => JValue): Array[Byte] = - JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8) - final def valueType(s: String): Array[Byte] = jsonToBytes { withExecuteContext { ctx => @@ -220,15 +168,7 @@ abstract class Backend extends Closeable { } } - def loadReferencesFromDataset(path: String): Array[Byte] = { - withExecuteContext { ctx => - val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) - rgs.foreach(addReference) - - implicit val formats: Formats = defaultJSONFormats - Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) - } - } + def loadReferencesFromDataset(path: String): Array[Byte] def fromFASTAFile( name: String, @@ -240,18 +180,22 @@ abstract class Backend extends Closeable { parInput: Array[String], ): Array[Byte] = withExecuteContext { ctx => - val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, - xContigs, yContigs, mtContigs, parInput) - rg.toJSONString.getBytes(StandardCharsets.UTF_8) + jsonToBytes { + Extraction.decompose { + ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, + xContigs, yContigs, mtContigs, parInput).toJSON + }(defaultJSONFormats) + } } - def parseVCFMetadata(path: String): Array[Byte] = jsonToBytes { + def parseVCFMetadata(path: String): Array[Byte] = withExecuteContext { ctx => - val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path) - implicit val formats = defaultJSONFormats - Extraction.decompose(metadata) + jsonToBytes { + Extraction.decompose { + LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path) + }(defaultJSONFormats) + } } - } def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String) : Array[Byte] = @@ -261,27 +205,6 @@ abstract class Backend extends Closeable { ) } - def pyRegisterIR( - name: String, - typeParamStrs: java.util.ArrayList[String], - argNameStrs: java.util.ArrayList[String], - argTypeStrs: java.util.ArrayList[String], - returnType: String, - bodyStr: String, - ): Unit = { - withExecuteContext { ctx => - IRFunctionRegistry.registerIR( - ctx, - name, - typeParamStrs.asScala.toArray, - argNameStrs.asScala.toArray, - argTypeStrs.asScala.toArray, - returnType, - bodyStr, - ) - } - } - def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] } diff --git a/hail/src/main/scala/is/hail/backend/BackendServer.scala b/hail/src/main/scala/is/hail/backend/BackendServer.scala index 7ce224548c9..04ab2200641 100644 --- a/hail/src/main/scala/is/hail/backend/BackendServer.scala +++ b/hail/src/main/scala/is/hail/backend/BackendServer.scala @@ -112,6 +112,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler { } return } + val response: Array[Byte] = exchange.getRequestURI.getPath match { case "/value/type" => backend.valueType(body.extract[IRTypePayload].ir) case "/table/type" => backend.tableType(body.extract[IRTypePayload].ir) diff --git a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala index 1ddc132265f..ce477a3232f 100644 --- a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala +++ b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala @@ -128,7 +128,7 @@ class ExecuteContext( ) } - val stateManager = HailStateManager(backend.references) + def stateManager = HailStateManager(backend.references.toMap) val tempFileManager: TempFileManager = if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs) diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index 2b43df11419..37adcd0e2c1 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -4,8 +4,9 @@ import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend._ +import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate -import is.hail.expr.ir.{IRParser, _} +import is.hail.expr.ir._ import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.lowering._ import is.hail.io.fs._ @@ -17,11 +18,12 @@ import is.hail.types.virtual.{BlockMatrixType, TVoid} import is.hail.utils._ import is.hail.variant.ReferenceGenome -import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag import java.io.PrintWriter +import com.fasterxml.jackson.core.StreamReadConstraints import com.google.common.util.concurrent.MoreExecutors import org.apache.hadoop import sourcecode.Enclosing @@ -35,6 +37,18 @@ class LocalTaskContext(val partitionId: Int, val stageId: Int) extends HailTaskC object LocalBackend { private var theLocalBackend: LocalBackend = _ + // From https://github.com/hail-is/hail/issues/14580 : + // IR can get quite big, especially as it can contain an arbitrary + // amount of encoded literals from the user's python session. This + // was a (controversial) restriction imposed by Jackson and should be lifted. + // + // We remove this restriction at the earliest point possible for each backend/ + // This can't be unified since each backend has its own entry-point from python + // and its own specific initialisation code. + StreamReadConstraints.overrideDefaultStreamReadConstraints( + StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build() + ) + def apply( tmpdir: String, logFile: String = "hail.log", @@ -47,8 +61,11 @@ object LocalBackend { if (!skipLoggingConfiguration) HailContext.configureLogging(logFile, quiet, append) - theLocalBackend = new LocalBackend(tmpdir) - theLocalBackend.addDefaultReferences() + theLocalBackend = new LocalBackend( + tmpdir, + mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*), + ) + theLocalBackend } @@ -64,18 +81,15 @@ object LocalBackend { } } -class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache { - - private[this] val flags = HailFeatureFlags.fromEnv() - private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) - - def getFlag(name: String): String = flags.get(name) +class LocalBackend( + val tmpdir: String, + override val references: mutable.Map[String, ReferenceGenome], +) extends Backend with BackendWithCodeCache with Py4JBackendExtensions { - def setFlag(name: String, value: String) = flags.set(name, value) + override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() + override def longLifeTempFileManager: TempFileManager = null - // called from python - val availableFlags: java.util.ArrayList[String] = - flags.available + private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) // flags can be set after construction from python def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) @@ -190,81 +204,6 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache res } - def executeLiteral(irStr: String): Int = - withExecuteContext { ctx => - val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap)) - assert(ir.typ.isRealizable) - execute(ctx, ir) match { - case Left(_) => throw new HailException("Can't create literal") - case Right((pt, addr)) => - val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) - addJavaIR(field) - } - } - - def pyAddReference(jsonConfig: String): Unit = addReference(ReferenceGenome.fromJSON(jsonConfig)) - def pyRemoveReference(name: String): Unit = removeReference(name) - - def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = - withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName)) - - def pyRemoveLiftover(name: String, destRGName: String) = - references(name).removeLiftover(destRGName) - - def pyFromFASTAFile( - name: String, - fastaFile: String, - indexFile: String, - xContigs: java.util.List[String], - yContigs: java.util.List[String], - mtContigs: java.util.List[String], - parInput: java.util.List[String], - ): String = - withExecuteContext { ctx => - val rg = ReferenceGenome.fromFASTAFile( - ctx, - name, - fastaFile, - indexFile, - xContigs.asScala.toArray, - yContigs.asScala.toArray, - mtContigs.asScala.toArray, - parInput.asScala.toArray, - ) - rg.toJSONString - } - - def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = - withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile)) - - def pyRemoveSequence(name: String) = references(name).removeSequence() - - def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = - withExecuteContext { ctx => - IRParser.parse_value_ir( - s, - IRParserEnvironment(ctx, persistedIR.toMap), - BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) => - Name(n) -> IRParser.parseType(t) - }.toSeq: _*), - ) - } - - def parse_table_ir(s: String): TableIR = - withExecuteContext { ctx => - IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - } - - def parse_matrix_ir(s: String): MatrixIR = - withExecuteContext { ctx => - IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - } - - def parse_blockmatrix_ir(s: String): BlockMatrixIR = - withExecuteContext { ctx => - IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - } - override def lowerDistributedSort( ctx: ExecuteContext, stage: TableStage, diff --git a/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala new file mode 100644 index 00000000000..7bb95f7ae52 --- /dev/null +++ b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala @@ -0,0 +1,255 @@ +package is.hail.backend.py4j + +import is.hail.HailFeatureFlags +import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager} +import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} +import is.hail.expr.ir.{ + BaseIR, BindingEnv, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IR, IRParser, + IRParserEnvironment, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, Name, + NativeReaderOptions, TableIR, TableLiteral, TableValue, +} +import is.hail.expr.ir.IRParser.parseType +import is.hail.expr.ir.functions.IRFunctionRegistry +import is.hail.linalg.RowMatrix +import is.hail.types.physical.PStruct +import is.hail.types.virtual.{TArray, TInterval} +import is.hail.utils.{defaultJSONFormats, log, toRichIterable, FastSeq, HailException, Interval} +import is.hail.variant.ReferenceGenome + +import scala.collection.mutable +import scala.jdk.CollectionConverters.{ + asScalaBufferConverter, mapAsScalaMapConverter, seqAsJavaListConverter, +} + +import java.nio.charset.StandardCharsets +import java.util + +import org.apache.spark.sql.DataFrame +import org.json4s +import org.json4s.Formats +import org.json4s.jackson.{JsonMethods, Serialization} +import sourcecode.Enclosing + +trait Py4JBackendExtensions { this: Backend => + def persistedIR: mutable.Map[Int, BaseIR] + def flags: HailFeatureFlags + def longLifeTempFileManager: TempFileManager + + def pyGetFlag(name: String): String = + flags.get(name) + + def pySetFlag(name: String, value: String): Unit = + flags.set(name, value) + + def pyAvailableFlags: java.util.ArrayList[String] = + flags.available + + private[this] var irID: Int = 0 + + private[this] def nextIRID(): Int = { + irID += 1 + irID + } + + private[this] def addJavaIR(ir: BaseIR): Int = { + val id = nextIRID() + persistedIR += (id -> ir) + id + } + + def pyRemoveJavaIR(id: Int): Unit = + persistedIR.remove(id) + + def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = + withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile)) + + def pyRemoveSequence(name: String): Unit = + references(name).removeSequence() + + def pyExportBlockMatrix( + pathIn: String, + pathOut: String, + delimiter: String, + header: String, + addIndex: Boolean, + exportType: String, + partitionSize: java.lang.Integer, + entries: String, + ): Unit = + withExecuteContext { ctx => + val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize) + entries match { + case "full" => + rm.export(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "lower" => + rm.exportLowerTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "strict_lower" => + rm.exportStrictLowerTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) + case "upper" => + rm.exportUpperTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "strict_upper" => + rm.exportStrictUpperTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) + } + } + + def pyRegisterIR( + name: String, + typeParamStrs: java.util.ArrayList[String], + argNameStrs: java.util.ArrayList[String], + argTypeStrs: java.util.ArrayList[String], + returnType: String, + bodyStr: String, + ): Unit = { + withExecuteContext { ctx => + IRFunctionRegistry.registerIR( + ctx, + name, + typeParamStrs.asScala.toArray, + argNameStrs.asScala.toArray, + argTypeStrs.asScala.toArray, + returnType, + bodyStr, + ) + } + } + + def pyExecuteLiteral(irStr: String): Int = + withExecuteContext { ctx => + val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap)) + assert(ir.typ.isRealizable) + execute(ctx, ir) match { + case Left(_) => throw new HailException("Can't create literal") + case Right((pt, addr)) => + val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) + addJavaIR(field) + } + } + + def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = { + val key = jKey.asScala.toArray.toFastSeq + val signature = + SparkAnnotationImpex.importType(df.schema).setRequired(true).asInstanceOf[PStruct] + withExecuteContext(selfContainedExecution = false) { ctx => + val tir = TableLiteral( + TableValue( + ctx, + signature.virtualType, + key, + df.rdd, + Some(signature), + ), + ctx.theHailClassLoader, + ) + val id = addJavaIR(tir) + (id, JsonMethods.compact(tir.typ.toJSON)) + } + } + + def pyToDF(s: String): DataFrame = + withExecuteContext { ctx => + val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + Interpret(tir, ctx).toDF() + } + + def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] = + withExecuteContext { ctx => + log.info("pyReadMultipleMatrixTables: got query") + val kvs = JsonMethods.parse(jsonQuery) match { + case json4s.JObject(values) => values.toMap + } + + val paths = kvs("paths").asInstanceOf[json4s.JArray].arr.toArray.map { + case json4s.JString(s) => s + } + + val intervalPointType = parseType(kvs("intervalPointType").asInstanceOf[json4s.JString].s) + val intervalObjects = + JSONAnnotationImpex.importAnnotation(kvs("intervals"), TArray(TInterval(intervalPointType))) + .asInstanceOf[IndexedSeq[Interval]] + + val opts = NativeReaderOptions(intervalObjects, intervalPointType) + val matrixReaders: IndexedSeq[MatrixIR] = paths.map { p => + log.info(s"creating MatrixRead node for $p") + val mnr = MatrixNativeReader(ctx.fs, p, Some(opts)) + MatrixRead(mnr.fullMatrixTypeWithoutUIDs, false, false, mnr): MatrixIR + } + log.info("pyReadMultipleMatrixTables: returning N matrix tables") + matrixReaders.asJava + } + + def pyAddReference(jsonConfig: String): Unit = + addReference(ReferenceGenome.fromJSON(jsonConfig)) + + def pyRemoveReference(name: String): Unit = + references.remove(name) + + def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = + withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName)) + + def pyRemoveLiftover(name: String, destRGName: String): Unit = + references(name).removeLiftover(destRGName) + + private[this] def addReference(rg: ReferenceGenome): Unit = + ReferenceGenome.addFatalOnCollision(references, FastSeq(rg)) + + def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = + withExecuteContext { ctx => + IRParser.parse_value_ir( + s, + IRParserEnvironment(ctx, irMap = persistedIR.toMap), + BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) => + Name(n) -> IRParser.parseType(t) + }.toSeq: _*), + ) + } + + def parse_table_ir(s: String): TableIR = + withExecuteContext(selfContainedExecution = false) { ctx => + IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + } + + def parse_matrix_ir(s: String): MatrixIR = + withExecuteContext(selfContainedExecution = false) { ctx => + IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + } + + def parse_blockmatrix_ir(s: String): BlockMatrixIR = + withExecuteContext(selfContainedExecution = false) { ctx => + IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + } + + def loadReferencesFromDataset(path: String): Array[Byte] = + withExecuteContext { ctx => + val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) + ReferenceGenome.addFatalOnCollision(references, rgs) + + implicit val formats: Formats = defaultJSONFormats + Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) + } + + def withExecuteContext[T]( + selfContainedExecution: Boolean = true + )( + f: ExecuteContext => T + )(implicit E: Enclosing + ): T = + withExecuteContext { ctx => + val tempFileManager = longLifeTempFileManager + if (selfContainedExecution && tempFileManager != null) f(ctx) + else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f) + } +} diff --git a/hail/src/main/scala/is/hail/backend/service/Main.scala b/hail/src/main/scala/is/hail/backend/service/Main.scala index f35b683988d..a31f99e7946 100644 --- a/hail/src/main/scala/is/hail/backend/service/Main.scala +++ b/hail/src/main/scala/is/hail/backend/service/Main.scala @@ -10,7 +10,7 @@ object Main { /* This constraint should be overridden as early as possible. See: * - https://github.com/hail-is/hail/issues/14580 * - https://github.com/hail-is/hail/issues/14749 - * - The note on this setting in is.hail.backend.Backend */ + * - The note on this setting in is.hail.backend.Backend companion objects */ StreamReadConstraints.overrideDefaultStreamReadConstraints( StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build() ) diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 4908f2429ae..c65276ab518 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -6,8 +6,8 @@ import is.hail.asm4s._ import is.hail.backend._ import is.hail.expr.Validate import is.hail.expr.ir.{ - Compile, IR, IRParser, IRParserEnvironment, IRSize, LoweringAnalyses, MakeTuple, SortField, - TableIR, TableReader, TypeCheck, + Compile, IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, + TypeCheck, } import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.functions.IRFunctionRegistry @@ -24,6 +24,7 @@ import is.hail.utils._ import is.hail.variant.ReferenceGenome import scala.annotation.switch +import scala.collection.mutable import scala.reflect.ClassTag import java.io._ @@ -34,7 +35,7 @@ import java.util.concurrent._ import org.apache.log4j.Logger import org.json4s.{DefaultFormats, Formats} import org.json4s.JsonAST._ -import org.json4s.jackson.JsonMethods +import org.json4s.jackson.{JsonMethods, Serialization} import sourcecode.Enclosing class ServiceBackendContext( @@ -85,10 +86,18 @@ object ServiceBackend { ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), ) + val references = mutable.Map.empty[String, ReferenceGenome] + references ++= ReferenceGenome.builtinReferences() + ReferenceGenome.addFatalOnCollision( + references, + rpcConfig.custom_references.map(ReferenceGenome.fromJSON), + ) + val backend = new ServiceBackend( JarUrl(jarLocation), name, theHailClassLoader, + references, batchClient, batchConfig, flags, @@ -97,16 +106,16 @@ object ServiceBackend { backendContext, scratchDir, ) - backend.addDefaultReferences() - rpcConfig.custom_references.foreach(s => backend.addReference(ReferenceGenome.fromJSON(s))) - rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => - liftoversForSource.foreach { case (destGenome, chainFile) => - backend.addLiftover(sourceGenome, chainFile, destGenome) + backend.withExecuteContext { ctx => + rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => + liftoversForSource.foreach { case (destGenome, chainFile) => + references(sourceGenome).addLiftover(ctx, chainFile, destGenome) + } + } + rpcConfig.sequences.foreach { case (rg, seq) => + references(rg).addSequence(ctx, seq.fasta, seq.index) } - } - rpcConfig.sequences.foreach { case (rg, seq) => - backend.addSequence(rg, seq.fasta, seq.index) } backend @@ -117,6 +126,7 @@ class ServiceBackend( val jarSpec: JarSpec, var name: String, val theHailClassLoader: HailClassLoader, + override val references: mutable.Map[String, ReferenceGenome], val batchClient: BatchClient, val batchConfig: BatchConfig, val flags: HailFeatureFlags, @@ -397,12 +407,13 @@ class ServiceBackend( )(f) } - def addLiftover(name: String, chainFile: String, destRGName: String): Unit = - withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName)) - - def addSequence(name: String, fastaFile: String, indexFile: String): Unit = - withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile)) - + override def loadReferencesFromDataset(path: String): Array[Byte] = + withExecuteContext { ctx => + val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) + ReferenceGenome.addFatalOnCollision(references, rgs) + implicit val formats: Formats = defaultJSONFormats + Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) + } } class EndOfInputException extends RuntimeException @@ -577,8 +588,7 @@ class ServiceBackendAPI( val code = qobExecutePayload.payload.ir backend.withExecuteContext { ctx => withIRFunctionsReadFromInput(qobExecutePayload.functions, ctx) { () => - val ir = - IRParser.parse_value_ir(code, IRParserEnvironment(ctx, backend.persistedIR.toMap)) + val ir = IRParser.parse_value_ir(ctx, code) backend.execute(ctx, ir) match { case Left(()) => Array() diff --git a/hail/src/main/scala/is/hail/backend/service/Worker.scala b/hail/src/main/scala/is/hail/backend/service/Worker.scala index 5722e3bc435..4f596de3b4a 100644 --- a/hail/src/main/scala/is/hail/backend/service/Worker.scala +++ b/hail/src/main/scala/is/hail/backend/service/Worker.scala @@ -177,6 +177,7 @@ object Worker { null, null, null, + null, scratchDir, ) } else { @@ -192,6 +193,7 @@ object Worker { null, null, null, + null, scratchDir, ) ) diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index 2b4720cd10e..6a9abf1a245 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -4,16 +4,15 @@ import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ -import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex, Validate} -import is.hail.expr.ir.{IRParser, _} -import is.hail.expr.ir.IRParser.parseType +import is.hail.backend.py4j.Py4JBackendExtensions +import is.hail.expr.Validate +import is.hail.expr.ir._ import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.lowering._ import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs._ -import is.hail.linalg.{BlockMatrix, RowMatrix} +import is.hail.linalg.BlockMatrix import is.hail.rvd.RVD -import is.hail.stats.LinearMixedModel import is.hail.types._ import is.hail.types.physical.{PStruct, PTuple} import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType @@ -21,23 +20,21 @@ import is.hail.types.virtual._ import is.hail.utils._ import is.hail.variant.ReferenceGenome -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionException import scala.reflect.ClassTag import scala.util.control.NonFatal -import java.io.{Closeable, PrintWriter} +import java.io.PrintWriter +import com.fasterxml.jackson.core.StreamReadConstraints import org.apache.hadoop import org.apache.hadoop.conf.Configuration import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SparkSession} -import org.json4s -import org.json4s.jackson.JsonMethods +import org.apache.spark.sql.SparkSession import sourcecode.Enclosing class SparkBroadcastValue[T](bc: Broadcast[T]) extends BroadcastValue[T] with Serializable { @@ -77,6 +74,18 @@ object SparkBackend { val MaxStageParallelism = "spark_max_stage_parallelism" } + // From https://github.com/hail-is/hail/issues/14580 : + // IR can get quite big, especially as it can contain an arbitrary + // amount of encoded literals from the user's python session. This + // was a (controversial) restriction imposed by Jackson and should be lifted. + // + // We remove this restriction at the earliest point possible for each backend/ + // This can't be unified since each backend has its own entry-point from python + // and its own specific initialisation code. + StreamReadConstraints.overrideDefaultStreamReadConstraints( + StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build() + ) + private var theSparkBackend: SparkBackend = _ def sparkContext(op: String): SparkContext = HailContext.sparkBackend(op).sc @@ -274,8 +283,14 @@ object SparkBackend { sc1.uiWebUrl.foreach(ui => info(s"SparkUI: $ui")) theSparkBackend = - new SparkBackend(tmpdir, localTmpdir, sc1, gcsRequesterPaysProject, gcsRequesterPaysBuckets) - theSparkBackend.addDefaultReferences() + new SparkBackend( + tmpdir, + localTmpdir, + sc1, + mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*), + gcsRequesterPaysProject, + gcsRequesterPaysBuckets, + ) theSparkBackend } @@ -302,9 +317,11 @@ class SparkBackend( val tmpdir: String, val localTmpdir: String, val sc: SparkContext, + override val references: mutable.Map[String, ReferenceGenome], gcsRequesterPaysProject: String, gcsRequesterPaysBuckets: String, -) extends Backend with Closeable with BackendWithCodeCache { +) extends Backend with BackendWithCodeCache with Py4JBackendExtensions { + assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null) lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate() @@ -328,17 +345,12 @@ class SparkBackend( new HadoopFS(new SerializableHadoopConfiguration(conf)) } - private[this] val longLifeTempFileManager: TempFileManager = new OwningTempFileManager(fs) + override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() - val bmCache: SparkBlockMatrixCache = SparkBlockMatrixCache() + override val longLifeTempFileManager: TempFileManager = + new OwningTempFileManager(fs) - private[this] val flags = HailFeatureFlags.fromEnv() - - def getFlag(name: String): String = flags.get(name) - - def setFlag(name: String, value: String) = flags.set(name, value) - - val availableFlags: java.util.ArrayList[String] = flags.available + val bmCache: SparkBlockMatrixCache = SparkBlockMatrixCache() def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String) : Unit = bmCache.persistBlockMatrix(id, value, storageLevel) @@ -375,25 +387,7 @@ class SparkBackend( new IrMetadata(), ) - def withExecuteContext[T]( - selfContainedExecution: Boolean = true - )( - f: ExecuteContext => T - )(implicit E: Enclosing - ): T = - withExecuteContext( - if (selfContainedExecution) null else NonOwningTempFileManager(longLifeTempFileManager) - )(f) - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - withExecuteContext(null.asInstanceOf[TempFileManager])(f) - - def withExecuteContext[T]( - tmpFileManager: TempFileManager - )( - f: ExecuteContext => T - )(implicit E: Enclosing - ): T = ExecutionTimer.logTime { timer => ExecuteContext.scoped( tmpdir, @@ -401,7 +395,7 @@ class SparkBackend( this, fs, timer, - tmpFileManager, + null, theHailClassLoader, flags, new BackendContext { @@ -446,7 +440,7 @@ class SparkBackend( } } - val chunkSize = getFlag(SparkBackend.Flags.MaxStageParallelism).toInt + val chunkSize = flags.get(SparkBackend.Flags.MaxStageParallelism).toInt val partsToRun = partitions.getOrElse(contexts.indices) val buffer = new ArrayBuffer[(Array[Byte], Int)](partsToRun.length) var failure: Option[Throwable] = None @@ -547,185 +541,16 @@ class SparkBackend( Validate(ir) ctx.irMetadata.semhash = SemanticHash(ctx)(ir) try { - val lowerTable = getFlag("lower") != null - val lowerBM = getFlag("lower_bm") != null + val lowerTable = flags.get("lower") != null + val lowerBM = flags.get("lower_bm") != null _jvmLowerAndExecute(ctx, ir, optimize = true, lowerTable, lowerBM) } catch { - case e: LowererUnsupportedOperation if getFlag("lower_only") != null => throw e + case e: LowererUnsupportedOperation if flags.get("lower_only") != null => throw e case _: LowererUnsupportedOperation => CompileAndEvaluate._apply(ctx, ir, optimize = true) } } - def executeLiteral(irStr: String): Int = - withExecuteContext { ctx => - val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap)) - assert(ir.typ.isRealizable) - execute(ctx, ir) match { - case Left(_) => throw new HailException("Can't create literal") - case Right((pt, addr)) => - val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) - addJavaIR(field) - } - } - - def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = { - val key = jKey.asScala.toArray.toFastSeq - val signature = - SparkAnnotationImpex.importType(df.schema).setRequired(true).asInstanceOf[PStruct] - withExecuteContext(selfContainedExecution = false) { ctx => - val tir = TableLiteral( - TableValue( - ctx, - signature.virtualType.asInstanceOf[TStruct], - key, - df.rdd, - Some(signature), - ), - ctx.theHailClassLoader, - ) - val id = addJavaIR(tir) - (id, JsonMethods.compact(tir.typ.toJSON)) - } - } - - def pyToDF(s: String): DataFrame = - withExecuteContext(selfContainedExecution = false) { ctx => - val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - Interpret(tir, ctx).toDF() - } - - def pyReadMultipleMatrixTables(jsonQuery: String): java.util.List[MatrixIR] = { - log.info("pyReadMultipleMatrixTables: got query") - val kvs = JsonMethods.parse(jsonQuery) match { - case json4s.JObject(values) => values.toMap - } - - val paths = kvs("paths").asInstanceOf[json4s.JArray].arr.toArray.map { case json4s.JString(s) => - s - } - - val intervalPointType = parseType(kvs("intervalPointType").asInstanceOf[json4s.JString].s) - val intervalObjects = - JSONAnnotationImpex.importAnnotation(kvs("intervals"), TArray(TInterval(intervalPointType))) - .asInstanceOf[IndexedSeq[Interval]] - - val opts = NativeReaderOptions(intervalObjects, intervalPointType, filterIntervals = false) - val matrixReaders: IndexedSeq[MatrixIR] = paths.map { p => - log.info(s"creating MatrixRead node for $p") - val mnr = MatrixNativeReader(fs, p, Some(opts)) - MatrixRead(mnr.fullMatrixTypeWithoutUIDs, false, false, mnr): MatrixIR - } - log.info("pyReadMultipleMatrixTables: returning N matrix tables") - matrixReaders.asJava - } - - def pyAddReference(jsonConfig: String): Unit = addReference(ReferenceGenome.fromJSON(jsonConfig)) - def pyRemoveReference(name: String): Unit = removeReference(name) - - def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = - withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName)) - - def pyRemoveLiftover(name: String, destRGName: String) = - references(name).removeLiftover(destRGName) - - def pyFromFASTAFile( - name: String, - fastaFile: String, - indexFile: String, - xContigs: java.util.List[String], - yContigs: java.util.List[String], - mtContigs: java.util.List[String], - parInput: java.util.List[String], - ): String = - withExecuteContext { ctx => - val rg = ReferenceGenome.fromFASTAFile( - ctx, - name, - fastaFile, - indexFile, - xContigs.asScala.toArray, - yContigs.asScala.toArray, - mtContigs.asScala.toArray, - parInput.asScala.toArray, - ) - rg.toJSONString - } - - def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = - withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile)) - - def pyRemoveSequence(name: String) = references(name).removeSequence() - - def pyExportBlockMatrix( - pathIn: String, - pathOut: String, - delimiter: String, - header: String, - addIndex: Boolean, - exportType: String, - partitionSize: java.lang.Integer, - entries: String, - ): Unit = - withExecuteContext { ctx => - val rm = RowMatrix.readBlockMatrix(fs, pathIn, partitionSize) - entries match { - case "full" => - rm.export(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "lower" => - rm.exportLowerTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "strict_lower" => - rm.exportStrictLowerTriangle( - ctx, - pathOut, - delimiter, - Option(header), - addIndex, - exportType, - ) - case "upper" => - rm.exportUpperTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "strict_upper" => - rm.exportStrictUpperTriangle( - ctx, - pathOut, - delimiter, - Option(header), - addIndex, - exportType, - ) - } - } - - def pyFitLinearMixedModel(lmm: LinearMixedModel, pa_t: RowMatrix, a_t: RowMatrix): TableIR = - withExecuteContext(selfContainedExecution = false)(ctx => lmm.fit(ctx, pa_t, Option(a_t))) - - def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = - withExecuteContext { ctx => - IRParser.parse_value_ir( - s, - IRParserEnvironment(ctx, irMap = persistedIR.toMap), - BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) => - Name(n) -> IRParser.parseType(t) - }.toSeq: _*), - ) - } - - def parse_table_ir(s: String): TableIR = - withExecuteContext(selfContainedExecution = false) { ctx => - IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - } - - def parse_matrix_ir(s: String): MatrixIR = - withExecuteContext(selfContainedExecution = false) { ctx => - IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - } - - def parse_blockmatrix_ir(s: String): BlockMatrixIR = - withExecuteContext(selfContainedExecution = false) { ctx => - IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - } - override def lowerDistributedSort( ctx: ExecuteContext, stage: TableStage, @@ -733,7 +558,7 @@ class SparkBackend( rt: RTable, nPartitions: Option[Int], ): TableReader = { - if (getFlag("use_new_shuffle") != null) + if (flags.get("use_new_shuffle") != null) return LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt) val (globals, rvd) = TableStageToRVD(ctx, stage) diff --git a/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala b/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala index 7412b32ff9b..e17f49c0a64 100644 --- a/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala +++ b/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala @@ -795,4 +795,19 @@ object ReferenceGenome { def getMapFromArray(arr: Array[ReferenceGenome]): Map[String, ReferenceGenome] = arr.map(rg => (rg.name, rg)).toMap + + def addFatalOnCollision( + existing: mutable.Map[String, ReferenceGenome], + newReferences: IndexedSeq[ReferenceGenome], + ): Unit = + for (rg <- newReferences) { + if (existing.get(rg.name).exists(_ != rg)) + fatal( + s"Cannot add reference genome '${rg.name}', a different reference with that name already exists. " ++ + "Choose a reference name NOT in the following list:" ++ + existing.keys.toFastSeq.sorted.mkString(start = "\n ", sep = "\n ", end = "") + ) + + existing += (rg.name -> rg) + } } diff --git a/hail/src/test/scala/is/hail/HailSuite.scala b/hail/src/test/scala/is/hail/HailSuite.scala index bccfee74c3f..e2422740d96 100644 --- a/hail/src/test/scala/is/hail/HailSuite.scala +++ b/hail/src/test/scala/is/hail/HailSuite.scala @@ -44,7 +44,7 @@ object HailSuite { lazy val hc: HailContext = { val hc = withSparkBackend() - hc.sparkBackend("HailSuite.hc").setFlag("lower", "1") + hc.sparkBackend("HailSuite.hc").flags.set("lower", "1") hc.checkRVDKeys = true hc } diff --git a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala index c66e1b0fcbd..461eed812f4 100644 --- a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala @@ -8,6 +8,7 @@ import is.hail.services._ import is.hail.services.JobGroupStates.Success import is.hail.utils.{tokenUrlSafe, using} +import scala.collection.mutable import scala.reflect.io.{Directory, Path} import scala.util.Random @@ -102,6 +103,7 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV jarSpec = GitRevision("123"), name = "name", theHailClassLoader = new HailClassLoader(getClass.getClassLoader), + references = mutable.Map.empty, batchClient = client, batchConfig = BatchConfig(batchId = Random.nextInt(), jobGroupId = Random.nextInt()), flags = flags, diff --git a/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala b/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala index eedfbb748b9..be15fd59677 100644 --- a/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala @@ -99,9 +99,9 @@ class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances { }.check() def withReferenceGenome[A](rg: ReferenceGenome)(f: => A): A = { - ctx.backend.addReference(rg) + ctx.backend.references += (rg.name -> rg) try f - finally ctx.backend.removeReference(rg.name) + finally ctx.backend.references.remove(rg.name) } def sort(xs: IndexedSeq[Int]): IndexedSeq[Int] = diff --git a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala index fb296ad605b..7c5dd3eb33d 100644 --- a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala @@ -61,9 +61,9 @@ class LowerDistributedSortSuite extends HailSuite { // Only does ascending for now def testDistributedSortHelper(myTable: TableIR, sortFields: IndexedSeq[SortField]): Unit = { - val originalShuffleCutoff = backend.getFlag("shuffle_cutoff_to_local_sort") + val originalShuffleCutoff = backend.flags.get("shuffle_cutoff_to_local_sort") try { - backend.setFlag("shuffle_cutoff_to_local_sort", "40") + backend.flags.set("shuffle_cutoff_to_local_sort", "40") val analyses: LoweringAnalyses = LoweringAnalyses.apply(myTable, ctx) val rt = analyses.requirednessAnalysis.lookup(myTable).asInstanceOf[RTable] val stage = LowerTableIR.applyTable(myTable, DArrayLowering.All, ctx, analyses) @@ -104,7 +104,7 @@ class LowerDistributedSortSuite extends HailSuite { } assert(res == scalaSorted) } finally - backend.setFlag("shuffle_cutoff_to_local_sort", originalShuffleCutoff) + backend.flags.set("shuffle_cutoff_to_local_sort", originalShuffleCutoff) } @Test def testDistributedSort(): Unit = {