Skip to content

Commit

Permalink
Address comments and fix test script
Browse files Browse the repository at this point in the history
  • Loading branch information
carolinafernandezp committed Aug 8, 2023
1 parent 3d18c4c commit 3828c1b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 12 deletions.
18 changes: 16 additions & 2 deletions hnn_core/opt_toy_example/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,14 @@ def __init__(self, net, constraints, solver, obj_fun):
>>>>>>> 1a7e98b (Address comments for more generalized routine)
=======
tstop, scale_factor=1., smooth_window_len=None):
<<<<<<< HEAD
>>>>>>> 46f1268 (Add tests and address comments)
=======
if net.external_drives:
raise ValueError("The current Network instance has external " +
"drives, provide a Network object with no " +
"drives.")
>>>>>>> 51112fc (Address comments and fix test script)
self.net = net
self.constraints = constraints
self._set_params = set_params
Expand All @@ -73,9 +80,16 @@ def __init__(self, net, constraints, solver, obj_fun):
self._assemble_constraints = _assemble_constraints_cobyla
self._run_opt = _run_opt_cobyla
<<<<<<< HEAD
<<<<<<< HEAD
=======
else:
raise ValueError("solver must be 'bayesian' or 'cobyla'")
>>>>>>> 51112fc (Address comments and fix test script)
# Response to be optimized
if obj_fun == 'evoked':
self.obj_fun = _rmse_evoked
else:
raise ValueError("obj_fun must be 'evoked'")
self.scale_factor = scale_factor
self.smooth_window_len = smooth_window_len
self.tstop = tstop
Expand Down Expand Up @@ -440,8 +454,8 @@ def _get_initial_params(constraints):

initial_params = dict()
for cons_key in constraints:
initial_params.update({cons_key: (constraints[cons_key][0] +
constraints[cons_key][1])/2})
initial_params.update({cons_key: ((constraints[cons_key][0] +
constraints[cons_key][1]))/2})

return initial_params

Expand Down
5 changes: 3 additions & 2 deletions hnn_core/opt_toy_example/optimize_evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
of the model simulation to match an experimental dipole waveform.
"""

# Authors: Blake Caldwell <[email protected]>
# Authors: Carolina Fernandez <[email protected]>
# Nick Tolley <[email protected]>
# Ryan Thorpe <[email protected]>
# Mainak Jas <[email protected]>
# Carolina Fernandez <[email protected]>

import os.path as op

Expand Down
18 changes: 10 additions & 8 deletions hnn_core/opt_toy_example/test_optimize_evoked.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Authors: Mainak Jas <[email protected]>
# Carolina Fernandez <[email protected]>
# Authors: Carolina Fernandez <[email protected]>

from hnn_core import jones_2009_model, simulate_dipole
from general import Optimizer # change path***
Expand All @@ -8,12 +7,14 @@
def _optimize_evoked(solver):
"""Test running the full routine in a reduced network."""

tstop = 5.
tstop = 10.
n_trials = 1

# simulate a dipole to establish ground-truth drive parameters
net_orig = jones_2009_model()
mu_orig = 6.
net_orig._N_pyr_x = 3
net_orig._N_pyr_y = 3
mu_orig = 2.
weights_ampa = {'L2_basket': 0.5,
'L2_pyramidal': 0.5,
'L5_basket': 0.5,
Expand All @@ -31,6 +32,8 @@ def _optimize_evoked(solver):

# define set_params function and constraints
net_offset = jones_2009_model()
net_offset._N_pyr_x = 3
net_offset._N_pyr_y = 3

def set_params(net_offset, param_dict):
weights_ampa = {'L2_basket': 0.5,
Expand All @@ -48,8 +51,7 @@ def set_params(net_offset, param_dict):
synaptic_delays=synaptic_delays)

# define constraints
mu_offset = 4. # initial time-shifted drive
mu_range = (2, 8)
mu_range = (1, 6)
constraints = dict()
constraints.update({'mu_offset': mu_range})

Expand All @@ -58,9 +60,9 @@ def set_params(net_offset, param_dict):
obj_fun='evoked', tstop=tstop)
optim.fit(dpl_orig.data['agg'])

opt_param = optim.opt_params
opt_param = optim.opt_params[0]
# the optimized parameter is in the range
assert opt_param[0] in range(mu_range[0], mu_range[1]), "Optimized parameter is not in user-defined range"
assert mu_range[0] <= opt_param <= mu_range[1], "Optimized parameter is not in user-defined range"

obj = optim.obj
# the number of returned rmse values should be the same as max_iter
Expand Down

0 comments on commit 3828c1b

Please sign in to comment.