From f4e8deb2bc719fc4f26bae5e406bc0265022d04e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kaan=20Ak=C5=9Fit?= Date: Thu, 30 Nov 2023 11:45:04 +0000 Subject: [PATCH] Updating two utility functions to support both 3xMxN and MxNx3: odak.learn.tools.zero_pad and odak.learn.tools.crop_center. --- odak/learn/tools/matrix.py | 16 +++++++++-- ...learn_lenses.py => test_learn_gratings.py} | 0 test/test_learn_tools.py | 27 +++++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) rename test/{test_learn_lenses.py => test_learn_gratings.py} (100%) create mode 100644 test/test_learn_tools.py diff --git a/odak/learn/tools/matrix.py b/odak/learn/tools/matrix.py index 872167a3..80b5e4da 100644 --- a/odak/learn/tools/matrix.py +++ b/odak/learn/tools/matrix.py @@ -32,7 +32,7 @@ def zero_pad(field, size = None, method = 'center'): Parameters ---------- field : ndarray - Input field MxN or KxJxMxN array. + Input field MxN or KxJxMxN or KxMxNxJ array. size : list Size to be zeropadded (e.g., [m, n], last two dimensions only). method : str @@ -48,6 +48,10 @@ def zero_pad(field, size = None, method = 'center'): field = field.unsqueeze(0) if len(field.shape) < 4: field = field.unsqueeze(0) + permute_flag = False + if field.shape[-1] < 5: + permute_flag = True + field = field.permute(0, 3, 1, 2) if type(size) == type(None): resolution = [field.shape[0], field.shape[1], 2 * field.shape[-2], 2 * field.shape[-1]] else: @@ -69,6 +73,8 @@ def zero_pad(field, size = None, method = 'center'): 0: field.shape[-2], 0: field.shape[-1] ] = field + if permute_flag == True: + field = field.permute(0, 3, 1, 2) if len(orig_resolution) == 2: field_zero_padded = field_zero_padded.squeeze(0).squeeze(0) if len(orig_resolution) == 3: @@ -83,7 +89,7 @@ def crop_center(field, size = None): Parameters ---------- field : ndarray - Input field 2M x 2N or K x L x 2M x 2N array. + Input field 2M x 2N or K x L x 2M x 2N or K x 2M x 2N x L array. size : list Dimensions to crop with respect to center of the image (e.g., M x N or 1 x 1 x M x N). @@ -97,6 +103,10 @@ def crop_center(field, size = None): field = field.unsqueeze(0) if len(field.shape) < 4: field = field.unsqueeze(0) + permute_flag = False + if field.shape[-1] < 5: + permute_flag = True + field = field.permute(0, 3, 1, 2) if type(size) == type(None): qx = int(field.shape[-2] // 4) qy = int(field.shape[-1] // 4) @@ -108,6 +118,8 @@ def crop_center(field, size = None): hy = int(size[-1] // 2) cropped_padded = field[:, :, cx-hx:cx+hx, cy-hy:cy+hy] cropped = cropped_padded + if permute_flag == True: + field = field.permute(0, 3, 1, 2) if len(orig_resolution) == 2: cropped = cropped_padded.squeeze(0).squeeze(0) if len(orig_resolution) == 3: diff --git a/test/test_learn_lenses.py b/test/test_learn_gratings.py similarity index 100% rename from test/test_learn_lenses.py rename to test/test_learn_gratings.py diff --git a/test/test_learn_tools.py b/test/test_learn_tools.py new file mode 100644 index 00000000..2f97cd0a --- /dev/null +++ b/test/test_learn_tools.py @@ -0,0 +1,27 @@ +import sys +import odak +import torch + + +def test(): + image0 = torch.zeros(1, 3, 50, 50) + image0[:, :, ::2, :: 2] = 1. + image0_zero_padded = odak.learn.tools.zero_pad(image0) + odak.learn.tools.save_image('image0_padded.png', image0_zero_padded, cmin = 0., cmax = 1.) + + image1 = torch.zeros(1, 50, 50, 3) + image1[:, ::3, ::3, :] = 1. + image1_zero_padded = odak.learn.tools.zero_pad(image1) + odak.learn.tools.save_image('image1_padded.png', image1_zero_padded, cmin = 0., cmax = 1.) + + + image0_cropped = odak.learn.tools.crop_center(image0_zero_padded) + odak.learn.tools.save_image('image0_cropped.png', image0_cropped, cmin = 0., cmax = 1.) + image1_cropped = odak.learn.tools.crop_center(image1_zero_padded) + odak.learn.tools.save_image('image1_padded.png', image1_cropped, cmin = 0., cmax = 1.) + + assert True == True + + +if __name__ == '__main__': + sys.exit(test())