Skip to content

Commit

Permalink
Have the YAML and Python equivalent as different assets and ...
Browse files Browse the repository at this point in the history
read the metrics from S3 and log them as metadata in Dagster.
  • Loading branch information
mjkanji committed Mar 4, 2024
1 parent 04baae3 commit 4357e8b
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 12 deletions.
87 changes: 76 additions & 11 deletions dagster_skypilot/assets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import os
from pathlib import Path

import sky
import yaml
from dagster import AssetExecutionContext, asset
from dagster_shell import execute_shell_command
from upath import UPath

from dagster_skypilot.consts import DEPLOYMENT_TYPE

Expand All @@ -15,8 +18,8 @@ def populate_keyfiles():
This reads the credentials for AWS and Lambda Labs from env vars (set in the
Dagster Cloud UI) and then populates the expected key files accordingly.
"""
lambda_key_file = Path.home() / ".lambda_cloud" / "lambda_keys"
aws_key_file = Path.home() / ".aws" / "credentials"
lambda_key_file = UPath.home() / ".lambda_cloud" / "lambda_keys"
aws_key_file = UPath.home() / ".aws" / "credentials"

# Don't overwrite local keys, but always populate them dynamically in
# Dagster Cloud
Expand All @@ -35,17 +38,79 @@ def populate_keyfiles():
)


def get_metrics(context: AssetExecutionContext, bucket):
with (UPath(bucket) / context.run_id / "train_results.json").open("r") as f:
return json.load(f)


@asset(group_name="ai")
def skypilot_model(context: AssetExecutionContext) -> None:
def skypilot_yaml(context: AssetExecutionContext) -> None:
# SkyPilot doesn't support reading credentials from environment variables.
# So, we need to populate the required keyfiles.
populate_keyfiles()

execute_shell_command(
"sky launch -c dnn dnn.yaml --yes -i 5 --down",
output_logging="STREAM",
log=context.log,
cwd=str(Path(__file__).parent.parent),
# Disable color and styling for rich
env=os.environ | {"TERM": "dumb"},
skypilot_bucket = os.getenv("SKYPILOT_BUCKET")

try:
execute_shell_command(
"sky launch -c gemma finetune.yaml --env HF_TOKEN --env DAGSTER_RUN_ID --yes", # -i 5 --down",
output_logging="STREAM",
log=context.log,
cwd=str(UPath(__file__).parent),
# Disable color and styling for rich
env={
**os.environ,
"TERM": "dumb",
"NO_COLOR": "1",
"HF_TOKEN": os.getenv("HF_TOKEN", ""),
"DAGSTER_RUN_ID": context.run_id,
},
)

context.add_output_metadata(get_metrics(context, skypilot_bucket))
finally:
execute_shell_command(
"sky down --all --yes", # -i 5 --down",
output_logging="STREAM",
log=context.log,
cwd=str(UPath(__file__).parent),
# Disable color and styling for rich
env={
**os.environ,
"TERM": "dumb",
"NO_COLOR": "1",
},
)


@asset(group_name="ai")
def skypilot_python_api(context: AssetExecutionContext) -> None:
# SkyPilot doesn't support reading credentials from environment variables.
# So, we need to populate the required keyfiles.
populate_keyfiles()

skypilot_bucket = os.getenv("SKYPILOT_BUCKET")

# The parent of the current script
parent_dir = UPath(__file__).parent
yaml_file = parent_dir / "finetune.yaml"
with yaml_file.open("r", encoding="utf-8") as f:
task_config = yaml.safe_load(f)

task = sky.Task().from_yaml_config(
config=task_config,
env_overrides={ # type: ignore
"HF_TOKEN": os.getenv("HF_TOKEN", ""),
"DAGSTER_RUN_ID": context.run_id,
"SCRIPT_WORKING_DIR": str(parent_dir / "scripts"),
"BUCKET_NAME": skypilot_bucket,
},
)

try:
sky.launch(task, cluster_name="gemma", idle_minutes_to_autostop=5) # type: ignore
context.add_output_metadata(get_metrics(context, skypilot_bucket))

finally:
sky.down("gemma")
...
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
install_requires=[
"dagster>=1.6.0,<1.7.0",
"dagster-cloud",
"skypilot[aws,azure,gcp]",
"dagster-shell",
"universal_pathlib",
"s3fs",
"skypilot[aws,azure,gcp]",
],
extras_require={
"dev": [
Expand Down

0 comments on commit 4357e8b

Please sign in to comment.