Skip to content

Commit

Permalink
✅ : pytests for utils, position, linearization
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Feb 29, 2024
1 parent 221f0db commit e1b4c1c
Show file tree
Hide file tree
Showing 23 changed files with 621 additions and 155 deletions.
2 changes: 1 addition & 1 deletion config/add_dj_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
"This script is deprecated. "
+ "Use spyglass.utils.database_settings.DatabaseSettings instead."
)
DatabaseSettings(user_name=sys.argv[1]).add_dj_user()
DatabaseSettings(user_name=sys.argv[1]).add_user(check_exists=True)
16 changes: 8 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ ignore-words-list = 'nevers'
minversion = "7.0"
addopts = [
"-sv",
"--sw", # stepwise: resume with next test after failure
"--pdb", # drop into debugger on failure
# "--sw", # stepwise: resume with next test after failure
# "--pdb", # drop into debugger on failure
"-p no:warnings",
"--no-teardown", # don't teardown the database after tests
"--quiet-spy", # don't show logging from spyglass
# "--no-teardown", # don't teardown the database after tests
# "--quiet-spy", # don't show logging from spyglass
"--show-capture=no",
"--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger
"--cov=spyglass",
Expand All @@ -141,19 +141,19 @@ omit = [ # which submodules have no tests
"*/__init__.py",
"*/_version.py",
"*/cli/*",
"*/common/*",
# "*/common/*",
"*/data_import/*",
"*/decoding/*",
"*/figurl_views/*",
# "*/lfp/*",
"*/linearization/*",
# "*/linearization/*",
"*/lock/*",
"*/mua/*",
"*/position/*",
# "*/position/*",
"*/ripple/*",
"*/sharing/*",
"*/spikesorting/*",
"*/utils/*",
# "*/utils/*",
"settings.py",
]

Expand Down
2 changes: 0 additions & 2 deletions src/spyglass/position/v1/position_dlc_project.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import copy
import getpass
import glob
import os
import shutil
import stat
from itertools import combinations
from pathlib import Path, PosixPath
from typing import Dict, List, Union
Expand Down
18 changes: 9 additions & 9 deletions src/spyglass/position/v1/position_trodes_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.common.common_position import IntervalPositionInfo
from spyglass.position.v1.dlc_utils import check_videofile, get_video_path
from spyglass.utils import logger
from spyglass.utils.dj_mixin import SpyglassMixin
from spyglass.utils import SpyglassMixin, logger

schema = dj.schema("position_v1_trodes_position")

Expand Down Expand Up @@ -158,7 +157,7 @@ class TrodesPosV1(SpyglassMixin, dj.Computed):
"""

def make(self, key):
print(f"Computing position for: {key}")
logger.info(f"Computing position for: {key}")
orig_key = copy.deepcopy(key)

analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"])
Expand Down Expand Up @@ -220,8 +219,9 @@ def fetch1_dataframe(self, add_frame_ind=True):
TrodesPosParams & {"trodes_pos_params_name": pos_params}
).fetch1("params")["is_upsampled"]
):
logger.warn(
"Upsampled position data, frame indices are invalid. Setting add_frame_ind=False"
logger.warning(
"Upsampled position data, frame indices are invalid. "
+ "Setting add_frame_ind=False"
)
add_frame_ind = False
return IntervalPositionInfo._data_to_df(
Expand All @@ -245,7 +245,7 @@ class TrodesPosVideo(SpyglassMixin, dj.Computed):
def make(self, key):
M_TO_CM = 100

print("Loading position data...")
logger.info("Loading position data...")
raw_position_df = (
RawPosition.PosObject
& {
Expand All @@ -255,7 +255,7 @@ def make(self, key):
).fetch1_dataframe()
position_info_df = (TrodesPosV1() & key).fetch1_dataframe()

print("Loading video data...")
logger.info("Loading video data...")
epoch = (
int(
key["interval_list_name"]
Expand Down Expand Up @@ -299,7 +299,7 @@ def make(self, key):
position_time = np.asarray(position_info_df.index)
cm_per_pixel = meters_per_pixel * M_TO_CM

print("Making video...")
logger.info("Making video...")
self.make_video(
video_path,
centroids,
Expand Down Expand Up @@ -367,7 +367,7 @@ def make_video(
frame_size = (int(video.get(3)), int(video.get(4)))
frame_rate = video.get(5)
n_frames = int(orientation_mean.shape[0])
print(f"video filepath: {output_video_filename}")
logger.info(f"video filepath: {output_video_filename}")
out = cv2.VideoWriter(
output_video_filename, fourcc, frame_rate, frame_size, True
)
Expand Down
10 changes: 8 additions & 2 deletions src/spyglass/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ def __init__(self, base_dir: str = None, **kwargs):
}

def load_config(
self, base_dir=None, force_reload=False, on_startup: bool = False
self,
base_dir=None,
force_reload=False,
on_startup: bool = False,
**kwargs,
):
"""
Loads the configuration settings for the object.
Expand Down Expand Up @@ -134,7 +138,9 @@ def load_config(
dj_dlc = dj_custom.get("dlc_dirs", {})

self._debug_mode = dj_custom.get("debug_mode", False)
self._test_mode = dj_custom.get("test_mode", False)
self._test_mode = kwargs.get("test_mode") or dj_custom.get(
"test_mode", False
)

resolved_base = (
base_dir
Expand Down
114 changes: 75 additions & 39 deletions src/spyglass/utils/database_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __init__(
host_name=None,
debug=False,
target_database=None,
exec_user=None,
exec_pass=None,
):
"""Class to manage common database settings
Expand All @@ -59,6 +61,10 @@ def __init__(
Default False. If True, pprint sql instead of running
target_database : str, optional
Default is mysql. Can also be docker container id
exec_user : str, optional
User for executing commands. If None, use dj.config
exec_pass : str, optional
Password for executing commands. If None, use dj.config
"""
self.shared_modules = [f"{m}{ESC}" for m in SHARED_MODULES]
self.user = user_name or dj.config["database.user"]
Expand All @@ -67,30 +73,37 @@ def __init__(
)
self.debug = debug
self.target_database = target_database or "mysql"
self.exec_user = exec_user or dj.config["database.user"]
self.exec_pass = exec_pass or dj.config["database.password"]

@property
def _create_roles_sql(self):
guest_role = [
f"{CREATE_ROLE}'dj_guest';\n",
f"{GRANT_SEL}`%`.* TO 'dj_guest';\n",
]
collab_role = [
f"{CREATE_ROLE}'dj_collab';\n",
f"{GRANT_SEL}`%`.* TO 'dj_collab';\n",
] # also gets own prefix below
user_role = [
f"{CREATE_ROLE}'dj_user';\n",
f"{GRANT_SEL}`%`.* TO 'dj_user';\n",
] + [
f"{GRANT_ALL}`{module}`.* TO 'dj_user';\n"
for module in self.shared_modules
] # also gets own prefix below
admin_role = [
f"{CREATE_ROLE}'dj_admin';\n",
f"{GRANT_ALL}`%`.* TO 'dj_admin';\n",
]
def _create_roles_dict(self):
return dict(
guest=[
f"{CREATE_ROLE}`dj_guest`;\n",
f"{GRANT_SEL}`%`.* TO `dj_guest`;\n",
],
collab=[
f"{CREATE_ROLE}`dj_collab`;\n",
f"{GRANT_SEL}`%`.* TO `dj_collab`;\n",
], # also gets own prefix below
user=[
f"{CREATE_ROLE}`dj_user`;\n",
f"{GRANT_SEL}`%`.* TO `dj_user`;\n",
]
+ [
f"{GRANT_ALL}`{module}`.* TO `dj_user`;\n"
for module in self.shared_modules
], # also gets own prefix below
admin=[
f"{CREATE_ROLE}`dj_admin`;\n",
f"{GRANT_ALL}`%`.* TO `dj_admin`;\n",
],
)

return guest_role + collab_role + user_role + admin_role
@property
def _create_roles_sql(self):
return sum(self._create_roles_dict.values(), [])

def _create_user_sql(self, role):
"""Create user and grant role"""
Expand Down Expand Up @@ -118,17 +131,33 @@ def _add_collab_sql(self):
def _add_user_sql(self):
return self._create_user_sql("dj_user") + self._user_prefix_sql

@property
def _add_admin_sql(self):
return self._create_user_sql("dj_admin") + self._user_prefix_sql

def _add_module_sql(self, module_name):
return [f"{GRANT_ALL}`{module_name}{ESC}`.* TO dj_user;\n"]

def add_collab(self):
def add_guest(self, *args, **kwargs):
"""Add guest user with select permissions to shared modules"""
file = self.write_temp_file(self._add_guest_sql)
self.exec(file)

def add_collab(self, *args, **kwargs):
"""Add collaborator user with full permissions to shared modules"""
file = self.write_temp_file(self._add_collab_sql)
self.exec(file)

def add_guest(self):
"""Add guest user with select permissions to shared modules"""
file = self.write_temp_file(self._add_guest_sql)
def add_user(self, check_exists=False, *args, **kwargs):
"""Add user to database with permissions to shared modules"""
if check_exists:
self.check_user_exists()
file = self.write_temp_file(self._add_user_sql)
self.exec(file)

def add_admin(self, *args, **kwargs):
"""Add admin user with full permissions to all modules"""
file = self.write_temp_file(self._add_admin_sql)
self.exec(file)

def add_module(self, module_name):
Expand All @@ -137,19 +166,26 @@ def add_module(self, module_name):
file = self.write_temp_file(self._add_module_sql(module_name))
self.exec(file)

def add_dj_user(self, check_exists=True):
def check_user_exists(self):
"""Add user to database with permissions to shared modules"""
if check_exists:
user_home = Path.home().parent / self.user
if user_home.exists():
logger.info("Creating database user ", self.user)
else:
sys.exit(
f"Error: couldn't find {self.user} in home dir: {user_home}"
)

file = self.write_temp_file(self._add_user_sql)
self.exec(file)
user_home = Path.home().parent / self.user
if user_home.exists():
logger.info("Creating database user ", self.user)
else:
sys.exit(
f"Error: couldn't find {self.user} in home dir: {user_home}"
)

def add_user_by_role(self, role, check_exists=False):
add_func = {
"guest": self.add_guest,
"user": self.add_user,
"collab": self.add_collab,
"admin": self.add_admin,
}
if role not in add_func:
raise ValueError(f"Role {role} not recognized")
add_func[role]()

def add_roles(self):
"""Add roles to database"""
Expand Down Expand Up @@ -180,7 +216,7 @@ def exec(self, file):
cmd = (
f"mysql -p -h {self.host} < {file.name}"
if self.target_database == "mysql"
else f"docker exec -i {self.target_database} mysql -u {self.user} "
+ f"--password=tutorial < {file.name}"
else f"docker exec -i {self.target_database} mysql -u "
+ f"{self.exec_user} --password={self.exec_pass} < {file.name}"
)
os.system(cmd)
9 changes: 6 additions & 3 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,17 @@ def _merge_restrict_parents(
return part_parents

@classmethod
def _merge_repr(cls, restriction: str = True) -> dj.expression.Union:
def _merge_repr(
cls, restriction: str = True, include_empties=False
) -> dj.expression.Union:
"""Merged view, including null entries for columns unique to one part.
Parameters
---------
restriction: str, optional
Restriction to apply to the merged view
include_empties: bool, optional
Default False. Add columns for empty parts.
Returns
------
Expand All @@ -246,7 +250,7 @@ def _merge_repr(cls, restriction: str = True) -> dj.expression.Union:
for p in cls._merge_restrict_parts(
restriction=restriction,
add_invalid_restrict=False,
return_empties=False, # motivated by SpikeSortingOutput.Import
return_empties=include_empties,
)
]
if not parts:
Expand Down Expand Up @@ -635,7 +639,6 @@ def merge_get_parent(
)

if not multi_source and len(part_parents) != 1:
__import__("pdb").set_trace()
raise ValueError(
f"Found {len(part_parents)} potential parents: {part_parents}"
+ "\n\tTry adding a string restriction when invoking "
Expand Down
1 change: 1 addition & 0 deletions src/spyglass/utils/nwb_helper_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ def get_nwb_copy_filename(nwb_file_name):
def change_group_permissions(
subject_ids, set_group_name, analysis_dir="/stelmo/nwb/analysis"
):
"""CB NOTE: Unused. Remove?"""
# Change to directory with analysis nwb files
os.chdir(analysis_dir)
# Get nwb file directories with specified subject ids
Expand Down
2 changes: 1 addition & 1 deletion tests/common/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def mini_beh_events(mini_behavior):


@pytest.fixture(scope="session")
def mini_pos_interval_dict(common):
def mini_pos_interval_dict(mini_insert, common):
yield {"interval_list_name": common.PositionSource.get_pos_interval_name(0)}


Expand Down
Loading

0 comments on commit e1b4c1c

Please sign in to comment.