diff --git a/pysteps/nowcasts/steps.py b/pysteps/nowcasts/steps.py index ed1cd43c..dedbb726 100644 --- a/pysteps/nowcasts/steps.py +++ b/pysteps/nowcasts/steps.py @@ -26,7 +26,7 @@ from pysteps.nowcasts.utils import compute_percentile_mask, nowcast_main_loop from dataclasses import dataclass, field -from typing import Optional, Dict, Any, Callable, List +from typing import Any, Callable try: import dask @@ -40,30 +40,30 @@ class StepsNowcasterConfig: n_ens_members: int = 24 n_cascade_levels: int = 6 - precip_threshold: Optional[float] = None - kmperpixel: Optional[float] = None - timestep: Optional[float] = None + precip_threshold: float | None = None + kmperpixel: float | None = None + timestep: float | None = None extrapolation_method: str = "semilagrangian" decomposition_method: str = "fft" bandpass_filter_method: str = "gaussian" - noise_method: Optional[str] = "nonparametric" - noise_stddev_adj: Optional[str] = None + noise_method: str | None = "nonparametric" + noise_stddev_adj: str | None = None ar_order: int = 2 - velocity_perturbation_method: Optional[str] = "bps" + velocity_perturbation_method: str | None = "bps" conditional: bool = False - probmatching_method: Optional[str] = "cdf" - mask_method: Optional[str] = "incremental" - seed: Optional[int] = None + probmatching_method: str | None = "cdf" + mask_method: str | None = "incremental" + seed: int | None = None num_workers: int = 1 fft_method: str = "numpy" domain: str = "spatial" - extrapolation_kwargs: Dict[str, Any] = field(default_factory=dict) - filter_kwargs: Dict[str, Any] = field(default_factory=dict) - noise_kwargs: Dict[str, Any] = field(default_factory=dict) - velocity_perturbation_kwargs: Dict[str, Any] = field(default_factory=dict) - mask_kwargs: Dict[str, Any] = field(default_factory=dict) + extrapolation_kwargs: dict[str, Any] = field(default_factory=dict) + filter_kwargs: dict[str, Any] = field(default_factory=dict) + noise_kwargs: dict[str, Any] = field(default_factory=dict) + velocity_perturbation_kwargs: dict[str, Any] = field(default_factory=dict) + mask_kwargs: dict[str, Any] = field(default_factory=dict) measure_time: bool = False - callback: Optional[Callable[[Any], None]] = None + callback: Callable[[Any], None] | None = None return_output: bool = True @@ -74,218 +74,228 @@ class StepsNowcasterParams: extrapolation_method: Any = None decomposition_method: Any = None recomposition_method: Any = None - noise_generator: Optional[callable] = None - perturbation_generator: Optional[callable] = None - noise_std_coefficients: Optional[np.ndarray] = None - ar_model_coefficients: Optional[np.ndarray] = None # Corresponds to phi - autocorrelation_coefficients: Optional[np.ndarray] = None # Corresponds to gamma - domain_mask: Optional[np.ndarray] = None - structuring_element: Optional[np.ndarray] = None - precipitation_mean: Optional[float] = None - wet_area_ratio: Optional[float] = None - mask_rim: Optional[int] = None + noise_generator: Callable | None = None + perturbation_generator: Callable | None = None + noise_std_coefficients: np.ndarray | None = None + ar_model_coefficients: np.ndarray | None = None # Corresponds to phi + autocorrelation_coefficients: np.ndarray | None = None # Corresponds to gamma + domain_mask: np.ndarray | None = None + structuring_element: np.ndarray | None = None + precipitation_mean: float | None = None + wet_area_ratio: float | None = None + mask_rim: int | None = None num_ensemble_workers: int = 1 - xy_coordinates: Optional[np.ndarray] = None - velocity_perturbation_parallel: Optional[List[float]] = None - velocity_perturbation_perpendicular: Optional[List[float]] = None + xy_coordinates: np.ndarray | None = None + velocity_perturbation_parallel: list[float] | None = None + velocity_perturbation_perpendicular: list[float] | None = None @dataclass class StepsNowcasterState: - precip_forecast: Optional[List[Any]] = field(default_factory=list) - precip_cascades: Optional[List[List[np.ndarray]]] = field(default_factory=list) - precip_decomposed: Optional[List[Dict[str, Any]]] = field(default_factory=list) + precip_forecast: list[Any] | None = field(default_factory=list) + precip_cascades: list[list[np.ndarray]] | None = field(default_factory=list) + precip_decomposed: list[dict[str, Any]] | None = field(default_factory=list) # The observation mask (where the radar can observe the precipitation) - precip_mask: Optional[List[Any]] = field(default_factory=list) - precip_mask_decomposed: Optional[Dict[str, Any]] = field(default_factory=dict) + precip_mask: list[Any] | None = field(default_factory=list) + precip_mask_decomposed: dict[str, Any] | None = field(default_factory=dict) # The mask around the precipitation fields (to get only non-zero values) - mask_precip: Optional[np.ndarray] = None - mask_threshold: Optional[np.ndarray] = None - random_generator_precip: Optional[List[np.random.RandomState]] = field( + mask_precip: np.ndarray | None = None + mask_threshold: np.ndarray | None = None + random_generator_precip: list[np.random.RandomState] | None = field( default_factory=list ) - random_generator_motion: Optional[List[np.random.RandomState]] = field( + random_generator_motion: list[np.random.RandomState] | None = field( default_factory=list ) - velocity_perturbations: Optional[List[callable]] = field(default_factory=list) - fft_objects: Optional[List[Any]] = field(default_factory=list) + velocity_perturbations: list[Callable] | None = field(default_factory=list) + fft_objects: list[Any] | None = field(default_factory=list) class StepsNowcaster: - def __init__(self, precip, velocity, time_steps, steps_config: StepsNowcasterConfig): + def __init__( + self, precip, velocity, time_steps, steps_config: StepsNowcasterConfig + ): # Store inputs and optional parameters - self.precip = precip - self.velocity = velocity - self.time_steps = time_steps + self.__precip = precip + self.__velocity = velocity + self.__time_steps = time_steps # Store the config data: - self.config = steps_config + self.__config = steps_config # Store the state and params data: - self.state = StepsNowcasterState() - self.params = StepsNowcasterParams() + self.__state = StepsNowcasterState() + self.__params = StepsNowcasterParams() # Additional variables for time measurement - self.start_time_init = None - self.init_time = None - self.mainloop_time = None + self.__start_time_init = None + self.__init_time = None + self.__mainloop_time = None def compute_forecast(self): """ Main loop for nowcast ensemble generation. This handles extrapolation, noise application, autoregressive modeling, and recomposition of cascades. """ - self._check_inputs() - self._print_forecast_info() + self.__check_inputs() + self.__print_forecast_info() # Measure time for initialization - if self.config.measure_time: - self.start_time_init = time.time() + if self.__config.measure_time: + self.__start_time_init = time.time() - self._initialize_nowcast_components() + self.__initialize_nowcast_components() # Slice the precipitation field to only use the last ar_order + 1 fields - self.precip = self.precip[-(self.config.ar_order + 1) :, :, :].copy() + self.__precip = self.__precip[-(self.__config.ar_order + 1) :, :, :].copy() - self._perform_extrapolation() - self._apply_noise_and_ar_model() - self._initialize_velocity_perturbations() - self._initialize_precipitation_mask() - self._initialize_fft_objects() + self.__perform_extrapolation() + self.__apply_noise_and_ar_model() + self.__initialize_velocity_perturbations() + self.__initialize_precipitation_mask() + self.__initialize_fft_objects() # Measure and print initialization time - if self.config.measure_time: - self._measure_time("Initialization", self.start_time_init) + if self.__config.measure_time: + self.__measure_time("Initialization", self.__start_time_init) # Run the main nowcast loop - self._nowcast_main() + self.__nowcast_main() - if self.config.measure_time: - self.state.precip_forecast, self.mainloop_time = self.state.precip_forecast + if self.__config.measure_time: + self.__state.precip_forecast, self.__mainloop_time = ( + self.__state.precip_forecast + ) # Stack and return the forecast output - if self.config.return_output: - self.state.precip_forecast = np.stack( + if self.__config.return_output: + self.__state.precip_forecast = np.stack( [ - np.stack(self.state.precip_forecast[j]) - for j in range(self.config.n_ens_members) + np.stack(self.__state.precip_forecast[j]) + for j in range(self.__config.n_ens_members) ] ) - if self.config.measure_time: - return self.state.precip_forecast, self.init_time, self.mainloop_time + if self.__config.measure_time: + return ( + self.__state.precip_forecast, + self.__init_time, + self.__mainloop_time, + ) else: - return self.state.precip_forecast + return self.__state.precip_forecast else: return None - def _nowcast_main(self): + def __nowcast_main(self): """ Main nowcast loop that iterates through the ensemble members and time steps to generate forecasts. """ # Isolate the last time slice of precipitation - precip = self.precip[-1, :, :] # Extract the last available precipitation field + precip = self.__precip[ + -1, :, : + ] # Extract the last available precipitation field # Prepare state and params dictionaries, these need to be formatted a specific way for the nowcast_main_loop - state = self._initialize_state() - params = self._initialize_params(precip) + state = self.__initialize_state() + params = self.__initialize_params(precip) print("Starting nowcast computation.") # Run the nowcast main loop - self.state.precip_forecast = nowcast_main_loop( + self.__state.precip_forecast = nowcast_main_loop( precip, - self.velocity, + self.__velocity, state, - self.time_steps, - self.config.extrapolation_method, - self._update_state, # Reference to the update function - extrap_kwargs=self.config.extrapolation_kwargs, - velocity_pert_gen=self.state.velocity_perturbations, + self.__time_steps, + self.__config.extrapolation_method, + self.__update_state, # Reference to the update function + extrap_kwargs=self.__config.extrapolation_kwargs, + velocity_pert_gen=self.__state.velocity_perturbations, params=params, ensemble=True, - num_ensemble_members=self.config.n_ens_members, - callback=self.config.callback, - return_output=self.config.return_output, - num_workers=self.params.num_ensemble_workers, - measure_time=self.config.measure_time, + num_ensemble_members=self.__config.n_ens_members, + callback=self.__config.callback, + return_output=self.__config.return_output, + num_workers=self.__params.num_ensemble_workers, + measure_time=self.__config.measure_time, ) - def _check_inputs(self): + def __check_inputs(self): """ Validate the inputs to ensure consistency and correct shapes. """ - if self.precip.ndim != 3: + if self.__precip.ndim != 3: raise ValueError("precip must be a three-dimensional array") - if self.precip.shape[0] < self.config.ar_order + 1: + if self.__precip.shape[0] < self.__config.ar_order + 1: raise ValueError( f"precip.shape[0] must be at least ar_order+1, " - f"but found {self.precip.shape[0]}" + f"but found {self.__precip.shape[0]}" ) - if self.velocity.ndim != 3: + if self.__velocity.ndim != 3: raise ValueError("velocity must be a three-dimensional array") - if self.precip.shape[1:3] != self.velocity.shape[1:3]: + if self.__precip.shape[1:3] != self.__velocity.shape[1:3]: raise ValueError( f"Dimension mismatch between precip and velocity: " - f"shape(precip)={self.precip.shape}, shape(velocity)={self.velocity.shape}" + f"shape(precip)={self.__precip.shape}, shape(velocity)={self.__velocity.shape}" ) if ( - isinstance(self.time_steps, list) - and not sorted(self.time_steps) == self.time_steps + isinstance(self.__time_steps, list) + and not sorted(self.__time_steps) == self.__time_steps ): raise ValueError("timesteps must be in ascending order") - if np.any(~np.isfinite(self.velocity)): + if np.any(~np.isfinite(self.__velocity)): raise ValueError("velocity contains non-finite values") - if self.config.mask_method not in ["obs", "sprog", "incremental", None]: + if self.__config.mask_method not in ["obs", "sprog", "incremental", None]: raise ValueError( - f"Unknown mask method '{self.config.mask_method}'. " + f"Unknown mask method '{self.__config.mask_method}'. " "Must be 'obs', 'sprog', 'incremental', or None." ) - if self.config.precip_threshold is None: - if self.config.conditional: + if self.__config.precip_threshold is None: + if self.__config.conditional: raise ValueError("conditional=True but precip_thr is not specified.") - if self.config.mask_method is not None: + if self.__config.mask_method is not None: raise ValueError("mask_method is set but precip_thr is not specified.") - if self.config.probmatching_method == "mean": + if self.__config.probmatching_method == "mean": raise ValueError( "probmatching_method='mean' but precip_thr is not specified." ) if ( - self.config.noise_method is not None - and self.config.noise_stddev_adj == "auto" + self.__config.noise_method is not None + and self.__config.noise_stddev_adj == "auto" ): raise ValueError( "noise_stddev_adj='auto' but precip_thr is not specified." ) - if self.config.noise_stddev_adj not in ["auto", "fixed", None]: + if self.__config.noise_stddev_adj not in ["auto", "fixed", None]: raise ValueError( - f"Unknown noise_stddev_adj method '{self.config.noise_stddev_adj}'. " + f"Unknown noise_stddev_adj method '{self.__config.noise_stddev_adj}'. " "Must be 'auto', 'fixed', or None." ) - if self.config.kmperpixel is None: - if self.config.velocity_perturbation_method is not None: + if self.__config.kmperpixel is None: + if self.__config.velocity_perturbation_method is not None: raise ValueError("vel_pert_method is set but kmperpixel=None") - if self.config.mask_method == "incremental": + if self.__config.mask_method == "incremental": raise ValueError("mask_method='incremental' but kmperpixel=None") - if self.config.timestep is None: - if self.config.velocity_perturbation_method is not None: + if self.__config.timestep is None: + if self.__config.velocity_perturbation_method is not None: raise ValueError("vel_pert_method is set but timestep=None") - if self.config.mask_method == "incremental": + if self.__config.mask_method == "incremental": raise ValueError("mask_method='incremental' but timestep=None") # Handle None values for various kwargs - if self.config.extrapolation_kwargs is None: - self.config.extrapolation_kwargs = {} - if self.config.filter_kwargs is None: - self.config.filter_kwargs = {} - if self.config.noise_kwargs is None: - self.config.noise_kwargs = {} - if self.config.velocity_perturbation_kwargs is None: - self.config.velocity_perturbation_kwargs = {} - if self.config.mask_kwargs is None: - self.config.mask_kwargs = {} + if self.__config.extrapolation_kwargs is None: + self.__config.extrapolation_kwargs = {} + if self.__config.filter_kwargs is None: + self.__config.filter_kwargs = {} + if self.__config.noise_kwargs is None: + self.__config.noise_kwargs = {} + if self.__config.velocity_perturbation_kwargs is None: + self.__config.velocity_perturbation_kwargs = {} + if self.__config.mask_kwargs is None: + self.__config.mask_kwargs = {} print("Inputs validated and initialized successfully.") - def _print_forecast_info(self): + def __print_forecast_info(self): """ Print information about the forecast setup, including inputs, methods, and parameters. """ @@ -295,499 +305,505 @@ def _print_forecast_info(self): print("Inputs") print("------") - print(f"input dimensions: {self.precip.shape[1]}x{self.precip.shape[2]}") - if self.config.kmperpixel is not None: - print(f"km/pixel: {self.config.kmperpixel}") - if self.config.timestep is not None: - print(f"time step: {self.config.timestep} minutes") + print(f"input dimensions: {self.__precip.shape[1]}x{self.__precip.shape[2]}") + if self.__config.kmperpixel is not None: + print(f"km/pixel: {self.__config.kmperpixel}") + if self.__config.timestep is not None: + print(f"time step: {self.__config.timestep} minutes") print("") print("Methods") print("-------") - print(f"extrapolation: {self.config.extrapolation_method}") - print(f"bandpass filter: {self.config.bandpass_filter_method}") - print(f"decomposition: {self.config.decomposition_method}") - print(f"noise generator: {self.config.noise_method}") + print(f"extrapolation: {self.__config.extrapolation_method}") + print(f"bandpass filter: {self.__config.bandpass_filter_method}") + print(f"decomposition: {self.__config.decomposition_method}") + print(f"noise generator: {self.__config.noise_method}") print( "noise adjustment: {}".format( - ("yes" if self.config.noise_stddev_adj else "no") + ("yes" if self.__config.noise_stddev_adj else "no") ) ) - print(f"velocity perturbator: {self.config.velocity_perturbation_method}") + print(f"velocity perturbator: {self.__config.velocity_perturbation_method}") print( "conditional statistics: {}".format( - ("yes" if self.config.conditional else "no") + ("yes" if self.__config.conditional else "no") ) ) - print(f"precip. mask method: {self.config.mask_method}") - print(f"probability matching: {self.config.probmatching_method}") - print(f"FFT method: {self.config.fft_method}") - print(f"domain: {self.config.domain}") + print(f"precip. mask method: {self.__config.mask_method}") + print(f"probability matching: {self.__config.probmatching_method}") + print(f"FFT method: {self.__config.fft_method}") + print(f"domain: {self.__config.domain}") print("") print("Parameters") print("----------") - if isinstance(self.time_steps, int): - print(f"number of time steps: {self.time_steps}") + if isinstance(self.__time_steps, int): + print(f"number of time steps: {self.__time_steps}") else: - print(f"time steps: {self.time_steps}") - print(f"ensemble size: {self.config.n_ens_members}") - print(f"parallel threads: {self.config.num_workers}") - print(f"number of cascade levels: {self.config.n_cascade_levels}") - print(f"order of the AR(p) model: {self.config.ar_order}") - - if self.config.velocity_perturbation_method == "bps": - self.params.velocity_perturbation_parallel = ( - self.config.velocity_perturbation_kwargs.get( + print(f"time steps: {self.__time_steps}") + print(f"ensemble size: {self.__config.n_ens_members}") + print(f"parallel threads: {self.__config.num_workers}") + print(f"number of cascade levels: {self.__config.n_cascade_levels}") + print(f"order of the AR(p) model: {self.__config.ar_order}") + + if self.__config.velocity_perturbation_method == "bps": + self.__params.velocity_perturbation_parallel = ( + self.__config.velocity_perturbation_kwargs.get( "p_par", noise.motion.get_default_params_bps_par() ) ) - self.params.velocity_perturbation_perpendicular = ( - self.config.velocity_perturbation_kwargs.get( + self.__params.velocity_perturbation_perpendicular = ( + self.__config.velocity_perturbation_kwargs.get( "p_perp", noise.motion.get_default_params_bps_perp() ) ) print( - f"velocity perturbations, parallel: {self.params.velocity_perturbation_parallel[0]},{self.params.velocity_perturbation_parallel[1]},{self.params.velocity_perturbation_parallel[2]}" + f"velocity perturbations, parallel: {self.__params.velocity_perturbation_parallel[0]},{self.__params.velocity_perturbation_parallel[1]},{self.__params.velocity_perturbation_parallel[2]}" ) print( - f"velocity perturbations, perpendicular: {self.params.velocity_perturbation_perpendicular[0]},{self.params.velocity_perturbation_perpendicular[1]},{self.params.velocity_perturbation_perpendicular[2]}" + f"velocity perturbations, perpendicular: {self.__params.velocity_perturbation_perpendicular[0]},{self.__params.velocity_perturbation_perpendicular[1]},{self.__params.velocity_perturbation_perpendicular[2]}" ) - if self.config.precip_threshold is not None: - print(f"precip. intensity threshold: {self.config.precip_threshold}") + if self.__config.precip_threshold is not None: + print(f"precip. intensity threshold: {self.__config.precip_threshold}") - def _initialize_nowcast_components(self): + def __initialize_nowcast_components(self): """ Initialize the FFT, bandpass filters, decomposition methods, and extrapolation method. """ # Initialize number of ensemble workers - self.params.num_ensemble_workers = min( - self.config.n_ens_members, self.config.num_workers + self.__params.num_ensemble_workers = min( + self.__config.n_ens_members, self.__config.num_workers ) - M, N = self.precip.shape[1:] # Extract the spatial dimensions (height, width) + M, N = self.__precip.shape[1:] # Extract the spatial dimensions (height, width) # Initialize FFT method - self.params.fft = utils.get_method( - self.config.fft_method, shape=(M, N), n_threads=self.config.num_workers + self.__params.fft = utils.get_method( + self.__config.fft_method, shape=(M, N), n_threads=self.__config.num_workers ) # Initialize the band-pass filter for the cascade decomposition - filter_method = cascade.get_method(self.config.bandpass_filter_method) - self.params.bandpass_filter = filter_method( - (M, N), self.config.n_cascade_levels, **(self.config.filter_kwargs or {}) + filter_method = cascade.get_method(self.__config.bandpass_filter_method) + self.__params.bandpass_filter = filter_method( + (M, N), + self.__config.n_cascade_levels, + **(self.__config.filter_kwargs or {}), ) # Get the decomposition method (e.g., FFT) - self.params.decomposition_method, self.params.recomposition_method = ( - cascade.get_method(self.config.decomposition_method) + self.__params.decomposition_method, self.__params.recomposition_method = ( + cascade.get_method(self.__config.decomposition_method) ) # Get the extrapolation method (e.g., semilagrangian) - self.params.extrapolation_method = extrapolation.get_method( - self.config.extrapolation_method + self.__params.extrapolation_method = extrapolation.get_method( + self.__config.extrapolation_method ) # Generate the mesh grid for spatial coordinates x_values, y_values = np.meshgrid(np.arange(N), np.arange(M)) - self.params.xy_coordinates = np.stack([x_values, y_values]) + self.__params.xy_coordinates = np.stack([x_values, y_values]) # Determine the domain mask from non-finite values in the precipitation data - self.params.domain_mask = np.logical_or.reduce( - [~np.isfinite(self.precip[i, :]) for i in range(self.precip.shape[0])] + self.__params.domain_mask = np.logical_or.reduce( + [~np.isfinite(self.__precip[i, :]) for i in range(self.__precip.shape[0])] ) print("Nowcast components initialized successfully.") - def _perform_extrapolation(self): + def __perform_extrapolation(self): """ Extrapolate (advect) precipitation fields based on the velocity field to align them in time. This prepares the precipitation fields for autoregressive modeling. """ # Determine the precipitation threshold mask if conditional is set - if self.config.conditional: - self.state.mask_threshold = np.logical_and.reduce( + if self.__config.conditional: + self.__state.mask_threshold = np.logical_and.reduce( [ - self.precip[i, :, :] >= self.config.precip_threshold - for i in range(self.precip.shape[0]) + self.__precip[i, :, :] >= self.__config.precip_threshold + for i in range(self.__precip.shape[0]) ] ) else: - self.state.mask_threshold = None + self.__state.mask_threshold = None - extrap_kwargs = self.config.extrapolation_kwargs.copy() - extrap_kwargs["xy_coords"] = self.params.xy_coordinates + extrap_kwargs = self.__config.extrapolation_kwargs.copy() + extrap_kwargs["xy_coords"] = self.__params.xy_coordinates extrap_kwargs["allow_nonfinite_values"] = ( - True if np.any(~np.isfinite(self.precip)) else False + True if np.any(~np.isfinite(self.__precip)) else False ) res = [] - def _extrapolate_single_field(precip, i): + def __extrapolate_single_field(precip, i): # Extrapolate a single precipitation field using the velocity field - return self.params.extrapolation_method( + return self.__params.extrapolation_method( precip[i, :, :], - self.velocity, - self.config.ar_order - i, + self.__velocity, + self.__config.ar_order - i, "min", **extrap_kwargs, )[-1] - for i in range(self.config.ar_order): + for i in range(self.__config.ar_order): if ( not DASK_IMPORTED ): # If Dask is not available, perform sequential extrapolation - self.precip[i, :, :] = _extrapolate_single_field(self.precip, i) + self.__precip[i, :, :] = __extrapolate_single_field(self.__precip, i) else: # If Dask is available, accumulate delayed computations for parallel execution - res.append(dask.delayed(_extrapolate_single_field)(self.precip, i)) + res.append(dask.delayed(__extrapolate_single_field)(self.__precip, i)) # If Dask is available, perform the parallel computation if DASK_IMPORTED and res: - num_workers_ = min(self.params.num_ensemble_workers, len(res)) - self.precip = np.stack( + num_workers_ = min(self.__params.num_ensemble_workers, len(res)) + self.__precip = np.stack( list(dask.compute(*res, num_workers=num_workers_)) - + [self.precip[-1, :, :]] + + [self.__precip[-1, :, :]] ) print("Extrapolation complete and precipitation fields aligned.") - def _apply_noise_and_ar_model(self): + def __apply_noise_and_ar_model(self): """ Apply noise and autoregressive (AR) models to precipitation cascades. This method applies the AR model to the decomposed precipitation cascades and adds noise perturbations if necessary. """ # Make a copy of the precipitation data and replace non-finite values - precip = self.precip.copy() - for i in range(self.precip.shape[0]): + precip = self.__precip.copy() + for i in range(self.__precip.shape[0]): # Replace non-finite values with the minimum finite value of the precipitation field precip[i, ~np.isfinite(precip[i, :])] = np.nanmin(precip[i, :]) # Store the precipitation data back in the object - self.precip = precip + self.__precip = precip # Initialize the noise generator if the noise_method is provided - if self.config.noise_method is not None: - np.random.seed(self.config.seed) # Set the random seed for reproducibility - init_noise, generate_noise = noise.get_method(self.config.noise_method) - self.params.noise_generator = generate_noise - - self.params.perturbation_generator = init_noise( - self.precip, fft_method=self.params.fft, **self.config.noise_kwargs + if self.__config.noise_method is not None: + np.random.seed( + self.__config.seed + ) # Set the random seed for reproducibility + init_noise, generate_noise = noise.get_method(self.__config.noise_method) + self.__params.noise_generator = generate_noise + + self.__params.perturbation_generator = init_noise( + self.__precip, + fft_method=self.__params.fft, + **self.__config.noise_kwargs, ) # Handle noise standard deviation adjustments if necessary - if self.config.noise_stddev_adj == "auto": + if self.__config.noise_stddev_adj == "auto": print("Computing noise adjustment coefficients... ", end="", flush=True) - if self.config.measure_time: + if self.__config.measure_time: starttime = time.time() # Compute noise adjustment coefficients - self.params.noise_std_coefficients = ( + self.__params.noise_std_coefficients = ( noise.utils.compute_noise_stddev_adjs( - self.precip[-1, :, :], - self.config.precip_threshold, - np.min(self.precip), - self.params.bandpass_filter, - self.params.decomposition_method, - self.params.perturbation_generator, - self.params.noise_generator, + self.__precip[-1, :, :], + self.__config.precip_threshold, + np.min(self.__precip), + self.__params.bandpass_filter, + self.__params.decomposition_method, + self.__params.perturbation_generator, + self.__params.noise_generator, 20, - conditional=self.config.conditional, - num_workers=self.config.num_workers, - seed=self.config.seed, + conditional=self.__config.conditional, + num_workers=self.__config.num_workers, + seed=self.__config.seed, ) ) # Measure and print time taken - if self.config.measure_time: - self._measure_time( + if self.__config.measure_time: + self.__measure_time( "Noise adjustment coefficient computation", starttime ) else: print("done.") - elif self.config.noise_stddev_adj == "fixed": + elif self.__config.noise_stddev_adj == "fixed": # Set fixed noise adjustment coefficients func = lambda k: 1.0 / (0.75 + 0.09 * k) - self.params.noise_std_coefficients = [ - func(k) for k in range(1, self.config.n_cascade_levels + 1) + self.__params.noise_std_coefficients = [ + func(k) for k in range(1, self.__config.n_cascade_levels + 1) ] else: # Default to no adjustment - self.params.noise_std_coefficients = np.ones( - self.config.n_cascade_levels + self.__params.noise_std_coefficients = np.ones( + self.__config.n_cascade_levels ) - if self.config.noise_stddev_adj is not None: + if self.__config.noise_stddev_adj is not None: # Print noise std deviation coefficients if adjustments were made print( - f"noise std. dev. coeffs: {str(self.params.noise_std_coefficients)}" + f"noise std. dev. coeffs: {str(self.__params.noise_std_coefficients)}" ) else: # No noise, so set perturbation generator and noise_std_coefficients to None - self.params.perturbation_generator = None - self.params.noise_std_coefficients = np.ones( - self.config.n_cascade_levels + self.__params.perturbation_generator = None + self.__params.noise_std_coefficients = np.ones( + self.__config.n_cascade_levels ) # Keep default as 1.0 to avoid breaking AR model # Decompose the input precipitation fields - self.state.precip_decomposed = [] - for i in range(self.config.ar_order + 1): - precip_ = self.params.decomposition_method( - self.precip[i, :, :], - self.params.bandpass_filter, - mask=self.state.mask_threshold, - fft_method=self.params.fft, - output_domain=self.config.domain, + self.__state.precip_decomposed = [] + for i in range(self.__config.ar_order + 1): + precip_ = self.__params.decomposition_method( + self.__precip[i, :, :], + self.__params.bandpass_filter, + mask=self.__state.mask_threshold, + fft_method=self.__params.fft, + output_domain=self.__config.domain, normalize=True, compute_stats=True, compact_output=True, ) - self.state.precip_decomposed.append(precip_) + self.__state.precip_decomposed.append(precip_) # Normalize the cascades and rearrange them into a 4D array - self.state.precip_cascades = nowcast_utils.stack_cascades( - self.state.precip_decomposed, self.config.n_cascade_levels + self.__state.precip_cascades = nowcast_utils.stack_cascades( + self.__state.precip_decomposed, self.__config.n_cascade_levels ) - self.state.precip_decomposed = self.state.precip_decomposed[-1] - self.state.precip_decomposed = [ - self.state.precip_decomposed.copy() - for _ in range(self.config.n_ens_members) + self.__state.precip_decomposed = self.__state.precip_decomposed[-1] + self.__state.precip_decomposed = [ + self.__state.precip_decomposed.copy() + for _ in range(self.__config.n_ens_members) ] # Compute temporal autocorrelation coefficients for each cascade level - self.params.autocorrelation_coefficients = np.empty( - (self.config.n_cascade_levels, self.config.ar_order) + self.__params.autocorrelation_coefficients = np.empty( + (self.__config.n_cascade_levels, self.__config.ar_order) ) - for i in range(self.config.n_cascade_levels): - self.params.autocorrelation_coefficients[i, :] = ( + for i in range(self.__config.n_cascade_levels): + self.__params.autocorrelation_coefficients[i, :] = ( correlation.temporal_autocorrelation( - self.state.precip_cascades[i], mask=self.state.mask_threshold + self.__state.precip_cascades[i], mask=self.__state.mask_threshold ) ) - nowcast_utils.print_corrcoefs(self.params.autocorrelation_coefficients) + nowcast_utils.print_corrcoefs(self.__params.autocorrelation_coefficients) # Adjust the lag-2 correlation coefficient if AR(2) model is used - if self.config.ar_order == 2: - for i in range(self.config.n_cascade_levels): - self.params.autocorrelation_coefficients[i, 1] = ( + if self.__config.ar_order == 2: + for i in range(self.__config.n_cascade_levels): + self.__params.autocorrelation_coefficients[i, 1] = ( autoregression.adjust_lag2_corrcoef2( - self.params.autocorrelation_coefficients[i, 0], - self.params.autocorrelation_coefficients[i, 1], + self.__params.autocorrelation_coefficients[i, 0], + self.__params.autocorrelation_coefficients[i, 1], ) ) # Estimate the parameters of the AR model using auto-correlation coefficients - self.params.ar_model_coefficients = np.empty( - (self.config.n_cascade_levels, self.config.ar_order + 1) + self.__params.ar_model_coefficients = np.empty( + (self.__config.n_cascade_levels, self.__config.ar_order + 1) ) - for i in range(self.config.n_cascade_levels): - self.params.ar_model_coefficients[i, :] = ( + for i in range(self.__config.n_cascade_levels): + self.__params.ar_model_coefficients[i, :] = ( autoregression.estimate_ar_params_yw( - self.params.autocorrelation_coefficients[i, :] + self.__params.autocorrelation_coefficients[i, :] ) ) - nowcast_utils.print_ar_params(self.params.ar_model_coefficients) + nowcast_utils.print_ar_params(self.__params.ar_model_coefficients) # Discard all except the last ar_order cascades for AR model - self.state.precip_cascades = [ - self.state.precip_cascades[i][-self.config.ar_order :] - for i in range(self.config.n_cascade_levels) + self.__state.precip_cascades = [ + self.__state.precip_cascades[i][-self.__config.ar_order :] + for i in range(self.__config.n_cascade_levels) ] # Stack the cascades into a list containing all ensemble members - self.state.precip_cascades = [ + self.__state.precip_cascades = [ [ - self.state.precip_cascades[j].copy() - for j in range(self.config.n_cascade_levels) + self.__state.precip_cascades[j].copy() + for j in range(self.__config.n_cascade_levels) ] - for _ in range(self.config.n_ens_members) + for _ in range(self.__config.n_ens_members) ] # Initialize random generators if noise_method is provided - if self.config.noise_method is not None: - self.state.random_generator_precip = [] - self.state.random_generator_motion = [] + if self.__config.noise_method is not None: + self.__state.random_generator_precip = [] + self.__state.random_generator_motion = [] - for _ in range(self.config.n_ens_members): + for _ in range(self.__config.n_ens_members): # Create random state for precipitation noise generator - rs = np.random.RandomState(self.config.seed) - self.state.random_generator_precip.append(rs) - self.config.seed = rs.randint( + rs = np.random.RandomState(self.__config.seed) + self.__state.random_generator_precip.append(rs) + self.__config.seed = rs.randint( 0, high=int(1e9) ) # Update seed after generating # Create random state for motion perturbations generator - rs = np.random.RandomState(self.config.seed) - self.state.random_generator_motion.append(rs) - self.config.seed = rs.randint( + rs = np.random.RandomState(self.__config.seed) + self.__state.random_generator_motion.append(rs) + self.__config.seed = rs.randint( 0, high=int(1e9) ) # Update seed after generating else: - self.state.random_generator_precip = None - self.state.random_generator_motion = None + self.__state.random_generator_precip = None + self.__state.random_generator_motion = None print("AR model and noise applied to precipitation cascades.") - def _initialize_velocity_perturbations(self): + def __initialize_velocity_perturbations(self): """ Initialize the velocity perturbators for each ensemble member if the velocity perturbation method is specified. """ - if self.config.velocity_perturbation_method is not None: + if self.__config.velocity_perturbation_method is not None: init_vel_noise, generate_vel_noise = noise.get_method( - self.config.velocity_perturbation_method + self.__config.velocity_perturbation_method ) - self.state.velocity_perturbations = [] - for j in range(self.config.n_ens_members): + self.__state.velocity_perturbations = [] + for j in range(self.__config.n_ens_members): kwargs = { - "randstate": self.state.random_generator_motion[j], - "p_par": self.config.velocity_perturbation_kwargs.get( - "p_par", self.params.velocity_perturbation_parallel + "randstate": self.__state.random_generator_motion[j], + "p_par": self.__config.velocity_perturbation_kwargs.get( + "p_par", self.__params.velocity_perturbation_parallel ), - "p_perp": self.config.velocity_perturbation_kwargs.get( - "p_perp", self.params.velocity_perturbation_perpendicular + "p_perp": self.__config.velocity_perturbation_kwargs.get( + "p_perp", self.__params.velocity_perturbation_perpendicular ), } vp = init_vel_noise( - self.velocity, - 1.0 / self.config.kmperpixel, - self.config.timestep, + self.__velocity, + 1.0 / self.__config.kmperpixel, + self.__config.timestep, **kwargs, ) - self.state.velocity_perturbations.append( - lambda t, vp=vp: generate_vel_noise(vp, t * self.config.timestep) + self.__state.velocity_perturbations.append( + lambda t, vp=vp: generate_vel_noise(vp, t * self.__config.timestep) ) else: - self.state.velocity_perturbations = None + self.__state.velocity_perturbations = None print("Velocity perturbations initialized successfully.") - def _initialize_precipitation_mask(self): + def __initialize_precipitation_mask(self): """ Initialize the precipitation mask and handle different mask methods (sprog, incremental). """ - self.state.precip_forecast = [[] for _ in range(self.config.n_ens_members)] + self.__state.precip_forecast = [[] for _ in range(self.__config.n_ens_members)] - if self.config.probmatching_method == "mean": - self.params.precipitation_mean = np.mean( - self.precip[-1, :, :][ - self.precip[-1, :, :] >= self.config.precip_threshold + if self.__config.probmatching_method == "mean": + self.__params.precipitation_mean = np.mean( + self.__precip[-1, :, :][ + self.__precip[-1, :, :] >= self.__config.precip_threshold ] ) else: - self.params.precipitation_mean = None + self.__params.precipitation_mean = None - if self.config.mask_method is not None: - self.state.mask_precip = ( - self.precip[-1, :, :] >= self.config.precip_threshold + if self.__config.mask_method is not None: + self.__state.mask_precip = ( + self.__precip[-1, :, :] >= self.__config.precip_threshold ) - if self.config.mask_method == "sprog": + if self.__config.mask_method == "sprog": # Compute the wet area ratio and the precipitation mask - self.params.wet_area_ratio = np.sum(self.state.mask_precip) / ( - self.precip.shape[1] * self.precip.shape[2] + self.__params.wet_area_ratio = np.sum(self.__state.mask_precip) / ( + self.__precip.shape[1] * self.__precip.shape[2] ) - self.state.precip_mask = [ - self.state.precip_cascades[0][i].copy() - for i in range(self.config.n_cascade_levels) + self.__state.precip_mask = [ + self.__state.precip_cascades[0][i].copy() + for i in range(self.__config.n_cascade_levels) ] - self.state.precip_mask_decomposed = self.state.precip_decomposed[ + self.__state.precip_mask_decomposed = self.__state.precip_decomposed[ 0 ].copy() - elif self.config.mask_method == "incremental": + elif self.__config.mask_method == "incremental": # Get mask parameters - self.params.mask_rim = self.config.mask_kwargs.get("mask_rim", 10) - mask_f = self.config.mask_kwargs.get("mask_f", 1.0) + self.__params.mask_rim = self.__config.mask_kwargs.get("mask_rim", 10) + mask_f = self.__config.mask_kwargs.get("mask_f", 1.0) # Initialize the structuring element - self.params.structuring_element = generate_binary_structure(2, 1) + self.__params.structuring_element = generate_binary_structure(2, 1) # Expand the structuring element based on mask factor and timestep - n = mask_f * self.config.timestep / self.config.kmperpixel - self.params.structuring_element = iterate_structure( - self.params.structuring_element, int((n - 1) / 2.0) + n = mask_f * self.__config.timestep / self.__config.kmperpixel + self.__params.structuring_element = iterate_structure( + self.__params.structuring_element, int((n - 1) / 2.0) ) # Compute and apply the dilated mask for each ensemble member - self.state.mask_precip = nowcast_utils.compute_dilated_mask( - self.state.mask_precip, - self.params.structuring_element, - self.params.mask_rim, + self.__state.mask_precip = nowcast_utils.compute_dilated_mask( + self.__state.mask_precip, + self.__params.structuring_element, + self.__params.mask_rim, ) - self.state.mask_precip = [ - self.state.mask_precip.copy() - for _ in range(self.config.n_ens_members) + self.__state.mask_precip = [ + self.__state.mask_precip.copy() + for _ in range(self.__config.n_ens_members) ] else: - self.state.mask_precip = None + self.__state.mask_precip = None - if self.config.noise_method is None and self.state.precip_mask is None: - self.state.precip_mask = [ - self.state.precip_cascades[0][i].copy() - for i in range(self.config.n_cascade_levels) + if self.__config.noise_method is None and self.__state.precip_mask is None: + self.__state.precip_mask = [ + self.__state.precip_cascades[0][i].copy() + for i in range(self.__config.n_cascade_levels) ] print("Precipitation mask initialized successfully.") - def _initialize_fft_objects(self): + def __initialize_fft_objects(self): """ Initialize FFT objects for each ensemble member. """ - self.state.fft_objs = [] - for _ in range(self.config.n_ens_members): + self.__state.fft_objs = [] + for _ in range(self.__config.n_ens_members): fft_obj = utils.get_method( - self.config.fft_method, shape=self.precip.shape[1:] + self.__config.fft_method, shape=self.__precip.shape[1:] ) - self.state.fft_objs.append(fft_obj) + self.__state.fft_objs.append(fft_obj) print("FFT objects initialized successfully.") - def _initialize_state(self): + def __initialize_state(self): """ Initialize the state dictionary used during the nowcast iteration. """ return { - "fft_objs": self.state.fft_objs, - "mask_prec": self.state.mask_precip, - "precip_cascades": self.state.precip_cascades, - "precip_decomp": self.state.precip_decomposed, - "precip_m": self.state.precip_mask, - "precip_m_d": self.state.precip_mask_decomposed, - "randgen_prec": self.state.random_generator_precip, + "fft_objs": self.__state.fft_objs, + "mask_prec": self.__state.mask_precip, + "precip_cascades": self.__state.precip_cascades, + "precip_decomp": self.__state.precip_decomposed, + "precip_m": self.__state.precip_mask, + "precip_m_d": self.__state.precip_mask_decomposed, + "randgen_prec": self.__state.random_generator_precip, } - def _initialize_params(self, precip): + def __initialize_params(self, precip): """ Initialize the params dictionary used during the nowcast iteration. """ return { - "decomp_method": self.params.decomposition_method, - "domain": self.config.domain, - "domain_mask": self.params.domain_mask, - "filter": self.params.bandpass_filter, - "fft": self.params.fft, - "generate_noise": self.params.noise_generator, - "mask_method": self.config.mask_method, - "mask_rim": self.params.mask_rim, - "mu_0": self.params.precipitation_mean, - "n_cascade_levels": self.config.n_cascade_levels, - "n_ens_members": self.config.n_ens_members, - "noise_method": self.config.noise_method, - "noise_std_coeffs": self.params.noise_std_coefficients, - "num_ensemble_workers": self.params.num_ensemble_workers, - "phi": self.params.ar_model_coefficients, - "pert_gen": self.params.perturbation_generator, - "probmatching_method": self.config.probmatching_method, + "decomp_method": self.__params.decomposition_method, + "domain": self.__config.domain, + "domain_mask": self.__params.domain_mask, + "filter": self.__params.bandpass_filter, + "fft": self.__params.fft, + "generate_noise": self.__params.noise_generator, + "mask_method": self.__config.mask_method, + "mask_rim": self.__params.mask_rim, + "mu_0": self.__params.precipitation_mean, + "n_cascade_levels": self.__config.n_cascade_levels, + "n_ens_members": self.__config.n_ens_members, + "noise_method": self.__config.noise_method, + "noise_std_coeffs": self.__params.noise_std_coefficients, + "num_ensemble_workers": self.__params.num_ensemble_workers, + "phi": self.__params.ar_model_coefficients, + "pert_gen": self.__params.perturbation_generator, + "probmatching_method": self.__config.probmatching_method, "precip": precip, - "precip_thr": self.config.precip_threshold, - "recomp_method": self.params.recomposition_method, - "struct": self.params.structuring_element, - "war": self.params.wet_area_ratio, + "precip_thr": self.__config.precip_threshold, + "recomp_method": self.__params.recomposition_method, + "struct": self.__params.structuring_element, + "war": self.__params.wet_area_ratio, } - def _update_state(self, state, params): + def __update_state(self, state, params): """ Update the state during the nowcasting loop. This function handles the AR model iteration, noise generation, recomposition, and mask application for each ensemble member. @@ -796,12 +812,12 @@ def _update_state(self, state, params): # Update the deterministic AR(p) model if noise or sprog mask is used if params["noise_method"] is None or params["mask_method"] == "sprog": - self._update_deterministic_ar_model(state, params) + self.__update_deterministic_ar_model(state, params) # Worker function for each ensemble member def worker(j): - self._apply_ar_model_to_cascades(j, state, params) - precip_forecast_out[j] = self._recompose_and_apply_mask(j, state, params) + self.__apply_ar_model_to_cascades(j, state, params) + precip_forecast_out[j] = self.__recompose_and_apply_mask(j, state, params) # Use Dask for parallel execution if available if ( @@ -819,7 +835,7 @@ def worker(j): return np.stack(precip_forecast_out), state - def _update_deterministic_ar_model(self, state, params): + def __update_deterministic_ar_model(self, state, params): """ Update the deterministic AR(p) model for each cascade level if noise is disabled or if the sprog mask is used. @@ -846,14 +862,14 @@ def _update_deterministic_ar_model(self, state, params): if params["mask_method"] == "sprog": state["mask_prec"] = compute_percentile_mask(precip_m_, params["war"]) - def _apply_ar_model_to_cascades(self, j, state, params): + def __apply_ar_model_to_cascades(self, j, state, params): """ Apply the AR(p) model to the cascades for each ensemble member, including noise generation and normalization. """ # Generate noise if enabled if params["noise_method"] is not None: - eps = self._generate_and_decompose_noise(j, state, params) + eps = self.__generate_and_decompose_noise(j, state, params) else: eps = None @@ -878,7 +894,7 @@ def _apply_ar_model_to_cascades(self, j, state, params): eps = None eps_ = None - def _generate_and_decompose_noise(self, j, state, params): + def __generate_and_decompose_noise(self, j, state, params): """ Generate and decompose the noise field into cascades for a given ensemble member. """ @@ -902,7 +918,7 @@ def _generate_and_decompose_noise(self, j, state, params): return eps - def _recompose_and_apply_mask(self, j, state, params): + def __recompose_and_apply_mask(self, j, state, params): """ Recompose the precipitation field from cascades and apply the precipitation mask. """ @@ -923,7 +939,7 @@ def _recompose_and_apply_mask(self, j, state, params): # Apply the precipitation mask if params["mask_method"] is not None: - precip_forecast = self._apply_precipitation_mask( + precip_forecast = self.__apply_precipitation_mask( precip_forecast, j, state, params ) @@ -951,7 +967,7 @@ def _recompose_and_apply_mask(self, j, state, params): return precip_forecast - def _apply_precipitation_mask(self, precip_forecast, j, state, params): + def __apply_precipitation_mask(self, precip_forecast, j, state, params): """ Apply the precipitation mask to prevent new precipitation from generating in areas where it was not observed. @@ -972,7 +988,7 @@ def _apply_precipitation_mask(self, precip_forecast, j, state, params): return precip_forecast - def _measure_time(self, label, start_time): + def __measure_time(self, label, start_time): """ Measure and print the time taken for a specific part of the process. @@ -980,7 +996,7 @@ def _measure_time(self, label, start_time): - label: A description of the part of the process being measured. - start_time: The timestamp when the process started (from time.time()). """ - if self.config.measure_time: + if self.__config.measure_time: elapsed_time = time.time() - start_time print(f"{label} took {elapsed_time:.2f} seconds.") @@ -991,13 +1007,13 @@ def reset_states_and_params(self): the inputs like precip, velocity, time_steps, or config. """ # Re-initialize the state and parameters - self.state = StepsNowcasterState() - self.params = StepsNowcasterParams() + self.__state = StepsNowcasterState() + self.__params = StepsNowcasterParams() # Reset time measurement variables - self.start_time_init = None - self.init_time = None - self.mainloop_time = None + self.__start_time_init = None + self.__init_time = None + self.__mainloop_time = None # Wrapper function to preserve backward compatibility