Skip to content

Commit

Permalink
Merge pull request #9 from Olivier-tl/stim-parameters-loading
Browse files Browse the repository at this point in the history
Added loading of stim settings
  • Loading branch information
Olivier-tl authored Apr 3, 2024
2 parents e512b61 + da98cb4 commit a9a9807
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 36 deletions.
14 changes: 14 additions & 0 deletions scripts/load_and_plot_coarse_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import argparse
import logging
from dataclasses import asdict
from pathlib import Path

import plotly.express as px
from utils import parse_args

from reveal_data_client import RevealDataClient
from reveal_data_client.constants import VnsStatus

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,6 +55,18 @@ def main(dataset_path: Path) -> None:
f"Elapsed time: {elapsed}, sampling rate: {sampling_rate} Hz.\n"
f"Plot saved to {fig_path}\n"
)
if vns_status == VnsStatus.ON:
try:
stim = client.coarse_time_series.get_stim_setting(
participant_id, visit_id, ans_period
)
LOG.info(
f"Stim setting for participant {participant_id}, visit {visit_id}, ANS period {ans_period}: {asdict(stim)}\n"
)
except ValueError:
LOG.error(
f"Stim setting not found for participant {participant_id}, visit {visit_id}, ANS period {ans_period}\n"
)


if __name__ == "__main__":
Expand Down
7 changes: 4 additions & 3 deletions src/reveal_data_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from reveal_data_client.constants import AnsPeriod, VisitID, VnsStatus
from reveal_data_client.time_series.coarse.client import CoarseTimeSeriesClient
from reveal_data_client.types import ParticipantId

PRIMARY_DIR = Path("primary")

Expand All @@ -30,18 +31,18 @@ def __init__(self, coarse_ts_client: CoarseTimeSeriesClient) -> None:
"""
self._coarse_ts_client = coarse_ts_client

def get_participant_ids(self) -> Sequence[str]:
def get_participant_ids(self) -> Sequence[ParticipantId]:
"""Get the participant IDs in the dataset."""
# TODO: Refactor this in a separate MetadataClient class?
return self._coarse_ts_client.get_participant_ids()

def get_visit_ids(self, participant_id: str) -> Sequence[VisitID]:
def get_visit_ids(self, participant_id: ParticipantId) -> Sequence[VisitID]:
"""Get the visit IDs for a participant."""
# TODO: Refactor this in a separate MetadataClient class?
return self._coarse_ts_client.get_visit_ids(participant_id)

def get_ans_periods_and_vns_status(
self, participant_id: str, visit_id: VisitID
self, participant_id: ParticipantId, visit_id: VisitID
) -> Sequence[tuple[AnsPeriod, VnsStatus]]:
"""Get the ANS periods and VNS status for a participant and visit."""
# TODO: Refactor this in a separate MetadataClient class?
Expand Down
10 changes: 7 additions & 3 deletions src/reveal_data_client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ class AnsPeriod(str, Enum):
"""Third acute stim of the day."""

IHG = "IHG"
"""IHG_w/VNS on or off"""
"""Handgrip exercise with VNS ON or OFF"""

PECO = "PECO"
"""PECO_w/VNS on or off"""
"""Post-exercise circulatory occlusion with VNS ON or OFF"""

HUT = "HUT"
"""HUT w/ VNS on or off"""
"""Head-up tilt with VNS ON or OFF"""

BASELINE_IHG = "BASEIHG"

Expand Down Expand Up @@ -60,3 +60,7 @@ class VnsStatus(str, Enum):

OFF = "OFF"
"""VNS is off"""


CSV_DELIMITER = "|"
"""The delimiter used in the CSV file."""
116 changes: 95 additions & 21 deletions src/reveal_data_client/stim_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,54 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Dict, Sequence
from typing import Dict, Mapping

import pandas as pd

from reveal_data_client.constants import AnsPeriod
from reveal_data_client.constants import CSV_DELIMITER, AnsPeriod, VisitID
from reveal_data_client.types import ParticipantId

STIM_SETTING_FILE_PATH = "primary/stim_settings.csv"
STIM_OPTION_MAPPING_FILE_PATH = "docs/rc_rand_vns_stim_parameters.csv"
PARTICIPANT_MAPPING_FILE_PATH = "docs/rc_rand_ans_participant_stim_settings.csv"


class StimSettingCsvColumn(str, Enum):
class StimOption(str, Enum):
"""
Enum to represent the 6 different stimulation options.
"""

A = "A"
B = "B"
C = "C"
D = "D"
E = "E"
F = "F"


class StimOptionMappingCsvColumn(str, Enum):
"""
Enum to represent the columns in the CSV file containing the stimulation settings.
"""

STIM_OPTION = "stim_option"
PULSE_WIDTH = "pulsewidth_ms"
PULSE_WIDTH = "pulse_width_ms"
CURRENT = "level_ma"
FREQUENCY = "freq_hz"
DUTY_CYCLE_OFF = "duty_minoff"
DUTY_CYCLE_OFF = "duty_cycle_off_min"


class ParticipantMappingCsvColumn(str, Enum):
"""
Enum to represent the columns in the CSV file containing the participant mapping.
"""

PARTICIPANT_ID = "Participant_ID"
SV1_STIM1 = "SV1_STIM1"
SV1_STIM2 = "SV1_STIM2"
SV1_STIM3 = "SV1_STIM3"
SV2_STIM1 = "SV2_STIM1"
SV2_STIM2 = "SV2_STIM2"
SV2_STIM3 = "SV2_STIM3"


@dataclass(frozen=True)
Expand All @@ -30,9 +59,6 @@ class StimSetting:
Data class to represent a stimulation setting.
"""

stim_option: int
"""The stimulation option. A value between 1 and 6."""

pulse_width: float
"""The pulse width in milliseconds."""

Expand All @@ -53,21 +79,69 @@ class StimSetting:
"""A mapping from ANS period to stimulation settings."""


def get_stim_settings(dataset_path: Path) -> Sequence[StimSetting]:
def get_ans_stim_mapping(
dataset_path: Path,
) -> Mapping[tuple[ParticipantId, VisitID, AnsPeriod], StimSetting]:
"""
Gets the mapping from participant ID, visit ID, and ANS period to stimulation settings.
:param dataset_path: The path to the root directory of the dataset.
:return: A mapping from participant ID, visit ID, and ANS period to stimulation settings.
"""
stim_option_mapping = _get_stim_option_mapping(dataset_path)
participant_mapping = _get_participant_mapping(dataset_path)

return {
(participant_id, visit_id, ans_period): stim_option_mapping[stim_option]
for (participant_id, visit_id, ans_period), stim_option in participant_mapping.items()
}


def _get_stim_option_mapping(dataset_path: Path) -> Mapping[StimOption, StimSetting]:
"""
Get the stimulation settings from the CSV file.
Gets the mapping from stimulation option to stimulation settings.
:param dataset_path: The path to the root directory of the dataset.
:return: A sequence of stimulation settings.
:return: A mapping from stimulation option to stimulation settings.
"""
stim_settings = pd.read_csv(dataset_path / STIM_SETTING_FILE_PATH)
return [
StimSetting(
stim_option=row[StimSettingCsvColumn.STIM_OPTION],
pulse_width=row[StimSettingCsvColumn.PULSE_WIDTH],
current=row[StimSettingCsvColumn.CURRENT],
frequency=row[StimSettingCsvColumn.FREQUENCY],
duty_cycle_off=row[StimSettingCsvColumn.DUTY_CYCLE_OFF],
stim_settings = pd.read_csv(
dataset_path / STIM_OPTION_MAPPING_FILE_PATH, delimiter=CSV_DELIMITER
)
return {
StimOption(row[StimOptionMappingCsvColumn.STIM_OPTION]): StimSetting(
pulse_width=row[StimOptionMappingCsvColumn.PULSE_WIDTH],
current=row[StimOptionMappingCsvColumn.CURRENT],
frequency=row[StimOptionMappingCsvColumn.FREQUENCY],
duty_cycle_off=row[StimOptionMappingCsvColumn.DUTY_CYCLE_OFF],
)
for _, row in stim_settings.iterrows()
]
}


def _get_participant_mapping(
dataset_path: Path,
) -> Mapping[tuple[ParticipantId, VisitID, AnsPeriod], StimOption]:
"""
Gets the mapping from participant ID, visit ID, and ANS period to stim option.
:param dataset_path: The path to the root directory of the dataset.
:return: A mapping from participant ID, visit ID, and ANS period to stim option.
"""

participant_mapping = pd.read_csv(
dataset_path / PARTICIPANT_MAPPING_FILE_PATH, delimiter=CSV_DELIMITER
)

mapping = {}

col = ParticipantMappingCsvColumn
for _, row in participant_mapping.iterrows():
participant_id = row[col.PARTICIPANT_ID]
mapping[(participant_id, VisitID.SV1, AnsPeriod.STIM1)] = StimOption(row[col.SV1_STIM1])
mapping[(participant_id, VisitID.SV1, AnsPeriod.STIM2)] = StimOption(row[col.SV1_STIM2])
mapping[(participant_id, VisitID.SV1, AnsPeriod.STIM3)] = StimOption(row[col.SV1_STIM3])
mapping[(participant_id, VisitID.SV2, AnsPeriod.STIM1)] = StimOption(row[col.SV2_STIM1])
mapping[(participant_id, VisitID.SV2, AnsPeriod.STIM2)] = StimOption(row[col.SV2_STIM2])
mapping[(participant_id, VisitID.SV2, AnsPeriod.STIM3)] = StimOption(row[col.SV2_STIM3])

return mapping
32 changes: 25 additions & 7 deletions src/reveal_data_client/time_series/coarse/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

import pandas as pd

from reveal_data_client.constants import AnsPeriod, VisitID, VnsStatus
from reveal_data_client.constants import CSV_DELIMITER, AnsPeriod, VisitID, VnsStatus
from reveal_data_client.stim_setting import StimSetting, get_ans_stim_mapping
from reveal_data_client.time_series.api import TimeSeriesClient
from reveal_data_client.time_series.coarse.constants import CsvColumn
from reveal_data_client.time_series.coarse.utils import extract_participant_id
from reveal_data_client.types import ParticipantId

PRIMARY_DIR = Path("primary")
ANS_DIR = Path("ANS")
Expand All @@ -33,8 +35,9 @@ def __init__(self, dataset_path: Path) -> None:
:param dataset_path: The path to the root directory of the dataset.
"""
self._participants_path = dataset_path / PRIMARY_DIR
self._stim_mapping = dict(get_ans_stim_mapping(dataset_path))

def get_participant_ids(self) -> Sequence[str]:
def get_participant_ids(self) -> Sequence[ParticipantId]:
participant_folders = [f for f in self._participants_path.iterdir() if f.is_dir()]
ids = []
for participant_folder in participant_folders:
Expand All @@ -48,15 +51,15 @@ def get_participant_ids(self) -> Sequence[str]:
)
return ids

def get_visit_ids(self, participant_id: str) -> Sequence[VisitID]:
def get_visit_ids(self, participant_id: ParticipantId) -> Sequence[VisitID]:
# TODO: Get the visit IDs from the files in the "By Period" folder once available.
# For now, load the visit IDs from the data file.
data = self._load_data(participant_id)

return [VisitID(id) for id in data[CsvColumn.VISIT_ID].unique()]

def get_ans_periods_and_vns_status(
self, participant_id: str, visit_id: VisitID
self, participant_id: ParticipantId, visit_id: VisitID
) -> Sequence[tuple[AnsPeriod, VnsStatus]]:
"""Gets unique pairs of ANS periods and VNS status for a participant and visit."""
data = self._load_data(participant_id)
Expand All @@ -69,7 +72,11 @@ def get_ans_periods_and_vns_status(
]

def get_data_for_ans_period(
self, participant_id: str, visit_id: VisitID, ans_period: AnsPeriod, vns_status: VnsStatus
self,
participant_id: ParticipantId,
visit_id: VisitID,
ans_period: AnsPeriod,
vns_status: VnsStatus,
) -> pd.DataFrame:
data = self._load_data(participant_id)
return data[
Expand All @@ -79,8 +86,19 @@ def get_data_for_ans_period(
& (data[CsvColumn.ANS_STATUS] == vns_status)
]

def get_stim_setting(
self, participant_id: ParticipantId, visit_id: VisitID, ans_period: AnsPeriod
) -> StimSetting:
stim_setting = self._stim_mapping.get((participant_id, visit_id, ans_period))
if stim_setting is None:
raise ValueError(
f"No stimulation setting found for participant {participant_id}, visit {visit_id}, "
f"and ANS period {ans_period}."
)
return stim_setting

@lru_cache(maxsize=MAX_CACHE_SIZE)
def _load_data(self, participant_id: str) -> pd.DataFrame:
def _load_data(self, participant_id: ParticipantId) -> pd.DataFrame:

file_path = (
self._participants_path
Expand All @@ -93,7 +111,7 @@ def _load_data(self, participant_id: str) -> pd.DataFrame:
# we should use a more memory-efficient approach. The current sample file
# is 200MB, which is fine, but we should validate that the full dataset
# is not too large to fit into memory.
df = pd.read_csv(file_path, delimiter="|")
df = pd.read_csv(file_path, delimiter=CSV_DELIMITER)

# Convert the time in seconds to a timedelta and set it as the index
df[CsvColumn.index_col()] = pd.to_timedelta(df[CsvColumn.index_col()], unit="s")
Expand Down
6 changes: 4 additions & 2 deletions src/reveal_data_client/time_series/coarse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import re

from reveal_data_client.types import ParticipantId

def extract_participant_id(folder_name: str) -> str:

def extract_participant_id(folder_name: str) -> ParticipantId:
"""
Extracts the participant ID from a folder name. e.g. sub-Sample-001 -> Sample-001
Expand All @@ -13,6 +15,6 @@ def extract_participant_id(folder_name: str) -> str:
"""
match = re.search(r"sub-(\S+)", folder_name)
if match:
return match.group(1)
return ParticipantId(match.group(1))
else:
raise ValueError("No ID found")
5 changes: 5 additions & 0 deletions src/reveal_data_client/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""This module defines types specific to the Reveal data client."""

from typing import NewType

ParticipantId = NewType("ParticipantId", str)

0 comments on commit a9a9807

Please sign in to comment.