Skip to content

Commit

Permalink
Add type hints to local_script_lib
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Dec 18, 2023
1 parent f89ec3f commit b7424ac
Showing 1 changed file with 40 additions and 27 deletions.
67 changes: 40 additions & 27 deletions semeio/workflows/localisation/local_script_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import math
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Any
from typing import Dict, List, Any, Union, Tuple, Optional

import cwrap
import numpy as np
Expand Down Expand Up @@ -98,7 +98,12 @@ class Decay:
main_range: float
perp_range: float
azimuth: float
grid: object
grid: Grid

def __call__(self, data_index):
# Default behavior of Decay when called as a function
# This is a placeholder; you can define a more meaningful default behavior
return 1.0

def __post_init__(self):
angle = (90.0 - self.azimuth) * math.pi / 180.0
Expand Down Expand Up @@ -134,7 +139,7 @@ def norm_dist_square(self, data_index):
class GaussianDecay(Decay):
cutoff: bool

def __call__(self, data_index) -> float:
def __call__(self, data_index: List[int]) -> float:
d2 = super().norm_dist_square(data_index)
if self.cutoff and d2 > 1.0:
return 0.0
Expand Down Expand Up @@ -260,8 +265,8 @@ def write_qc_parameter_surface(
corr_name: str,
surface_scale: bool,
reference_surface_file: str,
param_for_surface: object,
log_level=LogLevel.OFF,
param_for_surface: ma.MaskedArray,
log_level: LogLevel = LogLevel.OFF,
) -> None:
# pylint: disable=too-many-arguments

Expand Down Expand Up @@ -367,7 +372,9 @@ def build_decay_object(
tapering_range: float = 1.5,
) -> Decay:
# pylint: disable=too-many-arguments
decay_obj = None
decay_obj: Union[
GaussianDecay, ExponentialDecay, ConstGaussianDecay, ConstExponentialDecay
]
if method == "gaussian_decay":
decay_obj = GaussianDecay(
ref_pos,
Expand Down Expand Up @@ -419,7 +426,9 @@ def build_decay_object(
return decay_obj


def calculate_scaling_vector_fields(grid: object, decay_obj: object) -> ma.MaskedArray:
def calculate_scaling_vector_fields(
grid: object, decay_obj: Union[Decay, ConstantScalingFactor]
) -> ma.MaskedArray:
assert isinstance(grid, Grid)
nx = grid.getNX()
ny = grid.getNY()
Expand All @@ -439,7 +448,9 @@ def calculate_scaling_vector_fields(grid: object, decay_obj: object) -> ma.Maske
return scaling_vector


def calculate_scaling_vector_surface(grid: object, decay_obj: object):
def calculate_scaling_vector_surface(
grid: object, decay_obj: Union[ConstantScalingFactor, Decay]
) -> ma.MaskedArray:
assert isinstance(grid, Surface)
nx = grid.getNX()
ny = grid.getNY()
Expand All @@ -457,16 +468,16 @@ def calculate_scaling_vector_surface(grid: object, decay_obj: object):

def apply_decay(
method: str,
row_scaling: object,
grid: object,
row_scaling: RowScaling,
grid: Grid,
ref_pos: list,
main_range: float,
perp_range: float,
azimuth: float,
use_cutoff: bool = False,
tapering_range: float = 1.5,
calculate_qc_parameter: bool = False,
):
) -> Tuple[Optional[ma.MaskedArray], Optional[ma.MaskedArray]]:
# pylint: disable=too-many-arguments,too-many-locals
"""
Calculates the scaling factor, assign it to ERT instance by row_scaling
Expand Down Expand Up @@ -504,12 +515,12 @@ def apply_decay(


def apply_constant(
row_scaling: object,
grid: object,
row_scaling: RowScaling,
grid: Grid,
value: float,
log_level: LogLevel,
calculate_qc_parameter: bool = False,
):
) -> Tuple[Optional[ma.MaskedArray], Optional[ma.MaskedArray]]:
# pylint: disable=too-many-arguments,too-many-locals
"""
Assign constant value to the scaling factor,
Expand All @@ -535,13 +546,13 @@ def apply_constant(


def apply_from_file(
row_scaling: object,
grid: object,
row_scaling: RowScaling,
grid: Grid,
filename: str,
param_name: str,
log_level: LogLevel,
calculate_qc_parameter: bool = False,
):
) -> Tuple[Optional[ma.MaskedArray], None]:
# pylint: disable=too-many-arguments, too-many-locals
debug_print(
f"Read scaling factors as parameter {param_name}", LogLevel.LEVEL3, log_level
Expand Down Expand Up @@ -642,8 +653,11 @@ def calculate_scaling_factors_in_regions(


def smooth_parameter(
grid, smooth_range_list, scaling_values, active_region_values_used
):
grid: Grid,
smooth_range_list: List[int],
scaling_values: List[float],
active_region_values_used: ma.MaskedArray,
) -> ma.MaskedArray:
"""
Function taking as input a 3D parameter scaling_values and calculates a new
3D parameter scaling_values_smooth using local average within a rectangular window
Expand Down Expand Up @@ -671,13 +685,12 @@ def smooth_parameter(
if active_region_values_used[index0] is not ma.masked:
sumv = 0.0
nval = 0
ilow = max(0, i0 - di)
ihigh = min(i0 + di + 1, nx)
jlow = max(0, j0 - dj)
jhigh = min(j0 + dj + 1, ny)
ilow: int = max(0, i0 - di)
ihigh: int = min(i0 + di + 1, nx)
jlow: int = max(0, j0 - dj)
jhigh: int = min(j0 + dj + 1, ny)
for i in range(ilow, ihigh):
for j in range(jlow, jhigh):
# index = i + j * nx + k * nx * ny
index = k + j * nz + i * nz * ny
if active_region_values_used[index] is not ma.masked:
# Only use values from grid cells that are active
Expand All @@ -691,16 +704,16 @@ def smooth_parameter(


def apply_segment(
row_scaling,
grid,
row_scaling: RowScaling,
grid: Grid,
region_param_dict: Dict[str, Any],
active_segment_list: List[int],
scaling_factor_list: List[float],
smooth_range_list: List[int],
corr_name: str,
log_level: LogLevel = LogLevel.OFF,
calculate_qc_parameter: bool = False,
):
) -> Tuple[Any, None]:
# pylint: disable=too-many-arguments,too-many-locals
"""
Purpose: Use region numbers and list of scaling factors per region to
Expand Down

0 comments on commit b7424ac

Please sign in to comment.