Skip to content

Commit

Permalink
Merge pull request #642 from BindsNET/hananel
Browse files Browse the repository at this point in the history
Read the doc issue
  • Loading branch information
Hananel-Hazan authored Aug 18, 2023
2 parents 9a2cc55 + d5b1e0e commit 2c4286e
Show file tree
Hide file tree
Showing 6 changed files with 953 additions and 1,039 deletions.
1 change: 1 addition & 0 deletions readthedocs.yml → .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ python:
install:
- method: pip
path: .
- requirements: docs/requirements.txt
system_packages: False
17 changes: 17 additions & 0 deletions bindsnet/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def plot_weights(
figsize: Tuple[int, int] = (5, 5),
cmap: str = "hot_r",
save: Optional[str] = None,
title: Optional[str] = None,
) -> AxesImage:
# language=rst
"""
Expand All @@ -198,6 +199,7 @@ def plot_weights(
:param figsize: Horizontal, vertical figure size in inches.
:param cmap: Matplotlib colormap.
:param save: file name to save fig, if None = not saving fig.
:param title: Title of the plot.
:return: ``AxesImage`` for re-drawing the weights plot.
"""
local_weights = weights.detach().clone().cpu().numpy()
Expand All @@ -213,6 +215,8 @@ def plot_weights(
ax.set_xticks(())
ax.set_yticks(())
ax.set_aspect("auto")
if title != None:
ax.set_title(title + " Weights")

plt.colorbar(im, cax=cax)
fig.tight_layout()
Expand Down Expand Up @@ -241,6 +245,8 @@ def plot_weights(
ax.set_xticks(())
ax.set_yticks(())
ax.set_aspect("auto")
if title != None:
ax.set_title(title + " Weights")

plt.colorbar(im, cax=cax)
fig.tight_layout()
Expand All @@ -257,6 +263,7 @@ def plot_conv2d_weights(
im: Optional[AxesImage] = None,
figsize: Tuple[int, int] = (5, 5),
cmap: str = "hot_r",
title: Optional[str] = None,
) -> AxesImage:
# language=rst
"""
Expand All @@ -268,6 +275,7 @@ def plot_conv2d_weights(
:param im: Used for re-drawing the weights plot.
:param figsize: Horizontal, vertical figure size in inches.
:param cmap: Matplotlib colormap.
:param title: Title of the plot.
:return: Used for re-drawing the weights plot.
"""

Expand Down Expand Up @@ -295,6 +303,8 @@ def plot_conv2d_weights(
ax.set_xticks(())
ax.set_yticks(())
ax.set_aspect("auto")
if title != None:
ax.set_title(title + " Weights")

plt.colorbar(im, cax=cax)
fig.tight_layout()
Expand All @@ -317,6 +327,7 @@ def plot_locally_connected_weights(
lines: bool = True,
figsize: Tuple[int, int] = (5, 5),
cmap: str = "hot_r",
title: Optional[str] = None,
) -> AxesImage:
# language=rst
"""
Expand All @@ -337,6 +348,7 @@ def plot_locally_connected_weights(
regions.
:param figsize: Horizontal, vertical figure size in inches.
:param cmap: Matplotlib colormap.
:param title: Title of the plot.
:return: Used for re-drawing the weights plot.
"""
kernel_size = _pair(kernel_size)
Expand Down Expand Up @@ -373,6 +385,8 @@ def plot_locally_connected_weights(
ax.set_xticks(())
ax.set_yticks(())
ax.set_aspect("auto")
if title != None:
ax.set_title(title + " Weights")

plt.colorbar(im, cax=cax)
fig.tight_layout()
Expand All @@ -391,6 +405,7 @@ def plot_local_connection_2d_weights(
figsize: Tuple[int, int] = (5, 5),
cmap: str = "hot_r",
color: str = "r",
title: Optional[str] = None,
) -> AxesImage:
# language=rst
"""
Expand Down Expand Up @@ -451,6 +466,8 @@ def plot_local_connection_2d_weights(
ax.set_xticks(())
ax.set_yticks(())
ax.set_aspect("auto")
if title != None:
ax.set_title(title + " Weights")

plt.colorbar(im, cax=cax)
fig.tight_layout()
Expand Down
8 changes: 7 additions & 1 deletion bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(

from ..learning import NoOp

self.update_rule = kwargs.get("update_rule", NoOp)
self.update_rule = kwargs.get("update_rule", None)

# Float32 necessary for comparisons with +/-inf
self.wmin = Parameter(
Expand Down Expand Up @@ -775,6 +775,12 @@ def __init__(

self.register_buffer("firing_rates", torch.zeros(source.s.shape))

from ..learning import NoOp

# Initialize learning rule
if self.update_rule is not None and (self.update_rule == NoOp):
self.update_rule = self.update_rule()

def compute(self, s: torch.Tensor) -> torch.Tensor:
# language=rst
"""
Expand Down
6 changes: 6 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Defining the exact version will make sure things don't break
sphinx==6.2.1
sphinx_rtd_theme==1.2.2
readthedocs-sphinx-search==0.1.1
imagecodecs == 2023.3.16
Jinja2 == 3.1.2
Loading

0 comments on commit 2c4286e

Please sign in to comment.