Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ragged dataset support #162

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
yolo_tiny_anchors, yolo_tiny_anchor_masks
)
from yolov3_tf2.utils import freeze_all
import yolov3_tf2.dataset as dataset

flags.DEFINE_string('dataset', '', 'path to dataset')
flags.DEFINE_string('val_dataset', '', 'path to validation dataset')
Expand All @@ -42,6 +41,7 @@
flags.DEFINE_integer('num_classes', 80, 'number of classes in the model')
flags.DEFINE_integer('weights_num_classes', None, 'specify num class for `weights` file if different, '
'useful in transfer learning with different number of classes')
flags.DEFINE_boolean('use_ragged_dataset', False, 'use tf.ragged api or not')


def main(_argv):
Expand All @@ -59,12 +59,17 @@ def main(_argv):
anchors = yolo_anchors
anchor_masks = yolo_anchor_masks

if FLAGS.use_ragged_dataset:
import yolov3_tf2.ragged_dataset as dataset
else:
import yolov3_tf2.dataset as dataset

train_dataset = dataset.load_fake_dataset()
if FLAGS.dataset:
train_dataset = dataset.load_tfrecord_dataset(
FLAGS.dataset, FLAGS.classes, FLAGS.size)
train_dataset = train_dataset.shuffle(buffer_size=512)
train_dataset = train_dataset.batch(FLAGS.batch_size)
train_dataset = train_dataset.batch(FLAGS.batch_size, drop_remainder=FLAGS.use_ragged_dataset)
train_dataset = train_dataset.map(lambda x, y: (
dataset.transform_images(x, FLAGS.size),
dataset.transform_targets(y, anchors, anchor_masks, FLAGS.size)))
Expand All @@ -75,7 +80,7 @@ def main(_argv):
if FLAGS.val_dataset:
val_dataset = dataset.load_tfrecord_dataset(
FLAGS.val_dataset, FLAGS.classes, FLAGS.size)
val_dataset = val_dataset.batch(FLAGS.batch_size)
val_dataset = val_dataset.batch(FLAGS.batch_size, drop_remainder=FLAGS.use_ragged_dataset)
val_dataset = val_dataset.map(lambda x, y: (
dataset.transform_images(x, FLAGS.size),
dataset.transform_targets(y, anchors, anchor_masks, FLAGS.size)))
Expand Down
149 changes: 149 additions & 0 deletions yolov3_tf2/ragged_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import tensorflow as tf
from absl.flags import FLAGS

@tf.function
def transform_targets_for_output(y_true, grid_size, anchor_idxs):
# y_true: (batch_size, (nbboxes), (x1, y1, x2, y2, class, best_anchor))
N = y_true.nrows()

# y_true_out: (N, grid, grid, anchors, [x, y, w, h, obj, class])
y_true_out = tf.zeros(
(N, grid_size, grid_size, tf.shape(anchor_idxs)[0], 6))

anchor_idxs = tf.cast(anchor_idxs, tf.int32)

indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True)
updates = tf.TensorArray(tf.float32, 1, dynamic_size=True)
idx = 0
for i in tf.range(N):
for j in tf.range(y_true[i].nrows()):
anchor_eq = tf.equal(
anchor_idxs, tf.cast(y_true[i][j][5], tf.int32))

if tf.reduce_any(anchor_eq):
box = y_true[i][j][0:4]
box_xy = (y_true[i][j][0:2] + y_true[i][j][2:4]) / 2

anchor_idx = tf.cast(tf.where(anchor_eq), tf.int32)
grid_xy = tf.cast(box_xy // (1/grid_size), tf.int32)

# grid[y][x][anchor] = (tx, ty, bw, bh, obj, class)
indexes = indexes.write(
idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]])
updates = updates.write(
idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]])
idx += 1

# tf.print(indexes.stack())
# tf.print(updates.stack())

return tf.tensor_scatter_nd_update(
y_true_out, indexes.stack(), updates.stack())


def transform_targets(y_train, anchors, anchor_masks, size):
y_outs = []
grid_size = size // 32

y_train = y_train.merge_dims(1, 2)
# calculate anchor index for true boxes
anchors = tf.cast(anchors, tf.float32)
anchor_area = anchors[..., 0] * anchors[..., 1]
box_wh = y_train[..., 2:4] - y_train[..., 0:2]
box_wh = tf.tile(tf.expand_dims(box_wh, -2),
(1, 1, tf.shape(anchors)[0], 1))
box_area = box_wh[..., 0] * box_wh[..., 1]
intersection = tf.minimum(box_wh[..., 0], anchors[..., 0]) * \
tf.minimum(box_wh[..., 1], anchors[..., 1])
iou = intersection / (box_area + anchor_area - intersection)

# tf.ragged.argmax is not ready, some dirty code
anchor_idx_list = []
for row_index in range(FLAGS.batch_size):
iou_single = iou[row_index]
anchor_idx = tf.cast(tf.argmax(iou_single, axis=-1), tf.float32)
anchor_idx = tf.expand_dims(anchor_idx, axis=-1)
anchor_idx_list.append(anchor_idx)
anchor_idx = tf.ragged.stack(anchor_idx_list)

y_train = tf.concat([y_train, anchor_idx], axis=-1)

for anchor_idxs in anchor_masks:
y_outs.append(transform_targets_for_output(
y_train, grid_size, anchor_idxs))
grid_size *= 2

return tuple(y_outs)


def transform_images(x_train, size):
x_train = tf.image.resize(x_train, (size, size))
x_train = x_train / 255
return x_train


# https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/using_your_own_dataset.md#conversion-script-outline-conversion-script-outline
# Commented out fields are not required in our project
IMAGE_FEATURE_MAP = {
# 'image/width': tf.io.FixedLenFeature([], tf.int64),
# 'image/height': tf.io.FixedLenFeature([], tf.int64),
# 'image/filename': tf.io.FixedLenFeature([], tf.string),
# 'image/source_id': tf.io.FixedLenFeature([], tf.string),
# 'image/key/sha256': tf.io.FixedLenFeature([], tf.string),
'image/encoded': tf.io.FixedLenFeature([], tf.string),
# 'image/format': tf.io.FixedLenFeature([], tf.string),
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
'image/object/class/text': tf.io.VarLenFeature(tf.string),
# 'image/object/class/label': tf.io.VarLenFeature(tf.int64),
# 'image/object/difficult': tf.io.VarLenFeature(tf.int64),
# 'image/object/truncated': tf.io.VarLenFeature(tf.int64),
# 'image/object/view': tf.io.VarLenFeature(tf.string),
}


def parse_tfrecord(tfrecord, class_table, size):
x = tf.io.parse_single_example(tfrecord, IMAGE_FEATURE_MAP)
x_train = tf.image.decode_jpeg(x['image/encoded'], channels=3)
x_train = tf.image.resize(x_train, (size, size))

class_text = tf.sparse.to_dense(
x['image/object/class/text'], default_value='')
labels = tf.cast(class_table.lookup(class_text), tf.float32)
y_train = tf.stack([tf.sparse.to_dense(x['image/object/bbox/xmin']),
tf.sparse.to_dense(x['image/object/bbox/ymin']),
tf.sparse.to_dense(x['image/object/bbox/xmax']),
tf.sparse.to_dense(x['image/object/bbox/ymax']),
labels], axis=1)

y_train = tf.RaggedTensor.from_row_splits(y_train, [0, tf.shape(y_train)[0]])

return x_train, y_train


def load_tfrecord_dataset(file_pattern, class_file, size=416):
LINE_NUMBER = -1 # TODO: use tf.lookup.TextFileIndex.LINE_NUMBER
class_table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
class_file, tf.string, 0, tf.int64, LINE_NUMBER, delimiter="\n"), -1)

files = tf.data.Dataset.list_files(file_pattern)
dataset = files.flat_map(tf.data.TFRecordDataset)
return dataset.map(lambda x: parse_tfrecord(x, class_table, size))


def load_fake_dataset():
x_train = tf.image.decode_jpeg(
open('./data/girl.png', 'rb').read(), channels=3)
x_train = tf.expand_dims(x_train, axis=0)

labels = [
[0.18494931, 0.03049111, 0.9435849, 0.96302897, 0],
[0.01586703, 0.35938117, 0.17582396, 0.6069674, 56],
[0.09158827, 0.48252046, 0.26967454, 0.6403017, 67]
] + [[0, 0, 0, 0, 0]] * 5
y_train = tf.convert_to_tensor(labels, tf.float32)
y_train = tf.expand_dims(y_train, axis=0)

return tf.data.Dataset.from_tensor_slices((x_train, y_train))