Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change some code for the new version of bindsnet. #8

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions experiments/mnist/increasing_inhibition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import numpy as np
import matplotlib.pyplot as plt

import torchvision.datasets.mnist as vision_mnist
from time import time as t
from sklearn.metrics import confusion_matrix
from scipy.spatial.distance import euclidean

from bindsnet.datasets import MNIST
from bindsnet.network import Network
from bindsnet.encoding import poisson
from bindsnet.network import load_network
from bindsnet.network import load
from bindsnet.learning import PostPre, NoOp
from bindsnet.network.monitors import Monitor
from bindsnet.network.topology import Connection
Expand Down Expand Up @@ -96,23 +97,25 @@ def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, c_low=1, c_high=25,
network.add_connection(input_exc_conn, source='X', target='Y')
network.add_connection(recurrent_conn, source='Y', target='Y')
else:
network = load_network(os.path.join(params_path, model_name + '.pt'))
network = load(os.path.join(params_path, model_name + '.pt'))
network.connections['X', 'Y'].update_rule = NoOp(
connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu
)
network.layers['Y'].theta_decay = 0
network.layers['Y'].theta_plus = 0

# Load MNIST data.
dataset = MNIST(data_path, download=True)
dataset = vision_mnist.MNIST(root=data_path, download=True)

if train:
images, labels = dataset.get_train()
images = dataset.train_data
labels = dataset.train_labels
else:
images, labels = dataset.get_test()
images = dataset.test_data
labels = dataset.test_labels

images = images.view(-1, 784)
images *= intensity
images = images * intensity

# Record spikes during the simulation.
spike_record = torch.zeros(update_interval, int(time / dt), n_neurons)
Expand Down Expand Up @@ -232,7 +235,7 @@ def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, c_low=1, c_high=25,
inpts = {'X' : sample}

# Run the network on the input.
network.run(inpts=inpts, time=time)
network.run(inputs=inpts, time=time)

retries = 0
while spikes['Y'].get('s').sum() < 5 and retries < 3:
Expand All @@ -243,7 +246,7 @@ def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, c_low=1, c_high=25,
network.run(inpts=inpts, time=time)

# Add to spikes recording.
spike_record[i % update_interval] = spikes['Y'].get('s').t()
spike_record[i % update_interval] = spikes['Y'].get('s').squeeze(1)

# Optionally plot various simulation information.
if plot:
Expand All @@ -261,7 +264,7 @@ def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, c_low=1, c_high=25,

plt.pause(1e-8)

network.reset_() # Reset state variables.
network.reset_state_variables() # Reset state variables.

print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

Expand Down Expand Up @@ -354,7 +357,7 @@ def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, c_low=1, c_high=25,
# Parameters.
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--n_neurons', type=int, default=100)
parser.add_argument('--n_neurons', type=int, default=625)
parser.add_argument('--n_train', type=int, default=60000)
parser.add_argument('--n_test', type=int, default=10000)
parser.add_argument('--c_low', type=float, default=1)
Expand Down Expand Up @@ -389,6 +392,7 @@ def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, c_low=1, c_high=25,
progress_interval = args.progress_interval
update_interval = args.update_interval
plot = args.plot
plot = True
train = args.train
gpu = args.gpu

Expand Down