Skip to content

Commit

Permalink
TPU Metrics PodMonitoring (#761)
Browse files Browse the repository at this point in the history
* first commit

* terraform fmt

* remove default

* more descriptive metric_scrape_interval comment

* more descriptive comments

* object for metrics config

* update readme

* update tutorial readme
  • Loading branch information
Bslabe123 authored Aug 5, 2024
1 parent c166bfa commit 54531da
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 20 deletions.
20 changes: 16 additions & 4 deletions modules/jetstream-maxtext-deployment/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Assure the following environment variables are set:
- MODEL_NAME: The name of your LLM (as of the writing of this README valid options are "gemma-7b", "llama2-7b", "llama2-13b")
- PARAMETERS_PATH: Where to find the parameters for your LLM (if using the checkpoint-converter it will be "gs:\/\/$BUCKET_NAME\/final\/unscanned\/gemma_7b-it\/0\/checkpoints\/0\/items" where $BUCKET_NAME is the same one used in the checkpoint-converter)
- (optional) METRICS_PORT: Port to emit custom metrics on
- (optional) SERVER_METRICS_SCRAPE_INTERVAL: How often to scrape Jetstream server metrics
- (optional) SYSTEM_METRICS_SCRAPE_INTERVAL: How often to scrape TPU system metrics
- (optional) TPU_TOPOLOGY: Topology of TPU chips used by jetstream (default: "2x4")
- (optional) TPU_TYPE: Type of TPUs used (default: "tpu-v5-lite-podslice")
- (optional) TPU_CHIP_COUNT: Number of TPU chips requested, can be obtained by algebraically evaluating TPU_TOPOLOGY
Expand Down Expand Up @@ -53,11 +55,21 @@ cat ./templates/deployment.yaml.tftpl >> "$JETSTREAM_MANIFEST"
PODMONITORING_MANIFEST=$(mktemp)
cat ./templates/podmonitoring.yaml.tftpl >> "$PODMONITORING_MANIFEST"
if [ "$METRICS_PORT" != "" ]; then
cat $PODMONITORING_MANIFEST | sed "s/\${metrics_port}/$METRICS_PORT/g" >> "$PODMONITORING_MANIFEST"
cat $JETSTREAM_MANIFEST | sed "s/\${metrics_port_arg}/prometheus_port=$METRICS_PORT/g" >> "$JETSTREAM_MANIFEST"
PODMONITORING_TPU_MANIFEST=$(mktemp)
cat ./templates/podmonitoring-tpu.yaml.tftpl >> "$PODMONITORING_TPU_MANIFEST"
if [ "$SYSTEM_METRICS_SCRAPE_INTERVAL" != "" ]; then
cat $PODMONITORING_TPU_MANIFEST \
| sed "s/\${metrics_scrape_interval}/$SYSTEM_METRICS_SCRAPE_INTERVAL/g" >> "$PODMONITORING_TPU_MANIFEST"
cat $PODMONITORING_TPU_MANIFEST | kubectl apply -f -
fi
if [ "$METRICS_PORT" != "" ] && [ "$SERVER_METRICS_SCRAPE_INTERVAL" != "" ]; then
cat $PODMONITORING_MANIFEST \
| sed "s/\${metrics_port}/$METRICS_PORT/g" \
| sed "s/\${metrics_scrape_interval}/$SERVER_METRICS_SCRAPE_INTERVAL/g" >> "$PODMONITORING_MANIFEST"
cat $PODMONITORING_MANIFEST | kubectl apply -f -
cat $JETSTREAM_MANIFEST | sed "s/\${metrics_port_arg}/prometheus_port=$METRICS_PORT/g" >> "$JETSTREAM_MANIFEST"
else
cat $JETSTREAM_MANIFEST | sed "s/\${metrics_port_arg}//g" >> "$JETSTREAM_MANIFEST"
fi
Expand Down
16 changes: 12 additions & 4 deletions modules/jetstream-maxtext-deployment/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ locals {
deployment_template = "${path.module}/templates/deployment.yaml.tftpl"
service_template = "${path.module}/templates/service.yaml.tftpl"
podmonitoring_template = "${path.module}/templates/podmonitoring.yaml.tftpl"
podmonitoring_tpu_template = "${path.module}/templates/podmonitoring-tpu.yaml.tftpl"
cmsa_jetstream_hpa_template = "${path.module}/templates/custom-metrics-stackdriver-adapter/hpa.jetstream.yaml.tftpl"
prometheus_jetstream_hpa_template = "${path.module}/templates/prometheus-adapter/hpa.jetstream.yaml.tftpl"
}
Expand All @@ -30,7 +31,7 @@ resource "kubernetes_manifest" "jetstream-deployment" {
model_name = var.maxengine_deployment_settings.model_name
tokenizer = strcontains(var.maxengine_deployment_settings.model_name, "gemma") ? "assets/tokenizer.gemma" : (strcontains(var.maxengine_deployment_settings.model_name, "llama") ? "assets/tokenizer.llama2" : "")
load_parameters_path_arg = var.maxengine_deployment_settings.parameters_path
metrics_port_arg = var.maxengine_deployment_settings.metrics_port != null ? format("prometheus_port=%d", var.maxengine_deployment_settings.metrics_port) : "",
metrics_port_arg = try(format("prometheus_port=%d", var.maxengine_deployment_settings.metrics.server.port), ""),
tpu-topology = var.maxengine_deployment_settings.accelerator_selectors.topology
tpu-type = var.maxengine_deployment_settings.accelerator_selectors.accelerator
tpu-chip-count = var.maxengine_deployment_settings.accelerator_selectors.chip_count
Expand All @@ -43,10 +44,17 @@ resource "kubernetes_manifest" "jetstream-service" {
}

resource "kubernetes_manifest" "jetstream-podmonitoring" {
count = var.maxengine_deployment_settings.metrics_port != null ? 1 : 0
count = try(var.maxengine_deployment_settings.metrics.server != null ? 1 : 0, 0)
manifest = yamldecode(templatefile(local.podmonitoring_template, {
metrics_port = var.maxengine_deployment_settings.metrics_port != null ? var.maxengine_deployment_settings.metrics_port : "",
metrics_scrape_interval = var.maxengine_deployment_settings.metrics_scrape_interval
metrics_port = try(var.maxengine_deployment_settings.metrics.server.port, 0),
metrics_scrape_interval = try(var.maxengine_deployment_settings.metrics.server.scrape_interval, 0),
}))
}

resource "kubernetes_manifest" "jetstream-podmonitoring-tpu" {
count = try(var.maxengine_deployment_settings.metrics.system != null ? 1 : 0, 0)
manifest = yamldecode(templatefile(local.podmonitoring_tpu_template, {
metrics_scrape_interval = try(var.maxengine_deployment_settings.metrics.system.scrape_interval, 0)
}))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ spec:
- load_parameters_path=${load_parameters_path_arg}
- ${metrics_port_arg}
ports:
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
- containerPort: 9000
resources:
requests:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
apiVersion: monitoring.googleapis.com/v1
kind: PodMonitoring
metadata:
name: tpu-metrics-exporter
namespace: kube-system
labels:
k8s-app: tpu-device-plugin
spec:
endpoints:
- port: 2112
interval: ${metrics_scrape_interval}s
selector:
matchLabels:
k8s-app: tpu-device-plugin
26 changes: 22 additions & 4 deletions modules/jetstream-maxtext-deployment/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,18 @@ variable "maxengine_deployment_settings" {
maxengine_server_image = optional(string, "us-docker.pkg.dev/cloud-tpu-images/inference/maxengine-server:v0.2.2")
jetstream_http_server_image = optional(string, "us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.2")

model_name = string // Name of your LLM (for example: "gemma-7b")
parameters_path = string // Path to the paramters for your model
metrics_port = optional(number) // Emit Jetstream metrics on this port of each container
metrics_scrape_interval = optional(number) // Interval for scraping metrics (default: 10s)
model_name = string // Name of your LLM (for example: "gemma-7b")
parameters_path = string // Path to the parameters for your model

metrics = optional(object({ // Settings for metrics server
server = optional(object({ // Settings for Jetstream server metrics
port = number
scrape_interval = number
}))
system = optional(object({ // Settings for TPU metrics
scrape_interval = number
}))
}))

accelerator_selectors = object({
topology = string
Expand All @@ -45,6 +53,16 @@ variable "maxengine_deployment_settings" {
condition = contains(["gemma-7b", "llama2-7b", "llama2-13b"], var.maxengine_deployment_settings.model_name)
error_message = "model_name must be one of \"gemma-7b\", \"llama2-7b\", or \"llama2-13b\""
}

validation {
condition = try(var.maxengine_deployment_settings.metrics.server.scrape_interval >= 5, true)
error_message = "Server metrics scrape interval may not be shorter than 5s"
}

validation {
condition = try(var.maxengine_deployment_settings.metrics.system.scrape_interval >= 15, true)
error_message = "TPU system metrics scrape interval may not be shorter than 15s"
}
}

variable "hpa_config" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ For deploying autoscaling components via terraform, a few more variables to be s

```
maxengine_deployment_settings = {
metrics_port = <same as above>
metrics_scrape_interval
metrics = {
port: <same as above> # which port will we scrape server metrics from
scrape_interval: 5s # how often do we scrape
}
}
hpa_config = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
maxengine_deployment_settings = {
metrics_port = 9100
metrics_scrape_interval = 10
metrics = {
server = {
port = 9100
scrape_interval : 10
}
}

accelerator_selectors = {
topology = "2x4"
accelerator = "tpu-v5-lite-podslice"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,18 @@ variable "maxengine_deployment_settings" {
maxengine_server_image = optional(string)
jetstream_http_server_image = optional(string)

model_name = string // Name of your LLM (for example: "gemma-7b")
parameters_path = string // Path to the parameters for your model
metrics_port = optional(number) // Emit Jetstream metrics on this port of each container
metrics_scrape_interval = optional(number) // Interval for scraping metrics (default: 10s)
model_name = string // Name of your LLM (for example: "gemma-7b")
parameters_path = string // Path to the parameters for your model

metrics = optional(object({ // Settings for metrics server
server = optional(object({ // Settings for Jetstream server metrics
port = number
scrape_interval = number
}))
system = optional(object({ // Settings for TPU metrics
scrape_interval = number
}))
}))

accelerator_selectors = object({
topology = string
Expand Down

0 comments on commit 54531da

Please sign in to comment.