Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training update changes #965

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions training/tests/loss_functions
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from collections import Counter
import torch
import torch.nn as nn
import os

from typing import Union
from enum import Enum

from common.constants import UNZIPPED_DIR_NAME


class LossFunctions(Enum):
# Some common loss functions
L1LOSS = nn.L1Loss()
MSELOSS = nn.MSELoss()
BCELOSS = nn.BCELoss()
BCEWITHLOGITSLOSS = nn.BCEWithLogitsLoss()
CELOSS = nn.CrossEntropyLoss(reduction="mean")
WCELOSS = nn.CrossEntropyLoss(
reduction="mean"
) # Will not use this, just to prevent errors

def get_loss_obj(self):
return self.value


def compute_loss(loss_function_name, output, labels):
"""
Function to compute the loss. Postprocessing of output or labels depends on the loss object used

Args:
loss_function_name (str): Valid name from LossFunctions Enum
output (_type_): _description_
labels (_type_): _description_

Return:
loss(float): computed loss
"""
postprocess_output = output.clone()
postprocess_label = labels.clone()
if loss_function_name in LossFunctions._member_names_:
loss_obj = LossFunctions.get_loss_obj(LossFunctions[loss_function_name])
if (
loss_function_name.upper() == "BCELOSS"
or loss_function_name.upper() == "BCEWITHLOGITSLOSS"
):
# If target is say [20] but output is [20, 1], you need to unsqueeze target to be [20, 1] dimension
return loss_obj(
postprocess_output, postprocess_label.unsqueeze(1)
) # get the dimensions to match up.
elif (
loss_function_name.upper() == "MSELOSS"
or loss_function_name.upper() == "L1LOSS"
):
postprocess_output = torch.reshape(
postprocess_output,
(postprocess_output.shape[0], postprocess_output.shape[2]),
)

# print(f"output dims = {postprocess_output.size()}")
# print(f"label dims = {postprocess_label.size()}")
return loss_obj(postprocess_output, postprocess_label) # compute the loss
else:
postprocess_output = torch.reshape(
postprocess_output,
(postprocess_output.shape[0], postprocess_output.shape[2]),
)
postprocess_label = postprocess_label.squeeze_()
return loss_obj(
postprocess_output, postprocess_label.long()
) # compute the loss
raise Exception(
"Invalid loss function name provided. Please contact admin to request addition of it. Provide documentation of this loss function"
)


def compute_img_loss(criterion, pred, ground_truth, weights_counter):
"""
Computes CE and WCE loss. pred and y are processed to different shapes supported by the corresponding functions.
"""
loss_obj = LossFunctions.get_loss_obj(LossFunctions[criterion])

if criterion == LossFunctions.CELOSS.name:
return loss_obj(pred, ground_truth.squeeze())
if criterion == "WCELOSS":
weight_list = [0] * len(pred[0])

for i in range(len(weight_list)):
if weights_counter[i] != 0:
weight_list[i] = 1 / weights_counter[i]
# Weighting the class with least representation in dataset with maximum weight

loss = nn.CrossEntropyLoss(
weight=torch.FloatTensor(weight_list), reduction="mean"
)
return loss(pred, ground_truth.squeeze())
7 changes: 7 additions & 0 deletions training/tests/test_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def test_all_imports():
import training.core.dataset
import training.core.criterion
import training.core.optimizer

# import training.core.dl_model
import training.core.trainer
63 changes: 63 additions & 0 deletions training/tests/test_loss_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pytest
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
import torch
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
import torch.nn as nn
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
from training.core.criterion import getCriterionHandler


"""
Unit tests to check that the loss function is being computed correctly
"""


# initialize some tensors
zero_tensor = torch.zeros(10, 1, 1)
one_tensor = torch.ones(10, 1, 1)
tensor_vstack_one = torch.vstack(
[torch.tensor([[2.5], [56.245], [2342.68967]]), torch.tensor([[3], [4], [5]])]
).reshape((6, 1, 1))
tensor_vstack_two = torch.vstack(
[torch.tensor([[5646456], [634767], [37647346]]), torch.tensor([[6], [7], [8]])]
).reshape((6, 1, 1))


def compute_loss(loss_function_name, output, labels):
loss_function = getCriterionHandler(loss_function_name)
return loss_function.compute_loss(output, labels)


@pytest.mark.parametrize(
"output, labels, expected_number",
[
(zero_tensor, one_tensor, 1.0),
(tensor_vstack_one, tensor_vstack_two, 7321426),
],
)
def test_l1_loss_computation_correct(output, labels, expected_number):
loss_function_name = "L1LOSS"
computed_loss = compute_loss(loss_function_name, output, labels)
assert pytest.approx(expected_number) == computed_loss


@pytest.mark.parametrize(
"output, labels, expected_number",
[
(zero_tensor, one_tensor, 1.0),
(tensor_vstack_one, tensor_vstack_two, 15543340),
],
)
def test_mse_loss_computation_correct(output, labels, expected_number):
loss_function_name = "MSELOSS"
loss_function = getCriterionHandler("MSELOSS")
computed_loss = compute_loss(loss_function_name, output, labels)
print(torch.sqrt(computed_loss))
assert pytest.approx(expected_number) == torch.sqrt(computed_loss)


@pytest.mark.parametrize(
"output, labels, expected_number",
[(zero_tensor.reshape((10, 1)), one_tensor.reshape(10), 100)],
)
def test_bce_loss_computation_correct(output, labels, expected_number):
loss_function_name = "BCELOSS"
computed_loss = compute_loss(loss_function_name, output, labels)
assert pytest.approx(expected_number) == computed_loss
33 changes: 33 additions & 0 deletions training/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
import torch.nn as nn
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
from torch.autograd import Variable
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
from training.core.dl_model import DLModel


@pytest.mark.parametrize(
"input_list",
[
([nn.Linear(10, 5), nn.Linear(5, 3)]),
([nn.Linear(0, 0), nn.Linear(0, 0)]),
([nn.Linear(100, 50), nn.Linear(5, 3)]),
],
)
def test_dlmodel(input_list):
print("input_list: " + str(input_list) + " is of type " + str(type(input_list)))
my_model = DLModel(input_list)
print("my_model: " + str(my_model) + " is of type " + str(type(my_model)))
print(
"[module for module in my_model.model.modules() if not isinstance(module, nn.Sequential)]: "
+ str(
[
module
for module in my_model.model.modules()
if not isinstance(module, nn.Sequential)
]
)
)
assert [
module
for module in my_model.model.modules()
if not isinstance(module, nn.Sequential)
] == input_list
1 change: 1 addition & 0 deletions training/training/core/dl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def fromLayerParamsList(cls, layer_params_list: list[LayerParams]):
def build_model(self, layer_list):
model = nn.Sequential()
ctr = 1
print("Hello World")
karkir0003 marked this conversation as resolved.
Show resolved Hide resolved
for layer in layer_list:
model.add_module(f"layer #{ctr}: {str(layer.__class__.__name__)}", layer)
ctr += 1
Expand Down