Skip to content

Commit

Permalink
Merge pull request #687 from BindsNET/hananel
Browse files Browse the repository at this point in the history
CUDA update and deprecation python 3.8 welcome 3.11
  • Loading branch information
Hananel-Hazan authored Jun 7, 2024
2 parents 65fd024 + d8000ee commit bf329b9
Show file tree
Hide file tree
Showing 5 changed files with 1,232 additions and 1,040 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pythonpackage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand Down
4 changes: 2 additions & 2 deletions bindsnet/datasets/torchvision_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, Optional

import torch
import torchvision
from torchvision import datasets as torchDB

from bindsnet.encoding import Encoder, NullEncoder

Expand All @@ -13,7 +13,7 @@ def create_torchvision_dataset_wrapper(ds_type):
``__getitem__``. This applies to all of the datasets inside of ``torchvision``.
"""
if type(ds_type) == str:
ds_type = getattr(torchvision.datasets, ds_type)
ds_type = getattr(torchDB, ds_type)

class TorchvisionDatasetWrapper(ds_type):
__doc__ = (
Expand Down
14 changes: 14 additions & 0 deletions bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(
:param Union[float, torch.Tensor] wmax: Maximum allowed value(s) on the connection weights. Single value, or
tensor of same size as w
:param float norm: Total weight per target neuron normalization.
:param Union[bool, torch.Tensor] Dales_rule: Whether to enforce Dale's rule. input is boolean tensor in weight shape
where True means force zero or positive values and False means force zero or negative values.
"""
super().__init__()

Expand Down Expand Up @@ -88,6 +90,12 @@ def __init__(
**kwargs,
)

self.Dales_rule = kwargs.get("Dales_rule", None)
if self.Dales_rule is not None:
self.Dales_rule = Parameter(
torch.as_tensor(self.Dales_rule, dtype=torch.bool), requires_grad=False
)

@abstractmethod
def compute(self, s: torch.Tensor) -> None:
# language=rst
Expand Down Expand Up @@ -117,6 +125,12 @@ def update(self, **kwargs) -> None:
if mask is not None:
self.w.masked_fill_(mask, 0)

if self.Dales_rule is not None:
# weight that are negative and should be positive are set to 0
self.w[self.w < 0 * self.Dales_rule.to(torch.float)] = 0
# weight that are positive and should be negative are set to 0
self.w[self.w > 0 * 1 - self.Dales_rule.to(torch.float)] = 0

@abstractmethod
def reset_state_variables(self) -> None:
# language=rst
Expand Down
Loading

0 comments on commit bf329b9

Please sign in to comment.