diff --git a/build.yaml b/build.yaml index b3b3eb68d1c..3d6645220cd 100644 --- a/build.yaml +++ b/build.yaml @@ -3720,6 +3720,7 @@ steps: - hail_run_image - build_debug_hail_test_jar - build_hail_test_artifacts + - upload_query_jar - deploy_batch - kind: runImage name: start_hail_benchmark diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index 49a7d9b8e14..2c1610b5614 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -187,7 +187,7 @@ def decode_bytearray(encoded): self._jbackend = jbackend self._jhc = jhc - self._backend_server = self._hail_package.backend.BackendServer.apply(self._jbackend) + self._backend_server = self._hail_package.backend.BackendServer(self._jbackend) self._backend_server_port: int = self._backend_server.port() self._backend_server.start() self._requests_session = requests.Session() @@ -306,7 +306,8 @@ def _to_java_blockmatrix_ir(self, ir): return self._parse_blockmatrix_ir(self._render_ir(ir)) def stop(self): - self._backend_server.stop() + self._backend_server.close() + self._jbackend.close() self._jhc.stop() self._jhc = None self._registered_ir_function_names = set() diff --git a/hail/python/hail/context.py b/hail/python/hail/context.py index 5258f27fbc1..77493ecd083 100644 --- a/hail/python/hail/context.py +++ b/hail/python/hail/context.py @@ -573,7 +573,7 @@ async def init_batch( log = _get_log(log) if tmpdir is None: - tmpdir = backend.remote_tmpdir + 'tmp/hail/' + secret_alnum_string() + tmpdir = os.path.join(backend.remote_tmpdir, 'tmp/hail', secret_alnum_string()) local_tmpdir = _get_local_tmpdir(local_tmpdir) HailContext.create(log, quiet, append, tmpdir, local_tmpdir, default_reference, global_seed, backend) diff --git a/hail/python/hailtop/config/user_config.py b/hail/python/hailtop/config/user_config.py index 752031eb166..55114bba48b 100644 --- a/hail/python/hailtop/config/user_config.py +++ b/hail/python/hailtop/config/user_config.py @@ -144,6 +144,5 @@ def get_remote_tmpdir( raise ValueError( f'remote_tmpdir must be a storage uri path like gs://bucket/folder. Received: {remote_tmpdir}. Possible schemes include gs for GCP and https for Azure' ) - if remote_tmpdir[-1] != '/': - remote_tmpdir += '/' - return remote_tmpdir + + return remote_tmpdir.rstrip('/') diff --git a/hail/python/hailtop/hailctl/batch/submit.py b/hail/python/hailtop/hailctl/batch/submit.py index 21547f8e0b5..196b844ab23 100644 --- a/hail/python/hailtop/hailctl/batch/submit.py +++ b/hail/python/hailtop/hailctl/batch/submit.py @@ -29,7 +29,6 @@ async def submit(name, image_name, files, output, script, arguments): quiet = output != 'text' remote_tmpdir = get_remote_tmpdir('hailctl batch submit') - remote_tmpdir = remote_tmpdir.rstrip('/') tmpdir_path_prefix = secret_alnum_string() diff --git a/hail/src/main/scala/is/hail/HailContext.scala b/hail/src/main/scala/is/hail/HailContext.scala index 1d2fe403762..3de89ce13cd 100644 --- a/hail/src/main/scala/is/hail/HailContext.scala +++ b/hail/src/main/scala/is/hail/HailContext.scala @@ -136,7 +136,7 @@ object HailContext { def stop(): Unit = synchronized { IRFunctionRegistry.clearUserFunctions() - backend.stop() + backend.close() theContext = null } diff --git a/hail/src/main/scala/is/hail/HailFeatureFlags.scala b/hail/src/main/scala/is/hail/HailFeatureFlags.scala index b9246d03016..48bb22bb390 100644 --- a/hail/src/main/scala/is/hail/HailFeatureFlags.scala +++ b/hail/src/main/scala/is/hail/HailFeatureFlags.scala @@ -47,7 +47,7 @@ object HailFeatureFlags { ), ) - def fromMap(m: Map[String, String]): HailFeatureFlags = + def fromEnv(m: Map[String, String] = sys.env): HailFeatureFlags = new HailFeatureFlags( mutable.Map( HailFeatureFlags.defaults.map { diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index 2b32a471025..329ee1a3e38 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -74,7 +74,7 @@ trait BackendContext { def executionCache: ExecutionCache } -abstract class Backend { +abstract class Backend extends Closeable { // From https://github.com/hail-is/hail/issues/14580 : // IR can get quite big, especially as it can contain an arbitrary // amount of encoded literals from the user's python session. This @@ -123,8 +123,6 @@ abstract class Backend { f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) - def stop(): Unit - def asSpark(op: String): SparkBackend = fatal(s"${getClass.getSimpleName}: $op requires SparkBackend") diff --git a/hail/src/main/scala/is/hail/backend/BackendServer.scala b/hail/src/main/scala/is/hail/backend/BackendServer.scala index e23d4a3d1e3..7ce224548c9 100644 --- a/hail/src/main/scala/is/hail/backend/BackendServer.scala +++ b/hail/src/main/scala/is/hail/backend/BackendServer.scala @@ -5,6 +5,7 @@ import is.hail.utils._ import scala.util.control.NonFatal +import java.io.Closeable import java.net.InetSocketAddress import java.nio.charset.StandardCharsets import java.util.concurrent._ @@ -31,11 +32,7 @@ case class ParseVCFMetadataPayload(path: String) case class ImportFamPayload(path: String, quant_pheno: Boolean, delimiter: String, missing: String) case class ExecutePayload(ir: String, stream_codec: String, timed: Boolean) -object BackendServer { - def apply(backend: Backend) = new BackendServer(backend) -} - -class BackendServer(backend: Backend) { +class BackendServer(backend: Backend) extends Closeable { // 0 => let the OS pick an available port private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10) private[this] val handler = new BackendHttpHandler(backend) @@ -77,7 +74,7 @@ class BackendServer(backend: Backend) { def start(): Unit = thread.start() - def stop(): Unit = + override def close(): Unit = httpServer.stop(10) } diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index 31fc698f97c..2b43df11419 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -66,7 +66,7 @@ object LocalBackend { class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache { - private[this] val flags = HailFeatureFlags.fromMap(sys.env) + private[this] val flags = HailFeatureFlags.fromEnv() private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) def getFlag(name: String): String = flags.get(name) @@ -78,7 +78,7 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache flags.available // flags can be set after construction from python - def fs: FS = FS.buildRoutes(None, Some(flags), sys.env) + def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = ExecutionTimer.logTime { timer => @@ -137,7 +137,7 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache def defaultParallelism: Int = 1 - def stop(): Unit = LocalBackend.stop() + def close(): Unit = LocalBackend.stop() private[this] def _jvmLowerAndExecute( ctx: ExecuteContext, diff --git a/hail/src/main/scala/is/hail/backend/service/Main.scala b/hail/src/main/scala/is/hail/backend/service/Main.scala index 698f5ffa23c..e6c03bd6596 100644 --- a/hail/src/main/scala/is/hail/backend/service/Main.scala +++ b/hail/src/main/scala/is/hail/backend/service/Main.scala @@ -3,11 +3,13 @@ package is.hail.backend.service object Main { val WORKER = "worker" val DRIVER = "driver" + val TEST = "test" def main(argv: Array[String]): Unit = argv(3) match { case WORKER => Worker.main(argv) case DRIVER => ServiceBackendAPI.main(argv) + case TEST => () case kind => throw new RuntimeException(s"unknown kind: $kind") } } diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 07b33488c45..d9e0b40777f 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -1,6 +1,6 @@ package is.hail.backend.service -import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} +import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ @@ -14,8 +14,8 @@ import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ import is.hail.io.fs._ import is.hail.linalg.BlockMatrix -import is.hail.services._ -import is.hail.services.batch_client.BatchClient +import is.hail.services.{BatchClient, JobGroupRequest, _} +import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success} import is.hail.types._ import is.hail.types.physical._ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType @@ -28,6 +28,7 @@ import scala.reflect.ClassTag import java.io._ import java.nio.charset.StandardCharsets +import java.nio.file.Path import java.util.concurrent._ import org.apache.log4j.Logger @@ -56,16 +57,21 @@ object ServiceBackend { name: String, theHailClassLoader: HailClassLoader, batchClient: BatchClient, - batchId: Option[Long], - jobGroupId: Option[Long], + batchConfig: BatchConfig, scratchDir: String = sys.env.getOrElse("HAIL_WORKER_SCRATCH_DIR", ""), rpcConfig: ServiceBackendRPCPayload, env: Map[String, String], ): ServiceBackend = { - val flags = HailFeatureFlags.fromMap(rpcConfig.flags) + val flags = HailFeatureFlags.fromEnv(rpcConfig.flags) val shouldProfile = flags.get("profile") != null - val fs = FS.buildRoutes(Some(s"$scratchDir/secrets/gsa-key/key.json"), Some(flags), env) + val fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + flags, + env, + ) + ) val backendContext = new ServiceBackendContext( rpcConfig.billing_project, @@ -80,12 +86,11 @@ object ServiceBackend { ) val backend = new ServiceBackend( - jarLocation, + JarUrl(jarLocation), name, theHailClassLoader, batchClient, - batchId, - jobGroupId, + batchConfig, flags, rpcConfig.tmp_dir, fs, @@ -109,17 +114,16 @@ object ServiceBackend { } class ServiceBackend( - val jarLocation: String, + val jarSpec: JarSpec, var name: String, val theHailClassLoader: HailClassLoader, val batchClient: BatchClient, - val curBatchId: Option[Long], - val curJobGroupId: Option[Long], + val batchConfig: BatchConfig, val flags: HailFeatureFlags, val tmpdir: String, val fs: FS, val serviceBackendContext: ServiceBackendContext, - val scratchDir: String = sys.env.get("HAIL_WORKER_SCRATCH_DIR").getOrElse(""), + val scratchDir: String, ) extends Backend with BackendWithNoCodeCache { import ServiceBackend.log @@ -152,130 +156,62 @@ class ServiceBackend( new String(bytes, StandardCharsets.UTF_8) } - private[this] def submitAndWaitForBatch( - _backendContext: BackendContext, - fs: FS, + private[this] def submitJobGroupAndWait( + backendContext: ServiceBackendContext, collection: IndexedSeq[Array[Byte]], + token: String, + root: String, stageIdentifier: String, - f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte], - ): (String, String, Int) = { - val backendContext = _backendContext.asInstanceOf[ServiceBackendContext] - val n = collection.length - val token = tokenUrlSafe - val root = s"${backendContext.remoteTmpDir}parallelizeAndComputeWithIndex/$token" - - log.info(s"parallelizeAndComputeWithIndex: $token: nPartitions $n") - log.info(s"parallelizeAndComputeWithIndex: $token: writing f and contexts") - - val uploadFunction = executor.submit[Unit](() => - retryTransientErrors { - fs.writePDOS(s"$root/f") { fos => - using(new ObjectOutputStream(fos))(oos => oos.writeObject(f)) - } - } - ) - - val uploadContexts = executor.submit[Unit](() => - retryTransientErrors { - fs.writePDOS(s"$root/contexts") { os => - var o = 12L * n - collection.foreach { context => - val len = context.length - os.writeLong(o) - os.writeInt(len) - o += len - } - collection.foreach(context => os.write(context)) - } - } - ) - - uploadFunction.get() - uploadContexts.get() - - val parentJobGroup = curJobGroupId.getOrElse(0L) - val jobGroupIdInUpdate = 1 // QoB creates an update for every new stage - val workerJobGroup = JObject( - "job_group_id" -> JInt(jobGroupIdInUpdate), - "absolute_parent_id" -> JInt(parentJobGroup), - "attributes" -> JObject("name" -> JString(stageIdentifier)), - ) - log.info(s"worker job group spec: $workerJobGroup") + ): JobGroupResponse = { + val defaultProcess = + JvmJob( + command = null, + spec = jarSpec, + profile = flags.get("profile") != null, + ) - val jobs = collection.zipWithIndex.map { case (_, i) => - var resources = JObject("preemptible" -> JBool(true)) - if (backendContext.workerCores != "None") { - resources = resources.merge(JObject("cpu" -> JString(backendContext.workerCores))) - } - if (backendContext.workerMemory != "None") { - resources = resources.merge(JObject("memory" -> JString(backendContext.workerMemory))) - } - if (backendContext.storageRequirement != "0Gi") { - resources = - resources.merge(JObject("storage" -> JString(backendContext.storageRequirement))) - } - JObject( - "always_run" -> JBool(false), - "job_id" -> JInt(i + 1), - "in_update_parent_ids" -> JArray(List()), - "in_update_job_group_id" -> JInt(jobGroupIdInUpdate), - "process" -> JObject( - "jar_spec" -> JObject( - "type" -> JString("jar_url"), - "value" -> JString(jarLocation), - ), - "command" -> JArray(List( - JString(Main.WORKER), - JString(root), - JString(s"$i"), - JString(s"$n"), - )), - "type" -> JString("jvm"), - "profile" -> JBool(backendContext.profile), - ), - "attributes" -> JObject( - "name" -> JString(s"${name}_stage${stageCount}_${stageIdentifier}_job$i") - ), - "resources" -> resources, - "regions" -> JArray(backendContext.regions.map(JString).toList), - "cloudfuse" -> JArray(backendContext.cloudfuseConfig.map { config => - JObject( - "bucket" -> JString(config.bucket), - "mount_path" -> JString(config.mount_path), - "read_only" -> JBool(config.read_only), + val defaultJob = + JobRequest( + always_run = false, + process = null, + resources = Some( + JobResources( + preemptible = true, + cpu = Some(backendContext.workerCores).filter(_ != "None"), + memory = Some(backendContext.workerMemory).filter(_ != "None"), + storage = Some(backendContext.storageRequirement).filter(_ != "0Gi"), ) - }.toList), + ), + regions = Some(backendContext.regions).filter(_.nonEmpty), + cloudfuse = Some(backendContext.cloudfuseConfig).filter(_.nonEmpty), ) - } - log.info(s"parallelizeAndComputeWithIndex: $token: running job") - - val (batchId, (updateId, jobGroupId)) = curBatchId match { - case Some(id) => - (id, batchClient.update(id, token, workerJobGroup, jobs)) - case None => - val batchId = batchClient.create( - JObject( - "billing_project" -> JString(backendContext.billingProject), - "n_jobs" -> JInt(n), - "token" -> JString(token), - "attributes" -> JObject("name" -> JString(name + "_" + stageCount)), + val jobs = + collection.indices.map { i => + defaultJob.copy( + attributes = Map("name" -> s"${name}_stage${stageCount}_${stageIdentifier}_job$i"), + process = defaultProcess.copy( + command = Array(Main.WORKER, root, s"$i", s"${collection.length}") ), - jobs, ) - (batchId, (1L, 1L)) - } + } - val batch = batchClient.waitForJobGroup(batchId, jobGroupId) + val jobGroupId = + batchClient.newJobGroup( + JobGroupRequest( + batch_id = batchConfig.batchId, + absolute_parent_id = batchConfig.jobGroupId, + token = token, + cancel_after_n_failures = Some(1), + attributes = Map("name" -> stageIdentifier), + jobs = jobs, + ) + ) stageCount += 1 - implicit val formats: Formats = DefaultFormats - val batchState = (batch \ "state").extract[String] - if (batchState == "failed") { - throw new HailBatchFailure(s"Update $updateId for batch $batchId failed") - } - (token, root, n) + Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms + batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId) } private[this] def readResult(root: String, i: Int): Array[Byte] = { @@ -303,23 +239,75 @@ class ServiceBackend( f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = { + val backendContext = _backendContext.asInstanceOf[ServiceBackendContext] + + val token = tokenUrlSafe + val root = s"${backendContext.remoteTmpDir}/parallelizeAndComputeWithIndex/$token" + + val uploadFunction = executor.submit[Unit](() => + retryTransientErrors { + fs.writePDOS(s"$root/f") { fos => + using(new ObjectOutputStream(fos))(oos => oos.writeObject(f)) + log.info(s"parallelizeAndComputeWithIndex: $token: uploaded f") + } + } + ) + val (partIdxs, parts) = partitions .map(ps => (ps, ps.map(contexts))) .getOrElse((contexts.indices, contexts)) - val (token, root, _) = - submitAndWaitForBatch(_backendContext, fs, parts, stageIdentifier, f) + val uploadContexts = executor.submit[Unit](() => + retryTransientErrors { + fs.writePDOS(s"$root/contexts") { os => + var o = 12L * parts.length // 12L = sizeof(Long) + sizeof(Int) + parts.foreach { context => + val len = context.length + os.writeLong(o) + os.writeInt(len) + o += len + } + parts.foreach(os.write) + log.info(s"parallelizeAndComputeWithIndex: $token: wrote ${parts.length} contexts") + } + } + ) + + uploadFunction.get() + uploadContexts.get() + + val jobGroup = submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier) log.info(s"parallelizeAndComputeWithIndex: $token: reading results") val startTime = System.nanoTime() - val r @ (error, results) = runAllKeepFirstError(new CancellingExecutorService(executor)) { + var r @ (err, results) = runAll[Option, Array[Byte]](executor) { + /* A missing file means the job was cancelled because another job failed. Assumes that if any + * job was cancelled, then at least one job failed. We want to ignore the missing file + * exceptions and return one of the actual failure exceptions. */ + case (opt, _: FileNotFoundException) => opt + case (opt, e) => opt.orElse(Some(e)) + }(None) { (partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) => (() => readResult(root, jobIndex), partIdx) } } + if (jobGroup.state != Success && err.isEmpty) { + assert(jobGroup.state != Running) + val error = + jobGroup.state match { + case Failure => + new HailBatchFailure( + s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} failed with an unknown error" + ) + case Cancelled => + new CancellationException( + s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} was cancelled" + ) + } - error.foreach(throw _) + r = (Some(error), results) + } val resultsReadingSeconds = (System.nanoTime() - startTime) / 1000000000.0 val rate = results.length / resultsReadingSeconds @@ -328,8 +316,10 @@ class ServiceBackend( r } - def stop(): Unit = + override def close(): Unit = { executor.shutdownNow() + batchClient.close() + } override def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] = ctx.time { @@ -438,23 +428,25 @@ object ServiceBackendAPI { val inputURL = argv(5) val outputURL = argv(6) - val fs = FS.buildRoutes(Some(s"$scratchDir/secrets/gsa-key/key.json"), None, sys.env) - val deployConfig = DeployConfig.fromConfigFile( - s"$scratchDir/secrets/deploy-config/deploy-config.json" + implicit val formats: Formats = DefaultFormats + + val fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + HailFeatureFlags.fromEnv(), + ) ) + val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") DeployConfig.set(deployConfig) sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) - val batchClient = new BatchClient(s"$scratchDir/secrets/gsa-key/key.json") + val batchClient = BatchClient(deployConfig, Path.of(scratchDir, "secrets/gsa-key/key.json")) log.info("BatchClient allocated.") - val batchConfig = BatchConfig.fromConfigFile(s"$scratchDir/batch-config/batch-config.json") - val batchId = batchConfig.map(_.batchId) - val jobGroupId = batchConfig.map(_.jobGroupId) + val batchConfig = + BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")) log.info("BatchConfig parsed.") - implicit val formats: Formats = DefaultFormats - val input = using(fs.openNoCompression(inputURL))(JsonMethods.parse(_)) val rpcConfig = (input \ "config").extract[ServiceBackendRPCPayload] @@ -464,8 +456,7 @@ object ServiceBackendAPI { name, new HailClassLoader(getClass().getClassLoader()), batchClient, - batchId, - jobGroupId, + batchConfig, scratchDir, rpcConfig, sys.env, @@ -518,8 +509,6 @@ private class HailSocketAPIOutputStream( } } -case class CloudfuseConfig(bucket: String, mount_path: String, read_only: Boolean) - case class SequenceConfig(fasta: String, index: String) case class ServiceBackendRPCPayload( diff --git a/hail/src/main/scala/is/hail/backend/service/Worker.scala b/hail/src/main/scala/is/hail/backend/service/Worker.scala index 95067e2d1fb..5722e3bc435 100644 --- a/hail/src/main/scala/is/hail/backend/service/Worker.scala +++ b/hail/src/main/scala/is/hail/backend/service/Worker.scala @@ -1,6 +1,6 @@ package is.hail.backend.service -import is.hail.{HAIL_REVISION, HailContext} +import is.hail.{HAIL_REVISION, HailContext, HailFeatureFlags} import is.hail.asm4s._ import is.hail.backend.HailTaskContext import is.hail.io.fs._ @@ -14,6 +14,7 @@ import scala.util.control.NonFatal import java.io._ import java.nio.charset._ +import java.nio.file.Path import java.util import java.util.{concurrent => javaConcurrent} @@ -113,9 +114,7 @@ object Worker { val n = argv(6).toInt val timer = new WorkerTimer() - val deployConfig = DeployConfig.fromConfigFile( - s"$scratchDir/secrets/deploy-config/deploy-config.json" - ) + val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") DeployConfig.set(deployConfig) sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) @@ -125,7 +124,12 @@ object Worker { timer.start(s"Job $i/$n") timer.start("readInputs") - val fs = FS.buildRoutes(Some(s"$scratchDir/secrets/gsa-key/key.json"), None, sys.env) + val fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + HailFeatureFlags.fromEnv(), + ) + ) def open(x: String): SeekableDataInputStream = fs.openNoCompression(x) @@ -162,19 +166,18 @@ object Worker { timer.end("readInputs") timer.start("executeFunction") - if (HailContext.isInitialized) { HailContext.get.backend = new ServiceBackend( null, null, new HailClassLoader(getClass().getClassLoader()), null, - None, - None, null, null, null, null, + null, + scratchDir, ) } else { HailContext( @@ -184,12 +187,12 @@ object Worker { null, new HailClassLoader(getClass().getClassLoader()), null, - None, - None, null, null, null, null, + null, + scratchDir, ) ) } diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index dfa71427bce..2b4720cd10e 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -332,7 +332,7 @@ class SparkBackend( val bmCache: SparkBlockMatrixCache = SparkBlockMatrixCache() - private[this] val flags = HailFeatureFlags.fromMap(sys.env) + private[this] val flags = HailFeatureFlags.fromEnv() def getFlag(name: String): String = flags.get(name) @@ -472,7 +472,10 @@ class SparkBackend( override def asSpark(op: String): SparkBackend = this - def stop(): Unit = SparkBackend.stop() + def close(): Unit = { + SparkBackend.stop() + longLifeTempFileManager.close() + } def startProgressBar(): Unit = ProgressBarBuilder.build(sc) @@ -761,9 +764,6 @@ class SparkBackend( RVDTableReader(RVD.unkeyed(rowPType, orderedCRDD), globalsLit, rt) } - def close(): Unit = - longLifeTempFileManager.close() - def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) : TableStage = { CanLowerEfficiently(ctx, inputIR) match { diff --git a/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala index 61444d1cdb7..2e7eb2b8672 100644 --- a/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala @@ -1,12 +1,10 @@ package is.hail.io.fs import is.hail.io.fs.FSUtil.dropTrailingSlash +import is.hail.services.oauth2.AzureCloudCredentials import is.hail.services.retryTransientErrors import is.hail.shadedazure.com.azure.core.credential.AzureSasCredential import is.hail.shadedazure.com.azure.core.util.HttpClientOptions -import is.hail.shadedazure.com.azure.identity.{ - ClientSecretCredentialBuilder, DefaultAzureCredentialBuilder, -} import is.hail.shadedazure.com.azure.storage.blob.{ BlobClient, BlobContainerClient, BlobServiceClient, BlobServiceClientBuilder, } @@ -14,19 +12,16 @@ import is.hail.shadedazure.com.azure.storage.blob.models.{ BlobItem, BlobRange, BlobStorageException, ListBlobsOptions, } import is.hail.shadedazure.com.azure.storage.blob.specialized.BlockBlobClient -import is.hail.utils._ +import is.hail.utils.FastSeq import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import java.io.{ByteArrayOutputStream, FileNotFoundException, OutputStream} -import java.nio.file.Paths +import java.nio.file.{Path, Paths} import java.time.Duration -import org.json4s.Formats -import org.json4s.jackson.JsonMethods - class AzureStorageFSURL( val account: String, val container: String, @@ -54,20 +49,19 @@ class AzureStorageFSURL( prefix + pathPart } - override def toString(): String = { + override def toString: String = { val sasTokenPart = sasToken.getOrElse("") this.base + sasTokenPart } } object AzureStorageFS { - object EnvVars { - val AzureApplicationCredentials = "AZURE_APPLICATION_CREDENTIALS" - } - private val AZURE_HTTPS_URI_REGEX = "^https:\\/\\/([a-z0-9_\\-\\.]+)\\.blob\\.core\\.windows\\.net\\/([a-z0-9_\\-\\.]+)(\\/.*)?".r + val RequiredOAuthScopes: IndexedSeq[String] = + FastSeq("https://storage.azure.com/.default") + def parseUrl(filename: String): AzureStorageFSURL = { AZURE_HTTPS_URI_REGEX .findFirstMatchIn(filename) @@ -120,48 +114,30 @@ object AzureStorageFileListEntry { new BlobStorageFileListEntry(url.toString, null, 0, true) } -class AzureStorageFS(val credentialsJSON: Option[String] = None) extends FS { +case class AzureStorageFSConfig(credentials_file: Option[Path]) + +class AzureStorageFS(val credential: AzureCloudCredentials) extends FS { type URL = AzureStorageFSURL private[this] lazy val clients = mutable.Map[(String, String, Option[String]), BlobServiceClient]() - private lazy val credential = credentialsJSON match { - case None => - new DefaultAzureCredentialBuilder().build() - case Some(keyData) => - implicit val formats: Formats = defaultJSONFormats - val kvs = JsonMethods.parse(keyData) - val appId = (kvs \ "appId").extract[String] - val password = (kvs \ "password").extract[String] - val tenant = (kvs \ "tenant").extract[String] - - new ClientSecretCredentialBuilder() - .clientId(appId) - .clientSecret(password) - .tenantId(tenant) - .build() - } - def getServiceClient(url: URL): BlobServiceClient = { val k = (url.account, url.container, url.sasToken) - - clients.get(k) match { - case Some(client) => client - case None => + clients.getOrElseUpdate( + k, { val clientBuilder = url.sasToken match { case Some(sasToken) => new BlobServiceClientBuilder().credential(new AzureSasCredential(sasToken)) - case None => new BlobServiceClientBuilder().credential(credential) + case None => new BlobServiceClientBuilder().credential(credential.value) } - val blobServiceClient = clientBuilder + clientBuilder .endpoint(s"https://${url.account}.blob.core.windows.net") .clientOptions(httpClientOptions) .buildClient() - clients += (k -> blobServiceClient) - blobServiceClient - } + }, + ) } def setPublicAccessServiceClient(url: AzureStorageFSURL): Unit = { diff --git a/hail/src/main/scala/is/hail/io/fs/FS.scala b/hail/src/main/scala/is/hail/io/fs/FS.scala index cd0489ef9e9..f7de8037b74 100644 --- a/hail/src/main/scala/is/hail/io/fs/FS.scala +++ b/hail/src/main/scala/is/hail/io/fs/FS.scala @@ -1,11 +1,9 @@ package is.hail.io.fs -import is.hail.{HailContext, HailFeatureFlags} +import is.hail.HailContext import is.hail.backend.BroadcastValue import is.hail.io.compress.{BGzipInputStream, BGzipOutputStream} -import is.hail.io.fs.AzureStorageFS.EnvVars.AzureApplicationCredentials import is.hail.io.fs.FSUtil.{containsWildcard, dropTrailingSlash} -import is.hail.io.fs.GoogleStorageFS.EnvVars.GoogleApplicationCredentials import is.hail.services._ import is.hail.utils._ @@ -14,14 +12,12 @@ import scala.io.Source import java.io._ import java.nio.ByteBuffer -import java.nio.charset._ import java.nio.file.FileSystems import java.util.zip.GZIPOutputStream import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream import org.apache.commons.io.IOUtils import org.apache.hadoop -import org.apache.log4j.Logger class WrappedSeekableDataInputStream(is: SeekableInputStream) extends DataInputStream(is) with Seekable { @@ -258,56 +254,9 @@ abstract class FSPositionedOutputStream(val capacity: Int) extends OutputStream def getPosition: Long = pos } -object FS { - def buildRoutes( - credentialsPath: Option[String], - flags: Option[HailFeatureFlags], - env: Map[String, String], - ): FS = - retryTransientErrors { - - def readString(path: String): String = - using(new FileInputStream(path))(is => IOUtils.toString(is, Charset.defaultCharset())) - - def gcs = new GoogleStorageFS( - credentialsPath.orElse(sys.env.get(GoogleApplicationCredentials)).map(readString), - flags.flatMap(RequesterPaysConfig.fromFlags), - ) - - def az = env.get("HAIL_TERRA") match { - case Some(_) => new TerraAzureStorageFS() - case None => new AzureStorageFS( - credentialsPath.orElse(sys.env.get(AzureApplicationCredentials)).map(readString) - ) - } - - val cloudSpecificFSs = env.get("HAIL_CLOUD") match { - case Some("gcp") => FastSeq(gcs) - case Some("azure") => FastSeq(az) - case Some(cloud) => - throw new IllegalArgumentException(s"Unknown cloud provider: '$cloud'.'") - case None => - if (credentialsPath.isEmpty) FastSeq(gcs, az) - else fatal( - "Don't know to which cloud credentials belong because 'HAIL_CLOUD' was not set." - ) - } - - new RouterFS( - cloudSpecificFSs :+ new HadoopFS( - new SerializableHadoopConfiguration(new hadoop.conf.Configuration()) - ) - ) - } - - private val log = Logger.getLogger(getClass.getName()) -} - -trait FS extends Serializable { +trait FS extends Serializable with Logging { type URL <: FSURL - import FS.log - def parseUrl(filename: String): URL def validUrl(filename: String): Boolean diff --git a/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala index a2aa439ebb7..dfcb81d30ec 100644 --- a/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala @@ -2,24 +2,24 @@ package is.hail.io.fs import is.hail.HailFeatureFlags import is.hail.io.fs.FSUtil.dropTrailingSlash +import is.hail.io.fs.GoogleStorageFS.RequesterPaysFailure import is.hail.services.{isTransientError, retryTransientErrors} +import is.hail.services.oauth2.GoogleCloudCredentials import is.hail.utils._ import scala.jdk.CollectionConverters._ -import java.io.{ByteArrayInputStream, FileNotFoundException, IOException} +import java.io.{FileNotFoundException, IOException} import java.nio.ByteBuffer -import java.nio.file.Paths +import java.nio.file.{Path, Paths} import com.google.api.client.googleapis.json.GoogleJsonResponseException -import com.google.auth.oauth2.ServiceAccountCredentials import com.google.cloud.{ReadChannel, WriteChannel} import com.google.cloud.http.HttpTransportOptions -import com.google.cloud.storage.{Blob, BlobId, BlobInfo, Storage, StorageException, StorageOptions} +import com.google.cloud.storage.{Option => _, _} import com.google.cloud.storage.Storage.{ BlobGetOption, BlobListOption, BlobSourceOption, BlobWriteOption, } -import org.apache.log4j.Logger case class GoogleStorageFSURL(bucket: String, path: String) extends FSURL { def addPathComponent(c: String): GoogleStorageFSURL = @@ -41,13 +41,12 @@ case class GoogleStorageFSURL(bucket: String, path: String) extends FSURL { } object GoogleStorageFS { - object EnvVars { - val GoogleApplicationCredentials = "GOOGLE_APPLICATION_CREDENTIALS" - } - private val log = Logger.getLogger(getClass.getName()) private[this] val GCS_URI_REGEX = "^gs:\\/\\/([a-z0-9_\\-\\.]+)(\\/.*)?".r + val RequiredOAuthScopes: IndexedSeq[String] = + FastSeq("https://www.googleapis.com/auth/devstorage.read_write") + def parseUrl(filename: String): GoogleStorageFSURL = { val scheme = filename.split(":")(0) if (scheme == null || scheme != "gs") { @@ -65,6 +64,32 @@ object GoogleStorageFS { ) } } + + object RequesterPaysFailure { + def unapply(t: Throwable): Option[Throwable] = + Some(t).filter { + case e: IOException => + Option(e.getCause).exists { + case RequesterPaysFailure(_) => true + case _ => false + } + + case exc: StorageException => + Option(exc.getMessage).exists { message => + message == "userProjectMissing" || + (exc.getCode == 400 && message.contains("requester pays")) + } + + case exc: GoogleJsonResponseException => + Option(exc.getMessage).exists { message => + message == "userProjectMissing" || + (exc.getStatusCode == 400 && message.contains("requester pays")) + } + + case _ => + false + } + } } object GoogleStorageFileListEntry { @@ -86,6 +111,8 @@ object GoogleStorageFileListEntry { new BlobStorageFileListEntry(url.toString, null, 0, true) } +case class RequesterPaysConfig(project: String, buckets: Option[Set[String]]) + object RequesterPaysConfig { object Flags { val RequesterPaysProject = "gcs_requester_pays_project" @@ -107,17 +134,17 @@ object RequesterPaysConfig { } } -case class RequesterPaysConfig(project: String, buckets: Option[Set[String]] = None) - extends Serializable +case class GoogleStorageFSConfig( + credentials_file: Option[Path], + requester_pays_config: Option[RequesterPaysConfig], +) class GoogleStorageFS( - private[this] val serviceAccountKey: Option[String] = None, - private[this] var requesterPaysConfig: Option[RequesterPaysConfig] = None, + private[this] val credentials: GoogleCloudCredentials, + private[this] var requesterPaysConfig: Option[RequesterPaysConfig], ) extends FS { type URL = GoogleStorageFSURL - import GoogleStorageFS.log - override def parseUrl(filename: String): URL = GoogleStorageFS.parseUrl(filename) override def validUrl(filename: String): Boolean = @@ -140,32 +167,6 @@ class GoogleStorageFS( Seq() } - object RequesterPaysFailure { - def unapply(t: Throwable): Option[Throwable] = - Some(t).filter { - case e: IOException => - Option(e.getCause).exists { - case RequesterPaysFailure(_) => true - case _ => false - } - - case exc: StorageException => - Option(exc.getMessage).exists { message => - message == "userProjectMissing" || - (exc.getCode == 400 && message.contains("requester pays")) - } - - case exc: GoogleJsonResponseException => - Option(exc.getMessage).exists { message => - message == "userProjectMissing" || - (exc.getStatusCode == 400 && message.contains("requester pays")) - } - - case _ => - false - } - } - private[this] def handleRequesterPays[T, U]( makeRequest: Seq[U] => T, makeUserProjectOption: String => U, @@ -185,23 +186,12 @@ class GoogleStorageFS( .setConnectTimeout(5000) .setReadTimeout(5000) .build() - serviceAccountKey match { - case None => - log.info("Initializing google storage client from latent credentials") - StorageOptions.newBuilder() - .setTransportOptions(transportOptions) - .build() - .getService - case Some(keyData) => - log.info("Initializing google storage client from service account key") - StorageOptions.newBuilder() - .setCredentials( - ServiceAccountCredentials.fromStream(new ByteArrayInputStream(keyData.getBytes)) - ) - .setTransportOptions(transportOptions) - .build() - .getService - } + + StorageOptions.newBuilder() + .setTransportOptions(transportOptions) + .setCredentials(credentials.value) + .build() + .getService } def openNoCompression(url: URL): SeekableDataInputStream = retryTransientErrors { diff --git a/hail/src/main/scala/is/hail/io/fs/RouterFS.scala b/hail/src/main/scala/is/hail/io/fs/RouterFS.scala index 7d8e9df3c52..2a7ab03fe37 100644 --- a/hail/src/main/scala/is/hail/io/fs/RouterFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/RouterFS.scala @@ -1,5 +1,14 @@ package is.hail.io.fs +import is.hail.HailFeatureFlags +import is.hail.services.oauth2.{AzureCloudCredentials, GoogleCloudCredentials} +import is.hail.utils.{FastSeq, SerializableHadoopConfiguration} + +import java.io.Serializable +import java.nio.file.Path + +import org.apache.hadoop.conf.Configuration + object RouterFSURL { def apply(fs: FS)(_url: fs.URL): RouterFSURL = RouterFSURL(_url, fs) } @@ -15,6 +24,52 @@ case class RouterFSURL private (_url: FSURL, val fs: FS) extends FSURL { override def toString(): String = url.toString } +case class CloudStorageFSConfig( + azure: Option[AzureStorageFSConfig] = None, + google: Option[GoogleStorageFSConfig] = None, +) extends Serializable + +object CloudStorageFSConfig { + def fromFlagsAndEnv( + credentialsFile: Option[Path], + flags: HailFeatureFlags, + env: Map[String, String] = sys.env, + ): CloudStorageFSConfig = { + env.get("HAIL_CLOUD") match { + case Some("azure") => + CloudStorageFSConfig(azure = Some(AzureStorageFSConfig(credentialsFile))) + case Some("gcp") | None => + val rpConf = RequesterPaysConfig.fromFlags(flags) + CloudStorageFSConfig(google = Some(GoogleStorageFSConfig(credentialsFile, rpConf))) + case _ => + CloudStorageFSConfig() + } + } +} + +object RouterFS { + + def buildRoutes(cloudConfig: CloudStorageFSConfig, env: Map[String, String] = sys.env): FS = + new RouterFS( + IndexedSeq.concat( + cloudConfig.google.map { case GoogleStorageFSConfig(path, mRPConfig) => + new GoogleStorageFS( + GoogleCloudCredentials(path, GoogleStorageFS.RequiredOAuthScopes, env), + mRPConfig, + ) + }, + cloudConfig.azure.map { case AzureStorageFSConfig(path) => + if (env.contains("HAIL_TERRA")) { + val creds = AzureCloudCredentials(path, TerraAzureStorageFS.RequiredOAuthScopes, env) + new TerraAzureStorageFS(creds) + } else + new AzureStorageFS(AzureCloudCredentials(path, AzureStorageFS.RequiredOAuthScopes, env)) + }, + FastSeq(new HadoopFS(new SerializableHadoopConfiguration(new Configuration()))), + ) + ) +} + class RouterFS(fss: IndexedSeq[FS]) extends FS { type URL = RouterFSURL diff --git a/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala index 4078bf30603..a073e9a0b38 100644 --- a/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala @@ -1,9 +1,6 @@ package is.hail.io.fs -import is.hail.shadedazure.com.azure.core.credential.TokenRequestContext -import is.hail.shadedazure.com.azure.identity.{ - DefaultAzureCredential, DefaultAzureCredentialBuilder, -} +import is.hail.services.oauth2.AzureCloudCredentials import is.hail.shadedazure.com.azure.storage.blob.BlobServiceClient import is.hail.utils._ @@ -13,17 +10,18 @@ import org.apache.http.client.methods.HttpPost import org.apache.http.client.utils.URIBuilder import org.apache.http.impl.client.HttpClients import org.apache.http.util.EntityUtils -import org.apache.log4j.Logger import org.json4s.{DefaultFormats, Formats} import org.json4s.jackson.JsonMethods object TerraAzureStorageFS { - private val log = Logger.getLogger(getClass.getName) private val TEN_MINUTES_IN_MS = 10 * 60 * 1000 + + val RequiredOAuthScopes: IndexedSeq[String] = + FastSeq("https://management.azure.com/.default") } -class TerraAzureStorageFS extends AzureStorageFS() { - import TerraAzureStorageFS.{log, TEN_MINUTES_IN_MS} +class TerraAzureStorageFS(credential: AzureCloudCredentials) extends AzureStorageFS(credential) { + import TerraAzureStorageFS.TEN_MINUTES_IN_MS private[this] val httpClient = HttpClients.custom().build() private[this] val sasTokenCache = mutable.Map[String, (URL, Long)]() @@ -33,8 +31,6 @@ class TerraAzureStorageFS extends AzureStorageFS() { private[this] val containerResourceId = sys.env("WORKSPACE_STORAGE_CONTAINER_ID") private[this] val storageContainerUrl = parseUrl(sys.env("WORKSPACE_STORAGE_CONTAINER_URL")) - private[this] val credential: DefaultAzureCredential = new DefaultAzureCredentialBuilder().build() - override def getServiceClient(url: URL): BlobServiceClient = if (blobInWorkspaceStorageContainer(url)) { super.getServiceClient(getTerraSasToken(url)) @@ -59,14 +55,10 @@ class TerraAzureStorageFS extends AzureStorageFS() { private def createTerraSasToken(): (URL, Long) = { implicit val formats: Formats = DefaultFormats - val context = new TokenRequestContext() - context.addScopes("https://management.azure.com/.default") - val token = credential.getToken(context).block().getToken() - val url = s"$workspaceManagerUrl/api/workspaces/v1/$workspaceId/resources/controlled/azure/storageContainer/$containerResourceId/getSasToken" val req = new HttpPost(url) - req.addHeader("Authorization", s"Bearer $token") + req.addHeader("Authorization", s"Bearer ${credential.accessToken}") val tenHoursInSeconds = 10 * 3600 val expiration = System.currentTimeMillis() + tenHoursInSeconds * 1000 diff --git a/hail/src/main/scala/is/hail/services/BatchClient.scala b/hail/src/main/scala/is/hail/services/BatchClient.scala new file mode 100644 index 00000000000..8b956b7dc6e --- /dev/null +++ b/hail/src/main/scala/is/hail/services/BatchClient.scala @@ -0,0 +1,334 @@ +package is.hail.services + +import is.hail.expr.ir.ByteArrayBuilder +import is.hail.services.BatchClient.BunchMaxSizeBytes +import is.hail.services.oauth2.CloudCredentials +import is.hail.services.requests.Requester +import is.hail.utils._ + +import scala.util.Random + +import java.net.URL +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.Path + +import org.apache.http.entity.ByteArrayEntity +import org.apache.http.entity.ContentType.APPLICATION_JSON +import org.json4s.{ + CustomSerializer, DefaultFormats, Extraction, Formats, JInt, JNull, JObject, JString, +} +import org.json4s.JsonAST.{JArray, JBool} +import org.json4s.jackson.JsonMethods + +case class BatchRequest( + billing_project: String, + token: String, + n_jobs: Int, + attributes: Map[String, String] = Map.empty, +) + +case class JobGroupRequest( + batch_id: Int, + absolute_parent_id: Int, + token: String, + cancel_after_n_failures: Option[Int] = None, + attributes: Map[String, String] = Map.empty, + jobs: IndexedSeq[JobRequest] = FastSeq(), +) + +case class JobRequest( + always_run: Boolean, + process: JobProcess, + attributes: Map[String, String] = Map.empty, + cloudfuse: Option[Array[CloudfuseConfig]] = None, + resources: Option[JobResources] = None, + regions: Option[Array[String]] = None, +) + +sealed trait JobProcess +case class BashJob(image: String, command: Array[String]) extends JobProcess +case class JvmJob(command: Array[String], spec: JarSpec, profile: Boolean) extends JobProcess + +sealed trait JarSpec +case class GitRevision(sha: String) extends JarSpec +case class JarUrl(url: String) extends JarSpec + +case class JobResources( + preemptible: Boolean, + cpu: Option[String] = None, + memory: Option[String] = None, + storage: Option[String] = None, +) + +case class CloudfuseConfig( + bucket: String, + mount_path: String, + read_only: Boolean, +) + +case class JobGroupResponse( + batch_id: Int, + job_group_id: Int, + state: JobGroupState, + complete: Boolean, + n_jobs: Int, + n_completed: Int, + n_succeeded: Int, + n_failed: Int, + n_cancelled: Int, +) + +sealed trait JobGroupState extends Product with Serializable + +object JobGroupStates { + case object Failure extends JobGroupState + case object Cancelled extends JobGroupState + case object Success extends JobGroupState + case object Running extends JobGroupState +} + +object BatchClient { + + private[this] def BatchServiceScopes(env: Map[String, String]): Array[String] = + env.get("HAIL_CLOUD") match { + case Some("gcp") => + Array( + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/userinfo.email", + "openid", + ) + case Some("azure") => + env.get("HAIL_AZURE_OAUTH_SCOPE").toArray + case Some(cloud) => + throw new IllegalArgumentException(s"Unknown cloud: '$cloud'.") + case None => + throw new IllegalArgumentException(s"HAIL_CLOUD must be set.") + } + + def apply(deployConfig: DeployConfig, credentialsFile: Path, env: Map[String, String] = sys.env) + : BatchClient = + new BatchClient(Requester( + new URL(deployConfig.baseUrl("batch")), + CloudCredentials(credentialsFile, BatchServiceScopes(env), env), + )) + + private val BunchMaxSizeBytes: Int = 1024 * 1024 +} + +case class BatchClient private (req: Requester) extends Logging with AutoCloseable { + + implicit private[this] val fmts: Formats = + DefaultFormats + + JobProcessRequestSerializer + + JobGroupStateDeserializer + + JobGroupResponseDeserializer + + JarSpecSerializer + + def newBatch(createRequest: BatchRequest): Int = { + val response = req.post("/api/v1alpha/batches/create", Extraction.decompose(createRequest)) + val batchId = (response \ "id").extract[Int] + log.info(s"Created batch $batchId") + batchId + } + + def newJobGroup(req: JobGroupRequest): Int = { + val nJobs = req.jobs.length + val (updateId, startJobGroupId) = beginUpdate(req.batch_id, req.token, nJobs) + log.info(s"Began update '$updateId' for batch '${req.batch_id}'.") + + createJobGroup(updateId, req) + log.info(s"Created job group $startJobGroupId for batch ${req.batch_id}") + + createJobsIncremental(req.batch_id, updateId, req.jobs) + log.info(s"Submitted $nJobs in job group $startJobGroupId for batch ${req.batch_id}") + + commitUpdate(req.batch_id, updateId) + log.info(s"Committed update $updateId for batch ${req.batch_id}.") + + startJobGroupId + } + + def getJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse = + req + .get(s"/api/v1alpha/batches/$batchId/job-groups/$jobGroupId") + .extract[JobGroupResponse] + + def waitForJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse = { + val start = System.nanoTime() + + while (true) { + val jobGroup = getJobGroup(batchId, jobGroupId) + + if (jobGroup.complete) + return jobGroup + + // wait 10% of duration so far + // at least, 50ms + // at most, 5s + val now = System.nanoTime() + val elapsed = now - start + val d = math.max( + math.min( + (0.1 * (0.8 + Random.nextFloat() * 0.4) * (elapsed / 1000.0 / 1000)).toInt, + 5000, + ), + 50, + ) + Thread.sleep(d) + } + + throw new AssertionError("unreachable") + } + + override def close(): Unit = + req.close() + + private[this] def createJobsIncremental( + batchId: Int, + updateId: Int, + jobs: IndexedSeq[JobRequest], + ): Unit = { + val buff = new ByteArrayBuilder(BunchMaxSizeBytes) + var sym = "[" + + def flush(): Unit = { + buff ++= "]".getBytes(UTF_8) + req.post( + s"/api/v1alpha/batches/$batchId/updates/$updateId/jobs/create", + new ByteArrayEntity(buff.result(), APPLICATION_JSON), + ) + buff.clear() + sym = "[" + } + + for ((job, idx) <- jobs.zipWithIndex) { + val jobPayload = jobToJson(job, idx).getBytes(UTF_8) + + if (buff.size + jobPayload.length > BunchMaxSizeBytes) { + flush() + } + + buff ++= sym.getBytes(UTF_8) + buff ++= jobPayload + sym = "," + } + + if (buff.size > 0) { flush() } + } + + private[this] def jobToJson(j: JobRequest, jobIdx: Int): String = + JsonMethods.compact { + Extraction.decompose(j) + .asInstanceOf[JObject] + .merge( + JObject( + "job_id" -> JInt(jobIdx + 1), + "in_update_job_group_id" -> JInt(1), + ) + ) + } + + private[this] def beginUpdate(batchId: Int, token: String, nJobs: Int): (Int, Int) = + req + .post( + s"/api/v1alpha/batches/$batchId/updates/create", + JObject( + "token" -> JString(token), + "n_jobs" -> JInt(nJobs), + "n_job_groups" -> JInt(1), + ), + ) + .as { case obj: JObject => + ( + (obj \ "update_id").extract[Int], + (obj \ "start_job_group_id").extract[Int], + ) + } + + private[this] def commitUpdate(batchId: Int, updateId: Int): Unit = + req.patch(s"/api/v1alpha/batches/$batchId/updates/$updateId/commit") + + private[this] def createJobGroup(updateId: Int, jobGroup: JobGroupRequest): Unit = + req.post( + s"/api/v1alpha/batches/${jobGroup.batch_id}/updates/$updateId/job-groups/create", + JArray(List( + JObject( + "job_group_id" -> JInt(1), // job group id relative to the update + "absolute_parent_id" -> JInt(jobGroup.absolute_parent_id), + "cancel_after_n_failures" -> jobGroup.cancel_after_n_failures.map(JInt(_)).getOrElse( + JNull + ), + "attributes" -> Extraction.decompose(jobGroup.attributes), + ) + )), + ) + + private[this] object JobProcessRequestSerializer + extends CustomSerializer[JobProcess](_ => + ( + PartialFunction.empty, + { + case BashJob(image, command) => + JObject( + "type" -> JString("docker"), + "image" -> JString(image), + "command" -> JArray(command.map(JString).toList), + ) + case JvmJob(command, jarSpec, profile) => + JObject( + "type" -> JString("jvm"), + "command" -> JArray(command.map(JString).toList), + "jar_spec" -> Extraction.decompose(jarSpec), + "profile" -> JBool(profile), + ) + }, + ) + ) + + private[this] object JobGroupStateDeserializer + extends CustomSerializer[JobGroupState](_ => + ( + { + case JString("failure") => JobGroupStates.Failure + case JString("cancelled") => JobGroupStates.Cancelled + case JString("success") => JobGroupStates.Success + case JString("running") => JobGroupStates.Running + }, + PartialFunction.empty, + ) + ) + + private[this] object JobGroupResponseDeserializer + extends CustomSerializer[JobGroupResponse](implicit fmts => + ( + { + case o: JObject => + JobGroupResponse( + batch_id = (o \ "batch_id").extract[Int], + job_group_id = (o \ "job_group_id").extract[Int], + state = (o \ "state").extract[JobGroupState], + complete = (o \ "complete").extract[Boolean], + n_jobs = (o \ "n_jobs").extract[Int], + n_completed = (o \ "n_completed").extract[Int], + n_succeeded = (o \ "n_succeeded").extract[Int], + n_failed = (o \ "n_failed").extract[Int], + n_cancelled = (o \ "n_failed").extract[Int], + ) + }, + PartialFunction.empty, + ) + ) + + private[this] object JarSpecSerializer + extends CustomSerializer[JarSpec](_ => + ( + PartialFunction.empty, + { + case JarUrl(url) => + JObject("type" -> JString("jar_url"), "value" -> JString(url)) + case GitRevision(sha) => + JObject("type" -> JString("git_revision"), "value" -> JString(sha)) + }, + ) + ) +} diff --git a/hail/src/main/scala/is/hail/services/BatchConfig.scala b/hail/src/main/scala/is/hail/services/BatchConfig.scala index 9d2e0256ff5..6d5b08ed7b5 100644 --- a/hail/src/main/scala/is/hail/services/BatchConfig.scala +++ b/hail/src/main/scala/is/hail/services/BatchConfig.scala @@ -2,18 +2,14 @@ package is.hail.services import is.hail.utils._ -import java.io.{File, FileInputStream} +import java.nio.file.{Files, Path} import org.json4s._ import org.json4s.jackson.JsonMethods object BatchConfig { - def fromConfigFile(file: String): Option[BatchConfig] = - if (new File(file).exists()) { - using(new FileInputStream(file))(in => Some(fromConfig(JsonMethods.parse(in)))) - } else { - None - } + def fromConfigFile(file: Path): BatchConfig = + using(Files.newInputStream(file))(in => fromConfig(JsonMethods.parse(in))) def fromConfig(config: JValue): BatchConfig = { implicit val formats: Formats = DefaultFormats @@ -21,4 +17,4 @@ object BatchConfig { } } -class BatchConfig(val batchId: Long, val jobGroupId: Long) +case class BatchConfig(batchId: Int, jobGroupId: Int) diff --git a/hail/src/main/scala/is/hail/services/Requester.scala b/hail/src/main/scala/is/hail/services/Requester.scala deleted file mode 100644 index fcfbd808ba8..00000000000 --- a/hail/src/main/scala/is/hail/services/Requester.scala +++ /dev/null @@ -1,168 +0,0 @@ -package is.hail.services - -import is.hail.shadedazure.com.azure.core.credential.TokenRequestContext -import is.hail.shadedazure.com.azure.identity.{ - ClientSecretCredential, ClientSecretCredentialBuilder, -} -import is.hail.utils._ - -import scala.collection.JavaConverters._ - -import java.io.{FileInputStream, InputStream} - -import com.google.auth.oauth2.ServiceAccountCredentials -import org.apache.commons.io.IOUtils -import org.apache.http.{HttpEntity, HttpEntityEnclosingRequest} -import org.apache.http.client.config.RequestConfig -import org.apache.http.client.methods.HttpUriRequest -import org.apache.http.impl.client.{CloseableHttpClient, HttpClients} -import org.apache.http.util.EntityUtils -import org.apache.log4j.{LogManager, Logger} -import org.json4s.{Formats, JValue} -import org.json4s.jackson.JsonMethods - -abstract class CloudCredentials { - def accessToken(): String -} - -class GoogleCloudCredentials(gsaKeyPath: String) extends CloudCredentials { - private[this] val credentials = using(new FileInputStream(gsaKeyPath)) { is => - ServiceAccountCredentials - .fromStream(is) - .createScoped("openid", "email", "profile") - } - - override def accessToken(): String = { - credentials.refreshIfExpired() - credentials.getAccessToken.getTokenValue - } -} - -class AzureCloudCredentials(credentialsPath: String) extends CloudCredentials { - private[this] val credentials: ClientSecretCredential = - using(new FileInputStream(credentialsPath)) { is => - implicit val formats: Formats = defaultJSONFormats - val kvs = JsonMethods.parse(is) - val appId = (kvs \ "appId").extract[String] - val password = (kvs \ "password").extract[String] - val tenant = (kvs \ "tenant").extract[String] - - new ClientSecretCredentialBuilder() - .clientId(appId) - .clientSecret(password) - .tenantId(tenant) - .build() - } - - override def accessToken(): String = { - val context = new TokenRequestContext() - context.setScopes(Array(System.getenv("HAIL_AZURE_OAUTH_SCOPE")).toList.asJava) - credentials.getToken(context).block.getToken - } -} - -class ClientResponseException( - val status: Int, - message: String, - cause: Throwable, -) extends Exception(message, cause) { - def this(statusCode: Int) = this(statusCode, null, null) - - def this(statusCode: Int, message: String) = this(statusCode, message, null) -} - -object Requester { - private val log: Logger = LogManager.getLogger("Requester") - private[this] val TIMEOUT_MS = 5 * 1000 - - val httpClient: CloseableHttpClient = { - log.info("creating HttpClient") - val requestConfig = RequestConfig.custom() - .setConnectTimeout(TIMEOUT_MS) - .setConnectionRequestTimeout(TIMEOUT_MS) - .setSocketTimeout(TIMEOUT_MS) - .build() - try { - HttpClients.custom() - .setSSLContext(tls.getSSLContext) - .setMaxConnPerRoute(20) - .setMaxConnTotal(100) - .setDefaultRequestConfig(requestConfig) - .build() - } catch { - case _: NoSSLConfigFound => - log.info("creating HttpClient with no SSL Context") - HttpClients.custom() - .setMaxConnPerRoute(20) - .setMaxConnTotal(100) - .setDefaultRequestConfig(requestConfig) - .build() - } - } - - def fromCredentialsFile(credentialsPath: String) = { - val credentials = sys.env.get("HAIL_CLOUD") match { - case Some("gcp") => new GoogleCloudCredentials(credentialsPath) - case Some("azure") => new AzureCloudCredentials(credentialsPath) - case Some(cloud) => - throw new IllegalArgumentException(s"Bad cloud: $cloud") - case None => - throw new IllegalArgumentException(s"HAIL_CLOUD must be set.") - } - new Requester(credentials) - } -} - -class Requester( - val credentials: CloudCredentials -) { - import Requester._ - - def requestWithHandler[T >: Null](req: HttpUriRequest, body: HttpEntity, f: InputStream => T) - : T = { - log.info(s"request ${req.getMethod} ${req.getURI}") - - if (body != null) - req.asInstanceOf[HttpEntityEnclosingRequest].setEntity(body) - - val token = credentials.accessToken() - req.addHeader("Authorization", s"Bearer $token") - - retryTransientErrors { - using(httpClient.execute(req)) { resp => - val statusCode = resp.getStatusLine.getStatusCode - log.info(s"request ${req.getMethod} ${req.getURI} response $statusCode") - if (statusCode < 200 || statusCode >= 300) { - val entity = resp.getEntity - val message = - if (entity != null) - EntityUtils.toString(entity) - else - null - throw new ClientResponseException(statusCode, message) - } - val entity: HttpEntity = resp.getEntity - if (entity != null) { - using(entity.getContent)(f) - } else - null - } - } - } - - def requestAsByteStream(req: HttpUriRequest, body: HttpEntity = null): Array[Byte] = - requestWithHandler(req, body, IOUtils.toByteArray) - - def request(req: HttpUriRequest, body: HttpEntity = null): JValue = - requestWithHandler( - req, - body, - { content => - val s = IOUtils.toByteArray(content) - if (s.isEmpty) - null - else - JsonMethods.parse(new String(s)) - }, - ) -} diff --git a/hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala b/hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala deleted file mode 100644 index e6ea6ea1aa2..00000000000 --- a/hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala +++ /dev/null @@ -1,263 +0,0 @@ -package is.hail.services.batch_client - -import is.hail.expr.ir.ByteArrayBuilder -import is.hail.services._ -import is.hail.utils._ - -import scala.util.Random - -import java.nio.charset.StandardCharsets - -import org.apache.http.HttpEntity -import org.apache.http.client.methods.{HttpDelete, HttpGet, HttpPatch, HttpPost} -import org.apache.http.entity.{ByteArrayEntity, ContentType, StringEntity} -import org.apache.log4j.{LogManager, Logger} -import org.json4s.{DefaultFormats, Formats, JInt, JObject, JString, JValue} -import org.json4s.jackson.JsonMethods - -class NoBodyException(message: String, cause: Throwable) extends Exception(message, cause) { - def this() = this(null, null) - - def this(message: String) = this(message, null) -} - -object BatchClient { - lazy val log: Logger = LogManager.getLogger("BatchClient") -} - -class BatchClient( - deployConfig: DeployConfig, - requester: Requester, -) { - - def this(credentialsPath: String) = - this(DeployConfig.get, Requester.fromCredentialsFile(credentialsPath)) - - import BatchClient._ - import requester.request - - private[this] val baseUrl = deployConfig.baseUrl("batch") - - def get(path: String): JValue = - request(new HttpGet(s"$baseUrl$path")) - - def post(path: String, body: HttpEntity): JValue = - request(new HttpPost(s"$baseUrl$path"), body = body) - - def post(path: String, json: JValue = null): JValue = - post( - path, - if (json != null) - new StringEntity( - JsonMethods.compact(json), - ContentType.create("application/json"), - ) - else - null, - ) - - def patch(path: String): JValue = - request(new HttpPatch(s"$baseUrl$path")) - - def delete(path: String, token: String): JValue = - request(new HttpDelete(s"$baseUrl$path")) - - def update(batchID: Long, token: String, jobGroup: JObject, jobs: IndexedSeq[JObject]) - : (Long, Long) = { - implicit val formats: Formats = DefaultFormats - - val updateJson = - JObject("n_jobs" -> JInt(jobs.length), "n_job_groups" -> JInt(1), "token" -> JString(token)) - val jobGroupSpec = specBytes(jobGroup) - val jobBunches = createBunches(jobs) - val updateIDAndJobGroupId = - if (jobBunches.length == 1 && jobBunches(0).length + jobGroupSpec.length < 1024 * 1024) { - val b = new ByteArrayBuilder() - b ++= "{\"job_groups\":".getBytes(StandardCharsets.UTF_8) - addBunchBytes(b, Array(jobGroupSpec)) - b ++= ",\"bunch\":".getBytes(StandardCharsets.UTF_8) - addBunchBytes(b, jobBunches(0)) - b ++= ",\"update\":".getBytes(StandardCharsets.UTF_8) - b ++= JsonMethods.compact(updateJson).getBytes(StandardCharsets.UTF_8) - b += '}' - val data = b.result() - val resp = retryTransientErrors { - post( - s"/api/v1alpha/batches/$batchID/update-fast", - new ByteArrayEntity(data, ContentType.create("application/json")), - ) - } - b.clear() - ((resp \ "update_id").extract[Long], (resp \ "start_job_group_id").extract[Long]) - } else { - val resp = retryTransientErrors { - post(s"/api/v1alpha/batches/$batchID/updates/create", json = updateJson) - } - val updateID = (resp \ "update_id").extract[Long] - val startJobGroupId = (resp \ "start_job_group_id").extract[Long] - - val b = new ByteArrayBuilder() - b ++= "[".getBytes(StandardCharsets.UTF_8) - b ++= jobGroupSpec - b ++= "]".getBytes(StandardCharsets.UTF_8) - retryTransientErrors { - post( - s"/api/v1alpha/batches/$batchID/updates/$updateID/job-groups/create", - new ByteArrayEntity(b.result(), ContentType.create("application/json")), - ) - } - - b.clear() - var i = 0 - while (i < jobBunches.length) { - addBunchBytes(b, jobBunches(i)) - val data = b.result() - retryTransientErrors { - post( - s"/api/v1alpha/batches/$batchID/updates/$updateID/jobs/create", - new ByteArrayEntity( - data, - ContentType.create("application/json"), - ), - ) - } - b.clear() - i += 1 - } - - retryTransientErrors { - patch(s"/api/v1alpha/batches/$batchID/updates/$updateID/commit") - } - (updateID, startJobGroupId) - } - - log.info(s"run: created update $updateIDAndJobGroupId for batch $batchID") - updateIDAndJobGroupId - } - - def create(batchJson: JObject, jobs: IndexedSeq[JObject]): Long = { - implicit val formats: Formats = DefaultFormats - - val bunches = createBunches(jobs) - val batchID = if (bunches.length == 1) { - val bunch = bunches(0) - val b = new ByteArrayBuilder() - b ++= "{\"bunch\":".getBytes(StandardCharsets.UTF_8) - addBunchBytes(b, bunch) - b ++= ",\"batch\":".getBytes(StandardCharsets.UTF_8) - b ++= JsonMethods.compact(batchJson).getBytes(StandardCharsets.UTF_8) - b += '}' - val data = b.result() - val resp = retryTransientErrors { - post( - "/api/v1alpha/batches/create-fast", - new ByteArrayEntity(data, ContentType.create("application/json")), - ) - } - b.clear() - (resp \ "id").extract[Long] - } else { - val resp = retryTransientErrors(post("/api/v1alpha/batches/create", json = batchJson)) - val batchID = (resp \ "id").extract[Long] - - val b = new ByteArrayBuilder() - - var i = 0 - while (i < bunches.length) { - addBunchBytes(b, bunches(i)) - val data = b.result() - retryTransientErrors { - post( - s"/api/v1alpha/batches/$batchID/jobs/create", - new ByteArrayEntity( - data, - ContentType.create("application/json"), - ), - ) - } - b.clear() - i += 1 - } - - retryTransientErrors(patch(s"/api/v1alpha/batches/$batchID/close")) - batchID - } - log.info(s"run: created batch $batchID") - batchID - } - - def run(batchJson: JObject, jobs: IndexedSeq[JObject]): JValue = { - val batchID = create(batchJson, jobs) - waitForJobGroup(batchID, 0L) - } - - def waitForJobGroup(batchID: Long, jobGroupId: Long): JValue = { - implicit val formats: Formats = DefaultFormats - - Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms - - val start = System.nanoTime() - - while (true) { - val jobGroup = - retryTransientErrors(get(s"/api/v1alpha/batches/$batchID/job-groups/$jobGroupId")) - if ((jobGroup \ "complete").extract[Boolean]) - return jobGroup - - // wait 10% of duration so far - // at least, 50ms - // at most, 5s - val now = System.nanoTime() - val elapsed = now - start - val d = math.max( - math.min( - (0.1 * (0.8 + Random.nextFloat() * 0.4) * (elapsed / 1000.0 / 1000)).toInt, - 5000, - ), - 50, - ) - Thread.sleep(d) - } - - throw new AssertionError("unreachable") - } - - private def createBunches(jobs: IndexedSeq[JObject]): BoxedArrayBuilder[Array[Array[Byte]]] = { - val bunches = new BoxedArrayBuilder[Array[Array[Byte]]]() - val bunchb = new BoxedArrayBuilder[Array[Byte]]() - - var i = 0 - var size = 0 - while (i < jobs.length) { - val jobBytes = specBytes(jobs(i)) - if (size + jobBytes.length > 1024 * 1024) { - bunches += bunchb.result() - bunchb.clear() - size = 0 - } - bunchb += jobBytes - size += jobBytes.length - i += 1 - } - assert(bunchb.size > 0) - - bunches += bunchb.result() - bunchb.clear() - bunches - } - - private def specBytes(obj: JObject): Array[Byte] = - JsonMethods.compact(obj).getBytes(StandardCharsets.UTF_8) - - private def addBunchBytes(b: ByteArrayBuilder, bunch: Array[Array[Byte]]): Unit = { - var j = 0 - b += '[' - while (j < bunch.length) { - if (j > 0) - b += ',' - b ++= bunch(j) - j += 1 - } - b += ']' - } -} diff --git a/hail/src/main/scala/is/hail/services/oauth2.scala b/hail/src/main/scala/is/hail/services/oauth2.scala new file mode 100644 index 00000000000..5063a51bb8e --- /dev/null +++ b/hail/src/main/scala/is/hail/services/oauth2.scala @@ -0,0 +1,126 @@ +package is.hail.services + +import is.hail.services.oauth2.AzureCloudCredentials.EnvVars.AzureApplicationCredentials +import is.hail.services.oauth2.GoogleCloudCredentials.EnvVars.GoogleApplicationCredentials +import is.hail.shadedazure.com.azure.core.credential.{ + AccessToken, TokenCredential, TokenRequestContext, +} +import is.hail.shadedazure.com.azure.identity.{ + ClientSecretCredentialBuilder, DefaultAzureCredentialBuilder, +} +import is.hail.utils.{defaultJSONFormats, using} + +import scala.collection.JavaConverters._ + +import java.io.Serializable +import java.nio.file.{Files, Path} +import java.time.OffsetDateTime + +import com.google.auth.oauth2.{GoogleCredentials, ServiceAccountCredentials} +import org.json4s.Formats +import org.json4s.jackson.JsonMethods + +object oauth2 { + + sealed trait CloudCredentials extends Product with Serializable { + def accessToken: String + } + + def CloudCredentials( + keyPath: Path, + scopes: IndexedSeq[String], + env: Map[String, String] = sys.env, + ): CloudCredentials = + env.get("HAIL_CLOUD") match { + case Some("gcp") => GoogleCloudCredentials(Some(keyPath), scopes, env) + case Some("azure") => AzureCloudCredentials(Some(keyPath), scopes, env) + case Some(cloud) => throw new IllegalArgumentException(s"Unknown cloud: '$cloud'") + case None => throw new IllegalArgumentException(s"HAIL_CLOUD must be set.") + } + + case class GoogleCloudCredentials(value: GoogleCredentials) extends CloudCredentials { + override def accessToken: String = { + value.refreshIfExpired() + value.getAccessToken.getTokenValue + } + } + + object GoogleCloudCredentials { + object EnvVars { + val GoogleApplicationCredentials = "GOOGLE_APPLICATION_CREDENTIALS" + } + + def apply(keyPath: Option[Path], scopes: IndexedSeq[String], env: Map[String, String] = sys.env) + : GoogleCloudCredentials = + GoogleCloudCredentials { + val creds: GoogleCredentials = + keyPath.orElse(env.get(GoogleApplicationCredentials).map(Path.of(_))) match { + case Some(path) => + using(Files.newInputStream(path))(ServiceAccountCredentials.fromStream) + case None => + GoogleCredentials.getApplicationDefault + } + + creds.createScoped(scopes: _*) + } + } + + sealed trait AzureCloudCredentials extends CloudCredentials { + + def value: TokenCredential + def scopes: IndexedSeq[String] + + @transient private[this] var token: AccessToken = _ + + override def accessToken: String = { + refreshIfRequired() + token.getToken + } + + private[this] def refreshIfRequired(): Unit = + if (!isExpired) token.getToken + else synchronized { + if (isExpired) { + token = value.getTokenSync(new TokenRequestContext().setScopes(scopes.asJava)) + } + + token.getToken + } + + private[this] def isExpired: Boolean = + token == null || OffsetDateTime.now.plusMinutes(5).isBefore(token.getExpiresAt) + } + + object AzureCloudCredentials { + object EnvVars { + val AzureApplicationCredentials = "AZURE_APPLICATION_CREDENTIALS" + } + + def apply(keyPath: Option[Path], scopes: IndexedSeq[String], env: Map[String, String] = sys.env) + : AzureCloudCredentials = + keyPath.orElse(env.get(AzureApplicationCredentials).map(Path.of(_))) match { + case Some(path) => AzureClientSecretCredentials(path, scopes) + case None => AzureDefaultCredentials(scopes) + } + } + + private case class AzureDefaultCredentials(scopes: IndexedSeq[String]) + extends AzureCloudCredentials { + @transient override lazy val value: TokenCredential = + new DefaultAzureCredentialBuilder().build() + } + + private case class AzureClientSecretCredentials(path: Path, scopes: IndexedSeq[String]) + extends AzureCloudCredentials { + @transient override lazy val value: TokenCredential = + using(Files.newInputStream(path)) { is => + implicit val fmts: Formats = defaultJSONFormats + val kvs = JsonMethods.parse(is) + new ClientSecretCredentialBuilder() + .clientId((kvs \ "appId").extract[String]) + .clientSecret((kvs \ "password").extract[String]) + .tenantId((kvs \ "tenant").extract[String]) + .build() + } + } +} diff --git a/hail/src/main/scala/is/hail/services/package.scala b/hail/src/main/scala/is/hail/services/package.scala index 161448ef102..e7887fea509 100644 --- a/hail/src/main/scala/is/hail/services/package.scala +++ b/hail/src/main/scala/is/hail/services/package.scala @@ -1,5 +1,6 @@ package is.hail +import is.hail.services.requests.ClientResponseException import is.hail.shadedazure.com.azure.storage.common.implementation.Constants import is.hail.utils._ diff --git a/hail/src/main/scala/is/hail/services/requests.scala b/hail/src/main/scala/is/hail/services/requests.scala new file mode 100644 index 00000000000..b6ec90a08b3 --- /dev/null +++ b/hail/src/main/scala/is/hail/services/requests.scala @@ -0,0 +1,94 @@ +package is.hail.services + +import is.hail.services.oauth2.CloudCredentials +import is.hail.utils.{log, _} + +import java.net.URL + +import org.apache.http.{HttpEntity, HttpEntityEnclosingRequest} +import org.apache.http.client.config.RequestConfig +import org.apache.http.client.methods.{HttpGet, HttpPatch, HttpPost, HttpUriRequest} +import org.apache.http.entity.ContentType.APPLICATION_JSON +import org.apache.http.entity.StringEntity +import org.apache.http.impl.client.{CloseableHttpClient, HttpClients} +import org.apache.http.util.EntityUtils +import org.json4s.JValue +import org.json4s.JsonAST.JNothing +import org.json4s.jackson.JsonMethods + +object requests { + + class ClientResponseException(val status: Int, message: String) extends Exception(message) + + trait Requester extends AutoCloseable { + def get(route: String): JValue + def post(route: String, body: JValue): JValue + def post(route: String, body: HttpEntity): JValue + def patch(route: String): JValue + } + + private[this] val TIMEOUT_MS = 5 * 1000 + + def Requester(baseUrl: URL, cred: CloudCredentials): Requester = { + + val httpClient: CloseableHttpClient = { + log.info("creating HttpClient") + val requestConfig = RequestConfig.custom() + .setConnectTimeout(TIMEOUT_MS) + .setConnectionRequestTimeout(TIMEOUT_MS) + .setSocketTimeout(TIMEOUT_MS) + .build() + try { + HttpClients.custom() + .setSSLContext(tls.getSSLContext) + .setMaxConnPerRoute(20) + .setMaxConnTotal(100) + .setDefaultRequestConfig(requestConfig) + .build() + } catch { + case _: NoSSLConfigFound => + log.info("creating HttpClient with no SSL Context") + HttpClients.custom() + .setMaxConnPerRoute(20) + .setMaxConnTotal(100) + .setDefaultRequestConfig(requestConfig) + .build() + } + } + + def request(req: HttpUriRequest, body: Option[HttpEntity] = None): JValue = { + req.addHeader("Authorization", s"Bearer ${cred.accessToken}") + body.foreach(entity => req.asInstanceOf[HttpEntityEnclosingRequest].setEntity(entity)) + retryTransientErrors { + using(httpClient.execute(req)) { resp => + val statusCode = resp.getStatusLine.getStatusCode + val message = Option(resp.getEntity).map(EntityUtils.toString).filter(_.nonEmpty) + if (statusCode < 200 || statusCode >= 300) { + log.warn(s"$statusCode ${req.getMethod} ${req.getURI}\n${message.orNull}") + throw new ClientResponseException(statusCode, message.orNull) + } + + log.info(s"$statusCode ${req.getMethod} ${req.getURI}") + message.map(JsonMethods.parse(_)).getOrElse(JNothing) + } + } + } + + new Requester with Logging { + override def get(route: String): JValue = + request(new HttpGet(s"$baseUrl$route")) + + override def post(route: String, body: JValue): JValue = + post(route, new StringEntity(JsonMethods.compact(body), APPLICATION_JSON)) + + override def post(route: String, body: HttpEntity): JValue = + request(new HttpPost(s"$baseUrl$route"), Some(body)) + + override def patch(route: String): JValue = + request(new HttpPatch(s"$baseUrl$route")) + + override def close(): Unit = + httpClient.close() + } + } +} diff --git a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala index 14a9a1689d9..c66e1b0fcbd 100644 --- a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala @@ -1,184 +1,130 @@ package is.hail.backend +import is.hail.HailFeatureFlags import is.hail.asm4s.HailClassLoader -import is.hail.backend.service.{ServiceBackend, ServiceBackendRPCPayload} -import is.hail.services.batch_client.BatchClient -import is.hail.utils.tokenUrlSafe +import is.hail.backend.service.{ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload} +import is.hail.io.fs.{CloudStorageFSConfig, RouterFS} +import is.hail.services._ +import is.hail.services.JobGroupStates.Success +import is.hail.utils.{tokenUrlSafe, using} import scala.reflect.io.{Directory, Path} +import scala.util.Random + +import java.io.Closeable -import org.json4s.{JArray, JBool, JInt, JObject, JString} import org.mockito.ArgumentMatchersSugar.any import org.mockito.IdiomaticMockito import org.mockito.MockitoSugar.when +import org.scalatest.OptionValues import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test -class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito { +class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionValues { @Test def testCreateJobPayload(): Unit = - withMockDriverContext { rpcPayload => - val batchClient = mock[BatchClient] - - val backend = - ServiceBackend( - jarLocation = - classOf[ServiceBackend].getProtectionDomain.getCodeSource.getLocation.getPath, - name = "name", - theHailClassLoader = new HailClassLoader(getClass.getClassLoader), - batchClient, - batchId = None, - jobGroupId = None, - scratchDir = rpcPayload.remote_tmpdir, - rpcConfig = rpcPayload, - sys.env + ("HAIL_CLOUD" -> "gcp"), - ) - - val contexts = Array.tabulate(1)(_.toString.getBytes) - - // verify that the service backend - // - creates the batch with the correct billing project, and - // - the number of jobs matches the number of partitions, and - // - each job is created in the specified region, and - // - each job's resource configuration matches the rpc config - when(batchClient.create(any[JObject], any[IndexedSeq[JObject]])) thenAnswer { - (batch: JObject, jobs: IndexedSeq[JObject]) => - batch \ "billing_project" shouldBe JString(rpcPayload.billing_project) - batch \ "n_jobs" shouldBe JInt(contexts.length) - - jobs.length shouldEqual contexts.length - jobs.foreach { payload => - payload \ "regions" shouldBe JArray(rpcPayload.regions.map(JString).toList) - - payload \ "resources" shouldBe JObject( - "preemptible" -> JBool(true), - "cpu" -> JString(rpcPayload.worker_cores), - "memory" -> JString(rpcPayload.worker_memory), - "storage" -> JString(rpcPayload.storage), - ) - } - - 37L - } - - // the service backend expects that each job write its output to a well-known - // location when it finishes. - when(batchClient.waitForJobGroup(any[Long], any[Long])) thenAnswer { - (batchId: Long, jobGroupId: Long) => - batchId shouldEqual 37L - jobGroupId shouldEqual 1L - - val resultsDir = - Path(backend.serviceBackendContext.remoteTmpDir) / - "parallelizeAndComputeWithIndex" / - tokenUrlSafe - - resultsDir.createDirectory() - for (i <- contexts.indices) (resultsDir / f"result.$i").toFile.writeAll("11") - JObject("state" -> JString("success")) - } - - val (failure, _) = - backend.parallelizeAndComputeWithIndex( - backend.serviceBackendContext, - backend.fs, - contexts, - "stage1", - )((bytes, _, _, _) => bytes) - - failure.foreach(throw _) - - batchClient.create(any[JObject], any[IndexedSeq[JObject]]) wasCalled once - } - - @Test def testUpdateJobPayload(): Unit = - withMockDriverContext { config => + withMockDriverContext { rpcConfig => val batchClient = mock[BatchClient] + using(ServiceBackend(batchClient, rpcConfig)) { backend => + val contexts = Array.tabulate(1)(_.toString.getBytes) + + // verify that + // - the number of jobs matches the number of partitions, and + // - each job is created in the specified region, and + // - each job's resource configuration matches the rpc config + + when(batchClient.newJobGroup(any[JobGroupRequest])) thenAnswer { + jobGroup: JobGroupRequest => + jobGroup.batch_id shouldBe backend.batchConfig.batchId + jobGroup.absolute_parent_id shouldBe backend.batchConfig.jobGroupId + val jobs = jobGroup.jobs + jobs.length shouldEqual contexts.length + jobs.foreach { payload => + payload.regions.value shouldBe rpcConfig.regions + payload.resources.value shouldBe JobResources( + preemptible = true, + cpu = Some(rpcConfig.worker_cores), + memory = Some(rpcConfig.worker_memory), + storage = Some(rpcConfig.storage), + ) + } + + backend.batchConfig.jobGroupId + 1 + } - val backend = - ServiceBackend( - jarLocation = - classOf[ServiceBackend].getProtectionDomain.getCodeSource.getLocation.getPath, - name = "name", - theHailClassLoader = new HailClassLoader(getClass.getClassLoader), - batchClient, - batchId = Some(23L), - jobGroupId = None, - scratchDir = config.remote_tmpdir, - rpcConfig = config, - sys.env + ("HAIL_CLOUD" -> "gcp"), - ) - - val contexts = Array.tabulate(1)(_.toString.getBytes) - - // verify that the service backend - // - updates the batch with the correct billing project, and - // - the number of jobs matches the number of partitions, and - // - each job is created in the specified region, and - // - each job's resource configuration matches the rpc config - when( - batchClient.update(any[Long], any[String], any[JObject], any[IndexedSeq[JObject]]) - ) thenAnswer { - (batchId: Long, _: String, _: JObject, jobs: IndexedSeq[JObject]) => - batchId shouldEqual 23L - - jobs.length shouldEqual contexts.length - jobs.foreach { payload => - payload \ "regions" shouldBe JArray(config.regions.map(JString).toList) - - payload \ "resources" shouldBe JObject( - "preemptible" -> JBool(true), - "cpu" -> JString(config.worker_cores), - "memory" -> JString(config.worker_memory), - "storage" -> JString(config.storage), + // the service backend expects that each job write its output to a well-known + // location when it finishes. + when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { + (id: Int, jobGroupId: Int) => + id shouldEqual backend.batchConfig.batchId + jobGroupId shouldEqual backend.batchConfig.jobGroupId + 1 + + val resultsDir = + Path(backend.serviceBackendContext.remoteTmpDir) / + "parallelizeAndComputeWithIndex" / + tokenUrlSafe + + resultsDir.createDirectory() + for (i <- contexts.indices) (resultsDir / f"result.$i").toFile.writeAll("11") + JobGroupResponse( + batch_id = id, + job_group_id = jobGroupId, + state = Success, + complete = true, + n_jobs = contexts.length, + n_completed = contexts.length, + n_succeeded = contexts.length, + n_failed = 0, + n_cancelled = 0, ) - } - - (2L, 3L) - } + } - when(batchClient.waitForJobGroup(any[Long], any[Long])) thenAnswer { - (batchId: Long, jobGroupId: Long) => - batchId shouldEqual 23L - jobGroupId shouldEqual 3L + val (failure, _) = + backend.parallelizeAndComputeWithIndex( + backend.serviceBackendContext, + backend.fs, + contexts, + "stage1", + )((bytes, _, _, _) => bytes) - val resultsDir = - Path(backend.serviceBackendContext.remoteTmpDir) / - "parallelizeAndComputeWithIndex" / - tokenUrlSafe + failure.foreach(throw _) - resultsDir.createDirectory() - for (i <- contexts.indices) (resultsDir / f"result.$i").toFile.writeAll("11") - JObject("state" -> JString("success")) + batchClient.newJobGroup(any) wasCalled once + batchClient.waitForJobGroup(any, any) wasCalled once } - - val (failure, _) = - backend.parallelizeAndComputeWithIndex( - backend.serviceBackendContext, - backend.fs, - contexts, - "stage1", - )((bytes, _, _, _) => bytes) - - failure.foreach(throw _) - - batchClient.create(any[JObject], any[IndexedSeq[JObject]]) wasNever called - batchClient.update( - any[Long], - any[String], - any[JObject], - any[IndexedSeq[JObject]], - ) wasCalled once } - def withMockDriverContext(test: ServiceBackendRPCPayload => Any): Any = - withNewLocalTmpFolder { tmp => - // The `ServiceBackend` assumes credentials are installed to a well known location - val gcsKeyDir = tmp / "secrets" / "gsa-key" - gcsKeyDir.createDirectory() - (gcsKeyDir / "key.json").toFile.writeAll("password1234") + def ServiceBackend(client: BatchClient, rpcConfig: ServiceBackendRPCPayload): ServiceBackend = { + val flags = HailFeatureFlags.fromEnv() + val fs = RouterFS.buildRoutes(CloudStorageFSConfig()) + new ServiceBackend( + jarSpec = GitRevision("123"), + name = "name", + theHailClassLoader = new HailClassLoader(getClass.getClassLoader), + batchClient = client, + batchConfig = BatchConfig(batchId = Random.nextInt(), jobGroupId = Random.nextInt()), + flags = flags, + tmpdir = rpcConfig.tmp_dir, + fs = fs, + serviceBackendContext = + new ServiceBackendContext( + rpcConfig.billing_project, + rpcConfig.remote_tmpdir, + rpcConfig.worker_cores, + rpcConfig.worker_memory, + rpcConfig.storage, + rpcConfig.regions, + rpcConfig.cloudfuse_configs, + profile = false, + ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), + ), + scratchDir = rpcConfig.remote_tmpdir, + ) + } + def withMockDriverContext(test: ServiceBackendRPCPayload => Any): Any = + using(LocalTmpFolder) { tmp => withObjectSpied[is.hail.utils.UtilsType] { // not obvious how to pull out `tokenUrlSafe` and inject this directory // using a spy is a hack and i don't particularly like it. @@ -186,8 +132,8 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito { test { ServiceBackendRPCPayload( - tmp_dir = "", - remote_tmpdir = tmp.path + "/", // because raw strings... + tmp_dir = tmp.path, + remote_tmpdir = tmp.path, billing_project = "fancy", worker_cores = "128", worker_memory = "a lot.", @@ -203,10 +149,8 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito { } } - def withNewLocalTmpFolder[A](f: Directory => A): A = { - val tmp = Directory.makeTemp("hail-testing-tmp", "") - try f(tmp) - finally tmp.deleteRecursively() - } - + def LocalTmpFolder: Directory with Closeable = + new Directory(Directory.makeTemp("hail-testing-tmp").jfile) with Closeable { + override def close(): Unit = deleteRecursively() + } } diff --git a/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala b/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala index d6f3b8cbc4f..ce3759d0062 100644 --- a/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala +++ b/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala @@ -1,8 +1,7 @@ package is.hail.io.fs -import java.io.FileInputStream +import is.hail.services.oauth2.AzureCloudCredentials -import org.apache.commons.io.IOUtils import org.testng.SkipException import org.testng.annotations.{BeforeClass, Test} @@ -17,16 +16,8 @@ class AzureStorageFSSuite extends FSSuite { } } - lazy val fs = { - val aac = System.getenv("AZURE_APPLICATION_CREDENTIALS") - if (aac == null) { - new AzureStorageFS() - } else { - new AzureStorageFS( - Some(new String(IOUtils.toByteArray(new FileInputStream(aac)))) - ) - } - } + override lazy val fs: FS = + new AzureStorageFS(AzureCloudCredentials(None, AzureStorageFS.RequiredOAuthScopes)) @Test def testMakeQualified(): Unit = { val qualifiedFileName = "https://account.blob.core.windows.net/container/path" diff --git a/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala b/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala index 4f6e654b87c..5bcf8a7fc6e 100644 --- a/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala +++ b/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala @@ -1,8 +1,7 @@ package is.hail.io.fs -import java.io.FileInputStream +import is.hail.services.oauth2.GoogleCloudCredentials -import org.apache.commons.io.IOUtils import org.scalatestplus.testng.TestNGSuite import org.testng.SkipException import org.testng.annotations.{BeforeClass, Test} @@ -18,16 +17,8 @@ class GoogleStorageFSSuite extends TestNGSuite with FSSuite { } } - lazy val fs = { - val gac = System.getenv("GOOGLE_APPLICATION_CREDENTIALS") - if (gac == null) { - new GoogleStorageFS() - } else { - new GoogleStorageFS( - Some(new String(IOUtils.toByteArray(new FileInputStream(gac)))) - ) - } - } + override lazy val fs: FS = + new GoogleStorageFS(GoogleCloudCredentials(None, GoogleStorageFS.RequiredOAuthScopes), None) @Test def testMakeQualified(): Unit = { val qualifiedFileName = "gs://bucket/path" diff --git a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala new file mode 100644 index 00000000000..079870efce2 --- /dev/null +++ b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala @@ -0,0 +1,132 @@ +package is.hail.services + +import is.hail.HAIL_REVISION +import is.hail.backend.service.Main +import is.hail.services.JobGroupStates.Failure +import is.hail.utils._ + +import java.lang.reflect.Method +import java.nio.file.Path + +import org.scalatestplus.testng.TestNGSuite +import org.testng.annotations.{AfterClass, BeforeClass, BeforeMethod, Test} + +class BatchClientSuite extends TestNGSuite { + + private[this] var client: BatchClient = _ + private[this] var batchId: Int = _ + private[this] var parentJobGroupId: Int = _ + + @BeforeClass + def createClientAndBatch(): Unit = { + client = BatchClient(DeployConfig.get(), Path.of("/test-gsa-key/key.json")) + batchId = client.newBatch( + BatchRequest( + billing_project = "test", + n_jobs = 0, + token = tokenUrlSafe, + attributes = Map("name" -> s"${getClass.getName}"), + ) + ) + } + + @BeforeMethod + def createEmptyParentJobGroup(m: Method): Unit = { + parentJobGroupId = client.newJobGroup( + req = JobGroupRequest( + batch_id = batchId, + absolute_parent_id = 0, + token = tokenUrlSafe, + attributes = Map("name" -> m.getName), + jobs = FastSeq(), + ) + ) + } + + @AfterClass + def closeClient(): Unit = + client.close() + + @Test + def testCancelAfterNFailures(): Unit = { + val jobGroupId = client.newJobGroup( + req = JobGroupRequest( + batch_id = batchId, + absolute_parent_id = parentJobGroupId, + cancel_after_n_failures = Some(1), + token = tokenUrlSafe, + jobs = FastSeq( + JobRequest( + always_run = false, + process = BashJob( + image = "ubuntu:22.04", + command = Array("/bin/bash", "-c", "sleep 5m"), + ), + resources = Some(JobResources(preemptible = true)), + ), + JobRequest( + always_run = false, + process = BashJob( + image = "ubuntu:22.04", + command = Array("/bin/bash", "-c", "exit 1"), + ), + ), + ), + ) + ) + val result = client.waitForJobGroup(batchId, jobGroupId) + assert(result.state == Failure) + assert(result.n_cancelled == 1) + } + + @Test + def testNewJobGroup(): Unit = + // The query driver submits a job group per stage with one job per partition + for (i <- 1 to 2) { + val jobGroupId = client.newJobGroup( + req = JobGroupRequest( + batch_id = batchId, + absolute_parent_id = parentJobGroupId, + token = tokenUrlSafe, + attributes = Map("name" -> s"JobGroup$i"), + jobs = (1 to i).map { k => + JobRequest( + always_run = false, + process = BashJob( + image = "ubuntu:22.04", + command = Array("/bin/bash", "-c", s"echo 'job $k'"), + ), + ) + }, + ) + ) + + val result = client.getJobGroup(batchId, jobGroupId) + assert(result.n_jobs == i) + } + + @Test + def testJvmJob(): Unit = { + val jobGroupId = client.newJobGroup( + req = JobGroupRequest( + batch_id = batchId, + absolute_parent_id = parentJobGroupId, + token = tokenUrlSafe, + attributes = Map("name" -> "TableStage"), + jobs = FastSeq( + JobRequest( + always_run = false, + process = JvmJob( + command = Array(Main.TEST), + spec = GitRevision(HAIL_REVISION), + profile = false, + ), + ) + ), + ) + ) + + val result = client.getJobGroup(batchId, jobGroupId) + assert(result.n_jobs == 1) + } +} diff --git a/hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala deleted file mode 100644 index 521d40046e4..00000000000 --- a/hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala +++ /dev/null @@ -1,40 +0,0 @@ -package is.hail.services.batch_client - -import is.hail.utils._ - -import org.json4s.{DefaultFormats, Formats} -import org.json4s.JsonAST._ -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test - -class BatchClientSuite extends TestNGSuite { - @Test def testBasic(): Unit = { - val client = new BatchClient("/test-gsa-key/key.json") - val token = tokenUrlSafe - val batch = client.run( - JObject( - "billing_project" -> JString("test"), - "n_jobs" -> JInt(1), - "token" -> JString(token), - ), - FastSeq( - JObject( - "always_run" -> JBool(false), - "job_id" -> JInt(0), - "parent_ids" -> JArray(List()), - "process" -> JObject( - "image" -> JString("ubuntu:22.04"), - "command" -> JArray(List( - JString("/bin/bash"), - JString("-c"), - JString("echo 'Hello, world!'"), - )), - "type" -> JString("docker"), - ), - ) - ), - ) - implicit val formats: Formats = DefaultFormats - assert((batch \ "state").extract[String] == "success") - } -}