Skip to content

Commit

Permalink
[query] Extract Backend Methods called from Python into `Py4JBackendE…
Browse files Browse the repository at this point in the history
…xtensions`
  • Loading branch information
ehigham committed Dec 17, 2024
1 parent 9b8fb6d commit 79ecea8
Show file tree
Hide file tree
Showing 18 changed files with 406 additions and 434 deletions.
8 changes: 4 additions & 4 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,17 @@ def _rpc(self, action, payload) -> Tuple[bytes, Optional[dict]]:

def persist_expression(self, expr):
t = expr.dtype
return construct_expr(JavaIR(t, self._jbackend.executeLiteral(self._render_ir(expr._ir))), t)
return construct_expr(JavaIR(t, self._jbackend.pyExecuteLiteral(self._render_ir(expr._ir))), t)

def _is_registered_ir_function_name(self, name: str) -> bool:
return name in self._registered_ir_function_names

def set_flags(self, **flags: Mapping[str, str]):
available = self._jbackend.availableFlags()
available = self._jbackend.pyAvailableFlags()
invalid = []
for flag, value in flags.items():
if flag in available:
self._jbackend.setFlag(flag, value)
self._jbackend.pySetFlag(flag, value)
else:
invalid.append(flag)
if len(invalid) != 0:
Expand All @@ -256,7 +256,7 @@ def set_flags(self, **flags: Mapping[str, str]):
)

def get_flags(self, *flags) -> Mapping[str, str]:
return {flag: self._jbackend.getFlag(flag) for flag in flags}
return {flag: self._jbackend.pyGetFlag(flag) for flag in flags}

def _add_reference_to_scala_backend(self, rg):
self._jbackend.pyAddReference(orjson.dumps(rg._config).decode('utf-8'))
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3880,7 +3880,7 @@ def __del__(self):
if Env._hc:
backend = Env.backend()
assert isinstance(backend, Py4JBackend)
backend._jbackend.removeJavaIR(self._id)
backend._jbackend.pyRemoveJavaIR(self._id)


class JavaIR(IR):
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/ir/table_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,4 +1215,4 @@ def __del__(self):
if Env._hc:
backend = Env.backend()
assert isinstance(backend, Py4JBackend)
backend._jbackend.removeJavaIR(self._id)
backend._jbackend.pyRemoveJavaIR(self._id)
2 changes: 1 addition & 1 deletion hail/python/test/hail/genetics/test_reference_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def assert_rg_loaded_correctly(name):
# loading different reference genome with same name should fail
# (different `test_rg_o` definition)
with pytest.raises(FatalError):
hl.read_matrix_table(resource('custom_references_2.t')).count()
hl.read_table(resource('custom_references_2.t')).count()

assert hl.read_matrix_table(resource('custom_references.mt')).count_rows() == 14
assert_rg_loaded_correctly('test_rg_1')
Expand Down
115 changes: 19 additions & 96 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package is.hail.backend

import is.hail.asm4s._
import is.hail.backend.Backend.jsonToBytes
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.{
BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses,
SortField, TableIR, TableReader,
}
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.fs._
Expand All @@ -20,16 +20,14 @@ import is.hail.types.virtual.{BlockMatrixType, TFloat64}
import is.hail.utils._
import is.hail.variant.ReferenceGenome

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag

import java.io._
import java.nio.charset.StandardCharsets

import com.fasterxml.jackson.core.StreamReadConstraints
import org.json4s._
import org.json4s.jackson.{JsonMethods, Serialization}
import org.json4s.jackson.JsonMethods
import sourcecode.Enclosing

object Backend {
Expand All @@ -41,13 +39,6 @@ object Backend {
s"hail_query_$id"
}

private var irID: Int = 0

def nextIRID(): Int = {
irID += 1
irID
}

def encodeToOutputStream(
ctx: ExecuteContext,
t: PTuple,
Expand All @@ -66,6 +57,9 @@ object Backend {
assert(t.isFieldDefined(off, 0))
codec.encode(ctx, elementType, t.loadField(off, 0), os)
}

def jsonToBytes(f: => JValue): Array[Byte] =
JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8)
}

abstract class BroadcastValue[T] { def value: T }
Expand All @@ -75,28 +69,8 @@ trait BackendContext {
}

abstract class Backend extends Closeable {
// From https://github.com/hail-is/hail/issues/14580 :
// IR can get quite big, especially as it can contain an arbitrary
// amount of encoded literals from the user's python session. This
// was a (controversial) restriction imposed by Jackson and should be lifted.
//
// We remove this restriction for all backends, and we do so here, in the
// constructor since constructing a backend is one of the first things that
// happens and this constraint should be overrided as early as possible.
StreamReadConstraints.overrideDefaultStreamReadConstraints(
StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build()
)

val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()

protected[this] def addJavaIR(ir: BaseIR): Int = {
val id = Backend.nextIRID()
persistedIR += (id -> ir)
id
}

def removeJavaIR(id: Int): Unit = persistedIR.remove(id)

def defaultParallelism: Int

def canExecuteParallelTasksOnDriver: Boolean = true
Expand Down Expand Up @@ -131,30 +105,7 @@ abstract class Backend extends Closeable {
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T]

var references: Map[String, ReferenceGenome] = Map.empty

def addDefaultReferences(): Unit =
references = ReferenceGenome.builtinReferences()

def addReference(rg: ReferenceGenome): Unit = {
references.get(rg.name) match {
case Some(rg2) =>
if (rg != rg2) {
fatal(
s"Cannot add reference genome '${rg.name}', a different reference with that name already exists. Choose a reference name NOT in the following list:\n " +
s"@1",
references.keys.truncatable("\n "),
)
}
case None =>
references += (rg.name -> rg)
}
}

def hasReference(name: String) = references.contains(name)

def removeReference(name: String): Unit =
references -= name
def references: mutable.Map[String, ReferenceGenome]

def lowerDistributedSort(
ctx: ExecuteContext,
Expand Down Expand Up @@ -189,9 +140,6 @@ abstract class Backend extends Closeable {

def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T

private[this] def jsonToBytes(f: => JValue): Array[Byte] =
JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8)

final def valueType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext { ctx =>
Expand Down Expand Up @@ -220,15 +168,7 @@ abstract class Backend extends Closeable {
}
}

def loadReferencesFromDataset(path: String): Array[Byte] = {
withExecuteContext { ctx =>
val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path)
rgs.foreach(addReference)

implicit val formats: Formats = defaultJSONFormats
Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8)
}
}
def loadReferencesFromDataset(path: String): Array[Byte]

def fromFASTAFile(
name: String,
Expand All @@ -240,18 +180,22 @@ abstract class Backend extends Closeable {
parInput: Array[String],
): Array[Byte] =
withExecuteContext { ctx =>
val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput)
rg.toJSONString.getBytes(StandardCharsets.UTF_8)
jsonToBytes {
Extraction.decompose {
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput).toJSON
}(defaultJSONFormats)
}
}

def parseVCFMetadata(path: String): Array[Byte] = jsonToBytes {
def parseVCFMetadata(path: String): Array[Byte] =
withExecuteContext { ctx =>
val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path)
implicit val formats = defaultJSONFormats
Extraction.decompose(metadata)
jsonToBytes {
Extraction.decompose {
LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path)
}(defaultJSONFormats)
}
}
}

def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String)
: Array[Byte] =
Expand All @@ -261,27 +205,6 @@ abstract class Backend extends Closeable {
)
}

def pyRegisterIR(
name: String,
typeParamStrs: java.util.ArrayList[String],
argNameStrs: java.util.ArrayList[String],
argTypeStrs: java.util.ArrayList[String],
returnType: String,
bodyStr: String,
): Unit = {
withExecuteContext { ctx =>
IRFunctionRegistry.registerIR(
ctx,
name,
typeParamStrs.asScala.toArray,
argNameStrs.asScala.toArray,
argTypeStrs.asScala.toArray,
returnType,
bodyStr,
)
}
}

def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
}

Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/backend/BackendServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler {
}
return
}

val response: Array[Byte] = exchange.getRequestURI.getPath match {
case "/value/type" => backend.valueType(body.extract[IRTypePayload].ir)
case "/table/type" => backend.tableType(body.extract[IRTypePayload].ir)
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ExecuteContext(
)
}

val stateManager = HailStateManager(backend.references)
def stateManager = HailStateManager(backend.references.toMap)

val tempFileManager: TempFileManager =
if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs)
Expand Down
Loading

0 comments on commit 79ecea8

Please sign in to comment.