Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

solution to Issue 70 & new options #72

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions n2v/internals/N2V_DataGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class N2V_DataGenerator():
The 'N2V_DataGenerator' enables training and validation data generation for Noise2Void.
"""

def load_imgs(self, files, dims='YX'):
def load_imgs(self, files, to32bit, dims='YX'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parameter should have a default value of False. This will ensure backwards compatibility.

"""
Helper to read a list of files. The images are not required to have same size,
but have to be of same dimensionality.
Expand All @@ -21,7 +21,8 @@ def load_imgs(self, files, dims='YX'):
List of paths to tiff-files.
dims : String, optional(default='YX')
Dimensions of the images to read. Known dimensions are: 'TZYXC'

to32bit: makes conversion to 32 bit optional, if set to False the original datatype is used

Returns
-------
images : list(array(float))
Expand Down Expand Up @@ -65,7 +66,10 @@ def load_imgs(self, files, dims='YX'):
else:
_raise("Filetype '{}' is not supported.".format(f))

img = imread(f).astype(np.float32)
if to32bit:
img = imread(f).astype(np.float32)
else:
img = imread(f)
assert len(img.shape) == len(dims), "Number of image dimensions doesn't match 'dims'."

img = np.moveaxis(img, move_axis_from, move_axis_to)
Expand All @@ -80,7 +84,7 @@ def load_imgs(self, files, dims='YX'):

return imgs

def load_imgs_from_directory(self, directory, filter='*.tif', dims='YX'):
def load_imgs_from_directory(self, directory, filter='*.tif', dims='YX', names_back = False, to32bit = True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to32bit = False for backwards compatibility.

"""
Helper to read all files which match 'filter' from a directory. The images are not required to have same size,
but have to be of same dimensionality.
Expand All @@ -94,15 +98,23 @@ def load_imgs_from_directory(self, directory, filter='*.tif', dims='YX'):
dims : String, optional(default='YX')
Dimensions of the images to read. Known dimensions are: 'TZYXC'

names_back: if set to True, the function returns the names of the input files as list

to32bit: makes conversion to 32 bit optional, if set to False the original datatype is used

Returns
-------
images : list(array(float))
A list of the read tif-files. The images have dimensionality 'SZYXC' or 'SYXC'
(optional): files, names of the input files as list
"""

files = glob(join(directory, filter))
files.sort()
return self.load_imgs(files, dims=dims)
if names_back:
return files, self.load_imgs(files, to32bit, dims=dims)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what the best practice would be here. Since this changes the return value of this method.

I think the best would be to refactor the code, such that we have two methods. The old one as it is and a new one which also returns the file list. Both these methods should probably wrap a new third method which takes a file-list and return the image-list.

I would also put the files as second return parameter, then the user always gets the image-list as first return value.

else:
return self.load_imgs(files, to32bit, dims=dims)


def generate_patches_from_list(self, data, num_patches_per_img=None, shape=(256, 256), augment=True, shuffle=False):
Expand Down Expand Up @@ -142,7 +154,7 @@ def generate_patches_from_list(self, data, num_patches_per_img=None, shape=(256,

return patches

def generate_patches(self, data, num_patches=None, shape=(256, 256), augment=True):
def generate_patches(self, data, num_patches=None, shape=(256, 256), augment=True, shuffle_patches = True):
"""
Extracts patches from 'data'. The patches can be augmented, which means they get rotated three times
in XY-Plane and flipped along the X-Axis. Augmentation leads to an eight-fold increase in training data.
Expand Down Expand Up @@ -174,17 +186,18 @@ def generate_patches(self, data, num_patches=None, shape=(256, 256), augment=Tru
if augment:
print("XY-Plane is not square. Omit augmentation!")

np.random.shuffle(patches)
print('Generated patches:', patches.shape)
if shuffle_patches:
np.random.shuffle(patches)
#print('Generated patches:', patches.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason to remove this print statement?

return patches

def __extract_patches__(self, data, num_patches=None, shape=(256, 256), n_dims=2):
if num_patches == None:
patches = []
if n_dims == 2:
if data.shape[1] > shape[0] and data.shape[2] > shape[1]:
for y in range(0, data.shape[1] - shape[0], shape[0]):
for x in range(0, data.shape[2] - shape[1], shape[1]):
if data.shape[1] >= shape[0] and data.shape[2] >= shape[1]:
for y in range(0, data.shape[1] - shape[0]+1, shape[0]):
for x in range(0, data.shape[2] - shape[1]+1, shape[1]):
patches.append(data[:, y:y + shape[0], x:x + shape[1]])

return np.concatenate(patches)
Expand All @@ -193,10 +206,10 @@ def __extract_patches__(self, data, num_patches=None, shape=(256, 256), n_dims=2
else:
print("'shape' is too big.")
elif n_dims == 3:
if data.shape[1] > shape[0] and data.shape[2] > shape[1] and data.shape[3] > shape[2]:
for z in range(0, data.shape[1] - shape[0], shape[0]):
for y in range(0, data.shape[2] - shape[1], shape[1]):
for x in range(0, data.shape[3] - shape[2], shape[2]):
if data.shape[1] >= shape[0] and data.shape[2] >= shape[1] and data.shape[3] >= shape[2]:
for z in range(0, data.shape[1] - shape[0]+1, shape[0]):
for y in range(0, data.shape[2] - shape[1]+1, shape[1]):
for x in range(0, data.shape[3] - shape[2]+1, shape[2]):
patches.append(data[:, z:z + shape[0], y:y + shape[1], x:x + shape[2]])

return np.concatenate(patches)
Expand Down