Skip to content

Commit

Permalink
PR respose, minor code refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
andped10 committed Apr 5, 2024
1 parent 4d82937 commit 3bf2204
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 41 deletions.
2 changes: 1 addition & 1 deletion EasyReflectometry/calculators/bornagain/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from easyCore.Objects.Inferface import ItemContainer

from EasyReflectometry.experiment.model import Model
from EasyReflectometry.experiment import Model
from EasyReflectometry.sample import Layer
from EasyReflectometry.sample import Material
from EasyReflectometry.sample import MaterialMixture
Expand Down
2 changes: 1 addition & 1 deletion EasyReflectometry/calculators/calculator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from easyCore.Objects.core import ComponentSerializer
from easyCore.Objects.Inferface import ItemContainer

from EasyReflectometry.experiment.model import Model
from EasyReflectometry.experiment import Model
from EasyReflectometry.sample import BaseAssembly
from EasyReflectometry.sample import Layer
from EasyReflectometry.sample import Material
Expand Down
9 changes: 6 additions & 3 deletions EasyReflectometry/calculators/refl1d/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

from ..wrapper_base import WrapperBase

PADDING_RANGE = 3.5
UPSCALE_FACTOR = 21


class Refl1dWrapper(WrapperBase):
def create_material(self, name: str):
Expand Down Expand Up @@ -168,9 +171,9 @@ def calculate(self, q_array: np.ndarray, model_name: str) -> np.ndarray:
background=self.storage['model'][model_name]['bkg'],
)
q.calc_Qo = np.linspace(
q_array[argmin] - 3.5 * dq_vector_normalized_to_refnx[argmin],
q_array[argmax] + 3.5 * dq_vector_normalized_to_refnx[argmax],
21 * len(q_array),
q_array[argmin] - PADDING_RANGE * dq_vector_normalized_to_refnx[argmin],
q_array[argmax] + PADDING_RANGE * dq_vector_normalized_to_refnx[argmax],
UPSCALE_FACTOR * len(q_array),
)
R = names.Experiment(probe=q, sample=structure).reflectivity()[1]
return R
Expand Down
14 changes: 2 additions & 12 deletions EasyReflectometry/calculators/wrapper_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import sys
from abc import abstractmethod
from typing import Callable

import numpy as np

from EasyReflectometry.experiment import constant_resolution_function


class WrapperBase:
def __init__(self):
Expand Down Expand Up @@ -209,14 +210,3 @@ def set_resolution_function(self, resolution_function: Callable[[np.array], floa
:param resolution_function: The resolution function
"""
self._resolution_function = resolution_function


def constant_resolution_function(constant: float) -> Callable[[np.array], float]:
return linear_spline_resolution_function([sys.float_info.min, sys.float_info.max], [constant, constant])


def linear_spline_resolution_function(q_data_points: np.array, resolution_points: np.array) -> Callable[[np.array], float]:
def resolution_function(q: np.array) -> np.array:
return np.interp(q, q_data_points, resolution_points)

return resolution_function
11 changes: 11 additions & 0 deletions EasyReflectometry/experiment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .model import Model
from .models import Models
from .resolution_functions import constant_resolution_function
from .resolution_functions import linear_spline_resolution_function

__all__ = (
constant_resolution_function,
linear_spline_resolution_function,
Model,
Models,
)
8 changes: 4 additions & 4 deletions EasyReflectometry/experiment/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
from easyCore.Objects.ObjectClasses import BaseObj
from easyCore.Objects.ObjectClasses import Parameter

from EasyReflectometry.calculators.wrapper_base import constant_resolution_function
from EasyReflectometry.sample import BaseAssembly
from EasyReflectometry.sample import Layer
from EasyReflectometry.sample import LayerCollection
from EasyReflectometry.sample import Sample

from .resolution_functions import constant_resolution_function

MODEL_DETAILS = {
'scale': {
'description': 'Scaling of the reflectomety profile',
Expand Down Expand Up @@ -191,9 +192,8 @@ def uid(self) -> int:
@property
def _dict_repr(self) -> dict[str, dict[str, str]]:
"""A simplified dict representation."""
resolution_values = self._resolution_function([0.1, 0.2])
if resolution_values[0] == resolution_values[1]:
resolution = f'{resolution_values[0]} %'
if self._resolution_function.__qualname__.split('.')[0] == 'constant_resolution_function':
resolution = f'{self._resolution_function([0])[0]} %'
else:
resolution = 'function of Q'

Expand Down
14 changes: 3 additions & 11 deletions EasyReflectometry/experiment/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,11 @@
import yaml
from easyCore.Objects.Groups import BaseCollection

from EasyReflectometry.experiment.model import Model
from .model import Model


class Models(BaseCollection):

def __init__(self,
*args: List[Model],
name: str = 'EasyModels',
interface=None,
**kwargs):
def __init__(self, *args: List[Model], name: str = 'EasyModels', interface=None, **kwargs):
super().__init__(name, *args, **kwargs)
self.interface = interface

Expand All @@ -31,10 +26,7 @@ def default(cls, interface=None) -> 'Models':
return cls(model1, model2, interface=interface)

@classmethod
def from_pars(cls,
*args: List[Model],
name: str = 'EasyModels',
interface=None) -> 'Models':
def from_pars(cls, *args: List[Model], name: str = 'EasyModels', interface=None) -> 'Models':
"""
Constructor for the models where models are being given.
Expand Down
36 changes: 36 additions & 0 deletions EasyReflectometry/experiment/resolution_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Callable

import numpy as np


def constant_resolution_function(constant: float) -> Callable[[np.array], np.array]:
"""Create a resolution function that is constant across the q range.
:param constant: The constant resolution value.
"""

def _constant(q: np.array) -> np.array:
"""Function that calculates the resolution at a given q value.
The function uses the data points from the encapsulating function and produces a linearly interpolated between them.
"""
return np.ones(len(q)) * constant

return _constant


def linear_spline_resolution_function(q_data_points: np.array, resolution_points: np.array) -> Callable[[np.array], np.array]:
"""Create a resolution function that is linearly interpolated between given data points.
:param q_data_points: The q values at which the resolution is defined.
:param resolution_points: The resolution values at the given q values.
"""

def _linear(q: np.array) -> np.array:
"""Function that calculates the resolution at a given q value.
The function uses the data points from the encapsulating function and produces a linearly interpolated between them.
"""
return np.interp(q, q_data_points, resolution_points)

return _linear
2 changes: 1 addition & 1 deletion EasyReflectometry/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import scipp as sc
from easyCore.Fitting.Fitting import MultiFitter as easyFitter

from EasyReflectometry.experiment.model import Model
from EasyReflectometry.experiment import Model


class Fitter:
Expand Down
Empty file removed tests/experiment/__init__.py
Empty file.
14 changes: 7 additions & 7 deletions tests/experiment/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from numpy.testing import assert_equal

from EasyReflectometry.calculators import CalculatorFactory
from EasyReflectometry.calculators.wrapper_base import constant_resolution_function
from EasyReflectometry.calculators.wrapper_base import linear_spline_resolution_function
from EasyReflectometry.experiment.model import Model
from EasyReflectometry.experiment import Model
from EasyReflectometry.experiment import constant_resolution_function
from EasyReflectometry.experiment import linear_spline_resolution_function
from EasyReflectometry.sample import Layer
from EasyReflectometry.sample import LayerCollection
from EasyReflectometry.sample import Material
Expand Down Expand Up @@ -44,8 +44,8 @@ def test_default(self):
assert_equal(p.background.min, 0.0)
assert_equal(p.background.max, np.Inf)
assert_equal(p.background.fixed, True)
assert p._resolution_function(1) == 5.0
assert p._resolution_function(100) == 5.0
assert p._resolution_function([1]) == 5.0
assert p._resolution_function([100]) == 5.0

def test_from_pars(self):
m1 = Material.from_pars(6.908, -0.278, 'Boron')
Expand Down Expand Up @@ -74,8 +74,8 @@ def test_from_pars(self):
assert_equal(mod.background.min, 0.0)
assert_equal(mod.background.max, np.Inf)
assert_equal(mod.background.fixed, True)
assert mod._resolution_function(1) == 2.0
assert mod._resolution_function(100) == 2.0
assert mod._resolution_function([1]) == 2.0
assert mod._resolution_function([100]) == 2.0

def test_add_item(self):
m1 = Material.from_pars(6.908, -0.278, 'Boron')
Expand Down
18 changes: 18 additions & 0 deletions tests/experiment/test_resolution_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np

from EasyReflectometry.experiment.resolution_functions import constant_resolution_function
from EasyReflectometry.experiment.resolution_functions import linear_spline_resolution_function


def test_constant_resolution_function():
resolution_function = constant_resolution_function(5)
assert np.all(resolution_function([0, 2.5]) == [5, 5])
assert resolution_function([-100]) == 5
assert resolution_function([100]) == 5


def test_linear_spline_resolution_function():
resolution_function = linear_spline_resolution_function([0, 10], [5, 10])
assert np.all(resolution_function([0, 2.5]) == [5, 6.25])
assert resolution_function([-100]) == 5
assert resolution_function([100]) == 10
2 changes: 1 addition & 1 deletion tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import EasyReflectometry
from EasyReflectometry.calculators import CalculatorFactory
from EasyReflectometry.data import load
from EasyReflectometry.experiment.model import Model
from EasyReflectometry.experiment import Model
from EasyReflectometry.fitting import Fitter
from EasyReflectometry.sample import Layer
from EasyReflectometry.sample import Material
Expand Down

0 comments on commit 3bf2204

Please sign in to comment.