diff --git a/fastmri/models/feature_varnet.py b/fastmri/models/feature_varnet.py index acc0f527..fee94505 100644 --- a/fastmri/models/feature_varnet.py +++ b/fastmri/models/feature_varnet.py @@ -127,50 +127,6 @@ def forward(self, data: Tensor) -> Tuple[Tensor, Tensor]: return mean, variance - -""" -class RunningChannelStats(nn.Module): - def __init__(self, chans: int, eps: float = 1e-14, freeze_step: int = 20000): - super().__init__() - - self.means: Tensor - self.vars: Tensor - self.current_step: Tensor - self.eps = eps - self.chans = chans - self.freeze_step = freeze_step - - self.register_buffer("current_step", torch.zeros(1, dtype=torch.int)) - self.register_buffer("means", torch.zeros(chans)) - self.register_buffer("vars", torch.zeros(chans)) - - def forward(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - if image.shape[1] != self.chans: - raise ValueError("Invalid channel number.") - - if self.current_step < self.freeze_step and self.training: - stats = image.permute(1, 0, 2, 3).reshape(self.chans, -1) - mean = stats.mean(1) - var = stats.var(1, unbiased=True) - - var = var / dist.get_world_size() - self.means.copy_(self.means + (mean - self.means) / (self.current_step + 1)) - self.vars.copy_(self.vars + (var - self.vars) / (self.current_step + 1)) - - self.current_step += 1 - - if self.current_step == 0 and not self.training: - stats = image.permute(1, 0, 2, 3).reshape(self.chans, -1) - run_mean = stats.mean(1).view(1, -1, 1, 1) - run_var = (stats.var(1, unbiased=True) + self.eps).view(1, -1, 1, 1) - else: - run_mean = self.means.clone().view(1, -1, 1, 1) - run_var = self.vars.clone().view(1, -1, 1, 1) + self.eps - - return run_mean, run_var -""" - - class FeatureImage(NamedTuple): features: Tensor sens_maps: Optional[Tensor] = None