forked from zhangxd0530/MS-DSA-NET
-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_model.py
38 lines (30 loc) · 1.36 KB
/
get_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
import torch
from thop import profile, clever_format
from networks2 import MS_DSA_NET
from monai.networks.layers.factories import Norm,Act
def get_model(params):
if params['patch_size'][0] == params['patch_size'][1] and params['patch_size'][0] == params['patch_size'][2]:
str_ps = 'ps{}'.format(params['patch_size'][0])
else:
str_ps = 'ps{}x{}x{}'.format(params['patch_size'][0], params['patch_size'][1], params['patch_size'][2])
if 'MS_DSA_NET' == params['model_type']: #dual-attention
model = MS_DSA_NET(
spatial_dims=3,
in_channels=params['chans_in'],
out_channels=params['chans_out'],
img_size=params['patch_size'],
feature_size=params['feature_size'],
pos_embed=True,
project_size= params['project_size'],
sa_type=params['sa_type'],
norm_name= 'instance', #'batch', #
act_name= ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), #'relu',
res_block=True,
bias= False, #False,
dropout_rate=0.1,
)
sub_dir = '{}_{}_fs{}'.format(model.name, str_ps, params['feature_size'])
save_dir = os.path.join(params['base_dir'] , sub_dir)
params['save_dir'] = save_dir
return model, params