Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resmasking Forward Function TypeError #53

Open
tanzim10 opened this issue May 17, 2024 · 0 comments
Open

Resmasking Forward Function TypeError #53

tanzim10 opened this issue May 17, 2024 · 0 comments

Comments

@tanzim10
Copy link

import torch
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck

model_urls = {
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
}

class ResMasking(ResNet):
def init(self, weight_path=""):
super(ResMasking, self).init(
block=BasicBlock, layers=[2, 2, 2, 2]
)
if weight_path:
state_dict = torch.load(weight_path)
self.load_state_dict(state_dict, strict=False)
else:
state_dict = load_state_dict_from_url(model_urls["resnet18"], progress=True)
self.load_state_dict(state_dict, strict=False)
self.fc = nn.Linear(512, 7)

    self.mask1 = self._masking(64, 64, depth=4)
    self.mask2 = self._masking(128, 128, depth=3)
    self.mask3 = self._masking(256, 256, depth=2)
    self.mask4 = self._masking(512, 512, depth=1)

def _masking(self, in_channels, out_channels, depth):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        *[
            nn.Sequential(
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ) for _ in range(depth - 1)
        ]
    )

def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    m = self.mask1(x)
    x = x * (1 + m)

    x = self.layer2(x)
    m = self.mask2(x)
    x = x * (1 + m)

    x = self.layer3(x)
    m = self.mask3(x)
    x = x * (1 + m)

    x = self.layer4(x)
    m = self.mask4(x)
    x = x * (1 + m)

    x = self.avgpool(x)
    x = torch.flatten(x, 1)

    x = self.fc(x)
    return x

class ResMasking50(ResNet):
def init(self, weight_path=""):
super(ResMasking50, self).init(
block=Bottleneck, layers=[3, 4, 6, 3]
)
if weight_path:
state_dict = torch.load(weight_path)
self.load_state_dict(state_dict, strict=False)
else:
state_dict = load_state_dict_from_url(model_urls["resnet50"], progress=True)
self.load_state_dict(state_dict, strict=False)
self.fc = nn.Linear(2048, 7)

    self.mask1 = self._masking(256, 256, depth=4)
    self.mask2 = self._masking(512, 512, depth=3)
    self.mask3 = self._masking(1024, 1024, depth=2)
    self.mask4 = self._masking(2048, 2048, depth=1)

def _masking(self, in_channels, out_channels, depth):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        *[
            nn.Sequential(
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ) for _ in range(depth - 1)
        ]
    )

def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    m = self.mask1(x)
    x = x * (1 + m)

    x = self.layer2(x)
    m = self.mask2(x)
    x = x * (1 + m)

    x = self.layer3(x)
    m = self.mask3(x)
    x = x * (1 + m)

    x = self.layer4(x)
    m = self.mask4(x)
    x = x * (1 + m)

    x = self.avgpool(x)
    x = torch.flatten(x, 1)

    x = self.fc(x)
    return x

def resmasking(in_channels=3, num_classes=7, weight_path=""):
return ResMasking(weight_path)

def resmasking50_dropout1(in_channels=3, num_classes=7, weight_path=""):
model = ResMasking50(weight_path)
model.fc = nn.Sequential(nn.Dropout(0.4), nn.Linear(2048, num_classes))
return model

def resmasking_dropout1(in_channels=3, num_classes=7, weight_path=""):
model = ResMasking(weight_path)
model.fc = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(512, num_classes)
)
return model

def resmasking_dropout2(in_channels=3, num_classes=7, weight_path=""):
model = ResMasking(weight_path)
model.fc = nn.Sequential(
nn.Linear(512, 128),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(128, num_classes),
)
return model

def resmasking_dropout3(in_channels=3, num_classes=7, weight_path=""):
model = ResMasking(weight_path)
model.fc = nn.Sequential(
nn.Linear(512, 512),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(512, 128),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(128, num_classes),
)
return model

TypeError: ResMasking.forward() got an unexpected keyword argument 'in_channels'

main.py
def main(config_path):
"""
This is the main function to make the training up

Parameters:
-----------
config_path : srt
    path to config file
"""
# load configs and set random seed
configs = json.load(open(config_path))
configs["cwd"] = os.getcwd()

# load model and data_loader
model = get_model(configs)

train_set, val_set, test_set = get_dataset(configs)

# init trainer and make a training
# from trainers.fer2013_trainer import FER2013Trainer

# from trainers.centerloss_trainer import FER2013Trainer
trainer = FER2013Trainer(model, train_set, val_set, test_set, configs)

if configs["distributed"] == 1:
    ngpus = torch.cuda.device_count()
    mp.spawn(trainer.train, nprocs=ngpus, args=())
else:
    trainer.train()

def get_model(configs):
# Assuming 'arch' in configs matches 'vgg19_bn_mask_pretrain'
if configs["arch"] == "resmasking_dropout3":
# Directly return the imported model architecture
model = resmasking_dropout3(

        num_classes=configs["num_classes"]
    )
    return model
else:
    # Handle case where 'arch' does not match
    raise ValueError(f"Model architecture {configs['arch']} is not supported.")

def get_dataset(configs):
"""
This function get raw dataset
"""

# todo: add transform
train_set = fer2013("train", configs)
val_set = fer2013("val", configs)
test_set = fer2013("test", configs, tta=True, tta_size=10)
return train_set, val_set, test_set

if name == "main":
main("/content/drive/MyDrive/Resnet/fer2013_config.json")

@tanzim10 tanzim10 changed the title Resmasking Forward Function Resmasking Forward Function TypeError May 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant