diff --git a/train.py b/train.py index ae2298fc..198922da 100644 --- a/train.py +++ b/train.py @@ -16,7 +16,6 @@ yolo_tiny_anchors, yolo_tiny_anchor_masks ) from yolov3_tf2.utils import freeze_all -import yolov3_tf2.dataset as dataset flags.DEFINE_string('dataset', '', 'path to dataset') flags.DEFINE_string('val_dataset', '', 'path to validation dataset') @@ -42,6 +41,7 @@ flags.DEFINE_integer('num_classes', 80, 'number of classes in the model') flags.DEFINE_integer('weights_num_classes', None, 'specify num class for `weights` file if different, ' 'useful in transfer learning with different number of classes') +flags.DEFINE_boolean('use_ragged_dataset', False, 'use tf.ragged api or not') def main(_argv): @@ -59,12 +59,17 @@ def main(_argv): anchors = yolo_anchors anchor_masks = yolo_anchor_masks + if FLAGS.use_ragged_dataset: + import yolov3_tf2.ragged_dataset as dataset + else: + import yolov3_tf2.dataset as dataset + train_dataset = dataset.load_fake_dataset() if FLAGS.dataset: train_dataset = dataset.load_tfrecord_dataset( FLAGS.dataset, FLAGS.classes, FLAGS.size) train_dataset = train_dataset.shuffle(buffer_size=512) - train_dataset = train_dataset.batch(FLAGS.batch_size) + train_dataset = train_dataset.batch(FLAGS.batch_size, drop_remainder=FLAGS.use_ragged_dataset) train_dataset = train_dataset.map(lambda x, y: ( dataset.transform_images(x, FLAGS.size), dataset.transform_targets(y, anchors, anchor_masks, FLAGS.size))) @@ -75,7 +80,7 @@ def main(_argv): if FLAGS.val_dataset: val_dataset = dataset.load_tfrecord_dataset( FLAGS.val_dataset, FLAGS.classes, FLAGS.size) - val_dataset = val_dataset.batch(FLAGS.batch_size) + val_dataset = val_dataset.batch(FLAGS.batch_size, drop_remainder=FLAGS.use_ragged_dataset) val_dataset = val_dataset.map(lambda x, y: ( dataset.transform_images(x, FLAGS.size), dataset.transform_targets(y, anchors, anchor_masks, FLAGS.size))) diff --git a/yolov3_tf2/ragged_dataset.py b/yolov3_tf2/ragged_dataset.py new file mode 100644 index 00000000..95bf6919 --- /dev/null +++ b/yolov3_tf2/ragged_dataset.py @@ -0,0 +1,149 @@ +import tensorflow as tf +from absl.flags import FLAGS + +@tf.function +def transform_targets_for_output(y_true, grid_size, anchor_idxs): + # y_true: (batch_size, (nbboxes), (x1, y1, x2, y2, class, best_anchor)) + N = y_true.nrows() + + # y_true_out: (N, grid, grid, anchors, [x, y, w, h, obj, class]) + y_true_out = tf.zeros( + (N, grid_size, grid_size, tf.shape(anchor_idxs)[0], 6)) + + anchor_idxs = tf.cast(anchor_idxs, tf.int32) + + indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True) + updates = tf.TensorArray(tf.float32, 1, dynamic_size=True) + idx = 0 + for i in tf.range(N): + for j in tf.range(y_true[i].nrows()): + anchor_eq = tf.equal( + anchor_idxs, tf.cast(y_true[i][j][5], tf.int32)) + + if tf.reduce_any(anchor_eq): + box = y_true[i][j][0:4] + box_xy = (y_true[i][j][0:2] + y_true[i][j][2:4]) / 2 + + anchor_idx = tf.cast(tf.where(anchor_eq), tf.int32) + grid_xy = tf.cast(box_xy // (1/grid_size), tf.int32) + + # grid[y][x][anchor] = (tx, ty, bw, bh, obj, class) + indexes = indexes.write( + idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]]) + updates = updates.write( + idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]]) + idx += 1 + + # tf.print(indexes.stack()) + # tf.print(updates.stack()) + + return tf.tensor_scatter_nd_update( + y_true_out, indexes.stack(), updates.stack()) + + +def transform_targets(y_train, anchors, anchor_masks, size): + y_outs = [] + grid_size = size // 32 + + y_train = y_train.merge_dims(1, 2) + # calculate anchor index for true boxes + anchors = tf.cast(anchors, tf.float32) + anchor_area = anchors[..., 0] * anchors[..., 1] + box_wh = y_train[..., 2:4] - y_train[..., 0:2] + box_wh = tf.tile(tf.expand_dims(box_wh, -2), + (1, 1, tf.shape(anchors)[0], 1)) + box_area = box_wh[..., 0] * box_wh[..., 1] + intersection = tf.minimum(box_wh[..., 0], anchors[..., 0]) * \ + tf.minimum(box_wh[..., 1], anchors[..., 1]) + iou = intersection / (box_area + anchor_area - intersection) + + # tf.ragged.argmax is not ready, some dirty code + anchor_idx_list = [] + for row_index in range(FLAGS.batch_size): + iou_single = iou[row_index] + anchor_idx = tf.cast(tf.argmax(iou_single, axis=-1), tf.float32) + anchor_idx = tf.expand_dims(anchor_idx, axis=-1) + anchor_idx_list.append(anchor_idx) + anchor_idx = tf.ragged.stack(anchor_idx_list) + + y_train = tf.concat([y_train, anchor_idx], axis=-1) + + for anchor_idxs in anchor_masks: + y_outs.append(transform_targets_for_output( + y_train, grid_size, anchor_idxs)) + grid_size *= 2 + + return tuple(y_outs) + + +def transform_images(x_train, size): + x_train = tf.image.resize(x_train, (size, size)) + x_train = x_train / 255 + return x_train + + +# https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/using_your_own_dataset.md#conversion-script-outline-conversion-script-outline +# Commented out fields are not required in our project +IMAGE_FEATURE_MAP = { + # 'image/width': tf.io.FixedLenFeature([], tf.int64), + # 'image/height': tf.io.FixedLenFeature([], tf.int64), + # 'image/filename': tf.io.FixedLenFeature([], tf.string), + # 'image/source_id': tf.io.FixedLenFeature([], tf.string), + # 'image/key/sha256': tf.io.FixedLenFeature([], tf.string), + 'image/encoded': tf.io.FixedLenFeature([], tf.string), + # 'image/format': tf.io.FixedLenFeature([], tf.string), + 'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32), + 'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32), + 'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32), + 'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32), + 'image/object/class/text': tf.io.VarLenFeature(tf.string), + # 'image/object/class/label': tf.io.VarLenFeature(tf.int64), + # 'image/object/difficult': tf.io.VarLenFeature(tf.int64), + # 'image/object/truncated': tf.io.VarLenFeature(tf.int64), + # 'image/object/view': tf.io.VarLenFeature(tf.string), +} + + +def parse_tfrecord(tfrecord, class_table, size): + x = tf.io.parse_single_example(tfrecord, IMAGE_FEATURE_MAP) + x_train = tf.image.decode_jpeg(x['image/encoded'], channels=3) + x_train = tf.image.resize(x_train, (size, size)) + + class_text = tf.sparse.to_dense( + x['image/object/class/text'], default_value='') + labels = tf.cast(class_table.lookup(class_text), tf.float32) + y_train = tf.stack([tf.sparse.to_dense(x['image/object/bbox/xmin']), + tf.sparse.to_dense(x['image/object/bbox/ymin']), + tf.sparse.to_dense(x['image/object/bbox/xmax']), + tf.sparse.to_dense(x['image/object/bbox/ymax']), + labels], axis=1) + + y_train = tf.RaggedTensor.from_row_splits(y_train, [0, tf.shape(y_train)[0]]) + + return x_train, y_train + + +def load_tfrecord_dataset(file_pattern, class_file, size=416): + LINE_NUMBER = -1 # TODO: use tf.lookup.TextFileIndex.LINE_NUMBER + class_table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer( + class_file, tf.string, 0, tf.int64, LINE_NUMBER, delimiter="\n"), -1) + + files = tf.data.Dataset.list_files(file_pattern) + dataset = files.flat_map(tf.data.TFRecordDataset) + return dataset.map(lambda x: parse_tfrecord(x, class_table, size)) + + +def load_fake_dataset(): + x_train = tf.image.decode_jpeg( + open('./data/girl.png', 'rb').read(), channels=3) + x_train = tf.expand_dims(x_train, axis=0) + + labels = [ + [0.18494931, 0.03049111, 0.9435849, 0.96302897, 0], + [0.01586703, 0.35938117, 0.17582396, 0.6069674, 56], + [0.09158827, 0.48252046, 0.26967454, 0.6403017, 67] + ] + [[0, 0, 0, 0, 0]] * 5 + y_train = tf.convert_to_tensor(labels, tf.float32) + y_train = tf.expand_dims(y_train, axis=0) + + return tf.data.Dataset.from_tensor_slices((x_train, y_train))