Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Sep 18, 2024
2 parents ddfd01f + c82f23e commit 094e002
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 29 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Add docstrings to all public methods #1076
- Update DataJoint to 0.14.2 #1081
- Allow restriction based on parent keys in `Merge.fetch_nwb()` #1086
- Import `datajoint.dependencies.unite_master_parts` -> `topo_sort` #1116

### Pipelines

Expand All @@ -46,6 +47,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()

- Fix bug in `get_group_by_shank` #1096
- Fix bug in `_compute_metric` #1099
- Fix bug in `insert_curation` returned key #1114

## [0.5.3] (August 27, 2024)

Expand Down
22 changes: 12 additions & 10 deletions docs/src/ForDevelopers/UsingNWB.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,26 +144,28 @@ ndx_franklab_novela.CameraDevice </b>
| Spyglass Table | Key | NWBfile Location | Config option | Notes |
| :------------- | :---------------: | ----------------------------------------: | -----------------------------------------: | ----: |
| Probe | probe_type | nwbf.devices.\<\*Probe>.probe_type | config\["Probe"\]\[index\]\["probe_type"\] | str |
| Probe | probe_id | nwbf.devices.\<\*Probe>.probe_type | XXX | str |
| Probe | manufacturer | nwbf.devices.\<\*Probe>.manufacturer | XXX | str |
| Probe | probe_description | nwbf.devices.\<\*Probe>.probe_description | XXX | str |
| Probe | probe_id | nwbf.devices.\<\*Probe>.probe_type | config\["Probe"\]\[index\]\["probe_type"\] | str |
| Probe | manufacturer | nwbf.devices.\<\*Probe>.manufacturer | config\["Probe"\]\[index\]\["manufacturer"\] | str |
| Probe | probe_description | nwbf.devices.\<\*Probe>.probe_description | config\["Probe"\]\[index\]\["description"\] | str |
| Probe | num_shanks | nwbf.devices.\<\*Probe>.num_shanks | XXX | int |

<b> NWBfile Location: nwbf.devices.\<\*Probe>.\<\*Shank> <br/> Object type:
ndx_franklab_novela.Shank </b>

| Spyglass Table | Key | NWBfile Location | Config option | Notes |
| :------------- | :---------: | ---------------------------------------------: | ------------: | ----: |
| Probe.Shank | probe_shank | nwbf.devices.\<\*Probe>.\<\*Shank>.probe_shank | XXX | int |
| Probe.Shank | probe_shank | nwbf.devices.\<\*Probe>.\<\*Shank>.probe_shank | config\["Probe"\]\[Shank\]\ | int | In the config, a list of ints |

<b> NWBfile Location: nwbf.devices.\<\*Probe>.\<\*Shank>.\<\*Electrode> <br/>
Object type: ndx_franklab_novela.Electrode </b>

| Spyglass Table | Key | NWBfile Location | Config option | Notes |
| :-------------- | :----------: | -------------------------------------------------------------: | ------------: | ----: |
| Probe.Electrode | probe_shank | nwbf.devices.\<\*Probe>.\<\*Shank>.probe_shank | XXX | int |
| Probe.Electrode | contact_size | nwbf.devices.\<\*Probe>.\<\*Shank>.\<\*Electrode>.contact_size | XXX | float |
| Probe.Electrode | rel_x | nwbf.devices.\<\*Probe>.\<\*Shank>.\<\*Electrode>.rel_x | XXX | float |
| Probe.Electrode | probe_shank | nwbf.devices.\<\*Probe>.\<\*Shank>.probe_shank | config\["Probe"]\["Electrode"]\[index]\["probe_shank"] | int |
| Probe.Electrode | contact_size | nwbf.devices.\<\*Probe>.\<\*Shank>.\<\*Electrode>.contact_size | config\["Probe"]\["Electrode"]\[index]\["contact_size"] | float |
| Probe.Electrode | rel_x | nwbf.devices.\<\*Probe>.\<\*Shank>.\<\*Electrode>.rel_x | config\["Probe"]\["Electrode"]\[index]\["rel_x"] | float |
| Probe.Electrode | rel_y | nwbf.devices.\<\*Probe>.\<\*Shank>.\<\*Electrode>.rel_y | config\["Probe"]\["Electrode"]\[index]\["rel_y"] | float |
| Probe.Electrode | rel_z | nwbf.devices.\<\*Probe>.\<\*Shank>.\<\*Electrode>.rel_z | config\["Probe"]\["Electrode"]\[index]\["rel_z"] | float |

<b> NWBfile Location: nwbf.epochs <br/> Object type: pynwb.epoch.TimeIntervals
</b>
Expand Down Expand Up @@ -213,9 +215,9 @@ hdmf.common.table.DynamicTable </b>
| :------------- | :--------------: | -----------------------------------------------: | ------------: | ----: |
| Task | task_name | nwbf.processing.tasks.\[index\].name | | |
| Task | task_description | nwbf.processing.\[index\].tasks.description | | |
| TaskEpoch | task_name | nwbf.processing.\[index\].tasks.name | | |
| TaskEpoch | camera_names | nwbf.processing.\[index\].tasks.camera_id | | |
| TaskEpoch | task_environment | nwbf.processing.\[index\].tasks.task_environment | | |
| TaskEpoch | task_name | nwbf.processing.\[index\].tasks.name | config\["Tasks"\]\[index\]\["task_name"\]| |
| TaskEpoch | camera_names | nwbf.processing.\[index\].tasks.camera_id | config\["Tasks"\]\[index\]\["camera_id"\] | |
| TaskEpoch | task_environment | nwbf.processing.\[index\].tasks.task_environment | config\["Tasks"\]\[index\]\["task_environment"\] | |

<b> NWBfile Location: nwbf.units </br> Object type: pynwb.misc.Units </b>

Expand Down
101 changes: 98 additions & 3 deletions src/spyglass/common/common_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,9 @@ def insert_from_nwbfile(cls, nwbf, config=None):
List of probe device types found in the NWB file.
"""
config = config or dict()
all_probes_types, ndx_probes, _ = cls.get_all_probe_names(nwbf, config)
all_probes_types, ndx_probes, config_probes = cls.get_all_probe_names(
nwbf, config
)

for probe_type in all_probes_types:
new_probe_type_dict = dict()
Expand All @@ -397,6 +399,16 @@ def insert_from_nwbfile(cls, nwbf, config=None):
elect_dict,
)

elif probe_type in config_probes:
cls._read_config_probe_data(
config,
probe_type,
new_probe_type_dict,
new_probe_dict,
shank_dict,
elect_dict,
)

# check that number of shanks is consistent
num_shanks = new_probe_type_dict["num_shanks"]
assert num_shanks == 0 or num_shanks == len(
Expand All @@ -405,15 +417,38 @@ def insert_from_nwbfile(cls, nwbf, config=None):

# if probe id already exists, do not overwrite anything or create
# new Shanks and Electrodes
# TODO: test whether the Shanks and Electrodes in the NWB file match
# the ones in the database
query = Probe & {"probe_id": new_probe_dict["probe_id"]}
if len(query) > 0:
logger.info(
f"Probe ID '{new_probe_dict['probe_id']}' already exists in"
" the database. Spyglass will use that and not create a new"
" Probe, Shanks, or Electrodes."
)
# Test whether the Shanks and Electrodes in the NWB file match
# the existing database entries
existing_shanks = query * cls.Shank()
bad_shanks = [
shank
for shank in shank_dict.values()
if len(existing_shanks & shank) != 1
]
if bad_shanks:
raise ValueError(
"Mismatch between nwb file and existing database "
+ f"entry for shanks: {bad_shanks}"
)

existing_electrodes = query * cls.Electrode()
bad_electrodes = [
electrode
for electrode in elect_dict.values()
if len(existing_electrodes & electrode) != 1
]
if bad_electrodes:
raise ValueError(
f"Mismatch between nwb file and existing database "
f"entry for electrodes: {bad_electrodes}"
)
continue

cls.insert1(new_probe_dict, skip_duplicates=True)
Expand Down Expand Up @@ -523,6 +558,66 @@ def __read_ndx_probe_data(
"rel_z": electrode.rel_z,
}

@classmethod
def _read_config_probe_data(
cls,
config,
probe_type,
new_probe_type_dict,
new_probe_dict,
shank_dict,
elect_dict,
):

# get the list of shank keys for the probe
shank_list = config["Probe"][config_probes.index(probe_type)].get(
"Shank", []
)
for i in shank_list:
shank_dict[str(i)] = {"probe_id": probe_type, "probe_shank": int(i)}

# get the list of electrode keys for the probe
elect_dict_list = config["Probe"][config_probes.index(probe_type)].get(
"Electrode", []
)
for i, e in enumerate(elect_dict_list):
elect_dict[str(i)] = {
"probe_id": probe_type,
"probe_shank": e["probe_shank"],
"probe_electrode": e["probe_electrode"],
"contact_size": e.get("contact_size"),
"rel_x": e.get("rel_x"),
"rel_y": e.get("rel_y"),
"rel_z": e.get("rel_z"),
}

# make the probe type if not in database
new_probe_type_dict.update(
{
"manufacturer": config["Probe"][
config_probes.index(probe_type)
].get("manufacturer"),
"probe_type": probe_type,
"probe_description": config["Probe"][
config_probes.index(probe_type)
].get("probe_description"),
"num_shanks": len(shank_list),
}
)

cls._add_probe_type(new_probe_type_dict)

# make the probe dictionary
new_probe_dict.update(
{
"probe_type": probe_type,
"probe_id": probe_type,
"contact_side_numbering": config["Probe"][
config_probes.index(probe_type)
].get("contact_side_numbering"),
}
)

@classmethod
def _add_probe_type(cls, new_probe_type_dict):
"""Check the probe type value against the values in the database.
Expand Down
92 changes: 80 additions & 12 deletions src/spyglass/common/common_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from spyglass.common.common_nwbfile import Nwbfile
from spyglass.common.common_session import Session # noqa: F401
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils.nwb_helper_fn import get_nwb_file
from spyglass.utils.nwb_helper_fn import get_config, get_nwb_file

schema = dj.schema("common_task")

Expand Down Expand Up @@ -106,6 +106,7 @@ def make(self, key):
nwb_file_name = key["nwb_file_name"]
nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name)
nwbf = get_nwb_file(nwb_file_abspath)
config = get_config(nwb_file_abspath, calling_table=self.camel_name)
camera_names = dict()

# the tasks refer to the camera_id which is unique for the NWB file but
Expand All @@ -117,13 +118,27 @@ def make(self, key):
# get the camera ID
camera_id = int(str.split(device.name)[1])
camera_names[camera_id] = device.camera_name
if device_list := config.get("CameraDevice"):
for device in device_list:
camera_names.update(
{
name: id
for name, id in zip(
device.get("camera_name"),
device.get("camera_id", -1),
)
}
)

# find the task modules and for each one, add the task to the Task
# schema if it isn't there and then add an entry for each epoch

tasks_mod = nwbf.processing.get("tasks")
if tasks_mod is None:
logger.warn(f"No tasks processing module found in {nwbf}\n")
config_tasks = config.get("Tasks")
if tasks_mod is None and config_tasks is None:
logger.warn(
f"No tasks processing module found in {nwbf} or config\n"
)
return

task_inserts = []
Expand Down Expand Up @@ -166,19 +181,72 @@ def make(self, key):
for epoch in task.task_epochs[0]:
# TODO in beans file, task_epochs[0] is 1x2 dset of ints,
# so epoch would be an int

key["epoch"] = epoch
target_interval = str(epoch).zfill(2)
for interval in session_intervals:
if (
target_interval in interval
): # TODO this is not true for the beans file
break
# TODO case when interval is not found is not handled
key["interval_list_name"] = interval
target_interval = self.get_epoch_interval_name(
epoch, session_intervals
)
if target_interval is None:
logger.warn("Skipping epoch.")
continue
key["interval_list_name"] = target_interval
task_inserts.append(key.copy())

# Add tasks from config
for task in config_tasks:
new_key = {
**key,
"task_name": task.get("task_name"),
"task_environment": task.get("task_environment", None),
}
# add cameras
camera_ids = task.get("camera_id", [])
valid_camera_ids = [
camera_id
for camera_id in camera_ids
if camera_id in camera_names.keys()
]
if valid_camera_ids:
new_key["camera_names"] = [
{"camera_name": camera_names[camera_id]}
for camera_id in valid_camera_ids
]
session_intervals = (
IntervalList() & {"nwb_file_name": nwb_file_name}
).fetch("interval_list_name")
for epoch in task.get("task_epochs", []):
new_key["epoch"] = epoch
target_interval = self.get_epoch_interval_name(
epoch, session_intervals
)
if target_interval is None:
logger.warn("Skipping epoch.")
continue
new_key["interval_list_name"] = target_interval
task_inserts.append(key.copy())

self.insert(task_inserts, allow_direct_insert=True)

@classmethod
def get_epoch_interval_name(cls, epoch, session_intervals):
"""Get the interval name for a given epoch based on matching number"""
target_interval = str(epoch).zfill(2)
possible_targets = [
interval
for interval in session_intervals
if target_interval in interval
]
if not possible_targets:
logger.warn(
f"Interval not found for epoch {epoch} in {nwb_file_name}."
)
elif len(possible_targets) > 1:
logger.warn(
f"Multiple intervals found for epoch {epoch} in {nwb_file_name}. "
+ f"matches are {possible_targets}."
)
else:
return possible_targets[0]

@classmethod
def update_entries(cls, restrict=True):
"""Update entries in the TaskEpoch table based on a restriction."""
Expand Down
5 changes: 3 additions & 2 deletions src/spyglass/spikesorting/v0/spikesorting_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,9 @@ def insert_curation(
Curation.insert1(sorting_key, skip_duplicates=True)

# get the primary key for this curation
c_key = (Curation & sorting_key).fetch1("KEY")
curation_key = {item: sorting_key[item] for item in c_key}
curation_key = {
item: sorting_key[item] for item in Curation.primary_key
}

return curation_key

Expand Down
8 changes: 6 additions & 2 deletions src/spyglass/utils/dj_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from datajoint import FreeTable, Table
from datajoint.condition import make_condition
from datajoint.dependencies import unite_master_parts
from datajoint.hash import key_hash
from datajoint.user_tables import TableMeta
from datajoint.utils import get_master, to_camel_case
Expand All @@ -35,6 +34,11 @@
unique_dicts,
)

try: # Datajoint 0.14.2+ uses topo_sort instead of unite_master_parts
from datajoint.dependencies import topo_sort as dj_topo_sort
except ImportError:
from datajoint.dependencies import unite_master_parts as dj_topo_sort


class Direction(Enum):
"""Cascade direction enum. Calling Up returns True. Inverting flips."""
Expand Down Expand Up @@ -474,7 +478,7 @@ def _topo_sort(
if not self._is_out(node, warn=False)
]
graph = self.graph.subgraph(nodes) if subgraph else self.graph
ordered = unite_master_parts(list(topological_sort(graph)))
ordered = dj_topo_sort(list(topological_sort(graph)))
if reverse:
ordered.reverse()
return [n for n in ordered if n in nodes]
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def frequent_imports():
from spyglass.lfp.analysis.v1 import LFPBandSelection
from spyglass.mua.v1.mua import MuaEventsV1
from spyglass.ripple.v1.ripple import RippleTimesV1
from spyglass.spikesorting.analysis.v1.unit_annotation import UnitAnnotation
from spyglass.spikesorting.v0.figurl_views import SpikeSortingRecordingView

return (
Expand All @@ -403,6 +404,7 @@ def frequent_imports():
RippleTimesV1,
SortedSpikesIndicatorSelection,
SpikeSortingRecordingView,
UnitAnnotation,
UnitMarksIndicatorSelection,
)

Expand Down

0 comments on commit 094e002

Please sign in to comment.