diff --git a/modules.py b/modules.py index 32f2764..bf8ab8a 100644 --- a/modules.py +++ b/modules.py @@ -80,7 +80,7 @@ def forward(self, z_e_x): def straight_through(self, z_e_x): z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous() - z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight) + z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach()) z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous() z_q_x_bar_flatten = torch.index_select(self.embedding.weight,