Skip to content

Commit

Permalink
Added frozen dataclass to nowcast
Browse files Browse the repository at this point in the history
  • Loading branch information
sidekock committed Dec 19, 2024
1 parent c72d953 commit 00f057b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 27 deletions.
1 change: 0 additions & 1 deletion pysteps/blending/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ def __blended_nowcast_main_loop(self):

if self.__config.measure_time:
starttime_mainloop = time.time()
# self.__state.extrapolation_kwargs = deepcopy(self.__config.extrapolation_kwargs)
self.__state.extrapolation_kwargs["return_displacement"] = True

self.__state.precip_cascades_prev_subtimestep = deepcopy(
Expand Down
71 changes: 45 additions & 26 deletions pysteps/nowcasts/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
from scipy.ndimage import generate_binary_structure, iterate_structure
import time
from copy import deepcopy

from pysteps import cascade
from pysteps import extrapolation
Expand All @@ -35,7 +36,7 @@
DASK_IMPORTED = False


@dataclass
@dataclass(frozen=True)
class StepsNowcasterConfig:
"""
Parameters
Expand Down Expand Up @@ -247,6 +248,10 @@ class StepsNowcasterParams:
xy_coordinates: np.ndarray | None = None
velocity_perturbation_parallel: list[float] | None = None
velocity_perturbation_perpendicular: list[float] | None = None
filter_kwargs: dict | None = None
noise_kwargs: dict | None = None
velocity_perturbation_kwargs: dict | None = None
mask_kwargs: dict | None = None


@dataclass
Expand All @@ -268,6 +273,7 @@ class StepsNowcasterState:
)
velocity_perturbations: list[Callable] | None = field(default_factory=list)
fft_objects: list[Any] | None = field(default_factory=list)
extrapolation_kwargs: dict[str, Any] | None = field(default_factory=dict)


class StepsNowcaster:
Expand Down Expand Up @@ -408,7 +414,7 @@ def __nowcast_main(self):
self.__time_steps,
self.__config.extrapolation_method,
self.__update_state, # Reference to the update function
extrap_kwargs=self.__config.extrapolation_kwargs,
extrap_kwargs=self.__state.extrapolation_kwargs,
velocity_pert_gen=self.__state.velocity_perturbations,
params=params,
ensemble=True,
Expand Down Expand Up @@ -483,15 +489,33 @@ def __check_inputs(self):

# Handle None values for various kwargs
if self.__config.extrapolation_kwargs is None:
self.__config.extrapolation_kwargs = {}
self.__state.extrapolation_kwargs = dict()
else:
self.__state.extrapolation_kwargs = deepcopy(

Check warning on line 494 in pysteps/nowcasts/steps.py

View check run for this annotation

Codecov / codecov/patch

pysteps/nowcasts/steps.py#L494

Added line #L494 was not covered by tests
self.__config.extrapolation_kwargs
)

if self.__config.filter_kwargs is None:
self.__config.filter_kwargs = {}
self.__params.filter_kwargs = dict()
else:
self.__params.filter_kwargs = deepcopy(self.__config.filter_kwargs)

Check warning on line 501 in pysteps/nowcasts/steps.py

View check run for this annotation

Codecov / codecov/patch

pysteps/nowcasts/steps.py#L501

Added line #L501 was not covered by tests

if self.__config.noise_kwargs is None:
self.__config.noise_kwargs = {}
self.__params.noise_kwargs = dict()
else:
self.__params.noise_kwargs = deepcopy(self.__config.noise_kwargs)

Check warning on line 506 in pysteps/nowcasts/steps.py

View check run for this annotation

Codecov / codecov/patch

pysteps/nowcasts/steps.py#L506

Added line #L506 was not covered by tests

if self.__config.velocity_perturbation_kwargs is None:
self.__config.velocity_perturbation_kwargs = {}
self.__params.velocity_perturbation_kwargs = dict()
else:
self.__params.velocity_perturbation_kwargs = deepcopy(

Check warning on line 511 in pysteps/nowcasts/steps.py

View check run for this annotation

Codecov / codecov/patch

pysteps/nowcasts/steps.py#L511

Added line #L511 was not covered by tests
self.__config.velocity_perturbation_kwargs
)

if self.__config.mask_kwargs is None:
self.__config.mask_kwargs = {}
self.__params.mask_kwargs = dict()
else:
self.__params.mask_kwargs = deepcopy(self.__config.mask_kwargs)

Check warning on line 518 in pysteps/nowcasts/steps.py

View check run for this annotation

Codecov / codecov/patch

pysteps/nowcasts/steps.py#L518

Added line #L518 was not covered by tests

print("Inputs validated and initialized successfully.")

Expand Down Expand Up @@ -548,12 +572,12 @@ def __print_forecast_info(self):

if self.__config.velocity_perturbation_method == "bps":
self.__params.velocity_perturbation_parallel = (
self.__config.velocity_perturbation_kwargs.get(
self.__params.velocity_perturbation_kwargs.get(
"p_par", noise.motion.get_default_params_bps_par()
)
)
self.__params.velocity_perturbation_perpendicular = (
self.__config.velocity_perturbation_kwargs.get(
self.__params.velocity_perturbation_kwargs.get(
"p_perp", noise.motion.get_default_params_bps_perp()
)
)
Expand Down Expand Up @@ -588,7 +612,7 @@ def __initialize_nowcast_components(self):
self.__params.bandpass_filter = filter_method(
(M, N),
self.__config.n_cascade_levels,
**(self.__config.filter_kwargs or {}),
**(self.__params.filter_kwargs or {}),
)

# Get the decomposition method (e.g., FFT)
Expand Down Expand Up @@ -629,7 +653,7 @@ def __perform_extrapolation(self):
else:
self.__state.mask_threshold = None

extrap_kwargs = self.__config.extrapolation_kwargs.copy()
extrap_kwargs = self.__state.extrapolation_kwargs.copy()
extrap_kwargs["xy_coords"] = self.__params.xy_coordinates
extrap_kwargs["allow_nonfinite_values"] = (
True if np.any(~np.isfinite(self.__precip)) else False
Expand Down Expand Up @@ -691,7 +715,7 @@ def __apply_noise_and_ar_model(self):
self.__params.perturbation_generator = init_noise(
self.__precip,
fft_method=self.__params.fft,
**self.__config.noise_kwargs,
**self.__params.noise_kwargs,
)

# Handle noise standard deviation adjustments if necessary
Expand Down Expand Up @@ -831,21 +855,16 @@ def __apply_noise_and_ar_model(self):
if self.__config.noise_method is not None:
self.__state.random_generator_precip = []
self.__state.random_generator_motion = []

seed = self.__config.seed
for _ in range(self.__config.n_ens_members):
# Create random state for precipitation noise generator
rs = np.random.RandomState(self.__config.seed)
rs = np.random.RandomState(seed)
self.__state.random_generator_precip.append(rs)
self.__config.seed = rs.randint(
0, high=int(1e9)
) # Update seed after generating

seed = rs.randint(0, high=int(1e9))
# Create random state for motion perturbations generator
rs = np.random.RandomState(self.__config.seed)
rs = np.random.RandomState(seed)
self.__state.random_generator_motion.append(rs)
self.__config.seed = rs.randint(
0, high=int(1e9)
) # Update seed after generating
seed = rs.randint(0, high=int(1e9))
else:
self.__state.random_generator_precip = None
self.__state.random_generator_motion = None
Expand All @@ -865,10 +884,10 @@ def __initialize_velocity_perturbations(self):
for j in range(self.__config.n_ens_members):
kwargs = {
"randstate": self.__state.random_generator_motion[j],
"p_par": self.__config.velocity_perturbation_kwargs.get(
"p_par": self.__params.velocity_perturbation_kwargs.get(
"p_par", self.__params.velocity_perturbation_parallel
),
"p_perp": self.__config.velocity_perturbation_kwargs.get(
"p_perp": self.__params.velocity_perturbation_kwargs.get(
"p_perp", self.__params.velocity_perturbation_perpendicular
),
}
Expand Down Expand Up @@ -920,8 +939,8 @@ def __initialize_precipitation_mask(self):

elif self.__config.mask_method == "incremental":
# Get mask parameters
self.__params.mask_rim = self.__config.mask_kwargs.get("mask_rim", 10)
mask_f = self.__config.mask_kwargs.get("mask_f", 1.0)
self.__params.mask_rim = self.__params.mask_kwargs.get("mask_rim", 10)
mask_f = self.__params.mask_kwargs.get("mask_f", 1.0)
# Initialize the structuring element
self.__params.structuring_element = generate_binary_structure(2, 1)
# Expand the structuring element based on mask factor and timestep
Expand Down

0 comments on commit 00f057b

Please sign in to comment.