diff --git a/llm-lora-finetuning/pipelines/train_accelerated.py b/llm-lora-finetuning/pipelines/train_accelerated.py index eecd228b..ff3ca92a 100644 --- a/llm-lora-finetuning/pipelines/train_accelerated.py +++ b/llm-lora-finetuning/pipelines/train_accelerated.py @@ -18,13 +18,12 @@ from steps import ( evaluate_model, - finetune, + finetune_accelerated, prepare_data, promote, log_metadata_from_step_artifact, ) from zenml import pipeline -from zenml.integrations.huggingface.steps import run_with_accelerate @pipeline @@ -74,7 +73,7 @@ def llm_peft_full_finetune( id="log_metadata_evaluation_base" ) - ft_model_dir = run_with_accelerate(finetune)( + ft_model_dir = finetune_accelerated( base_model_id=base_model_id, dataset_dir=datasets_dir, use_fast=use_fast, diff --git a/llm-lora-finetuning/steps/__init__.py b/llm-lora-finetuning/steps/__init__.py index 317b6b4c..60912212 100644 --- a/llm-lora-finetuning/steps/__init__.py +++ b/llm-lora-finetuning/steps/__init__.py @@ -16,7 +16,7 @@ # from .evaluate_model import evaluate_model -from .finetune import finetune +from .finetune import finetune, finetune_accelerated from .prepare_datasets import prepare_data from .promote import promote from .log_metadata import log_metadata_from_step_artifact diff --git a/llm-lora-finetuning/steps/finetune.py b/llm-lora-finetuning/steps/finetune.py index ece4b1b0..2e362103 100644 --- a/llm-lora-finetuning/steps/finetune.py +++ b/llm-lora-finetuning/steps/finetune.py @@ -32,6 +32,7 @@ from zenml.materializers import BuiltInMaterializer from zenml.utils.cuda_utils import cleanup_gpu_memory from zenml.client import Client +from zenml.integrations.huggingface.steps import run_with_accelerate logger = get_logger(__name__) @@ -184,3 +185,6 @@ def finetune( ) return ft_model_dir + + +finetune_accelerated = run_with_accelerate(finetune, num_processes=2, multi_gpu=True, mixed_precision="bf16") \ No newline at end of file