diff --git a/neurodiffeq/callbacks.py b/neurodiffeq/callbacks.py index 8361721..f5985fa 100644 --- a/neurodiffeq/callbacks.py +++ b/neurodiffeq/callbacks.py @@ -2,6 +2,7 @@ import dill from datetime import datetime import logging +from .utils import safe_mkdir as _safe_mkdir class MonitorCallback: @@ -20,6 +21,8 @@ class MonitorCallback: def __init__(self, monitor, fig_dir=None, check_against='local', repaint_last=True): self.monitor = monitor self.fig_dir = fig_dir + if fig_dir: + _safe_mkdir(fig_dir) self.repaint_last = repaint_last if check_against not in ['local', 'global']: raise ValueError(f'unknown check_against type = {check_against}') diff --git a/neurodiffeq/monitors.py b/neurodiffeq/monitors.py index 515d7ad..a1cb4f0 100644 --- a/neurodiffeq/monitors.py +++ b/neurodiffeq/monitors.py @@ -28,12 +28,16 @@ class BaseMonitor(ABC): It blocks the training / validation process, so don't call the ``check()`` method too often. """ + def __init__(self): + self.check_every = ... + @abstractmethod def check(self, nets, conditions, history): pass -class MonitorSpherical: +# noinspection PyMissingConstructor +class MonitorSpherical(BaseMonitor): r"""A monitor for checking the status of the neural network during training. :param r_min: @@ -470,6 +474,7 @@ def max_degree(self): return ret +# noinspection PyMissingConstructor class Monitor1D(BaseMonitor): """A monitor for checking the status of the neural network during training. @@ -555,6 +560,7 @@ def check(self, nets, conditions, history): plt.pause(0.05) +# noinspection PyMissingConstructor class Monitor2D(BaseMonitor): r"""A monitor for checking the status of the neural network during training. diff --git a/neurodiffeq/utils.py b/neurodiffeq/utils.py index 91b55e7..2ca3dc0 100644 --- a/neurodiffeq/utils.py +++ b/neurodiffeq/utils.py @@ -1,3 +1,4 @@ +from pathlib import Path import torch @@ -32,3 +33,7 @@ def set_tensor_type(device=None, float_bits=32): raise ValueError(f"Unknown device '{device}'; device must be either 'cuda' or 'cpu'") torch.set_default_tensor_type(type_string) + + +def safe_mkdir(path): + Path(path).mkdir(parents=True, exist_ok=True)