-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplace_cell_rnn.py
130 lines (114 loc) · 6.85 KB
/
place_cell_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import numpy as np
import torch
import os
import time
import argparse
import json
from place_cells import PlaceCells
from trajectory_generator import TrajectoryGenerator
from model import RNN
from trainer import Trainer
from visualize import *
from scores import border_score
# Training hyperparameters to fully reproduce Sorscher et al. 2023
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
# Training hyperparameters to fully reproduce Sorscher et al. 2023
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use for training')
parser.add_argument('--oned', type=lambda x: (str(x).lower() == 'true'), default=False, help='Whether to use one-dimensional place cells')
parser.add_argument('--Np', type=int, default=512, help='Number of place cells')
parser.add_argument('--Ng', type=int, default=2048, help='Number of grid cells')
parser.add_argument('--Nv', type=int, default=2, help='Number of velocity inputs')
parser.add_argument('--DoG', type=lambda x: (str(x).lower() == 'true'), default=True, help='Whether to use Difference of Gaussians for place cell RFs')
parser.add_argument('--box_width', type=float, default=1.4, help='Width of the environment box')
parser.add_argument('--box_height', type=float, default=1.4, help='Height of the environment box')
parser.add_argument('--place_cell_rf', type=float, default=0.12, help='Place cell receptive field size')
parser.add_argument('--surround_scale', type=int, default=2, help='Scale factor for the surround inhibition')
parser.add_argument('--periodic', type=lambda x: (str(x).lower() == 'true'), default=False, help='Whether the environment is periodic')
parser.add_argument('--sequence_length', type=int, default=10, help='Length of the trajectory sequence')
parser.add_argument('--dt', type=float, default=0.02, help='Time step size')
parser.add_argument('--batch_size', type=int, default=500, help='Batch size for training')
parser.add_argument('--n_epochs', type=int, default=200, help='Number of training epochs')
parser.add_argument('--n_steps', type=int, default=100, help='Number of steps per epoch')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for training')
parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay for training')
parser.add_argument('--decay_step_size', type=int, default=10, help='Step size for learning rate decay')
parser.add_argument('--decay_rate', type=float, default=0.9, help='Decay rate for learning rate decay')
parser.add_argument('--rec_activation', type=str, default='relu', help='Recurrent activation function for the RNN')
parser.add_argument('--out_activation', type=str, default='softmax', help='Output activation function for the RNN')
parser.add_argument('--restore', type=str, default=None, help='Timestamp of the saved model to restore')
parser.add_argument('--preloaded_data', type=lambda x: (str(x).lower() == 'true'), default=False, help='Whether to use preloaded data')
parser.add_argument('--save', type=lambda x: (str(x).lower() == 'true'), default=True, help='Whether to save the model')
parser.add_argument('--save_every', type=int, default=100, help='Save the model every n epochs')
parser.add_argument('--loss', type=str, default='CE', help='Loss function for training')
parser.add_argument('--is_wandb', type=lambda x: (str(x).lower() == 'true'), default=False, help='Whether to use wandb for logging')
parser.add_argument('--mode', type=str, default='train', help='Mode for running the model; input run folder name for model inspection')
parser.add_argument('--normalize_pc', type=str, default='softmax', help='Transformation applied to place cells in generation')
parser.add_argument('--truncating', type=int, default=0, help='Truncating steps for BPTT')
options = parser.parse_args()
if options.mode == 'train':
# save directory
now = time.strftime('%b-%d-%Y-%H-%M-%S', time.gmtime(time.time()))
if options.restore is not None:
now = options.restore
options.save_dir = os.path.join('./results/rnn', now)
if not os.path.exists(options.save_dir):
os.makedirs(options.save_dir)
print('Saving to:', options.save_dir)
utils.save_options_to_json(options, os.path.join(options.save_dir, 'configs.json'))
# define place cells, trajectory generator, model, and trainer
place_cell = PlaceCells(options)
generator = TrajectoryGenerator(options, place_cell)
model = RNN(options, place_cell).to(options.device)
trainer = Trainer(options, model, generator, place_cell, restore=options.restore)
trainer.train(preloaded_data=options.preloaded_data, save=options.save)
plot_place_cells(place_cell, options, res=30)
plot_2d_performance(place_cell, generator, options, trainer)
rate_map = compute_ratemaps(model, trainer, generator, options, res=20, n_avg=200, Ng=options.Ng)
plot_2d_ratemaps(rate_map, options, n_col=4)
plot_loss_err(trainer, options)
else:
now = options.mode
save_dir = os.path.join('./results/rnn', now)
# load the configuration file to args
t_args = argparse.Namespace()
d = json.load(open(os.path.join(save_dir, 'configs.json')))
for k in list(d.keys()):
if k == '_get_args' or k == '_get_kwargs':
del d[k]
t_args.__dict__.update(d)
options = parser.parse_args(namespace=t_args)
print(options.__dict__)
# load the model
ckpt = torch.load(os.path.join(save_dir, 'models', 'most_recent_model.pth'))
options.save_dir = save_dir
place_cell = PlaceCells(options)
model = RNN(options, place_cell).to(options.device)
model.load_state_dict(ckpt)
print('Plotting weights...')
Wr = model.RNN.weight_hh_l0.detach().cpu().numpy()
plot_weights(Wr, options)
generator = TrajectoryGenerator(options, place_cell)
trainer = Trainer(options, model, generator, place_cell, restore=False)
print('Generating rate maps...')
rate_map = compute_ratemaps(
model, trainer, generator, options, res=30, n_avg=200, Ng=options.Ng
)
# calculate grid scores
print('Generating low resolution rate maps...')
lo_res = 20
rate_map_lo_res = compute_ratemaps(
model, trainer, generator, options, res=lo_res, n_avg=200, Ng=options.Ng
)
# scores are already sorted in descending order
print('Calculating grid scores...')
idx, scores = compute_grid_scores(lo_res, rate_map_lo_res, options) # descending order
# select the top grid cells
plot_all_ratemaps(rate_map[idx], options, scores)
# save scores
np.save(os.path.join(save_dir, 'grid_scores.npy'), scores)
# save top 64 grid cells
np.save(os.path.join(save_dir, 'top64_grid_cells.npy'), rate_map[idx[:64]])
# border score
print('Calculating border scores...')
idx_border, scores_border = compute_border_scores(lo_res, rate_map_lo_res, options)
plot_all_ratemaps(rate_map[idx_border], options, scores_border, dir='all_maps_border')