Skip to content

Commit

Permalink
Re-adding fastmri/models/feature_varnet.py to ensure compatibility wi…
Browse files Browse the repository at this point in the history
…th mypy 1.1.1
  • Loading branch information
GiannakopoulosIlias authored and mmuckley committed Jul 23, 2024
1 parent 280a7a8 commit 7bf5464
Showing 1 changed file with 82 additions and 42 deletions.
124 changes: 82 additions & 42 deletions fastmri/models/feature_varnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch.nn as nn
from torch import Tensor

torch.set_float32_matmul_precision("high")
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np
Expand Down Expand Up @@ -53,9 +52,18 @@ def image_uncrop(image: Tensor, original_image: Tensor) -> Tensor:
pad_height_left, pad_width = _calc_uncrop(image.shape[-1], in_shape[-1])

try:
original_image[
..., pad_height_top:pad_height, pad_height_left:pad_width
] = image[...]
if len(in_shape) == 2: # Assuming 2D images
original_image[pad_height_top:pad_height, pad_height_left:pad_width] = image
elif len(in_shape) == 3: # Assuming 3D images with channels
original_image[
:, pad_height_top:pad_height, pad_height_left:pad_width
] = image
elif len(in_shape) == 4: # Assuming 4D images with batch size
original_image[
:, :, pad_height_top:pad_height, pad_height_left:pad_width
] = image
else:
raise RuntimeError(f"Unsupported tensor shape: {in_shape}")
except RuntimeError:
print(f"in_shape: {in_shape}, image shape: {image.shape}")
raise
Expand Down Expand Up @@ -120,6 +128,7 @@ def forward(self, data: Tensor) -> Tuple[Tensor, Tensor]:
return mean, variance


"""
class RunningChannelStats(nn.Module):
def __init__(self, chans: int, eps: float = 1e-14, freeze_step: int = 20000):
super().__init__()
Expand Down Expand Up @@ -159,16 +168,17 @@ def forward(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
run_var = self.vars.clone().view(1, -1, 1, 1) + self.eps
return run_mean, run_var
"""


class FeatureImage(NamedTuple):
features: Tensor
sens_maps: Tensor = None
sens_maps: Optional[Tensor] = None
crop_size: Optional[Tuple[int, int]] = None
means: Tensor = None
variances: Tensor = None
mask: Tensor = None
ref_kspace: Tensor = None
means: Optional[Tensor] = None
variances: Optional[Tensor] = None
mask: Optional[Tensor] = None
ref_kspace: Optional[Tensor] = None
beta: Optional[Tensor] = None
gamma: Optional[Tensor] = None

Expand Down Expand Up @@ -279,7 +289,7 @@ def forward(self, x: Tensor, accel: int) -> Tensor:
h_ = x
h_ = self.norm(h_)

pos_enc = self.get_positional_encodings(x.shape[2], x.shape[3], h_.device)
pos_enc = self.get_positional_encodings(x.shape[2], x.shape[3], h_.device.type)

h_ = h_ + pos_enc

Expand Down Expand Up @@ -434,13 +444,15 @@ def __init__(
)

if output_bias:
self.final_conv = nn.Conv2d(
in_channels=chans,
out_channels=out_chans,
kernel_size=1,
stride=1,
padding=0,
bias=True,
self.final_conv = nn.Sequential(
nn.Conv2d(
in_channels=chans,
out_channels=out_chans,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
)
else:
self.final_conv = nn.Sequential(
Expand Down Expand Up @@ -491,15 +503,24 @@ def __init__(

if child is not None:
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
self.upsample = TransposeConvBlock(
in_chans=child.out_planes, out_chans=out_planes
)
if isinstance(child, UnetLevel): # Ensure child is an instance of UnetLevel
self.upsample = TransposeConvBlock(
in_chans=child.out_planes, out_chans=out_planes
)
else:
raise TypeError("Child must be an instance of UnetLevel")

self.right_block = ConvBlock(
in_chans=2 * out_planes, out_chans=out_planes, drop_prob=drop_prob
)

def down_up(self, image: Tensor) -> Tensor:
return self.upsample(self.child(self.downsample(image)))
if self.child is None:
raise ValueError("self.child is None, cannot call down_up.")
downsampled = self.downsample(image)
child_output = self.child(downsampled)
upsampled = self.upsample(child_output)
return upsampled

def forward(self, image: Tensor) -> Tensor:
image = self.left_block(image)
Expand Down Expand Up @@ -879,7 +900,7 @@ def __init__(
self.cascades = nn.Sequential(*cascades)
self.norm_fn = NormStats()

def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]:
def _decode_output(self, feature_image: FeatureImage) -> Tensor:
image = self.decoder(
self.decode_norm(feature_image.features),
means=feature_image.means,
Expand All @@ -897,7 +918,7 @@ def _encode_input(
sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
image = sens_reduce(masked_kspace, sens_maps)
# detect FLAIR 203
if image.shape[-1] < crop_size[1]:
if crop_size is not None and image.shape[-1] < crop_size[1]:
crop_size = (image.shape[-1], image.shape[-1])
means, variances = self.norm_fn(image)
features = self.encoder(image, means=means, variances=variances)
Expand Down Expand Up @@ -937,8 +958,12 @@ def forward(
kspace_pred, feature_image.ref_kspace, mask, feature_image.sens_maps
)
# Divide with k-space factor and Return Final Image
kspace_pred = kspace_pred / self.kspace_mult_factor
return rss(complex_abs(ifft2c(kspace_pred)), dim=1)
kspace_pred = (
kspace_pred / self.kspace_mult_factor
) # Ensure kspace_pred is a Tensor
return rss(
complex_abs(ifft2c(kspace_pred)), dim=1
) # Ensure kspace_pred is a Tensor


class IFVarNet(nn.Module):
Expand Down Expand Up @@ -991,7 +1016,7 @@ def __init__(
self.cascades = nn.Sequential(*cascades)
self.norm_fn = NormStats()

def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]:
def _decode_output(self, feature_image: FeatureImage) -> Tensor:
image = self.decoder(
self.decode_norm(feature_image.features),
means=feature_image.means,
Expand All @@ -1009,7 +1034,7 @@ def _encode_input(
) -> FeatureImage:
image = sens_reduce(masked_kspace, sens_maps)
# detect FLAIR 203
if image.shape[-1] < crop_size[1]:
if crop_size is not None and image.shape[-1] < crop_size[1]:
crop_size = (image.shape[-1], image.shape[-1])
means, variances = self.norm_fn(image)
features = self.encoder(image, means=means, variances=variances)
Expand Down Expand Up @@ -1049,9 +1074,12 @@ def forward(
)
feature_image = self.cascades(feature_image)
kspace_pred = self._decode_output(feature_image)
kspace_pred = kspace_pred / self.kspace_mult_factor

return rss(complex_abs(ifft2c(kspace_pred)), dim=1)
kspace_pred = (
kspace_pred / self.kspace_mult_factor
) # Ensure kspace_pred is a Tensor
return rss(
complex_abs(ifft2c(kspace_pred)), dim=1
) # Ensure kspace_pred is a Tensor


class FeatureVarNet_sh_w(nn.Module):
Expand Down Expand Up @@ -1097,7 +1125,7 @@ def __init__(
self.cascades = nn.Sequential(*cascades)
self.norm_fn = NormStats()

def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]:
def _decode_output(self, feature_image: FeatureImage) -> Tensor:
image = self.decoder(
self.decode_norm(feature_image.features),
means=feature_image.means,
Expand All @@ -1115,7 +1143,7 @@ def _encode_input(
sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
image = sens_reduce(masked_kspace, sens_maps)
# detect FLAIR 203
if image.shape[-1] < crop_size[1]:
if crop_size is not None and image.shape[-1] < crop_size[1]:
crop_size = (image.shape[-1], image.shape[-1])
means, variances = self.norm_fn(image)
features = self.encoder(image, means=means, variances=variances)
Expand Down Expand Up @@ -1150,8 +1178,12 @@ def forward(
# Find last k-space
kspace_pred = self._decode_output(feature_image)
# Return Final Image
kspace_pred = kspace_pred / self.kspace_mult_factor
return rss(complex_abs(ifft2c(kspace_pred)), dim=1)
kspace_pred = (
kspace_pred / self.kspace_mult_factor
) # Ensure kspace_pred is a Tensor
return rss(
complex_abs(ifft2c(kspace_pred)), dim=1
) # Ensure kspace_pred is a Tensor


class FeatureVarNet_n_sh_w(nn.Module):
Expand Down Expand Up @@ -1197,7 +1229,7 @@ def __init__(
self.cascades = nn.Sequential(*cascades)
self.norm_fn = NormStats()

def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]:
def _decode_output(self, feature_image: FeatureImage) -> Tensor:
image = self.decoder(
self.decode_norm(feature_image.features),
means=feature_image.means,
Expand All @@ -1215,7 +1247,7 @@ def _encode_input(
sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
image = sens_reduce(masked_kspace, sens_maps)
# detect FLAIR 203
if image.shape[-1] < crop_size[1]:
if crop_size is not None and image.shape[-1] < crop_size[1]:
crop_size = (image.shape[-1], image.shape[-1])
means, variances = self.norm_fn(image)
features = self.encoder(image, means=means, variances=variances)
Expand Down Expand Up @@ -1250,8 +1282,12 @@ def forward(
# Find last k-space
kspace_pred = self._decode_output(feature_image)
# Return Final Image
kspace_pred = kspace_pred / self.kspace_mult_factor
return rss(complex_abs(ifft2c(kspace_pred)), dim=1)
kspace_pred = (
kspace_pred / self.kspace_mult_factor
) # Ensure kspace_pred is a Tensor
return rss(
complex_abs(ifft2c(kspace_pred)), dim=1
) # Ensure kspace_pred is a Tensor


class AttentionFeatureVarNet_n_sh_w(nn.Module):
Expand Down Expand Up @@ -1300,7 +1336,7 @@ def __init__(
self.cascades = nn.Sequential(*cascades)
self.norm_fn = NormStats()

def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]:
def _decode_output(self, feature_image: FeatureImage) -> Tensor:
image = self.decoder(
self.decode_norm(feature_image.features),
means=feature_image.means,
Expand All @@ -1318,7 +1354,7 @@ def _encode_input(
sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
image = sens_reduce(masked_kspace, sens_maps)
# detect FLAIR 203
if image.shape[-1] < crop_size[1]:
if crop_size is not None and image.shape[-1] < crop_size[1]:
crop_size = (image.shape[-1], image.shape[-1])
means, variances = self.norm_fn(image)
features = self.encoder(image, means=means, variances=variances)
Expand Down Expand Up @@ -1353,8 +1389,12 @@ def forward(
# Find last k-space
kspace_pred = self._decode_output(feature_image)
# Return Final Image
kspace_pred = kspace_pred / self.kspace_mult_factor
return rss(complex_abs(ifft2c(kspace_pred)), dim=1)
kspace_pred = (
kspace_pred / self.kspace_mult_factor
) # Ensure kspace_pred is a Tensor
return rss(
complex_abs(ifft2c(kspace_pred)), dim=1
) # Ensure kspace_pred is a Tensor


class E2EVarNet(nn.Module):
Expand Down

0 comments on commit 7bf5464

Please sign in to comment.