Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dataflow module #1028

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
4 changes: 4 additions & 0 deletions tensorlayer/dataflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base import Dataset
from .base import Transform
from .common import Dataloader
from .common import TFDataloader
73 changes: 73 additions & 0 deletions tensorlayer/dataflow/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
class Dataset(object):

def __getitem__(self, index):
raise NotImplementedError("A Dataset must implement __getitem__(index) method.")

def __len__(self):
raise NotImplementedError("A Dataset must implement __len__() method.")

def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)

def __call__(self, *args, **kwargs):
return self.__iter__()


class DatasetWrapper(object):
def __init__(self, ds):
self.ds = ds
self.ds_len = len(ds)

def __len__(self):
return len(self.ds)

def __iter__(self):
for dp in self.ds:
yield dp

def __call__(self, *args, **kwargs):
return self.__iter__()


class IndexableDatasetWrapper(object):
def __init__(self, ds):
self.ds = ds
self.ds_len = len(ds)

def __getitem__(self, index):
return self.ds.__getitem__(index)

def __len__(self):
return len(self.ds)

def __call__(self, *args, **kwargs):
return self


class Transform(object):
def __call__(self, *args, **kwargs):
raise NotImplementedError("Transform must implement __call__() method.")


class _Transforms_for_tf_dataset(object):
"""
This class aggregate Transforms into one object in order to use tf.data.Dataset.map API
"""

def __init__(self, transforms):
self.transforms = transforms

def __call__(self, *args):
# assert len(args) == len(self.transforms)
# data_list = [None] * len(args)
# for i in range(len(args)):
# data = args[i]
# for transform in self.transforms[i]:
# data = transform(data)
# data_list[i] = data
# return data_list
data_list = list(args)
for transform in self.transforms:
data_list = transform(*data_list)
return data_list
Loading