Skip to content

Commit

Permalink
update stdit dtype (#392)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengzangw authored May 11, 2024
1 parent 09e53db commit ea41df3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions opensora/models/stdit/stdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,10 @@ def forward(self, x, timestep, y, mask=None, x_mask=None):
Returns:
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
"""

x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
dtype = self.x_embedder.proj.weight.dtype
x = x.to(dtype)
timestep = timestep.to(dtype)
y = y.to(dtype)

# embedding
x = self.x_embedder(x) # [B, N, C]
Expand Down

0 comments on commit ea41df3

Please sign in to comment.