-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathDemo.py
121 lines (109 loc) · 4.52 KB
/
Demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Example Script
from __future__ import division
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from numpy.random import randn, randint
from numpy import zeros, transpose, min, max, array, prod, percentile
from scipy.io import loadmat
from scipy.ndimage.filters import gaussian_filter
from sys import argv
from BlockGroupLasso import gaussian_group_lasso, GetCenters
from BlockLocalNMF import LocalNMF
data_source = 1 if len(argv) == 1 else int(argv[1])
plt.close('all')
# Fetch Data
if data_source == 1: # generate 2D model data
T = 30 # duration of the simulation
sz = (150, 100) # size of image
sig = (5, 5) # neurons size
foo = 0.1 * randn(*((T,) + sz))
bar = zeros((T,) + sz)
N = 15 # number of neurons
lam = 1
for i in range(N):
ind = tuple([randint(x) for x in sz])
for j in range(T):
bar[(j,) + ind] = abs(randn())
data = foo + 10 * gaussian_filter(bar, (0,) + sig)
TargetArea = N * prod(2 * array(sig)) / prod(sz)
TargetRange = [TargetArea * 0.8, TargetArea * 1.2]
NonNegative = True
lam = 1
elif data_source == 2: # Use experimental 2D data
mat = loadmat('Datasets/data_exp2D')
data = transpose(mat['data'], [2, 0, 1])
sig = (6, 6) # estimated neurons size
N = 40 # estimated number of neurons
TargetArea = N * prod(2 * array(sig)) / prod(data[0, :, :].shape)
TargetRange = [TargetArea * 0.8, TargetArea * 1.2]
NonNegative = True
lam = 1
elif data_source == 3: # Use experimental 3D data
mat = loadmat('Datasets/data_exp3D')
data = transpose(mat['data'], [3, 0, 1, 2])
sig = (2, 2, 2) # neurons size
TargetRange = [0.005, 0.015]
NonNegative = True
lam = 0.001
# Run source detection algorithms
x = gaussian_group_lasso(data, sig, lam,
NonNegative=NonNegative, TargetAreaRatio=TargetRange, verbose=True, adaptBias=True)
# x = gaussian_group_lasso(data[:len(data) / 5 * 5].reshape((-1, 5) + data.shape[1:]).max(1), sig, lam/5.,
# NonNegative=NonNegative, TargetAreaRatio=TargetRange, verbose=True, adaptBias=True)
pic_x = percentile(x, 95, axis=0)
pic_data = percentile(data, 95, axis=0)
# centers extracted from fista output using RegionalMax
cent = GetCenters(pic_x)
MSE_array, shapes, activity, boxes = LocalNMF(
data, (array(cent)[:-1]).T, sig,
NonNegative=NonNegative, verbose=True, adaptBias=True)
L = len(cent[0]) # number of detected neurons
denoised_data = activity[:L].T.dot(shapes[:L].reshape(L, -1)).reshape(data.shape)
pic_denoised = percentile(denoised_data, 95, axis=0)
residual = data - activity.T.dot(shapes.reshape(len(shapes), -1)).reshape(data.shape)
# Plot Results
plt.figure(figsize=(12, 4. * data.shape[1] / data.shape[2]))
ax = plt.subplot(131)
ax.scatter(cent[1], cent[0], s=7 * sig[1], marker='o', c='white')
plt.hold(True)
ax.set_title('Data + centers')
ax.imshow(pic_data if data_source != 3 else pic_data.max(-1))
ax2 = plt.subplot(132)
ax2.scatter(cent[1], cent[0], s=7 * sig[1], marker='o', c='white')
ax2.imshow(pic_x if data_source != 3 else pic_x.max(-1))
ax2.set_title('Inferred x')
ax3 = plt.subplot(133)
ax3.scatter(cent[1], cent[0], s=7 * sig[1], marker='o', c='white')
ax3.imshow(pic_denoised if data_source != 3 else pic_denoised.max(-1))
ax3.set_title('Denoised data')
plt.show()
fig = plt.figure()
plt.plot(MSE_array)
plt.xlabel('Iteration')
plt.ylabel('MSE')
plt.show()
# Video Results
fig = plt.figure(figsize=(12, 4. * data.shape[1] / data.shape[2]))
mi = min(data)
ma = max(data)
ii = 0
ax = plt.subplot(131)
ax.scatter(cent[1], cent[0], s=7 * sig[1], marker='o', c='white')
im = ax.imshow(data[ii] if data_source != 3 else data[ii].max(-1), vmin=mi, vmax=ma)
ax.set_title('Data + centers')
ax2 = plt.subplot(132)
ax2.scatter(cent[1], cent[0], s=7 * sig[1], marker='o', c='white')
im2 = ax2.imshow(residual[ii] if data_source != 3 else residual[ii].max(-1), vmin=mi, vmax=ma)
ax2.set_title('Residual')
ax3 = plt.subplot(133)
ax3.scatter(cent[1], cent[0], s=7 * sig[1], marker='o', c='white')
im3 = ax3.imshow(denoised_data[ii] if data_source !=
3 else denoised_data[ii].max(-1), vmin=mi, vmax=ma)
ax3.set_title('Denoised')
def update(ii):
im.set_data(data[ii] if data_source != 3 else data[ii].max(-1))
im2.set_data(residual[ii] if data_source != 3 else residual[ii].max(-1))
im3.set_data(denoised_data[ii] if data_source != 3 else denoised_data[ii].max(-1))
ani = animation.FuncAnimation(fig, update, frames=len(data), blit=False, interval=30,
repeat=False)
plt.show()