Skip to content

Commit

Permalink
[query] Remove lookupOrCompileCachedFunction from Backend interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Jan 13, 2025
1 parent 96e508a commit b1732cc
Show file tree
Hide file tree
Showing 17 changed files with 144 additions and 162 deletions.
26 changes: 1 addition & 25 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ import is.hail.asm4s._
import is.hail.backend.Backend.jsonToBytes
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.{
BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses,
SortField, TableIR, TableReader,
BaseIR, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader,
}
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.io.{BufferSpec, TypedCodecSpec}
Expand Down Expand Up @@ -92,9 +91,6 @@ abstract class Backend extends Closeable {

def shouldCacheQueryInfo: Boolean = true

def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T]

def lowerDistributedSort(
ctx: ExecuteContext,
stage: TableStage,
Expand Down Expand Up @@ -193,23 +189,3 @@ abstract class Backend extends Closeable {

def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
}

trait BackendWithCodeCache {
private[this] val codeCache: Cache[CodeCacheKey, CompiledFunction[_]] = new Cache(50)

def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T] = {
codeCache.get(k) match {
case Some(v) => v.asInstanceOf[CompiledFunction[T]]
case None =>
val compiledFunction = f
codeCache += ((k, compiledFunction))
compiledFunction
}
}
}

trait BackendWithNoCodeCache {
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T] = f
}
6 changes: 6 additions & 0 deletions hail/src/main/scala/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags}
import is.hail.annotations.{Region, RegionPool}
import is.hail.asm4s.HailClassLoader
import is.hail.backend.local.LocalTaskContext
import is.hail.expr.ir.{CodeCacheKey, CompiledFunction}
import is.hail.expr.ir.lowering.IrMetadata
import is.hail.io.fs.FS
import is.hail.linalg.BlockMatrix
Expand Down Expand Up @@ -73,6 +74,7 @@ object ExecuteContext {
backendContext: BackendContext,
irMetadata: IrMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix],
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
)(
f: ExecuteContext => T
): T = {
Expand All @@ -92,6 +94,7 @@ object ExecuteContext {
backendContext,
irMetadata,
blockMatrixCache,
codeCache,
))(f(_))
}
}
Expand Down Expand Up @@ -122,6 +125,7 @@ class ExecuteContext(
val backendContext: BackendContext,
val irMetadata: IrMetadata,
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
) extends Closeable {

val rngNonce: Long =
Expand Down Expand Up @@ -191,6 +195,7 @@ class ExecuteContext(
backendContext: BackendContext = this.backendContext,
irMetadata: IrMetadata = this.irMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
)(
f: ExecuteContext => A
): A =
Expand All @@ -208,5 +213,6 @@ class ExecuteContext(
backendContext,
irMetadata,
blockMatrixCache,
codeCache,
))(f)
}
7 changes: 5 additions & 2 deletions hail/src/main/scala/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
import is.hail.types._
Expand Down Expand Up @@ -83,13 +84,14 @@ object LocalBackend {
class LocalBackend(
val tmpdir: String,
override val references: mutable.Map[String, ReferenceGenome],
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
) extends Backend with Py4JBackendExtensions {

override def backend: Backend = this
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
override def longLifeTempFileManager: TempFileManager = null

private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader())
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)

// flags can be set after construction from python
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
Expand All @@ -113,6 +115,7 @@ class LocalBackend(
},
new IrMetadata(),
ImmutableMap.empty,
codeCache,
)(f)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import is.hail.asm4s._
import is.hail.backend._
import is.hail.expr.Validate
import is.hail.expr.ir.{
Compile, IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader,
TypeCheck,
IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck,
}
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
Expand Down Expand Up @@ -51,7 +51,6 @@ class ServiceBackendContext(
) extends BackendContext with Serializable {}

object ServiceBackend {
private val log = Logger.getLogger(getClass.getName())

def apply(
jarLocation: String,
Expand Down Expand Up @@ -130,8 +129,7 @@ class ServiceBackend(
val fs: FS,
val serviceBackendContext: ServiceBackendContext,
val scratchDir: String,
) extends Backend with BackendWithNoCodeCache {
import ServiceBackend.log
) extends Backend with Logging {

private[this] var stageCount = 0
private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000
Expand Down Expand Up @@ -393,6 +391,7 @@ class ServiceBackend(
serviceBackendContext,
new IrMetadata(),
ImmutableMap.empty,
mutable.Map.empty,
)(f)
}

Expand Down
9 changes: 6 additions & 3 deletions hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.lowering._
import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.fs._
Expand Down Expand Up @@ -320,7 +321,7 @@ class SparkBackend(
override val references: mutable.Map[String, ReferenceGenome],
gcsRequesterPaysProject: String,
gcsRequesterPaysBuckets: String,
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
) extends Backend with Py4JBackendExtensions {

assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null)
lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate()
Expand Down Expand Up @@ -351,8 +352,8 @@ class SparkBackend(
override val longLifeTempFileManager: TempFileManager =
new OwningTempFileManager(fs)

private[this] val bmCache: BlockMatrixCache =
new BlockMatrixCache()
private[this] val bmCache = new BlockMatrixCache()
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)

def createExecuteContextForTests(
timer: ExecutionTimer,
Expand All @@ -376,6 +377,7 @@ class SparkBackend(
},
new IrMetadata(),
ImmutableMap.empty,
mutable.Map.empty,
)

override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
Expand All @@ -396,6 +398,7 @@ class SparkBackend(
},
new IrMetadata(),
bmCache,
codeCache,
)(f)
}

Expand Down
Loading

0 comments on commit b1732cc

Please sign in to comment.