diff --git a/experiments/mnist/increasing_inhibition.py b/experiments/mnist/increasing_inhibition.py index 75a9b82..47ad03c 100644 --- a/experiments/mnist/increasing_inhibition.py +++ b/experiments/mnist/increasing_inhibition.py @@ -4,6 +4,7 @@ import numpy as np import matplotlib.pyplot as plt +import torchvision.datasets.mnist as vision_mnist from time import time as t from sklearn.metrics import confusion_matrix from scipy.spatial.distance import euclidean @@ -11,7 +12,7 @@ from bindsnet.datasets import MNIST from bindsnet.network import Network from bindsnet.encoding import poisson -from bindsnet.network import load_network +from bindsnet.network import load from bindsnet.learning import PostPre, NoOp from bindsnet.network.monitors import Monitor from bindsnet.network.topology import Connection @@ -96,7 +97,7 @@ def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, c_low=1, c_high=25, network.add_connection(input_exc_conn, source='X', target='Y') network.add_connection(recurrent_conn, source='Y', target='Y') else: - network = load_network(os.path.join(params_path, model_name + '.pt')) + network = load(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu ) @@ -104,15 +105,17 @@ def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, c_low=1, c_high=25, network.layers['Y'].theta_plus = 0 # Load MNIST data. - dataset = MNIST(data_path, download=True) + dataset = vision_mnist.MNIST(root=data_path, download=True) if train: - images, labels = dataset.get_train() + images = dataset.train_data + labels = dataset.train_labels else: - images, labels = dataset.get_test() + images = dataset.test_data + labels = dataset.test_labels images = images.view(-1, 784) - images *= intensity + images = images * intensity # Record spikes during the simulation. spike_record = torch.zeros(update_interval, int(time / dt), n_neurons) @@ -232,7 +235,7 @@ def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, c_low=1, c_high=25, inpts = {'X' : sample} # Run the network on the input. - network.run(inpts=inpts, time=time) + network.run(inputs=inpts, time=time) retries = 0 while spikes['Y'].get('s').sum() < 5 and retries < 3: @@ -243,7 +246,7 @@ def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, c_low=1, c_high=25, network.run(inpts=inpts, time=time) # Add to spikes recording. - spike_record[i % update_interval] = spikes['Y'].get('s').t() + spike_record[i % update_interval] = spikes['Y'].get('s').squeeze(1) # Optionally plot various simulation information. if plot: @@ -261,7 +264,7 @@ def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, c_low=1, c_high=25, plt.pause(1e-8) - network.reset_() # Reset state variables. + network.reset_state_variables() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') @@ -354,7 +357,7 @@ def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, c_low=1, c_high=25, # Parameters. parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--n_neurons', type=int, default=100) + parser.add_argument('--n_neurons', type=int, default=625) parser.add_argument('--n_train', type=int, default=60000) parser.add_argument('--n_test', type=int, default=10000) parser.add_argument('--c_low', type=float, default=1) @@ -389,6 +392,7 @@ def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, c_low=1, c_high=25, progress_interval = args.progress_interval update_interval = args.update_interval plot = args.plot + plot = True train = args.train gpu = args.gpu