From 2e57d0e22a743af8f1ab6e79df7df104299ae952 Mon Sep 17 00:00:00 2001 From: Adam Page Date: Thu, 10 Oct 2024 21:46:33 +0000 Subject: [PATCH] feat: Allow providing custom label weights. feat: Filter out class labels below threshold when stratified. --- heartkit/datasets/lsad.py | 23 +++++++++++++++++++++++ heartkit/datasets/ptbxl.py | 22 ++++++++++++++++++++++ heartkit/defines.py | 2 +- heartkit/tasks/diagnostic/train.py | 6 +++++- heartkit/tasks/rhythm/train.py | 6 +++++- heartkit/tasks/segmentation/train.py | 6 +++++- 6 files changed, 61 insertions(+), 4 deletions(-) diff --git a/heartkit/datasets/lsad.py b/heartkit/datasets/lsad.py index 67518cd..e91e83e 100644 --- a/heartkit/datasets/lsad.py +++ b/heartkit/datasets/lsad.py @@ -422,6 +422,7 @@ def split_train_test_patients( test_size: float, label_map: dict[int, int] | None = None, label_type: str | None = None, + label_threshold: int | None = 2, ) -> list[list[int]]: """Perform train/test split on patients for given task. NOTE: We only perform inter-patient splits and not intra-patient. @@ -431,6 +432,7 @@ def split_train_test_patients( test_size (float): Test size label_map (dict[int, int], optional): Label map. Defaults to None. label_type (str, optional): Label type. Defaults to None. + label_threshold (int, optional): Label threshold. Defaults to 2. Returns: list[list[int]]: Train and test sets of patient ids @@ -440,16 +442,37 @@ def split_train_test_patients( patients_labels = self.get_patients_labels(patient_ids, label_map=label_map, label_type=label_type) # Select random label for stratification or -1 if no labels stratify = np.array([random.choice(x) if len(x) > 0 else -1 for x in patients_labels]) + + # Remove patients w/ label counts below threshold + for i, label in enumerate(sorted(set(label_map.values()))): + class_counts = np.sum(stratify == label) + if label_threshold is not None and class_counts < label_threshold: + stratify[stratify == label] = -1 + logger.warning(f"Removed class {label} w/ only {class_counts} samples") + # END IF + # END FOR + # Remove patients w/o labels neg_mask = stratify == -1 stratify = stratify[~neg_mask] patient_ids = patient_ids[~neg_mask] + num_neg = neg_mask.sum() if num_neg > 0: logger.debug(f"Removed {num_neg} patients w/ no target class") # END IF # END IF + # Get occurence of each class along with class index + if stratify is not None: + class_counts = np.zeros(len(label_map), dtype=np.int32) + logger.debug(f"[{self.name}] Stratify class counts:") + for i, label in enumerate(sorted(set(label_map.values()))): + class_counts = np.sum(stratify == label) + logger.debug(f"Class {label}: {class_counts}") + # END FOR + # END IF + return sklearn.model_selection.train_test_split( patient_ids, test_size=test_size, diff --git a/heartkit/datasets/ptbxl.py b/heartkit/datasets/ptbxl.py index 68d6764..50e9079 100644 --- a/heartkit/datasets/ptbxl.py +++ b/heartkit/datasets/ptbxl.py @@ -506,6 +506,7 @@ def split_train_test_patients( test_size: float, label_map: dict[int, int] | None = None, label_type: str | None = None, + label_threshold: int | None = 2, ) -> list[list[int]]: """Perform train/test split on patients for given task. NOTE: We only perform inter-patient splits and not intra-patient. @@ -515,6 +516,7 @@ def split_train_test_patients( test_size (float): Test size label_map (dict[int, int], optional): Label map. Defaults to None. label_type (str, optional): Label type. Defaults to None. + label_threshold (int, optional): Label threshold. Defaults to 2. Returns: list[list[int]]: Train and test sets of patient ids @@ -524,6 +526,16 @@ def split_train_test_patients( patients_labels = self.get_patients_labels(patient_ids, label_map=label_map, label_type=label_type) # Select random label for stratification or -1 if no labels stratify = np.array([random.choice(x) if len(x) > 0 else -1 for x in patients_labels]) + + # Remove patients w/ label counts below threshold + for i, label in enumerate(sorted(set(label_map.values()))): + class_counts = np.sum(stratify == label) + if label_threshold is not None and class_counts < label_threshold: + stratify[stratify == label] = -1 + logger.warning(f"Removed class {label} w/ only {class_counts} samples") + # END IF + # END FOR + # Remove patients w/o labels neg_mask = stratify == -1 stratify = stratify[~neg_mask] @@ -534,6 +546,16 @@ def split_train_test_patients( # END IF # END IF + # Get occurence of each class along with class index + if stratify is not None: + class_counts = np.zeros(len(label_map), dtype=np.int32) + logger.debug(f"[{self.name}] Stratify class counts:") + for i, label in enumerate(sorted(set(label_map.values()))): + class_counts = np.sum(stratify == label) + logger.debug(f"Class {label}: {class_counts}") + # END FOR + # END IF + return sklearn.model_selection.train_test_split( patient_ids, test_size=test_size, diff --git a/heartkit/defines.py b/heartkit/defines.py index e3fde95..ce0db77 100644 --- a/heartkit/defines.py +++ b/heartkit/defines.py @@ -118,7 +118,7 @@ class HKTaskParams(BaseModel, extra="allow"): steps_per_epoch: int = Field(10, description="Number of steps per epoch") val_steps_per_epoch: int = Field(10, description="Number of validation steps") val_metric: Literal["loss", "acc", "f1"] = Field("loss", description="Performance metric") - class_weights: Literal["balanced", "fixed"] = Field("fixed", description="Class weights") + class_weights: list[float] | str = Field("fixed", description="Class weights") # Evaluation arguments threshold: float | None = Field(None, description="Model output threshold") diff --git a/heartkit/tasks/diagnostic/train.py b/heartkit/tasks/diagnostic/train.py index 28f9c0b..6e3f456 100644 --- a/heartkit/tasks/diagnostic/train.py +++ b/heartkit/tasks/diagnostic/train.py @@ -58,7 +58,11 @@ def train(params: HKTaskParams): val_ds.save(str(params.val_file)) class_weights = 0.25 - if params.class_weights == "balanced": + if isinstance(params.class_weights, list): + class_weights = np.array(params.class_weights) + class_weights = class_weights / class_weights.sum() + class_weights = class_weights.tolist() + elif params.class_weights == "balanced": n_samples = np.sum(y_true) class_weights = n_samples / (params.num_classes * np.sum(y_true, axis=0)) # class_weights = (class_weights + class_weights.mean()) / 2 # Smooth out diff --git a/heartkit/tasks/rhythm/train.py b/heartkit/tasks/rhythm/train.py index a63d0bf..35489ba 100644 --- a/heartkit/tasks/rhythm/train.py +++ b/heartkit/tasks/rhythm/train.py @@ -59,7 +59,11 @@ def train(params: HKTaskParams): val_ds.save(str(params.val_file)) class_weights = 0.25 - if params.class_weights == "balanced": + if isinstance(params.class_weights, list): + class_weights = np.array(params.class_weights) + class_weights = class_weights / class_weights.sum() + class_weights = class_weights.tolist() + elif params.class_weights == "balanced": class_weights = sklearn.utils.compute_class_weight("balanced", classes=np.array(classes), y=y_true) class_weights = (class_weights + class_weights.mean()) / 2 # Smooth out class_weights = class_weights.tolist() diff --git a/heartkit/tasks/segmentation/train.py b/heartkit/tasks/segmentation/train.py index b3c50ae..ea69759 100644 --- a/heartkit/tasks/segmentation/train.py +++ b/heartkit/tasks/segmentation/train.py @@ -61,7 +61,11 @@ def train(params: HKTaskParams): val_ds.save(str(params.val_file)) class_weights = 0.25 - if params.class_weights == "balanced": + if isinstance(params.class_weights, list): + class_weights = np.array(params.class_weights) + class_weights = class_weights / class_weights.sum() + class_weights = class_weights.tolist() + elif params.class_weights == "balanced": class_weights = sklearn.utils.compute_class_weight("balanced", classes=np.array(classes), y=y_true) class_weights = (class_weights + class_weights.mean()) / 2 # Smooth out class_weights = class_weights.tolist()