Skip to content

Commit

Permalink
autoencoder without variable lengths runs
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 5, 2023
1 parent e667f83 commit bebb59c
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 17 deletions.
46 changes: 46 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,57 @@ Implementation of <a href="https://arxiv.org/abs/2311.15475">MeshGPT</a>, SOTA M

Will also add text conditioning, for eventual text-to-3d asset

## Install

```bash
$ pip install meshgpt-pytorch
```

## Usage

```python
import torch
from meshgpt_pytorch import MeshAutoencoder

# autoencoder

autoencoder = MeshAutoencoder(
dim = 512
)

# mock inputs

vertices = torch.randn((2, 121, 3))
faces = torch.randint(0, 121, (2, 64, 3))
face_edges = torch.randint(0, 64, (2, 2, 96))

# forward in the faces

loss = autoencoder(
vertices = vertices,
faces = faces,
face_edges = face_edges
)

loss.backward()

# after much training...

face_vertex_codes = autoencoder.tokenize(
vertices = vertices,
faces = faces,
face_edges = face_edges
)

# now train your transformer to generate this sequence of codes
```

## Todo

- [ ] autoencoder
- [x] encoder sageconv with torch geometric
- [x] proper scatter mean accounting for padding for meaning the vertices and RVQ the vertices before gathering back for decoder
- [x] complete decoder and reconstruction loss + commitment loss
- [ ] xcit linear attention in both encoder / decoder
- [ ] add option to use residual FSQ / LFQ, latest quantization development
- [ ] handle variable lengthed faces last - use sink tokens when scattering
Expand Down
107 changes: 90 additions & 17 deletions meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
def exists(v):
return v is not None

def default(v, d):
return v if exists(v) else d

# tensor helper functions

@beartype
Expand Down Expand Up @@ -63,6 +66,37 @@ def undiscretize_coors(
t /= num_discrete
return t * (hi - lo) + lo

# resnet block

class Block(Module):
def __init__(self, dim, groups = 8):
super().__init__()
self.proj = nn.Conv1d(dim, dim, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim)
self.act = nn.SiLU()

def forward(self, x):
x = self.proj(x)
x = self.norm(x)
x = self.act(x)
return x

class ResnetBlock(Module):
def __init__(
self,
dim,
*,
groups = 8
):
super().__init__()
self.block1 = Block(dim, groups = groups)
self.block2 = Block(dim, groups = groups)

def forward(self, x):
h = self.block1(x)
h = self.block2(h)
return h + x

# main classes

class MeshAutoencoder(Module):
Expand All @@ -78,7 +112,8 @@ def __init__(
dim_codebook = 192,
num_quantizers = 2, # or 'D' in the paper
codebook_size = 16384, # they use 16k, shared codebook between layers
rq_kwargs: dict = dict()
rq_kwargs: dict = dict(),
commit_loss_weight = 0.1,
):
super().__init__()

Expand All @@ -105,17 +140,19 @@ def __init__(
num_quantizers = num_quantizers,
codebook_size = codebook_size,
shared_codebook = True,
commitment_weight = 1.,
**rq_kwargs
)

self.commit_loss_weight = commit_loss_weight

self.project_codebook_out = nn.Linear(dim_codebook * 3, dim)

self.decoders = ModuleList([])

for _ in range(decoder_depth):
sage_conv = SAGEConv(dim, dim)

self.decoders.append(sage_conv)
resnet_block = ResnetBlock(dim)
self.decoders.append(resnet_block)

self.to_coor_logits = nn.Sequential(
nn.Linear(dim, num_discrete_coors * 9),
Expand All @@ -128,7 +165,8 @@ def encode(
*,
vertices: TensorType['b', 'nv', 3, int],
faces: TensorType['b', 'nf', 3, int],
face_edges: TensorType['b', 2, 'e', int]
face_edges: TensorType['b', 2, 'e', int],
return_face_coordinates = False
):
"""
einops:
Expand Down Expand Up @@ -157,7 +195,7 @@ def encode(
batch_offset = batch_arange * num_faces
batch_offset = rearrange(batch_offset, 'b -> b 1 1')

face_edges += batch_offset
face_edges = face_edges + batch_offset
face_edges = rearrange(face_edges, 'b ij e -> ij (b e)')

x = rearrange(face_embed, 'b nf d -> (b nf) d')
Expand All @@ -167,7 +205,10 @@ def encode(

x = rearrange(x, '(b nf) d -> b nf d', b = batch)

return x
if not return_face_coordinates:
return x

return x, face_coords

@beartype
def quantize(
Expand Down Expand Up @@ -209,6 +250,8 @@ def quantize(
face_embed_output = quantized.gather(-2, faces_with_dim)
face_embed_output = rearrange(face_embed_output, 'b (nf nv) d -> b nf (nv d)', nv = 3)

face_embed_output = self.project_codebook_out(face_embed_output)

# vertex codes also need to be gathered to be organized by face sequence
# for autoregressive learning

Expand All @@ -220,51 +263,81 @@ def quantize(
@beartype
def decode(
self,
codes
quantized: TensorType['b', 'n', 'd', float]
):
raise NotImplementedError
quantized = rearrange(quantized, 'b n d -> b d n')

x = quantized

for resnet_block in self.decoders:
x = resnet_block(x)

return rearrange(x, 'b d n -> b n d')

@beartype
def decode_from_codes_to_vertices(
def decode_from_codes_to_faces(
self,
codes: Tensor
) -> Tensor:
):
raise NotImplementedError

def tokenize(self, *args, **kwargs):
assert 'return_codes' not in kwargs
return self.forward(*args, return_codes = True, **kwargs)

@beartype
def forward(
self,
*,
vertices: Tensor,
faces: Tensor,
face_edges: Tensor,
return_quantized = False
return_codes = False,
return_loss_breakdown = False
):
discretized_vertices = discretize_coors(
vertices,
num_discrete = self.num_discrete_coors,
continuous_range = self.coor_continuous_range,
)

encoded = self.encode(
encoded, face_coordinates = self.encode(
vertices = discretized_vertices,
faces = faces,
face_edges = face_edges
face_edges = face_edges,
return_face_coordinates = True
)

quantized, codes, commit_loss = self.quantize(
face_embed = encoded,
faces = faces
)

if return_quantized:
return quantized
if return_codes:
return codes

decode = self.decode(quantized)

pred_coor_bins = self.to_coor_logits(decode)

return loss
# reconstruction loss on discretized coordinates on each face

recon_loss = F.cross_entropy(
rearrange(pred_coor_bins, 'b ... c -> b c ...'),
face_coordinates
)

# calculate total loss

total_loss = recon_loss + \
commit_loss.sum() * self.commit_loss_weight

if not return_loss_breakdown:
return total_loss

loss_breakdown = (recon_loss, commit_loss)

return recon_loss, loss_breakdown

class MeshGPT(Module):
@beartype
Expand Down

0 comments on commit bebb59c

Please sign in to comment.