Skip to content

Commit

Permalink
add clerml logging
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Sep 12, 2024
1 parent 184e19c commit 8884e83
Show file tree
Hide file tree
Showing 43 changed files with 248 additions and 78 deletions.
2 changes: 1 addition & 1 deletion configs/exp/kto/kto.json
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
"use_ref_model": true,
"deepspeed": "configs/exp/deepspeed/stage3.json"
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "sft",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion configs/exp/train/classification/classification.json
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
},
"save_total_limit": 10
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "classification",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion configs/exp/train/dpo/dpo.json
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
"use_ref_model": true,
"deepspeed": "configs/exp/deepspeed/stage2.json"
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "sft",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion configs/exp/train/multimodal/c_abs.json
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
"load_best_model_at_end": false,
"deepspeed": "configs/exp/deepspeed/ds_config_stage_2.json"
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "multimodal",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion configs/exp/train/multimodal/mlp.json
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
"load_best_model_at_end": false,
"deepspeed": "configs/exp/deepspeed/ds_config_stage_2.json"
},
"wandb_settings": null,
"logging_settings": null,
"log_path": "train_output",
"modality_encoder_settings_mapping": {
"image": {
Expand Down
2 changes: 1 addition & 1 deletion configs/exp/train/rag/end2end_rag.json
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
"max_grad_norm": 0.11,
"save_total_limit": 3
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "rag",
"entity": "biglm"
Expand Down
2 changes: 1 addition & 1 deletion configs/exp/train/rm/rm.json
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
"save_total_limit": 1,
"deepspeed": "configs/exp/deepspeed/stage2.json"
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "rm",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion configs/exp/train/sft/sft.json
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
"weight_decay": 0.01,
"max_grad_norm": 0.11
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "sft",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/classification/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
},
"no_cuda": true
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "classification",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/ddpo/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
"no_cuda": true,
"max_grad_norm": 1.0
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "ddpo",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/dpo/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
"use_sft_model": true,
"no_cuda": true
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "dpo",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/dpo/simpo.json
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
"use_sft_model": true,
"no_cuda": true
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "dpo",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/kto/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
"save_total_limit": 1,
"no_cuda": true
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "kto",
"entity": "turbo-alignment"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
"save_total_limit": 1,
"no_cuda": true
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "multimodal",
"entity": "turbo-alignment"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
"save_total_limit": 1,
"no_cuda": true
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "multimodal",
"entity": "turbo-alignment"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
"save_total_limit": 1,
"no_cuda": true
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "multimodal",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/rag/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
"no_cuda": true,
"save_total_limit": 1
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "rag",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/rm/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
"save_total_limit": 1,
"no_cuda": true
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "rm",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/sft/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
"save_total_limit": 1,
"no_cuda": true
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "sft",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/sft/prompt_tuning.json
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
"save_total_limit": 1,
"no_cuda": true
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "sft",
"entity": "turbo-alignment"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
"save_total_limit": 1,
"no_cuda": true
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "sft",
"entity": "turbo-alignment"
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/sft/sft_with_rm_metric.json
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
"save_total_limit": 1,
"no_cuda": true
},
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "sft",
"entity": "turbo-alignment"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
},
"log_path": "train_output",
"seed": 42,
"wandb_settings": {
"logging_settings": {
"project_name": "alignment",
"run_name": "fine_tune_llama",
"entity": "vladislavkruglikov",
Expand Down
15 changes: 15 additions & 0 deletions turbo_alignment/common/logging/clearml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Any

from clearml import Task

from turbo_alignment.settings.logging.clearml import ClearMLSettings


def create_clearml_task(parameters: ClearMLSettings, config: dict[str, Any] | None = None) -> Task:
clearml_task = Task.init(
task_name=parameters.task_name, project_name=parameters.project_name, continue_last_task=True # FIXME?
)

clearml_task.connect_configuration(config, name='HyperParameters')

return clearml_task
2 changes: 1 addition & 1 deletion turbo_alignment/common/logging/weights_and_biases.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from wandb.sdk.wandb_run import Run

import wandb
from turbo_alignment.settings.weights_and_biases import WandbSettings
from turbo_alignment.settings.logging.weights_and_biases import WandbSettings


def create_wandb_run(parameters: WandbSettings, config: dict[str, Any] | None = None) -> Run | RunDisabled:
Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/common/tf/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .common import EvaluateFirstStepCallback, WandbMetricsCallbackHandler
from .common import EvaluateFirstStepCallback, MetricsCallbackHandler
from .s3 import CheckpointUploaderCallback
from .weights_and_biases import BaseWandbCallback
32 changes: 14 additions & 18 deletions turbo_alignment/common/tf/callbacks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)
from transformers.trainer_callback import CallbackHandler

import wandb
from turbo_alignment.settings.metric import MetricResults


Expand All @@ -23,7 +22,7 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T
return control


class WandbMetricsCallbackHandler(CallbackHandler):
class MetricsCallbackHandler(CallbackHandler):
def __init__(self, *args, **kwargs) -> None:
self.ref_model = kwargs.pop('ref_model', None)
self.sft_model = kwargs.pop('sft_model', None)
Expand All @@ -50,46 +49,43 @@ def on_evaluate(
if isinstance(results, list):
gathered_results: list[list[MetricResults]] = gather_object(results)

gathered_float_wandb_data: dict[str, list[Any]] = defaultdict(list)
gathered_table_wandb_data: dict[str, list[str] | list[list[str]]] = defaultdict(list)
gathered_float_data: dict[str, list[Any]] = defaultdict(list)
gathered_table_data: dict[str, list[str] | list[list[str]]] = defaultdict(list)
average_functions: dict[str, Callable] = {}

for single_process_results in gathered_results:
for metric_result in single_process_results:
for score in metric_result.element_wise_scores:
if metric_result.need_average:
gathered_float_wandb_data[score.label].extend(score.values)
gathered_float_data[score.label].extend(score.values)
average_functions[score.label] = score.average_function
else:
gathered_table_wandb_data[score.label].extend(score.values)
gathered_table_data[score.label].extend(score.values)

logs = {
'cherry_pick_' + k: average_functions[k](list(*zip(*v)))
if isinstance(v[0], tuple)
else average_functions[k](v)
for k, v in gathered_float_wandb_data.items()
for k, v in gathered_float_data.items()
}

self.call_event('on_log', args, state, control, logs=logs)

wandb_table_cols = list(gathered_table_wandb_data.keys())
wandb_table_data = list(gathered_table_wandb_data.values())
table_cols = list(gathered_table_data.keys())
table_data = list(gathered_table_data.values())

flattened_wandb_data = [
sum(item, []) if isinstance(item, list) and isinstance(item[0], list) else item
for item in wandb_table_data
flattened_table_data = [
sum(item, []) if isinstance(item, list) and isinstance(item[0], list) else item for item in table_data
] # flatten list[lists] to display all outputs in wandb table

wandb_data = pd.DataFrame(columns=wandb_table_cols, data=list(zip(*flattened_wandb_data)))
dataset_prefixes = set(col.split('@@')[0] for col in wandb_data.columns)
cherrypicks_table_data = pd.DataFrame(columns=table_cols, data=list(zip(*flattened_table_data)))
dataset_prefixes = set(col.split('@@')[0] for col in cherrypicks_table_data.columns)

for dataset_prefix in dataset_prefixes:
dataset_columns = [col for col in wandb_data.columns if col.startswith(dataset_prefix)]
dataset_columns = [col for col in cherrypicks_table_data.columns if col.startswith(dataset_prefix)]

table = {
f'cherry_pick_table_{dataset_prefix}_{state.global_step}': wandb.Table(
dataframe=wandb_data[dataset_columns]
)
f'cherry_pick_table_{dataset_prefix}_{state.global_step}': cherrypicks_table_data[dataset_columns]
}

self.call_event('on_log', args, state, control, logs=table)
Expand Down
Loading

0 comments on commit 8884e83

Please sign in to comment.