Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible issue with compute_contra_memobank_loss #171

Open
lucasdavid opened this issue Jul 25, 2024 · 0 comments
Open

Possible issue with compute_contra_memobank_loss #171

lucasdavid opened this issue Jul 25, 2024 · 0 comments

Comments

@lucasdavid
Copy link

Hello! Thank you for the great paper and code. It's been really helpful to me!
I believe there might be an issue with function compute_contra_memobank_loss. I'd appreciate it if you clarify it for me.
The article states that:

"For i-th labeled image, a qualified negative sample for class c should be: (a) not belonging to class c; (b) difficult to distinguish between class c and its ground-truth category."

However, we see a different thing in compute_contra_memobank_loss:

high_valid_pixel = torch.cat((label_l, label_u), dim=0) * high_mask

for i in range(num_segments):
  high_valid_pixel_seg = high_valid_pixel[:, i]
  rep_mask_high_entropy = (prob_seg < 1.0) * high_valid_pixel_seg.bool()

  class_mask_u = torch.sum(
    prob_indices_u[:, :, :, low_rank:high_rank].eq(i), dim=3
  ).bool()
  class_mask_l = torch.sum(prob_indices_l[:, :, :, :low_rank].eq(i), dim=3).bool()

  class_mask = torch.cat((class_mask_l * (label_l[:, i] == 0), class_mask_u), dim=0)
  negative_mask = rep_mask_high_entropy * class_mask
  keys = rep_teacher[negative_mask].detach()
  new_keys.append(dequeue_and_enqueue(keys=keys, ...))

For the labeled samples, negative_mask is formed by the conjunction label_l[:, i] == 0 (from class_mask) and label_l[:, i] == 1 (from high_valid_pixel), so it will always be False:

negative_mask_l = rep_mask_high_entropy[:NL] * (class_mask_l * (label_l[:, i] == 0))
    = (label_l * high_mask[:NL])[:, i] * (class_mask_l * (label_l[:, i] == 0))
    = (label_l * (label_l[:, i] == 0) * high_mask[:NL])[:, i] * class_mask_l
    = (0 * high_mask[:NL])[:, i] * class_mask_l
    = 0

For unlabeled samples, negative_mask will be true if high_valid_pixel & class_mask_u (the teacher says the pixel label is i with high entropy (high_valid_pixel), and the student says it isn't with high entropy (prob_indices[..., low_rank:high_rank].eq(i)).

Therefore, we believe all labeled reps are being discarded. Does that make sense?

Cheers,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant