diff --git a/dagster_skypilot/assets.py b/dagster_skypilot/assets.py index b64fe1e..d0f6644 100644 --- a/dagster_skypilot/assets.py +++ b/dagster_skypilot/assets.py @@ -1,8 +1,11 @@ +import json import os -from pathlib import Path +import sky +import yaml from dagster import AssetExecutionContext, asset from dagster_shell import execute_shell_command +from upath import UPath from dagster_skypilot.consts import DEPLOYMENT_TYPE @@ -15,8 +18,8 @@ def populate_keyfiles(): This reads the credentials for AWS and Lambda Labs from env vars (set in the Dagster Cloud UI) and then populates the expected key files accordingly. """ - lambda_key_file = Path.home() / ".lambda_cloud" / "lambda_keys" - aws_key_file = Path.home() / ".aws" / "credentials" + lambda_key_file = UPath.home() / ".lambda_cloud" / "lambda_keys" + aws_key_file = UPath.home() / ".aws" / "credentials" # Don't overwrite local keys, but always populate them dynamically in # Dagster Cloud @@ -35,17 +38,79 @@ def populate_keyfiles(): ) +def get_metrics(context: AssetExecutionContext, bucket): + with (UPath(bucket) / context.run_id / "train_results.json").open("r") as f: + return json.load(f) + + @asset(group_name="ai") -def skypilot_model(context: AssetExecutionContext) -> None: +def skypilot_yaml(context: AssetExecutionContext) -> None: # SkyPilot doesn't support reading credentials from environment variables. # So, we need to populate the required keyfiles. populate_keyfiles() - execute_shell_command( - "sky launch -c dnn dnn.yaml --yes -i 5 --down", - output_logging="STREAM", - log=context.log, - cwd=str(Path(__file__).parent.parent), - # Disable color and styling for rich - env=os.environ | {"TERM": "dumb"}, + skypilot_bucket = os.getenv("SKYPILOT_BUCKET") + + try: + execute_shell_command( + "sky launch -c gemma finetune.yaml --env HF_TOKEN --env DAGSTER_RUN_ID --yes", # -i 5 --down", + output_logging="STREAM", + log=context.log, + cwd=str(UPath(__file__).parent), + # Disable color and styling for rich + env={ + **os.environ, + "TERM": "dumb", + "NO_COLOR": "1", + "HF_TOKEN": os.getenv("HF_TOKEN", ""), + "DAGSTER_RUN_ID": context.run_id, + }, + ) + + context.add_output_metadata(get_metrics(context, skypilot_bucket)) + finally: + execute_shell_command( + "sky down --all --yes", # -i 5 --down", + output_logging="STREAM", + log=context.log, + cwd=str(UPath(__file__).parent), + # Disable color and styling for rich + env={ + **os.environ, + "TERM": "dumb", + "NO_COLOR": "1", + }, + ) + + +@asset(group_name="ai") +def skypilot_python_api(context: AssetExecutionContext) -> None: + # SkyPilot doesn't support reading credentials from environment variables. + # So, we need to populate the required keyfiles. + populate_keyfiles() + + skypilot_bucket = os.getenv("SKYPILOT_BUCKET") + + # The parent of the current script + parent_dir = UPath(__file__).parent + yaml_file = parent_dir / "finetune.yaml" + with yaml_file.open("r", encoding="utf-8") as f: + task_config = yaml.safe_load(f) + + task = sky.Task().from_yaml_config( + config=task_config, + env_overrides={ # type: ignore + "HF_TOKEN": os.getenv("HF_TOKEN", ""), + "DAGSTER_RUN_ID": context.run_id, + "SCRIPT_WORKING_DIR": str(parent_dir / "scripts"), + "BUCKET_NAME": skypilot_bucket, + }, ) + + try: + sky.launch(task, cluster_name="gemma", idle_minutes_to_autostop=5) # type: ignore + context.add_output_metadata(get_metrics(context, skypilot_bucket)) + + finally: + sky.down("gemma") + ... diff --git a/setup.py b/setup.py index 48b27bb..d92dd8e 100644 --- a/setup.py +++ b/setup.py @@ -6,8 +6,10 @@ install_requires=[ "dagster>=1.6.0,<1.7.0", "dagster-cloud", - "skypilot[aws,azure,gcp]", "dagster-shell", + "universal_pathlib", + "s3fs", + "skypilot[aws,azure,gcp]", ], extras_require={ "dev": [