Skip to content

Commit

Permalink
Return timestamp
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Sep 27, 2024
1 parent 5e4e715 commit 082b853
Show file tree
Hide file tree
Showing 10 changed files with 24 additions and 2 deletions.
9 changes: 9 additions & 0 deletions tests/data/satlas/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
[7149, 3246, 'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235'],
[1234, 5678, 'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235'],
]
times = {
'2022-03': '2022-03-01T00:00:00+00:00',
'm_3808245_se_17_1_20110801': '2011-08-01T12:00:00+00:00',
'2022-01': '2022-01-01T00:00:00+00:00',
'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235': '2022-03-09T06:02:35+00:00',
}

FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str]
filenames: FILENAME_HIERARCHY = {
Expand Down Expand Up @@ -97,6 +103,9 @@ def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None:
with open(os.path.join('metadata', 'good_images_lowres_all.json'), 'w') as f:
json.dump(good_images, f)

with open(os.path.join('metadata', 'image_times.json'), 'w') as f:
json.dump(times, f)

for path in os.listdir('.'):
if os.path.isdir(path):
shutil.make_archive(path, 'tar', '.', path)
Binary file modified tests/data/satlas/landsat.tar
Binary file not shown.
Binary file modified tests/data/satlas/metadata.tar
Binary file not shown.
1 change: 1 addition & 0 deletions tests/data/satlas/metadata/image_times.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"2022-03": "2022-03-01T00:00:00+00:00", "m_3808245_se_17_1_20110801": "2011-08-01T12:00:00+00:00", "2022-01": "2022-01-01T00:00:00+00:00", "S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235": "2022-03-09T06:02:35+00:00"}
Binary file modified tests/data/satlas/naip.tar
Binary file not shown.
Binary file modified tests/data/satlas/sentinel1.tar
Binary file not shown.
Binary file modified tests/data/satlas/sentinel2.tar
Binary file not shown.
Binary file modified tests/data/satlas/static.tar
Binary file not shown.
1 change: 1 addition & 0 deletions tests/datasets/test_satlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_getitem(self, dataset: SatlasPretrain, index: int) -> None:
assert isinstance(x, dict)
for image in dataset.images:
assert isinstance(x[f'image_{image}'], Tensor)
assert isinstance(x[f'time_{image}'], Tensor)
for label in dataset.labels:
assert isinstance(x[f'mask_{label}'], Tensor)

Expand Down
15 changes: 13 additions & 2 deletions torchgeo/datasets/satlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ def __init__(
root: Path = 'data',
split: str = 'train_lowres',
good_images: str = 'good_images_lowres_all',
image_times: str = 'image_times',
images: Iterable[str] = ('sentinel1', 'sentinel2', 'landsat'),
labels: Iterable[str] = ('land_cover',),
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
Expand All @@ -550,6 +551,7 @@ def __init__(
root: Root directory where dataset can be found.
split: Metadata split to load.
good_images: Metadata mapping between col/row and directory.
image_times: Metadata mapping between directory and ISO time.
images: List of image products.
labels: List of label products.
transforms: A function/transform that takes input sample and its target as
Expand All @@ -572,10 +574,17 @@ def __init__(

self._verify()

self.split = pd.read_json(os.path.join(root, 'metadata', f'{split}.json'))
# Read metadata files
self.split = pd.read_json(
os.path.join(root, 'metadata', f'{split}.json'), typ='frame'
)
self.good_images = pd.read_json(
os.path.join(root, 'metadata', f'{good_images}.json')
os.path.join(root, 'metadata', f'{good_images}.json'), typ='frame'
)
self.image_times = pd.read_json(
os.path.join(root, 'metadata', f'{image_times}.json'), typ='series'
)

self.split.columns = ['col', 'row']
self.good_images.columns = ['col', 'row', 'directory']
self.good_images = self.good_images.groupby(['col', 'row'])
Expand Down Expand Up @@ -646,6 +655,8 @@ def _load_image(
# Choose a random timestamp
idx = torch.randint(len(good_directories), (1,))
directory = good_directories[idx]
time = self.image_times[directory].timestamp()
sample[f'time_{image}'] = torch.tensor(time)

# Load all bands
channels = []
Expand Down

0 comments on commit 082b853

Please sign in to comment.