Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continuation of #858 for #839 Cifar10 revision #933

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
c9ecaae
remove torchvision dependency and write tests for cifar10
BaruchG Aug 8, 2022
46f224f
insert torchvision dependency and write tests for cifar10
BaruchG Aug 8, 2022
01d6f1e
removed print
BaruchG Aug 8, 2022
45c9a9d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 8, 2022
be1fa15
Merge remote-tracking branch 'upstream/master'
BaruchG Aug 19, 2022
86482ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2022
adc57b9
cleanup failed merge
Aug 23, 2022
a2082e3
Merge branch 'master' into BaruchG/master
Aug 23, 2022
f8ba0e6
merge master
Aug 25, 2022
7230608
Merge remote-tracking branch 'upstream/master'
BaruchG Sep 6, 2022
36d57f8
Merge remote-tracking branch 'upstream/master'
BaruchG Sep 9, 2022
88eb870
Merge remote-tracking branch 'upstream/master'
BaruchG Sep 16, 2022
b1995a7
Revert "insert torchvision dependency and write tests for cifar10"
BaruchG Sep 16, 2022
2533023
Merge remote-tracking branch 'upstream/master'
BaruchG Sep 20, 2022
57e99e5
revert cifar10
BaruchG Sep 20, 2022
22019af
revert cifar10
BaruchG Sep 20, 2022
486d9f8
Merge remote-tracking branch 'upstream/master'
BaruchG Sep 30, 2022
8b3133a
Merge remote-tracking branch 'upstream/master'
BaruchG Oct 27, 2022
8ab4343
Merge branch 'Lightning-AI:master' into master
BaruchG Oct 27, 2022
030b17c
Merge branch 'master' of https://github.com/BaruchG/lightning-bolts
BaruchG Oct 27, 2022
31eaed9
Merge branch 'master' of https://github.com/BaruchG/lightning-bolts
BaruchG Nov 2, 2022
55a05c3
fixed cifar10 import
BaruchG Nov 2, 2022
4539532
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2022
bdc4b6d
torchvision requirement
BaruchG Nov 2, 2022
9eaf93b
Merge remote-tracking branch 'upstream/master'
BaruchG Nov 3, 2022
42356c4
Merge branch 'master' into cifar
BaruchG Nov 3, 2022
bc7a314
changed arguments for cifar10 in datamodule tests to reflect torchvision
BaruchG Nov 3, 2022
c42472c
Merge branch 'cifar' of https://github.com/BaruchG/lightning-bolts in…
BaruchG Nov 3, 2022
cf66fc0
reformat docstring
BaruchG Nov 3, 2022
cb43285
Merge branch 'master' into cifar
BaruchG Nov 3, 2022
42e3a26
Merge branch 'master' into cifar
BaruchG Dec 14, 2022
97cc4f4
Update __init__.py
BaruchG Dec 14, 2022
300b139
Merge branch 'master' into HEAD
Borda Dec 15, 2022
a6fcde0
noqa fix
BaruchG Dec 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions =1.7.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Collecting package metadata (current_repodata.json): ...working... done
Solving environment: ...working... failed with initial frozen solve. Retrying with flexible solve.
Collecting package metadata (repodata.json): ...working... done
Solving environment: ...working... failed with initial frozen solve. Retrying with flexible solve.
3 changes: 2 additions & 1 deletion pl_bolts/datasets/cifar10_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from torch import Tensor
from torchvision.datasets import CIFAR10

from pl_bolts.datasets import LightDataset
from pl_bolts.datasets.utils import safe_extract_tarfile
Expand All @@ -19,7 +20,7 @@


@under_review()
class CIFAR10(LightDataset):
class IndependentCIFAR10(LightDataset):
"""Customized `CIFAR10 <http://www.cs.toronto.edu/~kriz/cifar.html>`_ dataset for testing Pytorch Lightning
without the torchvision dependency.

Expand Down
2 changes: 1 addition & 1 deletion tests/datamodules/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def test_async_dataloader(datadir):
ds = CIFAR10(data_dir=datadir)
ds = CIFAR10(root=datadir)

if torch.cuda.device_count() > 0: # Can only run this test with a GPU
device = torch.device("cuda", 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/datamodules/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


def test_dev_datasets(datadir):
ds = CIFAR10(data_dir=datadir)
ds = CIFAR10(root=datadir)
for _ in ds:
pass

Expand Down
20 changes: 20 additions & 0 deletions tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
RandomDictDataset,
RandomDictStringDataset,
)
from pl_bolts.datasets.cifar10_dataset import CIFAR10
from pl_bolts.datasets.dummy_dataset import DummyDetectionDataset
from pl_bolts.datasets.sr_mnist_dataset import SRMNIST
from pl_bolts.utils import _PIL_AVAILABLE
Expand Down Expand Up @@ -147,6 +148,25 @@ def test_sr_datasets(datadir, scale_factor):
assert torch.allclose(lr_image.max(), torch.tensor(1.0), atol=atol)


def test_cifar10_datasets(datadir):
transform = transform_lib.Compose(
[transform_lib.ToTensor(), transform_lib.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
dl = DataLoader(CIFAR10(root=datadir, download=True, transform=transform))
hr_image, lr_image = next(iter(dl))
print("==============================", lr_image.size())

hr_image_size = 32
assert hr_image.size() == torch.Size([1, 3, hr_image_size, hr_image_size])
assert lr_image.size() == torch.Size([1])

atol = 0.3
assert torch.allclose(hr_image.min(), torch.tensor(-1.0), atol=atol)
assert torch.allclose(hr_image.max(), torch.tensor(1.0), atol=atol)
assert torch.greater_equal(lr_image.min(), torch.tensor(0))
assert torch.less_equal(lr_image.max(), torch.tensor(9))


def test_binary_mnist_dataset(datadir):
"""Check BinaryMNIST image and target dimensions and value range."""
dl = DataLoader(BinaryMNIST(root=datadir, download=True, transform=transform_lib.ToTensor()))
Expand Down