diff --git a/indica/converters/transect.py b/indica/converters/transect.py index d59aef7c..0886324f 100644 --- a/indica/converters/transect.py +++ b/indica/converters/transect.py @@ -5,6 +5,7 @@ import xarray as xr from xarray import DataArray +from indica.utilities import format_coord from .abstractconverter import CoordinateTransform from ..numpy_typing import Coordinates from ..numpy_typing import LabeledArray @@ -61,7 +62,7 @@ def __init__( self.name: str = f"{name}_transect_transform" x1 = np.arange(len(x_positions)) - self.x1: DataArray = DataArray(x1, coords=[(self.x1_name, x1)]) + self.x1: DataArray = format_coord(x1, self.x1_name) self.x2: DataArray = DataArray(None) # TODO: add intersection with first walls to restrict possible coordinates diff --git a/indica/models/diode_filters.py b/indica/models/diode_filters.py index ad645b6c..66d5e140 100644 --- a/indica/models/diode_filters.py +++ b/indica/models/diode_filters.py @@ -13,6 +13,7 @@ import indica.physics as ph from indica.readers.available_quantities import AVAILABLE_QUANTITIES from indica.utilities import assign_datatype +from indica.utilities import format_coord class BremsstrahlungDiode(DiagnosticModel): @@ -72,7 +73,7 @@ def __init__( self.filter_wavelength - self.filter_fwhm * 2, self.filter_wavelength + self.filter_fwhm * 2, ) - self.wavelength = DataArray(wavelength, coords=[("wavelength", wavelength)]) + self.wavelength = format_coord(wavelength, "wavelength") # Transmission filter function transmission = ph.make_window( @@ -81,7 +82,7 @@ def __init__( self.filter_fwhm, window=self.filter_type, ) - self.transmission = DataArray(transmission, coords=[("wavelength", wavelength)]) + self.transmission = DataArray(transmission, coords={"wavelength": wavelength}) def integrate_spectra(self, spectra: DataArray, fit_background: bool = True): """ diff --git a/indica/models/helike_spectroscopy.py b/indica/models/helike_spectroscopy.py index 99b0ff7d..4b83983b 100644 --- a/indica/models/helike_spectroscopy.py +++ b/indica/models/helike_spectroscopy.py @@ -86,7 +86,7 @@ def __init__( mask[(window > mslice.start) & (window < mslice.stop)] = 1 else: mask[:] = 1 - self.window = DataArray(mask, coords=[("wavelength", window)]) + self.window = DataArray(mask, coords={"wavelength": window}) self._get_atomic_data(self.window) self.line_emission: dict diff --git a/indica/models/plasma.py b/indica/models/plasma.py index 1d7fd25d..56d7c80d 100644 --- a/indica/models/plasma.py +++ b/indica/models/plasma.py @@ -133,7 +133,7 @@ def initialize_variables(self, n_rad: int = 41, n_R: int = 100, n_z: int = 100): index = np.arange(n_R) R_midplane = np.linspace(self.R.min(), self.R.max(), n_R) z_midplane = np.full_like(R_midplane, 0.0) - coords_midplane = [("index", format_coord(index, "index"))] + coords_midplane = {"index": index} self.R_midplane = format_dataarray(R_midplane, "R_midplane", coords_midplane) self.z_midplane = format_dataarray(z_midplane, "z_midplane", coords_midplane) @@ -151,7 +151,7 @@ def initialize_variables(self, n_rad: int = 41, n_R: int = 100, n_z: int = 100): element_name.append(_name) element_symbol.append(_symbol) - coords_elem = [("element", list(self.elements))] + coords_elem = {"element": list(self.elements)} self.element_z = format_dataarray(element_z, "atomic_number", coords_elem) self.element_a = format_dataarray(element_a, "atomic_weight", coords_elem) self.element_name = format_dataarray(element_name, "element_name", coords_elem) @@ -170,19 +170,19 @@ def initialize_variables(self, n_rad: int = 41, n_R: int = 100, n_z: int = 100): data3d = np.zeros((nel, nt, nr)) data3d_imp = np.zeros((nimp, nt, nr)) - coords1d_time = [("t", self.t)] - coords2d = [("t", self.t), (self.rho_type, self.rho)] - coords2d_elem = [("element", list(self.elements)), ("t", self.t)] - coords3d = [ - ("element", list(self.elements)), - ("t", self.t), - (self.rho_type, self.rho), - ] - coords3d_imp = [ - ("element", list(self.impurities)), - ("t", self.t), - (self.rho_type, self.rho), - ] + coords1d_time = {"t": self.t} + coords2d = {"t": self.t, self.rho_type: self.rho} + coords2d_elem = {"element": list(self.elements), "t": self.t} + coords3d = { + "element": list(self.elements), + "t": self.t, + self.rho_type: self.rho, + } + coords3d_imp = { + "element": list(self.impurities), + "t": self.t, + self.rho_type: self.rho, + } # Independent plasma quantities self.electron_temperature = format_dataarray( @@ -286,11 +286,11 @@ def initialize_variables(self, n_rad: int = 41, n_R: int = 100, n_z: int = 100): for elem in self.elements: nz = self.element_z.sel(element=elem).values + 1 ion_charge = format_coord(np.arange(nz), "ion_charge") - coords3d_fract = [ - ("t", self.t), - ("rho_poloidal", self.rho), - ("ion_charge", ion_charge), - ] + coords3d_fract = { + "t": self.t, + "rho_poloidal": self.rho, + "ion_charge": ion_charge, + } data3d_fz = np.full((len(self.t), len(self.rho), nz), 0.0) _fz[elem] = format_dataarray( data3d_fz, "fractional_abundance", coords3d_fract, make_copy=True diff --git a/indica/profiles_gauss.py b/indica/profiles_gauss.py index e97460be..77ab3fd0 100644 --- a/indica/profiles_gauss.py +++ b/indica/profiles_gauss.py @@ -43,8 +43,7 @@ def __init__( self.x = np.linspace(0, 1, 15) ** 0.7 self.datatype = datatype if xspl is None: - xspl = np.linspace(0, 1.0, 30) - xspl = DataArray(xspl, coords=[(self.coord, xspl)]) + xspl = format_coord(np.linspace(0, 1.0, 30), self.coord) self.xspl = xspl self.profile_parameters: list = [ "y0", @@ -186,7 +185,7 @@ def gaussian(x, A, B, x_0, w): ) _yspl = self.cubicspline(self.xspl) - coords = [(self.coord, format_coord(self.xspl, self.coord))] + coords = {self.coord: self.xspl} yspl = format_dataarray(_yspl, self.datatype, coords=coords) self.yspl = yspl diff --git a/indica/readers/abstractreader.py b/indica/readers/abstractreader.py index 3f1aaea2..bbe41d61 100644 --- a/indica/readers/abstractreader.py +++ b/indica/readers/abstractreader.py @@ -18,7 +18,6 @@ from indica.numpy_typing import OnlyArray from indica.numpy_typing import RevisionLike from indica.readers.available_quantities import AVAILABLE_QUANTITIES -from indica.utilities import format_coord from indica.utilities import format_dataarray @@ -198,17 +197,18 @@ def get_ppts( "R" ] # necessary because of assign_dataarray... - coords = [ - ("t", database_results["t"]), - ("channel", database_results["channel"]), - ] + coords_chan = {"channel": database_results["channel"]} + coords = { + "t": database_results["t"], + "channel": database_results["channel"], + } rho_poloidal_coords = xr.DataArray( database_results["rho_poloidal_data"], coords=coords ) rho_poloidal_coords = rho_poloidal_coords.sel(t=slice(self._tstart, self._tend)) - rpos_coords = xr.DataArray(database_results["R_data"], coords=[coords[1]]) - zpos_coords = xr.DataArray(database_results["z_data"], coords=[coords[1]]) + rpos_coords = xr.DataArray(database_results["R_data"], coords=coords_chan) + zpos_coords = xr.DataArray(database_results["z_data"], coords=coords_chan) data = {} for quantity in quantities: @@ -821,7 +821,7 @@ def get_astra( # Reorganise coordinate system to match Indica default rho-poloidal t = database_results["t"] - t = DataArray(t, coords=[("t", t)], attrs={"long_name": "t", "units": "s"}) + t = DataArray(t, coords={"t": t}, attrs={"long_name": "t", "units": "s"}) psin = database_results["psin"] rhop_psin = np.sqrt(psin) rhop_interp = np.linspace(0, 1.0, 65) @@ -862,12 +862,12 @@ def get_astra( else: name_coords = [] - coords: list = [("t", t)] + coords: dict = {"t": t} if len(name_coords) > 0: for coord in name_coords: - coords.append((coord, radial_coords[coord])) + coords[coord] = radial_coords[coord] - if len(np.shape(database_results[quantity])) != len(coords): + if len(np.shape(database_results[quantity])) != len(coords.keys()): continue quant_data = self.assign_dataarray( @@ -952,9 +952,9 @@ def assign_dataarray( DataArray with assigned coordinates, transform, error, long_name and units """ # Build coordinate dictionary - coords = [] + coords = {} for dim in dims: - coords.append((dim, format_coord(database_results[dim], dim))) + coords[dim] = database_results[dim] # Build DataArray data with coordinates and long_name + units var_name = self.available_quantities(instrument)[quantity] diff --git a/indica/utilities.py b/indica/utilities.py index b80a4fca..0be02aff 100644 --- a/indica/utilities.py +++ b/indica/utilities.py @@ -320,14 +320,13 @@ def format_coord(data: LabeledArray, var_name: str): var_name Name of the variable to be assigned as coordinate """ - coords = [(var_name, DataArray(data, dims=var_name))] - return format_dataarray(data, var_name, coords) + return format_dataarray(data, var_name, {var_name: data}) def format_dataarray( data: LabeledArray, var_name: str, - coords: List[Tuple[str, Any]] = [], + coords: Dict[str, Any] = {}, make_copy: bool = False, ): """ @@ -344,7 +343,7 @@ def format_dataarray( Returns ------- - Formatted data array + Formatted data array, including attribute assignement (also to coords) """ @@ -353,17 +352,11 @@ def format_dataarray( else: _data = data - """ - old - if len(coords) != 0: - data_array = DataArray(_data, coords=coords) - """ - processed_coords = { - name: coord.data if isinstance(coord, DataArray) else coord - for name, coord in coords - } - if len(coords) != 0: + processed_coords = { + name: coord.data if isinstance(coord, DataArray) else coord + for name, coord in coords.items() + } data_array = DataArray(_data, coords=processed_coords, name=var_name) else: @@ -371,6 +364,8 @@ def format_dataarray( raise ValueError("data must be a DataArray if coordinates are not given") assign_datatype(data_array, var_name) + for dim in data_array.dims: + assign_datatype(data_array.coords[dim], dim) return data_array diff --git a/tests/unit/converters/test_time_dt.py b/tests/unit/converters/test_time_dt.py index 3adee9e5..2a1e61ab 100644 --- a/tests/unit/converters/test_time_dt.py +++ b/tests/unit/converters/test_time_dt.py @@ -14,7 +14,7 @@ class Test_time: nt = 50 time = np.linspace(0, 0.1, nt) values = np.sin(np.linspace(0, np.pi * 3, nt)) + np.random.random(nt) - 0.5 - data = DataArray(values, coords=[("t", time)]) + data = DataArray(values, coords={"t": time}) channels = np.array([0, 1, 2, 3], dtype=int) d = [] for c in channels: