Skip to content

Commit

Permalink
Refactored across the board to follow new convention.
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosertoli committed Jul 24, 2024
1 parent 13c75cc commit 1a5e52c
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 139 deletions.
38 changes: 0 additions & 38 deletions indica/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""

from itertools import filterfalse
from numbers import Number
from typing import Any
from typing import Callable
Expand Down Expand Up @@ -715,43 +714,6 @@ def equilibrium(self):
if hasattr(self._obj.attrs["transform"], "equilibrium"):
del self._obj.attrs["transform"].equilibrium

@property
def with_ignored_data(self) -> xr.DataArray:
"""The full version of this data, including the channels which were
dropped at read-in.
"""
if "dropped" in self._obj.attrs:
ddim = self.drop_dim
dropped = self._obj.attrs["dropped"]
result = self._obj.copy()
result.loc[{ddim: dropped.coords[ddim]}] = dropped
if "error" in self._obj.attrs:
result.attrs["error"] = result.attrs["error"].copy()
result.attrs["error"].loc[{ddim: dropped.coords[ddim]}] = dropped.attrs[
"error"
]
del result.attrs["dropped"]
return result
else:
return self._obj

@property
def drop_dim(self) -> Optional[str]:
"""The dimension, if any, which contains dropped channels."""
if "dropped" in self._obj.attrs:
return str(
next(
filterfalse(
lambda dim: self._obj.coords[dim].equals(
self._obj.attrs["dropped"].coords[dim]
),
self._obj.dims,
)
)
)
return None


@xr.register_dataset_accessor("indica")
class InDiCADatasetAccessor:
Expand Down
6 changes: 4 additions & 2 deletions indica/models/diode_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,15 @@ def integrate_spectra(self, spectra: DataArray, fit_background: bool = True):
spectra_to_integrate = _spectra
spectra_to_integrate_err = _spectra_err

spectra_to_integrate.attrs["error"] = spectra_to_integrate_err
spectra_to_integrate = spectra_to_integrate.assign_coords(
error=(spectra_to_integrate.dims, spectra_to_integrate_err.data)
)

integral = (spectra_to_integrate * transmission).sum("wavelength")
integral_err = (np.sqrt((spectra_to_integrate_err * transmission) ** 2)).sum(
"wavelength"
)
integral.attrs["error"] = integral_err
integral = integral.assign_coords(error=(integral.dims, integral_err.data))

return spectra_to_integrate, integral

Expand Down
8 changes: 6 additions & 2 deletions indica/plotters/plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,14 +589,18 @@ def compare_pulses(
_error = raw_data["cxff_pi"]["ti"].error.sel(channel=chan)
_error = xr.where(_error > 0, _error, np.nan)
raw_data["cxff_pi"]["ti"] = _data
raw_data["cxff_pi"]["ti"].attrs["error"] = _error
raw_data["cxff_pi"]["ti"] = raw_data["cxff_pi"]["ti"].assign_coords(
error=(raw_data["cxff_pi"]["ti"].dims, _error.data)
)

_data = raw_data["cxff_pi"]["vtor"].sel(channel=chan)
_data = xr.where(_data > 0, _data, np.nan)
_error = raw_data["cxff_pi"]["vtor"].error.sel(channel=chan)
_error = xr.where(_error > 0, _error, np.nan)
raw_data["cxff_pi"]["vtor"] = _data
raw_data["cxff_pi"]["vtor"].attrs["error"] = _error
raw_data["cxff_pi"]["vtor"] = raw_data["cxff_pi"]["vtor"].assign_coords(
error=(raw_data["cxff_pi"]["vtor"].dims, _error.data)
)

if "hnbi1" in raw_data.keys():
raw_data["nbi"] = {
Expand Down
4 changes: 2 additions & 2 deletions indica/plotters/plot_time_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def plot_data(data, quantity: str, pulse: int, tplot: float, key="raw", color=No
_data = data[key][quantity][pulse]
tslice = slice(_data.t.min().values, _data.t.max().values)
if "error" not in _data.attrs:
_data.attrs["error"] = xr.full_like(_data, 0.0)
_data = _data.assign_coords(error=(_data.dims, xr.full_like(_data, 0.0).data))
if "stdev" not in _data.attrs:
_data.attrs["stdev"] = xr.full_like(_data, 0.0)
_data = _data.assign_coords(stdev=(_data.dims, xr.full_like(_data, 0.0).data))
_err = np.sqrt(_data.error**2 + _data.stdev**2)
_err = xr.where(_err / _data.values < 1.0, _err, 0.0)
if len(_data.dims) > 1:
Expand Down
9 changes: 0 additions & 9 deletions indica/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,16 +276,7 @@ def input_check(
) or isinstance(var_to_check, bool):
return

# Handles dropped channels, if present
sliced_var_to_check = deepcopy(var_to_check)
if (
isinstance(var_to_check, (DataArray, Dataset))
and "dropped" in var_to_check.attrs
):
dropped_coords = var_to_check.attrs["dropped"].coords
for icoord in dropped_coords.keys():
dropped_coord = dropped_coords[icoord]
sliced_var_to_check = var_to_check.drop_sel({icoord: dropped_coord})

if np.any(np.isnan(sliced_var_to_check)):
raise ValueError(f"{var_name} cannot contain any NaNs.")
Expand Down
8 changes: 6 additions & 2 deletions indica/workflows/load_modelling_plasma.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,13 @@ def plot_data_bckc_comparison(
tslice_binned = tslice

if "error" not in _binned.attrs:
_binned.attrs["error"] = xr.full_like(_binned, 0.0)
_binned = _binned.assign_coords(
error=(_binned.dims, xr.full_like(_binned, 0.0).data)
)
if "stdev" not in _binned.attrs:
_binned.attrs["stdev"] = xr.full_like(_binned, 0.0)
_binned = _binned.assign_coords(
stdev=(_binned.dims, xr.full_like(_binned, 0.0).data)
)
err = np.sqrt(_binned.error**2 + _binned.stdev**2)
err = xr.where(err / _binned.values < 1.0, err, 0.0)

Expand Down
4 changes: 3 additions & 1 deletion indica/workflows/run_tomo_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def example_tomo(
tomo.emiss_err,
coords=[("t", tomo.tvec), ("rho_poloidal", tomo.rho_grid_centers)],
)
inverted_emissivity.attrs["error"] = inverted_error
inverted_emissivity = inverted_emissivity.assign_coords(
error=(inverted_emissivity.dims, inverted_error.data)
)

data_tomo = brightness
bckc_tomo = DataArray(tomo.backprojection, coords=data_tomo.coords)
Expand Down
11 changes: 7 additions & 4 deletions indica/workflows/zeff_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def calculate_zeff(
spectra_to_integrate = None

if not hasattr(filter_data, "error"):
filter_data.attrs["error"] = filter_data * default_perc_err
filter_data = filter_data.assign_coords(
error=(filter_data.dims, (filter_data * default_perc_err).data)
)

print("Calculate LOS-averaged Zeff")
zeff_los_avrg = calculate_zeff_los_averaged(
Expand Down Expand Up @@ -325,7 +327,7 @@ def calculate_zeff_profile(
tomo_result["profile"]["sym_emissivity_err"],
coords=coords,
)
emissivity.attrs["error"] = _error
emissivity = emissivity.assign_coords(error=(emissivity.dims, _error.data))

wlnght = filter_wavelength
_te = te_fit.interp(rho_poloidal=emissivity.rho_poloidal)
Expand All @@ -351,8 +353,9 @@ def calculate_zeff_profile(
bremsstrahlung=emissivity + emissivity.error,
gaunt_approx="callahan",
)
zeff_profile.attrs["error"] = np.abs(zeff_up - zeff_lo)

zeff_profile = zeff_profile.assign_coords(
error=(zeff_profile.dims, np.abs(zeff_up - zeff_lo).data)
)
return zeff_profile, tomo


Expand Down
80 changes: 1 addition & 79 deletions tests/unit/converters/test_time_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,7 @@ class Test_time:
data = xr.concat(d, "chan").assign_coords(chan=channels)
error = deepcopy(data)
error.values = np.sqrt(np.abs(data.values))
dropped = xr.full_like(data, np.nan)
provenance = {"none": None}
partial_provenance = {"none": None}
data.attrs = {
"error": error,
"provenance": provenance,
"partial_provenance": partial_provenance,
}
data = data.assign_coords(error=(data.dims, error.data))

dt_data = (data.t[1] - data.t[0]).values

Expand Down Expand Up @@ -140,77 +133,6 @@ def test_interpolation(self):
assert np.all(_data.error.t >= self.data.error.t.min())
assert _dt == approx(dt)

def test_binning_dropped(self):
"""Checks binning including dropped channels"""
dt = self.dt_data * 3.0

tstart = (self.data.t[0] + 5 * self.dt_data).values
tend = (self.data.t[-1] - 10 * self.dt_data).values

chan_to_drop = 1
data = deepcopy(self.data)
data.attrs["dropped"] = self.dropped
data.dropped.loc[dict(chan=chan_to_drop)] = data.sel(chan=chan_to_drop)
data.loc[dict(chan=chan_to_drop)] = np.full_like(
data.sel(chan=chan_to_drop), np.nan
)

try:
_data = convert_in_time_dt(tstart, tend, dt, data)
except Exception as e:
raise e

_dt = (_data.t[1] - _data.t[0]).values
assert np.all(_data.t <= data.t.max())
assert np.all(_data.t >= data.t.min())
assert _dt == approx(dt)

_dt = (_data.error.t[1] - _data.error.t[0]).values
assert np.all(_data.error.t <= data.error.t.max())
assert np.all(_data.error.t >= data.error.t.min())
assert _dt == approx(dt)

_dt = (_data.dropped.t[1] - _data.dropped.t[0]).values
assert np.all(_data.dropped.t <= data.dropped.t.max())
assert np.all(_data.dropped.t >= data.dropped.t.min())
assert _dt == approx(dt)

def test_interpolation_dropped(self):
"""Dropped channels are correctly interpolated"""

dt = self.dt_data / 3.0

tstart = (self.data.t[0] + 5 * self.dt_data).values
tend = (self.data.t[-1] - 10 * self.dt_data).values

chan_to_drop = 1
data = deepcopy(self.data)
data.attrs["dropped"] = self.dropped
data.dropped.loc[dict(chan=chan_to_drop)] = data.sel(chan=chan_to_drop)
data.loc[dict(chan=chan_to_drop)] = np.full_like(
data.sel(chan=chan_to_drop), np.nan
)

try:
_data = convert_in_time_dt(tstart, tend, dt, data)
except Exception as e:
raise e

_dt = (_data.t[1] - _data.t[0]).values
assert np.all(_data.t <= data.t.max())
assert np.all(_data.t >= data.t.min())
assert _dt == approx(dt)

_dt = (_data.error.t[1] - _data.error.t[0]).values
assert np.all(_data.error.t <= data.error.t.max())
assert np.all(_data.error.t >= data.error.t.min())
assert _dt == approx(dt)

_dt = (_data.dropped.t[1] - _data.dropped.t[0]).values
assert np.all(_data.dropped.t <= data.dropped.t.max())
assert np.all(_data.dropped.t >= data.dropped.t.min())
assert _dt == approx(dt)

def test_wrong_start_time(self):
"""Checks start time wrongly set"""
dt = self.dt_data
Expand Down

0 comments on commit 1a5e52c

Please sign in to comment.