Skip to content

Commit

Permalink
fix: reference first element for first dropout
Browse files Browse the repository at this point in the history
Signed-off-by: Saurav Maheshkar <[email protected]>
  • Loading branch information
SauravMaheshkar committed Jul 25, 2023
1 parent 29bd67e commit a183292
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions monai/networks/nets/vnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
out_channels: int,
nconvs: int,
act: tuple[str, dict] | str,
dropout_prob: tuple[float, float] | None = [None, 0.5],
dropout_prob: tuple[float | None, float] = [None, 0.5], # noqa: B006
dropout_dim: int = 3,
):
super().__init__()
Expand All @@ -144,14 +144,14 @@ def __init__(

self.up_conv = conv_trans_type(in_channels, out_channels // 2, kernel_size=2, stride=2)
self.bn1 = norm_type(out_channels // 2)
self.dropout = dropout_type(dropout_prob[0]) if dropout_prob is not None else None
self.dropout = dropout_type(dropout_prob[0]) if dropout_prob[0] is not None else None
self.dropout2 = dropout_type(dropout_prob[1])
self.act_function1 = get_acti_layer(act, out_channels // 2)
self.act_function2 = get_acti_layer(act, out_channels)
self.ops = _make_nconv(spatial_dims, out_channels, nconvs, act)

def forward(self, x, skipx):
if self.dropout is not None:
if self.dropout[0] is not None:
out = self.dropout(x)
else:
out = x
Expand Down Expand Up @@ -225,7 +225,7 @@ def __init__(
out_channels: int = 1,
act: tuple[str, dict] | str = ("elu", {"inplace": True}),
dropout_prob_down: float = 0.5,
dropout_prob_up: tuple[float, float] = [0.5, 0.5],
dropout_prob_up: tuple[float, float] = [0.5, 0.5], # noqa: B006
dropout_dim: int = 3,
bias: bool = False,
):
Expand Down

0 comments on commit a183292

Please sign in to comment.