diff --git a/docs/usage/tutorials/sampling_training_data.ipynb b/docs/usage/tutorials/sampling_training_data.ipynb index 10cbd5cd0..f4d6de342 100644 --- a/docs/usage/tutorials/sampling_training_data.ipynb +++ b/docs/usage/tutorials/sampling_training_data.ipynb @@ -517,7 +517,8 @@ " # resize chips to 256x256 before returning\n", " out_size=256,\n", " # allow windows to overflow the extent by 100 pixels\n", - " padding=100\n", + " padding=100,\n", + " max_windows=10\n", ")\n", "\n", "img_full = ds.scene.raster_source[:, :]\n", diff --git a/docs/usage/tutorials/temporal.ipynb b/docs/usage/tutorials/temporal.ipynb index 6834926cb..2ac4f89b7 100644 --- a/docs/usage/tutorials/temporal.ipynb +++ b/docs/usage/tutorials/temporal.ipynb @@ -433,7 +433,7 @@ "source": [ "scene = Scene(id='test_scene', raster_source=raster_source)\n", "ds = SemanticSegmentationRandomWindowGeoDataset(\n", - " scene=scene, size_lims=(256, 256 + 1), out_size=256, return_window=True)" + " scene=scene, size_lims=(256, 256 + 1), out_size=256, max_windows=10, return_window=True)" ] }, { diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/dataset.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/dataset.py index a94ebb3f7..1f5b4c4e0 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/dataset.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/dataset.py @@ -283,12 +283,13 @@ class RandomWindowGeoDataset(GeoDataset): def __init__( self, scene: Scene, + *, out_size: PosInt | tuple[PosInt, PosInt] | None, size_lims: tuple[PosInt, PosInt] | None = None, h_lims: tuple[PosInt, PosInt] | None = None, w_lims: tuple[PosInt, PosInt] | None = None, padding: NonNegInt | tuple[NonNegInt, NonNegInt] | None = None, - max_windows: NonNegInt | None = None, + max_windows: NonNegInt, max_sample_attempts: PosInt = 100, efficient_aoi_sampling: bool = True, within_aoi: bool = True, @@ -316,8 +317,7 @@ def __init__( sides of the raster source. If ``None``, ``padding = size``. Defaults to ``None``. max_windows: Max allowed reads. Will raise ``StopIteration`` on - further read attempts. If None, will be set to ``np.inf``. - Defaults to ``None``. + further read attempts. transform: Albumentations transform to apply to the windows. Defaults to ``None``. Each transform in Albumentations takes images of type uint8, and @@ -384,9 +384,6 @@ def __init__( padding = (max_h // 2, max_w // 2) padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding) - if max_windows is None: - max_windows = np.iinfo('int').max - self.size_lims = size_lims self.h_lims = h_lims self.w_lims = w_lims diff --git a/tests/pytorch_learner/dataset/test_dataset.py b/tests/pytorch_learner/dataset/test_dataset.py index a582e1c4c..bf4136660 100644 --- a/tests/pytorch_learner/dataset/test_dataset.py +++ b/tests/pytorch_learner/dataset/test_dataset.py @@ -213,8 +213,9 @@ def test_sample_window_within_aoi(self): ds = RandomWindowGeoDataset( scene, - 10, - (5, 6), + out_size=10, + size_lims=(5, 6), + max_windows=10, within_aoi=True, transform_type=TransformType.noop, ) @@ -222,8 +223,9 @@ def test_sample_window_within_aoi(self): ds = RandomWindowGeoDataset( scene, - 10, - (12, 13), + out_size=10, + size_lims=(12, 13), + max_windows=10, within_aoi=True, transform_type=TransformType.noop, ) @@ -231,8 +233,9 @@ def test_sample_window_within_aoi(self): ds = RandomWindowGeoDataset( scene, - 10, - (12, 13), + out_size=10, + size_lims=(12, 13), + max_windows=10, within_aoi=False, transform_type=TransformType.noop, ) @@ -245,6 +248,7 @@ def test_init_validation(self): args = dict( scene=scene, out_size=10, + max_windows=10, transform_type=TransformType.noop, ) self.assertRaises(ValueError, lambda: RandomWindowGeoDataset(**args)) @@ -255,6 +259,7 @@ def test_init_validation(self): out_size=10, size_lims=(10, 11), h_lims=(10, 11), + max_windows=10, transform_type=TransformType.noop, ) self.assertRaises(ValueError, lambda: RandomWindowGeoDataset(**args)) @@ -266,6 +271,7 @@ def test_init_validation(self): size_lims=(10, 11), h_lims=(10, 11), w_lims=(10, 11), + max_windows=10, transform_type=TransformType.noop, ) self.assertRaises(ValueError, lambda: RandomWindowGeoDataset(**args)) @@ -275,6 +281,7 @@ def test_init_validation(self): scene=scene, out_size=10, w_lims=(10, 11), + max_windows=10, transform_type=TransformType.noop, ) self.assertRaises(ValueError, lambda: RandomWindowGeoDataset(**args)) @@ -284,6 +291,7 @@ def test_init_validation(self): scene, out_size=None, size_lims=(12, 13), + max_windows=10, transform_type=TransformType.noop, ) self.assertFalse(ds.normalize) @@ -295,6 +303,7 @@ def test_init_validation(self): out_size=None, h_lims=(10, 11), w_lims=(10, 11), + max_windows=10, transform_type=TransformType.noop, ) self.assertTupleEqual(ds.padding, (5, 5)) @@ -305,6 +314,7 @@ def test_min_max_size(self): scene, out_size=None, size_lims=(10, 15), + max_windows=10, transform_type=TransformType.noop, ) self.assertTupleEqual(ds.min_size, (10, 10)) @@ -315,6 +325,7 @@ def test_min_max_size(self): out_size=None, h_lims=(10, 15), w_lims=(8, 12), + max_windows=10, transform_type=TransformType.noop, ) self.assertTupleEqual(ds.min_size, (10, 8)) @@ -326,6 +337,7 @@ def test_sample_window_size(self): scene, out_size=None, size_lims=(10, 15), + max_windows=10, transform_type=TransformType.noop, ) sampled_h, sampled_w = ds.sample_window_size() @@ -337,6 +349,7 @@ def test_sample_window_size(self): out_size=None, h_lims=(10, 15), w_lims=(8, 12), + max_windows=10, transform_type=TransformType.noop, ) sampled_h, sampled_w = ds.sample_window_size() @@ -360,6 +373,7 @@ def test_return_window(self): scene, out_size=10, size_lims=(5, 6), + max_windows=10, transform_type=TransformType.noop, return_window=True, ) @@ -376,6 +390,7 @@ def test_triangle_missing(self): scene=scene, out_size=10, size_lims=(5, 6), + max_windows=10, transform_type=TransformType.noop, ) self.assertNoError(lambda: RandomWindowGeoDataset(**args))