Skip to content

Commit

Permalink
Merge branch 'jussihakosalo/indica311' of github.com:indica-mcf/Indic…
Browse files Browse the repository at this point in the history
…a into jussihakosalo/indica311
  • Loading branch information
Jussi Hakosalo committed Jun 24, 2024
2 parents 4224e74 + 2daa0f8 commit f840933
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 55 deletions.
3 changes: 2 additions & 1 deletion indica/converters/transect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import xarray as xr
from xarray import DataArray

from indica.utilities import format_coord
from .abstractconverter import CoordinateTransform
from ..numpy_typing import Coordinates
from ..numpy_typing import LabeledArray
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(
self.name: str = f"{name}_transect_transform"

x1 = np.arange(len(x_positions))
self.x1: DataArray = DataArray(x1, coords=[(self.x1_name, x1)])
self.x1: DataArray = format_coord(x1, self.x1_name)
self.x2: DataArray = DataArray(None)

# TODO: add intersection with first walls to restrict possible coordinates
Expand Down
5 changes: 3 additions & 2 deletions indica/models/diode_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import indica.physics as ph
from indica.readers.available_quantities import AVAILABLE_QUANTITIES
from indica.utilities import assign_datatype
from indica.utilities import format_coord


class BremsstrahlungDiode(DiagnosticModel):
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(
self.filter_wavelength - self.filter_fwhm * 2,
self.filter_wavelength + self.filter_fwhm * 2,
)
self.wavelength = DataArray(wavelength, coords=[("wavelength", wavelength)])
self.wavelength = format_coord(wavelength, "wavelength")

# Transmission filter function
transmission = ph.make_window(
Expand All @@ -81,7 +82,7 @@ def __init__(
self.filter_fwhm,
window=self.filter_type,
)
self.transmission = DataArray(transmission, coords=[("wavelength", wavelength)])
self.transmission = DataArray(transmission, coords={"wavelength": wavelength})

def integrate_spectra(self, spectra: DataArray, fit_background: bool = True):
"""
Expand Down
2 changes: 1 addition & 1 deletion indica/models/helike_spectroscopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
mask[(window > mslice.start) & (window < mslice.stop)] = 1
else:
mask[:] = 1
self.window = DataArray(mask, coords=[("wavelength", window)])
self.window = DataArray(mask, coords={"wavelength": window})
self._get_atomic_data(self.window)

self.line_emission: dict
Expand Down
40 changes: 20 additions & 20 deletions indica/models/plasma.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def initialize_variables(self, n_rad: int = 41, n_R: int = 100, n_z: int = 100):
index = np.arange(n_R)
R_midplane = np.linspace(self.R.min(), self.R.max(), n_R)
z_midplane = np.full_like(R_midplane, 0.0)
coords_midplane = [("index", format_coord(index, "index"))]
coords_midplane = {"index": index}
self.R_midplane = format_dataarray(R_midplane, "R_midplane", coords_midplane)
self.z_midplane = format_dataarray(z_midplane, "z_midplane", coords_midplane)

Expand All @@ -151,7 +151,7 @@ def initialize_variables(self, n_rad: int = 41, n_R: int = 100, n_z: int = 100):
element_name.append(_name)
element_symbol.append(_symbol)

coords_elem = [("element", list(self.elements))]
coords_elem = {"element": list(self.elements)}
self.element_z = format_dataarray(element_z, "atomic_number", coords_elem)
self.element_a = format_dataarray(element_a, "atomic_weight", coords_elem)
self.element_name = format_dataarray(element_name, "element_name", coords_elem)
Expand All @@ -170,19 +170,19 @@ def initialize_variables(self, n_rad: int = 41, n_R: int = 100, n_z: int = 100):
data3d = np.zeros((nel, nt, nr))
data3d_imp = np.zeros((nimp, nt, nr))

coords1d_time = [("t", self.t)]
coords2d = [("t", self.t), (self.rho_type, self.rho)]
coords2d_elem = [("element", list(self.elements)), ("t", self.t)]
coords3d = [
("element", list(self.elements)),
("t", self.t),
(self.rho_type, self.rho),
]
coords3d_imp = [
("element", list(self.impurities)),
("t", self.t),
(self.rho_type, self.rho),
]
coords1d_time = {"t": self.t}
coords2d = {"t": self.t, self.rho_type: self.rho}
coords2d_elem = {"element": list(self.elements), "t": self.t}
coords3d = {
"element": list(self.elements),
"t": self.t,
self.rho_type: self.rho,
}
coords3d_imp = {
"element": list(self.impurities),
"t": self.t,
self.rho_type: self.rho,
}

# Independent plasma quantities
self.electron_temperature = format_dataarray(
Expand Down Expand Up @@ -286,11 +286,11 @@ def initialize_variables(self, n_rad: int = 41, n_R: int = 100, n_z: int = 100):
for elem in self.elements:
nz = self.element_z.sel(element=elem).values + 1
ion_charge = format_coord(np.arange(nz), "ion_charge")
coords3d_fract = [
("t", self.t),
("rho_poloidal", self.rho),
("ion_charge", ion_charge),
]
coords3d_fract = {
"t": self.t,
"rho_poloidal": self.rho,
"ion_charge": ion_charge,
}
data3d_fz = np.full((len(self.t), len(self.rho), nz), 0.0)
_fz[elem] = format_dataarray(
data3d_fz, "fractional_abundance", coords3d_fract, make_copy=True
Expand Down
5 changes: 2 additions & 3 deletions indica/profiles_gauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def __init__(
self.x = np.linspace(0, 1, 15) ** 0.7
self.datatype = datatype
if xspl is None:
xspl = np.linspace(0, 1.0, 30)
xspl = DataArray(xspl, coords=[(self.coord, xspl)])
xspl = format_coord(np.linspace(0, 1.0, 30), self.coord)
self.xspl = xspl
self.profile_parameters: list = [
"y0",
Expand Down Expand Up @@ -186,7 +185,7 @@ def gaussian(x, A, B, x_0, w):
)
_yspl = self.cubicspline(self.xspl)

coords = [(self.coord, format_coord(self.xspl, self.coord))]
coords = {self.coord: self.xspl}
yspl = format_dataarray(_yspl, self.datatype, coords=coords)
self.yspl = yspl

Expand Down
26 changes: 13 additions & 13 deletions indica/readers/abstractreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from indica.numpy_typing import OnlyArray
from indica.numpy_typing import RevisionLike
from indica.readers.available_quantities import AVAILABLE_QUANTITIES
from indica.utilities import format_coord
from indica.utilities import format_dataarray


Expand Down Expand Up @@ -198,17 +197,18 @@ def get_ppts(
"R"
] # necessary because of assign_dataarray...

coords = [
("t", database_results["t"]),
("channel", database_results["channel"]),
]
coords_chan = {"channel": database_results["channel"]}
coords = {
"t": database_results["t"],
"channel": database_results["channel"],
}
rho_poloidal_coords = xr.DataArray(
database_results["rho_poloidal_data"], coords=coords
)
rho_poloidal_coords = rho_poloidal_coords.sel(t=slice(self._tstart, self._tend))

rpos_coords = xr.DataArray(database_results["R_data"], coords=[coords[1]])
zpos_coords = xr.DataArray(database_results["z_data"], coords=[coords[1]])
rpos_coords = xr.DataArray(database_results["R_data"], coords=coords_chan)
zpos_coords = xr.DataArray(database_results["z_data"], coords=coords_chan)

data = {}
for quantity in quantities:
Expand Down Expand Up @@ -821,7 +821,7 @@ def get_astra(

# Reorganise coordinate system to match Indica default rho-poloidal
t = database_results["t"]
t = DataArray(t, coords=[("t", t)], attrs={"long_name": "t", "units": "s"})
t = DataArray(t, coords={"t": t}, attrs={"long_name": "t", "units": "s"})
psin = database_results["psin"]
rhop_psin = np.sqrt(psin)
rhop_interp = np.linspace(0, 1.0, 65)
Expand Down Expand Up @@ -862,12 +862,12 @@ def get_astra(
else:
name_coords = []

coords: list = [("t", t)]
coords: dict = {"t": t}
if len(name_coords) > 0:
for coord in name_coords:
coords.append((coord, radial_coords[coord]))
coords[coord] = radial_coords[coord]

if len(np.shape(database_results[quantity])) != len(coords):
if len(np.shape(database_results[quantity])) != len(coords.keys()):
continue

quant_data = self.assign_dataarray(
Expand Down Expand Up @@ -952,9 +952,9 @@ def assign_dataarray(
DataArray with assigned coordinates, transform, error, long_name and units
"""
# Build coordinate dictionary
coords = []
coords = {}
for dim in dims:
coords.append((dim, format_coord(database_results[dim], dim)))
coords[dim] = database_results[dim]

# Build DataArray data with coordinates and long_name + units
var_name = self.available_quantities(instrument)[quantity]
Expand Down
23 changes: 9 additions & 14 deletions indica/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,13 @@ def format_coord(data: LabeledArray, var_name: str):
var_name
Name of the variable to be assigned as coordinate
"""
coords = [(var_name, DataArray(data, dims=var_name))]
return format_dataarray(data, var_name, coords)
return format_dataarray(data, var_name, {var_name: data})


def format_dataarray(
data: LabeledArray,
var_name: str,
coords: List[Tuple[str, Any]] = [],
coords: Dict[str, Any] = {},
make_copy: bool = False,
):
"""
Expand All @@ -344,7 +343,7 @@ def format_dataarray(
Returns
-------
Formatted data array
Formatted data array, including attribute assignement (also to coords)
"""

Expand All @@ -353,24 +352,20 @@ def format_dataarray(
else:
_data = data

"""
old
if len(coords) != 0:
data_array = DataArray(_data, coords=coords)
"""
processed_coords = {
name: coord.data if isinstance(coord, DataArray) else coord
for name, coord in coords
}

if len(coords) != 0:
processed_coords = {
name: coord.data if isinstance(coord, DataArray) else coord
for name, coord in coords.items()
}
data_array = DataArray(_data, coords=processed_coords, name=var_name)

else:
if type(_data) != DataArray:
raise ValueError("data must be a DataArray if coordinates are not given")

assign_datatype(data_array, var_name)
for dim in data_array.dims:
assign_datatype(data_array.coords[dim], dim)

return data_array

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/converters/test_time_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Test_time:
nt = 50
time = np.linspace(0, 0.1, nt)
values = np.sin(np.linspace(0, np.pi * 3, nt)) + np.random.random(nt) - 0.5
data = DataArray(values, coords=[("t", time)])
data = DataArray(values, coords={"t": time})
channels = np.array([0, 1, 2, 3], dtype=int)
d = []
for c in channels:
Expand Down

0 comments on commit f840933

Please sign in to comment.