diff --git a/research/rxrx1/data/dataset.py b/research/rxrx1/data/dataset.py index ca3f73baa..0d17521ad 100644 --- a/research/rxrx1/data/dataset.py +++ b/research/rxrx1/data/dataset.py @@ -44,8 +44,7 @@ def __getitem__(self, idx: int) -> tuple[torch.Tensor, int, int]: images.append(image) concatenated_image = torch.cat(images, dim=0) - print(concatenated_image.shape) - return concatenated_image, label, row["sirna_id"] + return concatenated_image, label def load_image(self, path: str) -> torch.Tensor: if not Path(path).exists():