Skip to content

Commit

Permalink
Updating two utility functions to support both 3xMxN and MxNx3: odak.…
Browse files Browse the repository at this point in the history
…learn.tools.zero_pad and odak.learn.tools.crop_center.
  • Loading branch information
kaanaksit committed Nov 30, 2023
1 parent db0f4bb commit f4e8deb
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
16 changes: 14 additions & 2 deletions odak/learn/tools/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def zero_pad(field, size = None, method = 'center'):
Parameters
----------
field : ndarray
Input field MxN or KxJxMxN array.
Input field MxN or KxJxMxN or KxMxNxJ array.
size : list
Size to be zeropadded (e.g., [m, n], last two dimensions only).
method : str
Expand All @@ -48,6 +48,10 @@ def zero_pad(field, size = None, method = 'center'):
field = field.unsqueeze(0)
if len(field.shape) < 4:
field = field.unsqueeze(0)
permute_flag = False
if field.shape[-1] < 5:
permute_flag = True
field = field.permute(0, 3, 1, 2)
if type(size) == type(None):
resolution = [field.shape[0], field.shape[1], 2 * field.shape[-2], 2 * field.shape[-1]]
else:
Expand All @@ -69,6 +73,8 @@ def zero_pad(field, size = None, method = 'center'):
0: field.shape[-2],
0: field.shape[-1]
] = field
if permute_flag == True:
field = field.permute(0, 3, 1, 2)
if len(orig_resolution) == 2:
field_zero_padded = field_zero_padded.squeeze(0).squeeze(0)
if len(orig_resolution) == 3:
Expand All @@ -83,7 +89,7 @@ def crop_center(field, size = None):
Parameters
----------
field : ndarray
Input field 2M x 2N or K x L x 2M x 2N array.
Input field 2M x 2N or K x L x 2M x 2N or K x 2M x 2N x L array.
size : list
Dimensions to crop with respect to center of the image (e.g., M x N or 1 x 1 x M x N).
Expand All @@ -97,6 +103,10 @@ def crop_center(field, size = None):
field = field.unsqueeze(0)
if len(field.shape) < 4:
field = field.unsqueeze(0)
permute_flag = False
if field.shape[-1] < 5:
permute_flag = True
field = field.permute(0, 3, 1, 2)
if type(size) == type(None):
qx = int(field.shape[-2] // 4)
qy = int(field.shape[-1] // 4)
Expand All @@ -108,6 +118,8 @@ def crop_center(field, size = None):
hy = int(size[-1] // 2)
cropped_padded = field[:, :, cx-hx:cx+hx, cy-hy:cy+hy]
cropped = cropped_padded
if permute_flag == True:
field = field.permute(0, 3, 1, 2)
if len(orig_resolution) == 2:
cropped = cropped_padded.squeeze(0).squeeze(0)
if len(orig_resolution) == 3:
Expand Down
File renamed without changes.
27 changes: 27 additions & 0 deletions test/test_learn_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import sys
import odak
import torch


def test():
image0 = torch.zeros(1, 3, 50, 50)
image0[:, :, ::2, :: 2] = 1.
image0_zero_padded = odak.learn.tools.zero_pad(image0)
odak.learn.tools.save_image('image0_padded.png', image0_zero_padded, cmin = 0., cmax = 1.)

image1 = torch.zeros(1, 50, 50, 3)
image1[:, ::3, ::3, :] = 1.
image1_zero_padded = odak.learn.tools.zero_pad(image1)
odak.learn.tools.save_image('image1_padded.png', image1_zero_padded, cmin = 0., cmax = 1.)


image0_cropped = odak.learn.tools.crop_center(image0_zero_padded)
odak.learn.tools.save_image('image0_cropped.png', image0_cropped, cmin = 0., cmax = 1.)
image1_cropped = odak.learn.tools.crop_center(image1_zero_padded)
odak.learn.tools.save_image('image1_padded.png', image1_cropped, cmin = 0., cmax = 1.)

assert True == True


if __name__ == '__main__':
sys.exit(test())

0 comments on commit f4e8deb

Please sign in to comment.