From 8884e83cbd1f44a0ea1b3286e9bd1b4f08e7f06b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D0=BB=D0=B0=D1=85=D0=BE=D0=B2=20=D0=90=D0=BB?= =?UTF-8?q?=D0=B5=D0=BA=D1=81=D0=B5=D0=B9=20=D0=9F=D0=B0=D0=B2=D0=BB=D0=BE?= =?UTF-8?q?=D0=B2=D0=B8=D1=87?= Date: Thu, 12 Sep 2024 12:37:09 +0000 Subject: [PATCH] add clerml logging --- configs/exp/kto/kto.json | 2 +- .../train/classification/classification.json | 2 +- configs/exp/train/dpo/dpo.json | 2 +- configs/exp/train/multimodal/c_abs.json | 2 +- configs/exp/train/multimodal/mlp.json | 2 +- configs/exp/train/rag/end2end_rag.json | 2 +- configs/exp/train/rm/rm.json | 2 +- configs/exp/train/sft/sft.json | 2 +- .../configs/train/classification/base.json | 2 +- tests/fixtures/configs/train/ddpo/base.json | 2 +- tests/fixtures/configs/train/dpo/base.json | 2 +- tests/fixtures/configs/train/dpo/simpo.json | 2 +- tests/fixtures/configs/train/kto/base.json | 2 +- .../multimodal/llama_c_abs_clip_pickle.json | 2 +- .../multimodal/llama_llava_base_clip.json | 2 +- .../multimodal/llama_llava_clip_pickle.json | 2 +- tests/fixtures/configs/train/rag/base.json | 2 +- tests/fixtures/configs/train/rm/base.json | 2 +- tests/fixtures/configs/train/sft/base.json | 2 +- .../configs/train/sft/prompt_tuning.json | 2 +- .../train/sft/resume_from_checkpoint.json | 2 +- .../configs/train/sft/sft_with_rm_metric.json | 2 +- .../experiment_settings_config.json | 2 +- turbo_alignment/common/logging/clearml.py | 15 +++ .../common/logging/weights_and_biases.py | 2 +- .../common/tf/callbacks/__init__.py | 2 +- turbo_alignment/common/tf/callbacks/common.py | 32 +++-- .../common/tf/callbacks/logging.py | 118 ++++++++++++++++++ turbo_alignment/pipelines/mixin/__init__.py | 2 +- turbo_alignment/pipelines/mixin/logging.py | 40 +++++- turbo_alignment/pipelines/train/base.py | 18 ++- turbo_alignment/settings/logging/clearml.py | 12 ++ .../{ => logging}/weights_and_biases.py | 8 +- .../settings/pipelines/train/base.py | 5 +- turbo_alignment/trainers/ddpo.py | 4 +- turbo_alignment/trainers/dpo.py | 4 +- turbo_alignment/trainers/kto.py | 4 +- turbo_alignment/trainers/multigpu.py | 4 +- tutorials/dpo/dpo.json | 2 +- tutorials/kto/kto.json | 2 +- tutorials/multimodal/multimodal.json | 2 +- tutorials/rm/rm.json | 2 +- tutorials/sft/sft.json | 2 +- 43 files changed, 248 insertions(+), 78 deletions(-) create mode 100644 turbo_alignment/common/logging/clearml.py create mode 100755 turbo_alignment/common/tf/callbacks/logging.py create mode 100644 turbo_alignment/settings/logging/clearml.py rename turbo_alignment/settings/{ => logging}/weights_and_biases.py (68%) diff --git a/configs/exp/kto/kto.json b/configs/exp/kto/kto.json index 91ccffa..8213302 100644 --- a/configs/exp/kto/kto.json +++ b/configs/exp/kto/kto.json @@ -127,7 +127,7 @@ "use_ref_model": true, "deepspeed": "configs/exp/deepspeed/stage3.json" }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "sft", "entity": "turbo-alignment" diff --git a/configs/exp/train/classification/classification.json b/configs/exp/train/classification/classification.json index 5257230..8ab1f63 100755 --- a/configs/exp/train/classification/classification.json +++ b/configs/exp/train/classification/classification.json @@ -112,7 +112,7 @@ }, "save_total_limit": 10 }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "classification", "entity": "turbo-alignment" diff --git a/configs/exp/train/dpo/dpo.json b/configs/exp/train/dpo/dpo.json index 61bb787..ba6e9fe 100755 --- a/configs/exp/train/dpo/dpo.json +++ b/configs/exp/train/dpo/dpo.json @@ -136,7 +136,7 @@ "use_ref_model": true, "deepspeed": "configs/exp/deepspeed/stage2.json" }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "sft", "entity": "turbo-alignment" diff --git a/configs/exp/train/multimodal/c_abs.json b/configs/exp/train/multimodal/c_abs.json index 46f00f7..54a9860 100644 --- a/configs/exp/train/multimodal/c_abs.json +++ b/configs/exp/train/multimodal/c_abs.json @@ -127,7 +127,7 @@ "load_best_model_at_end": false, "deepspeed": "configs/exp/deepspeed/ds_config_stage_2.json" }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "multimodal", "entity": "turbo-alignment" diff --git a/configs/exp/train/multimodal/mlp.json b/configs/exp/train/multimodal/mlp.json index 68a8925..a2fc824 100644 --- a/configs/exp/train/multimodal/mlp.json +++ b/configs/exp/train/multimodal/mlp.json @@ -127,7 +127,7 @@ "load_best_model_at_end": false, "deepspeed": "configs/exp/deepspeed/ds_config_stage_2.json" }, - "wandb_settings": null, + "logging_settings": null, "log_path": "train_output", "modality_encoder_settings_mapping": { "image": { diff --git a/configs/exp/train/rag/end2end_rag.json b/configs/exp/train/rag/end2end_rag.json index e0b5cad..4fa4641 100755 --- a/configs/exp/train/rag/end2end_rag.json +++ b/configs/exp/train/rag/end2end_rag.json @@ -139,7 +139,7 @@ "max_grad_norm": 0.11, "save_total_limit": 3 }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "rag", "entity": "biglm" diff --git a/configs/exp/train/rm/rm.json b/configs/exp/train/rm/rm.json index b9bd08c..cd6dbaf 100755 --- a/configs/exp/train/rm/rm.json +++ b/configs/exp/train/rm/rm.json @@ -112,7 +112,7 @@ "save_total_limit": 1, "deepspeed": "configs/exp/deepspeed/stage2.json" }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "rm", "entity": "turbo-alignment" diff --git a/configs/exp/train/sft/sft.json b/configs/exp/train/sft/sft.json index f746e83..be3c0e0 100755 --- a/configs/exp/train/sft/sft.json +++ b/configs/exp/train/sft/sft.json @@ -132,7 +132,7 @@ "weight_decay": 0.01, "max_grad_norm": 0.11 }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "sft", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/classification/base.json b/tests/fixtures/configs/train/classification/base.json index c1f4158..127fd26 100755 --- a/tests/fixtures/configs/train/classification/base.json +++ b/tests/fixtures/configs/train/classification/base.json @@ -114,7 +114,7 @@ }, "no_cuda": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "classification", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/ddpo/base.json b/tests/fixtures/configs/train/ddpo/base.json index 57e6b2d..82c48da 100755 --- a/tests/fixtures/configs/train/ddpo/base.json +++ b/tests/fixtures/configs/train/ddpo/base.json @@ -155,7 +155,7 @@ "no_cuda": true, "max_grad_norm": 1.0 }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "ddpo", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/dpo/base.json b/tests/fixtures/configs/train/dpo/base.json index 514d462..3140625 100755 --- a/tests/fixtures/configs/train/dpo/base.json +++ b/tests/fixtures/configs/train/dpo/base.json @@ -136,7 +136,7 @@ "use_sft_model": true, "no_cuda": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "dpo", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/dpo/simpo.json b/tests/fixtures/configs/train/dpo/simpo.json index 01233b8..293093a 100755 --- a/tests/fixtures/configs/train/dpo/simpo.json +++ b/tests/fixtures/configs/train/dpo/simpo.json @@ -130,7 +130,7 @@ "use_sft_model": true, "no_cuda": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "dpo", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/kto/base.json b/tests/fixtures/configs/train/kto/base.json index af53a14..a06d499 100755 --- a/tests/fixtures/configs/train/kto/base.json +++ b/tests/fixtures/configs/train/kto/base.json @@ -110,7 +110,7 @@ "save_total_limit": 1, "no_cuda": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "kto", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json b/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json index 49c2b9c..38130b4 100644 --- a/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json +++ b/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json @@ -125,7 +125,7 @@ "save_total_limit": 1, "no_cuda": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "multimodal", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json b/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json index 01b5dce..d95ee19 100644 --- a/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json +++ b/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json @@ -125,7 +125,7 @@ "save_total_limit": 1, "no_cuda": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "multimodal", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json b/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json index a16a69c..71e04b0 100644 --- a/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json +++ b/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json @@ -125,7 +125,7 @@ "save_total_limit": 1, "no_cuda": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "multimodal", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/rag/base.json b/tests/fixtures/configs/train/rag/base.json index 5fc808e..d54b66f 100755 --- a/tests/fixtures/configs/train/rag/base.json +++ b/tests/fixtures/configs/train/rag/base.json @@ -145,7 +145,7 @@ "no_cuda": true, "save_total_limit": 1 }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "rag", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/rm/base.json b/tests/fixtures/configs/train/rm/base.json index e3fed58..e1d5e21 100755 --- a/tests/fixtures/configs/train/rm/base.json +++ b/tests/fixtures/configs/train/rm/base.json @@ -112,7 +112,7 @@ "save_total_limit": 1, "no_cuda": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "rm", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/sft/base.json b/tests/fixtures/configs/train/sft/base.json index 1a48d35..4cb4cdb 100755 --- a/tests/fixtures/configs/train/sft/base.json +++ b/tests/fixtures/configs/train/sft/base.json @@ -122,7 +122,7 @@ "save_total_limit": 1, "no_cuda": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "sft", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/sft/prompt_tuning.json b/tests/fixtures/configs/train/sft/prompt_tuning.json index ca3eb88..4ef7df4 100755 --- a/tests/fixtures/configs/train/sft/prompt_tuning.json +++ b/tests/fixtures/configs/train/sft/prompt_tuning.json @@ -117,7 +117,7 @@ "save_total_limit": 1, "no_cuda": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "sft", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/sft/resume_from_checkpoint.json b/tests/fixtures/configs/train/sft/resume_from_checkpoint.json index f7e75a4..28c3f03 100755 --- a/tests/fixtures/configs/train/sft/resume_from_checkpoint.json +++ b/tests/fixtures/configs/train/sft/resume_from_checkpoint.json @@ -101,7 +101,7 @@ "save_total_limit": 1, "no_cuda": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "sft", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/sft/sft_with_rm_metric.json b/tests/fixtures/configs/train/sft/sft_with_rm_metric.json index ee22909..19bedde 100755 --- a/tests/fixtures/configs/train/sft/sft_with_rm_metric.json +++ b/tests/fixtures/configs/train/sft/sft_with_rm_metric.json @@ -144,7 +144,7 @@ "save_total_limit": 1, "no_cuda": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "sft", "entity": "turbo-alignment" diff --git a/tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/experiment_settings_config.json b/tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/experiment_settings_config.json index 8f6cee4..3425c8f 100755 --- a/tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/experiment_settings_config.json +++ b/tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/experiment_settings_config.json @@ -51,7 +51,7 @@ }, "log_path": "train_output", "seed": 42, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "fine_tune_llama", "entity": "vladislavkruglikov", diff --git a/turbo_alignment/common/logging/clearml.py b/turbo_alignment/common/logging/clearml.py new file mode 100644 index 0000000..c429439 --- /dev/null +++ b/turbo_alignment/common/logging/clearml.py @@ -0,0 +1,15 @@ +from typing import Any + +from clearml import Task + +from turbo_alignment.settings.logging.clearml import ClearMLSettings + + +def create_clearml_task(parameters: ClearMLSettings, config: dict[str, Any] | None = None) -> Task: + clearml_task = Task.init( + task_name=parameters.task_name, project_name=parameters.project_name, continue_last_task=True # FIXME? + ) + + clearml_task.connect_configuration(config, name='HyperParameters') + + return clearml_task diff --git a/turbo_alignment/common/logging/weights_and_biases.py b/turbo_alignment/common/logging/weights_and_biases.py index c62a205..112e585 100755 --- a/turbo_alignment/common/logging/weights_and_biases.py +++ b/turbo_alignment/common/logging/weights_and_biases.py @@ -4,7 +4,7 @@ from wandb.sdk.wandb_run import Run import wandb -from turbo_alignment.settings.weights_and_biases import WandbSettings +from turbo_alignment.settings.logging.weights_and_biases import WandbSettings def create_wandb_run(parameters: WandbSettings, config: dict[str, Any] | None = None) -> Run | RunDisabled: diff --git a/turbo_alignment/common/tf/callbacks/__init__.py b/turbo_alignment/common/tf/callbacks/__init__.py index 1e7dab5..6888c94 100755 --- a/turbo_alignment/common/tf/callbacks/__init__.py +++ b/turbo_alignment/common/tf/callbacks/__init__.py @@ -1,3 +1,3 @@ -from .common import EvaluateFirstStepCallback, WandbMetricsCallbackHandler +from .common import EvaluateFirstStepCallback, MetricsCallbackHandler from .s3 import CheckpointUploaderCallback from .weights_and_biases import BaseWandbCallback diff --git a/turbo_alignment/common/tf/callbacks/common.py b/turbo_alignment/common/tf/callbacks/common.py index 27504d5..108e251 100755 --- a/turbo_alignment/common/tf/callbacks/common.py +++ b/turbo_alignment/common/tf/callbacks/common.py @@ -11,7 +11,6 @@ ) from transformers.trainer_callback import CallbackHandler -import wandb from turbo_alignment.settings.metric import MetricResults @@ -23,7 +22,7 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T return control -class WandbMetricsCallbackHandler(CallbackHandler): +class MetricsCallbackHandler(CallbackHandler): def __init__(self, *args, **kwargs) -> None: self.ref_model = kwargs.pop('ref_model', None) self.sft_model = kwargs.pop('sft_model', None) @@ -50,46 +49,43 @@ def on_evaluate( if isinstance(results, list): gathered_results: list[list[MetricResults]] = gather_object(results) - gathered_float_wandb_data: dict[str, list[Any]] = defaultdict(list) - gathered_table_wandb_data: dict[str, list[str] | list[list[str]]] = defaultdict(list) + gathered_float_data: dict[str, list[Any]] = defaultdict(list) + gathered_table_data: dict[str, list[str] | list[list[str]]] = defaultdict(list) average_functions: dict[str, Callable] = {} for single_process_results in gathered_results: for metric_result in single_process_results: for score in metric_result.element_wise_scores: if metric_result.need_average: - gathered_float_wandb_data[score.label].extend(score.values) + gathered_float_data[score.label].extend(score.values) average_functions[score.label] = score.average_function else: - gathered_table_wandb_data[score.label].extend(score.values) + gathered_table_data[score.label].extend(score.values) logs = { 'cherry_pick_' + k: average_functions[k](list(*zip(*v))) if isinstance(v[0], tuple) else average_functions[k](v) - for k, v in gathered_float_wandb_data.items() + for k, v in gathered_float_data.items() } self.call_event('on_log', args, state, control, logs=logs) - wandb_table_cols = list(gathered_table_wandb_data.keys()) - wandb_table_data = list(gathered_table_wandb_data.values()) + table_cols = list(gathered_table_data.keys()) + table_data = list(gathered_table_data.values()) - flattened_wandb_data = [ - sum(item, []) if isinstance(item, list) and isinstance(item[0], list) else item - for item in wandb_table_data + flattened_table_data = [ + sum(item, []) if isinstance(item, list) and isinstance(item[0], list) else item for item in table_data ] # flatten list[lists] to display all outputs in wandb table - wandb_data = pd.DataFrame(columns=wandb_table_cols, data=list(zip(*flattened_wandb_data))) - dataset_prefixes = set(col.split('@@')[0] for col in wandb_data.columns) + cherrypicks_table_data = pd.DataFrame(columns=table_cols, data=list(zip(*flattened_table_data))) + dataset_prefixes = set(col.split('@@')[0] for col in cherrypicks_table_data.columns) for dataset_prefix in dataset_prefixes: - dataset_columns = [col for col in wandb_data.columns if col.startswith(dataset_prefix)] + dataset_columns = [col for col in cherrypicks_table_data.columns if col.startswith(dataset_prefix)] table = { - f'cherry_pick_table_{dataset_prefix}_{state.global_step}': wandb.Table( - dataframe=wandb_data[dataset_columns] - ) + f'cherry_pick_table_{dataset_prefix}_{state.global_step}': cherrypicks_table_data[dataset_columns] } self.call_event('on_log', args, state, control, logs=table) diff --git a/turbo_alignment/common/tf/callbacks/logging.py b/turbo_alignment/common/tf/callbacks/logging.py new file mode 100755 index 0000000..1ad484d --- /dev/null +++ b/turbo_alignment/common/tf/callbacks/logging.py @@ -0,0 +1,118 @@ +from abc import ABC, abstractmethod +from typing import Any + +import numpy as np +import pandas as pd +from clearml import Task +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) +from wandb.sdk.lib.disabled import RunDisabled +from wandb.sdk.wandb_run import Run + +import wandb +from turbo_alignment.common.logging import get_project_logger + +logger = get_project_logger() + + +class LoggingCallback(TrainerCallback, ABC): + def on_log( + self, + _args: TrainingArguments, + state: TrainerState, + _control: TrainerControl, + **kwargs, + ) -> None: + logs = kwargs.get('logs', {}) + self._log(logs=logs, state=state) + + @abstractmethod + def _log(self, logs: dict[str, Any], state: TrainerState) -> None: + ... + + @staticmethod + def _rewrite_logs(logs: dict[str, Any]) -> dict[str, Any]: + rewritten_logs = {} + eval_prefix = 'eval_' + eval_prefix_len = len(eval_prefix) + test_prefix = 'test_' + test_prefix_len = len(test_prefix) + train_prefix = 'train_' + train_prefix_len = len(train_prefix) + cherry_pick_prefix = 'cherry_pick_' + cherry_pick_prefix_len = len(cherry_pick_prefix) + for k, v in logs.items(): + if isinstance(v, pd.DataFrame): + v = wandb.Table(dataframe=v) + + if k.startswith(eval_prefix): + rewritten_logs['eval/' + k[eval_prefix_len:]] = v + elif k.startswith(test_prefix): + rewritten_logs['test/' + k[test_prefix_len:]] = v + elif k.startswith(train_prefix): + rewritten_logs['train/' + k[train_prefix_len:]] = v + elif k.startswith(cherry_pick_prefix): + rewritten_logs['cherry_pick/' + k[cherry_pick_prefix_len:]] = v + else: + rewritten_logs['train/' + k] = v + + return rewritten_logs + + +class WandbLoggingCallback(LoggingCallback): + def __init__(self, wandb_run: Run | RunDisabled) -> None: + super().__init__() + self._wandb_run = wandb_run + + def _log(self, logs: dict[str, Any], state: TrainerState) -> None: + rewritten_logs: dict[str, Any] = self._rewrite_logs(logs) + self._wandb_run.log({**rewritten_logs, 'train/global_step': state.global_step}, step=state.global_step) + + +class ClearMLLoggingCallback(LoggingCallback): + def __init__(self, clearml_task: Task) -> None: + super().__init__() + self._clearml_task = clearml_task + + def _log(self, logs: dict[str, Any], state: TrainerState) -> None: + rewritten_logs: dict[str, Any] = self._rewrite_logs(logs) + + single_value_scalars: list[str] = [ + 'train_runtime', + 'train_samples_per_second', + 'train_steps_per_second', + 'train_loss', + 'epoch', + ] + + for k, v in rewritten_logs.items(): + title, series = k.split('/')[0], '/'.join(k.split('/')[1:]) + + if isinstance(v, (int, float, np.floating, np.integer)): + if k in single_value_scalars: + self._clearml_task.get_logger().report_single_value(name=k, value=v) + else: + self._clearml_task.get_logger().report_scalar( + title=title, + series=series, + value=v, + iteration=state.global_step, + ) + elif isinstance(v, pd.DataFrame): + self._clearml_task.get_logger().report_table( + title=title, + series=series, + table_plot=v, + iteration=state.global_step, + ) + else: + logger.warning( + 'Trainer is attempting to log a value of ' + f'"{v}" of type {type(v)} for key "{k}". ' + "This invocation of ClearML logger's function " + 'is incorrect so this attribute was dropped. ' + ) diff --git a/turbo_alignment/pipelines/mixin/__init__.py b/turbo_alignment/pipelines/mixin/__init__.py index f3ec2fe..bf25cc6 100755 --- a/turbo_alignment/pipelines/mixin/__init__.py +++ b/turbo_alignment/pipelines/mixin/__init__.py @@ -1,2 +1,2 @@ -from .logging import LoggingMixin +from .logging import LoggingRegistry from .s3 import S3Mixin diff --git a/turbo_alignment/pipelines/mixin/logging.py b/turbo_alignment/pipelines/mixin/logging.py index e793374..86ec367 100755 --- a/turbo_alignment/pipelines/mixin/logging.py +++ b/turbo_alignment/pipelines/mixin/logging.py @@ -1,13 +1,43 @@ -from abc import ABC +from abc import ABC, abstractmethod -from wandb.sdk.lib import RunDisabled -from wandb.sdk.wandb_run import Run +from allenai_common import Registrable +from turbo_alignment.common.logging.clearml import create_clearml_task from turbo_alignment.common.logging.weights_and_biases import create_wandb_run +from turbo_alignment.common.tf.callbacks.logging import ( + ClearMLLoggingCallback, + LoggingCallback, + WandbLoggingCallback, +) +from turbo_alignment.settings.logging.clearml import ClearMLSettings +from turbo_alignment.settings.logging.weights_and_biases import WandbSettings from turbo_alignment.settings.pipelines.train.base import BaseTrainExperimentSettings +class LoggingRegistry(Registrable): + ... + + class LoggingMixin(ABC): @staticmethod - def _get_wandb_run(experiment_settings: BaseTrainExperimentSettings) -> Run | RunDisabled: - return create_wandb_run(parameters=experiment_settings.wandb_settings, config=experiment_settings.dict()) + @abstractmethod + def get_logging_callback(experiment_settings: BaseTrainExperimentSettings) -> LoggingCallback: + ... + + +@LoggingRegistry.register(WandbSettings.__name__) +class WandbLoggingMixin(LoggingMixin): + @staticmethod + def get_logging_callback(experiment_settings: BaseTrainExperimentSettings) -> WandbLoggingCallback: + logging_settings: WandbSettings = WandbSettings(**experiment_settings.logging_settings.dict()) + wandb_run = create_wandb_run(parameters=logging_settings, config=experiment_settings.dict()) + return WandbLoggingCallback(wandb_run=wandb_run) + + +@LoggingRegistry.register(ClearMLSettings.__name__) +class ClearMLLogging(LoggingMixin): + @staticmethod + def get_logging_callback(experiment_settings: BaseTrainExperimentSettings) -> ClearMLLoggingCallback: + logging_settings: ClearMLSettings = ClearMLSettings(**experiment_settings.logging_settings.dict()) + clearml_task = create_clearml_task(parameters=logging_settings, config=experiment_settings.dict()) + return ClearMLLoggingCallback(clearml_task=clearml_task) diff --git a/turbo_alignment/pipelines/train/base.py b/turbo_alignment/pipelines/train/base.py index 57e4a20..3f1f1be 100755 --- a/turbo_alignment/pipelines/train/base.py +++ b/turbo_alignment/pipelines/train/base.py @@ -11,19 +11,16 @@ Trainer, TrainingArguments, ) -from wandb.sdk.lib import RunDisabled -from wandb.sdk.wandb_run import Run from turbo_alignment.cherry_picks.base import CherryPickCallbackBase from turbo_alignment.common.data.io import write_json from turbo_alignment.common.logging import get_project_logger -from turbo_alignment.common.tf.callbacks import BaseWandbCallback from turbo_alignment.common.tf.loaders.model import load_model from turbo_alignment.common.tf.loaders.tokenizer import load_tokenizer from turbo_alignment.common.tf.special_tokens_setter import SpecialTokensSetter from turbo_alignment.dataset.loader import DatasetLoader from turbo_alignment.pipelines.base import BaseStrategy -from turbo_alignment.pipelines.mixin import LoggingMixin, S3Mixin +from turbo_alignment.pipelines.mixin import LoggingRegistry, S3Mixin from turbo_alignment.settings.datasets.base import DatasetStrategy from turbo_alignment.settings.pipelines.train.base import BaseTrainExperimentSettings from turbo_alignment.settings.s3 import ExperimentMetadata, S3HandlerParameters @@ -34,15 +31,11 @@ ExperimentSettingsT = TypeVar('ExperimentSettingsT', bound=BaseTrainExperimentSettings) -class BaseTrainStrategy(S3Mixin, LoggingMixin, BaseStrategy, Generic[ExperimentSettingsT]): +class BaseTrainStrategy(S3Mixin, BaseStrategy, Generic[ExperimentSettingsT]): tokenizer: PreTrainedTokenizerBase model: PreTrainedModel trainer: Trainer - @staticmethod - def _get_wandb_callback(wandb_run: Run | RunDisabled) -> BaseWandbCallback: - return BaseWandbCallback(wandb_run=wandb_run) - @staticmethod @abstractmethod def _get_cherry_pick_callback( @@ -112,8 +105,11 @@ def _get_additional_special_tokens( def _add_trainer_callbacks(self, experiment_settings: ExperimentSettingsT, **kwargs) -> None: if self.trainer.accelerator.is_main_process: - if experiment_settings.wandb_settings: - self.trainer.add_callback(self._get_wandb_callback(wandb_run=self._get_wandb_run(experiment_settings))) + self.trainer.add_callback( + LoggingRegistry.by_name(experiment_settings.logging_settings.__name__).get_logging_callback( + experiment_settings=experiment_settings + ) + ) cherry_pick_callback = self._get_cherry_pick_callback(experiment_settings, self.tokenizer, **kwargs) diff --git a/turbo_alignment/settings/logging/clearml.py b/turbo_alignment/settings/logging/clearml.py new file mode 100644 index 0000000..2f2ffc7 --- /dev/null +++ b/turbo_alignment/settings/logging/clearml.py @@ -0,0 +1,12 @@ +from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel + + +class ClearMLSettings(ExtraFieldsNotAllowedBaseModel): + project_name: str + task_name: str + tags: list[str] = [] + + __name__ = 'ClearMLSettings' + + class Config: + env_prefix: str = 'CLEARML_' diff --git a/turbo_alignment/settings/weights_and_biases.py b/turbo_alignment/settings/logging/weights_and_biases.py similarity index 68% rename from turbo_alignment/settings/weights_and_biases.py rename to turbo_alignment/settings/logging/weights_and_biases.py index d905832..94821dc 100755 --- a/turbo_alignment/settings/weights_and_biases.py +++ b/turbo_alignment/settings/logging/weights_and_biases.py @@ -1,6 +1,6 @@ from enum import Enum -from pydantic_settings import BaseSettings +from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel class WandbMode(str, Enum): @@ -9,13 +9,15 @@ class WandbMode(str, Enum): DISABLED: str = 'disabled' -class WandbSettings(BaseSettings): +class WandbSettings(ExtraFieldsNotAllowedBaseModel): project_name: str run_name: str entity: str - notes: str | None = None tags: list[str] = [] + notes: str | None = None mode: WandbMode = WandbMode.ONLINE + __name__ = 'WandbSettings' + class Config: env_prefix: str = 'WANDB_' diff --git a/turbo_alignment/settings/pipelines/train/base.py b/turbo_alignment/settings/pipelines/train/base.py index 7e2173e..d31242d 100755 --- a/turbo_alignment/settings/pipelines/train/base.py +++ b/turbo_alignment/settings/pipelines/train/base.py @@ -5,6 +5,8 @@ from turbo_alignment.common import set_random_seed from turbo_alignment.settings.cherry_pick import CherryPickSettings from turbo_alignment.settings.datasets.base import MultiDatasetSettings +from turbo_alignment.settings.logging.clearml import ClearMLSettings +from turbo_alignment.settings.logging.weights_and_biases import WandbSettings from turbo_alignment.settings.model import ( ModelForPeftSettings, PreTrainedAdaptersModelSettings, @@ -14,7 +16,6 @@ from turbo_alignment.settings.tf.special_tokens_setter import SpecialTokensSettings from turbo_alignment.settings.tf.tokenizer import TokenizerSettings from turbo_alignment.settings.tf.trainer import TrainerSettings -from turbo_alignment.settings.weights_and_biases import WandbSettings class EarlyStoppingSettings(BaseSettings): @@ -36,7 +37,7 @@ class BaseTrainExperimentSettings(BaseSettings): train_dataset_settings: MultiDatasetSettings val_dataset_settings: MultiDatasetSettings - wandb_settings: WandbSettings + logging_settings: (WandbSettings | ClearMLSettings) checkpoint_uploader_callback_parameters: CheckpointUploaderCallbackParameters | None = None cherry_pick_settings: CherryPickSettings | None = None diff --git a/turbo_alignment/trainers/ddpo.py b/turbo_alignment/trainers/ddpo.py index 9d84294..b32b6a8 100755 --- a/turbo_alignment/trainers/ddpo.py +++ b/turbo_alignment/trainers/ddpo.py @@ -18,7 +18,7 @@ from transformers.integrations import get_reporting_integration_callbacks from turbo_alignment.common.logging import get_project_logger -from turbo_alignment.common.tf.callbacks.common import WandbMetricsCallbackHandler +from turbo_alignment.common.tf.callbacks.common import MetricsCallbackHandler from turbo_alignment.trainers.dpo import DPOTrainer from turbo_alignment.trainers.utils import concatenated_inputs, prepare_model @@ -74,7 +74,7 @@ def __init__( default_callbacks = [DefaultFlowCallback] + get_reporting_integration_callbacks(self.args.report_to) callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks - self.callback_handler = WandbMetricsCallbackHandler( + self.callback_handler = MetricsCallbackHandler( callbacks, model, tokenizer, None, None, ref_model=ref_model, accelerator=self.accelerator ) self.add_callback(PrinterCallback if self.args.disable_tqdm else ProgressCallback) diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index 0ce83db..47d6beb 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -20,7 +20,7 @@ from transformers.integrations import get_reporting_integration_callbacks from turbo_alignment.common.logging import get_project_logger -from turbo_alignment.common.tf.callbacks.common import WandbMetricsCallbackHandler +from turbo_alignment.common.tf.callbacks.common import MetricsCallbackHandler from turbo_alignment.common.tf.callbacks.sync_ref_model import SyncRefModelCallback from turbo_alignment.constants import DISABLE_LOSS_LABEL from turbo_alignment.settings.pipelines.train.dpo import ( @@ -416,7 +416,7 @@ def __init__( default_callbacks = [DefaultFlowCallback] + get_reporting_integration_callbacks(self.args.report_to) callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks - self.callback_handler = WandbMetricsCallbackHandler( + self.callback_handler = MetricsCallbackHandler( callbacks, model, tokenizer, diff --git a/turbo_alignment/trainers/kto.py b/turbo_alignment/trainers/kto.py index e3e7c6e..c6df560 100755 --- a/turbo_alignment/trainers/kto.py +++ b/turbo_alignment/trainers/kto.py @@ -19,7 +19,7 @@ from transformers.integrations import get_reporting_integration_callbacks from turbo_alignment.common.logging import get_project_logger -from turbo_alignment.common.tf.callbacks.common import WandbMetricsCallbackHandler +from turbo_alignment.common.tf.callbacks.common import MetricsCallbackHandler from turbo_alignment.settings.pipelines.train.dpo import SyncRefModelSettings from turbo_alignment.trainers.dpo import DPOTrainer from turbo_alignment.trainers.utils import prepare_model @@ -84,7 +84,7 @@ def __init__( default_callbacks = [DefaultFlowCallback] + get_reporting_integration_callbacks(self.args.report_to) callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks - self.callback_handler = WandbMetricsCallbackHandler( + self.callback_handler = MetricsCallbackHandler( callbacks, model, tokenizer, diff --git a/turbo_alignment/trainers/multigpu.py b/turbo_alignment/trainers/multigpu.py index 7c84076..7611d40 100755 --- a/turbo_alignment/trainers/multigpu.py +++ b/turbo_alignment/trainers/multigpu.py @@ -15,7 +15,7 @@ ) from transformers.integrations import get_reporting_integration_callbacks -from turbo_alignment.common.tf.callbacks.common import WandbMetricsCallbackHandler +from turbo_alignment.common.tf.callbacks.common import MetricsCallbackHandler class MultiGPUCherryPicksTrainer(Trainer): @@ -46,7 +46,7 @@ def __init__( default_callbacks = [DefaultFlowCallback] + get_reporting_integration_callbacks(self.args.report_to) callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks - self.callback_handler = WandbMetricsCallbackHandler( + self.callback_handler = MetricsCallbackHandler( callbacks, model, tokenizer, diff --git a/tutorials/dpo/dpo.json b/tutorials/dpo/dpo.json index 7948c6e..72de2f7 100755 --- a/tutorials/dpo/dpo.json +++ b/tutorials/dpo/dpo.json @@ -132,7 +132,7 @@ "use_sft_model": false, "use_ref_model": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "dpo", "entity": "turbo-alignment" diff --git a/tutorials/kto/kto.json b/tutorials/kto/kto.json index 9457b17..73de14b 100644 --- a/tutorials/kto/kto.json +++ b/tutorials/kto/kto.json @@ -127,7 +127,7 @@ "undesirable_weight": 1.0, "use_ref_model": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "kto", "entity": "turbo-alignment" diff --git a/tutorials/multimodal/multimodal.json b/tutorials/multimodal/multimodal.json index 12037e7..54b7948 100644 --- a/tutorials/multimodal/multimodal.json +++ b/tutorials/multimodal/multimodal.json @@ -125,7 +125,7 @@ "optim": "adamw_torch", "save_total_limit": 1 }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "multimodal", "entity": "turbo-alignment" diff --git a/tutorials/rm/rm.json b/tutorials/rm/rm.json index d552358..aab07d7 100755 --- a/tutorials/rm/rm.json +++ b/tutorials/rm/rm.json @@ -110,7 +110,7 @@ "gradient_checkpointing": true, "save_only_model": true }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "rm", "entity": "turbo-alignment" diff --git a/tutorials/sft/sft.json b/tutorials/sft/sft.json index 674b658..fd99470 100755 --- a/tutorials/sft/sft.json +++ b/tutorials/sft/sft.json @@ -111,7 +111,7 @@ "weight_decay": 0.01, "max_grad_norm": 0.11 }, - "wandb_settings": { + "logging_settings": { "project_name": "alignment", "run_name": "sft", "entity": "turbo-alignment"