Skip to content

Commit

Permalink
cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Aug 26, 2024
1 parent cf361a9 commit 621f5d3
Show file tree
Hide file tree
Showing 62 changed files with 2,285 additions and 442 deletions.
Binary file added docs/_static/img/BlobLocationComputer.webm
Binary file not shown.
Binary file added docs/_static/img/count_values_in_ranges_cuda.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/img/create_shap_log_cuda.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/img/direction_from_two_bps_cuda.webp
Binary file not shown.
Binary file added docs/_static/img/get_3pt_angle_cuda.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/img/get_convex_hull_cuda.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/img/is_inside_polygon_cuda.webp
Binary file not shown.
Binary file added docs/_static/img/sliding_mean_cuda.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/img/sliding_min_cuda.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 9 additions & 1 deletion docs/simba.data_processors.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Data processors
==============================

Methods for manipulating and transforming classification and pose-estimation data
Methods for manipulating animal detection and transforming classification and pose-estimation data

Aggregate classifier statistics calculator
--------------------------------------------------
Expand Down Expand Up @@ -148,3 +148,11 @@ Spontaneous alternation calculator
.. automodule:: simba.data_processors.spontaneous_alternation_calculator
:members:
:show-inheritance:


"Blob" location detector
------------------------------------------------------------

.. automodule:: simba.data_processors.blob_location_computer
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/simba.plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,10 @@ Spontaneous alternation plotter
:members:
:show-inheritance:

"Blob" plotter
---------------------------------------------------------------

.. automodule:: simba.plotting.blob_plotter
:members:
:show-inheritance:

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# Setup configuration
setuptools.setup(
name="Simba-UW-tf-dev",
version="2.0.4",
version="2.0.7",
author="Simon Nilsson, Jia Jie Choong, Sophia Hwang",
author_email="[email protected]",
description="Toolkit for computer classification and analysis of behaviors in experimental animals",
Expand Down
20 changes: 6 additions & 14 deletions simba/SimBA.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
from simba.utils.lookups import (get_bp_config_code_class_pairs, get_emojis,
get_icons_paths, load_simba_fonts)
from simba.utils.printing import stdout_success, stdout_warning
from simba.utils.read_write import get_video_meta_data
from simba.utils.read_write import get_video_meta_data, find_core_cnt
from simba.utils.warnings import FFMpegNotFoundWarning, PythonVersionWarning
from simba.video_processors.video_processing import \
extract_frames_from_all_videos_in_directory
Expand Down Expand Up @@ -211,7 +211,7 @@ def __init__(self, config_path: str):
simongui.wm_title("LOAD PROJECT")
simongui.columnconfigure(0, weight=1)
simongui.rowconfigure(0, weight=1)

self.core_cnt = find_core_cnt()[0]
self.btn_icons = get_icons_paths()

for k in self.btn_icons.keys():
Expand Down Expand Up @@ -640,7 +640,10 @@ def import_ethovision(self):
def import_deepethogram(self):
ann_folder = askdirectory()
deepethogram_importer = DeepEthogramImporter(config_path=self.config_path, data_dir=ann_folder)
threading.Thread(target=deepethogram_importer.run).start()
if self.core_cnt > Defaults.THREADSAFE_CORE_COUNT.value:
deepethogram_importer.run()
else:
threading.Thread(target=deepethogram_importer.run).start()

def import_noldus_observer(self):
directory = askdirectory()
Expand All @@ -652,17 +655,6 @@ def importMARS(self):
bento_appender = BentoAppender(config_path=self.config_path, data_dir=bento_dir)
threading.Thread(target=bento_appender.run).start()

def choose_animal_bps(self):
if hasattr(self, "path_plot_animal_bp_frm"):
self.path_plot_animal_bp_frm.destroy()
self.path_plot_animal_bp_frm = LabelFrame( self.path_plot_frm, text="CHOOSE ANIMAL BODY-PARTS", font=("Helvetica", 10, "bold"), pady=5, padx=5)
self.path_plot_bp_dict = {}
for animal_cnt in range(int(self.path_animal_cnt_dropdown.getChoices())):
self.path_plot_bp_dict[animal_cnt] = DropDownMenu( self.path_plot_animal_bp_frm, "Animal {} bodypart".format(str(animal_cnt + 1)), self.bp_set, "15")
self.path_plot_bp_dict[animal_cnt].setChoices(self.bp_set[animal_cnt])
self.path_plot_bp_dict[animal_cnt].grid(row=animal_cnt, sticky=NW)
self.path_plot_animal_bp_frm.grid(row=2, column=0, sticky=NW)

def launch_interactive_plot(self):
interactive_grapher = InteractiveProbabilityGrapher(config_path=self.config_path,file_path=self.csvfile.file_path,model_path=self.modelfile.file_path)
interactive_grapher.run()
Expand Down
114 changes: 114 additions & 0 deletions simba/data_processors/blob_location_computer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import os
from typing import Union, Optional
import numpy as np
from copy import deepcopy
try:
from typing import Literal
except:
from typing_extensions import Literal
from simba.utils.read_write import find_all_videos_in_directory, get_fn_ext, remove_files, get_video_meta_data
from simba.video_processors.video_processing import video_bg_subtraction, video_bg_substraction_mp
from simba.mixins.image_mixin import ImageMixin
from simba.utils.checks import check_if_dir_exists, check_valid_boolean, check_int, check_nvidea_gpu_available, check_str
from simba.utils.enums import Formats, Options, Methods
from simba.utils.printing import stdout_success, SimbaTimer
from simba.utils.read_write import write_df
from simba.utils.errors import InvalidInputError, FFMPEGCodecGPUError
from simba.utils.data import savgol_smoother, df_smoother
import pandas as pd

class BlobLocationComputer(object):
def __init__(self,
data_path: Union[str, os.PathLike],
verbose: Optional[bool] = True,
gpu: Optional[bool] = True,
batch_size: Optional[int] = 2500,
save_dir: Optional[Union[str, os.PathLike]] = None,
smoothing: Optional[str] = None,
multiprocessing: Optional[bool] = False):

"""
Detecting and saving blob locations from video files.
.. video:: _static/img/BlobLocationComputer.webm
:width: 800
:autoplay:
:loop:
:param Union[str, os.PathLike] data_path: Path to a video file or a directory containing video files. The videos will be processed for blob detection.
:param Optional[bool] verbose: If True, prints progress and success messages to the console. Default is True.
:param Optional[bool] gpu: If True, GPU acceleration will be used for blob detection. Default is True.
:param Optional[int] batch_size: The number of frames to process in each batch for blob detection. Default is 2500.
:param Optional[Union[str, os.PathLike]] save_dir: Directory where the blob location data will be saved as CSV files. If None, the results will not be saved. Default is None.
:param Optional[bool] multiprocessing: If True, video background subtraction will be done using multiprocessing. Default is False.
:example:
>>> x = BlobLocationComputer(data_path=r"C:\troubleshooting\RAT_NOR\project_folder\videos\2022-06-20_NOB_DOT_4_downsampled_bg_subtracted.mp4", multiprocessing=True, gpu=True, batch_size=2000, save_dir=r"C:\troubleshooting\RAT_NOR\project_folder\csv\blob_positions")
>>> x.run()
"""


if os.path.isdir(data_path):
self.data_paths = find_all_videos_in_directory(directory=data_path, as_dict=True, raise_error=True).values()
elif os.path.isfile(data_path):
self.data_paths = [data_path]
else:
raise InvalidInputError(msg=f'{data_path} is not a valid directory or video file path or directory path.')
self.video_meta_data = {}
for i in self.data_paths:
self.video_meta_data[get_fn_ext(filepath=i)[1]] = get_video_meta_data(video_path=i)
if save_dir is not None:
check_if_dir_exists(in_dir=os.path.dirname(save_dir), source=self.__class__.__name__)
check_valid_boolean(value=[verbose, gpu, multiprocessing])
if gpu and not check_nvidea_gpu_available():
raise FFMPEGCodecGPUError(msg='No GPU detected.', source=self.__class__.__name__)
if smoothing is not None:
check_str(name=f'{self.__class__.__name__} smoothing', value=smoothing, options=Options.SMOOTHING_OPTIONS.value)
check_int(name=f'{self.__class__.__name__} batch_size', value=batch_size, min_value=1)
self.multiprocessing = multiprocessing
self.verbose = verbose
self.gpu = gpu
self.batch_size = batch_size
self.save_dir = save_dir
self.smoothing = smoothing

def run(self):
timer = SimbaTimer(start=True)
self.location_data = {}
for file_cnt, video_path in enumerate(self.data_paths):
video_timer = SimbaTimer(start=True)
_, video_name, ext = get_fn_ext(filepath=video_path)
temp_video_path = os.path.join(os.path.dirname(video_path), video_name + '_temp.mp4')
if not self.multiprocessing:
_ = video_bg_subtraction(video_path=video_path, verbose=self.verbose, bg_color=(0, 0, 0), fg_color=(255, 255, 255), save_path=temp_video_path)
else:
_ = video_bg_substraction_mp(video_path=video_path, verbose=self.verbose, bg_color=(0, 0, 0), fg_color=(255, 255, 255), save_path=temp_video_path)
self.location_data[video_name] = ImageMixin.get_blob_locations(video_path=temp_video_path, gpu=self.gpu, verbose=self.verbose, batch_size=self.batch_size).astype(np.int32)
remove_files(file_paths=[temp_video_path])
video_timer.stop_timer()
print(f'Blob detection for video {video_name} ({file_cnt+1}/{len(self.data_paths)}) complete (elapsed time: {video_timer.elapsed_time_str}s)...')
timer.stop_timer()
if self.smoothing is not None:
print('Smoothing data...')
smoothened_data = {}
if self.smoothing == Methods.SAVITZKY_GOLAY.value:
for video_name, video_data in self.location_data.items():
smoothened_data[video_name] = savgol_smoother(data=video_data, fps=self.video_meta_data[video_name]['fps'], time_window=2000, source=video_name)
if self.smoothing == Methods.GAUSSIAN.value:
for video_name, video_data in self.location_data.items():
smoothened_data[video_name] = df_smoother(data=pd.DataFrame(video_data, columns=['X', 'Y']), fps=self.video_meta_data[video_name]['fps'], time_window=2000, source=video_name, method='gaussian')
self.location_data = deepcopy(smoothened_data)
del smoothened_data
if self.save_dir is not None:
for video_name, video_data in self.location_data.items():
save_path = os.path.join(self.save_dir, f'{video_name}.csv')
df = pd.DataFrame(video_data, columns=['X', 'Y'])
write_df(df=df, file_type=Formats.CSV.value, save_path=save_path)
if self.verbose:
stdout_success(f'Video blob detection complete for {len(self.data_paths)} videos, data saved at {self.save_dir}', elapsed_time=timer.elapsed_time_str)
else:
if self.verbose:
stdout_success(f'Video blob detection complete for {len(self.data_paths)} video', elapsed_time=timer.elapsed_time_str)

# x = BlobLocationComputer(data_path=r"C:\troubleshooting\RAT_NOR\project_folder\videos\2022-06-20_NOB_DOT_4_downsampled.mp4", multiprocessing=True, gpu=True, batch_size=2000, save_dir=r"C:\troubleshooting\RAT_NOR\project_folder\csv\blob_positions", smoothing='Savitzky Golay')
# x.run()
26 changes: 26 additions & 0 deletions simba/data_processors/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from .angle_3pt import get_3pt_angle
from .convex_hull import get_convex_hull
from .count_values_in_range import count_values_in_ranges
from .create_shap_log import create_shap_log
from .euclidan_distance import get_euclidean_distance
from .is_inside_circle import is_inside_circle
from .is_inside_polygon import is_inside_polygon
from .is_inside_rectangle import is_inside_rectangle
from .sliding_mean import sliding_mean
from .sliding_std import sliding_std
from .sliding_min import sliding_min
from .sliding_sum import sliding_sum


__all__ = ['get_3pt_angle',
'get_convex_hull',
'count_values_in_ranges',
'create_shap_log',
'get_euclidean_distance',
'is_inside_circle',
'is_inside_polygon',
'is_inside_rectangle',
'sliding_mean',
'sliding_std',
'sliding_min',
'sliding_sum']
60 changes: 60 additions & 0 deletions simba/data_processors/cuda/angle_3pt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from numba import cuda
import numpy as np
from simba.utils.read_write import read_df
import math

THREADS_PER_BLOCK = 256
@cuda.jit
def _get_3pt_angle_kernel(x_dev, y_dev, z_dev, results):
i = cuda.grid(1)

if i >= x_dev.shape[0]:
return

if i < x_dev.shape[0]:
x_x, x_y = x_dev[i][0], x_dev[i][1]
y_x, y_y = y_dev[i][0], y_dev[i][1]
z_x, z_y = z_dev[i][0], z_dev[i][1]
D = math.degrees(math.atan2(z_y - y_y, z_x - y_x) - math.atan2(x_y - y_y, x_x - y_x))
if D < 0:
D += 360
results[i] = D

def get_3pt_angle(x: np.ndarray, y: np.ndarray, z: np.ndarray) -> np.ndarray:
"""
Computes the angle formed by three points in 2D space for each corresponding row in the input arrays using
GPU. The points x, y, and z represent the coordinates of three points in space, and the angle is calculated
at point `y` between the line segments `xy` and `yz`.
.. image:: _static/img/get_3pt_angle_cuda.png
:width: 500
:align: center
:param x: A numpy array of shape (n, 2) representing the first point (e.g., nose) coordinates.
:param y: A numpy array of shape (n, 2) representing the second point (e.g., center) coordinates, where the angle is computed.
:param z: A numpy array of shape (n, 2) representing the second point (e.g., center) coordinates, where the angle is computed.
:return: A numpy array of shape (n, 1) containing the calculated angles (in degrees) for each row.
:example:
>>> video_path = r"/mnt/c/troubleshooting/mitra/project_folder/videos/501_MA142_Gi_CNO_0514.mp4"
>>> data_path = r"/mnt/c/troubleshooting/mitra/project_folder/csv/outlier_corrected_movement_location/501_MA142_Gi_CNO_0514 - test.csv"
>>> df = read_df(file_path=data_path, file_type='csv')
>>> y = df[['Center_x', 'Center_y']].values
>>> x = df[['Nose_x', 'Nose_y']].values
>>> z = df[['Tail_base_x', 'Tail_base_y']].values
>>> angle_x = get_3pt_angle(x=x, y=y, z=z)
"""


x = np.ascontiguousarray(x).astype(np.float32)
y = np.ascontiguousarray(y).astype(np.float32)
n, m = x.shape
x_dev = cuda.to_device(x)
y_dev = cuda.to_device(y)
z_dev = cuda.to_device(z)
results = cuda.device_array((n, m), dtype=np.int32)
bpg = (n + (THREADS_PER_BLOCK - 1)) // THREADS_PER_BLOCK
_get_3pt_angle_kernel[bpg, THREADS_PER_BLOCK](x_dev, y_dev, z_dev, results)
results = results.copy_to_host()
cuda.current_context().memory_manager.deallocations.clear()
return results
Loading

0 comments on commit 621f5d3

Please sign in to comment.