Skip to content

Commit

Permalink
store constants in searchable fields
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Dec 12, 2024
1 parent e121222 commit 0a14104
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 87 deletions.
4 changes: 4 additions & 0 deletions hail/src/main/scala/is/hail/backend/service/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ object Main {
argv(3) match {
case WORKER => Worker.main(argv)
case DRIVER => ServiceBackendAPI.main(argv)

// Batch's "JvmJob" is a special kind of job that can only call `Main.main`.
// TEST is used for integration testing the `BatchClient` to verify that we
// can create JvmJobs without having to mock the payload to a `Worker` job.
case TEST => ()
case kind => throw new RuntimeException(s"unknown kind: $kind")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ class ServiceBackend(

val token = tokenUrlSafe
val root = s"${backendContext.remoteTmpDir}/parallelizeAndComputeWithIndex/$token"
log.info(s"parallelizeAndComputeWithIndex: token='$token', nPartitions=${contexts.length}")

val uploadFunction = executor.submit[Unit](() =>
retryTransientErrors {
Expand Down
159 changes: 81 additions & 78 deletions hail/src/main/scala/is/hail/services/BatchClient.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package is.hail.services

import is.hail.expr.ir.ByteArrayBuilder
import is.hail.services.BatchClient.BunchMaxSizeBytes
import is.hail.services.BatchClient.{
BunchMaxSizeBytes, JarSpecSerializer, JobGroupResponseDeserializer, JobGroupStateDeserializer,
JobProcessRequestSerializer,
}
import is.hail.services.oauth2.CloudCredentials
import is.hail.services.requests.Requester
import is.hail.utils._
Expand Down Expand Up @@ -86,7 +89,16 @@ object JobGroupStates {

object BatchClient {

private[this] def BatchServiceScopes(env: Map[String, String]): Array[String] =
val BunchMaxSizeBytes: Int = 1024 * 1024

def apply(deployConfig: DeployConfig, credentialsFile: Path, env: Map[String, String] = sys.env)
: BatchClient =
new BatchClient(Requester(
new URL(deployConfig.baseUrl("batch")),
CloudCredentials(credentialsFile, BatchServiceScopes(env), env),
))

def BatchServiceScopes(env: Map[String, String]): Array[String] =
env.get("HAIL_CLOUD") match {
case Some("gcp") =>
Array(
Expand All @@ -102,14 +114,70 @@ object BatchClient {
throw new IllegalArgumentException(s"HAIL_CLOUD must be set.")
}

def apply(deployConfig: DeployConfig, credentialsFile: Path, env: Map[String, String] = sys.env)
: BatchClient =
new BatchClient(Requester(
new URL(deployConfig.baseUrl("batch")),
CloudCredentials(credentialsFile, BatchServiceScopes(env), env),
))
object JobProcessRequestSerializer extends CustomSerializer[JobProcess](implicit fmts =>
(
PartialFunction.empty,
{
case BashJob(image, command) =>
JObject(
"type" -> JString("docker"),
"image" -> JString(image),
"command" -> JArray(command.map(JString).toList),
)
case JvmJob(command, jarSpec, profile) =>
JObject(
"type" -> JString("jvm"),
"command" -> JArray(command.map(JString).toList),
"jar_spec" -> Extraction.decompose(jarSpec),
"profile" -> JBool(profile),
)
},
)
)

private val BunchMaxSizeBytes: Int = 1024 * 1024
object JobGroupStateDeserializer extends CustomSerializer[JobGroupState](_ =>
(
{
case JString("failure") => JobGroupStates.Failure
case JString("cancelled") => JobGroupStates.Cancelled
case JString("success") => JobGroupStates.Success
case JString("running") => JobGroupStates.Running
},
PartialFunction.empty,
)
)

object JobGroupResponseDeserializer extends CustomSerializer[JobGroupResponse](implicit fmts =>
(
{
case o: JObject =>
JobGroupResponse(
batch_id = (o \ "batch_id").extract[Int],
job_group_id = (o \ "job_group_id").extract[Int],
state = (o \ "state").extract[JobGroupState],
complete = (o \ "complete").extract[Boolean],
n_jobs = (o \ "n_jobs").extract[Int],
n_completed = (o \ "n_completed").extract[Int],
n_succeeded = (o \ "n_succeeded").extract[Int],
n_failed = (o \ "n_failed").extract[Int],
n_cancelled = (o \ "n_failed").extract[Int],
)
},
PartialFunction.empty,
)
)

object JarSpecSerializer extends CustomSerializer[JarSpec](_ =>
(
PartialFunction.empty,
{
case JarUrl(url) =>
JObject("type" -> JString("jar_url"), "value" -> JString(url))
case GitRevision(sha) =>
JObject("type" -> JString("git_revision"), "value" -> JString(sha))
},
)
)
}

case class BatchClient private (req: Requester) extends Logging with AutoCloseable {
Expand Down Expand Up @@ -220,6 +288,10 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
.merge(
JObject(
"job_id" -> JInt(jobIdx + 1),
// Batch allows you to create multiple job groups in an update.
// For Query, we only create one job group per stage and so this i
// hidden from the case class abstractions used by the ServiceBackend.
// Thus, this job belongs to the job group with job group id 1 in this update.
"in_update_job_group_id" -> JInt(1),
)
)
Expand Down Expand Up @@ -256,73 +328,4 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
)
)),
)

private[this] object JobProcessRequestSerializer
extends CustomSerializer[JobProcess](_ =>
(
PartialFunction.empty,
{
case BashJob(image, command) =>
JObject(
"type" -> JString("docker"),
"image" -> JString(image),
"command" -> JArray(command.map(JString).toList),
)
case JvmJob(command, jarSpec, profile) =>
JObject(
"type" -> JString("jvm"),
"command" -> JArray(command.map(JString).toList),
"jar_spec" -> Extraction.decompose(jarSpec),
"profile" -> JBool(profile),
)
},
)
)

private[this] object JobGroupStateDeserializer
extends CustomSerializer[JobGroupState](_ =>
(
{
case JString("failure") => JobGroupStates.Failure
case JString("cancelled") => JobGroupStates.Cancelled
case JString("success") => JobGroupStates.Success
case JString("running") => JobGroupStates.Running
},
PartialFunction.empty,
)
)

private[this] object JobGroupResponseDeserializer
extends CustomSerializer[JobGroupResponse](implicit fmts =>
(
{
case o: JObject =>
JobGroupResponse(
batch_id = (o \ "batch_id").extract[Int],
job_group_id = (o \ "job_group_id").extract[Int],
state = (o \ "state").extract[JobGroupState],
complete = (o \ "complete").extract[Boolean],
n_jobs = (o \ "n_jobs").extract[Int],
n_completed = (o \ "n_completed").extract[Int],
n_succeeded = (o \ "n_succeeded").extract[Int],
n_failed = (o \ "n_failed").extract[Int],
n_cancelled = (o \ "n_failed").extract[Int],
)
},
PartialFunction.empty,
)
)

private[this] object JarSpecSerializer
extends CustomSerializer[JarSpec](_ =>
(
PartialFunction.empty,
{
case JarUrl(url) =>
JObject("type" -> JString("jar_url"), "value" -> JString(url))
case GitRevision(sha) =>
JObject("type" -> JString("git_revision"), "value" -> JString(sha))
},
)
)
}
7 changes: 6 additions & 1 deletion hail/src/main/scala/is/hail/services/oauth2.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package is.hail.services

import is.hail.services.oauth2.AzureCloudCredentials.AzureTokenRefreshMinutes
import is.hail.services.oauth2.AzureCloudCredentials.EnvVars.AzureApplicationCredentials
import is.hail.services.oauth2.GoogleCloudCredentials.EnvVars.GoogleApplicationCredentials
import is.hail.shadedazure.com.azure.core.credential.{
Expand Down Expand Up @@ -88,14 +89,18 @@ object oauth2 {
}

private[this] def isExpired: Boolean =
token == null || OffsetDateTime.now.plusMinutes(5).isBefore(token.getExpiresAt)
token == null || OffsetDateTime.now.plusMinutes(AzureTokenRefreshMinutes).isBefore(
token.getExpiresAt
)
}

object AzureCloudCredentials {
object EnvVars {
val AzureApplicationCredentials = "AZURE_APPLICATION_CREDENTIALS"
}

private[AzureCloudCredentials] val AzureTokenRefreshMinutes = 5

def apply(keyPath: Option[Path], scopes: IndexedSeq[String], env: Map[String, String] = sys.env)
: AzureCloudCredentials =
keyPath.orElse(env.get(AzureApplicationCredentials).map(Path.of(_))) match {
Expand Down
18 changes: 10 additions & 8 deletions hail/src/main/scala/is/hail/services/requests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,32 @@ object requests {
def patch(route: String): JValue
}

private[this] val TIMEOUT_MS = 5 * 1000
private[this] val TimeoutMs = 5 * 1000
private[this] val MaxNumConnectionPerRoute = 20
private[this] val MaxNumConnections = 100

def Requester(baseUrl: URL, cred: CloudCredentials): Requester = {

val httpClient: CloseableHttpClient = {
log.info("creating HttpClient")
val requestConfig = RequestConfig.custom()
.setConnectTimeout(TIMEOUT_MS)
.setConnectionRequestTimeout(TIMEOUT_MS)
.setSocketTimeout(TIMEOUT_MS)
.setConnectTimeout(TimeoutMs)
.setConnectionRequestTimeout(TimeoutMs)
.setSocketTimeout(TimeoutMs)
.build()
try {
HttpClients.custom()
.setSSLContext(tls.getSSLContext)
.setMaxConnPerRoute(20)
.setMaxConnTotal(100)
.setMaxConnPerRoute(MaxNumConnectionPerRoute)
.setMaxConnTotal(MaxNumConnections)
.setDefaultRequestConfig(requestConfig)
.build()
} catch {
case _: NoSSLConfigFound =>
log.info("creating HttpClient with no SSL Context")
HttpClients.custom()
.setMaxConnPerRoute(20)
.setMaxConnTotal(100)
.setMaxConnPerRoute(MaxNumConnectionPerRoute)
.setMaxConnTotal(MaxNumConnections)
.setDefaultRequestConfig(requestConfig)
.build()
}
Expand Down

0 comments on commit 0a14104

Please sign in to comment.