Skip to content

Commit

Permalink
Merge branch 'fix-non-adaptive-index-order' of github.com:oddvarlia/s…
Browse files Browse the repository at this point in the history
…emeio into fix-non-adaptive-index-order
  • Loading branch information
oddvarlia committed Dec 18, 2023
2 parents 45973b5 + 8eff858 commit 722094c
Showing 1 changed file with 47 additions and 38 deletions.
85 changes: 47 additions & 38 deletions semeio/workflows/localisation/local_script_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import math
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Any
from typing import Dict, List, Any, Union, Tuple, Optional

import cwrap
import numpy as np
Expand Down Expand Up @@ -98,7 +98,12 @@ class Decay:
main_range: float
perp_range: float
azimuth: float
grid: object
grid: Grid

def __call__(self, data_index):
# Default behavior of Decay when called as a function
# This is a placeholder; you can define a more meaningful default behavior
return 1.0

def __post_init__(self):
angle = (90.0 - self.azimuth) * math.pi / 180.0
Expand Down Expand Up @@ -134,7 +139,7 @@ def norm_dist_square(self, data_index):
class GaussianDecay(Decay):
cutoff: bool

def __call__(self, data_index):
def __call__(self, data_index: List[int]) -> float:
d2 = super().norm_dist_square(data_index)
if self.cutoff and d2 > 1.0:
return 0.0
Expand Down Expand Up @@ -260,8 +265,8 @@ def write_qc_parameter_surface(
corr_name: str,
surface_scale: bool,
reference_surface_file: str,
param_for_surface: object,
log_level=LogLevel.OFF,
param_for_surface: ma.MaskedArray,
log_level: LogLevel = LogLevel.OFF,
) -> None:
# pylint: disable=too-many-arguments

Expand Down Expand Up @@ -365,9 +370,11 @@ def build_decay_object(
grid: object,
use_cutoff: bool,
tapering_range: float = 1.5,
) -> Any:
) -> Decay:
# pylint: disable=too-many-arguments
decay_obj = None
decay_obj: Union[
GaussianDecay, ExponentialDecay, ConstGaussianDecay, ConstExponentialDecay
]
if method == "gaussian_decay":
decay_obj = GaussianDecay(
ref_pos,
Expand Down Expand Up @@ -419,7 +426,9 @@ def build_decay_object(
return decay_obj


def calculate_scaling_vector_fields(grid: object, decay_obj: object):
def calculate_scaling_vector_fields(
grid: object, decay_obj: Union[Decay, ConstantScalingFactor]
) -> ma.MaskedArray:
assert isinstance(grid, Grid)
nx = grid.getNX()
ny = grid.getNY()
Expand All @@ -439,7 +448,9 @@ def calculate_scaling_vector_fields(grid: object, decay_obj: object):
return scaling_vector


def calculate_scaling_vector_surface(grid: object, decay_obj: object):
def calculate_scaling_vector_surface(
grid: object, decay_obj: Union[ConstantScalingFactor, Decay]
) -> ma.MaskedArray:
assert isinstance(grid, Surface)
nx = grid.getNX()
ny = grid.getNY()
Expand All @@ -457,16 +468,16 @@ def calculate_scaling_vector_surface(grid: object, decay_obj: object):

def apply_decay(
method: str,
row_scaling: object,
grid: object,
row_scaling: RowScaling,
grid: Grid,
ref_pos: list,
main_range: float,
perp_range: float,
azimuth: float,
use_cutoff: bool = False,
tapering_range: float = 1.5,
calculate_qc_parameter: bool = False,
):
) -> Tuple[Optional[ma.MaskedArray], Optional[ma.MaskedArray]]:
# pylint: disable=too-many-arguments,too-many-locals
"""
Calculates the scaling factor, assign it to ERT instance by row_scaling
Expand Down Expand Up @@ -504,12 +515,12 @@ def apply_decay(


def apply_constant(
row_scaling: object,
grid: object,
row_scaling: RowScaling,
grid: Grid,
value: float,
log_level: LogLevel,
calculate_qc_parameter: bool = False,
):
) -> Tuple[Optional[ma.MaskedArray], Optional[ma.MaskedArray]]:
# pylint: disable=too-many-arguments,too-many-locals
"""
Assign constant value to the scaling factor,
Expand All @@ -535,13 +546,13 @@ def apply_constant(


def apply_from_file(
row_scaling: object,
grid: object,
row_scaling: RowScaling,
grid: Grid,
filename: str,
param_name: str,
log_level: LogLevel,
calculate_qc_parameter: bool = False,
):
) -> Tuple[Optional[ma.MaskedArray], None]:
# pylint: disable=too-many-arguments, too-many-locals
debug_print(
f"Read scaling factors as parameter {param_name}", LogLevel.LEVEL3, log_level
Expand Down Expand Up @@ -642,8 +653,11 @@ def calculate_scaling_factors_in_regions(


def smooth_parameter(
grid, smooth_range_list, scaling_values, active_region_values_used
):
grid: Grid,
smooth_range_list: List[int],
scaling_values: List[float],
active_region_values_used: ma.MaskedArray,
) -> ma.MaskedArray:
"""
Function taking as input a 3D parameter scaling_values and calculates a new
3D parameter scaling_values_smooth using local average within a rectangular window
Expand Down Expand Up @@ -671,13 +685,12 @@ def smooth_parameter(
if active_region_values_used[index0] is not ma.masked:
sumv = 0.0
nval = 0
ilow = max(0, i0 - di)
ihigh = min(i0 + di + 1, nx)
jlow = max(0, j0 - dj)
jhigh = min(j0 + dj + 1, ny)
ilow: int = max(0, i0 - di)
ihigh: int = min(i0 + di + 1, nx)
jlow: int = max(0, j0 - dj)
jhigh: int = min(j0 + dj + 1, ny)
for i in range(ilow, ihigh):
for j in range(jlow, jhigh):
# index = i + j * nx + k * nx * ny
index = k + j * nz + i * nz * ny
if active_region_values_used[index] is not ma.masked:
# Only use values from grid cells that are active
Expand All @@ -691,16 +704,16 @@ def smooth_parameter(


def apply_segment(
row_scaling,
grid,
region_param_dict,
active_segment_list,
scaling_factor_list,
smooth_range_list,
corr_name,
log_level=LogLevel.OFF,
row_scaling: RowScaling,
grid: Grid,
region_param_dict: Dict[str, Any],
active_segment_list: List[int],
scaling_factor_list: List[float],
smooth_range_list: List[int],
corr_name: str,
log_level: LogLevel = LogLevel.OFF,
calculate_qc_parameter: bool = False,
):
) -> Tuple[Any, None]:
# pylint: disable=too-many-arguments,too-many-locals
"""
Purpose: Use region numbers and list of scaling factors per region to
Expand Down Expand Up @@ -749,10 +762,6 @@ def apply_segment(
# Assign values to row_scaling object
row_scaling.assign_vector(scaling_values)

# for index in range(data_size):
# global_index = grid.global_index(active_index=index)
# row_scaling[index] = scaling_values[global_index]

not_defined_in_region_param = []
for n in active_segment_list:
if n not in regions_in_param:
Expand Down

0 comments on commit 722094c

Please sign in to comment.