From b7424ac4814ad3222172136401b9549abdcff1fe Mon Sep 17 00:00:00 2001 From: Feda Curic Date: Mon, 18 Dec 2023 14:23:05 +0100 Subject: [PATCH] Add type hints to local_script_lib --- .../localisation/local_script_lib.py | 67 +++++++++++-------- 1 file changed, 40 insertions(+), 27 deletions(-) diff --git a/semeio/workflows/localisation/local_script_lib.py b/semeio/workflows/localisation/local_script_lib.py index 61bb4428..4a9cc97b 100644 --- a/semeio/workflows/localisation/local_script_lib.py +++ b/semeio/workflows/localisation/local_script_lib.py @@ -6,7 +6,7 @@ import math from collections import defaultdict from dataclasses import dataclass, field -from typing import Dict, List, Any +from typing import Dict, List, Any, Union, Tuple, Optional import cwrap import numpy as np @@ -98,7 +98,12 @@ class Decay: main_range: float perp_range: float azimuth: float - grid: object + grid: Grid + + def __call__(self, data_index): + # Default behavior of Decay when called as a function + # This is a placeholder; you can define a more meaningful default behavior + return 1.0 def __post_init__(self): angle = (90.0 - self.azimuth) * math.pi / 180.0 @@ -134,7 +139,7 @@ def norm_dist_square(self, data_index): class GaussianDecay(Decay): cutoff: bool - def __call__(self, data_index) -> float: + def __call__(self, data_index: List[int]) -> float: d2 = super().norm_dist_square(data_index) if self.cutoff and d2 > 1.0: return 0.0 @@ -260,8 +265,8 @@ def write_qc_parameter_surface( corr_name: str, surface_scale: bool, reference_surface_file: str, - param_for_surface: object, - log_level=LogLevel.OFF, + param_for_surface: ma.MaskedArray, + log_level: LogLevel = LogLevel.OFF, ) -> None: # pylint: disable=too-many-arguments @@ -367,7 +372,9 @@ def build_decay_object( tapering_range: float = 1.5, ) -> Decay: # pylint: disable=too-many-arguments - decay_obj = None + decay_obj: Union[ + GaussianDecay, ExponentialDecay, ConstGaussianDecay, ConstExponentialDecay + ] if method == "gaussian_decay": decay_obj = GaussianDecay( ref_pos, @@ -419,7 +426,9 @@ def build_decay_object( return decay_obj -def calculate_scaling_vector_fields(grid: object, decay_obj: object) -> ma.MaskedArray: +def calculate_scaling_vector_fields( + grid: object, decay_obj: Union[Decay, ConstantScalingFactor] +) -> ma.MaskedArray: assert isinstance(grid, Grid) nx = grid.getNX() ny = grid.getNY() @@ -439,7 +448,9 @@ def calculate_scaling_vector_fields(grid: object, decay_obj: object) -> ma.Maske return scaling_vector -def calculate_scaling_vector_surface(grid: object, decay_obj: object): +def calculate_scaling_vector_surface( + grid: object, decay_obj: Union[ConstantScalingFactor, Decay] +) -> ma.MaskedArray: assert isinstance(grid, Surface) nx = grid.getNX() ny = grid.getNY() @@ -457,8 +468,8 @@ def calculate_scaling_vector_surface(grid: object, decay_obj: object): def apply_decay( method: str, - row_scaling: object, - grid: object, + row_scaling: RowScaling, + grid: Grid, ref_pos: list, main_range: float, perp_range: float, @@ -466,7 +477,7 @@ def apply_decay( use_cutoff: bool = False, tapering_range: float = 1.5, calculate_qc_parameter: bool = False, -): +) -> Tuple[Optional[ma.MaskedArray], Optional[ma.MaskedArray]]: # pylint: disable=too-many-arguments,too-many-locals """ Calculates the scaling factor, assign it to ERT instance by row_scaling @@ -504,12 +515,12 @@ def apply_decay( def apply_constant( - row_scaling: object, - grid: object, + row_scaling: RowScaling, + grid: Grid, value: float, log_level: LogLevel, calculate_qc_parameter: bool = False, -): +) -> Tuple[Optional[ma.MaskedArray], Optional[ma.MaskedArray]]: # pylint: disable=too-many-arguments,too-many-locals """ Assign constant value to the scaling factor, @@ -535,13 +546,13 @@ def apply_constant( def apply_from_file( - row_scaling: object, - grid: object, + row_scaling: RowScaling, + grid: Grid, filename: str, param_name: str, log_level: LogLevel, calculate_qc_parameter: bool = False, -): +) -> Tuple[Optional[ma.MaskedArray], None]: # pylint: disable=too-many-arguments, too-many-locals debug_print( f"Read scaling factors as parameter {param_name}", LogLevel.LEVEL3, log_level @@ -642,8 +653,11 @@ def calculate_scaling_factors_in_regions( def smooth_parameter( - grid, smooth_range_list, scaling_values, active_region_values_used -): + grid: Grid, + smooth_range_list: List[int], + scaling_values: List[float], + active_region_values_used: ma.MaskedArray, +) -> ma.MaskedArray: """ Function taking as input a 3D parameter scaling_values and calculates a new 3D parameter scaling_values_smooth using local average within a rectangular window @@ -671,13 +685,12 @@ def smooth_parameter( if active_region_values_used[index0] is not ma.masked: sumv = 0.0 nval = 0 - ilow = max(0, i0 - di) - ihigh = min(i0 + di + 1, nx) - jlow = max(0, j0 - dj) - jhigh = min(j0 + dj + 1, ny) + ilow: int = max(0, i0 - di) + ihigh: int = min(i0 + di + 1, nx) + jlow: int = max(0, j0 - dj) + jhigh: int = min(j0 + dj + 1, ny) for i in range(ilow, ihigh): for j in range(jlow, jhigh): - # index = i + j * nx + k * nx * ny index = k + j * nz + i * nz * ny if active_region_values_used[index] is not ma.masked: # Only use values from grid cells that are active @@ -691,8 +704,8 @@ def smooth_parameter( def apply_segment( - row_scaling, - grid, + row_scaling: RowScaling, + grid: Grid, region_param_dict: Dict[str, Any], active_segment_list: List[int], scaling_factor_list: List[float], @@ -700,7 +713,7 @@ def apply_segment( corr_name: str, log_level: LogLevel = LogLevel.OFF, calculate_qc_parameter: bool = False, -): +) -> Tuple[Any, None]: # pylint: disable=too-many-arguments,too-many-locals """ Purpose: Use region numbers and list of scaling factors per region to