diff --git a/src/anomalib/models/patchcore/lightning_model.py b/src/anomalib/models/patchcore/lightning_model.py index 00bf50dff1..16d6faa65f 100644 --- a/src/anomalib/models/patchcore/lightning_model.py +++ b/src/anomalib/models/patchcore/lightning_model.py @@ -35,10 +35,16 @@ class Patchcore(AnomalyModule): num_neighbors (int, optional): Number of nearest neighbors. Defaults to 9. pretrained_weights (str, optional): Path to pretrained weights. Defaults to None. compress_memory_bank (bool): If true the memory bank features are projected to a lower dimensionality following - the Johnson-Lindenstrauss lemma. + the Johnson-Lindenstrauss lemma. coreset_sampler (str): Coreset sampler to use. Defaults to "anomalib". score_computation (str): Score computation to use. Defaults to "anomalib". If "amazon" is used, the anomaly - score is correctly computed as from the paper but it may require more time to compute. + score is correctly computed as from the paper but it may require more time to compute. + disable_score_weighting: If true, the model will not apply the weight factor to the anomaly score. Only works if + score_computation is set to anomalib. + weight_anomaly_map: If true, the model will apply the weight factor to the whole anomaly map, this might be + useful as the anomaly score is now the max of the anomaly map. Only enabled if disable_score_weighting is + False. Only works if score_computation is set to anomalib. + anomaly_score_from_max_heatmap: If true, the anomaly score will be the max of the anomaly map. """ def __init__( @@ -53,6 +59,9 @@ def __init__( compress_memory_bank: bool = False, coreset_sampler: str = "anomalib", score_computation: str = "anomalib", + disable_score_weighting: bool = False, + weight_anomaly_map: bool = False, + anomaly_score_from_max_heatmap: bool = False, ) -> None: super().__init__() @@ -65,6 +74,9 @@ def __init__( pretrained_weights=pretrained_weights, compress_memory_bank=compress_memory_bank, score_computation=score_computation, + disable_score_weighting=disable_score_weighting, + weight_anomaly_map=weight_anomaly_map, + anomaly_score_from_max_heatmap=anomaly_score_from_max_heatmap, ) self.coreset_sampling_ratio = coreset_sampling_ratio self.embeddings: list[Tensor] = [] @@ -168,6 +180,9 @@ def __init__(self, hparams: DictConfig | ListConfig, backbone: str | nn.Module | compress_memory_bank=getattr(hparams.model, "compress_memory_bank", False), coreset_sampler=getattr(hparams.model, "coreset_sampler", "anomalib"), score_computation=getattr(hparams.model, "score_computation", "anomalib"), + disable_score_weighting=getattr(hparams.model, "disable_score_weighting", False), + weight_anomaly_map=getattr(hparams.model, "weight_anomaly_map", False), + anomaly_score_from_max_heatmap=getattr(hparams.model, "anomaly_score_from_max_heatmap", False), ) self.hparams: DictConfig | ListConfig # type: ignore self.save_hyperparameters(hparams) diff --git a/src/anomalib/models/patchcore/torch_model.py b/src/anomalib/models/patchcore/torch_model.py index c54913251b..5ad9d4bd1e 100644 --- a/src/anomalib/models/patchcore/torch_model.py +++ b/src/anomalib/models/patchcore/torch_model.py @@ -47,7 +47,12 @@ class PatchcoreModel(DynamicBufferModule, nn.Module): compress_memory_bank: If true the memory bank features are projected to a lower dimensionality following the Johnson-Lindenstrauss lemma. score_computation: Method to use for anomaly score computation either amazon or anomalib. - + disable_score_weighting: If true, the model will not apply the weight factor to the anomaly score. Only works if + score_computation is set to anomalib. + weight_anomaly_map: If true, the model will apply the weight factor to the whole anomaly map, this might be + useful as the anomaly score is now the max of the anomaly map. Only enabled if disable_score_weighting is + False. Only works if score_computation is set to anomalib. + anomaly_score_from_max_heatmap: If true, the anomaly score will be the max of the anomaly map. """ def __init__( @@ -60,6 +65,9 @@ def __init__( pretrained_weights: Optional[str] = None, compress_memory_bank: bool = False, score_computation: str = "anomalib", + disable_score_weighting: bool = False, + weight_anomaly_map: bool = False, + anomaly_score_from_max_heatmap: bool = False, ) -> None: super().__init__() self.tiler: Optional[Tiler] = None @@ -82,6 +90,9 @@ def __init__( self.memory_bank: torch.Tensor self.projection_model = SparseRandomProjection(eps=0.9) self.compress_memory_bank = compress_memory_bank + self.disable_score_weighting = disable_score_weighting + self.weight_anomaly_map = weight_anomaly_map + self.anomaly_score_from_max_heatmap = anomaly_score_from_max_heatmap log.info(f"Using {self.score_computation} score computation method") @@ -125,12 +136,20 @@ def forward(self, input_tensor: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: patch_scores, _ = self.nearest_neighbors(embedding=embedding, n_neighbors=self.num_neighbors) # Reshape patch_scores to match (batch_size, feature_dim, n_neighbours) # TODO: Verify this is correct patch_scores = patch_scores.reshape(-1, width * height, patch_scores.shape[1]) - max_scores = torch.argmax(patch_scores[:, :, 0], dim=1) - confidence = compute_confidence_scores(patch_scores, max_scores) - weights = 1 - (torch.max(torch.exp(confidence), dim=1)[0] / torch.sum(torch.exp(confidence), dim=1)) - anomaly_score = weights * torch.max(patch_scores[:, :, 0], dim=1)[0] - patch_scores = patch_scores[:, :, 0] + if not self.disable_score_weighting: + max_scores = torch.argmax(patch_scores[:, :, 0], dim=1) + confidence = compute_confidence_scores(patch_scores, max_scores) + weights = 1 - (torch.max(torch.exp(confidence), dim=1)[0] / torch.sum(torch.exp(confidence), dim=1)) + patch_scores = patch_scores[:, :, 0] + if self.weight_anomaly_map: + patch_scores = patch_scores * weights.unsqueeze(1) + anomaly_score = torch.max(patch_scores, dim=1)[0] + else: + anomaly_score = weights * torch.max(patch_scores, dim=1)[0] + else: + patch_scores = patch_scores[:, :, 0] + anomaly_score = torch.max(patch_scores, dim=1)[0] else: # apply nearest neighbor search patch_scores, locations = self.nearest_neighbors(embedding=embedding, n_neighbors=1) @@ -144,6 +163,9 @@ def forward(self, input_tensor: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: patch_scores = patch_scores.reshape((-1, 1, width, height)) # get anomaly map anomaly_map = self.anomaly_map_generator(patch_scores) + if self.anomaly_score_from_max_heatmap: + anomaly_score = anomaly_map.reshape((anomaly_map.shape[0], -1)).max(1)[0] + output = (anomaly_map, anomaly_score) return output