Skip to content

Commit

Permalink
Merge pull request #83 from mmcdermott/more_logging
Browse files Browse the repository at this point in the history
Added logging to other aspects of ESGPT.
  • Loading branch information
mmcdermott authored Dec 3, 2023
2 parents 273e09a + b34b3fe commit ccf9da9
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 25 deletions.
7 changes: 5 additions & 2 deletions EventStream/transformer/lightning_modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import lightning as L
import torch
from loguru import logger

from ...data.pytorch_dataset import PytorchDataset
from ..config import StructuredEventProcessingMode, StructuredTransformerConfig
Expand Down Expand Up @@ -153,8 +154,10 @@ def get_embeddings(cfg: FinetuneConfig):

if os.environ.get("LOCAL_RANK", "0") == "0":
if embeddings_fp.is_file() and not cfg.do_overwrite:
print(f"Embeddings already exist at {embeddings_fp}. To overwrite, set `do_overwrite=True`.")
logger.info(

Check warning on line 157 in EventStream/transformer/lightning_modules/embedding.py

View check run for this annotation

Codecov / codecov/patch

EventStream/transformer/lightning_modules/embedding.py#L157

Added line #L157 was not covered by tests
f"Embeddings already exist at {embeddings_fp}. To overwrite, set `do_overwrite=True`."
)
else:
print(f"Saving {sp} embeddings to {embeddings_fp}.")
logger.info(f"Saving {sp} embeddings to {embeddings_fp}.")
embeddings_fp.parent.mkdir(exist_ok=True, parents=True)
torch.save(embeddings, embeddings_fp)
31 changes: 17 additions & 14 deletions EventStream/transformer/lightning_modules/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger
from loguru import logger
from omegaconf import OmegaConf
from torchmetrics.classification import (
BinaryAccuracy,
Expand Down Expand Up @@ -183,7 +184,7 @@ def _log_metric_dict(
metric(preds, labels.long())
self.log(f"{prefix}_{metric_name}", metric)
except (ValueError, IndexError) as e:
print(
logger.error(

Check warning on line 187 in EventStream/transformer/lightning_modules/fine_tuning.py

View check run for this annotation

Codecov / codecov/patch

EventStream/transformer/lightning_modules/fine_tuning.py#L187

Added line #L187 was not covered by tests
f"Failed to compute {metric_name} "
f"with preds ({str_summary(preds)}) and labels ({str_summary(labels)}): {e}."
)
Expand Down Expand Up @@ -396,13 +397,13 @@ def __post_init__(self):
and self.data_config.get("train_subset_seed", None) is None
):
self.data_config["train_subset_seed"] = int(random.randint(1, int(1e6)))
print(
f"WARNING: train_subset_size={self.data_config.train_subset_size} but "
logger.warning(

Check warning on line 400 in EventStream/transformer/lightning_modules/fine_tuning.py

View check run for this annotation

Codecov / codecov/patch

EventStream/transformer/lightning_modules/fine_tuning.py#L400

Added line #L400 was not covered by tests
f"train_subset_size={self.data_config.train_subset_size} but "
f"seed is unset. Setting to {self.data_config['train_subset_seed']}"
)

data_config_fp = self.load_from_model_dir / "data_config.json"
print(f"Loading data_config from {data_config_fp}")
logger.info(f"Loading data_config from {data_config_fp}")
reloaded_data_config = PytorchDatasetConfig.from_json_file(data_config_fp)
reloaded_data_config.task_df_name = self.task_df_name

Expand All @@ -411,31 +412,33 @@ def __post_init__(self):
continue
if param == "task_df_name":
if val != self.task_df_name:
print(
f"WARNING: task_df_name is set in data_config_overrides to {val}! "
logger.warning(

Check warning on line 415 in EventStream/transformer/lightning_modules/fine_tuning.py

View check run for this annotation

Codecov / codecov/patch

EventStream/transformer/lightning_modules/fine_tuning.py#L415

Added line #L415 was not covered by tests
f"task_df_name is set in data_config_overrides to {val}! "
f"Original is {self.task_df_name}. Ignoring data_config..."
)
continue
print(f"Overwriting {param} in data_config from {getattr(reloaded_data_config, param)} to {val}")
logger.info(
f"Overwriting {param} in data_config from {getattr(reloaded_data_config, param)} to {val}"
)
setattr(reloaded_data_config, param, val)

self.data_config = reloaded_data_config

config_fp = self.load_from_model_dir / "config.json"
print(f"Loading config from {config_fp}")
logger.info(f"Loading config from {config_fp}")
reloaded_config = StructuredTransformerConfig.from_json_file(config_fp)

for param, val in self.config.items():
if val is None:
continue
print(f"Overwriting {param} in config from {getattr(reloaded_config, param)} to {val}")
logger.info(f"Overwriting {param} in config from {getattr(reloaded_config, param)} to {val}")
setattr(reloaded_config, param, val)

self.config = reloaded_config

reloaded_pretrain_config = OmegaConf.load(self.load_from_model_dir / "pretrain_config.yaml")
if self.wandb_logger_kwargs.get("project", None) is None:
print(f"Setting wandb project to {reloaded_pretrain_config.wandb_logger_kwargs.project}")
logger.info(f"Setting wandb project to {reloaded_pretrain_config.wandb_logger_kwargs.project}")
self.wandb_logger_kwargs["project"] = reloaded_pretrain_config.wandb_logger_kwargs.project


Expand Down Expand Up @@ -464,12 +467,12 @@ def train(cfg: FinetuneConfig):

if os.environ.get("LOCAL_RANK", "0") == "0":
cfg.save_dir.mkdir(parents=True, exist_ok=True)
print("Saving config files...")
logger.info("Saving config files...")
config_fp = cfg.save_dir / "config.json"
if config_fp.exists() and not cfg.do_overwrite:
raise FileExistsError(f"{config_fp} already exists!")
else:
print(f"Writing to {config_fp}")
logger.info(f"Writing to {config_fp}")
config.to_json_file(config_fp)

data_config.to_json_file(cfg.save_dir / "data_config.json", do_overwrite=cfg.do_overwrite)
Expand All @@ -486,7 +489,7 @@ def train(cfg: FinetuneConfig):

# TODO(mmd): Get this working!
# if cfg.compile:
# print("Compiling model!")
# logger.info("Compiling model!")
# LM = torch.compile(LM)

# Setting up torch dataloader
Expand Down Expand Up @@ -573,7 +576,7 @@ def train(cfg: FinetuneConfig):
held_out_metrics = trainer.test(model=LM, dataloaders=held_out_dataloader, ckpt_path="best")

if os.environ.get("LOCAL_RANK", "0") == "0":
print("Saving final metrics...")
logger.info("Saving final metrics...")

with open(cfg.save_dir / "tuning_metrics.json", mode="w") as f:
json.dump(tuning_metrics, f)
Expand Down
11 changes: 6 additions & 5 deletions EventStream/transformer/lightning_modules/generative_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger
from loguru import logger
from torchmetrics.classification import (
MulticlassAccuracy,
MulticlassAUROC,
Expand Down Expand Up @@ -279,7 +280,7 @@ def _log_metric_dict(
sync_dist=True,
)
except (ValueError, IndexError) as e:
print(
logger.error(

Check warning on line 283 in EventStream/transformer/lightning_modules/generative_modeling.py

View check run for this annotation

Codecov / codecov/patch

EventStream/transformer/lightning_modules/generative_modeling.py#L283

Added line #L283 was not covered by tests
f"Failed to compute {metric_name} for {measurement} "
f"with preds ({str_summary(preds)}) and labels ({str_summary(labels)}): {e}."
)
Expand Down Expand Up @@ -590,12 +591,12 @@ def train(cfg: PretrainConfig):
if os.environ.get("LOCAL_RANK", "0") == "0":
cfg.save_dir.mkdir(parents=True, exist_ok=True)

print("Saving config files...")
logger.info("Saving config files...")
config_fp = cfg.save_dir / "config.json"
if config_fp.exists() and not cfg.do_overwrite:
raise FileExistsError(f"{config_fp} already exists!")
else:
print(f"Writing to {config_fp}")
logger.info(f"Writing to {config_fp}")
config.to_json_file(config_fp)

data_config.to_json_file(cfg.save_dir / "data_config.json", do_overwrite=cfg.do_overwrite)
Expand All @@ -618,7 +619,7 @@ def train(cfg: PretrainConfig):

# TODO(mmd): Get this working!
# if cfg.compile:
# print("Compiling model!")
# logger.info("Compiling model!")
# LM = torch.compile(LM)

# Setting up torch dataloader
Expand Down Expand Up @@ -700,7 +701,7 @@ def train(cfg: PretrainConfig):
held_out_metrics = trainer.test(model=LM, dataloaders=held_out_dataloader)

if os.environ.get("LOCAL_RANK", "0") == "0":
print("Saving final metrics...")
logger.info("Saving final metrics...")

with open(cfg.save_dir / "tuning_metrics.json", mode="w") as f:
json.dump(tuning_metrics, f)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.multiprocessing
import torchmetrics
from lightning.pytorch.loggers import WandbLogger
from loguru import logger
from torchmetrics.classification import (
BinaryAccuracy,
BinaryAUROC,
Expand Down Expand Up @@ -168,7 +169,7 @@ def _log_metric_dict(
metric(preds, labels)
self.log(f"{prefix}_{metric_name}", metric)
except (ValueError, IndexError) as e:
print(
logger.error(

Check warning on line 172 in EventStream/transformer/lightning_modules/zero_shot_evaluator.py

View check run for this annotation

Codecov / codecov/patch

EventStream/transformer/lightning_modules/zero_shot_evaluator.py#L172

Added line #L172 was not covered by tests
f"Failed to compute {metric_name} "
f"with preds ({str_summary(preds)}) and labels ({str_summary(labels)}): {e}."
)
Expand Down Expand Up @@ -380,7 +381,7 @@ def zero_shot_evaluation(cfg: FinetuneConfig):
held_out_metrics = trainer.test(model=LM, dataloaders=held_out_dataloader)

if os.environ.get("LOCAL_RANK", "0") == "0":
print("Saving final metrics...")
logger.info("Saving final metrics...")
cfg.save_dir.mkdir(parents=True, exist_ok=True)

with open(cfg.save_dir / "zero_shot_tuning_metrics.json", mode="w") as f:
Expand Down
5 changes: 3 additions & 2 deletions EventStream/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import hydra
import polars as pl
from loguru import logger

PROPORTION = float
COUNT_OR_PROPORTION = Union[int, PROPORTION]
Expand Down Expand Up @@ -380,8 +381,8 @@ def wrap(*args, **kwargs):
# some hyperparameter combinations might be invalid or cause out-of-memory errors
# so when using hparam search plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
print(f"EXCEPTION: {ex}")
print(traceback.print_exc(), file=sys.stderr)
logger.error(f"EXCEPTION: {ex}")
logger.error(traceback.print_exc(), file=sys.stderr)
raise ex
finally:
# always close wandb run (even if exception occurs so multirun won't fail)
Expand Down

0 comments on commit ccf9da9

Please sign in to comment.