Skip to content

Commit

Permalink
able to decode codes back to continuous coordinates (9) for each face…
Browse files Browse the repository at this point in the history
…, then post processed into 3d asset
  • Loading branch information
lucidrains committed Dec 5, 2023
1 parent 27e045a commit b16923d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 7 deletions.
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ from meshgpt_pytorch import MeshAutoencoder
# autoencoder

autoencoder = MeshAutoencoder(
dim = 512
dim = 512,
encoder_depth = 6,
decoder_depth = 6,
num_discrete_coors = 128
)

# mock inputs
Expand Down Expand Up @@ -49,6 +52,13 @@ face_vertex_codes = autoencoder.tokenize(
)

# now train your transformer to generate this sequence of codes

# to decode back to continuous coordinates for each face (9 vertices)

# (batch, number of faces, vertex (3), coord (3))

face_seq_coords = autoencoder.decode_from_codes_to_faces(face_vertex_codes)

```

## Todo
Expand Down
34 changes: 29 additions & 5 deletions meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,35 @@ def decode(
return rearrange(x, 'b d n -> b n d')

@beartype
@torch.no_grad()
def decode_from_codes_to_faces(
self,
codes: Tensor
codes: Tensor,
return_discrete_codes = False
):
raise NotImplementedError
quantized = self.quantizer.get_output_from_indices(codes)
quantized = rearrange(quantized, 'b (nf nv) d -> b nf (nv d)', nv = 3)

face_embed_output = self.project_codebook_out(quantized)
decoded = self.decode(face_embed_output)

pred_face_coords = self.to_coor_logits(decoded)
pred_face_coords = pred_face_coords.argmax(dim = -1)

pred_face_coords = rearrange(pred_face_coords, '... (v c) -> ... v c', v = 3)

# back to continuous space

continuous_coors = undiscretize_coors(
pred_face_coords,
num_discrete = self.num_discrete_coors,
continuous_range = self.coor_continuous_range
)

if not return_discrete_codes:
return continuous_coors

return continuous_coors, pred_face_coords

def tokenize(self, *args, **kwargs):
assert 'return_codes' not in kwargs
Expand Down Expand Up @@ -318,12 +342,12 @@ def forward(

decode = self.decode(quantized)

pred_coor_bins = self.to_coor_logits(decode)
pred_face_coords = self.to_coor_logits(decode)

# reconstruction loss on discretized coordinates on each face

recon_loss = F.cross_entropy(
rearrange(pred_coor_bins, 'b ... c -> b c ...'),
rearrange(pred_face_coords, 'b ... c -> b c ...'),
face_coordinates
)

Expand Down Expand Up @@ -384,7 +408,7 @@ def generate(self):

def forward(
self,
c
codes
):
seq_len, device = x.shape[-2], device
assert divisible_by(seq_len, self.num_quantizers) == 0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'meshgpt-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'MeshGPT Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit b16923d

Please sign in to comment.