Skip to content

Commit

Permalink
All S3 all the time
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Sep 24, 2024
1 parent 6de00fc commit 0103711
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
5 changes: 4 additions & 1 deletion tests/datasets/test_satlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
from torch import Tensor

from torchgeo.datasets import DatasetNotFoundError, SatlasPretrain
from torchgeo.datasets.utils import Executable


class TestSatlasPretrain:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SatlasPretrain:
def dataset(
self, aws: Executable, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SatlasPretrain:
root = os.path.join('tests', 'data', 'satlas')
images = ('landsat', 'naip', 'sentinel1', 'sentinel2')
products = (*images, 'static', 'metadata')
Expand Down
21 changes: 7 additions & 14 deletions torchgeo/datasets/satlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, check_integrity, download_url, extract_archive, which
from .utils import Path, check_integrity, extract_archive, which


class _Task(TypedDict, total=False):
Expand Down Expand Up @@ -477,9 +477,7 @@ class SatlasPretrain(NonGeoDataset):

# https://github.com/allenai/satlas/blob/main/satlaspretrain_urls.txt
urls: ClassVar[dict[str, tuple[str, ...]]] = {
'landsat': (
'https://pub-956f3eb0f5974f37b9228e0a62f449bf.r2.dev/satlaspretrain/satlas-dataset-v1-landsat.tar',
),
'landsat': ('s3://ai2-public-datasets/satlas/satlas-dataset-v1-landsat.tar',),
'naip': (
's3://ai2-public-datasets/satlas/satlas-dataset-v1-naip-2011.tar',
's3://ai2-public-datasets/satlas/satlas-dataset-v1-naip-2012.tar',
Expand All @@ -493,7 +491,7 @@ class SatlasPretrain(NonGeoDataset):
's3://ai2-public-datasets/satlas/satlas-dataset-v1-naip-2020.tar',
),
'sentinel1': (
'https://pub-956f3eb0f5974f37b9228e0a62f449bf.r2.dev/satlaspretrain/satlas-dataset-v1-sentinel1.tar',
's3://ai2-public-datasets/satlas/satlas-dataset-v1-sentinel1-new.tar',
),
'sentinel2': (
's3://ai2-public-datasets/satlas/satlas-dataset-v1-sentinel2-a.tar',
Expand All @@ -505,9 +503,7 @@ class SatlasPretrain(NonGeoDataset):
'dynamic': (
's3://ai2-public-datasets/satlas/satlas-dataset-v1-labels-dynamic.tar',
),
'metadata': (
's3://ai2-public-datasets/satlas/satlas-dataset-v1-metadata.tar',
),
'metadata': ('s3://ai2-public-datasets/satlas/satlas-dataset-v1-metadata.tar',),
}
# TODO
md5s: ClassVar[dict[str, tuple[str, ...]]] = {
Expand All @@ -524,7 +520,7 @@ class SatlasPretrain(NonGeoDataset):
'55b110cc6f734bf88793306d49f1c415',
'97fc8414334987c59593d574f112a77e',
),
'sentinel1': ('b0edc6b7af5995b04b8d780eec1246bf',),
'sentinel1': ('3d88a0a10df6ab0aa50db2ba4c475048',),
'sentinel2': (
'7e1c6a1e322807fb11df8c0c062545ca',
'6636b8ecf2fff1d6723ecfef55a4876d',
Expand Down Expand Up @@ -687,11 +683,8 @@ def _verify(self) -> None:
raise DatasetNotFoundError(self)

# Download and extract the tarball
if url.startswith('s3://'):
aws = which('aws')
aws('s3', 'cp', url, self.root)
else:
download_url(url, self.root)
aws = which('aws')
aws('s3', 'cp', url, self.root)
check_integrity(tarball, md5 if self.checksum else None)
extract_archive(tarball)

Expand Down

0 comments on commit 0103711

Please sign in to comment.