diff --git a/README.md b/README.md index c066b5a8..24f9f6ff 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,11 @@ $ pip install meshgpt-pytorch ```python import torch -from meshgpt_pytorch import MeshAutoencoder + +from meshgpt_pytorch import ( + MeshAutoencoder, + MeshTransformer +) # autoencoder @@ -37,12 +41,17 @@ vertices = torch.randn((2, 121, 3)) faces = torch.randint(0, 121, (2, 64, 3)) face_edges = torch.randint(0, 64, (2, 2, 96)) +face_len = torch.randint(1, 64, (2,)) +face_edges_len = torch.randint(1, 96, (2,)) + # forward in the faces loss = autoencoder( vertices = vertices, faces = faces, - face_edges = face_edges + face_edges = face_edges, + face_len = face_len, + face_edges_len = face_edges_len ) loss.backward() @@ -52,14 +61,17 @@ loss.backward() face_vertex_codes = autoencoder.tokenize( vertices = vertices, faces = faces, - face_edges = face_edges + face_edges = face_edges, + face_len = face_len, + face_edges_len = face_edges_len ) # now train your transformer to generate this sequence of codes transformer = MeshTransformer( autoencoder, - dim = 512 + dim = 512, + max_seq_len = 768 ) loss = transformer(face_vertex_codes) @@ -71,7 +83,6 @@ faces_coordinates = transformer.generate() # (batch, num faces, vertices (3), coordinates (3)) # now post process for the generated 3d asset - ``` ## Todo diff --git a/meshgpt_pytorch/meshgpt_pytorch.py b/meshgpt_pytorch/meshgpt_pytorch.py index 0b8422a2..a5567e00 100644 --- a/meshgpt_pytorch/meshgpt_pytorch.py +++ b/meshgpt_pytorch/meshgpt_pytorch.py @@ -31,6 +31,8 @@ from torch_geometric.nn.conv import SAGEConv +from tqdm import tqdm + # helper functions def exists(v): @@ -308,7 +310,7 @@ def quantize( vertices = torch.zeros((batch, num_vertices, vertex_dim), device = device) - # create pad vertex, due to variable lenghted faces + # create pad vertex, due to variable lengthed faces pad_vertex_id = num_vertices vertices = F.pad(vertices, (0, 0, 0, 1), value = 0.) @@ -579,7 +581,7 @@ def generate( curr_length = codes.shape[-1] - for i in range(curr_length, self.max_seq_len): + for i in tqdm(range(curr_length, self.max_seq_len)): can_eos = divisible_by(i + 1, self.num_quantizers * 3) # only allow for eos to be decoded at the end of each face, defined as 3 vertices with D residusl VQ codes logits = self.forward(codes, return_loss = False, append_eos = False) diff --git a/setup.py b/setup.py index 91b4afca..28ae270b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'meshgpt-pytorch', packages = find_packages(exclude=[]), - version = '0.0.12', + version = '0.0.14', license='MIT', description = 'MeshGPT Pytorch', author = 'Phil Wang', @@ -27,7 +27,8 @@ 'torch_geometric', 'torchtyping', 'vector-quantize-pytorch>=1.11.8', - 'x-transformers>=1.26.0' + 'x-transformers>=1.26.0', + 'tqdm' ], classifiers=[ 'Development Status :: 4 - Beta',