Skip to content

Commit

Permalink
not load optimizer for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey.vilov committed Apr 10, 2022
1 parent da2239a commit 9ec74d2
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions NNC/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,24 +266,25 @@ def collate_fn(data):

optimizer = torch.optim.AdamW(model_params, lr=input_params.learning_rate, weight_decay=input_params.weight_decay) #define optimizer

last_epoch = -1

if input_params.load_weights:

if torch.cuda.is_available():
#load on gpu
model.load_state_dict(torch.load(input_params.config_start_base + '_model'))
optimizer.load_state_dict(torch.load(input_params.config_start_base + '_optimizer'))
if not input_params.inference_mode:
optimizer.load_state_dict(torch.load(input_params.config_start_base + '_optimizer'))
else:
#load on cpu
model.load_state_dict(torch.load(input_params.config_start_base + '_model', map_location=torch.device('cpu')))
optimizer.load_state_dict(torch.load(input_params.config_start_base + '_optimizer', map_location=torch.device('cpu')))
if not input_params.inference_mode:
optimizer.load_state_dict(torch.load(input_params.config_start_base + '_optimizer', map_location=torch.device('cpu')))

last_epoch = int(input_params.config_start_base.split('_')[-2]) #infer previous epoch from input_params.config_start_base

else:

last_epoch = -1

lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
if not input_params.inference_mode:
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=[input_params.lr_sch_milestones],
gamma=input_params.lr_sch_gamma,
last_epoch=last_epoch, verbose=False) #define learning rate scheduler
Expand Down

0 comments on commit 9ec74d2

Please sign in to comment.