Skip to content

Commit

Permalink
make sure readme runs
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 5, 2023
1 parent 9b07af0 commit 8697673
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
21 changes: 16 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ $ pip install meshgpt-pytorch

```python
import torch
from meshgpt_pytorch import MeshAutoencoder

from meshgpt_pytorch import (
MeshAutoencoder,
MeshTransformer
)

# autoencoder

Expand All @@ -37,12 +41,17 @@ vertices = torch.randn((2, 121, 3))
faces = torch.randint(0, 121, (2, 64, 3))
face_edges = torch.randint(0, 64, (2, 2, 96))

face_len = torch.randint(1, 64, (2,))
face_edges_len = torch.randint(1, 96, (2,))

# forward in the faces

loss = autoencoder(
vertices = vertices,
faces = faces,
face_edges = face_edges
face_edges = face_edges,
face_len = face_len,
face_edges_len = face_edges_len
)

loss.backward()
Expand All @@ -52,14 +61,17 @@ loss.backward()
face_vertex_codes = autoencoder.tokenize(
vertices = vertices,
faces = faces,
face_edges = face_edges
face_edges = face_edges,
face_len = face_len,
face_edges_len = face_edges_len
)

# now train your transformer to generate this sequence of codes

transformer = MeshTransformer(
autoencoder,
dim = 512
dim = 512,
max_seq_len = 768
)

loss = transformer(face_vertex_codes)
Expand All @@ -71,7 +83,6 @@ faces_coordinates = transformer.generate()

# (batch, num faces, vertices (3), coordinates (3))
# now post process for the generated 3d asset

```

## Todo
Expand Down
6 changes: 4 additions & 2 deletions meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

from torch_geometric.nn.conv import SAGEConv

from tqdm import tqdm

# helper functions

def exists(v):
Expand Down Expand Up @@ -308,7 +310,7 @@ def quantize(

vertices = torch.zeros((batch, num_vertices, vertex_dim), device = device)

# create pad vertex, due to variable lenghted faces
# create pad vertex, due to variable lengthed faces

pad_vertex_id = num_vertices
vertices = F.pad(vertices, (0, 0, 0, 1), value = 0.)
Expand Down Expand Up @@ -579,7 +581,7 @@ def generate(

curr_length = codes.shape[-1]

for i in range(curr_length, self.max_seq_len):
for i in tqdm(range(curr_length, self.max_seq_len)):
can_eos = divisible_by(i + 1, self.num_quantizers * 3) # only allow for eos to be decoded at the end of each face, defined as 3 vertices with D residusl VQ codes

logits = self.forward(codes, return_loss = False, append_eos = False)
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'meshgpt-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.12',
version = '0.0.14',
license='MIT',
description = 'MeshGPT Pytorch',
author = 'Phil Wang',
Expand All @@ -27,7 +27,8 @@
'torch_geometric',
'torchtyping',
'vector-quantize-pytorch>=1.11.8',
'x-transformers>=1.26.0'
'x-transformers>=1.26.0',
'tqdm'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit 8697673

Please sign in to comment.