diff --git a/multiviewdata/torchdatasets/mnist.py b/multiviewdata/torchdatasets/mnist.py index 88a977d..9e7513b 100644 --- a/multiviewdata/torchdatasets/mnist.py +++ b/multiviewdata/torchdatasets/mnist.py @@ -35,8 +35,8 @@ def __len__(self): def __getitem__(self, idx): x_a, label = self.dataset[idx] - x_b = x_a[:, :, 14:]/255. - x_a = x_a[:, :, :14]/255. + x_b = x_a[:, :, 14:] + x_a = x_a[:, :, :14] if self.flatten: x_a = torch.flatten(x_a) x_b = torch.flatten(x_b)