Skip to content

Commit

Permalink
Fixture issue in a unit test.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaanaksit committed Nov 28, 2023
1 parent 6602f7e commit 2d67179
Showing 1 changed file with 11 additions and 16 deletions.
27 changes: 11 additions & 16 deletions test/test_learn_models_multi_layer_perceptron.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,16 @@ def test():
model.load_state_dict(torch.load(weights_filename))
model.eval()
print('Model weights loaded: {}'.format(weights_filename))
try:
for epoch_id in epochs:
test_loss, estimation = test_pass(image, batches, loss_function, model)
train_loss = train(image, batches, optimizer, loss_function, model)
description = 'train loss: {:.5f}, test loss:{:.5f}'.format(train_loss, test_loss)
epochs.set_description(description)
if epoch_id % save_at_every == 0:
odak.learn.tools.save_image(test_filename, estimation, cmin = 0., cmax = 1.)
torch.save(model.state_dict(), weights_filename)
print('Model weights save: {}'.format(weights_filename))
odak.learn.tools.save_image(test_filename, estimation, cmin = 0., cmax = 1.)
except:
torch.save(model.state_dict(), weights_filename)
print('Model weights save: {}'.format(weights_filename))
odak.learn.tools.save_image(test_filename, estimation, cmin = 0., cmax = 1.)
for epoch_id in epochs:
test_loss, estimation = trial(image, batches, loss_function, model)
train_loss = train(image, batches, optimizer, loss_function, model)
description = 'train loss: {:.5f}, test loss:{:.5f}'.format(train_loss, test_loss)
epochs.set_description(description)
if epoch_id % save_at_every == 0:
odak.learn.tools.save_image(test_filename, estimation, cmin = 0., cmax = 1.)
torch.save(model.state_dict(), weights_filename)
print('Model weights save: {}'.format(weights_filename))
odak.learn.tools.save_image(test_filename, estimation, cmin = 0., cmax = 1.)
assert True == True


Expand Down Expand Up @@ -76,7 +71,7 @@ def train(output_values, input_values, optimizer, loss_function, model):
return total_loss


def test_pass(output_values, input_values, loss_function, model):
def trial(output_values, input_values, loss_function, model):
estimated_image = torch.zeros_like(output_values)
for input_value in input_values:
torch.no_grad()
Expand Down

0 comments on commit 2d67179

Please sign in to comment.