From 87884b715b3687a875dc7184616dabff5d87af82 Mon Sep 17 00:00:00 2001 From: billsioros Date: Sun, 13 Feb 2022 14:08:48 +0200 Subject: [PATCH] feat(flow): add latest epoch checkpoint --- src/roughml/training/flow.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/roughml/training/flow.py b/src/roughml/training/flow.py index 678c119..b58fd7d 100644 --- a/src/roughml/training/flow.py +++ b/src/roughml/training/flow.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd +import torch from roughml.content.loss import VectorSpaceContentLoss from roughml.plot import animate_epochs, as_3d_surface, as_grayscale_image, plot_against @@ -67,7 +68,9 @@ def __call__(self, get_generator, get_discriminator): with ExceptionLoggingHandler( logger, suppress_exceptions=self.suppress_exceptions ) as exception_logging_handler: - self.process_dataset(generator, discriminator, path, dataset) + dataset_output_dir = self.process_dataset( + generator, discriminator, path, dataset + ) end_time = time() @@ -81,6 +84,23 @@ def __call__(self, get_generator, get_discriminator): succeeded=exception_logging_handler.success, ) + checkpoint_dir = dataset_output_dir / "Checkpoint" + + generator_mt, discriminator_mt = ( + f"{generator.__class__.__name__}", + f"{discriminator.__class__.__name__}", + ) + + torch.save( + generator.state_dict(), + checkpoint_dir / f"{generator_mt}.pt", + ) + + torch.save( + discriminator.state_dict(), + checkpoint_dir / f"{discriminator_mt}.pt", + ) + def process_dataset(self, generator, discriminator, path, dataset): dataset_output_dir = ( self.output_dir @@ -192,7 +212,7 @@ def process_dataset(self, generator, discriminator, path, dataset): "Discriminator Loss", "Discriminator Output (Real)", "Discriminator Output (Fake)", - f"Content Loss ({self.content_loss.type.__name__ if self.content_loss.type else 'None'})", + f"Content Loss ({self.content_loss.type.__name__ if hasattr(self.content_loss, 'type') else 'None'})", "Content Loss (VectorSpaceContentLoss)", ], ).to_csv(str(checkpoint_dir / "per_epoch_data.csv")) @@ -284,3 +304,5 @@ def process_dataset(self, generator, discriminator, path, dataset): self.plot.save_directory / (self.plot.surface.save_path_fmt % ("fake", i)), ) + + return dataset_output_dir