Skip to content

Commit

Permalink
Merge pull request #80 from odegym/v0.3.1
Browse files Browse the repository at this point in the history
V0.3.1
  • Loading branch information
shuheng-liu authored Feb 8, 2021
2 parents 10f9ab7 + cbda7e1 commit 48d65d5
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
3 changes: 3 additions & 0 deletions neurodiffeq/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dill
from datetime import datetime
import logging
from .utils import safe_mkdir as _safe_mkdir


class MonitorCallback:
Expand All @@ -20,6 +21,8 @@ class MonitorCallback:
def __init__(self, monitor, fig_dir=None, check_against='local', repaint_last=True):
self.monitor = monitor
self.fig_dir = fig_dir
if fig_dir:
_safe_mkdir(fig_dir)
self.repaint_last = repaint_last
if check_against not in ['local', 'global']:
raise ValueError(f'unknown check_against type = {check_against}')
Expand Down
8 changes: 7 additions & 1 deletion neurodiffeq/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@ class BaseMonitor(ABC):
It blocks the training / validation process, so don't call the ``check()`` method too often.
"""

def __init__(self):
self.check_every = ...

@abstractmethod
def check(self, nets, conditions, history):
pass


class MonitorSpherical:
# noinspection PyMissingConstructor
class MonitorSpherical(BaseMonitor):
r"""A monitor for checking the status of the neural network during training.
:param r_min:
Expand Down Expand Up @@ -470,6 +474,7 @@ def max_degree(self):
return ret


# noinspection PyMissingConstructor
class Monitor1D(BaseMonitor):
"""A monitor for checking the status of the neural network during training.
Expand Down Expand Up @@ -555,6 +560,7 @@ def check(self, nets, conditions, history):
plt.pause(0.05)


# noinspection PyMissingConstructor
class Monitor2D(BaseMonitor):
r"""A monitor for checking the status of the neural network during training.
Expand Down
5 changes: 5 additions & 0 deletions neurodiffeq/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
import torch


Expand Down Expand Up @@ -32,3 +33,7 @@ def set_tensor_type(device=None, float_bits=32):
raise ValueError(f"Unknown device '{device}'; device must be either 'cuda' or 'cpu'")

torch.set_default_tensor_type(type_string)


def safe_mkdir(path):
Path(path).mkdir(parents=True, exist_ok=True)

0 comments on commit 48d65d5

Please sign in to comment.