diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index a82196cf..9acdd258 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -341,6 +341,19 @@ def result(self) -> Prediction: return None +def fig2img(fig): + import numpy as np + import matplotlib.pyplot as plt + from PIL import Image + """Convert a Matplotlib figure to a PIL Image and return it""" + import io + buf = io.BytesIO() + fig.savefig(buf) + buf.seek(0) + img = Image.open(buf) + return img + + class BaseModel(InitialParameterRepresenterMixIn, LightningModule, TupleOutputMixIn): """ BaseModel from which new timeseries models should inherit from. @@ -985,16 +998,16 @@ def log_prediction( tag += f" of item {idx} in batch {batch_idx}" if isinstance(fig, (list, tuple)): for idx, f in enumerate(fig): - self.logger.experiment.add_figure( - f"{self.target_names[idx]} {tag}", - f, - global_step=self.global_step, + self.logger.experiment.log_image( + run_id=self.logger.run_id, + image=fig2img(f), + artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) else: - self.logger.experiment.add_figure( - tag, - fig, - global_step=self.global_step, + self.logger.experiment.log_image( + run_id=self.logger.run_id, + image=fig2img(fig), + artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) def plot_prediction( @@ -1157,7 +1170,7 @@ def log_gradient_flow(self, named_parameters: Dict[str, torch.Tensor]) -> None: ax.set_ylabel("Average gradient") ax.set_yscale("log") ax.set_title("Gradient flow") - self.logger.experiment.add_figure("Gradient flow", fig, global_step=self.global_step) + self.logger.experiment.log_image(run_id=self.logger.run_id, image=fig2img(fig), artifact_file=f"gradient_flow.png") def on_after_backward(self): """ diff --git a/pytorch_forecasting/models/nhits/__init__.py b/pytorch_forecasting/models/nhits/__init__.py index 68816f22..eb49fcca 100644 --- a/pytorch_forecasting/models/nhits/__init__.py +++ b/pytorch_forecasting/models/nhits/__init__.py @@ -19,6 +19,18 @@ from pytorch_forecasting.utils._dependencies import _check_matplotlib +def fig2img(fig): + import numpy as np + import matplotlib.pyplot as plt + from PIL import Image + """Convert a Matplotlib figure to a PIL Image and return it""" + import io + buf = io.BytesIO() + fig.savefig(buf) + buf.seek(0) + img = Image.open(buf) + return img + class NHiTS(BaseModelWithCovariates): def __init__( self, @@ -552,17 +564,17 @@ def log_interpretation(self, x, out, batch_idx): name += f"step {self.global_step}" else: name += f"batch {batch_idx}" - self.logger.experiment.add_figure(name, fig, global_step=self.global_step) + self.logger.experiment.log_image(image=fig, artifact_file=f"{name}.png") if isinstance(fig, (list, tuple)): for idx, f in enumerate(fig): - self.logger.experiment.add_figure( - f"{self.target_names[idx]} {name}", - f, - global_step=self.global_step, + self.logger.experiment.log_image( + run_id=self.logger.run_id, + image=fig2img(f), + artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) else: - self.logger.experiment.add_figure( - name, - fig, - global_step=self.global_step, + self.logger.experiment.log_image( + run_id=self.logger.run_id, + image=fig2img(fig), + artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index cc506612..371c3189 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -25,6 +25,18 @@ from pytorch_forecasting.utils import create_mask, detach, integer_histogram, masked_op, padded_stack, to_list from pytorch_forecasting.utils._dependencies import _check_matplotlib +def fig2img(fig): + import numpy as np + import matplotlib.pyplot as plt + from PIL import Image + """Convert a Matplotlib figure to a PIL Image and return it""" + import io + buf = io.BytesIO() + fig.savefig(buf) + buf.seek(0) + img = Image.open(buf) + return img + class TemporalFusionTransformer(BaseModelWithCovariates): def __init__( @@ -827,8 +839,10 @@ def log_interpretation(self, outputs): label = self.current_stage # log to tensorboard for name, fig in figs.items(): - self.logger.experiment.add_figure( - f"{label.capitalize()} {name} importance", fig, global_step=self.global_step + self.logger.experiment.log_image( + run_id=self.logger.run_id, + image=fig2img(fig), + artifact_file=f"{label.capitalize()}_{name}_step_{self.global_step}.png" ) # log lengths of encoder/decoder @@ -849,8 +863,10 @@ def log_interpretation(self, outputs): ax.set_ylabel("Number of samples") ax.set_title(f"{type.capitalize()} length distribution in {label} epoch") - self.logger.experiment.add_figure( - f"{label.capitalize()} {type} length distribution", fig, global_step=self.global_step + self.logger.experiment.log_image( + run_id=self.logger.run_id, + image=fig2img(fig), + artifact_file=f"{label.capitalize()}_{type}_length_distribution_step_{self.global_step}.png", ) def log_embeddings(self):