Skip to content

Commit

Permalink
update config validators and fix some type hints
Browse files Browse the repository at this point in the history
- v1's validator and root_validator have been changed to field_validator and model_validator respectively
- move some type definitions to core.utils.types
- update some type hints to newer syntax
  • Loading branch information
AdeelH committed Jul 1, 2024
1 parent 07508f1 commit be455b9
Show file tree
Hide file tree
Showing 20 changed files with 233 additions and 242 deletions.
4 changes: 1 addition & 3 deletions rastervision_core/rastervision/core/box.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import (TYPE_CHECKING, Callable, Dict, List, Literal, Optional,
Sequence, Tuple, Union)
from pydantic import PositiveInt as PosInt, conint
from pydantic import NonNegativeInt as NonNegInt, PositiveInt as PosInt
import math
import random

Expand All @@ -11,8 +11,6 @@

from rastervision.pipeline.utils import repr_with_args

NonNegInt = conint(ge=0)

if TYPE_CHECKING:
from shapely.geometry import MultiPolygon

Expand Down
55 changes: 28 additions & 27 deletions rastervision_core/rastervision/core/data/class_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Self, Tuple, Union

from rastervision.pipeline.config import (Config, register_config, ConfigError,
Field, validator)
Field, model_validator)
from rastervision.core.data.utils import color_to_triple, normalize_color

DEFAULT_NULL_CLASS_NAME = 'null'
Expand Down Expand Up @@ -32,42 +32,43 @@ class ClassConfig(Config):
'Config is part of a SemanticSegmentationConfig, a null class will be '
'added automatically.')

@validator('colors', always=True)
def validate_colors(cls, v: Optional[List[Union[str, Tuple]]],
values: dict) -> Optional[List[Union[str, Tuple]]]:
@model_validator(mode='after')
def validate_colors(self) -> Self:
"""Compare length w/ names. Also auto-generate if not specified."""
class_names = values['names']
class_colors = v
if class_colors is None:
class_colors = [color_to_triple() for _ in class_names]
elif len(class_names) != len(class_colors):
raise ConfigError(f'len(class_names) ({len(class_names)}) != '
f'len(class_colors) ({len(class_colors)})\n'
f'class_names: {class_names}\n'
f'class_colors: {class_colors}')
return class_colors

@validator('null_class', always=True)
def validate_null_class(cls, v: Optional[str],
values: dict) -> Optional[str]:
names = self.names
colors = self.colors
if colors is None:
self.colors = [color_to_triple() for _ in names]
elif len(names) != len(colors):
raise ConfigError(f'len(class_names) ({len(names)}) != '
f'len(class_colors) ({len(colors)})\n'
f'class_names: {names}\n'
f'class_colors: {colors}')
return self

@model_validator(mode='after')
def validate_null_class(self) -> Self:
"""Check if in names. If 'null' in names, use it as null class."""
names = values['names']
if v is None:
names = self.names
null_class = self.null_class
if null_class is None:
if DEFAULT_NULL_CLASS_NAME in names:
v = DEFAULT_NULL_CLASS_NAME
self.null_class = DEFAULT_NULL_CLASS_NAME
else:
if v not in names:
if null_class not in names:
raise ConfigError(
f'The null_class, "{v}", must be in list of class names.')
f'The null_class, "{null_class}", must be in list of '
'class names.')

# edge case
default_null_class_in_names = (DEFAULT_NULL_CLASS_NAME in names)
null_class_neq_default = (v != DEFAULT_NULL_CLASS_NAME)
null_class_neq_default = (null_class != DEFAULT_NULL_CLASS_NAME)
if default_null_class_in_names and null_class_neq_default:
raise ConfigError(
f'"{DEFAULT_NULL_CLASS_NAME}" is in names but the '
f'specified null_class is something else ("{v}").')
return v
'specified null_class is something else '
f'("{null_class}").')
return self

def get_class_id(self, name: str) -> int:
return self.names.index(name)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Optional
from typing import Optional, Self

from rastervision.core.data.vector_source import (VectorSourceConfig)
from rastervision.core.data.label_source import (LabelSourceConfig,
ChipClassificationLabelSource)
from rastervision.pipeline.config import (ConfigError, register_config, Field,
validator, root_validator)
field_validator, model_validator)
from rastervision.core.data.vector_transformer import (
ClassInferenceTransformerConfig, BufferTransformerConfig)

Expand Down Expand Up @@ -62,7 +62,8 @@ class ChipClassificationLabelSourceConfig(LabelSourceConfig):
description='If True, labels will not be populated automatically '
'during initialization of the label source.')

@validator('vector_source')
@field_validator('vector_source')
@classmethod
def ensure_required_transformers(
cls, v: VectorSourceConfig) -> VectorSourceConfig:
"""Add class-inference and buffer transformers if absent."""
Expand All @@ -84,14 +85,14 @@ def ensure_required_transformers(

return v

@root_validator(skip_on_failure=True)
def ensure_bg_class_id_if_inferring(cls, values: dict) -> dict:
infer_cells = values.get('infer_cells')
has_bg_class_id = values.get('background_class_id') is not None
@model_validator(mode='after')
def ensure_bg_class_id_if_inferring(self) -> Self:
infer_cells = self.infer_cells
has_bg_class_id = self.background_class_id is not None
if infer_cells and not has_bg_class_id:
raise ConfigError(
'background_class_id is required if infer_cells=True.')
return values
return self

def build(self, class_config, crs_transformer, bbox=None,
tmp_dir=None) -> ChipClassificationLabelSource:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from rastervision.core.data.vector_source import VectorSourceConfig
from rastervision.core.data.vector_transformer import (
ClassInferenceTransformerConfig, BufferTransformerConfig)
from rastervision.pipeline.config import register_config, validator
from rastervision.pipeline.config import register_config, field_validator


@register_config('object_detection_label_source')
Expand All @@ -12,7 +12,8 @@ class ObjectDetectionLabelSourceConfig(LabelSourceConfig):

vector_source: VectorSourceConfig

@validator('vector_source')
@field_validator('vector_source')
@classmethod
def ensure_required_transformers(
cls, v: VectorSourceConfig) -> VectorSourceConfig:
"""Add class-inference and buffer transformers if absent."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Optional, Sequence, Self, Tuple
from pydantic import conint

from pydantic import NonNegativeInt as NonNegInt
import numpy as np
from pystac import Item

Expand All @@ -18,29 +18,28 @@ class MultiRasterSource(RasterSource):

def __init__(self,
raster_sources: Sequence[RasterSource],
primary_source_idx: conint(ge=0) = 0,
primary_source_idx: NonNegInt = 0,
force_same_dtype: bool = False,
channel_order: Optional[Sequence[conint(ge=0)]] = None,
channel_order: Sequence[NonNegInt] | None = None,
raster_transformers: Sequence = [],
bbox: Optional[Box] = None):
bbox: Box | None = None):
"""Constructor.
Args:
raster_sources (Sequence[RasterSource]): Sequence of RasterSources.
raster_sources: Sequence of RasterSources.
primary_source_idx (0 <= int < len(raster_sources)): Index of the
raster source whose CRS, dtype, and other attributes will
override those of the other raster sources.
force_same_dtype (bool): If true, force all sub-chips to have the
force_same_dtype: If true, force all sub-chips to have the
same dtype as the primary_source_idx-th sub-chip. No careful
conversion is done, just a quick cast. Use with caution.
channel_order (Sequence[conint(ge=0)], optional): Channel ordering
that will be used by .get_chip(). Defaults to None.
raster_transformers (Sequence, optional): Sequence of transformers.
Defaults to [].
bbox (Optional[Box], optional): User-specified crop of the extent.
If given, the primary raster source's bbox is set to this.
If None, the full extent available in the source file of the
primary raster source is used.
channel_order: Channel ordering that will be used by
:meth:`MultiRasterSource.get_chip()`. Defaults to ``None``.
raster_transformers: List of transformers. Defaults to ``[]``.
bbox: User-specified crop of the extent. If specified, the primary
raster source's bbox is set to this. If ``None``, the full
extent available in the source file of the primary raster
source is used.
"""
num_channels_raw = sum(rs.num_channels for rs in raster_sources)
if not channel_order:
Expand Down Expand Up @@ -78,7 +77,7 @@ def from_stac(
cls,
item: Item,
assets: list[str] | None,
primary_source_idx: conint(ge=0) = 0,
primary_source_idx: NonNegInt = 0,
raster_transformers: list['RasterTransformer'] = [],
force_same_dtype: bool = False,
channel_order: Sequence[int] | None = None,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Optional
from pydantic import conint, conlist
from typing import List, Optional, Self

from rastervision.pipeline.config import (Field, register_config, validator)
from typing_extensions import Annotated
from pydantic import NonNegativeInt as NonNegInt

from rastervision.pipeline.config import (Field, register_config,
model_validator)
from rastervision.core.box import Box
from rastervision.core.data.raster_source import (RasterSourceConfig,
MultiRasterSource)
Expand All @@ -25,10 +28,10 @@ class MultiRasterSourceConfig(RasterSourceConfig):
Or :class:`.TemporalMultiRasterSource`, if ``temporal=True``.
"""

raster_sources: conlist(
RasterSourceConfig, min_items=1) = Field(
raster_sources: Annotated[List[
RasterSourceConfig], Field(min_length=1)] = Field(
..., description='List of RasterSourceConfig to combine.')
primary_source_idx: conint(ge=0) = Field(
primary_source_idx: NonNegInt = Field(
0,
description=
'Index of the raster source whose CRS, dtype, and other attributes '
Expand All @@ -42,21 +45,21 @@ class MultiRasterSourceConfig(RasterSourceConfig):
description='Stack images from sub raster sources into a time-series '
'of shape (T, H, W, C) instead of concatenating bands.')

@validator('primary_source_idx')
def validate_primary_source_idx(cls, v: int, values: dict):
raster_sources = values.get('raster_sources', [])
if not (0 <= v < len(raster_sources)):
@model_validator(mode='after')
def validate_primary_source_idx(self) -> Self:
primary_source_idx = self.primary_source_idx
raster_sources = self.raster_sources
if not (0 <= primary_source_idx < len(raster_sources)):
raise IndexError('primary_source_idx must be in range '
'[0, len(raster_sources)].')
return v
return self

@validator('temporal')
def validate_temporal(cls, v: int, values: dict):
channel_order = values.get('channel_order')
if v and channel_order is not None:
@model_validator(mode='after')
def validate_temporal(self) -> Self:
if self.temporal and self.channel_order is not None:
raise ValueError(
'Setting channel_order is not allowed if temporal=True.')
return v
return self

def build(self,
tmp_dir: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rastervision.core.data.vector_transformer import (
ClassInferenceTransformerConfig, BufferTransformerConfig)
from rastervision.pipeline.config import (register_config, Config, Field,
validator)
field_validator)

if TYPE_CHECKING:
from rastervision.core.box import Box
Expand Down Expand Up @@ -35,7 +35,8 @@ class RasterizedSourceConfig(Config):
vector_source: VectorSourceConfig
rasterizer_config: RasterizerConfig

@validator('vector_source')
@field_validator('vector_source')
@classmethod
def ensure_required_transformers(
cls, v: VectorSourceConfig) -> VectorSourceConfig:
"""Add class-inference and buffer transformers if absent."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Optional, Sequence, Tuple
from pydantic import conint

from pydantic import NonNegativeInt as NonNegInt
import numpy as np

from rastervision.core.box import Box
Expand All @@ -14,7 +14,7 @@ class TemporalMultiRasterSource(MultiRasterSource):

def __init__(self,
raster_sources: Sequence[RasterSource],
primary_source_idx: conint(ge=0) = 0,
primary_source_idx: NonNegInt = 0,
force_same_dtype: bool = False,
raster_transformers: Sequence = [],
bbox: Optional[Box] = None):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING

from rastervision.pipeline.config import register_config, Field
from rastervision.core.data.vector_transformer import (
Expand All @@ -12,13 +12,13 @@
class ClassInferenceTransformerConfig(VectorTransformerConfig):
"""Configure a :class:`.ClassInferenceTransformer`."""

default_class_id: Optional[int] = Field(
default_class_id: int | None = Field(
None,
description='The default ``class_id`` to use if class cannot be '
'inferred using other mechanisms. If a feature has an inferred '
'``class_id`` of ``None``, then it will be deleted. '
'Defaults to ``None``.')
class_id_to_filter: Optional[Dict[int, list]] = Field(
class_id_to_filter: dict[int, list] | None = Field(
None,
description='Map from ``class_id`` to JSON filter used to infer '
'missing class IDs. Each key should be a class ID, and its value '
Expand All @@ -28,15 +28,15 @@ class ClassInferenceTransformerConfig(VectorTransformerConfig):
'is that described by '
'https://docs.mapbox.com/mapbox-gl-js/style-spec/other/#other-filter. '
'Defaults to ``None``.')
class_name_mapping: dict[str, str] = Field(
class_name_mapping: dict[str, str] | None = Field(
None,
description='``old_name --> new_name`` mapping for values in the '
'``class_name`` or ``label`` property of the GeoJSON features. The '
'``new_name`` must be a valid class name in the ``ClassConfig``. This '
'can also be used to merge multiple classes into one e.g.: '
'``dict(car="vehicle", truck="vehicle")``. Defaults to None.')

def build(self, class_config: Optional['ClassConfig'] = None
def build(self, class_config: 'ClassConfig | None' = None
) -> ClassInferenceTransformer:
return ClassInferenceTransformer(
self.default_class_id,
Expand Down
Loading

0 comments on commit be455b9

Please sign in to comment.