You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import torch
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
self.mask1 = self._masking(64, 64, depth=4)
self.mask2 = self._masking(128, 128, depth=3)
self.mask3 = self._masking(256, 256, depth=2)
self.mask4 = self._masking(512, 512, depth=1)
def _masking(self, in_channels, out_channels, depth):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
*[
nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
) for _ in range(depth - 1)
]
)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
m = self.mask1(x)
x = x * (1 + m)
x = self.layer2(x)
m = self.mask2(x)
x = x * (1 + m)
x = self.layer3(x)
m = self.mask3(x)
x = x * (1 + m)
x = self.layer4(x)
m = self.mask4(x)
x = x * (1 + m)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
self.mask1 = self._masking(256, 256, depth=4)
self.mask2 = self._masking(512, 512, depth=3)
self.mask3 = self._masking(1024, 1024, depth=2)
self.mask4 = self._masking(2048, 2048, depth=1)
def _masking(self, in_channels, out_channels, depth):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
*[
nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
) for _ in range(depth - 1)
]
)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
m = self.mask1(x)
x = x * (1 + m)
x = self.layer2(x)
m = self.mask2(x)
x = x * (1 + m)
x = self.layer3(x)
m = self.mask3(x)
x = x * (1 + m)
x = self.layer4(x)
m = self.mask4(x)
x = x * (1 + m)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
TypeError: ResMasking.forward() got an unexpected keyword argument 'in_channels'
main.py
def main(config_path):
"""
This is the main function to make the training up
Parameters:
-----------
config_path : srt
path to config file
"""
# load configs and set random seed
configs = json.load(open(config_path))
configs["cwd"] = os.getcwd()
# load model and data_loader
model = get_model(configs)
train_set, val_set, test_set = get_dataset(configs)
# init trainer and make a training
# from trainers.fer2013_trainer import FER2013Trainer
# from trainers.centerloss_trainer import FER2013Trainer
trainer = FER2013Trainer(model, train_set, val_set, test_set, configs)
if configs["distributed"] == 1:
ngpus = torch.cuda.device_count()
mp.spawn(trainer.train, nprocs=ngpus, args=())
else:
trainer.train()
def get_model(configs):
# Assuming 'arch' in configs matches 'vgg19_bn_mask_pretrain'
if configs["arch"] == "resmasking_dropout3":
# Directly return the imported model architecture
model = resmasking_dropout3(
num_classes=configs["num_classes"]
)
return model
else:
# Handle case where 'arch' does not match
raise ValueError(f"Model architecture {configs['arch']} is not supported.")
def get_dataset(configs):
"""
This function get raw dataset
"""
import torch
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
model_urls = {
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
}
class ResMasking(ResNet):
def init(self, weight_path=""):
super(ResMasking, self).init(
block=BasicBlock, layers=[2, 2, 2, 2]
)
if weight_path:
state_dict = torch.load(weight_path)
self.load_state_dict(state_dict, strict=False)
else:
state_dict = load_state_dict_from_url(model_urls["resnet18"], progress=True)
self.load_state_dict(state_dict, strict=False)
self.fc = nn.Linear(512, 7)
class ResMasking50(ResNet):
def init(self, weight_path=""):
super(ResMasking50, self).init(
block=Bottleneck, layers=[3, 4, 6, 3]
)
if weight_path:
state_dict = torch.load(weight_path)
self.load_state_dict(state_dict, strict=False)
else:
state_dict = load_state_dict_from_url(model_urls["resnet50"], progress=True)
self.load_state_dict(state_dict, strict=False)
self.fc = nn.Linear(2048, 7)
def resmasking(in_channels=3, num_classes=7, weight_path=""):
return ResMasking(weight_path)
def resmasking50_dropout1(in_channels=3, num_classes=7, weight_path=""):
model = ResMasking50(weight_path)
model.fc = nn.Sequential(nn.Dropout(0.4), nn.Linear(2048, num_classes))
return model
def resmasking_dropout1(in_channels=3, num_classes=7, weight_path=""):
model = ResMasking(weight_path)
model.fc = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(512, num_classes)
)
return model
def resmasking_dropout2(in_channels=3, num_classes=7, weight_path=""):
model = ResMasking(weight_path)
model.fc = nn.Sequential(
nn.Linear(512, 128),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(128, num_classes),
)
return model
def resmasking_dropout3(in_channels=3, num_classes=7, weight_path=""):
model = ResMasking(weight_path)
model.fc = nn.Sequential(
nn.Linear(512, 512),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(512, 128),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(128, num_classes),
)
return model
TypeError: ResMasking.forward() got an unexpected keyword argument 'in_channels'
main.py
def main(config_path):
"""
This is the main function to make the training up
def get_model(configs):
# Assuming 'arch' in configs matches 'vgg19_bn_mask_pretrain'
if configs["arch"] == "resmasking_dropout3":
# Directly return the imported model architecture
model = resmasking_dropout3(
def get_dataset(configs):
"""
This function get raw dataset
"""
if name == "main":
main("/content/drive/MyDrive/Resnet/fer2013_config.json")
The text was updated successfully, but these errors were encountered: