Skip to content

Commit

Permalink
[query] Unify CloudCredentials and Simplify BatchClient (#14684)
Browse files Browse the repository at this point in the history
This change combines cloud auth logic that was previously duplicated
between the various `FS` implementations and the `BatchClient`. 

The main refactoring is to make the interface between the
`ServiceBackend` more
high-level and leave json serialisation to the `BatchClient`. To do
this, I've
added a bunch of case classes that resemble the python objects the batch
service
expects (or a subset of the data). To simplify the interface, I've split
batch
creation from job submission (update). For QoB, the python client
creates the
batch before handing control to the query driver; batch creation is
necessary
for testing only.

This change has low security impact as there are minor changes to the
creation
and scoping of service account credentials. Note that for each `FS`,
credentials
are scoped to the default storage oauth2 scopes for each service.
  • Loading branch information
ehigham authored Dec 16, 2024
1 parent 0f8d35b commit 48f2925
Show file tree
Hide file tree
Showing 31 changed files with 1,072 additions and 1,012 deletions.
1 change: 1 addition & 0 deletions build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3723,6 +3723,7 @@ steps:
- hail_run_image
- build_debug_hail_test_jar
- build_hail_test_artifacts
- upload_query_jar
- deploy_batch
- kind: runImage
name: start_hail_benchmark
Expand Down
5 changes: 3 additions & 2 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def decode_bytearray(encoded):
self._jbackend = jbackend
self._jhc = jhc

self._backend_server = self._hail_package.backend.BackendServer.apply(self._jbackend)
self._backend_server = self._hail_package.backend.BackendServer(self._jbackend)
self._backend_server_port: int = self._backend_server.port()
self._backend_server.start()
self._requests_session = requests.Session()
Expand Down Expand Up @@ -306,7 +306,8 @@ def _to_java_blockmatrix_ir(self, ir):
return self._parse_blockmatrix_ir(self._render_ir(ir))

def stop(self):
self._backend_server.stop()
self._backend_server.close()
self._jbackend.close()
self._jhc.stop()
self._jhc = None
self._registered_ir_function_names = set()
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ async def init_batch(

log = _get_log(log)
if tmpdir is None:
tmpdir = backend.remote_tmpdir + 'tmp/hail/' + secret_alnum_string()
tmpdir = os.path.join(backend.remote_tmpdir, 'tmp/hail', secret_alnum_string())
local_tmpdir = _get_local_tmpdir(local_tmpdir)

HailContext.create(log, quiet, append, tmpdir, local_tmpdir, default_reference, global_seed, backend)
Expand Down
5 changes: 2 additions & 3 deletions hail/python/hailtop/config/user_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,5 @@ def get_remote_tmpdir(
raise ValueError(
f'remote_tmpdir must be a storage uri path like gs://bucket/folder. Received: {remote_tmpdir}. Possible schemes include gs for GCP and https for Azure'
)
if remote_tmpdir[-1] != '/':
remote_tmpdir += '/'
return remote_tmpdir

return remote_tmpdir.rstrip('/')
1 change: 0 additions & 1 deletion hail/python/hailtop/hailctl/batch/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ async def submit(name, image_name, files, output, script, arguments):
quiet = output != 'text'

remote_tmpdir = get_remote_tmpdir('hailctl batch submit')
remote_tmpdir = remote_tmpdir.rstrip('/')

tmpdir_path_prefix = secret_alnum_string()

Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/HailContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ object HailContext {

def stop(): Unit = synchronized {
IRFunctionRegistry.clearUserFunctions()
backend.stop()
backend.close()

theContext = null
}
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/HailFeatureFlags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object HailFeatureFlags {
),
)

def fromMap(m: Map[String, String]): HailFeatureFlags =
def fromEnv(m: Map[String, String] = sys.env): HailFeatureFlags =
new HailFeatureFlags(
mutable.Map(
HailFeatureFlags.defaults.map {
Expand Down
4 changes: 1 addition & 3 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ trait BackendContext {
def executionCache: ExecutionCache
}

abstract class Backend {
abstract class Backend extends Closeable {
// From https://github.com/hail-is/hail/issues/14580 :
// IR can get quite big, especially as it can contain an arbitrary
// amount of encoded literals from the user's python session. This
Expand Down Expand Up @@ -123,8 +123,6 @@ abstract class Backend {
f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte]
): (Option[Throwable], IndexedSeq[(Array[Byte], Int)])

def stop(): Unit

def asSpark(op: String): SparkBackend =
fatal(s"${getClass.getSimpleName}: $op requires SparkBackend")

Expand Down
9 changes: 3 additions & 6 deletions hail/src/main/scala/is/hail/backend/BackendServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import is.hail.utils._

import scala.util.control.NonFatal

import java.io.Closeable
import java.net.InetSocketAddress
import java.nio.charset.StandardCharsets
import java.util.concurrent._
Expand All @@ -31,11 +32,7 @@ case class ParseVCFMetadataPayload(path: String)
case class ImportFamPayload(path: String, quant_pheno: Boolean, delimiter: String, missing: String)
case class ExecutePayload(ir: String, stream_codec: String, timed: Boolean)

object BackendServer {
def apply(backend: Backend) = new BackendServer(backend)
}

class BackendServer(backend: Backend) {
class BackendServer(backend: Backend) extends Closeable {
// 0 => let the OS pick an available port
private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10)
private[this] val handler = new BackendHttpHandler(backend)
Expand Down Expand Up @@ -77,7 +74,7 @@ class BackendServer(backend: Backend) {
def start(): Unit =
thread.start()

def stop(): Unit =
override def close(): Unit =
httpServer.stop(10)
}

Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ object LocalBackend {

class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache {

private[this] val flags = HailFeatureFlags.fromMap(sys.env)
private[this] val flags = HailFeatureFlags.fromEnv()
private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader())

def getFlag(name: String): String = flags.get(name)
Expand All @@ -78,7 +78,7 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache
flags.available

// flags can be set after construction from python
def fs: FS = FS.buildRoutes(None, Some(flags), sys.env)
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))

override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
ExecutionTimer.logTime { timer =>
Expand Down Expand Up @@ -137,7 +137,7 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache

def defaultParallelism: Int = 1

def stop(): Unit = LocalBackend.stop()
def close(): Unit = LocalBackend.stop()

private[this] def _jvmLowerAndExecute(
ctx: ExecuteContext,
Expand Down
6 changes: 6 additions & 0 deletions hail/src/main/scala/is/hail/backend/service/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.fasterxml.jackson.core.StreamReadConstraints
object Main {
val WORKER = "worker"
val DRIVER = "driver"
val TEST = "test"

/* This constraint should be overridden as early as possible. See:
* - https://github.com/hail-is/hail/issues/14580
Expand All @@ -18,6 +19,11 @@ object Main {
argv(3) match {
case WORKER => Worker.main(argv)
case DRIVER => ServiceBackendAPI.main(argv)

// Batch's "JvmJob" is a special kind of job that can only call `Main.main`.
// TEST is used for integration testing the `BatchClient` to verify that we
// can create JvmJobs without having to mock the payload to a `Worker` job.
case TEST => ()
case kind => throw new RuntimeException(s"unknown kind: $kind")
}
}
Loading

0 comments on commit 48f2925

Please sign in to comment.