diff --git a/EventStream/transformer/lightning_modules/embedding.py b/EventStream/transformer/lightning_modules/embedding.py index 6353fdb6..18aa9d23 100644 --- a/EventStream/transformer/lightning_modules/embedding.py +++ b/EventStream/transformer/lightning_modules/embedding.py @@ -6,6 +6,7 @@ import lightning as L import torch +from loguru import logger from ...data.pytorch_dataset import PytorchDataset from ..config import StructuredEventProcessingMode, StructuredTransformerConfig @@ -153,8 +154,10 @@ def get_embeddings(cfg: FinetuneConfig): if os.environ.get("LOCAL_RANK", "0") == "0": if embeddings_fp.is_file() and not cfg.do_overwrite: - print(f"Embeddings already exist at {embeddings_fp}. To overwrite, set `do_overwrite=True`.") + logger.info( + f"Embeddings already exist at {embeddings_fp}. To overwrite, set `do_overwrite=True`." + ) else: - print(f"Saving {sp} embeddings to {embeddings_fp}.") + logger.info(f"Saving {sp} embeddings to {embeddings_fp}.") embeddings_fp.parent.mkdir(exist_ok=True, parents=True) torch.save(embeddings, embeddings_fp) diff --git a/EventStream/transformer/lightning_modules/fine_tuning.py b/EventStream/transformer/lightning_modules/fine_tuning.py index bf1cae02..8bf31e65 100644 --- a/EventStream/transformer/lightning_modules/fine_tuning.py +++ b/EventStream/transformer/lightning_modules/fine_tuning.py @@ -14,6 +14,7 @@ from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint from lightning.pytorch.callbacks.early_stopping import EarlyStopping from lightning.pytorch.loggers import WandbLogger +from loguru import logger from omegaconf import OmegaConf from torchmetrics.classification import ( BinaryAccuracy, @@ -183,7 +184,7 @@ def _log_metric_dict( metric(preds, labels.long()) self.log(f"{prefix}_{metric_name}", metric) except (ValueError, IndexError) as e: - print( + logger.error( f"Failed to compute {metric_name} " f"with preds ({str_summary(preds)}) and labels ({str_summary(labels)}): {e}." ) @@ -396,13 +397,13 @@ def __post_init__(self): and self.data_config.get("train_subset_seed", None) is None ): self.data_config["train_subset_seed"] = int(random.randint(1, int(1e6))) - print( - f"WARNING: train_subset_size={self.data_config.train_subset_size} but " + logger.warning( + f"train_subset_size={self.data_config.train_subset_size} but " f"seed is unset. Setting to {self.data_config['train_subset_seed']}" ) data_config_fp = self.load_from_model_dir / "data_config.json" - print(f"Loading data_config from {data_config_fp}") + logger.info(f"Loading data_config from {data_config_fp}") reloaded_data_config = PytorchDatasetConfig.from_json_file(data_config_fp) reloaded_data_config.task_df_name = self.task_df_name @@ -411,31 +412,33 @@ def __post_init__(self): continue if param == "task_df_name": if val != self.task_df_name: - print( - f"WARNING: task_df_name is set in data_config_overrides to {val}! " + logger.warning( + f"task_df_name is set in data_config_overrides to {val}! " f"Original is {self.task_df_name}. Ignoring data_config..." ) continue - print(f"Overwriting {param} in data_config from {getattr(reloaded_data_config, param)} to {val}") + logger.info( + f"Overwriting {param} in data_config from {getattr(reloaded_data_config, param)} to {val}" + ) setattr(reloaded_data_config, param, val) self.data_config = reloaded_data_config config_fp = self.load_from_model_dir / "config.json" - print(f"Loading config from {config_fp}") + logger.info(f"Loading config from {config_fp}") reloaded_config = StructuredTransformerConfig.from_json_file(config_fp) for param, val in self.config.items(): if val is None: continue - print(f"Overwriting {param} in config from {getattr(reloaded_config, param)} to {val}") + logger.info(f"Overwriting {param} in config from {getattr(reloaded_config, param)} to {val}") setattr(reloaded_config, param, val) self.config = reloaded_config reloaded_pretrain_config = OmegaConf.load(self.load_from_model_dir / "pretrain_config.yaml") if self.wandb_logger_kwargs.get("project", None) is None: - print(f"Setting wandb project to {reloaded_pretrain_config.wandb_logger_kwargs.project}") + logger.info(f"Setting wandb project to {reloaded_pretrain_config.wandb_logger_kwargs.project}") self.wandb_logger_kwargs["project"] = reloaded_pretrain_config.wandb_logger_kwargs.project @@ -464,12 +467,12 @@ def train(cfg: FinetuneConfig): if os.environ.get("LOCAL_RANK", "0") == "0": cfg.save_dir.mkdir(parents=True, exist_ok=True) - print("Saving config files...") + logger.info("Saving config files...") config_fp = cfg.save_dir / "config.json" if config_fp.exists() and not cfg.do_overwrite: raise FileExistsError(f"{config_fp} already exists!") else: - print(f"Writing to {config_fp}") + logger.info(f"Writing to {config_fp}") config.to_json_file(config_fp) data_config.to_json_file(cfg.save_dir / "data_config.json", do_overwrite=cfg.do_overwrite) @@ -486,7 +489,7 @@ def train(cfg: FinetuneConfig): # TODO(mmd): Get this working! # if cfg.compile: - # print("Compiling model!") + # logger.info("Compiling model!") # LM = torch.compile(LM) # Setting up torch dataloader @@ -573,7 +576,7 @@ def train(cfg: FinetuneConfig): held_out_metrics = trainer.test(model=LM, dataloaders=held_out_dataloader, ckpt_path="best") if os.environ.get("LOCAL_RANK", "0") == "0": - print("Saving final metrics...") + logger.info("Saving final metrics...") with open(cfg.save_dir / "tuning_metrics.json", mode="w") as f: json.dump(tuning_metrics, f) diff --git a/EventStream/transformer/lightning_modules/generative_modeling.py b/EventStream/transformer/lightning_modules/generative_modeling.py index 4c82a8e7..3d9ff572 100644 --- a/EventStream/transformer/lightning_modules/generative_modeling.py +++ b/EventStream/transformer/lightning_modules/generative_modeling.py @@ -12,6 +12,7 @@ from lightning.pytorch.callbacks import LearningRateMonitor from lightning.pytorch.callbacks.early_stopping import EarlyStopping from lightning.pytorch.loggers import WandbLogger +from loguru import logger from torchmetrics.classification import ( MulticlassAccuracy, MulticlassAUROC, @@ -279,7 +280,7 @@ def _log_metric_dict( sync_dist=True, ) except (ValueError, IndexError) as e: - print( + logger.error( f"Failed to compute {metric_name} for {measurement} " f"with preds ({str_summary(preds)}) and labels ({str_summary(labels)}): {e}." ) @@ -590,12 +591,12 @@ def train(cfg: PretrainConfig): if os.environ.get("LOCAL_RANK", "0") == "0": cfg.save_dir.mkdir(parents=True, exist_ok=True) - print("Saving config files...") + logger.info("Saving config files...") config_fp = cfg.save_dir / "config.json" if config_fp.exists() and not cfg.do_overwrite: raise FileExistsError(f"{config_fp} already exists!") else: - print(f"Writing to {config_fp}") + logger.info(f"Writing to {config_fp}") config.to_json_file(config_fp) data_config.to_json_file(cfg.save_dir / "data_config.json", do_overwrite=cfg.do_overwrite) @@ -618,7 +619,7 @@ def train(cfg: PretrainConfig): # TODO(mmd): Get this working! # if cfg.compile: - # print("Compiling model!") + # logger.info("Compiling model!") # LM = torch.compile(LM) # Setting up torch dataloader @@ -700,7 +701,7 @@ def train(cfg: PretrainConfig): held_out_metrics = trainer.test(model=LM, dataloaders=held_out_dataloader) if os.environ.get("LOCAL_RANK", "0") == "0": - print("Saving final metrics...") + logger.info("Saving final metrics...") with open(cfg.save_dir / "tuning_metrics.json", mode="w") as f: json.dump(tuning_metrics, f) diff --git a/EventStream/transformer/lightning_modules/zero_shot_evaluator.py b/EventStream/transformer/lightning_modules/zero_shot_evaluator.py index 8489ce4d..a51f4c4a 100644 --- a/EventStream/transformer/lightning_modules/zero_shot_evaluator.py +++ b/EventStream/transformer/lightning_modules/zero_shot_evaluator.py @@ -10,6 +10,7 @@ import torch.multiprocessing import torchmetrics from lightning.pytorch.loggers import WandbLogger +from loguru import logger from torchmetrics.classification import ( BinaryAccuracy, BinaryAUROC, @@ -168,7 +169,7 @@ def _log_metric_dict( metric(preds, labels) self.log(f"{prefix}_{metric_name}", metric) except (ValueError, IndexError) as e: - print( + logger.error( f"Failed to compute {metric_name} " f"with preds ({str_summary(preds)}) and labels ({str_summary(labels)}): {e}." ) @@ -380,7 +381,7 @@ def zero_shot_evaluation(cfg: FinetuneConfig): held_out_metrics = trainer.test(model=LM, dataloaders=held_out_dataloader) if os.environ.get("LOCAL_RANK", "0") == "0": - print("Saving final metrics...") + logger.info("Saving final metrics...") cfg.save_dir.mkdir(parents=True, exist_ok=True) with open(cfg.save_dir / "zero_shot_tuning_metrics.json", mode="w") as f: diff --git a/EventStream/utils.py b/EventStream/utils.py index 65906278..d1fce545 100644 --- a/EventStream/utils.py +++ b/EventStream/utils.py @@ -17,6 +17,7 @@ import hydra import polars as pl +from loguru import logger PROPORTION = float COUNT_OR_PROPORTION = Union[int, PROPORTION] @@ -380,8 +381,8 @@ def wrap(*args, **kwargs): # some hyperparameter combinations might be invalid or cause out-of-memory errors # so when using hparam search plugins like Optuna, you might want to disable # raising the below exception to avoid multirun failure - print(f"EXCEPTION: {ex}") - print(traceback.print_exc(), file=sys.stderr) + logger.error(f"EXCEPTION: {ex}") + logger.error(traceback.print_exc(), file=sys.stderr) raise ex finally: # always close wandb run (even if exception occurs so multirun won't fail)