Skip to content

Commit

Permalink
Merge pull request #691 from BindsNET/hananel
Browse files Browse the repository at this point in the history
monitors periodic update
  • Loading branch information
Hananel-Hazan authored Aug 16, 2024
2 parents 38c4306 + 89cc96e commit a8c4d64
Show file tree
Hide file tree
Showing 3 changed files with 1,263 additions and 1,250 deletions.
13 changes: 10 additions & 3 deletions bindsnet/network/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(
if self.time is None:
self.device = "cpu"

self.clean = True

self.recording = []
self.reset_state_variables()

Expand All @@ -68,16 +70,20 @@ def get(self, var: str) -> torch.Tensor:
Note, if time == `None`, get return the logs and empty the monitor variable
"""
return_logs = torch.cat(self.recording[var], 0)
if self.time is None:
self.recording[var] = []
if self.clean:
return_logs = torch.empty(0, device=self.device)
else:
return_logs = torch.cat(self.recording[var], 0)
if self.time is None:
self.recording[var] = []
return return_logs

def record(self) -> None:
# language=rst
"""
Appends the current value of the recorded state variables to the recording.
"""
self.clean = False
for v in self.state_vars:
data = getattr(self.obj, v).unsqueeze(0)
# self.recording[v].append(data.detach().clone().to(self.device))
Expand All @@ -101,6 +107,7 @@ def reset_state_variables(self) -> None:
self.recording = {
v: [[] for i in range(self.time)] for v in self.state_vars
}
self.clean = True


class NetworkMonitor(AbstractMonitor):
Expand Down
Loading

0 comments on commit a8c4d64

Please sign in to comment.