diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 2e733c4f6c..b860c62e9e 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -679,18 +679,22 @@ def __init__(self, weights: Sequence[float] | NdarrayOrTensor | None = None) -> self.weights = torch.as_tensor(weights, dtype=torch.float) if weights is not None else None def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayOrTensor: - img_ = self.get_stacked_torch(img) - if self.weights is not None: - self.weights = self.weights.to(img_.device) - shape = tuple(self.weights.shape) - for _ in range(img_.ndimension() - self.weights.ndimension()): - shape += (1,) - weights = self.weights.reshape(*shape) + out_pt = None + total_weight = 0.0 - img_ = img_ * weights / weights.mean(dim=0, keepdim=True) + for i, pred in enumerate(img): + pred = torch.as_tensor(pred) + if out_pt is None: + out_pt = torch.zeros_like(pred) - out_pt = torch.mean(img_, dim=0) - return self.post_convert(out_pt, img) + if self.weights is not None: + weight = self.weights[i].to(pred.device) + + out_pt += pred * weight + total_weight += weight + + out_pt /= total_weight + return post_convert(out_pt, img) class VoteEnsemble(Ensemble, Transform):