Skip to content

Commit

Permalink
Fixed a minor issue with input dims. of perceptual losses.
Browse files Browse the repository at this point in the history
  • Loading branch information
yilmazdoga committed Nov 19, 2024
1 parent 6b78698 commit 57f5813
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions odak/learn/perception/learned_perceptual_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def forward(self, predictions, targets, dim_order = 'CHW'):
The computed loss if successful, otherwise 0.0.
"""
try:
if len(predictions.shape) == 3:
predictions = predictions.unsqueeze(0)
targets = targets.unsqueeze(0)
l_ColorVideoVDP = self.cvvdp.loss(predictions, targets, dim_order = dim_order)
return l_ColorVideoVDP
except Exception as e:
Expand Down Expand Up @@ -82,6 +85,9 @@ def forward(self, predictions, targets, dim_order = 'CHW'):
The computed loss if successful, otherwise 0.0.
"""
try:
if len(predictions.shape) == 3:
predictions = predictions.unsqueeze(0)
targets = targets.unsqueeze(0)
l_FovVideoVDP = self.fvvdp.predict(predictions, targets, dim_order = dim_order)[0]
return l_FovVideoVDP
except Exception as e:
Expand Down Expand Up @@ -121,6 +127,9 @@ def forward(self, predictions, targets):
The computed loss if successful, otherwise 0.0.
"""
try:
if len(predictions.shape) == 3:
predictions = predictions.unsqueeze(0)
targets = targets.unsqueeze(0)
lpips_image = predictions
lpips_target = targets
if len(lpips_image.shape) == 3:
Expand Down

0 comments on commit 57f5813

Please sign in to comment.