Skip to content

Commit

Permalink
updating figure scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Apr 30, 2023
1 parent d767135 commit b40a61e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 113 deletions.
7 changes: 4 additions & 3 deletions paper/fig3.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def panels_scaling(data_path, dbs, grid, trans, il):

vis = np.array([db["visual"] for db in dbs])
ls = ['-','--']
lstr = ['--', '-- ']
lstr = [' -', '--']
mstr = ['network', 'linear']
for k in range(2):
for j, inds in enumerate([vis, ~vis]):
Expand Down Expand Up @@ -278,7 +278,7 @@ def panels_scaling(data_path, dbs, grid, trans, il):
)
y = 0.33 - i * 0.14
ax.text(
x+0.08, y, mstr[i], transform=ax.transAxes,
x+0.07, y, mstr[i], transform=ax.transAxes,
)

if j == 0:
Expand Down Expand Up @@ -343,7 +343,8 @@ def panels_cum_varexp(data_path, dbs, axs):
ax.fill_between(
np.arange(1, 129), vem + ves, vem - ves, color=colors[j//2], alpha=0.25
)
print(f"{lbls[j//2]}, {mstr[j%2]}, 1st pc= {vem[0]}")
print(f"{k}, {lbls[j//2]}, {mstr[j%2]}, 1st pc= {vem[0]}")
print(f"{k}, {lbls[j//2]}, {mstr[j%2]}, 64th pc= {vem[63]}")
xt = 2 ** np.arange(0, 10, 2)
ax.set_xticks(xt)
ax.set_xticklabels([str(x) for x in xt])
Expand Down
119 changes: 9 additions & 110 deletions paper/suppfigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def varexp_ranks(data_path, dbs, evals=None, save_fig=False):
)
ax.set_xlabel("ranks")
ax.set_xscale("log")
ax.set_xticks([1,4,16,64,128])
ax.set_xticklabels(["1", "4", "16", "32", "128"])
ax.set_xticks([1,4,16,64])
ax.set_xticklabels(["1", "4", "16", "64"])
ax.set_xlim([1,128])
ax.set_ylim([0, 38])

Expand Down Expand Up @@ -227,51 +227,17 @@ def model_complexity_AP(data_path, dbs, save_fig=False):
ve_latents[iexp] = (d["varexps_latents_neurons"].mean(axis=0) / ve_expl) * 100
ve_filts[iexp] = (d["varexps_filts_neurons"].mean(axis=0) / ve_expl) * 100

mstrs = [f"{db['mname']}_{db['datexp']}_{db['blk']}" for db in dbs]
nbins = 5
improvement = np.zeros((len(dbs), nbins))
ve_all = np.zeros((len(dbs), nbins))
xposs = []
yposs = []
ccol = []
for iexp, mstr in enumerate(mstrs):
dat = np.load(f"{data_path}/neural_data/spont_{mstr}.npz")
inds = dat["xpos"].argsort()
nneus = np.linspace(0, len(inds), nbins + 1).astype(int)
ve_net = np.load(f"{data_path}/proc/neuralpred/{mstr}_net_pred_test.npz")[
"varexp_neurons"
][:, 1]
ve_lin = np.load(f"{data_path}/proc/neuralpred/{mstr}_rrr_pred_test.npz")[
"varexp_neurons"
][1]

# ve_svd = np.load(f"{data_path}/proc/neuralpred/{mstr}_svd_pred_test.npz")[
# "varexp_neurons"
# ]
for i in range(nbins):
ineu = inds[nneus[i] : nneus[i + 1]]
ve0 = ve_net[ineu].mean()
ve1 = ve_lin[ineu].mean()
improvement[iexp, i] = ((ve0 - ve1) / ve1) * 100

if iexp == 2 or iexp == 10:
cc = ((ve_net - ve_lin) / (ve_lin)) * 100
igood = ve_lin > 1e-2
xposs.append(dat["xpos"][igood])
yposs.append(dat["ypos"][igood])
ccol.append(cc[igood])

fig = plt.figure(figsize=(12, 7))
fig = plt.figure(figsize=(12, 4))
yratio = 12 / 4
trans = mtransforms.ScaledTranslation(-40 / 72, 20 / 72, fig.dpi_scale_trans)
grid = plt.GridSpec(
2,
1,
6,
figure=fig,
left=0.08,
right=0.98,
top=0.9,
bottom=0.13,
right=0.9,
top=0.8,
bottom=0.35,
wspace=0.6,
hspace=1.5,
)
Expand Down Expand Up @@ -371,78 +337,11 @@ def model_complexity_AP(data_path, dbs, save_fig=False):
ve_latents[iexp] = (d["varexps_latents_neurons"].mean(axis=0) / ve_expl) * 100
ve_filts[iexp] = (d["varexps_filts_neurons"].mean(axis=0) / ve_expl) * 100


for i in range(2):
xpos, ypos, c = xposs[i], -1 * yposs[i], ccol[i]
ax = plt.subplot(grid[1, i])
if i==0:
il = plot_label(ltr, il, ax, trans, fs_title)
if 1:
pos = ax.get_position()
ax.axis("off")
pos = [pos.x0, pos.y0, pos.width, pos.height]
ax = fig.add_axes(
[pos[0] - 0.04 + 0.02*i, pos[1] - 0.02-0.02*i, pos[2] + 0.03+0.03*i, pos[3] + 0.03+0.03*i]
)
if i==0:
ax.set_title(" net prediction improvement over linear\n ", fontsize="medium")
add_apml(ax, xpos, ypos)

im = ax.scatter(
ypos,
xpos,
c=c,
vmin=-300,
vmax=300,
cmap="bwr",
s=1,
alpha=1,
rasterized=True,
)
ax.axis("square")
ax.axis("off")
if i == 1:
cx = pos[0]+pos[2]*0.65
cy = pos[1]+pos[3]*0.55
cbar = plt.colorbar(im, label="",
cax=grid.figure.add_axes([cx, cy, 0.005, 0.05]),
)
cbar.set_label(label='% improvement', size='small')
cbar.ax.tick_params(labelsize=10)

colors = [viscol, smcol]
vis = np.array([db["visual"] for db in dbs])
trans = mtransforms.ScaledTranslation(-50 / 72, 20 / 72, fig.dpi_scale_trans)
for i, inds in enumerate([vis, ~vis]):
ax = plt.subplot(grid[1, 3+i])
pos = ax.get_position()
ax.axis("off")
pos = [pos.x0, pos.y0, pos.width, pos.height]
ax = fig.add_axes([pos[0] - 0.07, pos[1], pos[2], pos[3]])
impr = improvement[inds]
plt.plot(impr.T, color=colors[i], alpha=0.5)
plt.errorbar(
np.arange(0, impr.shape[1]),
impr.mean(axis=0),
impr.std(axis=0) / impr.shape[0] ** 0.5,
color=colors[i],
lw=3,
)
plt.ylim([0, 300])
ax.set_xticks([0, 4])
ax.set_xticklabels([" posterior", "anterior "])
if i == 0:
il = plot_label(ltr, il, ax, trans, fs_title)
ax.set_ylabel("% improvement")
ax.set_title("visual", fontsize="medium")
else:
ax.set_title("sensorimotor", fontsize="medium")

ax = plt.subplot(grid[1, 5])
ax = plt.subplot(grid[0, 5])
pos = ax.get_position()
ax.axis("off")
pos = [pos.x0, pos.y0, pos.width, pos.height]
ax = fig.add_axes([pos[0] - 0.03, pos[1], pos[2], pos[3]])
ax = fig.add_axes([pos[0] + 0.03, pos[1], pos[2], pos[3]])
mstrs = [f"{db['mname']}_{db['datexp']}_{db['blk']}" for db in dbs]
ve_all = []
for j, mstr in enumerate(mstrs):
Expand Down

0 comments on commit b40a61e

Please sign in to comment.