Skip to content

Commit

Permalink
feat: Allow providing custom label weights.
Browse files Browse the repository at this point in the history
feat: Filter out class labels below threshold when stratified.
  • Loading branch information
apage224 committed Oct 10, 2024
1 parent 131d307 commit 2e57d0e
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 4 deletions.
23 changes: 23 additions & 0 deletions heartkit/datasets/lsad.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def split_train_test_patients(
test_size: float,
label_map: dict[int, int] | None = None,
label_type: str | None = None,
label_threshold: int | None = 2,
) -> list[list[int]]:
"""Perform train/test split on patients for given task.
NOTE: We only perform inter-patient splits and not intra-patient.
Expand All @@ -431,6 +432,7 @@ def split_train_test_patients(
test_size (float): Test size
label_map (dict[int, int], optional): Label map. Defaults to None.
label_type (str, optional): Label type. Defaults to None.
label_threshold (int, optional): Label threshold. Defaults to 2.
Returns:
list[list[int]]: Train and test sets of patient ids
Expand All @@ -440,16 +442,37 @@ def split_train_test_patients(
patients_labels = self.get_patients_labels(patient_ids, label_map=label_map, label_type=label_type)
# Select random label for stratification or -1 if no labels
stratify = np.array([random.choice(x) if len(x) > 0 else -1 for x in patients_labels])

# Remove patients w/ label counts below threshold
for i, label in enumerate(sorted(set(label_map.values()))):
class_counts = np.sum(stratify == label)
if label_threshold is not None and class_counts < label_threshold:
stratify[stratify == label] = -1
logger.warning(f"Removed class {label} w/ only {class_counts} samples")
# END IF
# END FOR

# Remove patients w/o labels
neg_mask = stratify == -1
stratify = stratify[~neg_mask]
patient_ids = patient_ids[~neg_mask]

num_neg = neg_mask.sum()
if num_neg > 0:
logger.debug(f"Removed {num_neg} patients w/ no target class")
# END IF
# END IF

# Get occurence of each class along with class index
if stratify is not None:
class_counts = np.zeros(len(label_map), dtype=np.int32)
logger.debug(f"[{self.name}] Stratify class counts:")
for i, label in enumerate(sorted(set(label_map.values()))):
class_counts = np.sum(stratify == label)
logger.debug(f"Class {label}: {class_counts}")
# END FOR
# END IF

return sklearn.model_selection.train_test_split(
patient_ids,
test_size=test_size,
Expand Down
22 changes: 22 additions & 0 deletions heartkit/datasets/ptbxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ def split_train_test_patients(
test_size: float,
label_map: dict[int, int] | None = None,
label_type: str | None = None,
label_threshold: int | None = 2,
) -> list[list[int]]:
"""Perform train/test split on patients for given task.
NOTE: We only perform inter-patient splits and not intra-patient.
Expand All @@ -515,6 +516,7 @@ def split_train_test_patients(
test_size (float): Test size
label_map (dict[int, int], optional): Label map. Defaults to None.
label_type (str, optional): Label type. Defaults to None.
label_threshold (int, optional): Label threshold. Defaults to 2.
Returns:
list[list[int]]: Train and test sets of patient ids
Expand All @@ -524,6 +526,16 @@ def split_train_test_patients(
patients_labels = self.get_patients_labels(patient_ids, label_map=label_map, label_type=label_type)
# Select random label for stratification or -1 if no labels
stratify = np.array([random.choice(x) if len(x) > 0 else -1 for x in patients_labels])

# Remove patients w/ label counts below threshold
for i, label in enumerate(sorted(set(label_map.values()))):
class_counts = np.sum(stratify == label)
if label_threshold is not None and class_counts < label_threshold:
stratify[stratify == label] = -1
logger.warning(f"Removed class {label} w/ only {class_counts} samples")
# END IF
# END FOR

# Remove patients w/o labels
neg_mask = stratify == -1
stratify = stratify[~neg_mask]
Expand All @@ -534,6 +546,16 @@ def split_train_test_patients(
# END IF
# END IF

# Get occurence of each class along with class index
if stratify is not None:
class_counts = np.zeros(len(label_map), dtype=np.int32)
logger.debug(f"[{self.name}] Stratify class counts:")
for i, label in enumerate(sorted(set(label_map.values()))):
class_counts = np.sum(stratify == label)
logger.debug(f"Class {label}: {class_counts}")
# END FOR
# END IF

return sklearn.model_selection.train_test_split(
patient_ids,
test_size=test_size,
Expand Down
2 changes: 1 addition & 1 deletion heartkit/defines.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class HKTaskParams(BaseModel, extra="allow"):
steps_per_epoch: int = Field(10, description="Number of steps per epoch")
val_steps_per_epoch: int = Field(10, description="Number of validation steps")
val_metric: Literal["loss", "acc", "f1"] = Field("loss", description="Performance metric")
class_weights: Literal["balanced", "fixed"] = Field("fixed", description="Class weights")
class_weights: list[float] | str = Field("fixed", description="Class weights")

# Evaluation arguments
threshold: float | None = Field(None, description="Model output threshold")
Expand Down
6 changes: 5 additions & 1 deletion heartkit/tasks/diagnostic/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ def train(params: HKTaskParams):
val_ds.save(str(params.val_file))

class_weights = 0.25
if params.class_weights == "balanced":
if isinstance(params.class_weights, list):
class_weights = np.array(params.class_weights)
class_weights = class_weights / class_weights.sum()
class_weights = class_weights.tolist()
elif params.class_weights == "balanced":
n_samples = np.sum(y_true)
class_weights = n_samples / (params.num_classes * np.sum(y_true, axis=0))
# class_weights = (class_weights + class_weights.mean()) / 2 # Smooth out
Expand Down
6 changes: 5 additions & 1 deletion heartkit/tasks/rhythm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def train(params: HKTaskParams):
val_ds.save(str(params.val_file))

class_weights = 0.25
if params.class_weights == "balanced":
if isinstance(params.class_weights, list):
class_weights = np.array(params.class_weights)
class_weights = class_weights / class_weights.sum()
class_weights = class_weights.tolist()
elif params.class_weights == "balanced":
class_weights = sklearn.utils.compute_class_weight("balanced", classes=np.array(classes), y=y_true)
class_weights = (class_weights + class_weights.mean()) / 2 # Smooth out
class_weights = class_weights.tolist()
Expand Down
6 changes: 5 additions & 1 deletion heartkit/tasks/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ def train(params: HKTaskParams):
val_ds.save(str(params.val_file))

class_weights = 0.25
if params.class_weights == "balanced":
if isinstance(params.class_weights, list):
class_weights = np.array(params.class_weights)
class_weights = class_weights / class_weights.sum()
class_weights = class_weights.tolist()
elif params.class_weights == "balanced":
class_weights = sklearn.utils.compute_class_weight("balanced", classes=np.array(classes), y=y_true)
class_weights = (class_weights + class_weights.mean()) / 2 # Smooth out
class_weights = class_weights.tolist()
Expand Down

0 comments on commit 2e57d0e

Please sign in to comment.