Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added working 'get_foraging_bouts' function #410

Open
wants to merge 60 commits into
base: datajoint_pipeline
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
1a2ade9
Allow forwarding of load function kwargs to reader
glopesdev Sep 6, 2024
9637fb4
Apply all safe ruff fixes
glopesdev Sep 6, 2024
97389bd
Black formatting
glopesdev Sep 6, 2024
e60b766
Move top-level linter settings to lint section
glopesdev Sep 6, 2024
d2a5104
Ignore missing docs in __init__ and magic methods
glopesdev Sep 6, 2024
743324a
Apply ruff recommendations to low-level API
glopesdev Sep 6, 2024
257d9cd
Ignore missing docs for module, package and tests
glopesdev Sep 6, 2024
e4fe028
Ignore missing docs for schema classes and streams
glopesdev Sep 6, 2024
0927a9a
Apply more ruff recommendations to low-level API
glopesdev Sep 6, 2024
a5bc419
Merge pull request #401 from SainsburyWellcomeCentre/gl-load-kwargs
jkbhagatio Sep 11, 2024
00ead97
Add support for downsampling encoder data
glopesdev Sep 11, 2024
dfd2d02
Ensure low-level API tests run on raw data
glopesdev Sep 11, 2024
be5d3c1
Update pre-commit-config
lochhh Apr 26, 2024
38aeeba
Remove black dependency
lochhh Apr 26, 2024
53e88c2
Temporarily disable ruff and pyright in pre-commit
lochhh Sep 12, 2024
6798e07
Auto-fix mixed lined endings and trailing whitespace
lochhh Sep 12, 2024
cc9c924
Ruff autofix
lochhh Sep 12, 2024
f225406
Fix D103 Missing docstring in public function
lochhh Sep 12, 2024
c7e74f7
Fix D415 First line should end with a period, question mark, or excla…
lochhh Sep 12, 2024
d9d0287
Ignore deprecated PT004
lochhh Sep 12, 2024
c866ab9
Fix D417 Missing argument description in the docstring
lochhh Sep 12, 2024
75af9b8
Ignore E741 check for `h, l, s` assignment
lochhh Sep 12, 2024
12abfcb
Use redundant import alias as suggested in F401
lochhh Sep 12, 2024
a5d88a2
Re-enable ruff in pre-commit
lochhh Sep 12, 2024
7fe4837
Re-enable pyright in pre-commit
lochhh Sep 12, 2024
b0714ab
Configure ruff to ignore .ipynb files
lochhh Sep 12, 2024
68e4344
Remove ruff `--config` in build_env_run_tests workflow
lochhh Sep 12, 2024
038f118
Merge pull request #409 from SainsburyWellcomeCentre/lint-format
glopesdev Sep 13, 2024
3d88792
Add downsampling tests and remove metadata kwargs
glopesdev Sep 17, 2024
a45b16a
Assert expected data ranges and add test comments
glopesdev Sep 18, 2024
222f502
Assert downsampled data is monotonic increasing
glopesdev Sep 18, 2024
871f342
Merge pull request #407 from SainsburyWellcomeCentre/gl-dev
glopesdev Sep 18, 2024
9c9a88b
Merge remote-tracking branch 'origin/main' into gl-ruff-check
glopesdev Sep 18, 2024
df20e9f
Apply remaining ruff recommendations
glopesdev Sep 18, 2024
6e64c83
Exclude venv folder from pyright checks
glopesdev Sep 19, 2024
8d0c03f
Remove obsolete and unused qc module
glopesdev Sep 19, 2024
97bc21c
Apply pyright recommendations
glopesdev Sep 19, 2024
6bacc43
Disable useLibraryCodeForTypes
glopesdev Sep 19, 2024
d1180a8
Remove unused function call
glopesdev Sep 19, 2024
23c440f
Ensure all roots are Path objects
glopesdev Sep 19, 2024
5dfd4a4
Exclude dj_pipeline tests from online CI
glopesdev Sep 19, 2024
f557c48
Exclude dj_pipeline tests from coverage report
glopesdev Sep 19, 2024
81bbfa1
Fix macOS wheel build for `datajoint` (Issue #249) (#406)
MilagrosMarin Sep 20, 2024
a678b8d
Run CI checks using pip env and pyproject.toml
glopesdev Sep 20, 2024
2107691
Run code checks and tests on all platforms
glopesdev Sep 20, 2024
1de5c25
Activate venv for later steps and remove all conda dependencies (#413)
lochhh Sep 20, 2024
a889dba
Merge pull request #402 from SainsburyWellcomeCentre/gl-ruff-check
glopesdev Sep 20, 2024
8a65b49
feat(block_analysis): add `patch_threshold` at `BlockSubjectAnalysis`…
ttngu207 Jul 5, 2024
0c01790
fix(block_analysis): handles patch rate value being INF
ttngu207 Jul 8, 2024
4f9262e
feat: update SLEAP ingestion - no longer dependent on `config_file`
ttngu207 Jul 16, 2024
9162eb4
fix: update logic to associate true pellet times with each threshold …
ttngu207 Jul 18, 2024
a1b30e2
chore(block_analysis): throw meaningful error message
ttngu207 Jul 24, 2024
5a45977
fix(block_analysis): handle edge case where no pellet found after thr…
ttngu207 Jul 24, 2024
04e6faf
chore: tune worker sleep cycle
ttngu207 Aug 15, 2024
2ce605a
feat(social0.4): add devices_schema for `social0.4`
ttngu207 Aug 15, 2024
8597ec4
Add helper script to create new social experiments
ttngu207 Aug 20, 2024
cf815e6
chore: minor updates to automated worker
ttngu207 Aug 20, 2024
baa5336
added working 'get_foraging_bouts' function
jkbhagatio Sep 13, 2024
b0e122c
Re-add missing `patch_threshold`
lochhh Sep 20, 2024
a96231c
Simplify `get_foraging_bouts`
lochhh Sep 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 24 additions & 51 deletions .github/workflows/build_env_run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
name: build_env_run_tests
on:
pull_request:
branches: [ main ]
branches: [main]
types: [opened, reopened, synchronize]
workflow_dispatch: # allows running manually from Github's 'Actions' tab
workflow_dispatch: # allows running manually from Github's 'Actions' tab

jobs:
build_env_pip_pyproject: # checks only for building env using pip and pyproject.toml
name: Build env using pip and pyproject.toml
runs-on: ubuntu-latest
build_env_run_tests: # checks for building env using pyproject.toml and runs codebase checks and tests
name: Build env using pip and pyproject.toml on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
if: github.event.pull_request.draft == false
strategy:
matrix:
Expand All @@ -20,70 +20,43 @@ jobs:
fail-fast: false
defaults:
run:
shell: bash -l {0} # reset shell for each step
shell: ${{ matrix.os == 'windows-latest' && 'cmd' || 'bash' }} -l {0} # Adjust shell based on OS
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Create venv and install dependencies
run: |
python -m venv .venv
source .venv/bin/activate
.venv/Scripts/activate || source .venv/bin/activate
pip install -e .[dev]
pip list
.venv/bin/python -c "import aeon"

build_env_run_tests: # checks for building env using mamba and runs codebase checks and tests
name: Build env and run tests on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
if: github.event.pull_request.draft == false
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: [3.11]
fail-fast: false
defaults:
run:
shell: bash -l {0} # reset shell for each step
steps:
- name: checkout repo
uses: actions/checkout@v2
- name: set up conda env
uses: conda-incubator/setup-miniconda@v2
with:
use-mamba: true
miniforge-variant: Mambaforge
python-version: ${{ matrix.python-version }}
environment-file: ./env_config/env.yml
activate-environment: aeon
- name: Update conda env with dev reqs
run: mamba env update -f ./env_config/env_dev.yml

# Only run codebase checks and tests for ubuntu.
python -c "import aeon"
- name: Activate venv for later steps
run: |
echo "VIRTUAL_ENV=$(pwd)/.venv" >> $GITHUB_ENV
echo "$(pwd)/.venv/bin" >> $GITHUB_PATH # For Unix-like systems
echo "$(pwd)/.venv/Scripts" >> $GITHUB_PATH # For Windows
# Only run codebase checks and tests for Linux (ubuntu).
- name: ruff
if: matrix.os == 'ubuntu-latest'
run: python -m ruff check --config ./pyproject.toml .
run: ruff check .
- name: pyright
if: matrix.os == 'ubuntu-latest'
run: python -m pyright --level error --project ./pyproject.toml .
run: pyright --level error --project ./pyproject.toml .
- name: pytest
if: matrix.os == 'ubuntu-latest'
run: python -m pytest tests/

run: pytest tests/ --ignore=tests/dj_pipeline
- name: generate test coverage report
if: matrix.os == 'ubuntu-latest'
if: ${{ matrix.os == 'ubuntu-latest' }}
run: |
python -m pytest --cov=aeon ./tests/ --cov-report=xml:./tests/test_coverage/test_coverage_report.xml
#python -m pytest --cov=aeon ./tests/ --cov-report=html:./tests/test_coverage/test_coverage_report_html
python -m pytest --cov=aeon tests/ --ignore=tests/dj_pipeline --cov-report=xml:tests/test_coverage/test_coverage_report.xml
- name: upload test coverage report to codecov
if: matrix.os == 'ubuntu-latest'
uses: codecov/codecov-action@v2
if: ${{ matrix.os == 'ubuntu-latest' }}
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
directory: ./tests/test_coverage/
directory: tests/test_coverage/
files: test_coverage_report.xml
fail_ci_if_error: true
verbose: true
21 changes: 7 additions & 14 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
# For info on running pre-commit manually, see `pre-commit run --help`

default_language_version:
python: python3.11

files: "^(test|aeon)\/.*$"
repos:
- repo: meta
hooks:
- id: identity

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.6.0
hooks:
- id: check-json
- id: check-yaml
Expand All @@ -25,20 +21,17 @@ repos:
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]

- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
args: [--check, --config, ./pyproject.toml]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.286
rev: v0.6.4
hooks:
# Run the linter with the `--fix` flag.
- id: ruff
args: [--config, ./pyproject.toml]
args: [ --fix ]
# Run the formatter.
- id: ruff-format

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.324
rev: v1.1.380
hooks:
- id: pyright
args: [--level, error, --project, ./pyproject.toml]
Expand Down
2 changes: 1 addition & 1 deletion aeon/README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
#
#
4 changes: 2 additions & 2 deletions aeon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
finally:
del version, PackageNotFoundError

# Set functions avaialable directly under the 'aeon' top-level namespace
from aeon.io.api import load
# Set functions available directly under the 'aeon' top-level namespace
from aeon.io.api import load as load # noqa: PLC0414
1 change: 0 additions & 1 deletion aeon/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
#
26 changes: 5 additions & 21 deletions aeon/analysis/block_plotting.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
import os
import pathlib
from colorsys import hls_to_rgb, rgb_to_hls
from contextlib import contextmanager
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objs as go
import seaborn as sns
from numpy.lib.stride_tricks import as_strided

"""Standardize subject colors, patch colors, and markers."""

Expand All @@ -35,27 +25,21 @@
"star",
]
patch_markers_symbols = ["●", "⧓", "■", "⧗", "♦", "✖", "×", "▲", "★"]
patch_markers_dict = {
marker: symbol for marker, symbol in zip(patch_markers, patch_markers_symbols)
}
patch_markers_dict = dict(zip(patch_markers, patch_markers_symbols, strict=False))
patch_markers_linestyles = ["solid", "dash", "dot", "dashdot", "longdashdot"]


def gen_hex_grad(hex_col, vals, min_l=0.3):
"""Generates an array of hex color values based on a gradient defined by unit-normalized values."""
# Convert hex to rgb to hls
h, l, s = rgb_to_hls(
*[int(hex_col.lstrip("#")[i: i + 2], 16) / 255 for i in (0, 2, 4)]
)
h, l, s = rgb_to_hls(*[int(hex_col.lstrip("#")[i : i + 2], 16) / 255 for i in (0, 2, 4)]) # noqa: E741
grad = np.empty(shape=(len(vals),), dtype="<U10") # init grad
for i, val in enumerate(vals):
cur_l = (l * val) + (
min_l * (1 - val)
) # get cur lightness relative to `hex_col`
cur_l = (l * val) + (min_l * (1 - val)) # get cur lightness relative to `hex_col`
cur_l = max(min(cur_l, l), min_l) # set min, max bounds
cur_rgb_col = hls_to_rgb(h, cur_l, s) # convert to rgb
cur_hex_col = "#%02x%02x%02x" % tuple(
int(c * 255) for c in cur_rgb_col
cur_hex_col = "#{:02x}{:02x}{:02x}".format(
*tuple(int(c * 255) for c in cur_rgb_col)
) # convert to hex
grad[i] = cur_hex_col

Expand Down
24 changes: 14 additions & 10 deletions aeon/analysis/movies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from aeon.io import video


def gridframes(frames, width, height, shape=None):
"""Arranges a set of frames into a grid layout with the specified
pixel dimensions and shape.
def gridframes(frames, width, height, shape: None | int | tuple[int, int] = None):
"""Arranges a set of frames into a grid layout with the specified pixel dimensions and shape.

:param list frames: A list of frames to include in the grid layout.
:param int width: The width of the output grid image, in pixels.
Expand All @@ -21,7 +20,7 @@ def gridframes(frames, width, height, shape=None):
"""
if shape is None:
shape = len(frames)
if type(shape) not in [list, tuple]:
if isinstance(shape, int):
shape = math.ceil(math.sqrt(shape))
shape = (shape, shape)

Expand All @@ -44,7 +43,7 @@ def gridframes(frames, width, height, shape=None):

def averageframes(frames):
"""Returns the average of the specified collection of frames."""
return cv2.convertScaleAbs(sum(np.multiply(1 / len(frames), frames)))
return cv2.convertScaleAbs(np.sum(np.multiply(1 / len(frames), frames)))


def groupframes(frames, n, fun):
Expand All @@ -65,7 +64,7 @@ def groupframes(frames, n, fun):
i = i + 1


def triggerclip(data, events, before=pd.Timedelta(0), after=pd.Timedelta(0)):
def triggerclip(data, events, before=None, after=None):
"""Split video data around the specified sequence of event timestamps.

:param DataFrame data:
Expand All @@ -76,10 +75,16 @@ def triggerclip(data, events, before=pd.Timedelta(0), after=pd.Timedelta(0)):
:return:
A pandas DataFrame containing the frames, clip and sequence numbers for each event timestamp.
"""
if before is not pd.Timedelta:
if before is None:
before = pd.Timedelta(0)
elif before is not pd.Timedelta:
before = pd.Timedelta(before)
if after is not pd.Timedelta:

if after is None:
after = pd.Timedelta(0)
elif after is not pd.Timedelta:
after = pd.Timedelta(after)

if events is not pd.Index:
events = events.index

Expand Down Expand Up @@ -107,8 +112,7 @@ def collatemovie(clipdata, fun):


def gridmovie(clipdata, width, height, shape=None):
"""Collates a set of video clips into a grid movie with the specified pixel dimensions
and grid layout.
"""Collates a set of video clips into a grid movie with the specified pixel dimensions and grid layout.

:param DataFrame clipdata:
A pandas DataFrame where each row specifies video path, frame number, clip and sequence number.
Expand Down
26 changes: 15 additions & 11 deletions aeon/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from matplotlib import colors
from matplotlib.collections import LineCollection

from aeon.analysis.utils import *
from aeon.analysis.utils import rate, sessiontime


def heatmap(position, frequency, ax=None, **kwargs):
Expand All @@ -31,21 +31,20 @@ def heatmap(position, frequency, ax=None, **kwargs):
return mesh, cbar


def circle(x, y, radius, fmt=None, ax=None, **kwargs):
def circle(x, y, radius, *args, ax=None, **kwargs):
"""Plot a circle centered at the given x, y position with the specified radius.

:param number x: The x-component of the circle center.
:param number y: The y-component of the circle center.
:param number radius: The radius of the circle.
:param str, optional fmt: The format used to plot the circle line.
:param Axes, optional ax: The Axes on which to draw the circle.
"""
if ax is None:
ax = plt.gca()
points = pd.DataFrame(np.linspace(0, 2 * math.pi, 360), columns=["angle"])
points = pd.DataFrame({"angle": np.linspace(0, 2 * math.pi, 360)})
points["x"] = radius * np.cos(points.angle) + x
points["y"] = radius * np.sin(points.angle) + y
ax.plot(points.x, points.y, fmt, **kwargs)
ax.plot(points.x, points.y, *args, **kwargs)


def rateplot(
Expand All @@ -60,16 +59,17 @@ def rateplot(
ax=None,
**kwargs,
):
"""Plot the continuous event rate and raster of a discrete event sequence, given the specified
window size and sampling frequency.
"""Plot the continuous event rate and raster of a discrete event sequence.

The window size and sampling frequency can be specified.

:param Series events: The discrete sequence of events.
:param offset window: The time period of each window used to compute the rate.
:param DateOffset, Timedelta or str frequency: The sampling frequency for the continuous rate.
:param number, optional weight: A weight used to scale the continuous rate of each window.
:param datetime, optional start: The left bound of the time range for the continuous rate.
:param datetime, optional end: The right bound of the time range for the continuous rate.
:param datetime, optional smooth: The size of the smoothing kernel applied to the continuous rate output.
:param datetime, optional smooth: The size of the smoothing kernel applied to the rate output.
:param DateOffset, Timedelta or str, optional smooth:
The size of the smoothing kernel applied to the continuous rate output.
:param bool, optional center: Specifies whether to center the convolution kernels.
Expand Down Expand Up @@ -108,8 +108,8 @@ def colorline(
x,
y,
z=None,
cmap=plt.get_cmap("copper"),
norm=plt.Normalize(0.0, 1.0),
cmap=None,
norm=None,
ax=None,
**kwargs,
):
Expand All @@ -128,9 +128,13 @@ def colorline(
ax = plt.gca()
if z is None:
z = np.linspace(0.0, 1.0, len(x))
if cmap is None:
cmap = plt.get_cmap("copper")
if norm is None:
norm = colors.Normalize(0.0, 1.0)
z = np.asarray(z)
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
lines = LineCollection(segments, array=z, cmap=cmap, norm=norm, **kwargs)
lines = LineCollection(segments, array=z, cmap=cmap, norm=norm, **kwargs) # type: ignore
ax.add_collection(lines)
return lines
2 changes: 1 addition & 1 deletion aeon/analysis/readme.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
#
#
Loading