-
Notifications
You must be signed in to change notification settings - Fork 148
/
ObjectDetectionDataset.swift
102 lines (92 loc) · 3.25 KB
/
ObjectDetectionDataset.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
import ModelSupport
import TensorFlow
public struct LazyImage {
public let width: Int
public let height: Int
public let url: URL?
public init(width w: Int, height h: Int, url u: URL?) {
self.width = w
self.height = h
self.url = u
}
public func tensor() -> Tensor<Float>? {
if url != nil {
return Image(contentsOf: url!).tensor
} else {
return nil
}
}
}
public struct LabeledObject {
public let xMin: Float
public let xMax: Float
public let yMin: Float
public let yMax: Float
public let className: String
public let classId: Int
public let isCrowd: Int?
public let area: Float
public let maskRLE: RLE?
public init(
xMin x0: Float, xMax x1: Float,
yMin y0: Float, yMax y1: Float,
className: String, classId: Int,
isCrowd: Int?, area: Float, maskRLE: RLE?
) {
self.xMin = x0
self.xMax = x1
self.yMin = y0
self.yMax = y1
self.className = className
self.classId = classId
self.isCrowd = isCrowd
self.area = area
self.maskRLE = maskRLE
}
}
public struct ObjectDetectionExample: KeyPathIterable {
public let image: LazyImage
public let objects: [LabeledObject]
public init(image: LazyImage, objects: [LabeledObject]) {
self.image = image
self.objects = objects
}
}
/// Types whose elements represent an object detection dataset (with both
/// training and validation data).
public protocol ObjectDetectionData {
/// The type of the training data, represented as a sequence of epochs, which
/// are collection of batches.
associatedtype Training: Sequence
where Training.Element: Collection, Training.Element.Element == [ObjectDetectionExample]
/// The type of the validation data, represented as a collection of batches.
associatedtype Validation: Collection where Validation.Element == [ObjectDetectionExample]
/// Creates an instance from a given `batchSize`.
init(
training: COCO, validation: COCO, includeMasks: Bool, batchSize: Int, on device: Device,
transform: @escaping (ObjectDetectionExample) -> [ObjectDetectionExample])
/// The `training` epochs.
var training: Training { get }
/// The `validation` batches.
var validation: Validation { get }
// The following is probably going to be necessary since we can't extract that
// information from `Epochs` or `Batches`.
/// The number of samples in the `training` set.
//var trainingSampleCount: Int {get}
/// The number of samples in the `validation` set.
//var validationSampleCount: Int {get}
}