Skip to content

Commit

Permalink
recon loss must account for variable lengthed faces
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 6, 2023
1 parent ddb3a51 commit dcf46a9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,14 +401,17 @@ def decode_from_codes_to_faces(
quantized = self.quantizer.get_output_from_indices(codes)
quantized = rearrange(quantized, 'b (nf nv) d -> b nf (nv d)', nv = 3)

quantized = quantized.masked_fill(~face_mask[..., None], 0.)
face_embed_output = self.project_codebook_out(quantized)

decoded = self.decode(
face_embed_output,
face_mask = face_mask
)

decoded = decoded.masked_fill(~face_mask[..., None], 0.)
pred_face_coords = self.to_coor_logits(decoded)

pred_face_coords = pred_face_coords.argmax(dim = -1)

pred_face_coords = rearrange(pred_face_coords, '... (v c) -> ... v c', v = 3)
Expand Down Expand Up @@ -497,7 +500,10 @@ def forward(

# cross entropy with localized smoothing

recon_loss = (-target_one_hot * pred_log_prob).sum(dim = 1).mean()
recon_losses = (-target_one_hot * pred_log_prob).sum(dim = 1)

face_mask = repeat(face_mask, 'b nf -> b (nf r)', r = 9)
recon_loss = recon_losses[face_mask].mean()

# calculate total loss

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'meshgpt-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.17',
version = '0.0.18',
license='MIT',
description = 'MeshGPT Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit dcf46a9

Please sign in to comment.