Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong total variation calculation #328

Open
Dobatymo opened this issue Jan 10, 2023 · 2 comments
Open

Wrong total variation calculation #328

Dobatymo opened this issue Jan 10, 2023 · 2 comments
Assignees

Comments

@Dobatymo
Copy link

Dobatymo commented Jan 10, 2023

The total variation (l2 version) is calculated here as sqrt(sum(d_w**2 + d_h**2)). Shouldn't it be sum(sqrt(d_w**2 + d_h**2)) instead? See

piq/piq/tv.py

Lines 34 to 37 in 26d044e

elif norm_type == 'l2':
w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2), dim=[1, 2, 3])
h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2), dim=[1, 2, 3])
score = torch.sqrt(h_variance + w_variance)

Now the problem is how to vectorize this correctly...

@zakajd
Copy link
Collaborator

zakajd commented Jan 10, 2023

Hi @Dobatymo
I believe that both variants are equally common in the literature. Wikipedia article has the summation outside, while other sources (see image) put it inside. We have exact formula included in the docs so user can decide if it satisfies his use case or not.
image

Feel free to close the issues if it answers your question!

@Dobatymo
Copy link
Author

Dobatymo commented Jan 10, 2023

Hi @zakajd Sorry I missed the formula in the docs. However both Wikipedia and the two references from the docs have the sum outside. I am not familiar with any formulation which has the sum inside. I am only familiar with the isotropic and anisotropic formulations. However both have the sum outside (well it only matters for the isotropic version). Only the sum of the per pixel norm differs.

EDIT: I would suggest

d_w = x[:, :, :-1, 1:] - x[:, :, :-1, :-1]
d_h = x[:, :, 1:, :-1] - x[:, :, :-1, :-1]
score = torch.sum(torch.sqrt(torch.square(d_w) + torch.square(d_h)), dim=(1, 2, 3))

For l2_squared, it doesn't really matter as well.

@zakajd zakajd self-assigned this Feb 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants