From 35182440d3b2db4c0aa823324cbe49cc47626c42 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Mon, 16 Sep 2024 20:42:19 -0400 Subject: [PATCH] extract references from `Backend` implementations --- .../main/scala/is/hail/backend/Backend.scala | 100 +------ .../scala/is/hail/backend/BackendServer.scala | 1 + .../is/hail/backend/local/LocalBackend.scala | 106 ++----- .../backend/py4j/Py4JBackendExtensions.scala | 277 ++++++++++++++++++ .../hail/backend/service/ServiceBackend.scala | 53 ++-- .../is/hail/backend/service/Worker.scala | 2 + .../is/hail/backend/spark/SparkBackend.scala | 238 ++------------- .../scala/is/hail/io/reference/LiftOver.scala | 69 ----- .../scala/is/hail/io/reference/package.scala | 164 +++++++++++ .../main/scala/is/hail/io/vcf/LoadVCF.scala | 2 +- .../is/hail/variant/ReferenceGenome.scala | 202 +++---------- hail/src/test/scala/is/hail/HailSuite.scala | 4 - .../is/hail/annotations/UnsafeSuite.scala | 2 +- .../is/hail/backend/ServiceBackendSuite.scala | 1 + .../test/scala/is/hail/expr/ir/IRSuite.scala | 6 +- .../is/hail/expr/ir/RequirednessSuite.scala | 3 +- .../hail/variant/ReferenceGenomeSuite.scala | 10 +- 17 files changed, 585 insertions(+), 655 deletions(-) create mode 100644 hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala delete mode 100644 hail/src/main/scala/is/hail/io/reference/LiftOver.scala create mode 100644 hail/src/main/scala/is/hail/io/reference/package.scala diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index e0444e371e27..d9be76c21806 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -1,12 +1,12 @@ package is.hail.backend import is.hail.asm4s._ +import is.hail.backend.Backend.jsonToBytes import is.hail.backend.spark.SparkBackend import is.hail.expr.ir.{ BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader, } -import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs._ @@ -20,7 +20,6 @@ import is.hail.types.virtual.{BlockMatrixType, TFloat64} import is.hail.utils._ import is.hail.variant.ReferenceGenome -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag @@ -29,7 +28,7 @@ import java.nio.charset.StandardCharsets import com.fasterxml.jackson.core.StreamReadConstraints import org.json4s._ -import org.json4s.jackson.{JsonMethods, Serialization} +import org.json4s.jackson.JsonMethods import sourcecode.Enclosing object Backend { @@ -41,13 +40,6 @@ object Backend { s"hail_query_$id" } - private var irID: Int = 0 - - def nextIRID(): Int = { - irID += 1 - irID - } - def encodeToOutputStream( ctx: ExecuteContext, t: PTuple, @@ -66,6 +58,9 @@ object Backend { assert(t.isFieldDefined(off, 0)) codec.encode(ctx, elementType, t.loadField(off, 0), os) } + + def jsonToBytes(f: => JValue): Array[Byte] = + JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8) } abstract class BroadcastValue[T] { def value: T } @@ -89,14 +84,6 @@ abstract class Backend extends Closeable { val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map() - protected[this] def addJavaIR(ir: BaseIR): Int = { - val id = Backend.nextIRID() - persistedIR += (id -> ir) - id - } - - def removeJavaIR(id: Int): Unit = persistedIR.remove(id) - def defaultParallelism: Int def canExecuteParallelTasksOnDriver: Boolean = true @@ -133,31 +120,6 @@ abstract class Backend extends Closeable { def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) : CompiledFunction[T] - var references: Map[String, ReferenceGenome] = Map.empty - - def addDefaultReferences(): Unit = - references = ReferenceGenome.builtinReferences() - - def addReference(rg: ReferenceGenome): Unit = { - references.get(rg.name) match { - case Some(rg2) => - if (rg != rg2) { - fatal( - s"Cannot add reference genome '${rg.name}', a different reference with that name already exists. Choose a reference name NOT in the following list:\n " + - s"@1", - references.keys.truncatable("\n "), - ) - } - case None => - references += (rg.name -> rg) - } - } - - def hasReference(name: String) = references.contains(name) - - def removeReference(name: String): Unit = - references -= name - def lowerDistributedSort( ctx: ExecuteContext, stage: TableStage, @@ -191,9 +153,6 @@ abstract class Backend extends Closeable { def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T - private[this] def jsonToBytes(f: => JValue): Array[Byte] = - JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8) - final def valueType(s: String): Array[Byte] = jsonToBytes { withExecuteContext { ctx => @@ -222,15 +181,7 @@ abstract class Backend extends Closeable { } } - def loadReferencesFromDataset(path: String): Array[Byte] = { - withExecuteContext { ctx => - val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) - rgs.foreach(addReference) - - implicit val formats: Formats = defaultJSONFormats - Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) - } - } + def loadReferencesFromDataset(path: String): Array[Byte] def fromFASTAFile( name: String, @@ -242,18 +193,20 @@ abstract class Backend extends Closeable { parInput: Array[String], ): Array[Byte] = withExecuteContext { ctx => - val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, - xContigs, yContigs, mtContigs, parInput) - rg.toJSONString.getBytes(StandardCharsets.UTF_8) + jsonToBytes { + ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, + xContigs, yContigs, mtContigs, parInput).toJSON + } } - def parseVCFMetadata(path: String): Array[Byte] = jsonToBytes { + def parseVCFMetadata(path: String): Array[Byte] = withExecuteContext { ctx => - val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path) - implicit val formats = defaultJSONFormats - Extraction.decompose(metadata) + jsonToBytes { + val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path) + implicit val formats = defaultJSONFormats + Extraction.decompose(metadata) + } } - } def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String) : Array[Byte] = @@ -263,27 +216,6 @@ abstract class Backend extends Closeable { ) } - def pyRegisterIR( - name: String, - typeParamStrs: java.util.ArrayList[String], - argNameStrs: java.util.ArrayList[String], - argTypeStrs: java.util.ArrayList[String], - returnType: String, - bodyStr: String, - ): Unit = { - withExecuteContext { ctx => - IRFunctionRegistry.registerIR( - ctx, - name, - typeParamStrs.asScala.toArray, - argNameStrs.asScala.toArray, - argTypeStrs.asScala.toArray, - returnType, - bodyStr, - ) - } - } - def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] } diff --git a/hail/src/main/scala/is/hail/backend/BackendServer.scala b/hail/src/main/scala/is/hail/backend/BackendServer.scala index 7ce224548c98..04ab2200641e 100644 --- a/hail/src/main/scala/is/hail/backend/BackendServer.scala +++ b/hail/src/main/scala/is/hail/backend/BackendServer.scala @@ -112,6 +112,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler { } return } + val response: Array[Byte] = exchange.getRequestURI.getPath match { case "/value/type" => backend.valueType(body.extract[IRTypePayload].ir) case "/table/type" => backend.tableType(body.extract[IRTypePayload].ir) diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index ffd1c7f75791..aa995fdc677a 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -4,8 +4,9 @@ import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend._ +import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate -import is.hail.expr.ir.{IRParser, _} +import is.hail.expr.ir._ import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.lowering._ import is.hail.io.fs._ @@ -17,7 +18,7 @@ import is.hail.types.virtual.{BlockMatrixType, TVoid} import is.hail.utils._ import is.hail.variant.ReferenceGenome -import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag import java.io.PrintWriter @@ -47,8 +48,11 @@ object LocalBackend { if (!skipLoggingConfiguration) HailContext.configureLogging(logFile, quiet, append) - theLocalBackend = new LocalBackend(tmpdir) - theLocalBackend.addDefaultReferences() + theLocalBackend = new LocalBackend( + tmpdir, + mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*), + ) + theLocalBackend } @@ -64,18 +68,16 @@ object LocalBackend { } } -class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache { - - private[this] val flags = HailFeatureFlags.fromEnv() - private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) +class LocalBackend( + val tmpdir: String, + override val references: mutable.Map[String, ReferenceGenome], +) extends Backend with BackendWithCodeCache with Py4JBackendExtensions { - def getFlag(name: String): String = flags.get(name) + override def backend: Backend = this + override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() + override def longLifeTempFileManager: TempFileManager = null - def setFlag(name: String, value: String) = flags.set(name, value) - - // called from python - val availableFlags: java.util.ArrayList[String] = - flags.available + private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) // flags can be set after construction from python def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) @@ -87,7 +89,7 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache tmpdir, tmpdir, this, - references, + references.toMap, fs, timer, null, @@ -191,80 +193,6 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache res } - def executeLiteral(irStr: String): Int = - withExecuteContext { ctx => - val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap)) - execute(ctx, ir) match { - case Left(_) => throw new HailException("Can't create literal") - case Right((pt, addr)) => - val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) - addJavaIR(field) - } - } - - def pyAddReference(jsonConfig: String): Unit = addReference(ReferenceGenome.fromJSON(jsonConfig)) - def pyRemoveReference(name: String): Unit = removeReference(name) - - def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = - withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName)) - - def pyRemoveLiftover(name: String, destRGName: String) = - references(name).removeLiftover(destRGName) - - def pyFromFASTAFile( - name: String, - fastaFile: String, - indexFile: String, - xContigs: java.util.List[String], - yContigs: java.util.List[String], - mtContigs: java.util.List[String], - parInput: java.util.List[String], - ): String = - withExecuteContext { ctx => - val rg = ReferenceGenome.fromFASTAFile( - ctx, - name, - fastaFile, - indexFile, - xContigs.asScala.toArray, - yContigs.asScala.toArray, - mtContigs.asScala.toArray, - parInput.asScala.toArray, - ) - rg.toJSONString - } - - def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = - withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile)) - - def pyRemoveSequence(name: String) = references(name).removeSequence() - - def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = - withExecuteContext { ctx => - IRParser.parse_value_ir( - s, - IRParserEnvironment(ctx, persistedIR.toMap), - BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) => - Name(n) -> IRParser.parseType(t) - }.toSeq: _*), - ) - } - - def parse_table_ir(s: String): TableIR = - withExecuteContext { ctx => - IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - } - - def parse_matrix_ir(s: String): MatrixIR = - withExecuteContext { ctx => - IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - } - - def parse_blockmatrix_ir(s: String): BlockMatrixIR = - withExecuteContext { ctx => - IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - } - override def lowerDistributedSort( ctx: ExecuteContext, stage: TableStage, diff --git a/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala new file mode 100644 index 000000000000..b431456b8227 --- /dev/null +++ b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala @@ -0,0 +1,277 @@ +package is.hail.backend.py4j + +import is.hail.HailFeatureFlags +import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager} +import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} +import is.hail.expr.ir.{ + BaseIR, BindingEnv, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IR, IRParser, + IRParserEnvironment, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, Name, + NativeReaderOptions, TableIR, TableLiteral, TableValue, +} +import is.hail.expr.ir.IRParser.parseType +import is.hail.expr.ir.functions.IRFunctionRegistry +import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} +import is.hail.linalg.RowMatrix +import is.hail.types.physical.PStruct +import is.hail.types.virtual.{TArray, TInterval, TStruct} +import is.hail.utils.{defaultJSONFormats, fatal, log, toRichIterable, HailException, Interval} +import is.hail.variant.ReferenceGenome + +import scala.collection.mutable +import scala.jdk.CollectionConverters.{ + asScalaBufferConverter, mapAsScalaMapConverter, seqAsJavaListConverter, +} + +import java.nio.charset.StandardCharsets +import java.util + +import org.apache.spark.sql.DataFrame +import org.json4s +import org.json4s.Formats +import org.json4s.jackson.{JsonMethods, Serialization} +import sourcecode.Enclosing + +trait Py4JBackendExtensions { + def backend: Backend + def references: mutable.Map[String, ReferenceGenome] + def persistedIR: mutable.Map[Int, BaseIR] + def flags: HailFeatureFlags + def longLifeTempFileManager: TempFileManager + + def getFlag(name: String): String = + flags.get(name) + + def setFlag(name: String, value: String): Unit = + flags.set(name, value) + + val availableFlags: java.util.ArrayList[String] = + flags.available + + private[this] var irID: Int = 0 + + def nextIRID(): Int = + synchronized { + irID += 1 + irID + } + + protected[this] def addJavaIR(ir: BaseIR): Int = { + val id = nextIRID() + persistedIR += (id -> ir) + id + } + + def removeJavaIR(id: Int): Unit = + persistedIR.remove(id) + + def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = + backend.withExecuteContext { ctx => + references(name).addSequence(IndexedFastaSequenceFile(ctx.fs, fastaFile, indexFile)) + } + + def pyRemoveSequence(name: String): Unit = + references(name).removeSequence() + + def pyExportBlockMatrix( + pathIn: String, + pathOut: String, + delimiter: String, + header: String, + addIndex: Boolean, + exportType: String, + partitionSize: java.lang.Integer, + entries: String, + ): Unit = + backend.withExecuteContext { ctx => + val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize) + entries match { + case "full" => + rm.export(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "lower" => + rm.exportLowerTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "strict_lower" => + rm.exportStrictLowerTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) + case "upper" => + rm.exportUpperTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "strict_upper" => + rm.exportStrictUpperTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) + } + } + + def pyRegisterIR( + name: String, + typeParamStrs: java.util.ArrayList[String], + argNameStrs: java.util.ArrayList[String], + argTypeStrs: java.util.ArrayList[String], + returnType: String, + bodyStr: String, + ): Unit = { + backend.withExecuteContext { ctx => + IRFunctionRegistry.registerIR( + ctx, + name, + typeParamStrs.asScala.toArray, + argNameStrs.asScala.toArray, + argTypeStrs.asScala.toArray, + returnType, + bodyStr, + ) + } + } + + def executeLiteral(irStr: String): Int = + backend.withExecuteContext { ctx => + val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap)) + backend.execute(ctx, ir) match { + case Left(_) => throw new HailException("Can't create literal") + case Right((pt, addr)) => + val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) + addJavaIR(field) + } + } + + def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = { + val key = jKey.asScala.toArray.toFastSeq + val signature = + SparkAnnotationImpex.importType(df.schema).setRequired(true).asInstanceOf[PStruct] + withExecuteContext(selfContainedExecution = false) { ctx => + val tir = TableLiteral( + TableValue( + ctx, + signature.virtualType.asInstanceOf[TStruct], + key, + df.rdd, + Some(signature), + ), + ctx.theHailClassLoader, + ) + val id = addJavaIR(tir) + (id, JsonMethods.compact(tir.typ.toJSON)) + } + } + + def pyToDF(s: String): DataFrame = + backend.withExecuteContext { ctx => + val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + Interpret(tir, ctx).toDF() + } + + def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] = + backend.withExecuteContext { ctx => + log.info("pyReadMultipleMatrixTables: got query") + val kvs = JsonMethods.parse(jsonQuery) match { + case json4s.JObject(values) => values.toMap + } + + val paths = + kvs("paths").asInstanceOf[json4s.JArray].arr.toArray.map { case json4s.JString(s) => + s + } + + val intervalPointType = parseType(kvs("intervalPointType").asInstanceOf[json4s.JString].s) + val intervalObjects = + JSONAnnotationImpex.importAnnotation(kvs("intervals"), TArray(TInterval(intervalPointType))) + .asInstanceOf[IndexedSeq[Interval]] + + val opts = NativeReaderOptions(intervalObjects, intervalPointType, filterIntervals = false) + val matrixReaders: IndexedSeq[MatrixIR] = paths.map { p => + log.info(s"creating MatrixRead node for $p") + val mnr = MatrixNativeReader(ctx.fs, p, Some(opts)) + MatrixRead(mnr.fullMatrixTypeWithoutUIDs, false, false, mnr): MatrixIR + } + log.info("pyReadMultipleMatrixTables: returning N matrix tables") + matrixReaders.asJava + } + + def pyAddReference(jsonConfig: String): Unit = + addReference(ReferenceGenome.fromJSON(jsonConfig)) + + def pyRemoveReference(name: String): Unit = + removeReference(name) + + def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = + backend.withExecuteContext { ctx => + references(name).addLiftover(references(destRGName), LiftOver(ctx.fs, chainFile)) + } + + def pyRemoveLiftover(name: String, destRGName: String) = + references(name).removeLiftover(destRGName) + + def addReference(rg: ReferenceGenome): Unit = { + references.get(rg.name) match { + case Some(rg2) => + if (rg != rg2) { + fatal( + s"Cannot add reference genome '${rg.name}', a different reference with that name already exists. Choose a reference name NOT in the following list:\n " + + s"@1", + references.keys.truncatable("\n "), + ) + } + case None => + references += (rg.name -> rg) + } + } + + def removeReference(name: String): Unit = + references -= name + + def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = + backend.withExecuteContext { ctx => + IRParser.parse_value_ir( + s, + IRParserEnvironment(ctx, irMap = persistedIR.toMap), + BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) => + Name(n) -> IRParser.parseType(t) + }.toSeq: _*), + ) + } + + def parse_table_ir(s: String): TableIR = + withExecuteContext(selfContainedExecution = false) { ctx => + IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + } + + def parse_matrix_ir(s: String): MatrixIR = + withExecuteContext(selfContainedExecution = false) { ctx => + IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + } + + def parse_blockmatrix_ir(s: String): BlockMatrixIR = + withExecuteContext(selfContainedExecution = false) { ctx => + IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + } + + def loadReferencesFromDataset(path: String): Array[Byte] = + backend.withExecuteContext { ctx => + val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) + rgs.foreach(addReference) + + implicit val formats: Formats = defaultJSONFormats + Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) + } + + def withExecuteContext[T]( + selfContainedExecution: Boolean = true + )( + f: ExecuteContext => T + )(implicit E: Enclosing + ): T = + backend.withExecuteContext { ctx => + if (selfContainedExecution && longLifeTempFileManager != null) f(ctx) + else ctx.local(tempFileManager = NonOwningTempFileManager(longLifeTempFileManager))(f) + } +} diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 95118f28dc91..ba57d7599883 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -6,13 +6,14 @@ import is.hail.asm4s._ import is.hail.backend._ import is.hail.expr.Validate import is.hail.expr.ir.{ - Compile, IR, IRParser, IRParserEnvironment, IRSize, LoweringAnalyses, MakeTuple, SortField, + Compile, IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck, } import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ import is.hail.io.fs._ +import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} import is.hail.linalg.BlockMatrix import is.hail.services.{BatchClient, _} import is.hail.services.JobGroupStates.Failure @@ -24,6 +25,7 @@ import is.hail.utils._ import is.hail.variant.ReferenceGenome import scala.annotation.switch +import scala.collection.mutable import scala.reflect.ClassTag import java.io._ @@ -34,7 +36,7 @@ import java.util.concurrent._ import org.apache.log4j.Logger import org.json4s.{DefaultFormats, Formats} import org.json4s.JsonAST._ -import org.json4s.jackson.JsonMethods +import org.json4s.jackson.{JsonMethods, Serialization} import sourcecode.Enclosing class ServiceBackendContext( @@ -86,10 +88,26 @@ object ServiceBackend { ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), ) - val backend = new ServiceBackend( + val references = mutable.Map.empty[String, ReferenceGenome] + references ++= ReferenceGenome.builtinReferences() + rpcConfig.custom_references.map(ReferenceGenome.fromJSON).foreach { r => + references += (r.name -> r) + } + + rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => + liftoversForSource.foreach { case (destGenome, chainFile) => + references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile)) + } + } + rpcConfig.sequences.foreach { case (rg, seq) => + references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index)) + } + + new ServiceBackend( jarLocation, name, theHailClassLoader, + references.toMap, batchClient, batchId, jobGroupId, @@ -99,19 +117,6 @@ object ServiceBackend { backendContext, scratchDir, ) - backend.addDefaultReferences() - - rpcConfig.custom_references.foreach(s => backend.addReference(ReferenceGenome.fromJSON(s))) - rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => - liftoversForSource.foreach { case (destGenome, chainFile) => - backend.addLiftover(sourceGenome, chainFile, destGenome) - } - } - rpcConfig.sequences.foreach { case (rg, seq) => - backend.addSequence(rg, seq.fasta, seq.index) - } - - backend } } @@ -119,6 +124,7 @@ class ServiceBackend( val jarLocation: String, var name: String, val theHailClassLoader: HailClassLoader, + val references: Map[String, ReferenceGenome], val batchClient: BatchClient, val curBatchId: Option[Int], val curJobGroupId: Option[Int], @@ -394,12 +400,12 @@ class ServiceBackend( )(f) } - def addLiftover(name: String, chainFile: String, destRGName: String): Unit = - withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName)) - - def addSequence(name: String, fastaFile: String, indexFile: String): Unit = - withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile)) - + override def loadReferencesFromDataset(path: String): Array[Byte] = + withExecuteContext { ctx => + val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) + implicit val formats: Formats = defaultJSONFormats + Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) + } } class EndOfInputException extends RuntimeException @@ -577,8 +583,7 @@ class ServiceBackendAPI( val code = qobExecutePayload.payload.ir backend.withExecuteContext { ctx => withIRFunctionsReadFromInput(qobExecutePayload.functions, ctx) { () => - val ir = - IRParser.parse_value_ir(code, IRParserEnvironment(ctx, backend.persistedIR.toMap)) + val ir = IRParser.parse_value_ir(ctx, code) backend.execute(ctx, ir) match { case Left(()) => Array() diff --git a/hail/src/main/scala/is/hail/backend/service/Worker.scala b/hail/src/main/scala/is/hail/backend/service/Worker.scala index de608b265151..e40364c8fdaa 100644 --- a/hail/src/main/scala/is/hail/backend/service/Worker.scala +++ b/hail/src/main/scala/is/hail/backend/service/Worker.scala @@ -172,6 +172,7 @@ object Worker { null, new HailClassLoader(getClass().getClassLoader()), null, + null, None, None, null, @@ -188,6 +189,7 @@ object Worker { null, new HailClassLoader(getClass().getClassLoader()), null, + null, None, None, null, diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index 9dc72e7b7a7f..469825e75afb 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -4,16 +4,15 @@ import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ -import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex, Validate} -import is.hail.expr.ir.{IRParser, _} -import is.hail.expr.ir.IRParser.parseType +import is.hail.backend.py4j.Py4JBackendExtensions +import is.hail.expr.Validate +import is.hail.expr.ir._ import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.lowering._ import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs._ -import is.hail.linalg.{BlockMatrix, RowMatrix} +import is.hail.linalg.BlockMatrix import is.hail.rvd.RVD -import is.hail.stats.LinearMixedModel import is.hail.types._ import is.hail.types.physical.{PStruct, PTuple} import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType @@ -21,23 +20,20 @@ import is.hail.types.virtual._ import is.hail.utils._ import is.hail.variant.ReferenceGenome -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionException import scala.reflect.ClassTag import scala.util.control.NonFatal -import java.io.{Closeable, PrintWriter} +import java.io.PrintWriter import org.apache.hadoop import org.apache.hadoop.conf.Configuration import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SparkSession} -import org.json4s -import org.json4s.jackson.JsonMethods +import org.apache.spark.sql.SparkSession import sourcecode.Enclosing class SparkBroadcastValue[T](bc: Broadcast[T]) extends BroadcastValue[T] with Serializable { @@ -274,8 +270,14 @@ object SparkBackend { sc1.uiWebUrl.foreach(ui => info(s"SparkUI: $ui")) theSparkBackend = - new SparkBackend(tmpdir, localTmpdir, sc1, gcsRequesterPaysProject, gcsRequesterPaysBuckets) - theSparkBackend.addDefaultReferences() + new SparkBackend( + tmpdir, + localTmpdir, + sc1, + mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*), + gcsRequesterPaysProject, + gcsRequesterPaysBuckets, + ) theSparkBackend } @@ -302,9 +304,11 @@ class SparkBackend( val tmpdir: String, val localTmpdir: String, val sc: SparkContext, + override val references: mutable.Map[String, ReferenceGenome], gcsRequesterPaysProject: String, gcsRequesterPaysBuckets: String, -) extends Backend with Closeable with BackendWithCodeCache { +) extends Backend with BackendWithCodeCache with Py4JBackendExtensions { + assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null) lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate() @@ -328,17 +332,13 @@ class SparkBackend( new HadoopFS(new SerializableHadoopConfiguration(conf)) } - private[this] val longLifeTempFileManager: TempFileManager = new OwningTempFileManager(fs) - - val bmCache: SparkBlockMatrixCache = SparkBlockMatrixCache() - - private[this] val flags = HailFeatureFlags.fromEnv() + override def backend: Backend = this + override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() - def getFlag(name: String): String = flags.get(name) + override val longLifeTempFileManager: TempFileManager = + new OwningTempFileManager(fs) - def setFlag(name: String, value: String) = flags.set(name, value) - - val availableFlags: java.util.ArrayList[String] = flags.available + val bmCache: SparkBlockMatrixCache = SparkBlockMatrixCache() def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String) : Unit = bmCache.persistBlockMatrix(id, value, storageLevel) @@ -362,7 +362,7 @@ class SparkBackend( tmpdir, localTmpdir, this, - references, + references.toMap, fs, region, timer, @@ -376,34 +376,16 @@ class SparkBackend( IrMetadata(None), ) - def withExecuteContext[T]( - selfContainedExecution: Boolean = true - )( - f: ExecuteContext => T - )(implicit E: Enclosing - ): T = - withExecuteContext( - if (selfContainedExecution) null else NonOwningTempFileManager(longLifeTempFileManager) - )(f) - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - withExecuteContext(null.asInstanceOf[TempFileManager])(f) - - def withExecuteContext[T]( - tmpFileManager: TempFileManager - )( - f: ExecuteContext => T - )(implicit E: Enclosing - ): T = ExecutionTimer.logTime { timer => ExecuteContext.scoped( tmpdir, localTmpdir, this, - references, + references.toMap, fs, timer, - tmpFileManager, + null, theHailClassLoader, flags, new BackendContext { @@ -448,7 +430,7 @@ class SparkBackend( } } - val chunkSize = getFlag(SparkBackend.Flags.MaxStageParallelism).toInt + val chunkSize = flags.get(SparkBackend.Flags.MaxStageParallelism).toInt val partsToRun = partitions.getOrElse(contexts.indices) val buffer = new ArrayBuffer[(Array[Byte], Int)](partsToRun.length) var failure: Option[Throwable] = None @@ -559,174 +541,6 @@ class SparkBackend( } } - def executeLiteral(irStr: String): Int = - withExecuteContext { ctx => - val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap)) - execute(ctx, ir) match { - case Left(_) => throw new HailException("Can't create literal") - case Right((pt, addr)) => - val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) - addJavaIR(field) - } - } - - def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = { - val key = jKey.asScala.toArray.toFastSeq - val signature = - SparkAnnotationImpex.importType(df.schema).setRequired(true).asInstanceOf[PStruct] - withExecuteContext(selfContainedExecution = false) { ctx => - val tir = TableLiteral( - TableValue( - ctx, - signature.virtualType.asInstanceOf[TStruct], - key, - df.rdd, - Some(signature), - ), - ctx.theHailClassLoader, - ) - val id = addJavaIR(tir) - (id, JsonMethods.compact(tir.typ.toJSON)) - } - } - - def pyToDF(s: String): DataFrame = - withExecuteContext(selfContainedExecution = false) { ctx => - val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - Interpret(tir, ctx).toDF() - } - - def pyReadMultipleMatrixTables(jsonQuery: String): java.util.List[MatrixIR] = { - log.info("pyReadMultipleMatrixTables: got query") - val kvs = JsonMethods.parse(jsonQuery) match { - case json4s.JObject(values) => values.toMap - } - - val paths = kvs("paths").asInstanceOf[json4s.JArray].arr.toArray.map { case json4s.JString(s) => - s - } - - val intervalPointType = parseType(kvs("intervalPointType").asInstanceOf[json4s.JString].s) - val intervalObjects = - JSONAnnotationImpex.importAnnotation(kvs("intervals"), TArray(TInterval(intervalPointType))) - .asInstanceOf[IndexedSeq[Interval]] - - val opts = NativeReaderOptions(intervalObjects, intervalPointType, filterIntervals = false) - val matrixReaders: IndexedSeq[MatrixIR] = paths.map { p => - log.info(s"creating MatrixRead node for $p") - val mnr = MatrixNativeReader(fs, p, Some(opts)) - MatrixRead(mnr.fullMatrixTypeWithoutUIDs, false, false, mnr): MatrixIR - } - log.info("pyReadMultipleMatrixTables: returning N matrix tables") - matrixReaders.asJava - } - - def pyAddReference(jsonConfig: String): Unit = addReference(ReferenceGenome.fromJSON(jsonConfig)) - def pyRemoveReference(name: String): Unit = removeReference(name) - - def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = - withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName)) - - def pyRemoveLiftover(name: String, destRGName: String) = - references(name).removeLiftover(destRGName) - - def pyFromFASTAFile( - name: String, - fastaFile: String, - indexFile: String, - xContigs: java.util.List[String], - yContigs: java.util.List[String], - mtContigs: java.util.List[String], - parInput: java.util.List[String], - ): String = - withExecuteContext { ctx => - val rg = ReferenceGenome.fromFASTAFile( - ctx, - name, - fastaFile, - indexFile, - xContigs.asScala.toArray, - yContigs.asScala.toArray, - mtContigs.asScala.toArray, - parInput.asScala.toArray, - ) - rg.toJSONString - } - - def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = - withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile)) - - def pyRemoveSequence(name: String) = references(name).removeSequence() - - def pyExportBlockMatrix( - pathIn: String, - pathOut: String, - delimiter: String, - header: String, - addIndex: Boolean, - exportType: String, - partitionSize: java.lang.Integer, - entries: String, - ): Unit = - withExecuteContext { ctx => - val rm = RowMatrix.readBlockMatrix(fs, pathIn, partitionSize) - entries match { - case "full" => - rm.export(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "lower" => - rm.exportLowerTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "strict_lower" => - rm.exportStrictLowerTriangle( - ctx, - pathOut, - delimiter, - Option(header), - addIndex, - exportType, - ) - case "upper" => - rm.exportUpperTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "strict_upper" => - rm.exportStrictUpperTriangle( - ctx, - pathOut, - delimiter, - Option(header), - addIndex, - exportType, - ) - } - } - - def pyFitLinearMixedModel(lmm: LinearMixedModel, pa_t: RowMatrix, a_t: RowMatrix): TableIR = - withExecuteContext(selfContainedExecution = false)(ctx => lmm.fit(ctx, pa_t, Option(a_t))) - - def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = - withExecuteContext { ctx => - IRParser.parse_value_ir( - s, - IRParserEnvironment(ctx, irMap = persistedIR.toMap), - BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) => - Name(n) -> IRParser.parseType(t) - }.toSeq: _*), - ) - } - - def parse_table_ir(s: String): TableIR = - withExecuteContext(selfContainedExecution = false) { ctx => - IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - } - - def parse_matrix_ir(s: String): MatrixIR = - withExecuteContext(selfContainedExecution = false) { ctx => - IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - } - - def parse_blockmatrix_ir(s: String): BlockMatrixIR = - withExecuteContext(selfContainedExecution = false) { ctx => - IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) - } - override def lowerDistributedSort( ctx: ExecuteContext, stage: TableStage, diff --git a/hail/src/main/scala/is/hail/io/reference/LiftOver.scala b/hail/src/main/scala/is/hail/io/reference/LiftOver.scala deleted file mode 100644 index c24f81f6d5be..000000000000 --- a/hail/src/main/scala/is/hail/io/reference/LiftOver.scala +++ /dev/null @@ -1,69 +0,0 @@ -package is.hail.io.reference - -import is.hail.io.fs.FS -import is.hail.utils._ -import is.hail.variant.{Locus, ReferenceGenome} - -import scala.collection.JavaConverters._ - -object LiftOver { - def apply(fs: FS, chainFile: String): LiftOver = new LiftOver(fs, chainFile) -} - -class LiftOver(fs: FS, val chainFile: String) { - val lo = using(fs.open(chainFile))(new htsjdk.samtools.liftover.LiftOver(_, chainFile)) - - def queryInterval( - interval: is.hail.utils.Interval, - minMatch: Double = htsjdk.samtools.liftover.LiftOver.DEFAULT_LIFTOVER_MINMATCH, - ): (is.hail.utils.Interval, Boolean) = { - val start = interval.start.asInstanceOf[Locus] - val end = interval.end.asInstanceOf[Locus] - - if (start.contig != end.contig) - fatal(s"'start' and 'end' contigs must be identical. Found '$interval'.") - - val contig = start.contig - val startPos = if (interval.includesStart) start.position else start.position + 1 - val endPos = if (interval.includesEnd) end.position else end.position - 1 - - if (startPos == endPos) - fatal( - s"Cannot liftover a 0-length interval: ${interval.toString}.\nDid you mean to use 'liftover_locus'?" - ) - - val result = lo.liftOver(new htsjdk.samtools.util.Interval(contig, startPos, endPos), minMatch) - if (result != null) - ( - Interval( - Locus(result.getContig, result.getStart), - Locus(result.getContig, result.getEnd), - includesStart = true, - includesEnd = true, - ), - result.isNegativeStrand, - ) - else - null - } - - def queryLocus( - l: Locus, - minMatch: Double = htsjdk.samtools.liftover.LiftOver.DEFAULT_LIFTOVER_MINMATCH, - ): (Locus, Boolean) = { - val result = - lo.liftOver(new htsjdk.samtools.util.Interval(l.contig, l.position, l.position), minMatch) - if (result != null) - (Locus(result.getContig, result.getStart), result.isNegativeStrand) - else - null - } - - def checkChainFile(srcRG: ReferenceGenome, destRG: ReferenceGenome): Unit = { - val cMap = lo.getContigMap.asScala - cMap.foreach { case (srcContig, destContigs) => - srcRG.checkContig(srcContig) - destContigs.asScala.foreach(destRG.checkContig) - } - } -} diff --git a/hail/src/main/scala/is/hail/io/reference/package.scala b/hail/src/main/scala/is/hail/io/reference/package.scala new file mode 100644 index 000000000000..819863f2d019 --- /dev/null +++ b/hail/src/main/scala/is/hail/io/reference/package.scala @@ -0,0 +1,164 @@ +package is.hail.io + +import is.hail.io.fs.FS +import is.hail.io.reference.LiftOver.MinMatchDefault +import is.hail.utils.{fatal, toRichIterable, using, Interval} +import is.hail.variant.{Locus, ReferenceGenome} + +import scala.collection.convert.ImplicitConversions.{`collection AsScalaIterable`, `map AsScala`} +import scala.jdk.CollectionConverters.iterableAsScalaIterableConverter + +import htsjdk.samtools.reference.FastaSequenceIndexEntry + +package reference { + + /* ASSUMPTION: The following will not move or change for the entire duration of a hail pipeline: + * - chainFile + * - fastaFile + * - indexFile */ + + object LiftOver { + def apply(fs: FS, chainFile: String): LiftOver = { + val lo = new LiftOver(chainFile) + lo.restore(fs) + lo + } + + val MinMatchDefault: Double = htsjdk.samtools.liftover.LiftOver.DEFAULT_LIFTOVER_MINMATCH + } + + class LiftOver private (chainFile: String) extends Serializable { + + @transient var asJava: htsjdk.samtools.liftover.LiftOver = _ + + def queryInterval(interval: Interval, minMatch: Double = MinMatchDefault) + : (Interval, Boolean) = { + val start = interval.start.asInstanceOf[Locus] + val end = interval.end.asInstanceOf[Locus] + + if (start.contig != end.contig) + fatal(s"'start' and 'end' contigs must be identical. Found '$interval'.") + + val contig = start.contig + val startPos = if (interval.includesStart) start.position else start.position + 1 + val endPos = if (interval.includesEnd) end.position else end.position - 1 + + if (startPos == endPos) + fatal( + s"Cannot liftover a 0-length interval: ${interval.toString}.\nDid you mean to use 'liftover_locus'?" + ) + + val result = asJava.liftOver( + new htsjdk.samtools.util.Interval(contig, startPos, endPos), + minMatch, + ) + if (result != null) + ( + Interval( + Locus(result.getContig, result.getStart), + Locus(result.getContig, result.getEnd), + includesStart = true, + includesEnd = true, + ), + result.isNegativeStrand, + ) + else + null + } + + def queryLocus(l: Locus, minMatch: Double = MinMatchDefault): (Locus, Boolean) = { + val result = asJava.liftOver( + new htsjdk.samtools.util.Interval(l.contig, l.position, l.position), + minMatch, + ) + if (result != null) (Locus(result.getContig, result.getStart), result.isNegativeStrand) + else null + } + + def checkChainFile(srcRG: ReferenceGenome, destRG: ReferenceGenome): Unit = + asJava.getContigMap.foreach { case (srcContig, destContigs) => + srcRG.checkContig(srcContig) + destContigs.foreach(destRG.checkContig) + } + + def restore(fs: FS): Unit = { + if (!fs.isFile(chainFile)) + fatal(s"Chain file '$chainFile' does not exist, is not a file, or you do not have access.") + + using(fs.open(chainFile)) { is => + asJava = new htsjdk.samtools.liftover.LiftOver(is, chainFile) + } + } + } + + object IndexedFastaSequenceFile { + + def apply(fs: FS, fastaFile: String, indexFile: String): IndexedFastaSequenceFile = { + if (!fs.isFile(fastaFile)) + fatal(s"FASTA file '$fastaFile' does not exist, is not a file, or you do not have access.") + + new IndexedFastaSequenceFile(fastaFile, FastaSequenceIndex(fs, indexFile)) + } + + } + + class IndexedFastaSequenceFile private (val path: String, val index: FastaSequenceIndex) + extends Serializable { + + def raiseIfIncompatible(rg: ReferenceGenome): Unit = { + val jindex = index.asJava + + val missingContigs = rg.contigs.filterNot(jindex.hasIndexEntry) + if (missingContigs.nonEmpty) + fatal( + s"Contigs missing in FASTA '$path' that are present in reference genome '${rg.name}':\n " + + s"@1", + missingContigs.truncatable("\n "), + ) + + val invalidLengths = + for { + (contig, length) <- rg.lengths + fastaLength = jindex.getIndexEntry(contig).getSize + if fastaLength != length + } yield (contig, length, fastaLength) + + if (invalidLengths.nonEmpty) + fatal( + s"Contig sizes in FASTA '$path' do not match expected sizes for reference genome '${rg.name}':\n " + + s"@1", + invalidLengths.map { case (c, e, f) => s"$c\texpected:$e\tfound:$f" }.truncatable("\n "), + ) + } + + def restore(fs: FS): Unit = + index.restore(fs) + } + + object FastaSequenceIndex { + def apply(fs: FS, indexFile: String): FastaSequenceIndex = { + val index = new FastaSequenceIndex(indexFile) + index.restore(fs) + index + } + } + + class FastaSequenceIndex private (val path: String) + extends Iterable[FastaSequenceIndexEntry] with Serializable { + + @transient var asJava: htsjdk.samtools.reference.FastaSequenceIndex = _ + + def restore(fs: FS): Unit = { + if (!fs.isFile(path)) + fatal( + s"FASTA index file '$path' does not exist, is not a file, or you do not have access." + ) + + using(fs.open(path))(is => asJava = new htsjdk.samtools.reference.FastaSequenceIndex(is)) + } + + override def iterator: Iterator[FastaSequenceIndexEntry] = + asJava.asScala.iterator + } + +} diff --git a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala index 1887fb233b13..d8193581821d 100644 --- a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala +++ b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala @@ -2011,7 +2011,7 @@ class MatrixVCFReader( val fs = ctx.fs val sm = ctx.stateManager - val rgBc = referenceGenome.map(_.broadcast) + val rgBc = referenceGenome.map(ctx.backend.broadcast) val localArrayElementsRequired = params.arrayElementsRequired val localContigRecoding = params.contigRecoding val localSkipInvalidLoci = params.skipInvalidLoci diff --git a/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala b/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala index 7412b32ff9b4..594631a9a2be 100644 --- a/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala +++ b/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala @@ -1,15 +1,14 @@ package is.hail.variant -import is.hail.HailContext import is.hail.annotations.ExtendedOrdering -import is.hail.backend.{BroadcastValue, ExecuteContext} +import is.hail.backend.ExecuteContext import is.hail.check.Gen import is.hail.expr.{ JSONExtractContig, JSONExtractIntervalLocus, JSONExtractReferenceGenome, Parser, } import is.hail.expr.ir.RelationalSpec import is.hail.io.fs.FS -import is.hail.io.reference.{FASTAReader, FASTAReaderConfig, LiftOver} +import is.hail.io.reference.{FASTAReader, FASTAReaderConfig, FastaSequenceIndex, IndexedFastaSequenceFile, LiftOver} import is.hail.types._ import is.hail.types.virtual.{TLocus, Type} import is.hail.utils._ @@ -18,31 +17,10 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import java.io.{FileNotFoundException, InputStream} -import java.lang.ThreadLocal -import htsjdk.samtools.reference.FastaSequenceIndex -import org.apache.spark.TaskContext -import org.json4s._ +import org.json4s.{DefaultFormats, Extraction, Formats, JValue} import org.json4s.jackson.{JsonMethods, Serialization} -class BroadcastRG(rgParam: ReferenceGenome) extends Serializable { - @transient private[this] val rg: ReferenceGenome = rgParam - - private[this] val rgBc: BroadcastValue[ReferenceGenome] = - if (TaskContext.get != null) - null - else - rg.broadcast - - def value: ReferenceGenome = { - val t = if (rg != null) - rg - else - rgBc.value - t - } -} - case class ReferenceGenome( name: String, contigs: Array[String], @@ -53,7 +31,6 @@ case class ReferenceGenome( parInput: Array[(Locus, Locus)] = Array.empty[(Locus, Locus)], ) extends Serializable { - @transient lazy val broadcastRG: BroadcastRG = new BroadcastRG(this) val nContigs = contigs.length if (nContigs <= 0) @@ -165,9 +142,8 @@ case class ReferenceGenome( Interval(start, end, includesStart = true, includesEnd = false) } - private var fastaFilePath: String = _ - private var fastaIndexPath: String = _ - @transient private var fastaReaderCfg: FASTAReaderConfig = _ + private[this] var fastaFile: IndexedFastaSequenceFile = _ + @transient private[this] var fastaReaderCfg: FASTAReaderConfig = _ @transient lazy val contigParser = Parser.oneOfLiteral(contigs) @@ -365,50 +341,14 @@ case class ReferenceGenome( ) } - def hasSequence: Boolean = fastaFilePath != null + private def hasSequence: Boolean = fastaFile != null - def addSequence(ctx: ExecuteContext, fastaFile: String, indexFile: String): Unit = { + def addSequence(fasta: IndexedFastaSequenceFile): Unit = { if (hasSequence) fatal(s"FASTA sequence has already been loaded for reference genome '$name'.") - val tmpdir = ctx.localTmpdir - val fs = ctx.fs - if (!fs.isFile(fastaFile)) - fatal(s"FASTA file '$fastaFile' does not exist, is not a file, or you do not have access.") - if (!fs.isFile(indexFile)) - fatal( - s"FASTA index file '$indexFile' does not exist, is not a file, or you do not have access." - ) - fastaFilePath = fastaFile - fastaIndexPath = indexFile - - /* assumption, fastaFile and indexFile will not move or change for the entire duration of a hail - * pipeline */ - val index = using(fs.open(indexFile))(new FastaSequenceIndex(_)) - - val missingContigs = contigs.filterNot(index.hasIndexEntry) - if (missingContigs.nonEmpty) - fatal( - s"Contigs missing in FASTA '$fastaFile' that are present in reference genome '$name':\n " + - s"@1", - missingContigs.truncatable("\n "), - ) - - val invalidLengths = lengths.flatMap { case (c, l) => - val fastaLength = index.getIndexEntry(c).getSize - if (fastaLength != l) - Some((c, l, fastaLength)) - else - None - }.map { case (c, e, f) => s"$c\texpected:$e\tfound:$f" } - - if (invalidLengths.nonEmpty) - fatal( - s"Contig sizes in FASTA '$fastaFile' do not match expected sizes for reference genome '$name':\n " + - s"@1", - invalidLengths.truncatable("\n "), - ) - heal(tmpdir, fs) + fasta.raiseIfIncompatible(this) + fastaFile = fasta } @transient private lazy val realFastaReader: ThreadLocal[FASTAReader] = @@ -417,6 +357,7 @@ case class ReferenceGenome( private def fastaReader(): FASTAReader = { if (!hasSequence) fatal(s"FASTA file has not been loaded for reference genome '$name'.") + if (realFastaReader.get() == null) realFastaReader.set(fastaReaderCfg.reader) if (realFastaReader.get().cfg != fastaReaderCfg) @@ -436,93 +377,46 @@ case class ReferenceGenome( def removeSequence(): Unit = { if (!hasSequence) fatal(s"Reference genome '$name' does not have sequence loaded.") - fastaFilePath = null - fastaIndexPath = null + fastaFile = null fastaReaderCfg = null } - private var chainFiles: Map[String, String] = Map.empty - @transient private[this] lazy val liftoverMap: mutable.Map[String, LiftOver] = mutable.Map.empty - - def hasLiftover(destRGName: String): Boolean = chainFiles.contains(destRGName) + private[this] val liftovers: mutable.Map[String, LiftOver] = + mutable.Map.empty - def addLiftover(ctx: ExecuteContext, chainFile: String, destRGName: String): Unit = { - if (name == destRGName) + def addLiftover(destRef: ReferenceGenome, liftOver: LiftOver): Unit = { + if (name == destRef.name) fatal(s"Destination reference genome cannot have the same name as this reference '$name'") - if (hasLiftover(destRGName)) + if (liftovers.contains(destRef.name)) fatal( - s"Chain file already exists for source reference '$name' and destination reference '$destRGName'." + s"LiftOver already exists for source reference '$name' and destination reference '${destRef.name}'." ) - val tmpdir = ctx.localTmpdir - val fs = ctx.fs - - if (!fs.isFile(chainFile)) - fatal(s"Chain file '$chainFile' does not exist, is not a file, or you do not have access.") - - val chainFilePath = fs.parseUrl(chainFile).toString - val lo = LiftOver(fs, chainFilePath) - val destRG = ctx.getReference(destRGName) - lo.checkChainFile(this, destRG) - - chainFiles += destRGName -> chainFile - heal(tmpdir, fs) + liftOver.checkChainFile(this, destRef) + liftovers += destRef.name -> liftOver } - def getLiftover(destRGName: String): LiftOver = { - if (!hasLiftover(destRGName)) - fatal( - s"Chain file has not been loaded for source reference '$name' and destination reference '$destRGName'." - ) - liftoverMap(destRGName) - } + def removeLiftover(destRGName: String): Unit = + liftovers -= destRGName - def removeLiftover(destRGName: String): Unit = { - if (!hasLiftover(destRGName)) - fatal(s"liftover does not exist from reference genome '$name' to '$destRGName'.") - chainFiles -= destRGName - liftoverMap -= destRGName - } - - def liftoverLocus(destRGName: String, l: Locus, minMatch: Double): (Locus, Boolean) = { - val lo = getLiftover(destRGName) - lo.queryLocus(l, minMatch) - } + def liftoverLocus(destRGName: String, l: Locus, minMatch: Double): (Locus, Boolean) = + liftovers(destRGName).queryLocus(l, minMatch) def liftoverLocusInterval(destRGName: String, interval: Interval, minMatch: Double) - : (Interval, Boolean) = { - val lo = getLiftover(destRGName) - lo.queryInterval(interval, minMatch) - } + : (Interval, Boolean) = + liftovers(destRGName).queryInterval(interval, minMatch) def heal(tmpdir: String, fs: FS): Unit = synchronized { - // Add liftovers - // NOTE: it shouldn't be possible for the liftover map to have more elements than the chain file - // since removeLiftover updates both maps, so we don't check to see if liftoverMap has - // keys that are not in chainFiles - for ((destRGName, chainFile) <- chainFiles) { - val chainFilePath = fs.parseUrl(chainFile).toString - liftoverMap.get(destRGName) match { - case Some(lo) if lo.chainFile == chainFilePath => // do nothing - case _ => liftoverMap += destRGName -> LiftOver(fs, chainFilePath) - } - } - + liftovers.values.foreach(_.restore(fs)) // add sequence - if (fastaFilePath != null) { - val fastaPath = fs.parseUrl(fastaFilePath).toString - val indexPath = fs.parseUrl(fastaIndexPath).toString - if ( - fastaReaderCfg == null || fastaReaderCfg.fastaFile != fastaPath || fastaReaderCfg.indexFile != indexPath - ) { - fastaReaderCfg = FASTAReaderConfig(tmpdir, fs, this, fastaPath, indexPath) - } + if (fastaFile != null) { + fastaFile.restore(fs) + val fastaPath = fs.parseUrl(fastaFile.path).toString + val indexPath = fs.parseUrl(fastaFile.index.path).toString + fastaReaderCfg = FASTAReaderConfig(tmpdir, fs, this, fastaPath, indexPath) } } - @transient lazy val broadcast: BroadcastValue[ReferenceGenome] = - HailContext.backend.broadcast(this) - override def hashCode: Int = { import org.apache.commons.lang3.builder.HashCodeBuilder @@ -555,9 +449,14 @@ case class ReferenceGenome( override def toString: String = name + implicit private[this] val fmts: Formats = DefaultFormats + def write(fs: is.hail.io.fs.FS, file: String): Unit = - using(fs.create(file)) { out => - val jrg = JSONExtractReferenceGenome( + using(fs.create(file))(out => Serialization.write(toJSON, out)) + + def toJSON: JValue = + Extraction.decompose( + JSONExtractReferenceGenome( name, contigs.map(contig => JSONExtractContig(contig, contigLength(contig))), xContigs, @@ -567,23 +466,8 @@ case class ReferenceGenome( JSONExtractIntervalLocus(i.start.asInstanceOf[Locus], i.end.asInstanceOf[Locus]) ), ) - implicit val formats: Formats = defaultJSONFormats - Serialization.write(jrg, out) - } + ) - def toJSON: JSONExtractReferenceGenome = JSONExtractReferenceGenome( - name, - contigs.map(contig => JSONExtractContig(contig, contigLength(contig))), - xContigs, - yContigs, - mtContigs, - par.map(i => JSONExtractIntervalLocus(i.start.asInstanceOf[Locus], i.end.asInstanceOf[Locus])), - ) - - def toJSONString: String = { - implicit val formats: Formats = defaultJSONFormats - Serialization.write(toJSON) - } } object ReferenceGenome { @@ -644,17 +528,11 @@ object ReferenceGenome { if (!fs.isFile(fastaFile)) fatal(s"FASTA file '$fastaFile' does not exist, is not a file, or you do not have access.") - if (!fs.isFile(indexFile)) - fatal( - s"FASTA index file '$indexFile' does not exist, is not a file, or you do not have access." - ) - - val index = using(fs.open(indexFile))(new FastaSequenceIndex(_)) val contigs = new BoxedArrayBuilder[String] val lengths = new BoxedArrayBuilder[(String, Int)] - index.iterator().asScala.foreach { entry => + FastaSequenceIndex(fs, indexFile).foreach { entry => val contig = entry.getContig val length = entry.getSize contigs += contig diff --git a/hail/src/test/scala/is/hail/HailSuite.scala b/hail/src/test/scala/is/hail/HailSuite.scala index bccfee74c3fd..50b25f5b4041 100644 --- a/hail/src/test/scala/is/hail/HailSuite.scala +++ b/hail/src/test/scala/is/hail/HailSuite.scala @@ -18,7 +18,6 @@ import org.apache.spark.sql.Row import org.scalatestplus.testng.TestNGSuite import org.testng.ITestContext import org.testng.annotations.{AfterMethod, BeforeClass, BeforeMethod} -import sourcecode.Enclosing object HailSuite { val theHailClassLoader = TestUtils.theHailClassLoader @@ -92,9 +91,6 @@ class HailSuite extends TestNGSuite { throw new RuntimeException(s"method stopped spark context!") } - def withExecuteContext[T]()(f: ExecuteContext => T)(implicit E: Enclosing): T = - hc.sparkBackend("HailSuite.withExecuteContext").withExecuteContext(f) - def assertEvalsTo( x: IR, env: Env[(Any, Type)], diff --git a/hail/src/test/scala/is/hail/annotations/UnsafeSuite.scala b/hail/src/test/scala/is/hail/annotations/UnsafeSuite.scala index d954822f0baf..81a2a82e3a7b 100644 --- a/hail/src/test/scala/is/hail/annotations/UnsafeSuite.scala +++ b/hail/src/test/scala/is/hail/annotations/UnsafeSuite.scala @@ -55,7 +55,7 @@ class UnsafeSuite extends HailSuite { @DataProvider(name = "codecs") def codecs(): Array[Array[Any]] = - withExecuteContext()(ctx => codecs(ctx)) + ExecuteContext.scoped(ctx => codecs(ctx)) def codecs(ctx: ExecuteContext): Array[Array[Any]] = (BufferSpec.specs ++ Array(TypedCodecSpec( diff --git a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala index 57b17a1c59c8..83a5fa9d5937 100644 --- a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala @@ -117,6 +117,7 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV jarLocation = "us-docker.pkg.dev/hail-vdc/hail/hailgenetics/hail@sha256:fake", name = "name", theHailClassLoader = new HailClassLoader(getClass.getClassLoader), + references = Map.empty, batchClient = client, curBatchId = None, curJobGroupId = None, diff --git a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala index b97fbd1893fa..75a64588c487 100644 --- a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala @@ -3250,7 +3250,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "valueIRs") def valueIRs(): Array[Array[Object]] = - withExecuteContext()(ctx => valueIRs(ctx)) + ExecuteContext.scoped(ctx => valueIRs(ctx)) def valueIRs(ctx: ExecuteContext): Array[Array[Object]] = { val fs = ctx.fs @@ -3595,7 +3595,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "tableIRs") def tableIRs(): Array[Array[TableIR]] = - withExecuteContext()(ctx => tableIRs(ctx)) + ExecuteContext.scoped(ctx => tableIRs(ctx)) def tableIRs(ctx: ExecuteContext): Array[Array[TableIR]] = { try { @@ -3704,7 +3704,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "matrixIRs") def matrixIRs(): Array[Array[MatrixIR]] = - withExecuteContext()(ctx => matrixIRs(ctx)) + ExecuteContext.scoped(ctx => matrixIRs(ctx)) def matrixIRs(ctx: ExecuteContext): Array[Array[MatrixIR]] = { try { diff --git a/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala b/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala index fd95b492a6a6..7c3597d6c4ab 100644 --- a/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala @@ -1,6 +1,7 @@ package is.hail.expr.ir import is.hail.HailSuite +import is.hail.backend.ExecuteContext import is.hail.expr.Nat import is.hail.expr.ir.agg.CallStatsState import is.hail.io.{BufferSpec, TypedCodecSpec} @@ -102,7 +103,7 @@ class RequirednessSuite extends HailSuite { def pinterval(point: PType, r: Boolean): PInterval = PCanonicalInterval(point, r) @DataProvider(name = "valueIR") - def valueIR(): Array[Array[Any]] = withExecuteContext() { ctx => + def valueIR(): Array[Array[Any]] = ExecuteContext.scoped { ctx => val nodes = new BoxedArrayBuilder[Array[Any]](50) val allRequired = Array( diff --git a/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala b/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala index 1f7c361f914f..41ac9aa65e32 100644 --- a/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala +++ b/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala @@ -1,11 +1,11 @@ package is.hail.variant import is.hail.{HailSuite, TestUtils} -import is.hail.backend.HailStateManager +import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.check.Prop._ import is.hail.check.Properties import is.hail.expr.ir.EmitFunctionBuilder -import is.hail.io.reference.{FASTAReader, FASTAReaderConfig} +import is.hail.io.reference.{FASTAReader, FASTAReaderConfig, LiftOver} import is.hail.types.virtual.TLocus import is.hail.utils._ @@ -222,7 +222,7 @@ class ReferenceGenomeSuite extends HailSuite { } @Test def testSerializeOnFB(): Unit = { - withExecuteContext() { ctx => + ExecuteContext.scoped { ctx => val grch38 = ctx.getReference(ReferenceGenome.GRCh38) val fb = EmitFunctionBuilder[String, Boolean](ctx, "serialize_rg") val rgfield = fb.getReferenceGenome(grch38.name) @@ -234,11 +234,11 @@ class ReferenceGenomeSuite extends HailSuite { } @Test def testSerializeWithLiftoverOnFB(): Unit = { - withExecuteContext() { ctx => + ExecuteContext.scoped { ctx => val grch37 = ctx.getReference(ReferenceGenome.GRCh37) val liftoverFile = "src/test/resources/grch37_to_grch38_chr20.over.chain.gz" - grch37.addLiftover(ctx, liftoverFile, "GRCh38") + grch37.addLiftover(ctx.references("GRCh38"), LiftOver(ctx.fs, liftoverFile)) val fb = EmitFunctionBuilder[String, Locus, Double, (Locus, Boolean)](ctx, "serialize_with_liftover")