diff --git a/docs/source/03_quickstart.rst b/docs/source/03_quickstart.rst index 5cce5145..560ce11a 100644 --- a/docs/source/03_quickstart.rst +++ b/docs/source/03_quickstart.rst @@ -46,7 +46,7 @@ as well as other parameters such as the directory on your local machine where to save the GPM dataset of interest. To facilitate the creation of the configuration file, you can adapt and run the following script in Python. -The configuration file will be created in the user's home directory under the name ``.config_gpm.yaml``. +The configuration file will be created in the user's home directory under the name ``.config_gpm_api.yaml``. .. code-block:: python diff --git a/gpm/_config.py b/gpm/_config.py index 33c2736a..54f828d0 100644 --- a/gpm/_config.py +++ b/gpm/_config.py @@ -35,7 +35,7 @@ def _get_default_configs(): - """Retrieve the default GPM-API settings from the ``.config_gpm.yaml`` file.""" + """Retrieve the default GPM-API settings from the ``.config_gpm_api.yaml`` file.""" try: config_dict = read_configs() config_dict = {key: value for key, value in config_dict.items() if value is not None} diff --git a/gpm/accessor/methods.py b/gpm/accessor/methods.py index c144f5aa..c1cccffe 100644 --- a/gpm/accessor/methods.py +++ b/gpm/accessor/methods.py @@ -91,9 +91,9 @@ def sel(self, indexers=None, drop=False, **indexers_kwargs): @auto_wrap_docstring def extent(self, padding=0, size=None): - from gpm.utils.geospatial import get_extent + from gpm.utils.geospatial import get_geographic_extent_from_xarray - return get_extent(self._obj, padding=padding, size=size) + return get_geographic_extent_from_xarray(self._obj, padding=padding, size=size) @auto_wrap_docstring def crop(self, extent): diff --git a/gpm/bucket/analysis.py b/gpm/bucket/analysis.py index 13a86ada..bf0b3372 100644 --- a/gpm/bucket/analysis.py +++ b/gpm/bucket/analysis.py @@ -30,6 +30,118 @@ import pandas as pd import polars as pl +from gpm.utils.geospatial import _check_size + +# in processing.py --> replace, assign_spatial_partitions, get_bin_partition +# assign_spatial_partitions +# get_bin_partition + + +def get_bin_partition(values, bin_size): + """Compute the bins partitioning values. + + Parameters + ---------- + values : float or array-like + Values. + bin_size : float + Bin size. + + Returns + ------- + Bin value : float or array-like + DESCRIPTION. + + """ + return bin_size * np.floor(values / bin_size) + + +# bin_size = 10 +# values = np.array([-180,-176,-175, -174, -171, 170, 166]) +# get_bin_partition(values, bin_size) + + +def assign_spatial_partitions( + df, + xbin_name, + ybin_name, + xbin_size, + ybin_size, + x_column="lat", + y_column="lon", +): + """Add partitioning bin columns to dataframe. + + Works for both `dask.dataframe.DataFrame` and `pandas.DataFrame`. + """ + # Remove invalid coordinates + df = df[~df[x_column].isna()] + df = df[~df[y_column].isna()] + + # Add spatial partitions columns to dataframe + partition_columns = { + xbin_name: get_bin_partition(df[x_column], bin_size=xbin_size), + ybin_name: get_bin_partition(df[y_column], bin_size=ybin_size), + } + return df.assign(**partition_columns) + + +def _get_bin_edges(vmin, vmax, size): + """Get bin edges.""" + return np.arange(vmin, vmax + 1e-10, step=size) + + +def _get_bin_midpoints(vmin, vmax, size): + """Get bin midpoints.""" + edges = _get_bin_edges(vmin=vmin, vmax=vmax, size=size) + return edges[:-1] + np.diff(edges) / 2 + + +def create_spatial_bin_empty_df( + xbin_size=1, + ybin_size=1, + xlim=(-180, 180), + ylim=(-90, 90), + xbin_name="xbin", + ybin_name="ybin", +): + """Create empty spatial bin DataFrame.""" + # Get midpoints + x_midpoints = _get_bin_midpoints(vmin=xlim[0], vmax=xlim[1], size=xbin_size) + y_midpoints = _get_bin_midpoints(vmin=ylim[0], vmax=ylim[1], size=ybin_size) + + # Create the MultiIndex from the combination of x and y bins + multi_index = pd.MultiIndex.from_product( + [x_midpoints, y_midpoints], + names=[xbin_name, ybin_name], + ) + + # Create an empty DataFrame with the MultiIndex + return pd.DataFrame(index=multi_index) + + +def add_bin_column(df, column, bin_size, vmin, vmax, bin_name, add_midpoint=True): + # Keep rows within values + valid_rows = df[column].between(left=vmin, right=vmax, inclusive="both") + df = df.loc[valid_rows, :] + + # Get bin edges and midpoints + bin_edges = _get_bin_edges(vmin=vmin, vmax=vmax, size=bin_size) + bin_midpoints = _get_bin_midpoints(vmin=vmin, vmax=vmax, size=bin_size) + + # Get bin index + # - 0 is outside to the left of the bins + # - -1 is outside to the right + # --> Subtract 1 + bin_idx = np.digitize(df[column], bins=bin_edges, right=False) - 1 + + # Add bin index/midpoint values + if add_midpoint: + df[bin_name] = bin_midpoints[bin_idx] + else: + df[bin_name] = bin_idx + return df + def get_n_decimals(number): number_str = str(number) @@ -95,6 +207,40 @@ def get_cut_lon_breaks_labels(bin_spacing): return cut_lon_breaks, cut_lon_labels +def add_spatial_bins( + df, + x="x", + y="y", + xbin_size=1, + ybin_size=1, + xlim=(-180, 180), + ylim=(-90, 90), + xbin_name="xbin", + ybin_name="ybin", + add_bin_midpoint=True, +): + # Define x bins + df = add_bin_column( + df=df, + column=x, + bin_size=xbin_size, + vmin=xlim[0], + vmax=xlim[1], + bin_name=xbin_name, + add_midpoint=add_bin_midpoint, + ) + # Define y bins + return add_bin_column( + df=df, + column=y, + bin_size=ybin_size, + vmin=ylim[0], + vmax=ylim[1], + bin_name=ybin_name, + add_midpoint=add_bin_midpoint, + ) + + def pl_add_geographic_bins( df, xbin_column, @@ -113,6 +259,49 @@ def pl_add_geographic_bins( # df.filter(pl.col(xbin_column) == "outside_right") +def add_geographic_bins( + df, + x, + y, + xbin, + ybin, + size, + extent, + add_bin_midpoint=True, +): + size = _check_size(size) + if isinstance(df, pd.DataFrame): + from gpm.bucket.analysis import add_spatial_bins + + df = add_spatial_bins( + df=df, + x=x, + y=y, + xbin_name=xbin, + ybin_name=ybin, + xbin_size=size[0], + ybin_size=size[1], + xlim=extent[0:2], + ylim=extent[0:2], + add_bin_midpoint=add_bin_midpoint, + ) + else: + # TODO: no extent ! + df = pl_add_geographic_bins( + df=df, + xbin_column=xbin, + ybin_column=ybin, + bin_spacing=size, + x_column=x, + y_column=y, + ) + return df + + +####----------------------------------------------------------------. +#### Conversion to xarray Dataset + + def pl_df_to_xarray(df, xbin_column, ybin_column, bin_spacing): df_stats_pd = df.to_pandas() @@ -138,3 +327,37 @@ def pl_df_to_xarray(df, xbin_column, ybin_column, bin_spacing): # Reshape to xarray ds = df_stats_pd.to_xarray() return ds.rename({xbin_column: "longitude", ybin_column: "latitude"}) + + +def pd_df_to_xarray(df, xbin, ybin, size): + size = _check_size(size) + if set(df.index.names) != {xbin, ybin}: + df[xbin] = df[xbin].astype(float) + df[ybin] = df[ybin].astype(float) + df = df.set_index([xbin, ybin]) + + # Create an empty DataFrame with the MultiIndex + lon_labels = get_lon_labels(size[0]) + lat_labels = get_lat_labels(size[1]) + multi_index = pd.MultiIndex.from_product( + [lon_labels, lat_labels], + names=[xbin, ybin], + ) + empty_df = pd.DataFrame(index=multi_index) + + # Create final dataframe + df_full = empty_df.join(df, how="left") + + # Reshape to xarray + ds = df_full.to_xarray() + return ds + + +def df_to_dataset(df, xbin, ybin, size, extent): + if isinstance(df, pl.DataFrame): + df = df.to_pandas() + if not isinstance(df, pd.DataFrame): + raise TypeError("Expecting a pandas or polars DataFrame.") + size = _check_size(size) + ds = pd_df_to_xarray(df, xbin=xbin, ybin=ybin, size=size, extent=extent) + return ds diff --git a/gpm/bucket/partitioning.py b/gpm/bucket/partitioning.py new file mode 100644 index 00000000..dc0503a4 --- /dev/null +++ b/gpm/bucket/partitioning.py @@ -0,0 +1,604 @@ +from functools import wraps +import numpy as np +import pandas as pd +import polars as pl +from gpm.utils.geospatial import ( + _check_size, + check_extent, + Extent, + get_geographic_extent_around_point, + get_country_extent, + get_continent_extent, + get_extent_around_point, +) +####--------------------------------------------------------. +#### Grids +## Unstructured +# - Planar +# - Spherical +# --> Regular Geometry +# --> Irregular Geometry + +## Structured +# 2D on the sphere +# - SwathDefinition +# - CRS: geographic + +# 2D on planar projection +# - AreaDefinition: define cell size (resolution) based on shape and extent ! +# - CRS: geographic / projected + +####--------------------------------------------------------. +#### Partitioning +# 2D Partitioning: +# --> XYPartitioning: size +# --> GeographicPartitioning +# --> TilePartitioning +# - Last partition not granted to have same size as the other +# - If size = resolution, share same property as an AreaDefinition +# - Directory: 1 level (TilePartitioning) or 2 levels (XYPartitioning/GeographicPartitioning) +# - CRS: geographic / projected +# - Query by extent: OK +# - Query by geometry: First subset by extent, then intersect quadmesh? + +# KdTree/RTree Partitioning +# --> https://rtree.readthedocs.io/en/latest/ + +# SphericalPartitioning +# --> H3Partitioning (Uber’s Hexagonal Hierarchical) +# --> S2Partitioning (Google Geometry) (Hierarchical, discrete, global grid system) (Square cells) +# --> HealpixPartitioning +# --> GaussianGrid, CubedSphere, ... (other hierarchical, discrete, and global grid system) +# --> Hexbin, the process of taking coordinates and binning them into hexagonal cells in analytics or mapping software. +# --> Geohash: a system for encoding locations using a string of characters, creating a hierarchical, square grid system (a quadtree). + +# - Directory: 1 level +# - Query by extent: Define polygon and query by geometry ??? +# - Query by geometry: First subset by extent, then intersect quadmesh? + +# GeometryPartitioning +# - Can include all partitioning + +# GeometryBucket --> GeometryIntersection, save geometries in GeoParquet + + +#### CentroidBucket + +#### GeometryBucket +# - Exploit kdtree and dask +# - Borrow/Replace from pyresample bucket +# - Output to xvec/datatree? +# - Intermediate for remapping (weighted fraction, fraction_overlap, conservative) + +####--------------------------------------------------------. +#### Bucket Improvements +# Readers +# - gpm.bucket.read_around_point(bucket_dir, lon, lat, distance, size, **polars_kwargs) +# --> compute distance on subset and select below threshold +# --> https://stackoverflow.com/questions/76262681/i-need-to-create-a-column-with-the-distance-between-two-coordinates-in-polars +# - gpm.bucket.read_within_extent(bucket_dir, extent, **polars_kwargs) +# - gpm.bucket.read_within_country(bucket_dir, country, **polars_kwargs) +# - gpm.bucket.read_within_continent(bucket_dir, continent, **polars_kwargs) + +# Core: +# - list_partitions_within_extent +# - list_filepaths_within_extent +# - read_within_extent ! + +# Routines +# - Routine to repartition in smaller partitions (disaggregate bucket) +# - Routine to repartition in larger partitions (aggregate bucket) + +# Analysis +# - Group by overpass +# - Reformat to dataset / Generator + +# Writers +# - RENAME: write_partitioned_dataset into write_arrow_dataset + +####--------------------------------------------------------. +#### Query directories +# Format +# - hive: xbin=xx/ybin=yy or tile_id=id +# - xx/yy or id +## Option1 (faster if small query and lot of directories) +# - Create directories paths +# - Check if they exists +## Option2 (might be faster for large queries) +## - List directories +## - Filter by lon/lat +# Core: +# - list_partitions_within_extent +# - list_filepaths_within_extent +# - read_within_extent ! + +# Notes +# - https://h3geo.org/docs/comparisons/s2 +# - H3: https://github.com/uber/h3-py # https://pypi.org/project/h3/ +# --> h3.latlng_to_cell(lat, lng, resolution) + +####--------------------------------------------------------------------------------------------- +# ProjectionPartitioning +# Enable distance in meters and compute planar or on sphere +# def get_partitions_around_point(self, x, y, distance=None, size=None): +# extent = get_extent_around_point(x, y, distance=distance, size=size) +# return self.get_partitions_by_extent(extent=extent) + +# Requires pyproj CRS. Backproject x,y to lon/lat +# def get_partitions_by_country(self, name, padding=None): +# extent = get_country_extent(name, padding=padding) +# return self.get_partitions_by_extent(extent=extent) + +# Requires pyproj CRS. Backproject x,y to lon/lat +# def get_partitions_by_continent(self, name, padding=None): +# extent = get_continent_extent(name, padding=padding) +# return self.get_partitions_by_extent(extent=extent) + +####--------------------------------------------------------------------------------------------- + + +def check_valid_dataframe(func): + """ + Decorator to check if the first argument or the keyword argument 'df' + is a `pandas.DataFrame` or a `polars.DataFrame`. + """ + @wraps(func) + def wrapper(*args, **kwargs): + # Check if 'df' is in kwargs, otherwise assume it is the first positional argument + df = kwargs.get('df', args[1] if len(args) == 2 else None) + # Validate the DataFrame + if not isinstance(df, (pd.DataFrame, pl.DataFrame)): + raise TypeError("The 'df' argument must be either a pandas.DataFrame or a polars.DataFrame.") + return func(*args, **kwargs) + return wrapper + + +def check_valid_x_y(df, x, y): + """Check if the x and y columns are in the dataframe.""" + if y not in df: + raise ValueError(f"y='{y}' is not a column of the dataframe.") + if x not in df: + raise ValueError(f"x='{x}' is not a column of the dataframe.") + + +def ensure_xy_without_nan_values(df, x, y, remove_invalid_rows=True): + """Ensure valid coordinates in the dataframe.""" + # Remove NaN vales + if remove_invalid_rows: + if isinstance(df, pd.DataFrame): + return df.dropna(subset=[x,y]) + else: + return df.filter(~pl.col(x).is_null() | ~pl.col(y).is_null()) + + # Check no NaN values + if isinstance(df, pd.DataFrame): + indices = df[[x, y]].isna().any(axis=1) + else: + indices = (df[x].is_null() | df[x].is_null()) + if indices.any(): + rows_indices = np.where(indices)[0].tolist() + raise ValueError(f"Null values present in columns {x} and {y} at rows: {rows_indices}") + return df + + +def ensure_valid_partitions(df, xbin, ybin, remove_invalid_rows=True): + """Ensure valid partitions labels in the dataframe.""" + # Remove NaN values + if remove_invalid_rows: + if isinstance(df, pd.DataFrame): + return df.dropna(subset=[xbin,ybin]) + else: + df = df.filter(~pl.col(xbin).is_in(["outside_right", "outside_left"])) + df = df.filter(~pl.col(ybin).is_in(["outside_right", "outside_left"])) + df = df.filter(~pl.col(xbin).is_null() | ~pl.col(ybin).is_null()) + return df + + # Check no invalid partitions (NaN or polars outside_right/outside_left) + if isinstance(df, pd.DataFrame): + indices = df[[xbin, ybin]].isna().any(axis=1) + else: + indices = (df[xbin].is_in(["outside_right", "outside_left"]) | + df[ybin].is_in(["outside_right", "outside_left"]) + ) + if indices.any(): + rows_indices = np.where(indices)[0].tolist() + raise ValueError(f"Out of extent x,y coordinates at rows: {rows_indices}") + + +def get_array_combinations(x,y): + """Return all the combinations between the two array.""" + # Create the mesh grid + grid1, grid2 = np.meshgrid(x, y) + # Stack and reshape the grid arrays to get combinations + combinations = np.vstack([grid1.ravel(), grid2.ravel()]).T + return combinations + + +def get_n_decimals(number): + """Get the number of decimals of a number.""" + number_str = str(number) + decimal_index = number_str.find(".") + + if decimal_index == -1: + return 0 # No decimal point found + + # Count the number of characters after the decimal point + return len(number_str) - decimal_index - 1 + + +def get_breaks(size, vmin, vmax): + """Define partitions edges.""" + breaks = np.arange(vmin, vmax, size) + if breaks[-1] != vmax: + breaks = np.append(breaks, np.array([vmax])) + return breaks + + +def get_midpoints(size, vmin, vmax): + """Define partitions midpoints.""" + breaks = get_breaks(size, vmin=vmin, vmax=vmax) + midpoints = breaks[0:-1] + size / 2 + return midpoints + + +def get_labels(size, vmin, vmax): + """Define partitions labels (rounding partitions midpoints).""" + n_decimals = get_n_decimals(size) + midpoints = get_midpoints(size, vmin, vmax) + return midpoints.round(n_decimals + 1).astype(str) + + +def get_breaks_and_midpoints(size, vmin, vmax): + """Return the partitions edges and partitions midpoints.""" + breaks = get_breaks(size, vmin=vmin, vmax=vmax) + midpoints = get_midpoints(size, vmin=vmin, vmax=vmax) + return breaks, midpoints + + +def get_breaks_and_labels(size, vmin, vmax): + """Return the partitions edges and partitions labels.""" + breaks = get_breaks(size, vmin=vmin, vmax=vmax) + labels = get_labels(size, vmin=vmin, vmax=vmax) + return breaks, labels + + +def query_labels(values, breaks, labels): + """Return the partition labels for the specified coordinates.""" + # TODO: flag to raise error for NaN, None? + values = np.atleast_1d(np.asanyarray(values)).astype(float) + return pd.cut(values, bins=breaks, labels=labels, include_lowest=True, right=True) + + +def query_midpoints(values, breaks, midpoints): + """Return the partition midpoints for the specified coordinates.""" + values = np.atleast_1d(np.asanyarray(values)).astype(float) + return pd.cut(values, bins=breaks, labels=midpoints, include_lowest=True, right=True).astype(float) + + +def add_pandas_xy_partitions(df, size, extent, x, y, xbin, ybin, remove_invalid_rows=True): + """Add partitions labels to a pandas DataFrame based on x, y coordinates.""" + # Check x,y names + check_valid_x_y(df, x, y) + # Check/remove rows with NaN x,y columns + df = ensure_xy_without_nan_values(df, x=x, y=y, remove_invalid_rows=remove_invalid_rows) + # Retrieve breaks and labels (N and N+1) + cut_x_breaks, cut_x_labels = get_breaks_and_labels(size[0], vmin=extent[0], vmax=extent[1]) + cut_y_breaks, cut_y_labels = get_breaks_and_labels(size[1], vmin=extent[2], vmax=extent[3]) + # Add partitions labels columns + df[xbin] = query_labels(df[x].to_numpy(), breaks=cut_x_breaks, labels=cut_x_labels) + df[ybin] = query_labels(df[y].to_numpy(), breaks=cut_y_breaks, labels=cut_y_labels) + # Check/remove rows with invalid partitions (NaN) + df = ensure_valid_partitions(df, xbin=xbin, ybin=ybin, remove_invalid_rows=remove_invalid_rows) + return df + + +def add_polars_xy_partitions(df, x, y, size, extent, xbin, ybin, remove_invalid_rows=True): + """Add partitions to a polars DataFrame based on x, y coordinates.""" + # Check x,y names + check_valid_x_y(df, x, y) + # Check/remove rows with null x,y columns + df = ensure_xy_without_nan_values(df, x=x, y=y, remove_invalid_rows=remove_invalid_rows) + # Retrieve breaks and labels (N and N+1) + cut_x_breaks, cut_x_labels = get_breaks_and_labels(size[0], vmin=extent[0], vmax=extent[1]) + cut_y_breaks, cut_y_labels = get_breaks_and_labels(size[1], vmin=extent[2], vmax=extent[3]) + # Add outside labels for polars cut function + cut_x_labels = ["outside_left", *cut_x_labels, "outside_right"] + cut_y_labels = ["outside_left", *cut_y_labels, "outside_right"] + # Deal with left inclusion + cut_x_breaks[0] = cut_x_breaks[0] - 1e-8 + cut_y_breaks[0] = cut_y_breaks[0] - 1e-8 + # Add partitions columns + df = df.with_columns( + pl.col(x).cut(cut_x_breaks, labels=cut_x_labels, left_closed=False).alias(xbin), + pl.col(y).cut(cut_y_breaks, labels=cut_y_labels, left_closed=False).alias(ybin), + ) + # Check/remove rows with invalid partitions (out of extent or Null) + df = ensure_valid_partitions(df, xbin=xbin, ybin=ybin, remove_invalid_rows=remove_invalid_rows) + return df + + + +def add_polars_xy_tile_partitions(df, size, extent, x,y, tile_id): + check_valid_x_y(df, x, y) + raise NotImplementedError() + + +def add_pandas_xy_tile_partitions(df, size, extent, x,y, tile_id): + check_valid_x_y(df, x, y) + raise NotImplementedError() + + +def df_to_xarray(df, xbin, ybin, size, extent, new_x=None, new_y=None): + """Convert dataframe to xarray Dataset based on specified partitions midpoints. + + The partitioning cells not present in the dataframe are set to NaN. + """ + if isinstance(df, pl.DataFrame): + df = df.to_pandas() + if set(df.index.names) == set([xbin, ybin]): + df = df.reset_index() + + # Ensure index is float or string + df[xbin] = df[xbin].astype(float) + df[ybin] = df[ybin].astype(float) + df = df.set_index([xbin, ybin]) + + # Create an empty DataFrame with the MultiIndex + x_midpoints = get_midpoints(size[0], vmin=extent[0], vmax=extent[1]) + y_midpoints = get_midpoints(size[1], vmin=extent[2], vmax=extent[3]) + multi_index = pd.MultiIndex.from_product( + [x_midpoints, y_midpoints], + names=[xbin, ybin], + ) + empty_df = pd.DataFrame(index=multi_index) + + # Create final dataframe + df_full = empty_df.join(df, how="left") + + # Reshape to xarray + ds = df_full.to_xarray() + ds[xbin] = ds[xbin].astype(float) + + # Rename dictionary + rename_dict = {} + if new_x is not None: + rename_dict[xbin] = new_x + if new_y is not None: + rename_dict[ybin] = new_y + ds = ds.rename(rename_dict) + return ds + + +class XYPartitioning: + """ + Handales partitioning of data into x and y bins. + + Parameters: + ---------- + xbin : float + Identifier for the x bin. + ybin : float + Identifier for the y bin. + size : int, float, tuple, list + The size value(s) of the bins. + The function interprets the input as follows: + - int or float: The same size is enforced in both x and y directions. + - tuple or list: The bin size for the x and y directions. + extent : list + The extent for the partitioning specified as [xmin, xmax, ymin, ymax]. + + """ + def __init__(self, xbin, ybin, size, extent): + # Define extent + self.extent = check_extent(extent) + # Define bin size + self.size = _check_size(size) + # Define bin names + self.xbin = xbin + self.ybin = ybin + # Define breaks, midpoints and labels + self.x_breaks = get_breaks(size=self.size[0], vmin=self.extent.xmin, vmax=self.extent.xmax) + self.y_breaks = get_breaks(size=self.size[1], vmin=self.extent.ymin, vmax=self.extent.ymax) + self.x_midpoints = get_midpoints(size=self.size[0], vmin=self.extent.xmin, vmax=self.extent.xmax) + self.y_midpoints = get_midpoints(size=self.size[1], vmin=self.extent.ymin, vmax=self.extent.ymax) + self.x_labels = get_labels(size=self.size[0], vmin=self.extent.xmin, vmax=self.extent.xmax) + self.y_labels = get_labels(size=self.size[1], vmin=self.extent.ymin, vmax=self.extent.ymax) + # Define info + self.shape = (len(self.x_labels), len(self.y_labels)) + self.n_partitions = len(self.x_labels) * len(self.y_labels) + self.n_x = self.shape[0] + self.n_y = self.shape[1] + + @property + def partitions(self): + return [self.xbin, self.ybin] + + @check_valid_dataframe + def add_partitions(self, df, x, y, remove_invalid_rows=True): + if isinstance(df, pd.DataFrame): + return add_pandas_xy_partitions(df=df, x=x, y=y, + size=self.size, + extent=self.extent, + xbin=self.xbin, + ybin=self.ybin, + remove_invalid_rows=remove_invalid_rows) + return add_polars_xy_partitions(df=df, x=x, y=y, + size=self.size, + extent=self.extent, + xbin=self.xbin, + ybin=self.ybin, + remove_invalid_rows=remove_invalid_rows) + + @check_valid_dataframe + def to_xarray(self, df, new_x=None, new_y=None): + return df_to_xarray(df, + xbin=self.xbin, + ybin=self.ybin, + size=self.size, extent=self.extent, + new_x=new_x, + new_y=new_y) + + def to_dict(self): + dictionary = {"name": self.__class__.__name__, + "extent": list(self.extent), + "size": self.size, + "xbin": self.xbin, + "ybin": self.ybin} + return dictionary + + def query_x_labels(self, x): + """Return the x partition labels for the specified x coordinates.""" + return query_labels(x, breaks=self.x_breaks, labels=self.x_labels).astype(str) + + def query_y_labels(self, y): + """Return the y partition labels for the specified y coordinates.""" + return query_labels(y, breaks=self.y_breaks, labels=self.y_labels).astype(str) + + def query_labels(self, x, y): + """Return the partition labels for the specified x,y coordinates.""" + return self.query_x_labels(x), self.query_y_labels(y) + + def query_x_midpoints(self, x): + """Return the x partition midpoints for the specified x coordinates.""" + return query_midpoints(x, breaks=self.x_breaks, midpoints=self.x_midpoints) + + def query_y_midpoints(self, y): + """Return the y partition midpoints for the specified y coordinates.""" + return query_midpoints(y, breaks=self.y_breaks, midpoints=self.y_midpoints) + + def query_midpoints(self, x, y): + """Return the partition midpoints for the specified x,y coordinates.""" + return self.query_x_midpoints(x), self.query_y_midpoints(y) + + def get_partitions_by_extent(self, extent): + """Return the partitions labels containing data within the extent.""" + extent = check_extent(extent) + # Define valid query extent (to be aligned with partitioning extent) + query_extent = [max(extent.xmin, self.extent.xmin), min(extent.xmax, self.extent.xmax), + max(extent.ymin, self.extent.ymin), min(extent.ymax, self.extent.ymax)] + query_extent = Extent(*query_extent) + # Retrieve midpoints + xmin, xmax = self.query_x_midpoints([query_extent.xmin, query_extent.xmax]) + ymin, ymax = self.query_y_midpoints([query_extent.ymin, query_extent.ymax]) + # Retrieve univariate x and y labels within the extent + x_labels = self.x_labels[np.logical_and(self.x_midpoints >= xmin, self.x_midpoints <= xmax)] + y_labels = self.y_labels[np.logical_and(self.y_midpoints >= ymin, self.y_midpoints <= ymax)] + # Retrieve combination of all (x,y) labels within the extent + combinations = get_array_combinations(x_labels, y_labels) + return combinations + + def get_partitions_around_point(self, x, y, distance=None, size=None): + extent = get_extent_around_point(x, y, distance=distance, size=size) + return self.get_partitions_by_extent(extent=extent) + + @property + def quadmesh(self): + """Return the quadrilateral mesh. + + A quadrilateral mesh is a grid of M by N adjacent quadrilaterals that are defined via a (M+1, N+1) + grid of vertices. + + The quadrilateral mesh is accepted by `matplotlib.pyplot.pcolormesh`, `matplotlib.collections.QuadMesh` + `matplotlib.collections.PolyQuadMesh`. + + np.naddary + Quadmesh array of shape (M+1, N+1, 2) + """ + x_corners, y_corners = np.meshgrid(self.x_breaks, self.y_breaks) + return np.stack((x_corners, y_corners), axis=2) + + # to_yaml + # to_shapely + # to_spherically (geographic) + # to_geopandas [lat_bin, lon_bin, geometry] + + + +class GeographicPartitioning(XYPartitioning): + """ + Handles geographic partitioning of data based on longitude and latitude bin sizes within a defined extent. + + The last bin size (in lon and lat direction) might not be of size ``size` ! + + Parameters: + ---------- + size : float + The uniform size for longitude and latitude binning. + xbin : str, optional + Name of the longitude bin, default is 'lon_bin'. + ybin : str, optional + Name of the latitude bin, default is 'lat_bin'. + extent : list, optional + The geographical extent for the partitioning specified as [xmin, xmax, ymin, ymax]. + Default is the whole earth: [-180, 180, -90, 90]. + + Inherits: + ---------- + XYPartitioning + """ + def __init__(self, size, xbin="lon_bin", ybin="lat_bin", extent=[-180, 180, -90, 90]): + super().__init__(xbin=xbin, ybin=ybin, size=size, extent=extent) + + def get_partitions_around_point(self, lon, lat, distance=None, size=None): + extent = get_geographic_extent_around_point(lon=lon, lat=lat, + distance=distance, + size=size, + distance_type="geographic") + return self.get_partitions_by_extent(extent=extent) + + def get_partitions_by_country(self, name, padding=None): + extent = get_country_extent(name, padding=padding) + return self.get_partitions_by_extent(extent=extent) + + def get_partitions_by_continent(self, name, padding=None): + extent = get_continent_extent(name, padding=padding) + return self.get_partitions_by_extent(extent=extent) + + @check_valid_dataframe + def to_xarray(self, df, new_x="lon", new_y="lat"): + return df_to_xarray(df, + xbin=self.xbin, + ybin=self.ybin, + size=self.size, + extent=self.extent, + new_x=new_x, + new_y=new_y) + + + +class TilePartitioning: + """ + Handles partitioning of data into tiles within a specified extent. + + Parameters: + ---------- + size : float + The size of the tiles. + extent : list + The extent for the partitioning specified as [xmin, xmax, ymin, ymax]. + tile_id : str, optional + Identifier for the tile bin. The default is ``'tile_id'``. + + """ + def __init__(self, size, extent, tile_id="tile_id"): + self.size = _check_size(size) + self.extent = check_extent(extent) + self.tile_id = tile_id + + @property + def bins(self): + return [self.tile_id] + + @check_valid_dataframe + def add_partitions(self, df, x, y): + if isinstance(df, pd.DataFrame): + return add_pandas_xy_tile_partitions(df, x=x, y=x, + size=self.size, + extent=self.extent, + tile_id=self.tile_id, + ) + return add_polars_xy_tile_partitions(df, x=x, y=x, + size=self.size, + extent=self.extent, + tile_id=self.tile_id, + ) \ No newline at end of file diff --git a/gpm/bucket/utils.py b/gpm/bucket/utils.py deleted file mode 100644 index 49d9ddfc..00000000 --- a/gpm/bucket/utils.py +++ /dev/null @@ -1,120 +0,0 @@ -# -----------------------------------------------------------------------------. -# MIT License - -# Copyright (c) 2024 GPM-API developers -# -# This file is part of GPM-API. - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -# -----------------------------------------------------------------------------. -"""This module provide utilities to manipulate GPM Geographic Buckets.""" -import numpy as np -import pandas as pd - - -def _get_bin_edges(vmin, vmax, size): - """Get bin edges.""" - return np.arange(vmin, vmax + 1e-10, step=size) - - -def _get_bin_midpoints(vmin, vmax, size): - """Get bin midpoints.""" - edges = _get_bin_edges(vmin=vmin, vmax=vmax, size=size) - return edges[:-1] + np.diff(edges) / 2 - - -def create_spatial_bin_empty_df( - xbin_size=1, - ybin_size=1, - xlim=(-180, 180), - ylim=(-90, 90), - xbin_name="xbin", - ybin_name="ybin", -): - """Create empty spatial bin DataFrame.""" - # Get midpoints - x_midpoints = _get_bin_midpoints(vmin=xlim[0], vmax=xlim[1], size=xbin_size) - y_midpoints = _get_bin_midpoints(vmin=ylim[0], vmax=ylim[1], size=ybin_size) - - # Create the MultiIndex from the combination of x and y bins - multi_index = pd.MultiIndex.from_product( - [x_midpoints, y_midpoints], - names=[xbin_name, ybin_name], - ) - - # Create an empty DataFrame with the MultiIndex - return pd.DataFrame(index=multi_index) - - -def add_bin_column(df, column, bin_size, vmin, vmax, bin_name, add_midpoint=True): - # Keep rows within values - valid_rows = df[column].between(left=vmin, right=vmax, inclusive="both") - df = df.loc[valid_rows, :] - - # Get bin edges and midpoints - bin_edges = _get_bin_edges(vmin=vmin, vmax=vmax, size=bin_size) - bin_midpoints = _get_bin_midpoints(vmin=vmin, vmax=vmax, size=bin_size) - - # Get bin index - # - 0 is outside to the left of the bins - # - -1 is outside to the right - # --> Subtract 1 - bin_idx = np.digitize(df[column], bins=bin_edges, right=False) - 1 - - # Add bin index/midpoint values - if add_midpoint: - df[bin_name] = bin_midpoints[bin_idx] - else: - df[bin_name] = bin_idx - return df - - -def add_spatial_bins( - df, - x="x", - y="y", - xbin_size=1, - ybin_size=1, - xlim=(-180, 180), - ylim=(-90, 90), - xbin_name="xbin", - ybin_name="ybin", - add_bin_midpoint=True, -): - # Define x bins - df = add_bin_column( - df=df, - column=x, - bin_size=xbin_size, - vmin=xlim[0], - vmax=xlim[1], - bin_name=xbin_name, - add_midpoint=add_bin_midpoint, - ) - # Define y bins - return add_bin_column( - df=df, - column=y, - bin_size=ybin_size, - vmin=ylim[0], - vmax=ylim[1], - bin_name=ybin_name, - add_midpoint=add_bin_midpoint, - ) diff --git a/gpm/configs.py b/gpm/configs.py index e6c6dad5..9c85b192 100644 --- a/gpm/configs.py +++ b/gpm/configs.py @@ -76,8 +76,8 @@ def _define_config_filepath(): """Define the config YAML file path.""" # Retrieve user home directory home_directory = os.path.expanduser("~") - # Define path where .config_gpm.yaml file should be located - return os.path.join(home_directory, ".config_gpm.yaml") + # Define path where .config_gpm_api.yaml file should be located + return os.path.join(home_directory, ".config_gpm_api.yaml") def define_configs( @@ -104,12 +104,12 @@ def define_configs( Notes ----- - This function writes a YAML file to the user's home directory at ~/.config_gpm.yaml + This function writes a YAML file to the user's home directory at ~/.config_gpm_api.yaml with the given GPM-API credentials and base directory. The configuration file can be used for authentication when making GPM-API requests. """ - # Define path to .config_gpm.yaml file + # Define path to .config_gpm_api.yaml file filepath = _define_config_filepath() # If the config exists, read it and update it ;) @@ -159,11 +159,11 @@ def read_configs() -> dict[str, str]: Notes ----- - This function reads the YAML configuration file located at ~/.config_gpm.yaml, which + This function reads the YAML configuration file located at ~/.config_gpm_api.yaml, which should contain the GPM-API credentials and base directory specified by `gpm.define_configs()`. """ - # Define path to .config_gpm.yaml file + # Define path to .config_gpm_api.yaml file filepath = _define_config_filepath() # Check it exists if not os.path.exists(filepath): diff --git a/gpm/dataset/datatree.py b/gpm/dataset/datatree.py index f27a9dc1..3bc43d71 100644 --- a/gpm/dataset/datatree.py +++ b/gpm/dataset/datatree.py @@ -79,6 +79,8 @@ def check_valid_granule(filepath): with xr.open_dataset(filepath, engine="netcdf4", group="") as ds: check_non_empty_granule(ds, filepath) except Exception as e: + if "an EMPTY granule" in str(e): + raise e _identify_error(e, filepath) diff --git a/gpm/io/download.py b/gpm/io/download.py index 961ded28..e63741b7 100644 --- a/gpm/io/download.py +++ b/gpm/io/download.py @@ -919,7 +919,7 @@ def download_archive( storage="PPS", n_threads=4, transfer_tool="CURL", - progress_bar=False, + progress_bar=True, force_download=False, check_integrity=True, remove_corrupted=True, diff --git a/gpm/tests/test_bucket/test_partitioning.py b/gpm/tests/test_bucket/test_partitioning.py new file mode 100644 index 00000000..34c18afd --- /dev/null +++ b/gpm/tests/test_bucket/test_partitioning.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Thu May 23 12:19:17 2024 + +@author: ghiggi +""" +import pytest +import pandas as pd +import numpy as np +import xarray as xr +import polars as pl +from gpm.bucket.partitioning import XYPartitioning +from gpm.bucket.partitioning import ( + get_n_decimals, + get_breaks, + get_labels, + get_midpoints, + get_breaks_and_labels, + get_array_combinations, +) + + +def test_get_n_decimals(): + """Ensure decimal count is accurate.""" + assert get_n_decimals(123.456) == 3 + assert get_n_decimals(100) == 0 + assert get_n_decimals(123.0001) == 4 + + +def test_get_breaks(): + """Verify the correct calculation of breaks.""" + breaks = get_breaks(0.5, 0, 2) + assert np.array_equal(breaks, np.array([0, 0.5, 1.0, 1.5, 2])) + + +def test_get_labels(): + """Verify correct label generation.""" + labels = get_labels(0.5, 0, 2) + expected_labels = ['0.25', '0.75', '1.25', '1.75'] + assert labels.tolist() == expected_labels + + labels = get_labels(0.999, 0, 2) + expected_labels =['0.4995', '1.4985', '2.4975'] + assert labels.tolist() == expected_labels + + +def test_get_midpoints(): + """Verify correct midpoint generation.""" + midpoints = get_midpoints(0.5, 0, 2) + expected_midpoints = [0.25, 0.75, 1.25, 1.75] + np.testing.assert_allclose(midpoints, expected_midpoints) + + midpoints = get_midpoints(0.999, 0, 2) + expected_midpoints = [0.4995, 1.4985, 2.4975] + np.testing.assert_allclose(midpoints, expected_midpoints) + + +def test_get_breaks_and_labels(): + """Ensure both breaks and labels are returned and accurate.""" + breaks, labels = get_breaks_and_labels(0.5, 0, 2) + assert np.array_equal(breaks, np.array([0, 0.5, 1.0, 1.5, 2])) + assert labels.tolist() == ['0.25', '0.75', '1.25', '1.75'] + + +def test_get_array_combinations(): + x = np.array([1, 2, 3]) + y = np.array([4, 5]) + expected_result = np.array( + [[1, 4], + [2, 4], + [3, 4], + [1, 5], + [2, 5], + [3, 5]]) + np.testing.assert_allclose(get_array_combinations(x,y), expected_result) + +class TestXYPartitioning: + """Tests for the XYPartitioning class.""" + + def test_initialization(self): + """Test proper initialization of XYPartitioning objects.""" + partitioning = XYPartitioning(xbin="xbin", ybin="ybin", size=(1, 2), extent=[0, 10, 0, 10]) + assert partitioning.size == (1, 2) + assert partitioning.partitions == ["xbin", "ybin"] + assert list(partitioning.extent) == [0, 10, 0, 10] + assert partitioning.shape == (10, 5) + assert partitioning.n_partitions == 50 + assert partitioning.n_x == 10 + assert partitioning.n_y == 5 + np.testing.assert_allclose(partitioning.x_breaks, [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + np.testing.assert_allclose(partitioning.y_breaks, [ 0, 2, 4, 6, 8, 10]) + np.testing.assert_allclose(partitioning.x_midpoints, [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5]) + np.testing.assert_allclose(partitioning.y_midpoints, [1., 3., 5., 7., 9.]) + assert partitioning.x_labels.tolist() == ['0.5', '1.5', '2.5', '3.5', '4.5', '5.5', '6.5', '7.5', '8.5', '9.5'] + assert partitioning.y_labels.tolist() == ['1.0', '3.0', '5.0', '7.0', '9.0'] + + def test_invalid_initialization(self): + """Test initialization with invalid extent and size.""" + with pytest.raises(ValueError): + XYPartitioning(xbin="xbin", ybin="ybin", size=(0.1, 0.2), extent=[10, 0, 0, 10]) + + with pytest.raises(TypeError): + XYPartitioning(xbin="xbin", ybin="ybin", size="invalid", extent=[0, 10, 0, 10]) + + def test_add_partitions_pandas(self): + """Test valid partitions are added to a pandas dataframe.""" + # Create test dataframe + df = pd.DataFrame({ + 'x': [-0.001, -0.0, 0, 0.5, 1.0, 1.5, 2.0, 2.1, np.nan], + 'y': [-0.001, -0.0, 0, 0.5, 1.0, 1.5, 2.0, 2.1, np.nan], + }) + # Create partitioning + size=(0.5, 0.25) + extent=[0, 2, 0, 2] + partitioning = XYPartitioning(xbin="my_xbin", ybin="my_ybin", + size=size, extent=extent) + # Add partitions + df_out = partitioning.add_partitions(df, x="x", y="y", remove_invalid_rows=True) + + # Test results + expected_xbin = [0.25, 0.25, 0.25, 0.75, 1.25, 1.75] + expected_ybin = [0.125, 0.125, 0.375, 0.875, 1.375, 1.875] + assert df_out["my_xbin"].dtype.name == "category", "X bin are not of categorical type." + assert df_out["my_ybin"].dtype.name == "category", "Y bin are not of categorical type." + assert df_out['my_xbin'].astype(float).tolist() == expected_xbin, "X bin are incorrect." + assert df_out['my_ybin'].astype(float).tolist() == expected_ybin, "Y bin are incorrect." + + def test_add_partitions_polars(self): + """Test valid partitions are added to a polars dataframe.""" + # Create test dataframe + df = pl.DataFrame(pd.DataFrame({ + 'x': [-0.001, -0.0, 0, 0.5, 1.0, 1.5, 2.0, 2.1, np.nan], + 'y': [-0.001, -0.0, 0, 0.5, 1.0, 1.5, 2.0, 2.1, np.nan], + })) + # Create partitioning + size=(0.5, 0.25) + extent=[0, 2, 0, 2] + partitioning = XYPartitioning(xbin="my_xbin", ybin="my_ybin", + size=size, extent=extent) + # Add partitions + df_out = partitioning.add_partitions(df, x="x", y="y", remove_invalid_rows=True) + + # Test results + expected_xbin = [0.25, 0.25, 0.25, 0.75, 1.25, 1.75] + expected_ybin = [0.125, 0.125, 0.375, 0.875, 1.375, 1.875] + assert df_out['my_xbin'].dtype == pl.datatypes.Categorical, "X bin are not of categorical type." + assert df_out['my_ybin'].dtype == pl.datatypes.Categorical, "X bin are not of categorical type." + assert df_out['my_xbin'].cast(float).to_list() == expected_xbin, "X bin are incorrect." + assert df_out['my_ybin'].cast(float).to_list() == expected_ybin, "Y bin are incorrect." + + def test_to_xarray(self): + """Test valid partitions are added to a pandas dataframe.""" + # Create test dataframe + df = pd.DataFrame({ + 'x': [-0.001, -0.0, 0, 0.5, 1.0, 1.5, 2.0, 2.1, np.nan], + 'y': [-0.001, -0.0, 0, 0.5, 1.0, 1.5, 2.0, 2.1, np.nan], + }) + # Create partitioning + size=(0.5, 0.25) + extent=[0, 2, 0, 2] + partitioning = XYPartitioning(xbin="my_xbin", ybin="my_ybin", + size=size, extent=extent) + # Add partitions + df = partitioning.add_partitions(df, x="x", y="y", remove_invalid_rows=True) + + # Aggregate by partitions + df_grouped = df.groupby(partitioning.partitions, observed=True).median() + df_grouped["dummy_var"] = 2 + + # Convert to Dataset + ds = partitioning.to_xarray(df_grouped, new_x="lon", new_y="lat") + + # Test results + expected_xbin = [0.25, 0.75, 1.25, 1.75] + expected_ybin = [0.125, 0.375, 0.625, 0.875, 1.125, 1.375, 1.625, 1.875] + assert isinstance(ds, xr.Dataset), "Not a xr.Dataset" + assert ds["lon"].data.dtype.name != 'object', "xr.Dataset coordinates should not be a string." + assert ds["lat"].data.dtype.name != 'object', "xr.Dataset coordinates should not be a string." + assert ds["lon"].data.dtype.name == 'float64', "xr.Dataset coordinates are not float64." + assert ds["lat"].data.dtype.name == 'float64', "xr.Dataset coordinates are not float64." + assert "dummy_var" in ds, "The x columns has not become a xr.Dataset variable" + + def test_query_labels(self): + """Test valid labels queries.""" + # Create partitioning + size=(0.5, 0.25) + extent=[0, 2, 0, 2] + partitioning = XYPartitioning(xbin="my_xbin", ybin="my_ybin", + size=size, extent=extent) + # Test results + assert partitioning.query_x_labels(1).tolist() == ['0.75'] + assert partitioning.query_y_labels(1).tolist() == ['0.875'] + assert partitioning.query_x_labels(np.array(1)).tolist() == ['0.75'] + assert partitioning.query_x_labels(np.array([1])).tolist() == ['0.75'] + assert partitioning.query_x_labels(np.array([1, 1])).tolist() == ['0.75', '0.75'] + assert partitioning.query_x_labels([1, 1]).tolist() == ['0.75', '0.75'] + + x_labels, y_labels = partitioning.query_labels([1,2], [0,1]) + assert x_labels.tolist() == ['0.75', '1.75'] + assert y_labels.tolist() == ['0.125', '0.875'] + + # Test out of extent + assert partitioning.query_x_labels([-1, 1]).tolist() == ['nan', '0.75'] + + # Test with input nan + assert partitioning.query_x_labels(np.nan).tolist() == ['nan'] + assert partitioning.query_x_labels(None).tolist() == ['nan'] + + # Test with input string + with pytest.raises(ValueError): + partitioning.query_x_labels("dummy") + + def test_query_midpoints(self): + """Test valid midpoint queries.""" + + # Create partitioning + size=(0.5, 0.25) + extent=[0, 2, 0, 2] + partitioning = XYPartitioning(xbin="my_xbin", ybin="my_ybin", + size=size, extent=extent) + # Test results + np.testing.assert_allclose(partitioning.query_x_midpoints(1), [0.75]) + np.testing.assert_allclose(partitioning.query_y_midpoints(1).tolist(), [0.875]) + np.testing.assert_allclose( partitioning.query_x_midpoints(np.array(1)), [0.75]) + np.testing.assert_allclose( partitioning.query_x_midpoints(np.array([1])), [0.75]) + np.testing.assert_allclose( partitioning.query_x_midpoints(np.array([1, 1])), [0.75, 0.75]) + np.testing.assert_allclose( partitioning.query_x_midpoints([1, 1]), [0.75, 0.75]) + + x_midpoints, y_midpoints = partitioning.query_midpoints([1,2], [0,1]) + np.testing.assert_allclose(x_midpoints.tolist(), [0.75, 1.75]) + np.testing.assert_allclose(y_midpoints.tolist(), [0.125, 0.875]) + + # Test out of extent + np.testing.assert_allclose(partitioning.query_x_midpoints([-1, 1]), [np.nan, 0.75]) + + # Test with input nan or None + np.testing.assert_allclose( partitioning.query_x_midpoints(np.nan).tolist(), [np.nan]) + np.testing.assert_allclose( partitioning.query_x_midpoints(None).tolist(), [np.nan]) + + # Test with input string + with pytest.raises(ValueError): + partitioning.query_x_midpoints("dummy") + + def test_get_partitions_by_extent(self): + """Test get_partitions_by_extent.""" + # Create partitioning + size=(0.5, 0.25) + extent=[0, 2, 0, 2] + partitioning = XYPartitioning(xbin="my_xbin", ybin="my_ybin", + size=size, extent=extent) + # Test results with extent within + new_extent = [0, 0.5, 0, 0.5] + labels = partitioning.get_partitions_by_extent(new_extent) + expected_labels = np.array([['0.25', '0.125'], + ['0.25', '0.375']]) + assert expected_labels.tolist() == labels.tolist() + + # Test results with extent outside + new_extent = [3, 4, 3, 4] + labels = partitioning.get_partitions_by_extent(new_extent) + assert labels.size == 0 + + # Test results with extent partially overlapping + new_extent = [1.5, 4, 1.75, 4] # BUG + labels = partitioning.get_partitions_by_extent(new_extent) + expected_labels = np.array([['1.25', '1.625'], + ['1.75', '1.625'], + ['1.25', '1.875'], + ['1.75', '1.875']]) + assert expected_labels.tolist() == labels.tolist() + + def test_get_partitions_around_point(self): + """Test get_partitions_around_point.""" + # Create partitioning + size=(0.5, 0.25) + extent=[0, 2, 0, 2] + partitioning = XYPartitioning(xbin="my_xbin", ybin="my_ybin", + size=size, extent=extent) + # Test results with point within + labels = partitioning.get_partitions_around_point(x=1, y=1, distance=0) + assert labels.tolist() == [['0.75', '0.875']] + + # Test results with point outside + labels = partitioning.get_partitions_around_point(x=3, y=3, distance=0) + assert labels.size == 0 + + # Test results with point outside but area within + labels = partitioning.get_partitions_around_point(x=3, y=3, distance=1) + assert labels.tolist() == [['1.75', '1.875']] + + def test_quadmesh(self): + """Test quadmesh.""" + size=(1, 1) + extent=[0, 2, 1, 3] + partitioning = XYPartitioning(xbin="my_xbin", ybin="my_ybin", + size=size, extent=extent) + # Test results + assert partitioning.quadmesh.shape == (3, 3, 2) + x_mesh = np.array([[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]) + y_mesh = np.array([[1, 1, 1], + [2, 2, 2], + [3, 3, 3]]) + np.testing.assert_allclose(partitioning.quadmesh[:,:, 0], x_mesh) + np.testing.assert_allclose(partitioning.quadmesh[:,:, 1], y_mesh) + + # TODO: origin: y: bottom or upper (RIGHT NOW UPPER !) # BUG: increase by descinding + + + def test_to_dict(self): + # Create partitioning + size=(0.5, 0.25) + extent=[0, 2, 0, 2] + xbin = "my_xbin" + ybin = "my_ybin" + partitioning = XYPartitioning(xbin=xbin, ybin=ybin, + size=size, extent=extent) + # Test results + expected_dict = {"name": "XYPartitioning", + "extent": extent, + "size": size, + "xbin": xbin, + "ybin": ybin} + assert partitioning.to_dict() == expected_dict + + + + + + + \ No newline at end of file diff --git a/gpm/tests/test_config.py b/gpm/tests/test_config.py index 31453a8c..23427f1e 100644 --- a/gpm/tests/test_config.py +++ b/gpm/tests/test_config.py @@ -43,12 +43,12 @@ def test_define_configs(tmp_path, mocker): import gpm # Mock to save config YAML at custom location - config_filepath = str(tmp_path / ".config_gpm.yaml") + config_filepath = str(tmp_path / ".config_gpm_api.yaml") mocker.patch("gpm.configs._define_config_filepath", return_value=config_filepath) # Define config YAML gpm.configs.define_configs(**CONFIGS_TEST_KWARGS) - assert os.path.exists(tmp_path / ".config_gpm.yaml") + assert os.path.exists(tmp_path / ".config_gpm_api.yaml") def test_read_configs(tmp_path, mocker): @@ -56,12 +56,12 @@ def test_read_configs(tmp_path, mocker): from gpm.configs import define_configs, read_configs # Mock to save config YAML at custom location - config_filepath = str(tmp_path / ".config_gpm.yaml") + config_filepath = str(tmp_path / ".config_gpm_api.yaml") mocker.patch("gpm.configs._define_config_filepath", return_value=config_filepath) # Define config YAML define_configs(**CONFIGS_TEST_KWARGS) - assert os.path.exists(tmp_path / ".config_gpm.yaml") + assert os.path.exists(tmp_path / ".config_gpm_api.yaml") # Read config YAML config_dict = read_configs() @@ -76,7 +76,7 @@ def test_update_gpm_configs(tmp_path, mocker): from gpm.utils.yaml import read_yaml # Mock to save config YAML at custom location - config_filepath = str(tmp_path / ".config_gpm.yaml") + config_filepath = str(tmp_path / ".config_gpm_api.yaml") mocker.patch("gpm.configs._define_config_filepath", return_value=config_filepath) # Initialize diff --git a/gpm/tests/test_donfig_config.py b/gpm/tests/test_donfig_config.py index ee28d203..787a4ed5 100644 --- a/gpm/tests/test_donfig_config.py +++ b/gpm/tests/test_donfig_config.py @@ -50,7 +50,7 @@ def test_donfig_takes_config_yaml_file(tmp_path, mocker): import gpm # Mock to save config YAML at custom location - config_fpath = str(tmp_path / ".config_gpm.yaml") + config_fpath = str(tmp_path / ".config_gpm_api.yaml") mocker.patch("gpm.configs._define_config_filepath", return_value=config_fpath) # Initialize config YAML diff --git a/gpm/tests/test_utils/test_geospatial.py b/gpm/tests/test_utils/test_geospatial.py index a54247da..aa5879e6 100644 --- a/gpm/tests/test_utils/test_geospatial.py +++ b/gpm/tests/test_utils/test_geospatial.py @@ -34,6 +34,7 @@ from gpm.tests.utils.fake_datasets import get_grid_dataarray, get_orbit_dataarray from gpm.utils.geospatial import ( adjust_geographic_extent, + check_extent, crop, crop_around_point, crop_by_continent, @@ -46,8 +47,8 @@ get_crop_slices_by_continent, get_crop_slices_by_country, get_crop_slices_by_extent, - get_extent, - get_extent_around_point, + get_geographic_extent_around_point, + get_geographic_extent_from_xarray, unwrap_longitude_degree, ) @@ -116,6 +117,40 @@ def grid_dataarray() -> xr.DataArray: # Tests ######################################################################## +class TestCheckExtent: + """Tests for the check_extent function.""" + + def test_valid_extent(self): + """Test that a valid extent passes without error.""" + assert list(check_extent([-180, 180, -90, 90])) == [-180, 180, -90, 90] + assert list(check_extent([-360, 180, -98, 90])) == [-360, 180, -98, 90] # should not assume lat/lon coords + + def test_invalid_extent_length(self): + """Test that an error is raised when the extent does not contain exactly four elements.""" + with pytest.raises(ValueError) as excinfo: + check_extent([-180, 180, -90]) + assert "four elements" in str(excinfo.value) + + def test_invalid_xmin_greater_than_xmax(self): + """Test that an error is raised when xmin is not less than xmax.""" + with pytest.raises(ValueError) as excinfo: + check_extent([180, -180, -90, 90]) + assert "xmin must be less than xmax" in str(excinfo.value) + + def test_invalid_ymin_greater_than_ymax(self): + """Test that an error is raised when ymin is not less than ymax.""" + with pytest.raises(ValueError) as excinfo: + check_extent([-180, 180, 90, -90]) + assert "ymin must be less than ymax" in str(excinfo.value) + + def test_invalid_numerical_values(self): + """Test that the function handles non-numerical inputs.""" + with pytest.raises(ValueError): + check_extent(["west", "east", "south", "north"]) + with pytest.raises(ValueError): + check_extent([None, None, None, None]) + + class TestAdjustGeographicExtent: """Tests for adjust_geographic_extent function.""" @@ -209,13 +244,13 @@ def test_extend_extent_boundary_limits(self): assert extend_geographic_extent(extent, padding) == expected -class TestGetExtentAroundPoint: - """Class to test get_extent_around_point function.""" +class TestGetGeographicExtentAroundPoint: + """Class to test get_geographic_extent_around_point function.""" def test_with_valid_distance(self): """Test function with a valid distance and no size.""" lon, lat, distance = -123.1207, 49.2827, 10000 - result = get_extent_around_point(lon, lat, distance=distance) + result = get_geographic_extent_around_point(lon, lat, distance=distance) assert isinstance(result, tuple), "Should return a tuple" assert len(result) == 4, "Tuple should have four elements" np.testing.assert_almost_equal(result, [-123.258144, -122.983255, 49.1927835, 49.372615], decimal=6) @@ -224,7 +259,7 @@ def test_with_valid_size(self): """Test function with a valid size and no distance.""" lon, lat = -123.1207, 49.2827 size = (0.1, 0.1) - result = get_extent_around_point(lon, lat, size=size) + result = get_geographic_extent_around_point(lon, lat, size=size) assert isinstance(result, tuple), "Should return a tuple" assert len(result) == 4, "Tuple should have four elements" np.testing.assert_almost_equal(result, [-123.1707, -123.0707, 49.2327, 49.332699], decimal=6) @@ -233,19 +268,19 @@ def test_with_both_distance_and_size(self): """Test function raises ValueError when both distance and size are provided.""" lon, lat, distance, size = -123.1207, 49.2827, 10000, (0.1, 0.1) with pytest.raises(ValueError): - get_extent_around_point(lon, lat, distance=distance, size=size) + get_geographic_extent_around_point(lon, lat, distance=distance, size=size) def test_with_neither_distance_nor_size(self): """Test function raises ValueError when neither distance nor size is provided.""" lon, lat = -123.1207, 49.2827 with pytest.raises(ValueError): - get_extent_around_point(lon, lat) + get_geographic_extent_around_point(lon, lat) def test_with_invalid_size_type(self): """Test function raises TypeError when size is of an invalid type.""" lon, lat = -123.1207, 49.2827 with pytest.raises(TypeError): - get_extent_around_point(lon, lat, size="invalid_size_type") + get_geographic_extent_around_point(lon, lat, size="invalid_size_type") def test_get_country_extent( @@ -306,8 +341,8 @@ def test_get_continent_extent( get_continent_extent(continent) -def test_get_extent() -> None: - """Test get_extent.""" +def test_get_geographic_extent_from_xarray() -> None: + """Test get_geographic_extent_from_xarray.""" ds = xr.Dataset( { "lon": [-10, 0, 20], @@ -317,44 +352,44 @@ def test_get_extent() -> None: # Test without padding expected_extent = (-10, 20, -30, 40) - returned_extent = get_extent(ds) + returned_extent = get_geographic_extent_from_xarray(ds) assert returned_extent == expected_extent # Test with float padding padding = 0.1 expected_extent = (-10.1, 20.1, -30.1, 40.1) - returned_extent = get_extent(ds, padding=padding) + returned_extent = get_geographic_extent_from_xarray(ds, padding=padding) assert returned_extent == expected_extent # Test with size expected_extent = (0, 10, 0, 10) - returned_extent = get_extent(ds, size=10) + returned_extent = get_geographic_extent_from_xarray(ds, size=10) assert returned_extent == expected_extent # Test with padding exceeding bounds padding = 180 expected_extent = (-180, 180, -90, 90) - returned_extent = get_extent(ds, padding=padding) + returned_extent = get_geographic_extent_from_xarray(ds, padding=padding) assert returned_extent == expected_extent # Test with tuple padding padding = (0.1, 0.2) expected_extent = (-10.1, 20.1, -30.2, 40.2) - returned_extent = get_extent(ds, padding=padding) + returned_extent = get_geographic_extent_from_xarray(ds, padding=padding) assert returned_extent == expected_extent padding = (0.1, 0.1, 0.2, 0.2) expected_extent = (-10.1, 20.1, -30.2, 40.2) - returned_extent = get_extent(ds, padding=padding) + returned_extent = get_geographic_extent_from_xarray(ds, padding=padding) assert returned_extent == expected_extent # Test with invalid padding with pytest.raises(TypeError): - get_extent(ds, padding="invalid") + get_geographic_extent_from_xarray(ds, padding="invalid") with pytest.raises(ValueError): - get_extent(ds, padding=(0.1,)) + get_geographic_extent_from_xarray(ds, padding=(0.1,)) with pytest.raises(ValueError): - get_extent(ds, padding=(0.1, 0.2, 0.3)) + get_geographic_extent_from_xarray(ds, padding=(0.1, 0.2, 0.3)) # Test with object crossing dateline ds = xr.Dataset( @@ -364,7 +399,7 @@ def test_get_extent() -> None: }, ) with pytest.raises(NotImplementedError): - get_extent(ds) + get_geographic_extent_from_xarray(ds) def test_get_circle_coordinates_around_point(): diff --git a/gpm/utils/geospatial.py b/gpm/utils/geospatial.py index 1521edf4..1eab344c 100644 --- a/gpm/utils/geospatial.py +++ b/gpm/utils/geospatial.py @@ -133,6 +133,150 @@ def _check_padding(padding: Union[int, float, tuple, list] = 0): return padding +def check_extent(extent): + """ + Validates the extent to ensure it has the correct format and logical consistency. + + Note: this function does not check for the realism of extent values ! + + Parameters + ---------- + extent : list or tuple + The extent specified as [xmin, xmax, ymin, ymax]. + + Returns + ------- + extent: tuple + + """ + if len(extent) != 4: + raise ValueError("Extent must contain exactly four elements: [xmin, xmax, ymin, ymax].") + for v in extent: + if not isinstance(v, (int, float, np.floating, np.integer)): + raise ValueError("The extent must be composed by numeric values.") + if not (extent[0] <= extent[1]): + raise ValueError("xmin must be less than xmax.") + if not (extent[2] <= extent[3]): + raise ValueError("ymin must be less than ymax.") + return Extent(*extent) + + +####------------------------------------------------------------------------------------. +#### Planar Extent + + +def get_extent_around_point(x, y, distance=None, size=None): + """ + Get the extent around a point. + + Either specify ``distance`` or the wished extent ``size`` (in the unit of the extent). + + Parameters + ---------- + x : float + X coordinate of the point. + y : float + Y coordinate of the point. + distance: float + Distance from the point in each direction. + size : int, float, tuple, list + The size of the extent in each direction. + If ``size`` is a single number, the same size is ensured in all directions. + If ``size`` is a tuple or list, it must of size 2 and specifying + the desired size of the extent in the x direction + and the y direction. + + Returns + ------- + tuple + The adjusted extent. + + """ + if distance is not None and size is not None: + raise ValueError("Either provide the 'distance' or the 'size' of the extent.") + if distance is None and size is None: + raise ValueError("Please provide the 'distance' or the 'size' of the extent.") + if size is not None: + return adjust_extent(extent=[x, x, y, y], size=size) + # Calculate new points in the four cardinal directions by the specified distance + extent = [x - distance, x + distance, y - distance, y + distance] + return extend_extent(extent, padding=0) + + +def adjust_extent(extent, size): + """ + Adjust the extent to have the desired size. + + Parameters + ---------- + extent : tuple + A tuple of four values representing the extent. + The extent format must be ``[xmin, xmax, ymin, ymax]``. + size : int, float, tuple, list + The size in degrees of the extent in each direction. + If ``size`` is a single number, the same size is ensured in all directions. + If ``size`` is a tuple or list, it must of size 2 and specifying + the desired size of the extent in the x direction and the y direction. + + Returns + ------- + tuple + The adjusted extent. + + """ + # Retrieve desired size + x_size, y_size = _check_size(size) + + # Retrieve current extent + extent = Extent(*extent) + + # Center of the current extent + x_center = (extent.xmax + extent.xmin) / 2 + y_center = (extent.ymax + extent.ymin) / 2 + + # Define new min and max xgitudes and yitudes + xmin = x_center - x_size / 2 + xmax = x_center + x_size / 2 + ymin = y_center - y_size / 2 + ymax = y_center + y_size / 2 + + return Extent(xmin, xmax, ymin, ymax) + + +def extend_extent(extent, padding: Union[int, float, tuple, list] = 0): + """Extend the extent by padding in every direction. + + Parameters + ---------- + extent : tuple + A tuple of four values representing the extent. + The extent format must be ``[xmin, xmax, ymin, ymax]``. + padding : int, float, tuple, list + The number of degrees to extend the extent in each direction. + If ``padding`` is a single number, the same padding is applied in all directions. + If ``padding`` is a tuple or list, it must contain 2 or 4 elements. + If two values are provided (x, y), they are interpreted as x and y padding, respectively. + If four values are provided, they directly correspond to padding for each side ``(left, right, top, bottom)``. + + Returns + ------- + tuple + The extended extent. + + """ + padding = _check_padding(padding) + xmin, xmax, ymin, ymax = extent + xmin = xmin - padding[0] + xmax = xmax + padding[1] + ymin = ymin - padding[2] + ymax = ymax + padding[3] + return Extent(xmin, xmax, ymin, ymax) + + +####------------------------------------------------------------------------------------. +#### Geographic Extent + + def extend_geographic_extent(extent, padding: Union[int, float, tuple, list] = 0): """Extend the lat/lon extent by x degrees in every direction. @@ -154,12 +298,11 @@ def extend_geographic_extent(extent, padding: Union[int, float, tuple, list] = 0 The extended extent. """ - padding = _check_padding(padding) - xmin, xmax, ymin, ymax = extent - xmin = max(xmin - padding[0], -180) - xmax = min(xmax + padding[1], 180) - ymin = max(ymin - padding[2], -90) - ymax = min(ymax + padding[3], 90) + extent = extend_extent(extent, padding) + xmin = max(extent.xmin, -180) + xmax = min(extent.xmax, 180) + ymin = max(extent.ymin, -90) + ymax = min(extent.ymax, 90) return Extent(xmin, xmax, ymin, ymax) @@ -185,22 +328,7 @@ def adjust_geographic_extent(extent, size): The adjusted extent. """ - # Retrieve current extent - lon_min, lon_max, lat_min, lat_max = extent - - # Retrieve desired size - x_size, y_size = _check_size(size) - - # Center of the current extent - lon_center = (lon_max + lon_min) / 2 - lat_center = (lat_max + lat_min) / 2 - - # Define new min and max longitudes and latitudes - new_lon_min = lon_center - x_size / 2 - new_lon_max = lon_center + x_size / 2 - new_lat_min = lat_center - y_size / 2 - new_lat_max = lat_center + y_size / 2 - + new_lon_min, new_lon_max, new_lat_min, new_lat_max = adjust_extent(extent, size) # Ensure within [-180, 180] longitude extent and of desired size if new_lon_min < -180: new_lon_max = new_lon_max + (new_lon_min + 180) @@ -208,7 +336,6 @@ def adjust_geographic_extent(extent, size): if new_lon_max > 180: new_lon_min = new_lon_min - (new_lon_max - 180) new_lon_max = 180 - # Ensure within [-90, 90] latitude extent and of desired size if new_lat_min < -90: new_lat_max = new_lat_min + (new_lat_min + 90) @@ -219,9 +346,65 @@ def adjust_geographic_extent(extent, size): return extend_geographic_extent([new_lon_min, new_lon_max, new_lat_min, new_lat_max], padding=0) -def get_extent_around_point(lon, lat, distance=None, size=None): +def _is_crossing_dateline(lon: Union[list, np.ndarray]): + """Check if the longitude array is crossing the dateline.""" + lon = np.asarray(lon) + diff = np.diff(lon) + return np.any(np.abs(diff) > 180) + + +def get_geographic_extent_from_xarray( + xr_obj, + padding: Union[int, float, tuple, list] = 0, + size: Optional[Union[int, float, tuple, list]] = None, +): + """Get the geographic extent from an xarray object. + + Parameters + ---------- + xr_obj : `xarray.DataArray` or `xarray.Dataset` + xarray object. + padding : int, float, tuple, list + The number of degrees to extend the extent in each direction. + If ``padding`` is a single number, the same padding is applied in all directions. + If ``padding`` is a tuple or list, it must contain 2 or 4 elements. + If two values are provided (x, y), they are interpreted as longitude and latitude padding, respectively. + If four values are provided, they directly correspond to padding for each side ``(left, right, top, bottom)``. + The default is ``0``. + size : int, float, tuple, list + The desired size in degrees of the extent in each direction. + If ``size`` is a single number, the same size is enforced in all directions. + If ``size`` is a tuple or list, it must of size 2 and specify the desired size of + the extent in the x direction (longitude) and the y direction (latitude). + The default is ``None``. + + Returns + ------- + extent : tuple + A tuple containing the longitude and latitude extent of the xarray object. + The extent follows the matplotlib/cartopy format ``(xmin, xmax, ymin, ymax)``. + """ - Get the extent around a point. + # TODO: should compute the corners and return based on the sides + padding = _check_padding(padding=padding) + + lon = xr_obj["lon"].to_numpy() + lat = xr_obj["lat"].to_numpy() + + if _is_crossing_dateline(lon): + raise NotImplementedError( + "The object cross the dateline. The extent can't be currently defined.", + ) + extent = Extent(np.nanmin(lon), np.nanmax(lon), np.nanmin(lat), np.nanmax(lat)) + extent = extend_geographic_extent(extent, padding=padding) + if size is not None: + extent = adjust_geographic_extent(extent, size=size) + return extent + + +def get_geographic_extent_around_point(lon, lat, distance=None, size=None): + """ + Get the geographic extent around a point. Either specify ``distance`` (in meters) or the wished extent ``size`` (in degrees). @@ -264,61 +447,42 @@ def get_extent_around_point(lon, lat, distance=None, size=None): return extend_geographic_extent(extent, padding=0) -def get_circle_coordinates_around_point(lon, lat, radius, num_vertices=360): - """Get the coordinates of a circle with custom radius around a point. - - Parameters - ---------- - lon : float - Longitude of the point. - lat : float - Latitude of the point. - radius : float - Radius (in meters) around the point. - num_vertices : int, optional - Number of circle coordinates to return. The default is 360. +def read_countries_extent_dictionary(): + """Reads a YAML file containing countries extent information and returns it as a dictionary. Returns ------- - lons : `numpy.ndarray` - Longitude vertices of the circle around the point. - lats : `numpy.ndarray` - Latitude vertices of the circle around the point. + dict + A dictionary containing countries extent information. """ - geod = pyproj.Geod(ellps="WGS84") - - # Angle between each point in degrees - angles = np.linspace(0, 360, num_vertices, endpoint=False) - - # Compute the coordinates of the circle's vertices - lons, lats, _ = geod.fwd( - np.ones(angles.shape) * lon, - np.ones(angles.shape) * lat, - angles, - np.ones(angles.shape) * radius, - radians=False, + countries_extent_filepath = os.path.join( + _root_path, + "gpm", + "etc", + "geospatial", + "country_extent.yaml", ) - return lons, lats + return read_yaml(countries_extent_filepath) -def read_countries_extent_dictionary(): - """Reads a YAML file containing countries extent information and returns it as a dictionary. +def read_continents_extent_dictionary(): + """Read and return a dictionary containing the extents of continents. Returns ------- dict - A dictionary containing countries extent information. + A dictionary containing the extents of continents. """ - countries_extent_filepath = os.path.join( + continents_extent_filepath = os.path.join( _root_path, "gpm", "etc", "geospatial", - "country_extent.yaml", + "continent_extent.yaml", ) - return read_yaml(countries_extent_filepath) + return read_yaml(continents_extent_filepath) def get_country_extent(name, padding=0.2): @@ -377,25 +541,6 @@ def get_country_extent(name, padding=0.2): raise ValueError(f"No matching country. Maybe are you looking for '{possible_match}'?") -def read_continents_extent_dictionary(): - """Read and return a dictionary containing the extents of continents. - - Returns - ------- - dict - A dictionary containing the extents of continents. - - """ - continents_extent_filepath = os.path.join( - _root_path, - "gpm", - "etc", - "geospatial", - "continent_extent.yaml", - ) - return read_yaml(continents_extent_filepath) - - def get_continent_extent(name: str, padding: Union[int, float, tuple, list] = 0): """Retrieves the extent of a continent. @@ -446,67 +591,8 @@ def get_continent_extent(name: str, padding: Union[int, float, tuple, list] = 0) raise ValueError(f"No matching continent. Maybe are you looking for '{possible_match}'?") -def unwrap_longitude_degree(x, period=360): - """Unwrap longitude array.""" - x = np.asarray(x) - mod = period / 2 - return (x + mod) % (2 * mod) - mod - - -def _is_crossing_dateline(lon: Union[list, np.ndarray]): - """Check if the longitude array is crossing the dateline.""" - lon = np.asarray(lon) - diff = np.diff(lon) - return np.any(np.abs(diff) > 180) - - -def get_extent( - xr_obj, - padding: Union[int, float, tuple, list] = 0, - size: Optional[Union[int, float, tuple, list]] = None, -): - """Get the geographic extent from an xarray object. - - Parameters - ---------- - xr_obj : `xarray.DataArray` or `xarray.Dataset` - xarray object. - padding : int, float, tuple, list - The number of degrees to extend the extent in each direction. - If ``padding`` is a single number, the same padding is applied in all directions. - If ``padding`` is a tuple or list, it must contain 2 or 4 elements. - If two values are provided (x, y), they are interpreted as longitude and latitude padding, respectively. - If four values are provided, they directly correspond to padding for each side ``(left, right, top, bottom)``. - The default is ``0``. - size : int, float, tuple, list - The desired size in degrees of the extent in each direction. - If ``size`` is a single number, the same size is enforced in all directions. - If ``size`` is a tuple or list, it must of size 2 and specify the desired size of - the extent in the x direction (longitude) and the y direction (latitude). - The default is ``None``. - - Returns - ------- - extent : tuple - A tuple containing the longitude and latitude extent of the xarray object. - The extent follows the matplotlib/cartopy format ``(xmin, xmax, ymin, ymax)``. - - """ - # TODO: should compute the corners and return based on the sides - padding = _check_padding(padding=padding) - - lon = xr_obj["lon"].to_numpy() - lat = xr_obj["lat"].to_numpy() - - if _is_crossing_dateline(lon): - raise NotImplementedError( - "The object cross the dateline. The extent can't be currently defined.", - ) - extent = Extent(np.nanmin(lon), np.nanmax(lon), np.nanmin(lat), np.nanmax(lat)) - extent = extend_geographic_extent(extent, padding=padding) - if size is not None: - extent = adjust_geographic_extent(extent, size=size) - return extent +####------------------------------------------------------------------------------------. +#### Geographic crop def crop(xr_obj, extent): @@ -610,7 +696,7 @@ def crop_around_point(xr_obj, lon: float, lat: float, distance=None, size=None): Cropped xarray object. """ - extent = get_extent_around_point(lon=lon, lat=lat, distance=distance, size=size) + extent = get_geographic_extent_around_point(lon=lon, lat=lat, distance=distance, size=size) return crop(xr_obj=xr_obj, extent=extent) @@ -722,5 +808,54 @@ def get_crop_slices_around_point(xr_obj, lon: float, lat: float, distance=None, Cropped xarray object. """ - extent = get_extent_around_point(lon=lon, lat=lat, distance=distance, size=size) + extent = get_geographic_extent_around_point(lon=lon, lat=lat, distance=distance, size=size) return get_crop_slices_by_extent(xr_obj=xr_obj, extent=extent) + + +####------------------------------------------------------------------------------------. +#### Miscellaneous + + +def unwrap_longitude_degree(x, period=360): + """Unwrap longitude array.""" + x = np.asarray(x) + mod = period / 2 + return (x + mod) % (2 * mod) - mod + + +def get_circle_coordinates_around_point(lon, lat, radius, num_vertices=360): + """Get the coordinates of a circle with custom radius around a point. + + Parameters + ---------- + lon : float + Longitude of the point. + lat : float + Latitude of the point. + radius : float + Radius (in meters) around the point. + num_vertices : int, optional + Number of circle coordinates to return. The default is 360. + + Returns + ------- + lons : `numpy.ndarray` + Longitude vertices of the circle around the point. + lats : `numpy.ndarray` + Latitude vertices of the circle around the point. + + """ + geod = pyproj.Geod(ellps="WGS84") + + # Angle between each point in degrees + angles = np.linspace(0, 360, num_vertices, endpoint=False) + + # Compute the coordinates of the circle's vertices + lons, lats, _ = geod.fwd( + np.ones(angles.shape) * lon, + np.ones(angles.shape) * lat, + angles, + np.ones(angles.shape) * radius, + radians=False, + ) + return lons, lats