Skip to content

Commit

Permalink
NWB ingestion fixes (#1074)
Browse files Browse the repository at this point in the history
* check nwb probe geometry agains existing probe entry on insertion

* allow insertion of new probe type from config

* fix shank and electrode dicts

* update docs for config probe insert

* add option to add tasks from config during ingestion

* avoid default use of last interval_list_name for task

* update docs for config TaskEpoch

* cleanup read of probe data from config

* improve query efficiency

* Apply suggestions from code review

Co-authored-by: Chris Broz <[email protected]>

* suggestions from code review

* spelling fix

---------

Co-authored-by: Chris Broz <[email protected]>
  • Loading branch information
samuelbray32 and CBroz1 authored Sep 18, 2024
1 parent 8be2f56 commit c82f23e
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 25 deletions.
22 changes: 12 additions & 10 deletions docs/src/ForDevelopers/UsingNWB.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,26 +144,28 @@ ndx_franklab_novela.CameraDevice </b>
| Spyglass Table | Key | NWBfile Location | Config option | Notes |
| :------------- | :---------------: | ----------------------------------------: | -----------------------------------------: | ----: |
| Probe | probe_type | nwbf.devices.\<\*Probe>.probe_type | config\["Probe"\]\[index\]\["probe_type"\] | str |
| Probe | probe_id | nwbf.devices.\<\*Probe>.probe_type | XXX | str |
| Probe | manufacturer | nwbf.devices.\<\*Probe>.manufacturer | XXX | str |
| Probe | probe_description | nwbf.devices.\<\*Probe>.probe_description | XXX | str |
| Probe | probe_id | nwbf.devices.\<\*Probe>.probe_type | config\["Probe"\]\[index\]\["probe_type"\] | str |
| Probe | manufacturer | nwbf.devices.\<\*Probe>.manufacturer | config\["Probe"\]\[index\]\["manufacturer"\] | str |
| Probe | probe_description | nwbf.devices.\<\*Probe>.probe_description | config\["Probe"\]\[index\]\["description"\] | str |
| Probe | num_shanks | nwbf.devices.\<\*Probe>.num_shanks | XXX | int |

<b> NWBfile Location: nwbf.devices.\<\*Probe>.\<\*Shank> <br/> Object type:
ndx_franklab_novela.Shank </b>

| Spyglass Table | Key | NWBfile Location | Config option | Notes |
| :------------- | :---------: | ---------------------------------------------: | ------------: | ----: |
| Probe.Shank | probe_shank | nwbf.devices.\<\*Probe>.\<\*Shank>.probe_shank | XXX | int |
| Probe.Shank | probe_shank | nwbf.devices.\<\*Probe>.\<\*Shank>.probe_shank | config\["Probe"\]\[Shank\]\ | int | In the config, a list of ints |

<b> NWBfile Location: nwbf.devices.\<\*Probe>.\<\*Shank>.\<\*Electrode> <br/>
Object type: ndx_franklab_novela.Electrode </b>

| Spyglass Table | Key | NWBfile Location | Config option | Notes |
| :-------------- | :----------: | -------------------------------------------------------------: | ------------: | ----: |
| Probe.Electrode | probe_shank | nwbf.devices.\<\*Probe>.\<\*Shank>.probe_shank | XXX | int |
| Probe.Electrode | contact_size | nwbf.devices.\<\*Probe>.\<\*Shank>.\<\*Electrode>.contact_size | XXX | float |
| Probe.Electrode | rel_x | nwbf.devices.\<\*Probe>.\<\*Shank>.\<\*Electrode>.rel_x | XXX | float |
| Probe.Electrode | probe_shank | nwbf.devices.\<\*Probe>.\<\*Shank>.probe_shank | config\["Probe"]\["Electrode"]\[index]\["probe_shank"] | int |
| Probe.Electrode | contact_size | nwbf.devices.\<\*Probe>.\<\*Shank>.\<\*Electrode>.contact_size | config\["Probe"]\["Electrode"]\[index]\["contact_size"] | float |
| Probe.Electrode | rel_x | nwbf.devices.\<\*Probe>.\<\*Shank>.\<\*Electrode>.rel_x | config\["Probe"]\["Electrode"]\[index]\["rel_x"] | float |
| Probe.Electrode | rel_y | nwbf.devices.\<\*Probe>.\<\*Shank>.\<\*Electrode>.rel_y | config\["Probe"]\["Electrode"]\[index]\["rel_y"] | float |
| Probe.Electrode | rel_z | nwbf.devices.\<\*Probe>.\<\*Shank>.\<\*Electrode>.rel_z | config\["Probe"]\["Electrode"]\[index]\["rel_z"] | float |

<b> NWBfile Location: nwbf.epochs <br/> Object type: pynwb.epoch.TimeIntervals
</b>
Expand Down Expand Up @@ -213,9 +215,9 @@ hdmf.common.table.DynamicTable </b>
| :------------- | :--------------: | -----------------------------------------------: | ------------: | ----: |
| Task | task_name | nwbf.processing.tasks.\[index\].name | | |
| Task | task_description | nwbf.processing.\[index\].tasks.description | | |
| TaskEpoch | task_name | nwbf.processing.\[index\].tasks.name | | |
| TaskEpoch | camera_names | nwbf.processing.\[index\].tasks.camera_id | | |
| TaskEpoch | task_environment | nwbf.processing.\[index\].tasks.task_environment | | |
| TaskEpoch | task_name | nwbf.processing.\[index\].tasks.name | config\["Tasks"\]\[index\]\["task_name"\]| |
| TaskEpoch | camera_names | nwbf.processing.\[index\].tasks.camera_id | config\["Tasks"\]\[index\]\["camera_id"\] | |
| TaskEpoch | task_environment | nwbf.processing.\[index\].tasks.task_environment | config\["Tasks"\]\[index\]\["task_environment"\] | |

<b> NWBfile Location: nwbf.units </br> Object type: pynwb.misc.Units </b>

Expand Down
101 changes: 98 additions & 3 deletions src/spyglass/common/common_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,9 @@ def insert_from_nwbfile(cls, nwbf, config=None):
List of probe device types found in the NWB file.
"""
config = config or dict()
all_probes_types, ndx_probes, _ = cls.get_all_probe_names(nwbf, config)
all_probes_types, ndx_probes, config_probes = cls.get_all_probe_names(
nwbf, config
)

for probe_type in all_probes_types:
new_probe_type_dict = dict()
Expand All @@ -397,6 +399,16 @@ def insert_from_nwbfile(cls, nwbf, config=None):
elect_dict,
)

elif probe_type in config_probes:
cls._read_config_probe_data(
config,
probe_type,
new_probe_type_dict,
new_probe_dict,
shank_dict,
elect_dict,
)

# check that number of shanks is consistent
num_shanks = new_probe_type_dict["num_shanks"]
assert num_shanks == 0 or num_shanks == len(
Expand All @@ -405,15 +417,38 @@ def insert_from_nwbfile(cls, nwbf, config=None):

# if probe id already exists, do not overwrite anything or create
# new Shanks and Electrodes
# TODO: test whether the Shanks and Electrodes in the NWB file match
# the ones in the database
query = Probe & {"probe_id": new_probe_dict["probe_id"]}
if len(query) > 0:
logger.info(
f"Probe ID '{new_probe_dict['probe_id']}' already exists in"
" the database. Spyglass will use that and not create a new"
" Probe, Shanks, or Electrodes."
)
# Test whether the Shanks and Electrodes in the NWB file match
# the existing database entries
existing_shanks = query * cls.Shank()
bad_shanks = [
shank
for shank in shank_dict.values()
if len(existing_shanks & shank) != 1
]
if bad_shanks:
raise ValueError(
"Mismatch between nwb file and existing database "
+ f"entry for shanks: {bad_shanks}"
)

existing_electrodes = query * cls.Electrode()
bad_electrodes = [
electrode
for electrode in elect_dict.values()
if len(existing_electrodes & electrode) != 1
]
if bad_electrodes:
raise ValueError(
f"Mismatch between nwb file and existing database "
f"entry for electrodes: {bad_electrodes}"
)
continue

cls.insert1(new_probe_dict, skip_duplicates=True)
Expand Down Expand Up @@ -523,6 +558,66 @@ def __read_ndx_probe_data(
"rel_z": electrode.rel_z,
}

@classmethod
def _read_config_probe_data(
cls,
config,
probe_type,
new_probe_type_dict,
new_probe_dict,
shank_dict,
elect_dict,
):

# get the list of shank keys for the probe
shank_list = config["Probe"][config_probes.index(probe_type)].get(
"Shank", []
)
for i in shank_list:
shank_dict[str(i)] = {"probe_id": probe_type, "probe_shank": int(i)}

# get the list of electrode keys for the probe
elect_dict_list = config["Probe"][config_probes.index(probe_type)].get(
"Electrode", []
)
for i, e in enumerate(elect_dict_list):
elect_dict[str(i)] = {
"probe_id": probe_type,
"probe_shank": e["probe_shank"],
"probe_electrode": e["probe_electrode"],
"contact_size": e.get("contact_size"),
"rel_x": e.get("rel_x"),
"rel_y": e.get("rel_y"),
"rel_z": e.get("rel_z"),
}

# make the probe type if not in database
new_probe_type_dict.update(
{
"manufacturer": config["Probe"][
config_probes.index(probe_type)
].get("manufacturer"),
"probe_type": probe_type,
"probe_description": config["Probe"][
config_probes.index(probe_type)
].get("probe_description"),
"num_shanks": len(shank_list),
}
)

cls._add_probe_type(new_probe_type_dict)

# make the probe dictionary
new_probe_dict.update(
{
"probe_type": probe_type,
"probe_id": probe_type,
"contact_side_numbering": config["Probe"][
config_probes.index(probe_type)
].get("contact_side_numbering"),
}
)

@classmethod
def _add_probe_type(cls, new_probe_type_dict):
"""Check the probe type value against the values in the database.
Expand Down
92 changes: 80 additions & 12 deletions src/spyglass/common/common_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from spyglass.common.common_nwbfile import Nwbfile
from spyglass.common.common_session import Session # noqa: F401
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils.nwb_helper_fn import get_nwb_file
from spyglass.utils.nwb_helper_fn import get_config, get_nwb_file

schema = dj.schema("common_task")

Expand Down Expand Up @@ -106,6 +106,7 @@ def make(self, key):
nwb_file_name = key["nwb_file_name"]
nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name)
nwbf = get_nwb_file(nwb_file_abspath)
config = get_config(nwb_file_abspath, calling_table=self.camel_name)
camera_names = dict()

# the tasks refer to the camera_id which is unique for the NWB file but
Expand All @@ -117,13 +118,27 @@ def make(self, key):
# get the camera ID
camera_id = int(str.split(device.name)[1])
camera_names[camera_id] = device.camera_name
if device_list := config.get("CameraDevice"):
for device in device_list:
camera_names.update(
{
name: id
for name, id in zip(
device.get("camera_name"),
device.get("camera_id", -1),
)
}
)

# find the task modules and for each one, add the task to the Task
# schema if it isn't there and then add an entry for each epoch

tasks_mod = nwbf.processing.get("tasks")
if tasks_mod is None:
logger.warn(f"No tasks processing module found in {nwbf}\n")
config_tasks = config.get("Tasks")
if tasks_mod is None and config_tasks is None:
logger.warn(
f"No tasks processing module found in {nwbf} or config\n"
)
return

task_inserts = []
Expand Down Expand Up @@ -166,19 +181,72 @@ def make(self, key):
for epoch in task.task_epochs[0]:
# TODO in beans file, task_epochs[0] is 1x2 dset of ints,
# so epoch would be an int

key["epoch"] = epoch
target_interval = str(epoch).zfill(2)
for interval in session_intervals:
if (
target_interval in interval
): # TODO this is not true for the beans file
break
# TODO case when interval is not found is not handled
key["interval_list_name"] = interval
target_interval = self.get_epoch_interval_name(
epoch, session_intervals
)
if target_interval is None:
logger.warn("Skipping epoch.")
continue
key["interval_list_name"] = target_interval
task_inserts.append(key.copy())

# Add tasks from config
for task in config_tasks:
new_key = {
**key,
"task_name": task.get("task_name"),
"task_environment": task.get("task_environment", None),
}
# add cameras
camera_ids = task.get("camera_id", [])
valid_camera_ids = [
camera_id
for camera_id in camera_ids
if camera_id in camera_names.keys()
]
if valid_camera_ids:
new_key["camera_names"] = [
{"camera_name": camera_names[camera_id]}
for camera_id in valid_camera_ids
]
session_intervals = (
IntervalList() & {"nwb_file_name": nwb_file_name}
).fetch("interval_list_name")
for epoch in task.get("task_epochs", []):
new_key["epoch"] = epoch
target_interval = self.get_epoch_interval_name(
epoch, session_intervals
)
if target_interval is None:
logger.warn("Skipping epoch.")
continue
new_key["interval_list_name"] = target_interval
task_inserts.append(key.copy())

self.insert(task_inserts, allow_direct_insert=True)

@classmethod
def get_epoch_interval_name(cls, epoch, session_intervals):
"""Get the interval name for a given epoch based on matching number"""
target_interval = str(epoch).zfill(2)
possible_targets = [
interval
for interval in session_intervals
if target_interval in interval
]
if not possible_targets:
logger.warn(
f"Interval not found for epoch {epoch} in {nwb_file_name}."
)
elif len(possible_targets) > 1:
logger.warn(
f"Multiple intervals found for epoch {epoch} in {nwb_file_name}. "
+ f"matches are {possible_targets}."
)
else:
return possible_targets[0]

@classmethod
def update_entries(cls, restrict=True):
"""Update entries in the TaskEpoch table based on a restriction."""
Expand Down

0 comments on commit c82f23e

Please sign in to comment.