From 04ee0703da50078a4b92c036fa94135162c21db2 Mon Sep 17 00:00:00 2001 From: Chris Brozdowski Date: Mon, 27 Nov 2023 06:38:43 -0800 Subject: [PATCH] Implement pre-commit, refactor input validation (#688) * WIP: database management doc * #686 * Spelling * Remove unused imports * Pre-commit hooks: yamllint, mdformat * Precommit hooks * Ruff PEP8 checks * Refactor dlc parameter validation * Add db backup instructions * Blackify, edit changelog * add mdformat-mkdocs * Minor edits * Fix Typo * Add black to precommit, remove commented variables * Ordered list increment; update pre-commit versions --- .git_archival.txt | 2 +- .gitattributes | 2 +- .pre-commit-config.yaml | 32 ++- .vscode/extensions.json | 2 +- .vscode/settings.json | 2 +- CHANGELOG.md | 60 ++-- MANIFEST.in | 2 +- README.md | 3 +- config/add_dj_collaborator.py | 0 config/add_dj_guest.py | 0 config/add_dj_module.py | 0 config/dj_config.py | 0 docs/README.md | 10 +- docs/build-docs.sh | 8 +- docs/overrides/nav.html | 2 - docs/src/api/index.md | 4 +- docs/src/contribute.md | 158 ++++++----- docs/src/images/Spyglass.svg | 12 +- docs/src/installation.md | 11 +- docs/src/misc/database_management.md | 19 +- docs/src/misc/insert_data.md | 45 +-- docs/src/misc/merge_tables.md | 34 +-- docs/src/misc/session_groups.md | 8 +- .../artifactdetectionparameters_default.yaml | 2 +- .../artifactdetectionparameters_none.yaml | 2 +- .../create_spike_sorting_recording.sh | 2 +- .../create_spike_sorting_recording_view.sh | 2 +- .../cli_examples/create_spike_sorting_view.sh | 2 +- examples/cli_examples/create_spyglass_view.sh | 2 +- .../insert_artifact_detection_parameters.sh | 2 +- examples/cli_examples/insert_lab_member.sh | 2 +- examples/cli_examples/insert_lab_team.sh | 2 +- .../cli_examples/insert_lab_team_member.sh | 2 +- examples/cli_examples/insert_session.sh | 2 +- .../insert_spike_sorter_parameters.sh | 2 +- ..._spike_sorting_preprocessing_parameters.sh | 2 +- examples/cli_examples/labmember.yaml | 2 +- examples/cli_examples/labteammember.yaml | 2 +- examples/cli_examples/parameters.yaml | 2 +- examples/cli_examples/readme.md | 3 +- examples/cli_examples/run_spike_sorting.sh | 2 +- .../spikesorterparameters_default.yaml | 2 +- .../spikesortingpreprocessingparameters.yaml | 2 +- ...0-g3_behavior+ecephys_spyglass_config.yaml | 2 +- franklab_scripts/sort.py | 18 +- notebooks/22_Position_DLC_2.ipynb | 1 - notebooks/30_Ripple_Detection.ipynb | 2 - notebooks/31_Extract_Mark_Indicators.ipynb | 1 - notebooks/README.md | 8 +- notebooks/py_scripts/22_Position_DLC_2.py | 1 - notebooks/py_scripts/30_Ripple_Detection.py | 2 - .../py_scripts/31_Extract_Mark_Indicators.py | 1 - pyproject.toml | 2 +- src/spyglass/cli/cli.py | 2 +- src/spyglass/common/__init__.py | 4 +- src/spyglass/common/common_behav.py | 2 +- src/spyglass/common/common_nwbfile.py | 7 +- src/spyglass/common/common_session.py | 2 - src/spyglass/decoding/clusterless.py | 53 ---- src/spyglass/lfp/lfp_merge.py | 8 +- src/spyglass/position/position_merge.py | 10 +- src/spyglass/position/v1/dlc_utils.py | 125 ++++++-- .../position/v1/position_dlc_centroid.py | 266 ++++++------------ .../position/v1/position_dlc_model.py | 160 +++++------ .../position/v1/position_dlc_position.py | 102 +++---- .../position/v1/position_dlc_project.py | 13 +- .../position/v1/position_dlc_selection.py | 17 +- .../position_linearization_merge.py | 6 +- src/spyglass/sharing/sharing_kachery.py | 11 +- .../spikesorting/sortingview_helper_fn.py | 7 +- .../spikesorting/spikesorting_curation.py | 21 +- .../spikesorting/spikesorting_recording.py | 2 +- src/spyglass/utils/database_settings.py | 0 tests/conftest.py | 13 +- 74 files changed, 617 insertions(+), 709 deletions(-) mode change 100644 => 100755 config/add_dj_collaborator.py mode change 100644 => 100755 config/add_dj_guest.py mode change 100644 => 100755 config/add_dj_module.py mode change 100644 => 100755 config/dj_config.py mode change 100644 => 100755 docs/build-docs.sh mode change 100644 => 100755 src/spyglass/utils/database_settings.py diff --git a/.git_archival.txt b/.git_archival.txt index b1a286bbb..8fb235d70 100644 --- a/.git_archival.txt +++ b/.git_archival.txt @@ -1,4 +1,4 @@ node: $Format:%H$ node-date: $Format:%cI$ describe-name: $Format:%(describe:tags=true,match=*[0-9]*)$ -ref-names: $Format:%D$ \ No newline at end of file +ref-names: $Format:%D$ diff --git a/.gitattributes b/.gitattributes index d62766021..f11b6d68c 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,4 +1,4 @@ # Auto detect text files and perform LF normalization * text=auto -.git_archival.txt export-subst \ No newline at end of file +.git_archival.txt export-subst diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f3f8793b0..7643f74c2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ exclude: (^.github/|^docs/site/|^images/) repos: - repo: https://github.com/executablebooks/mdformat # Do this before other tools "fixing" the line endings - rev: 0.7.16 + rev: 0.7.17 hooks: - id: mdformat name: Format Markdown @@ -13,6 +13,7 @@ repos: types: [markdown] args: [--wrap, "80", --number] additional_dependencies: + - mdformat-mkdocs - mdformat-toc - mdformat-beautysh - mdformat-config @@ -21,13 +22,11 @@ repos: - mdformat-gfm - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - - id: check-json - id: check-toml - id: check-yaml args: [--unsafe] - - id: requirements-txt-fixer - id: end-of-file-fixer - id: mixed-line-ending args: ["--fix=lf"] @@ -35,7 +34,7 @@ repos: - id: trailing-whitespace - id: debug-statements - id: check-added-large-files # prevent giant files from being committed - - id: check-builtin-literals + # - id: check-builtin-literals - id: check-merge-conflict - id: check-executables-have-shebangs - id: check-shebang-scripts-are-executable @@ -44,28 +43,41 @@ repos: - id: fix-byte-order-marker - repo: https://github.com/adrienverge/yamllint.git - rev: v1.29.0 + rev: v1.33.0 hooks: - id: yamllint args: - --no-warnings - -d - - "{extends: relaxed, rules: {line-length: {max: 90}}}" + - "{extends: relaxed, rules: {line-length: {max: 80}}}" - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.254 + rev: v0.1.6 hooks: - id: ruff + args: [ + "--exclude", "./notebooks/py_scripts/", + "--ignore", "F401,E402,E501" + ] + # F401: Unused import - May flag tables in DataJoint foreign keys + # E402: Module level import not at top of file - May want lazyloading + # E501: Line too long - Allow longer lines in table definitions - repo: https://github.com/PyCQA/autoflake - rev: v2.0.1 + rev: v2.2.1 hooks: - id: autoflake - repo: https://github.com/codespell-project/codespell - rev: v2.2.2 + rev: v2.2.6 hooks: - id: codespell args: [--toml, pyproject.toml] additional_dependencies: - tomli + + - repo: https://github.com/ambv/black + rev: 23.11.0 + hooks: + - id: black + language_version: python3.9 diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 36a8bb418..394322b27 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -20,4 +20,4 @@ ], // List of extensions recommended by VS Code that should not be recommended for users of this workspace. "unwantedRecommendations": [] -} \ No newline at end of file +} diff --git a/.vscode/settings.json b/.vscode/settings.json index 54a7c2424..bc2c1fc8c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -18,4 +18,4 @@ "--profile", "black" ], -} \ No newline at end of file +} diff --git a/CHANGELOG.md b/CHANGELOG.md index 437b61ac4..5073aa58c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,10 @@ # Change Log -## [0.4.4] (November 7, 2023) +## [0.4.4] (Unreleased) - Additional documentation. #686 +- Refactor input validation in DLC pipeline. +- Clean up following pre-commit checks. ## [0.4.3] (November 7, 2023) @@ -24,17 +26,17 @@ ### Pipelines - Common: - - Added support multiple cameras per epoch. #557 - - Removed `common_backup` schema. #631 - - Added support for multiple position objects per NWB in `common_behav` via - PositionSource.SpatialSeries and RawPosition.PosObject #628, #616. - _Note:_ Existing functions have been made compatible, but column labels for - `RawPosition.fetch1_dataframe` may change. + - Added support multiple cameras per epoch. #557 + - Removed `common_backup` schema. #631 + - Added support for multiple position objects per NWB in `common_behav` via + PositionSource.SpatialSeries and RawPosition.PosObject #628, #616. _Note:_ + Existing functions have been made compatible, but column labels for + `RawPosition.fetch1_dataframe` may change. - Spike sorting: - - Added pipeline populator. #637, #646, #647 - - Fixed curation functionality for `nn_isolation`. #597, #598 + - Added pipeline populator. #637, #646, #647 + - Fixed curation functionality for `nn_isolation`. #597, #598 - Position: Added position interval/epoch mapping via PositionIntervalMap. #620, - #621, #627 + #621, #627 - LFP: Refactored pipeline. #594, #588, #605, #606, #607, #608, #615, #629 ## [0.4.1] (June 30, 2023) @@ -45,12 +47,12 @@ ## [0.4.0] (May 22, 2023) - Updated call to `spikeinterface.preprocessing.whiten` to use dtype np.float16. - #446, + #446, - Updated default spike sorting metric parameters. #447 - Updated whitening to be compatible with recent changes in spikeinterface when - using mountainsort. #449 + using mountainsort. #449 - Moved LFP pipeline to `src/spyglass/lfp/v1` and addressed related usability - issues. #468, #478, #482, #484, #504 + issues. #468, #478, #482, #484, #504 - Removed whiten parameter for clusterless thresholder. #454 - Added plot to plot all DIO events in a session. #457 - Added file sharing functionality through kachery_cloud. #458, #460 @@ -58,28 +60,28 @@ - Added scripts to add guests and collaborators as users. #463 - Cleaned up installation instructions in repo README. #467 - Added checks in decoding visualization to ensure time dimensions are the - correct length. + correct length. - Fixed artifact removed valid times. #472 - Added codespell workflow for spell checking and fixed typos. #471 - Updated LFP code to save LFP as `pynwb.ecephys.LFP` type. #475 - Added artifact detection to LFP pipeline. #473 - Replaced calls to `spikeinterface.sorters.get_default_params` with - `spikeinterface.sorters.get_default_sorter_params`. #486 + `spikeinterface.sorters.get_default_sorter_params`. #486 - Updated position pipeline and added functionality to handle pose estimation - through DeepLabCut. #367, #505 + through DeepLabCut. #367, #505 - Updated `environment_position.yml`. #502 - Renamed `FirFilter` class to `FirFilterParameters`. #512 ## [0.3.4] (March 30, 2023) - Fixed error in spike sorting pipeline referencing the "probe_type" column - which is no longer accessible from the `Electrode` table. #437 + which is no longer accessible from the `Electrode` table. #437 - Fixed error when inserting an NWB file that does not have a probe - manufacturer. #433, #436 + manufacturer. #433, #436 - Fixed error when adding a new `DataAcquisitionDevice` and a new `ProbeType`. - #436 + #436 - Fixed inconsistency between capitalized/uncapitalized versions of "Intan" for - DataAcquisitionAmplifier and DataAcquisitionDevice.adc_circuit. #430, #438 + DataAcquisitionAmplifier and DataAcquisitionDevice.adc_circuit. #430, #438 ## [0.3.3] (March 29, 2023) @@ -99,13 +101,13 @@ - Allow creation and linkage of device metadata from YAML #400 - Move helper functions to utils directory #386 -[0.4.4]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.4.4 -[0.4.3]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.4.3 -[0.4.2]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.4.2 -[0.4.1]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.4.1 -[0.4.0]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.4.0 -[0.3.4]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.3.4 -[0.3.3]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.3.3 -[0.3.2]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.3.2 -[0.3.1]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.3.1 [0.3.0]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.3.0 +[0.3.1]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.3.1 +[0.3.2]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.3.2 +[0.3.3]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.3.3 +[0.3.4]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.3.4 +[0.4.0]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.4.0 +[0.4.1]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.4.1 +[0.4.2]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.4.2 +[0.4.3]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.4.3 +[0.4.4]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.4.4 diff --git a/MANIFEST.in b/MANIFEST.in index 80d188209..1e9e64a5c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,2 @@ graft tests -global-exclude *.py[cod] \ No newline at end of file +global-exclude *.py[cod] diff --git a/README.md b/README.md index 4b95fd508..821b7c4a3 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,8 @@ directory. We strongly recommend opening them in the context of `jupyterlab`. ## Contributing -See the [Developer's Note](https://lorenfranklab.github.io/spyglass/latest/contribute/) +See the +[Developer's Note](https://lorenfranklab.github.io/spyglass/latest/contribute/) for contributing instructions found at - [https://lorenfranklab.github.io/spyglass/latest/contribute/](https://lorenfranklab.github.io/spyglass/latest/contribute/) diff --git a/config/add_dj_collaborator.py b/config/add_dj_collaborator.py old mode 100644 new mode 100755 diff --git a/config/add_dj_guest.py b/config/add_dj_guest.py old mode 100644 new mode 100755 diff --git a/config/add_dj_module.py b/config/add_dj_module.py old mode 100644 new mode 100755 diff --git a/config/dj_config.py b/config/dj_config.py old mode 100644 new mode 100755 diff --git a/docs/README.md b/docs/README.md index fe480e486..80510daed 100644 --- a/docs/README.md +++ b/docs/README.md @@ -16,8 +16,8 @@ The remainder of `mkdocs.yml` specifies the site's ## GitHub Whenever a new tag is pushed, GitHub actions will run -`.github/workflows/publish-docs.yml`. Progress can be monitored in the -'Actions' tab within the repo. +`.github/workflows/publish-docs.yml`. Progress can be monitored in the 'Actions' +tab within the repo. Releases should be tagged with `X.Y.Z`. A tag to redeploy docs should use the current version, with an alpha release suffix, e.g. `X.Y.Za1`. @@ -45,9 +45,9 @@ Notably, this will make a copy of notebooks in `docs/src/notebooks`. Changes to the root notebooks directory may not be reflected when rebuilding. Use a browser to navigate to `localhost:8000/` to inspect the site. For -auto-reload of markdown files during development, use `mkdocs serve -f -./docs/mkdosc.yaml`. The `mike` package used in the build script manages -versioning, but does not support dynamic versioning. +auto-reload of markdown files during development, use +`mkdocs serve -f ./docs/mkdosc.yaml`. The `mike` package used in the build +script manages versioning, but does not support dynamic versioning. The following items can be commented out in `mkdocs.yml` to reduce build time: diff --git a/docs/build-docs.sh b/docs/build-docs.sh old mode 100644 new mode 100755 index d5b2e0fc7..03d28c07e --- a/docs/build-docs.sh +++ b/docs/build-docs.sh @@ -15,12 +15,12 @@ cp -r ./notebook-images ./docs/src/notebooks/ cp -r ./notebook-images ./docs/src/ # Get major version -FULL_VERSION=$(hatch version) # Most recent tag, may include periods -export MAJOR_VERSION="${FULL_VERSION:0:3}" # First 3 chars of tag +FULL_VERSION=$(hatch version) # Most recent tag, may include periods +export MAJOR_VERSION="${FULL_VERSION:0:3}" # First 3 chars of tag echo "$MAJOR_VERSION" -# Get ahead of errors -export JUPYTER_PLATFORM_DIRS=1 +# Get ahead of errors +export JUPYTER_PLATFORM_DIRS=1 # jupyter notebook --generate-config # Generate site docs diff --git a/docs/overrides/nav.html b/docs/overrides/nav.html index 216e139cf..b47ec9f90 100644 --- a/docs/overrides/nav.html +++ b/docs/overrides/nav.html @@ -5,5 +5,3 @@ {% if "toc.integrate" in features %} {% set class = class ~ " md-nav--integrated" %} {% endif %} - - diff --git a/docs/src/api/index.md b/docs/src/api/index.md index 0f5bf479d..d616c0757 100644 --- a/docs/src/api/index.md +++ b/docs/src/api/index.md @@ -4,8 +4,8 @@ The files in this directory are automatically generated from the docstrings in the source code. They include descriptions of each of the DataJoint tables and other classes/methods within Spyglass. -These docs are updated any time a new release is made or a tag is -pushed to the repository. +These docs are updated any time a new release is made or a tag is pushed to the +repository. ## Developer note The `py_scripts` directory contains the same notebook data in `.py` form to -facilitate GitHub PR reviews. To update them, run the following from the -root Spyglass directory +facilitate GitHub PR reviews. To update them, run the following from the root +Spyglass directory ```bash pip install jupytext diff --git a/notebooks/py_scripts/22_Position_DLC_2.py b/notebooks/py_scripts/22_Position_DLC_2.py index 1b0fe315d..b47e40469 100644 --- a/notebooks/py_scripts/22_Position_DLC_2.py +++ b/notebooks/py_scripts/22_Position_DLC_2.py @@ -40,7 +40,6 @@ # + import os import datajoint as dj -from pprint import pprint # change to the upper level folder to detect dj_local_conf.json if os.path.basename(os.getcwd()) == "notebooks": diff --git a/notebooks/py_scripts/30_Ripple_Detection.py b/notebooks/py_scripts/30_Ripple_Detection.py index 03494e913..dd1b14fc1 100644 --- a/notebooks/py_scripts/30_Ripple_Detection.py +++ b/notebooks/py_scripts/30_Ripple_Detection.py @@ -36,10 +36,8 @@ # + import os -import copy import datajoint as dj import numpy as np -import pandas as pd # change to the upper level folder to detect dj_local_conf.json if os.path.basename(os.getcwd()) == "notebooks": diff --git a/notebooks/py_scripts/31_Extract_Mark_Indicators.py b/notebooks/py_scripts/31_Extract_Mark_Indicators.py index f4d4ed3ed..73c9ae399 100644 --- a/notebooks/py_scripts/31_Extract_Mark_Indicators.py +++ b/notebooks/py_scripts/31_Extract_Mark_Indicators.py @@ -57,7 +57,6 @@ # + import os import datajoint as dj -from pprint import pprint # change to the upper level folder to detect dj_local_conf.json if os.path.basename(os.getcwd()) == "notebooks": diff --git a/pyproject.toml b/pyproject.toml index 0b1511d98..76acb9ca3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ test = [ "kachery-cloud", ] docs = [ - "hatch", # Get version from env + "hatch", # Get version from env "mike", # Docs versioning "mkdocs", # Docs core "mkdocs-exclude", # Docs exclude files diff --git a/src/spyglass/cli/cli.py b/src/spyglass/cli/cli.py index ba150c766..5ccc7aaee 100644 --- a/src/spyglass/cli/cli.py +++ b/src/spyglass/cli/cli.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import Union import click import yaml diff --git a/src/spyglass/common/__init__.py b/src/spyglass/common/__init__.py index abbd7d96a..47c00beb0 100644 --- a/src/spyglass/common/__init__.py +++ b/src/spyglass/common/__init__.py @@ -1,5 +1,3 @@ -import os - import spyglass as sg from ..utils.nwb_helper_fn import ( @@ -12,11 +10,11 @@ get_valid_intervals, ) from .common_behav import ( + PositionIntervalMap, PositionSource, RawPosition, StateScriptFile, VideoFile, - PositionIntervalMap, convert_epoch_interval_name_to_position_interval_name, ) from .common_device import ( diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index ee0d48f31..4f35bcd9e 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -611,7 +611,7 @@ def get_interval_list_name_from_epoch(nwb_file_name: str, epoch: int) -> str: if len(interval_names) != 1: print( - f"Found {len(interval_name)} interval list names found for " + f"Found {len(interval_names)} interval list names found for " + f"{nwb_file_name} epoch {epoch}" ) return None diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 8ee62c834..f6ee9c49b 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -1,5 +1,4 @@ import os -import pathlib import random import stat import string @@ -309,7 +308,7 @@ def get_abs_path(analysis_nwb_file_name): analysis_nwb_file_abspath : str The absolute path for the given file name. """ - base_dir = pathlib.Path(os.getenv("SPYGLASS_BASE_DIR", None)) + base_dir = Path(os.getenv("SPYGLASS_BASE_DIR", None)) assert ( base_dir is not None ), "You must set SPYGLASS_BASE_DIR environment variable." @@ -658,6 +657,8 @@ class NwbfileKachery(dj.Computed): """ def make(self, key): + import kachery_client as kc + print(f'Linking {key["nwb_file_name"]} and storing in kachery...') key["nwb_file_uri"] = kc.link_file( Nwbfile().get_abs_path(key["nwb_file_name"]) @@ -674,6 +675,8 @@ class AnalysisNwbfileKachery(dj.Computed): """ def make(self, key): + import kachery_client as kc + print(f'Linking {key["analysis_file_name"]} and storing in kachery...') key["analysis_file_uri"] = kc.link_file( AnalysisNwbfile().get_abs_path(key["analysis_file_name"]) diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index 7824fafdf..ac1950408 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -1,5 +1,3 @@ -import os - import datajoint as dj from ..settings import config, debug_mode diff --git a/src/spyglass/decoding/clusterless.py b/src/spyglass/decoding/clusterless.py index b530786c4..7cc149eda 100644 --- a/src/spyglass/decoding/clusterless.py +++ b/src/spyglass/decoding/clusterless.py @@ -5,7 +5,6 @@ ---------- [1] Denovellis, E. L. et al. Hippocampal replay of experience at real-world speeds. eLife 10, e64505 (2021). - """ import os @@ -667,58 +666,6 @@ def insert_default(self): ) -""" -NOTE: Table decommissioned. See #630, #664. Excessive key length. - -class MultiunitHighSynchronyEvents(dj.Computed): - "Finds times of high mulitunit activity during immobility." - - definition = " - -> MultiunitHighSynchronyEventsParameters - -> UnitMarksIndicator - -> IntervalPositionInfo - --- - -> AnalysisNwbfile - multiunit_hse_times_object_id: varchar(40) - " - - def make(self, key): - marks = (UnitMarksIndicator & key).fetch_xarray() - multiunit_spikes = (np.any(~np.isnan(marks.values), axis=1)).astype( - float - ) - position_info = (IntervalPositionInfo() & key).fetch1_dataframe() - - params = (MultiunitHighSynchronyEventsParameters & key).fetch1() - - multiunit_high_synchrony_times = multiunit_HSE_detector( - marks.time.values, - multiunit_spikes, - position_info.head_speed.values, - sampling_frequency=key["sampling_rate"], - **params, - ) - - # Insert into analysis nwb file - nwb_analysis_file = AnalysisNwbfile() - key["analysis_file_name"] = nwb_analysis_file.create( - key["nwb_file_name"] - ) - - key["multiunit_hse_times_object_id"] = nwb_analysis_file.add_nwb_object( - analysis_file_name=key["analysis_file_name"], - nwb_object=multiunit_high_synchrony_times.reset_index(), - ) - - nwb_analysis_file.add( - nwb_file_name=key["nwb_file_name"], - analysis_file_name=key["analysis_file_name"], - ) - - self.insert1(key) -""" - - def get_decoding_data_for_epoch( nwb_file_name: str, interval_list_name: str, diff --git a/src/spyglass/lfp/lfp_merge.py b/src/spyglass/lfp/lfp_merge.py index 36b0f696b..265319ca0 100644 --- a/src/spyglass/lfp/lfp_merge.py +++ b/src/spyglass/lfp/lfp_merge.py @@ -4,8 +4,8 @@ from spyglass.common.common_ephys import LFP as CommonLFP # noqa: F401 from spyglass.common.common_filter import FirFilterParameters # noqa: F401 from spyglass.common.common_interval import IntervalList # noqa: F401 -from spyglass.lfp.v1.lfp import LFPV1 # noqa: F401 from spyglass.lfp.lfp_imported import ImportedLFP # noqa: F401 +from spyglass.lfp.v1.lfp import LFPV1 # noqa: F401 from spyglass.utils.dj_merge_tables import _Merge schema = dj.schema("lfp_merge") @@ -19,21 +19,21 @@ class LFPOutput(_Merge): source: varchar(32) """ - class LFPV1(dj.Part): + class LFPV1(dj.Part): # noqa: F811 definition = """ -> master --- -> LFPV1 """ - class ImportedLFP(dj.Part): + class ImportedLFP(dj.Part): # noqa: F811 definition = """ -> master --- -> ImportedLFP """ - class CommonLFP(dj.Part): + class CommonLFP(dj.Part): # noqa: F811 """Table to pass-through legacy LFP""" definition = """ diff --git a/src/spyglass/position/position_merge.py b/src/spyglass/position/position_merge.py index 86de6194f..c87383581 100644 --- a/src/spyglass/position/position_merge.py +++ b/src/spyglass/position/position_merge.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd from datajoint.utils import to_camel_case -from tqdm import tqdm as tqdm from ..common.common_position import IntervalPositionInfo as CommonPos from ..utils.dj_merge_tables import _Merge @@ -207,6 +206,7 @@ def make(self, key): ) print("Loading video data...") + epoch = int("".join(filter(str.isdigit, key["interval_list_name"]))) + 1 ( video_path, @@ -214,13 +214,7 @@ def make(self, key): meters_per_pixel, video_time, ) = get_video_path( - { - "nwb_file_name": key["nwb_file_name"], - "epoch": int( - "".join(filter(str.isdigit, key["interval_list_name"])) - ) - + 1, - } + {"nwb_file_name": key["nwb_file_name"], "epoch": epoch} ) video_dir = os.path.dirname(video_path) + "/" video_frame_col_name = [ diff --git a/src/spyglass/position/v1/dlc_utils.py b/src/spyglass/position/v1/dlc_utils.py index 3abfb8f72..9894b4d98 100644 --- a/src/spyglass/position/v1/dlc_utils.py +++ b/src/spyglass/position/v1/dlc_utils.py @@ -21,6 +21,103 @@ from ...settings import raw_dir +def validate_option( + option=None, + options: list = None, + name="option", + types: tuple = None, + val_range: tuple = None, + permit_none=False, +): + """Validate that option is in a list options or a list of types. + + Parameters + ---------- + option : str, optional + If none, runs no checks. + options : lis, optional + If provided, option must be in options. + name : st, optional + If provided, name of option to use in error message. + types : tuple, optional + If provided, option must be an instance of one of the types in types. + val_range : tuple, optional + If provided, option must be in range (min, max) + permit_none : bool, optional + If True, permit option to be None. Default False. + + Raises + ------ + ValueError + If option is not in options. + """ + if option is None and not permit_none: + raise ValueError(f"{name} cannot be None") + + if options and option not in options: + raise KeyError( + f"Unknown {name}: {option} " f"Available options: {options}" + ) + + if types and not isinstance(option, tuple(types)): + raise TypeError(f"{name} is {type(option)}. Available types {types}") + + if val_range and not (val_range[0] <= option <= val_range[1]): + raise ValueError(f"{name} must be in range {val_range}") + + +def validate_list( + required_items: list, + option_list: list = None, + name="List", + condition="", + permit_none=False, +): + """Validate that option_list contains all items in required_items. + + Parameters + --------- + required_items : list + option_list : list, optional + If provided, option_list must contain all items in required_items. + name : str, optional + If provided, name of option_list to use in error message. + condition : str, optional + If provided, condition in error message as 'when using X'. + permit_none : bool, optional + If True, permit option_list to be None. Default False. + """ + if option_list is None: + if permit_none: + return + else: + raise ValueError(f"{name} cannot be None") + if condition: + condition = f" when using {condition}" + if any(x not in required_items for x in option_list): + raise KeyError( + f"{name} must contain all items in {required_items}{condition}." + ) + + +def validate_smooth_params(params): + """If params['smooth'], validate method is in list and duration type""" + if not params.get("smooth"): + return + smoothing_params = params.get("smoothing_params") + validate_option(smoother=smoothing_params, name="smoothing_params") + validate_option( + option=smoothing_params.get("smooth_method"), + name="smooth_method", + options=_key_to_smooth_func_dict, + ) + validate_option( + option=smoothing_params.get("smoothing_duration"), + name="smoothing_duration", + types=(int, float), + ) + + def _set_permissions(directory, mode, username: str, groupname: str = None): """ Use to recursively set ownership and permissions for @@ -704,7 +801,7 @@ def make_video( RGB_PINK = (234, 82, 111) RGB_YELLOW = (253, 231, 76) - RGB_WHITE = (255, 255, 255) + # RGB_WHITE = (255, 255, 255) RGB_BLUE = (30, 144, 255) RGB_ORANGE = (255, 127, 80) # "#29ff3e", @@ -751,10 +848,12 @@ def make_video( key: fill_nan( position_mean[key]["orientation"], video_time, position_time ) - for key in orientation_mean.keys() + for key in position_mean.keys() + # CBroz: Bug was here, using nonexistent orientation_mean dict } print( - f"frames start: {frames[0]}\nvideo_frames start: {video_frame_inds[0]}\ncv2 frame ind start: {int(video.get(1))}" + f"frames start: {frames[0]}\nvideo_frames start: " + + f"{video_frame_inds[0]}\ncv2 frame ind start: {int(video.get(1))}" ) for time_ind in tqdm( frames, desc="frames", disable=disable_progressbar @@ -923,12 +1022,9 @@ def make_video( position_mean = position_mean["DLC"] orientation_mean = orientation_mean["DLC"] - frame_offset = -1 - time_slice = [] video_slowdown = 1 - vmax = 0.07 # ? - # Set up formatting for the movie files + # Set up formatting for the movie files window_size = 501 if likelihoods: plot_likelihood = True @@ -1036,16 +1132,7 @@ def make_video( f"time = {time_delta:3.4f}s\n frame = {frame_ind}", fontsize=8, ) - fontprops = fm.FontProperties(size=12) - # scalebar = AnchoredSizeBar(axes[0].transData, - # 20, '20 cm', 'lower right', - # pad=0.1, - # color='white', - # frameon=False, - # size_vertical=1, - # fontproperties=fontprops) - - # axes[0].add_artist(scalebar) + _ = fm.FontProperties(size=12) axes[0].axis("off") if plot_likelihood: likelihood_objs = { @@ -1175,10 +1262,6 @@ def _update_plot(time_ind): centroid_position_dot, orientation_line, title, - # redC_likelihood, - # green_likelihood, - # redL_likelihood, - # redR_likelihood, ) movie = animation.FuncAnimation( diff --git a/src/spyglass/position/v1/position_dlc_centroid.py b/src/spyglass/position/v1/position_dlc_centroid.py index 351f577aa..d1e7e6dba 100644 --- a/src/spyglass/position/v1/position_dlc_centroid.py +++ b/src/spyglass/position/v1/position_dlc_centroid.py @@ -9,7 +9,14 @@ from ...common.common_behav import RawPosition from ...common.common_nwbfile import AnalysisNwbfile from ...utils.dj_helper_fn import fetch_nwb -from .dlc_utils import _key_to_smooth_func_dict, get_span_start_stop, interp_pos +from .dlc_utils import ( + _key_to_smooth_func_dict, + get_span_start_stop, + interp_pos, + validate_list, + validate_option, + validate_smooth_params, +) from .position_dlc_cohort import DLCSmoothInterpCohort from .position_dlc_position import DLCSmoothInterpParams @@ -30,14 +37,6 @@ class DLCCentroidParams(dj.Manual): params: longblob """ - _available_centroid_methods = [ - "four_led_centroid", - "two_pt_centroid", - "one_pt_centroid", - ] - _four_led_labels = ["greenLED", "redLED_L", "redLED_C", "redLED_R"] - _two_pt_labels = ["point1", "point2"] - @classmethod def insert_default(cls, **kwargs): """ @@ -79,75 +78,27 @@ def insert1(self, key, **kwargs): it contains all necessary items """ params = key["params"] - if "centroid_method" in params: - if params["centroid_method"] in self._available_centroid_methods: - if params["centroid_method"] == "four_led_centroid": - if any( - x not in self._four_led_labels for x in params["points"] - ): - raise KeyError( - f"Please make sure to specify all necessary labels: " - f"{self._four_led_labels} " - f"if using the 'four_led_centroid' method" - ) - elif params["centroid_method"] == "two_pt_centroid": - if any( - x not in self._two_pt_labels for x in params["points"] - ): - raise KeyError( - f"Please make sure to specify all necessary labels: " - f"{self._two_pt_labels} " - f"if using the 'two_pt_centroid' method" - ) - elif params["centroid_method"] == "one_pt_centroid": - if "point1" not in params["points"]: - raise KeyError( - "Please make sure to specify the necessary label: " - "'point1' " - "if using the 'one_pt_centroid' method" - ) - else: - raise Exception("This shouldn't happen lol oops") - else: - raise ValueError( - f"The given 'centroid_method': {params['centroid_method']} " - f"is not in the available methods: " - f"{self._available_centroid_methods}" - ) - else: - raise KeyError( - "'centroid_method' needs to be provided as a parameter" - ) + centroid_method = params.get("centroid_method") + validate_option( # Ensure centroid method is valid + option=centroid_method, + options=_key_to_points.keys(), + name="centroid_method", + ) + validate_list( # Ensure points are valid for centroid method + required_items=set(params["points"].keys()), + option_list=params["points"], + name="points", + condition=centroid_method, + ) - if "max_LED_separation" in params: - if not isinstance(params["max_LED_separation"], (int, float)): - raise TypeError( - f"parameter 'max_LED_separation' is type: " - f"{type(params['max_LED_separation'])}, " - f"it should be one of type (float, int)" - ) - if "smooth" in params: - if params["smooth"]: - if "smoothing_params" in params: - if "smooth_method" in params["smoothing_params"]: - smooth_method = params["smoothing_params"][ - "smooth_method" - ] - if smooth_method not in _key_to_smooth_func_dict: - raise KeyError( - f"smooth_method: {smooth_method} not an available method." - ) - if not "smoothing_duration" in params["smoothing_params"]: - raise KeyError( - "smoothing_duration must be passed as a smoothing_params within key['params']" - ) - else: - assert isinstance( - params["smoothing_params"]["smoothing_duration"], - (float, int), - ), "smoothing_duration must be a float or int" - else: - raise ValueError("smoothing_params not in key['params']") + validate_option( + option=params.get("max_LED_separation"), + name="max_LED_separation", + types=(int, float), + permit_none=True, + ) + + validate_smooth_params(params) super().insert1(key, **kwargs) @@ -192,91 +143,36 @@ def make(self, key): ) as logger: logger.logger.info("-----------------------") logger.logger.info("Centroid Calculation") + # Get labels to smooth from Parameters table cohort_entries = DLCSmoothInterpCohort.BodyPart & key params = (DLCCentroidParams() & key).fetch1("params") centroid_method = params.pop("centroid_method") bodyparts_avail = cohort_entries.fetch("bodypart") speed_smoothing_std_dev = params.pop("speed_smoothing_std_dev") - # TODO, generalize key naming - if centroid_method == "four_led_centroid": - centroid_func = _key_to_func_dict[centroid_method] - if "greenLED" in params["points"]: - assert ( - params["points"]["greenLED"] in bodyparts_avail - ), f'{params["points"]["greenLED"]} not a bodypart used in this model' - else: - raise ValueError( - "A green led needs to be specified for the 4 led centroid method" - ) - if "redLED_L" in params["points"]: - assert ( - params["points"]["redLED_L"] in bodyparts_avail - ), f'{params["points"]["redLED_L"]} not a bodypart used in this model' - else: - raise ValueError( - "A left red led needs to be specified for the 4 led centroid method" - ) - if "redLED_C" in params["points"]: - assert ( - params["points"]["redLED_C"] in bodyparts_avail - ), f'{params["points"]["redLED_C"]} not a bodypart used in this model' - else: - raise ValueError( - "A center red led needs to be specified for the 4 led centroid method" - ) - if "redLED_R" in params["points"]: - assert ( - params["points"]["redLED_R"] in bodyparts_avail - ), f'{params["points"]["redLED_R"]} not a bodypart used in this model' - else: - raise ValueError( - "A right red led needs to be specified for the 4 led centroid method" - ) - bodyparts_to_use = [ - params["points"]["greenLED"], - params["points"]["redLED_L"], - params["points"]["redLED_C"], - params["points"]["redLED_R"], - ] - elif centroid_method == "two_pt_centroid": - centroid_func = _key_to_func_dict[centroid_method] - if "point1" in params["points"]: - assert ( - params["points"]["point1"] in bodyparts_avail - ), f'{params["points"]["point1"]} not a bodypart used in this model' - else: - raise ValueError( - "point1 needs to be specified for the 2 pt centroid method" - ) - if "point2" in params["points"]: - assert ( - params["points"]["point2"] in bodyparts_avail - ), f'{params["points"]["point2"]} not a bodypart used in this model' - else: - raise ValueError( - "point2 needs to be specified for the 2 pt centroid method" - ) - bodyparts_to_use = [ - params["points"]["point1"], - params["points"]["point2"], - ] - - elif centroid_method == "one_pt_centroid": - centroid_func = _key_to_func_dict[centroid_method] - if "point1" in params["points"]: - assert ( - params["points"]["point1"] in bodyparts_avail - ), f'{params["points"]["point1"]} not a bodypart used in this model' - else: + if not centroid_method: + raise ValueError("Please specify a centroid method to use.") + validate_option(option=centroid_method, options=_key_to_func_dict) + + points = params.get("points") + required_points = _key_to_points.get(centroid_method) + validate_list( + required_items=required_points, + option_list=points, + name="params points", + condition=centroid_method, + ) + for point in required_points: + bodypart = points[point] + if bodypart not in bodyparts_avail: raise ValueError( - "point1 needs to be specified for the 1 pt centroid method" + "Bodypart in points not in model." + f"\tBodypart {bodypart}" + f"\tIn Model {bodyparts_avail}" ) - bodyparts_to_use = [params["points"]["point1"]] + bodyparts_to_use = [points[point] for point in required_points] - else: - raise ValueError("Please specify a centroid method to use.") pos_df = pd.concat( { bodypart: ( @@ -292,6 +188,7 @@ def make(self, key): logger.logger.info( "Calculating centroid with %s", str(centroid_method) ) + centroid_func = _key_to_func_dict.get(centroid_method) centroid = centroid_func(pos_df, **params) centroid_df = pd.DataFrame( centroid, @@ -317,30 +214,29 @@ def make(self, key): else: interp_df = centroid_df.copy() if params["smooth"]: - if "smoothing_duration" in params["smoothing_params"]: - smoothing_duration = params["smoothing_params"].pop( - "smoothing_duration" - ) - dt = np.median(np.diff(pos_df.index.to_numpy())) - sampling_rate = 1 / dt - logger.logger.info("smoothing position") - smooth_func = _key_to_smooth_func_dict[ - params["smoothing_params"]["smooth_method"] - ] - logger.logger.info( - "Smoothing using method: %s", - str(params["smoothing_params"]["smooth_method"]), - ) - final_df = smooth_func( - interp_df, - smoothing_duration=smoothing_duration, - sampling_rate=sampling_rate, - **params["smoothing_params"], - ) - else: + smoothing_duration = params["smoothing_params"].get( + "smoothing_duration" + ) + if not smoothing_duration: raise KeyError( "smoothing_duration needs to be passed within smoothing_params" ) + dt = np.median(np.diff(pos_df.index.to_numpy())) + sampling_rate = 1 / dt + logger.logger.info("smoothing position") + smooth_func = _key_to_smooth_func_dict[ + params["smoothing_params"]["smooth_method"] + ] + logger.logger.info( + "Smoothing using method: %s", + str(params["smoothing_params"]["smooth_method"]), + ) + final_df = smooth_func( + interp_df, + smoothing_duration=smoothing_duration, + sampling_rate=sampling_rate, + **params["smoothing_params"], + ) else: final_df = interp_df.copy() logger.logger.info("getting velocity") @@ -404,15 +300,18 @@ def make(self, key): comments="no comments", ) # Add to Analysis NWB file - key["analysis_file_name"] = AnalysisNwbfile().create( - key["nwb_file_name"] - ) + analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) nwb_analysis_file = AnalysisNwbfile() - key["dlc_position_object_id"] = nwb_analysis_file.add_nwb_object( - key["analysis_file_name"], position - ) - key["dlc_velocity_object_id"] = nwb_analysis_file.add_nwb_object( - key["analysis_file_name"], velocity + key.update( + { + "analysis_file_name": analysis_file_name, + "dlc_position_object_id": nwb_analysis_file.add_nwb_object( + analysis_file_name, position + ), + "dlc_velocity_object_id": nwb_analysis_file.add_nwb_object( + analysis_file_name, velocity + ), + } ) nwb_analysis_file.add( @@ -910,3 +809,8 @@ def one_pt_centroid(pos_df: pd.DataFrame, **params): "two_pt_centroid": two_pt_centroid, "one_pt_centroid": one_pt_centroid, } +_key_to_points = { + "four_led_centroid": ["greenLED", "redLED_L", "redLED_C", "redLED_R"], + "two_pt_centroid": ["point1", "point2"], + "one_pt_centroid": ["point1"], +} diff --git a/src/spyglass/position/v1/position_dlc_model.py b/src/spyglass/position/v1/position_dlc_model.py index 51cecb4ea..422e4f665 100644 --- a/src/spyglass/position/v1/position_dlc_model.py +++ b/src/spyglass/position/v1/position_dlc_model.py @@ -166,7 +166,7 @@ class DLCModel(dj.Computed): """ # project_path is the only item required downstream in the pose schema - class BodyPart(dj.Part): + class BodyPart(dj.Part): # noqa: F811 definition = """ -> DLCModel -> BodyPart @@ -175,93 +175,85 @@ class BodyPart(dj.Part): def make(self, key): from deeplabcut.utils.auxiliaryfunctions import GetScorerName - from .dlc_utils import OutputLogger - _, model_name, table_source = (DLCModelSource & key).fetch1().values() SourceTable = getattr(DLCModelSource, table_source) params = (DLCModelParams & key).fetch1("params") project_path = (SourceTable & key).fetch1("project_path") - with OutputLogger( - name="DLC_project_{project_name}_model_{model_name}", - path=f"{Path(project_path).as_posix()}/log.log", - ) as logger: - if not isinstance(project_path, PosixPath): - project_path = Path(project_path) - config_query = PurePath(project_path, Path("*config.y*ml")) - available_config = glob.glob(config_query.as_posix()) - dj_config = [path for path in available_config if "dj_dlc" in path] - if len(dj_config) > 0: - config_path = Path(dj_config[0]) - elif len(available_config) == 1: - config_path = Path(available_config[0]) - else: - config_path = PurePath(project_path, Path("config.yaml")) - if not config_path.exists(): - raise OSError(f"config_path {config_path} does not exist.") - if config_path.suffix in (".yml", ".yaml"): - with open(config_path, "rb") as f: - dlc_config = yaml.safe_load(f) - if isinstance(params["params"], dict): - dlc_config.update(params["params"]) - del params["params"] - # TODO: clean-up. this feels sloppy - shuffle = params.pop("shuffle", 1) - trainingsetindex = params.pop("trainingsetindex", None) - if not isinstance(trainingsetindex, int): - raise KeyError("no trainingsetindex specified in key") - model_prefix = params.pop("model_prefix", "") - model_description = params.pop("model_description", model_name) - paramset_name = params.pop("dlc_training_params_name", None) - - needed_attributes = [ - "Task", - "date", - "iteration", - "snapshotindex", - "TrainingFraction", - ] - for attribute in needed_attributes: - assert ( - attribute in dlc_config - ), f"Couldn't find {attribute} in config" - - scorer_legacy = str_to_bool(dlc_config.get("scorer_legacy", "f")) - - dlc_scorer = GetScorerName( - cfg=dlc_config, - shuffle=shuffle, - trainFraction=dlc_config["TrainingFraction"][ - int(trainingsetindex) - ], - modelprefix=model_prefix, - )[scorer_legacy] - if dlc_config["snapshotindex"] == -1: - dlc_scorer = "".join(dlc_scorer.split("_")[:-1]) - - # ---- Insert ---- - model_dict = { - "dlc_model_name": model_name, - "model_description": model_description, - "scorer": dlc_scorer, - "task": dlc_config["Task"], - "date": dlc_config["date"], - "iteration": dlc_config["iteration"], - "snapshotindex": dlc_config["snapshotindex"], - "shuffle": shuffle, - "trainingsetindex": int(trainingsetindex), - "project_path": project_path, - "config_template": dlc_config, - } - part_key = key.copy() - key.update(model_dict) - # ---- Save DJ-managed config ---- - _ = dlc_reader.save_yaml(project_path, dlc_config) - - # ____ Insert into table ---- - self.insert1(key) - self.BodyPart.insert( - {**part_key, "bodypart": bp} for bp in dlc_config["bodyparts"] - ) + if not isinstance(project_path, PosixPath): + project_path = Path(project_path) + config_query = PurePath(project_path, Path("*config.y*ml")) + available_config = glob.glob(config_query.as_posix()) + dj_config = [path for path in available_config if "dj_dlc" in path] + if len(dj_config) > 0: + config_path = Path(dj_config[0]) + elif len(available_config) == 1: + config_path = Path(available_config[0]) + else: + config_path = PurePath(project_path, Path("config.yaml")) + if not config_path.exists(): + raise OSError(f"config_path {config_path} does not exist.") + if config_path.suffix in (".yml", ".yaml"): + with open(config_path, "rb") as f: + dlc_config = yaml.safe_load(f) + if isinstance(params["params"], dict): + dlc_config.update(params["params"]) + del params["params"] + # TODO: clean-up. this feels sloppy + shuffle = params.pop("shuffle", 1) + trainingsetindex = params.pop("trainingsetindex", None) + if not isinstance(trainingsetindex, int): + raise KeyError("no trainingsetindex specified in key") + model_prefix = params.pop("model_prefix", "") + model_description = params.pop("model_description", model_name) + _ = params.pop("dlc_training_params_name", None) + + needed_attributes = [ + "Task", + "date", + "iteration", + "snapshotindex", + "TrainingFraction", + ] + for attribute in needed_attributes: + assert ( + attribute in dlc_config + ), f"Couldn't find {attribute} in config" + + scorer_legacy = str_to_bool(dlc_config.get("scorer_legacy", "f")) + + dlc_scorer = GetScorerName( + cfg=dlc_config, + shuffle=shuffle, + trainFraction=dlc_config["TrainingFraction"][int(trainingsetindex)], + modelprefix=model_prefix, + )[scorer_legacy] + if dlc_config["snapshotindex"] == -1: + dlc_scorer = "".join(dlc_scorer.split("_")[:-1]) + + # ---- Insert ---- + model_dict = { + "dlc_model_name": model_name, + "model_description": model_description, + "scorer": dlc_scorer, + "task": dlc_config["Task"], + "date": dlc_config["date"], + "iteration": dlc_config["iteration"], + "snapshotindex": dlc_config["snapshotindex"], + "shuffle": shuffle, + "trainingsetindex": int(trainingsetindex), + "project_path": project_path, + "config_template": dlc_config, + } + part_key = key.copy() + key.update(model_dict) + # ---- Save DJ-managed config ---- + _ = dlc_reader.save_yaml(project_path, dlc_config) + + # ____ Insert into table ---- + self.insert1(key) + self.BodyPart.insert( + {**part_key, "bodypart": bp} for bp in dlc_config["bodyparts"] + ) print( f"Finished inserting {model_name}, training iteration" f" {dlc_config['iteration']} into DLCModel" diff --git a/src/spyglass/position/v1/position_dlc_position.py b/src/spyglass/position/v1/position_dlc_position.py index 2a5060ab5..230779929 100644 --- a/src/spyglass/position/v1/position_dlc_position.py +++ b/src/spyglass/position/v1/position_dlc_position.py @@ -5,7 +5,13 @@ from ...common.common_nwbfile import AnalysisNwbfile from ...utils.dj_helper_fn import fetch_nwb -from .dlc_utils import _key_to_smooth_func_dict, get_span_start_stop, interp_pos +from .dlc_utils import ( + _key_to_smooth_func_dict, + get_span_start_stop, + interp_pos, + validate_option, + validate_smooth_params, +) from .position_dlc_pose_estimation import DLCPoseEstimation schema = dj.schema("position_v1_dlc_position") @@ -58,8 +64,8 @@ def insert_default(cls, **kwargs): "likelihood_thresh": 0.95, "interp_params": {"max_cm_to_interp": 15}, "max_cm_between_pts": 20, - # This is for use when finding "good spans" and is how many indices to bridge in between good spans - # see inds_to_span in get_good_spans + # This is for use when finding "good spans" and is how many indices + # to bridge in between good spans see inds_to_span in get_good_spans "num_inds_to_span": 20, } cls.insert1( @@ -105,56 +111,23 @@ def get_available_methods(): return _key_to_smooth_func_dict.keys() def insert1(self, key, **kwargs): - if "params" in key: - if not "max_cm_between_pts" in key["params"]: - raise KeyError("max_cm_between_pts is a required parameter") - if "smooth" in key["params"]: - if key["params"]["smooth"]: - if "smoothing_params" in key["params"]: - if "smooth_method" in key["params"]["smoothing_params"]: - smooth_method = key["params"]["smoothing_params"][ - "smooth_method" - ] - if smooth_method not in _key_to_smooth_func_dict: - raise KeyError( - f"smooth_method: {smooth_method} not an available method." - ) - if ( - not "smoothing_duration" - in key["params"]["smoothing_params"] - ): - raise KeyError( - "smoothing_duration must be passed as a smoothing_params within key['params']" - ) - else: - assert isinstance( - key["params"]["smoothing_params"][ - "smoothing_duration" - ], - (float, int), - ), "smoothing_duration must be a float or int" - else: - raise ValueError( - "smoothing_params not in key['params']" - ) - if "likelihood_thresh" in key["params"]: - assert isinstance( - key["params"]["likelihood_thresh"], - float, - ), "likelihood_thresh must be a float" - assert ( - 0 < key["params"]["likelihood_thresh"] < 1 - ), "likelihood_thresh must be between 0 and 1" - else: - raise ValueError( - "likelihood_thresh must be passed within key['params']" - ) - else: - raise KeyError("'params' must be in key") - super().insert1(key, **kwargs) + params = key.get("params") + if not isinstance(params, dict): + raise KeyError("'params' must be a dict in key") - # def delete(self, key, **kwargs): - # super().delete(key, **kwargs) + validate_option( + option=params.get("max_cm_between_pts"), name="max_cm_between_pts" + ) + validate_smooth_params(params) + + validate_option( + params.get("likelihood_thresh"), + name="likelihood_thresh", + types=(float), + val_range=(0, 1), + ) + + super().insert1(key, **kwargs) @schema @@ -343,16 +316,21 @@ def nan_inds( inds_to_span: int, ): idx = pd.IndexSlice - # Could either NaN sub-likelihood threshold inds here and then not consider in jumping... - # OR just keep in back pocket when checking jumps against last good point + + # Could either NaN sub-likelihood threshold inds here and then not consider + # in jumping... OR just keep in back pocket when checking jumps against + # last good point + subthresh_inds = get_subthresh_inds( dlc_df, likelihood_thresh=likelihood_thresh ) df_subthresh_indices = dlc_df.index[subthresh_inds] dlc_df.loc[idx[df_subthresh_indices], idx[("x", "y")]] = np.nan - # To further determine which indices are the original point and which are jump points - # There could be a more efficient method of doing this - # screen inds for jumps to baseline + + # To further determine which indices are the original point and which are + # jump points. There could be a more efficient method of doing this screen + # inds for jumps to baseline + subthresh_inds_mask = np.zeros(len(dlc_df), dtype=bool) subthresh_inds_mask[subthresh_inds] = True jump_inds_mask = np.zeros(len(dlc_df), dtype=bool) @@ -425,8 +403,9 @@ def nan_inds( def get_good_spans(bad_inds_mask, inds_to_span: int = 50): """ This function takes in a boolean mask of good and bad indices and - determines spans of consecutive good indices. It combines two neighboring spans - with a separation of less than inds_to_span and treats them as a single good span. + determines spans of consecutive good indices. It combines two neighboring + spans with a separation of less than inds_to_span and treats them as a + single good span. Parameters ---------- @@ -437,7 +416,8 @@ def get_good_spans(bad_inds_mask, inds_to_span: int = 50): be bridged to form a single good span. For instance if span A is (1500, 2350) and span B is (2370, 3700), then span A and span B would be combined into span A (1500, 3700) - since one would want to identify potential jumps in the space in between the original A and B. + since one would want to identify potential jumps in the space in between + the original A and B. Returns ------- @@ -490,6 +470,6 @@ def get_subthresh_inds(dlc_df: pd.DataFrame, likelihood_thresh: float): nand_inds = np.where(np.isnan(dlc_df["x"]))[0] all_nan_inds = list(set(sub_thresh_inds).union(set(nand_inds))) all_nan_inds.sort() - sub_thresh_percent = (len(sub_thresh_inds) / len(dlc_df)) * 100 # TODO: add option to return sub_thresh_percent + # sub_thresh_percent = (len(sub_thresh_inds) / len(dlc_df)) * 100 return all_nan_inds diff --git a/src/spyglass/position/v1/position_dlc_project.py b/src/spyglass/position/v1/position_dlc_project.py index 79b7df9b2..ee3094bfb 100644 --- a/src/spyglass/position/v1/position_dlc_project.py +++ b/src/spyglass/position/v1/position_dlc_project.py @@ -215,17 +215,20 @@ def insert_new_project( lab_team: str, frames_per_video: int, video_list: List, + groupname: str = None, project_directory: str = os.getenv("DLC_PROJECT_PATH"), output_path: str = os.getenv("DLC_VIDEO_PATH"), set_permissions=False, **kwargs, ): - """ - insert a new project into DLCProject table. + """Insert a new project into DLCProject table. + Parameters ---------- project_name : str user-friendly name of project + groupname : str, optional + Name for project group. If None, defaults to username bodyparts : list list of bodyparts to label. Should match bodyparts in BodyPart table lab_team : str @@ -236,9 +239,9 @@ def insert_new_project( frames_per_video : int number of frames to extract from each video video_list : list - list of dicts of form [{'nwb_file_name': nwb_file_name, 'epoch': epoch #},...] - to query VideoFile table for videos to train on. - Can also be list of absolute paths to import videos from + list of (a) dicts of to query VideoFile table for or (b) absolute + paths to videos to train on. If dict, use format: + [{'nwb_file_name': nwb_file_name, 'epoch': epoch #},...] output_path : str target path to output converted videos (Default is '/nimbus/deeplabcut/videos/') diff --git a/src/spyglass/position/v1/position_dlc_selection.py b/src/spyglass/position/v1/position_dlc_selection.py index fb4e2d97c..af4aa15ee 100644 --- a/src/spyglass/position/v1/position_dlc_selection.py +++ b/src/spyglass/position/v1/position_dlc_selection.py @@ -6,14 +6,13 @@ import pandas as pd import pynwb from datajoint.utils import to_camel_case -from tqdm import tqdm as tqdm -from ...common.common_nwbfile import AnalysisNwbfile from ...common.common_behav import ( convert_epoch_interval_name_to_position_interval_name, ) +from ...common.common_nwbfile import AnalysisNwbfile from ...utils.dj_helper_fn import fetch_nwb -from .dlc_utils import get_video_path, make_video +from .dlc_utils import make_video from .position_dlc_centroid import DLCCentroid from .position_dlc_cohort import DLCSmoothInterpCohort from .position_dlc_orient import DLCOrientation @@ -233,18 +232,6 @@ def evaluate_pose_estimation(cls, key): for bodypart in bodyparts if bodypart in pose_estimation_df.columns } - sub_thresh_ind_dict = { - bodypart: { - "inds": np.where( - ~np.isnan( - pose_estimation_df[bodypart]["likelihood"].where( - df_filter[bodypart] - ) - ) - )[0], - } - for bodypart in bodyparts - } sub_thresh_percent_dict = { bodypart: ( len( diff --git a/src/spyglass/position_linearization/position_linearization_merge.py b/src/spyglass/position_linearization/position_linearization_merge.py index 68f9d063a..1efd38afc 100644 --- a/src/spyglass/position_linearization/position_linearization_merge.py +++ b/src/spyglass/position_linearization/position_linearization_merge.py @@ -1,8 +1,8 @@ import datajoint as dj -from spyglass.position_linearization.v1.linearization import ( +from spyglass.position_linearization.v1.linearization import ( # noqa F401 LinearizedPositionV1, -) # noqa F401 +) from ..utils.dj_merge_tables import _Merge @@ -17,7 +17,7 @@ class LinearizedPositionOutput(_Merge): source: varchar(32) """ - class LinearizedPositionV1(dj.Part): + class LinearizedPositionV1(dj.Part): # noqa: F811 definition = """ -> LinearizedPositionOutput --- diff --git a/src/spyglass/sharing/sharing_kachery.py b/src/spyglass/sharing/sharing_kachery.py index f8109f56c..d3a6f1adf 100644 --- a/src/spyglass/sharing/sharing_kachery.py +++ b/src/spyglass/sharing/sharing_kachery.py @@ -2,6 +2,7 @@ import datajoint as dj import kachery_cloud as kcl +from datajoint.errors import DataJointError from ..common.common_lab import Lab # noqa from ..common.common_nwbfile import AnalysisNwbfile @@ -56,7 +57,7 @@ def set_zone(key: dict): kachery_zone_name, kachery_cloud_dir = (KacheryZone & key).fetch1( "kachery_zone_name", "kachery_cloud_dir" ) - except: + except DataJointError: raise Exception( f"{key} does not correspond to a single entry in KacheryZone." ) @@ -75,7 +76,8 @@ def reset_zone(): @staticmethod def set_resource_url(key: dict): - """Sets the KACHERY_RESOURCE_URL based on the key corresponding to a single Kachery Zone + """Sets the KACHERY_RESOURCE_URL based on the key corresponding to a + single Kachery Zone Parameters ---------- @@ -86,7 +88,7 @@ def set_resource_url(key: dict): kachery_zone_name, kachery_proxy = (KacheryZone & key).fetch1( "kachery_zone_name", "kachery_proxy" ) - except: + except DataJointError: raise Exception( f"{key} does not correspond to a single entry in KacheryZone." ) @@ -160,7 +162,8 @@ def make(self, key): @staticmethod def download_file(analysis_file_name: str) -> bool: - """Download the specified analysis file and associated linked files from kachery-cloud if possible + """Download the specified analysis file and associated linked files + from kachery-cloud if possible Parameters ---------- diff --git a/src/spyglass/spikesorting/sortingview_helper_fn.py b/src/spyglass/spikesorting/sortingview_helper_fn.py index 05ba0e822..177c8a831 100644 --- a/src/spyglass/spikesorting/sortingview_helper_fn.py +++ b/src/spyglass/spikesorting/sortingview_helper_fn.py @@ -1,12 +1,11 @@ "Sortingview helper functions" -from typing import Dict, List, Union, Any, Tuple - -import spikeinterface as si +from typing import Any, List, Tuple, Union import kachery_cloud as kcl import sortingview as sv import sortingview.views as vv +import spikeinterface as si from sortingview.SpikeSortingView import SpikeSortingView from .merged_sorting_extractor import MergedSortingExtractor @@ -56,8 +55,6 @@ def _create_spikesortingview_workspace( sorting_id=sorting_id, label=label, unit_ids=[int(unit_id)] ) - unit_metrics = workspace.get_unit_metrics_for_sorting(sorting_id) - return workspace.uri, recording_id, sorting_id diff --git a/src/spyglass/spikesorting/spikesorting_curation.py b/src/spyglass/spikesorting/spikesorting_curation.py index 5b6703271..6b7cfc315 100644 --- a/src/spyglass/spikesorting/spikesorting_curation.py +++ b/src/spyglass/spikesorting/spikesorting_curation.py @@ -84,7 +84,8 @@ def insert_curation( """ if parent_curation_id == -1: - # check to see if this sorting with a parent of -1 has already been inserted and if so, warn the user + # check to see if this sorting with a parent of -1 has already been + # inserted and if so, warn the user inserted_curation = (Curation & sorting_key).fetch("KEY") if len(inserted_curation) > 0: Warning( @@ -546,7 +547,7 @@ def _compute_metric(self, waveform_extractor, metric_name, **metric_params): else: raise Exception( f"{peak_sign_metrics} metrics require peak_sign", - f"to be defined in the metric parameters", + "to be defined in the metric parameters", ) else: metric = {} @@ -949,9 +950,6 @@ def make(self, key): recording = Curation.get_recording(key) # get the sort_interval and sorting interval list - sort_interval_name = (SpikeSortingRecording & key).fetch1( - "sort_interval_name" - ) sort_interval = (SortInterval & key).fetch1("sort_interval") sort_interval_list_name = (SpikeSorting & key).fetch1( "artifact_removed_interval_list_name" @@ -1044,7 +1042,8 @@ def insert1(self, key, **kwargs): def get_included_units( self, curated_sorting_key, unit_inclusion_param_name ): - """given a reference to a set of curated sorting units and the name of a unit inclusion parameter list, returns + """Given a reference to a set of curated sorting units and the name of + a unit inclusion parameter list, returns unit key Parameters ---------- @@ -1054,7 +1053,7 @@ def get_included_units( name of a unit inclusion parameter entry Returns - ------unit key + ------- dict key to select all of the included units """ @@ -1067,8 +1066,6 @@ def get_included_units( units_key = (CuratedSpikeSorting().Unit() & curated_sortings).fetch( "KEY" ) - # get a list of the metrics in the units table - metrics_list = CuratedSpikeSorting().metrics_fields() # get the list of labels to exclude if there is one if "exclude_labels" in inc_param_dict: exclude_labels = inc_param_dict["exclude_labels"] @@ -1079,7 +1076,8 @@ def get_included_units( # create a list of the units to kepp. keep = np.asarray([True] * len(units)) for metric in inc_param_dict: - # for all units, go through each metric, compare it to the value specified, and update the list to be kept + # for all units, go through each metric, compare it to the value + # specified, and update the list to be kept keep = np.logical_and( keep, _comparison_to_function[inc_param_dict[metric][0]]( @@ -1089,14 +1087,13 @@ def get_included_units( # now exclude by label if it is specified if len(exclude_labels): - included_units = [] for unit_ind in np.ravel(np.argwhere(keep)): labels = units[unit_ind]["label"].split(",") - exclude = False for label in labels: if label in exclude_labels: keep[unit_ind] = False break + # return units that passed all of the tests # TODO: Make this more efficient return {i: units_key[i] for i in np.ravel(np.argwhere(keep))} diff --git a/src/spyglass/spikesorting/spikesorting_recording.py b/src/spyglass/spikesorting/spikesorting_recording.py index f6abb72b5..398e3812a 100644 --- a/src/spyglass/spikesorting/spikesorting_recording.py +++ b/src/spyglass/spikesorting/spikesorting_recording.py @@ -321,7 +321,7 @@ def get_geometry(self, sort_group_id, nwb_file_name): class SortInterval(dj.Manual): definition = """ -> Session - sort_interval_name: varchar(64) # name for this interval + sort_interval_name: varchar(64) # name for this interval --- sort_interval: longblob # 1D numpy array with start and end time for a single interval to be used for spike sorting """ diff --git a/src/spyglass/utils/database_settings.py b/src/spyglass/utils/database_settings.py old mode 100644 new mode 100755 diff --git a/tests/conftest.py b/tests/conftest.py index eae26c2c2..ac1539abf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ # directory-specific hook implementations import os -import pathlib import shutil import sys import tempfile @@ -47,11 +46,13 @@ def pytest_configure(config): _set_env() - # note that in this configuration, every test will use the same datajoint server - # this may create conflicts and dependencies between tests - # it may be better but significantly slower to start a new server for every test - # but the server needs to be started before tests are collected because datajoint runs when the source - # files are loaded, not when the tests are run. one solution might be to restart the server after every test + # note that in this configuration, every test will use the same datajoint + # server this may create conflicts and dependencies between tests it may be + # better but significantly slower to start a new server for every test but + # the server needs to be started before tests are collected because + # datajoint runs when the source files are loaded, not when the tests are + # run. one solution might be to restart the server after every test + global __PROCESS __PROCESS = run_datajoint_server()