From 57f58130d654c279d88a8ca1e3f95572e79dd3db Mon Sep 17 00:00:00 2001 From: Doga Yilmaz Date: Tue, 19 Nov 2024 10:28:03 +0000 Subject: [PATCH] Fixed a minor issue with input dims. of perceptual losses. --- odak/learn/perception/learned_perceptual_losses.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/odak/learn/perception/learned_perceptual_losses.py b/odak/learn/perception/learned_perceptual_losses.py index 08def133..219897ba 100644 --- a/odak/learn/perception/learned_perceptual_losses.py +++ b/odak/learn/perception/learned_perceptual_losses.py @@ -39,6 +39,9 @@ def forward(self, predictions, targets, dim_order = 'CHW'): The computed loss if successful, otherwise 0.0. """ try: + if len(predictions.shape) == 3: + predictions = predictions.unsqueeze(0) + targets = targets.unsqueeze(0) l_ColorVideoVDP = self.cvvdp.loss(predictions, targets, dim_order = dim_order) return l_ColorVideoVDP except Exception as e: @@ -82,6 +85,9 @@ def forward(self, predictions, targets, dim_order = 'CHW'): The computed loss if successful, otherwise 0.0. """ try: + if len(predictions.shape) == 3: + predictions = predictions.unsqueeze(0) + targets = targets.unsqueeze(0) l_FovVideoVDP = self.fvvdp.predict(predictions, targets, dim_order = dim_order)[0] return l_FovVideoVDP except Exception as e: @@ -121,6 +127,9 @@ def forward(self, predictions, targets): The computed loss if successful, otherwise 0.0. """ try: + if len(predictions.shape) == 3: + predictions = predictions.unsqueeze(0) + targets = targets.unsqueeze(0) lpips_image = predictions lpips_target = targets if len(lpips_image.shape) == 3: