Skip to content

Commit

Permalink
removes unused GoogleNet
Browse files Browse the repository at this point in the history
  • Loading branch information
Frank Guibert committed Jul 4, 2024
1 parent 7b80aee commit 43576ca
Showing 1 changed file with 0 additions and 33 deletions.
33 changes: 0 additions & 33 deletions mfai/torch/models/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,36 +221,3 @@ def forward(self, x: torch.Tensor):
y_hat = self.fc(y_hat)
return y_hat


@dataclass_json
@dataclass(slots=True)
class GoogleNetSettings:
encoder_weights: bool = None


class GoogleNet(torch.nn.Module):
settings_kls = GoogleNetSettings

def __init__(
self,
num_channels: int = 3,
num_classes: int = 1000,
input_shape: Union[None, Tuple[int, int]] = None,
settings: GoogleNetSettings = GoogleNetSettings(),
):
super().__init__()
self.googlenet = models.googlenet(weights=settings.encoder_weights)

# Modification to allow inputs with more than 3 channels
self.googlenet.conv1 = nn.Conv2d(
num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
)

in_features = self.googlenet.fc.in_features
self.googlenet.fc = nn.Linear(in_features, num_classes)

def forward(self, x: torch.Tensor):
y_hat = self.googlenet(x)
if isinstance(y_hat, models.GoogLeNetOutputs):
y_hat = y_hat.logits
return y_hat

0 comments on commit 43576ca

Please sign in to comment.