Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytests round 2 #851

Merged
merged 6 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/add_dj_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
"This script is deprecated. "
+ "Use spyglass.utils.database_settings.DatabaseSettings instead."
)
DatabaseSettings(user_name=sys.argv[1]).add_dj_user()
DatabaseSettings(user_name=sys.argv[1]).add_user(check_exists=True)
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ addopts = [
# "--pdb", # drop into debugger on failure
"-p no:warnings",
# "--no-teardown", # don't teardown the database after tests
"--quiet-spy", # don't show logging from spyglass
# "--quiet-spy", # don't show logging from spyglass
"--show-capture=no",
"--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger
"--cov=spyglass",
Expand All @@ -146,13 +146,15 @@ omit = [ # which submodules have no tests
"*/decoding/*",
"*/figurl_views/*",
# "*/lfp/*",
"*/linearization/*",
# "*/linearization/*",
"*/lock/*",
"*/position/*",
"*/mua/*",
# "*/position/*",
"*/ripple/*",
"*/sharing/*",
"*/spikesorting/*",
# "*/utils/*",
"settings.py",
]

[tool.ruff] # CB: Propose replacing flake8 with ruff to delete setup.cfg
Expand Down
18 changes: 10 additions & 8 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def populate(self, keys=None):
"""
if not isinstance(keys, list):
keys = [keys]
if isinstance(keys[0], dj.Table):
if isinstance(keys[0], (dj.Table, dj.expression.QueryExpression)):
keys = [k for tbl in keys for k in tbl.fetch("KEY", as_dict=True)]
for key in keys:
nwb_file_name = key.get("nwb_file_name")
Expand All @@ -60,10 +60,10 @@ def populate(self, keys=None):
"PositionSource.populate is an alias for a non-computed table "
+ "and must be passed a key with nwb_file_name"
)
self.insert_from_nwbfile(nwb_file_name)
self.insert_from_nwbfile(nwb_file_name, skip_duplicates=True)

@classmethod
def insert_from_nwbfile(cls, nwb_file_name):
def insert_from_nwbfile(cls, nwb_file_name, skip_duplicates=False) -> None:
"""Add intervals to ItervalList and PositionSource.

Given an NWB file name, get the spatial series and interval lists from
Expand Down Expand Up @@ -111,9 +111,11 @@ def insert_from_nwbfile(cls, nwb_file_name):
)

with cls.connection.transaction:
IntervalList.insert(intervals)
cls.insert(sources)
cls.SpatialSeries.insert(spat_series)
IntervalList.insert(intervals, skip_duplicates=skip_duplicates)
cls.insert(sources, skip_duplicates=skip_duplicates)
cls.SpatialSeries.insert(
spat_series, skip_duplicates=skip_duplicates
)

# make map from epoch intervals to position intervals
populate_position_interval_map_session(nwb_file_name)
Expand Down Expand Up @@ -305,7 +307,7 @@ def make(self, key):
"Unable to import StateScriptFile: no processing module named "
+ '"associated_files" found in {nwb_file_name}.'
)
return
return # See #849

for associated_file_obj in associated_files.data_interfaces.values():
if not isinstance(
Expand Down Expand Up @@ -545,7 +547,7 @@ def _no_transaction_make(self, key):
# Check that each pos interval was matched to only one epoch
if len(matching_pos_intervals) != 1:
# TODO: Now that populate_all accept errors, raise here?
logger.error(
logger.warning(
f"Found {len(matching_pos_intervals)} pos intervals for {key}; "
+ f"{no_pop_msg}\n{matching_pos_intervals}"
)
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/common/common_dio.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def make(self, key):
"No conforming behavioral events data interface found in "
+ f"{nwb_file_name}\n"
)
return
return # See #849

# Times for these events correspond to the valid times for the raw data
key["interval_list_name"] = (
Expand Down
4 changes: 2 additions & 2 deletions src/spyglass/common/common_ephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def create_from_config(cls, nwb_file_name: str):
nwbf = get_nwb_file(nwb_file_abspath)
config = get_config(nwb_file_abspath)
if "Electrode" not in config:
return
return # See #849

# map electrode id to dictof electrode information from config YAML
electrode_dicts = {
Expand Down Expand Up @@ -341,7 +341,7 @@ def make(self, key):
"Unable to import SampleCount: no data interface named "
+ f'"sample_count" found in {nwb_file_name}.'
)
return
return # see #849
key["sample_count_object_id"] = sample_count.object_id
self.insert1(key)

Expand Down
24 changes: 16 additions & 8 deletions src/spyglass/common/signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@


def hilbert_decomp(lfp_band_object, sampling_rate=1):
"""generates the analytical decomposition of the signals in the lfp_band_object

:param lfp_band_object: bandpass filtered LFP
:type lfp_band_object: pynwb electrical series
:param sampling_rate: bandpass filtered LFP sampling rate (defaults to 1; only used for instantaneous frequency)
:type sampling_rate: int
:return: envelope, phase, frequency
:rtype: pynwb electrical series objects
"""Generates analytical decomposition of signals in the lfp_band_object

NOTE: This function is not currently used in the pipeline.

Parameters
----------
lfp_band_object : pynwb.ecephys.ElectricalSeries
bandpass filtered LFP
sampling_rate : int, optional
bandpass filtered LFP sampling rate
(defaults to 1; only used for instantaneous frequency)

Returns
-------
envelope : pynwb.ecephys.ElectricalSeries
envelope of the signal
"""
analytical_signal = signal.hilbert(lfp_band_object.data, axis=0)

Expand Down
44 changes: 21 additions & 23 deletions src/spyglass/lfp/analysis/v1/lfp_band.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,24 @@ def set_lfp_band_electrodes(
available_electrodes = query.fetch("electrode_id")
if not np.all(np.isin(electrode_list, available_electrodes)):
raise ValueError(
"All elements in electrode_list must be valid electrode_ids in the LFPElectodeGroup table"
"All elements in electrode_list must be valid electrode_ids in"
+ " the LFPElectodeGroup table: "
+ f"{electrode_list} not in {available_electrodes}"
)
# sampling rate
lfp_sampling_rate = LFPOutput.merge_get_parent(lfp_key).fetch1(
"lfp_sampling_rate"
)
decimation = lfp_sampling_rate // lfp_band_sampling_rate
if lfp_sampling_rate // decimation != lfp_band_sampling_rate:
raise ValueError(
f"lfp_band_sampling rate {lfp_band_sampling_rate} is not an integer divisor of lfp "
f"samping rate {lfp_sampling_rate}"
)
# filter
filter_query = FirFilterParameters() & {
"filter_name": filter_name,
"filter_sampling_rate": lfp_sampling_rate,
}
if not filter_query:
raise ValueError(
f"filter {filter_name}, sampling rate {lfp_sampling_rate} is not in the FirFilterParameters table"
f"Filter {filter_name}, sampling rate {lfp_sampling_rate} is "
+ "not in the FirFilterParameters table"
)
# interval_list
interval_query = IntervalList() & {
Expand All @@ -108,22 +106,23 @@ def set_lfp_band_electrodes(
}
if not interval_query:
raise ValueError(
f"interval list {interval_list_name} is not in the IntervalList table; the list must be "
"added before this function is called"
f"interval list {interval_list_name} is not in the IntervalList"
" table; the list must be added before this function is called"
)
# reference_electrode_list
if len(reference_electrode_list) != 1 and len(
reference_electrode_list
) != len(electrode_list):
raise ValueError(
"reference_electrode_list must contain either 1 or len(electrode_list) elements"
"reference_electrode_list must contain either 1 or "
+ "len(electrode_list) elements"
)
# add a -1 element to the list to allow for the no reference option
available_electrodes = np.append(available_electrodes, [-1])
if not np.all(np.isin(reference_electrode_list, available_electrodes)):
raise ValueError(
"All elements in reference_electrode_list must be valid electrode_ids in the LFPSelection "
"table"
"All elements in reference_electrode_list must be valid "
"electrode_ids in the LFPSelection table"
)

# make a list of all the references
Expand Down Expand Up @@ -204,8 +203,8 @@ def make(self, key):
"interval_list_name": interval_list_name,
}
).fetch1("valid_times")
# the valid_times for this interval may be slightly beyond the valid times for the lfp itself,
# so we have to intersect the two lists
# the valid_times for this interval may be slightly beyond the valid
# times for the lfp itself, so we have to intersect the two lists
lfp_valid_times = (
IntervalList()
& {
Expand All @@ -228,7 +227,8 @@ def make(self, key):

# load in the timestamps
timestamps = np.asarray(lfp_object.timestamps)
# get the indices of the first timestamp and the last timestamp that are within the valid times
# get the indices of the first timestamp and the last timestamp that
# are within the valid times
included_indices = interval_list_contains_ind(
lfp_band_valid_times, timestamps
)
Expand Down Expand Up @@ -267,11 +267,6 @@ def make(self, key):
& {"filter_name": filter_name}
& {"filter_sampling_rate": filter_sampling_rate}
).fetch(as_dict=True)
if len(filter) == 0:
raise ValueError(
f"Filter {filter_name} and sampling_rate {lfp_band_sampling_rate} does not exit in the "
"FirFilterParameters table"
)

filter_coeff = filter[0]["filter_coeff"]
if len(filter_coeff) == 0:
Expand Down Expand Up @@ -378,7 +373,9 @@ def fetch1_dataframe(self, *attrs, **kwargs):
)

def compute_analytic_signal(self, electrode_list: list[int], **kwargs):
"""Computes the hilbert transform of a given LFPBand signal using scipy.signal.hilbert
CBroz1 marked this conversation as resolved.
Show resolved Hide resolved
"""Computes the hilbert transform of a given LFPBand signal

Uses scipy.signal.hilbert to compute the hilbert transform of the signal

Parameters
----------
Expand All @@ -393,7 +390,7 @@ def compute_analytic_signal(self, electrode_list: list[int], **kwargs):
Raises
------
ValueError
If any electrodes passed to electrode_list are invalid for the dataset
If items in electrode_list are invalid for the dataset
"""

filtered_band = self.fetch_nwb()[0]["lfp_band"]
Expand All @@ -402,7 +399,8 @@ def compute_analytic_signal(self, electrode_list: list[int], **kwargs):
)
if len(electrode_list) != np.sum(electrode_index):
raise ValueError(
"Some of the electrodes specified in electrode_list are missing in the current LFPBand table."
"Some of the electrodes specified in electrode_list are missing"
+ " in the current LFPBand table."
)
analytic_signal_df = pd.DataFrame(
hilbert(filtered_band.data[:, electrode_index], axis=0),
Expand Down
8 changes: 4 additions & 4 deletions src/spyglass/lfp/lfp_imported.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import datajoint as dj

from spyglass.common.common_interval import IntervalList
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.common.common_session import Session
from spyglass.lfp.lfp_electrode import LFPElectrodeGroup
from spyglass.common.common_interval import IntervalList # noqa: F401
from spyglass.common.common_nwbfile import AnalysisNwbfile # noqa: F401
from spyglass.common.common_session import Session # noqa: F401
from spyglass.lfp.lfp_electrode import LFPElectrodeGroup # noqa: F401
from spyglass.utils.dj_mixin import SpyglassMixin

schema = dj.schema("lfp_imported")
Expand Down
7 changes: 4 additions & 3 deletions src/spyglass/lfp/v1/lfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def make(self, key):
"target_sampling_rate"
)

# to get the list of valid times, we need to combine those from the user with those from the
# raw data
# to get the list of valid times, we need to combine those from the
# user with those from the raw data
orig_key = copy.deepcopy(key)
orig_key["interval_list_name"] = key["target_interval_list_name"]
user_valid_times = (IntervalList() & orig_key).fetch1("valid_times")
Expand Down Expand Up @@ -120,7 +120,8 @@ def make(self, key):
"LFP: no filter found with data sampling rate of "
+ f"{sampling_rate}"
)
return None
return None # See #849

# get the list of selected LFP Channels from LFPElectrode
electrode_keys = (LFPElectrodeGroup.LFPElectrode & key).fetch("KEY")
electrode_id_list = list(k["electrode_id"] for k in electrode_keys)
Expand Down
4 changes: 3 additions & 1 deletion src/spyglass/lfp/v1/lfp_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
class LFPArtifactDetectionParameters(SpyglassMixin, dj.Manual):
definition = """
# Parameters for detecting LFP artifact times within a LFP group.
artifact_params_name: varchar(200)
artifact_params_name: varchar(64)
---
artifact_params: blob # dictionary of parameters
"""

# See #630, #664. Excessive key length.

def insert_default(self):
"""Insert the default artifact parameters."""
diff_params = [
Expand Down
18 changes: 9 additions & 9 deletions src/spyglass/position/v1/position_trodes_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.common.common_position import IntervalPositionInfo
from spyglass.position.v1.dlc_utils import check_videofile, get_video_path
from spyglass.utils import logger
from spyglass.utils.dj_mixin import SpyglassMixin
from spyglass.utils import SpyglassMixin, logger

schema = dj.schema("position_v1_trodes_position")

Expand Down Expand Up @@ -158,7 +157,7 @@ class TrodesPosV1(SpyglassMixin, dj.Computed):
"""

def make(self, key):
print(f"Computing position for: {key}")
logger.info(f"Computing position for: {key}")
orig_key = copy.deepcopy(key)

analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"])
Expand Down Expand Up @@ -220,8 +219,9 @@ def fetch1_dataframe(self, add_frame_ind=True):
TrodesPosParams & {"trodes_pos_params_name": pos_params}
).fetch1("params")["is_upsampled"]
):
logger.warn(
"Upsampled position data, frame indices are invalid. Setting add_frame_ind=False"
logger.warning(
"Upsampled position data, frame indices are invalid. "
+ "Setting add_frame_ind=False"
)
add_frame_ind = False
return IntervalPositionInfo._data_to_df(
Expand All @@ -245,7 +245,7 @@ class TrodesPosVideo(SpyglassMixin, dj.Computed):
def make(self, key):
M_TO_CM = 100

print("Loading position data...")
logger.info("Loading position data...")
raw_position_df = (
RawPosition.PosObject
& {
Expand All @@ -255,7 +255,7 @@ def make(self, key):
).fetch1_dataframe()
position_info_df = (TrodesPosV1() & key).fetch1_dataframe()

print("Loading video data...")
logger.info("Loading video data...")
epoch = (
int(
key["interval_list_name"]
Expand Down Expand Up @@ -299,7 +299,7 @@ def make(self, key):
position_time = np.asarray(position_info_df.index)
cm_per_pixel = meters_per_pixel * M_TO_CM

print("Making video...")
logger.info("Making video...")
self.make_video(
video_path,
centroids,
Expand Down Expand Up @@ -367,7 +367,7 @@ def make_video(
frame_size = (int(video.get(3)), int(video.get(4)))
frame_rate = video.get(5)
n_frames = int(orientation_mean.shape[0])
print(f"video filepath: {output_video_filename}")
logger.info(f"video filepath: {output_video_filename}")
out = cv2.VideoWriter(
output_video_filename, fourcc, frame_rate, frame_size, True
)
Expand Down
Loading
Loading