Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Negative Data training and testing #119

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions yoeo/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@
from terminaltables import AsciiTable

import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, ConcatDataset
from torch.autograd import Variable

from yoeo.models import load_model
from yoeo.utils.utils import ap_per_class, get_batch_statistics, non_max_suppression, to_cpu, xywh2xyxy, \
print_environment_info, seg_iou
from yoeo.utils.datasets import ListDataset
from yoeo.utils.datasets import ListDataset, NegativeDataset
from yoeo.utils.transforms import DEFAULT_TRANSFORMS
from yoeo.utils.dataclasses import ClassNames
from yoeo.utils.class_config import ClassConfig
from yoeo.utils.parse_config import parse_data_config
from yoeo.utils.metric import Metric


def evaluate_model_file(model_path, weights_path, img_path, class_config, batch_size=8, img_size=416,
def evaluate_model_file(model_path, weights_path, img_path, class_config, negative_img_dir="", negative_data_fraction=0.0, batch_size=8, img_size=416,
n_cpu=8, iou_thres=0.5, conf_thres=0.5, nms_thres=0.5, verbose=True):
"""Evaluate model on validation dataset.

Expand All @@ -36,6 +36,10 @@
:type img_path: str
:param class_config: Object containing all class name related settings
:type class_config: TrainConfig
:param negative_img_dir: Path to negative image folder, defaults to ""
:type negative_img_dir: str
:param negative_data_fraction: Fraction of negative data relative to positive data, defaults to 0.0
:type negative_data_fraction: float
:param batch_size: Size of each image batch, defaults to 8
:type batch_size: int, optional
:param img_size: Size of each image dimension for yolo, defaults to 416
Expand All @@ -53,7 +57,7 @@
:return: Returns precision, recall, AP, f1, ap_class
"""
dataloader = _create_validation_data_loader(
img_path, batch_size, img_size, n_cpu)
img_path, negative_img_dir, negative_data_fraction, batch_size, img_size, n_cpu)
model = load_model(model_path, weights_path)
metrics_output, seg_class_ious, secondary_metric = _evaluate(
model,
Expand All @@ -67,10 +71,10 @@
return metrics_output, seg_class_ious, secondary_metric


def print_eval_stats(metrics_output: Optional[Tuple[np.ndarray]],

Check warning on line 74 in yoeo/test.py

View workflow job for this annotation

GitHub Actions / linter

trailing whitespace
seg_class_ious: List[np.float64],

Check warning on line 75 in yoeo/test.py

View workflow job for this annotation

GitHub Actions / linter

trailing whitespace
secondary_metric: Optional[Metric],

Check warning on line 76 in yoeo/test.py

View workflow job for this annotation

GitHub Actions / linter

trailing whitespace
class_config: ClassConfig,

Check warning on line 77 in yoeo/test.py

View workflow job for this annotation

GitHub Actions / linter

trailing whitespace
verbose: bool
):
# Print detection statistics
Expand All @@ -95,7 +99,7 @@
if verbose:
classes = class_config.get_group_class_names()
mbACC_per_class = [secondary_metric.bACC(i) for i in range(len(classes))]

Check warning on line 102 in yoeo/test.py

View workflow job for this annotation

GitHub Actions / linter

blank line contains whitespace
sec_table = [["Index", "Class", "bACC"]]
for i, c in enumerate(classes):
sec_table += [[i, c, "%.5f" % mbACC_per_class[i]]]
Expand Down Expand Up @@ -157,7 +161,7 @@
# Extract labels
labels += bb_targets[:, 1].tolist()

# If a subset of the detection classes should be grouped into one class for non-maximum suppression and the

Check warning on line 164 in yoeo/test.py

View workflow job for this annotation

GitHub Actions / linter

trailing whitespace
# subsequent AP-computation, we need to group those class labels here.
if class_config.classes_should_be_grouped():
labels = class_config.group(labels)
Expand Down Expand Up @@ -223,12 +227,16 @@
return yolo_metrics_output, seg_class_ious, secondary_metric


def _create_validation_data_loader(img_path, batch_size, img_size, n_cpu):
def _create_validation_data_loader(img_path, negative_img_dir, negative_data_fraction, batch_size, img_size, n_cpu):
"""
Creates a DataLoader for validation.

:param img_path: Path to file containing all paths to validation images.
:type img_path: str
:param negative_img_dir: Path to negative image folder
:type negative_img_dir: str
:param negative_data_fraction: Fraction of negative data relative to positive data
:type negative_data_fraction: float
:param batch_size: Size of each image batch
:type batch_size: int
:param img_size: Size of each image dimension for yolo
Expand All @@ -239,8 +247,19 @@
:rtype: DataLoader
"""
dataset = ListDataset(img_path, img_size=img_size, multiscale=False, transform=DEFAULT_TRANSFORMS)

dataset_len = len(dataset)
negative_dataset_len = int(negative_data_fraction*dataset_len)

negative_dataset = NegativeDataset(
negative_img_dir,
img_size=img_size,
transform=DEFAULT_TRANSFORMS,
negative_dataset_max_len=negative_dataset_len)

concat_dataset = ConcatDataset([dataset, negative_dataset])
dataloader = DataLoader(
dataset,
concat_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=n_cpu,
Expand All @@ -257,6 +276,8 @@
parser.add_argument("-w", "--weights", type=str, default="weights/yoeo.pth",
help="Path to weights or checkpoint file (.weights or .pth)")
parser.add_argument("-d", "--data", type=str, default="config/torso.data", help="Path to data config file (.data)")
parser.add_argument("-n", "--negative_data_dir", default='', type=str, help="Path to negative data directory")
parser.add_argument("--negative_data_fraction", default=0, type=float, help="Fraction of negative data relative to positive data (default=0.0)")
parser.add_argument("-b", "--batch_size", type=int, default=8, help="Size of each image batch")
parser.add_argument("-v", "--verbose", action='store_true', help="Makes the validation more verbose")
parser.add_argument("--img_size", type=int, default=416, help="Size of each image dimension for yolo")
Expand Down
31 changes: 26 additions & 5 deletions yoeo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@
import numpy as np

import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, ConcatDataset
import torch.optim as optim
from torch.autograd import Variable

from yoeo.models import load_model
from yoeo.utils.logger import Logger
from yoeo.utils.utils import to_cpu, print_environment_info, provide_determinism, worker_seed_set
from yoeo.utils.datasets import ListDataset
from yoeo.utils.datasets import ListDataset, NegativeDataset
from yoeo.utils.dataclasses import ClassNames
from yoeo.utils.class_config import ClassConfig
from yoeo.utils.augmentations import AUGMENTATION_TRANSFORMS
from yoeo.utils.transforms import DEFAULT_TRANSFORMS

Check failure on line 23 in yoeo/train.py

View workflow job for this annotation

GitHub Actions / linter

'yoeo.utils.transforms.DEFAULT_TRANSFORMS' imported but unused
from yoeo.utils.parse_config import parse_data_config
from yoeo.utils.loss import compute_loss
from yoeo.test import _evaluate, _create_validation_data_loader
Expand All @@ -30,11 +30,15 @@
from torchsummary import summary


def _create_data_loader(img_path, batch_size, img_size, n_cpu, multiscale_training=False):
def _create_data_loader(img_path, negative_img_dir, negative_data_fraction, batch_size, img_size, n_cpu, multiscale_training=False):
"""Creates a DataLoader for training.

:param img_path: Path to file containing all paths to training images.
:type img_path: str
:param negative_img_dir: Path to negative image folder
:type negative_img_dir: str
:param negative_data_fraction: Fraction of negative data relative to positive data
:type negative_data_fraction: float
:param batch_size: Size of each image batch
:type batch_size: int
:param img_size: Size of each image dimension for yolo
Expand All @@ -51,8 +55,20 @@
img_size=img_size,
multiscale=multiscale_training,
transform=AUGMENTATION_TRANSFORMS)

Check warning on line 58 in yoeo/train.py

View workflow job for this annotation

GitHub Actions / linter

blank line contains whitespace
dataset_len = len(dataset)
negative_dataset_len = int(negative_data_fraction*dataset_len)

Check warning on line 61 in yoeo/train.py

View workflow job for this annotation

GitHub Actions / linter

blank line contains whitespace
negative_dataset = NegativeDataset(
negative_img_dir,
img_size=img_size,
transform=AUGMENTATION_TRANSFORMS,
negative_dataset_max_len=negative_dataset_len)

Check warning on line 67 in yoeo/train.py

View workflow job for this annotation

GitHub Actions / linter

blank line contains whitespace
concat_dataset = ConcatDataset([dataset, negative_dataset])

dataloader = DataLoader(
dataset,
concat_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=n_cpu,
Expand All @@ -61,12 +77,13 @@
worker_init_fn=worker_seed_set)
return dataloader


def run():

Check failure on line 80 in yoeo/train.py

View workflow job for this annotation

GitHub Actions / linter

expected 2 blank lines, found 1
print_environment_info()
parser = argparse.ArgumentParser(description="Trains the YOLO model.")
parser.add_argument("-m", "--model", type=str, default="config/yoeo.cfg", help="Path to model definition file (.cfg)")
parser.add_argument("-d", "--data", type=str, default="config/torso.data", help="Path to data config file (.data)")
parser.add_argument("-n", "--negative_data_dir", default='', type=str, help="Path to negative data directory")
parser.add_argument("--negative_data_fraction", default=0, type=float, help="Fraction of negative data relative to positive data (default=0.0)")
parser.add_argument("-e", "--epochs", type=int, default=300, help="Number of epochs")
parser.add_argument("-v", "--verbose", action='store_true', help="Makes the training more verbose")
parser.add_argument("--n_cpu", type=int, default=8, help="Number of cpu threads to use during batch generation")
Expand Down Expand Up @@ -122,6 +139,8 @@
# Load training dataloader
dataloader = _create_data_loader(
train_path,
args.negative_data_dir,
args.negative_data_fraction,
mini_batch_size,
model.hyperparams['height'],
args.n_cpu,
Expand All @@ -130,6 +149,8 @@
# Load validation dataloader
validation_dataloader = _create_validation_data_loader(
valid_path,
args.negative_data_dir,
args.negative_data_fraction,
mini_batch_size,
model.hyperparams['height'],
args.n_cpu)
Expand Down Expand Up @@ -267,7 +288,7 @@
("validation/mAP", AP.mean()),
("validation/f1", f1.mean()),
("validation/seg_iou", np.array(seg_class_ious).mean())]

Check warning on line 291 in yoeo/train.py

View workflow job for this annotation

GitHub Actions / linter

blank line contains whitespace
if metrics_output[2] is not None:
evaluation_metrics.append(("validation/secondary_mbACC", metrics_output[2].mbACC()))

Expand Down
52 changes: 43 additions & 9 deletions yoeo/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,48 @@ def __getitem__(self, index):

def __len__(self):
return len(self.files)


class NegativeDataset(Dataset):
def __init__(self, folder_path, img_size=416, transform=None,negative_dataset_max_len=0):
self.img_size = img_size
self.transform = transform
self.negative_dataset_max_len = negative_dataset_max_len
if folder_path:
self.files = sorted(glob.glob("%s/*.*" % folder_path))[:self.negative_dataset_max_len]
else:
self.files = []

def __getitem__(self, index):
img_path = self.files[index % len(self.files)]
img = np.array(
Image.open(img_path).convert('RGB'),
dtype=np.uint8)

# Label Placeholder
bb_targets = np.zeros((1, 5))
mask_targets = np.zeros_like(img)

# -----------
# Transform
# -----------
if self.transform:
try:
img, bb_targets, mask_targets = self.transform(
(img, bb_targets, mask_targets)
)
except Exception as e:
print(f"Could not apply transform.")
raise e

return img_path, img, bb_targets, mask_targets

def __len__(self):
return len(self.files)


class ListDataset(Dataset):
def __init__(self, list_path, img_size=416, multiscale=True, transform=None):
def __init__(self, list_path, img_size: int =416, multiscale=True, transform=None):
with open(list_path, "r") as file:
self.img_files = file.readlines()

Expand All @@ -81,7 +119,7 @@ def __init__(self, list_path, img_size=416, multiscale=True, transform=None):
mask_file = os.path.splitext(mask_file)[0] + '.png'
self.mask_files.append(mask_file)

self.img_size = img_size
self.img_size: int = img_size
self.max_objects = 100
self.multiscale = multiscale
self.min_size = self.img_size - 3 * 32
Expand All @@ -94,9 +132,8 @@ def __getitem__(self, index):
# ---------
# Image
# ---------
img_path = self.img_files[index % len(self.img_files)].rstrip()
try:
img_path = self.img_files[index % len(self.img_files)].rstrip()

img = np.array(Image.open(img_path).convert('RGB'), dtype=np.uint8)
except Exception:
print(f"Could not read image '{img_path}'.")
Expand All @@ -105,9 +142,8 @@ def __getitem__(self, index):
# ---------
# Label
# ---------
label_path = self.label_files[index % len(self.img_files)].rstrip()
try:
label_path = self.label_files[index % len(self.img_files)].rstrip()

# Ignore warning if file is empty
with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand All @@ -119,8 +155,8 @@ def __getitem__(self, index):
# ---------
# Segmentation Mask
# ---------
mask_path = self.mask_files[index % len(self.img_files)].rstrip()
try:
mask_path = self.mask_files[index % len(self.img_files)].rstrip()
# Load segmentation mask as numpy array
mask = np.array(Image.open(mask_path).convert('RGB'))
except FileNotFoundError as e:
Expand All @@ -138,7 +174,6 @@ def __getitem__(self, index):
except Exception as e:
print(f"Could not apply transform.")
raise e
return

return img_path, img, bb_targets, mask_targets

Expand All @@ -147,7 +182,6 @@ def collate_fn(self, batch):

# Drop invalid images
batch = [data for data in batch if data is not None]

paths, imgs, bb_targets, mask_targets = list(zip(*batch))

# Selects new image size every tenth batch
Expand Down
Loading