Skip to content

Commit

Permalink
feat(flow): add latest epoch checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
billsioros committed Feb 13, 2022
1 parent b4a4a94 commit 87884b7
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions src/roughml/training/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
import torch

from roughml.content.loss import VectorSpaceContentLoss
from roughml.plot import animate_epochs, as_3d_surface, as_grayscale_image, plot_against
Expand Down Expand Up @@ -67,7 +68,9 @@ def __call__(self, get_generator, get_discriminator):
with ExceptionLoggingHandler(
logger, suppress_exceptions=self.suppress_exceptions
) as exception_logging_handler:
self.process_dataset(generator, discriminator, path, dataset)
dataset_output_dir = self.process_dataset(
generator, discriminator, path, dataset
)

end_time = time()

Expand All @@ -81,6 +84,23 @@ def __call__(self, get_generator, get_discriminator):
succeeded=exception_logging_handler.success,
)

checkpoint_dir = dataset_output_dir / "Checkpoint"

generator_mt, discriminator_mt = (
f"{generator.__class__.__name__}",
f"{discriminator.__class__.__name__}",
)

torch.save(
generator.state_dict(),
checkpoint_dir / f"{generator_mt}.pt",
)

torch.save(
discriminator.state_dict(),
checkpoint_dir / f"{discriminator_mt}.pt",
)

def process_dataset(self, generator, discriminator, path, dataset):
dataset_output_dir = (
self.output_dir
Expand Down Expand Up @@ -192,7 +212,7 @@ def process_dataset(self, generator, discriminator, path, dataset):
"Discriminator Loss",
"Discriminator Output (Real)",
"Discriminator Output (Fake)",
f"Content Loss ({self.content_loss.type.__name__ if self.content_loss.type else 'None'})",
f"Content Loss ({self.content_loss.type.__name__ if hasattr(self.content_loss, 'type') else 'None'})",
"Content Loss (VectorSpaceContentLoss)",
],
).to_csv(str(checkpoint_dir / "per_epoch_data.csv"))
Expand Down Expand Up @@ -284,3 +304,5 @@ def process_dataset(self, generator, discriminator, path, dataset):
self.plot.save_directory
/ (self.plot.surface.save_path_fmt % ("fake", i)),
)

return dataset_output_dir

0 comments on commit 87884b7

Please sign in to comment.