Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marcosertoli/dev #292

Merged
merged 11 commits into from
Oct 12, 2023
2 changes: 1 addition & 1 deletion indica/converters/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def bin_in_time_dt(tstart: float, tend: float, dt: float, data: DataArray) -> Da
-------
:
Array like the input, but binned along the time axis.

TODO: add possibility of doing 50% overlap of time bins!
"""
check_bounds_bin(tstart, tend, dt, data)
tlabels = get_tlabels_dt(tstart, tend, dt)
Expand Down
95 changes: 59 additions & 36 deletions indica/operators/spline_fit_easy.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,63 +99,86 @@ def residuals(yknots):


def spline_fit_ts(
pulse: int,
tstart: float = 0.0,
tend: float = 0.2,
plot: bool = False,
pulse: int = 11314,
tstart: float = 0.03,
tend: float = 0.1,
dt: float = 0.01,
quantity: str = "te",
R_shift: float = 0.0,
knots: list = None,
plot: bool = True,
):
st40 = ReadST40(pulse, tstart=tstart, tend=tend, dt=dt)
st40(["ts"], R_shift=R_shift)

if quantity == "te" and knots is None:
knots = [0, 0.3, 0.6, 0.8, 1.1]
if quantity == "ne" and knots is None:
knots = [0, 0.3, 0.6, 0.8, 0.95, 1.1]
data_all = st40.raw_data["ts"][quantity]
t = data_all.t
transform = data_all.transform
transform.convert_to_rho_theta(t=data_all.t)

ST40 = ReadST40(pulse, tstart=tstart, tend=tend)
ST40(["ts"])

Te_data_all = ST40.binned_data["ts"]["te"]
t = ST40.binned_data["ts"]["te"].t
transform = ST40.binned_data["ts"]["te"].transform
R = transform.R
Rmag = transform.equilibrium.rmag.interp(t=t)

# Fit all available TS data
ind = np.full_like(Te_data_all, True)
ind = np.full_like(data_all, True)
rho = xr.where(ind, transform.rho, np.nan)
Te_data = xr.where(ind, Te_data_all, np.nan)
Te_err = xr.where(ind, Te_data_all.error, np.nan)
Te_fit = fit_profile(
rho, Te_data, Te_err, knots=[0, 0.3, 0.5, 0.75, 0.95, 1.05], virtual_knots=False
)
data = xr.where(ind, data_all, np.nan)
err = xr.where(ind, data_all.error, np.nan)
fit = fit_profile(rho, data, err, knots=knots, virtual_knots=False)

# Use only HFS channels
ind = R <= Rmag
rho = xr.where(ind, transform.rho, np.nan)
Te_data = xr.where(ind, Te_data_all, np.nan)
Te_err = xr.where(ind, Te_data_all.error, np.nan)
Te_fit_hfs = fit_profile(rho, Te_data, Te_err, virtual_knots=True)
rho_hfs = xr.where(ind, transform.rho, np.nan)
data_hfs = xr.where(ind, data_all, np.nan)
err_hfs = xr.where(ind, data_all.error, np.nan)
fit_hfs = fit_profile(rho_hfs, data_hfs, err_hfs, knots=knots, virtual_knots=True)

# Use only LFS channels
ind = R >= Rmag
rho_lfs = xr.where(ind, transform.rho, np.nan)
data_lfs = xr.where(ind, data_all, np.nan)
err_lfs = xr.where(ind, data_all.error, np.nan)
fit_lfs = fit_profile(rho_lfs, data_lfs, err_lfs, knots=knots, virtual_knots=True)

if plot:
for t in Te_data_all.t:
for t in data_all.t:
plt.ioff()
plt.errorbar(
Te_data_all.transform.rho.sel(t=t),
Te_data_all.sel(t=t),
Te_data_all.error.sel(t=t),
rho_hfs.sel(t=t),
data_hfs.sel(t=t),
err_hfs.sel(t=t),
marker="o",
label="data HFS",
color="blue",
)
plt.errorbar(
rho_lfs.sel(t=t),
data_lfs.sel(t=t),
err_lfs.sel(t=t),
marker="o",
label="data",
label="data LFS",
color="red",
)
Te_fit.sel(t=t).plot(
linewidth=5, alpha=0.5, color="orange", label="spline fit all"
fit.sel(t=t).plot(
linewidth=5, alpha=0.5, color="black", label="spline fit all"
)
Te_fit_hfs.sel(t=t).plot(
linewidth=5, alpha=0.5, color="red", label="spline fit HFS"
fit_lfs.sel(t=t).plot(
linewidth=5, alpha=0.5, color="red", label="spline fit LFS"
)
fit_hfs.sel(t=t).plot(
linewidth=5, alpha=0.5, color="blue", label="spline fit HFS"
)
plt.legend()
plt.show()

return Te_data_all, Te_fit
return data_all, fit


if __name__ == "__main__":
spline_fit_ts(
10619,
tstart=0.0,
tend=0.2,
plot=True,
)
plt.ioff()
spline_fit_ts(11089, quantity="ne")
plt.show()
1 change: 1 addition & 0 deletions indica/readers/available_quantities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
Quantities that can be read with the current abstract reader implementation
TODO: change the tuple to DataArray (long_name, units) - see examples in abstractreader
"""

from typing import Dict
Expand Down
10 changes: 7 additions & 3 deletions indica/readers/read_st40.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"sxrc_xy1": {"brightness": (0, np.inf)},
"sxrc_xy2": {"brightness": (0, np.inf)},
"blom_xy1": {"brightness": (0, np.inf)},
"ts": {"te": (0, 10.0e3), "ne": (0, 1.0e21)},
"ts": {"te": (0, np.inf), "ne": (0, np.inf)},
"pi": {"spectra": (0, np.inf)},
"tws_c": {"spectra": (0, np.inf)},
}
Expand Down Expand Up @@ -91,8 +91,12 @@ def __init__(
self.tend = tend
self.dt = dt

self.reader = ST40Reader(pulse, tstart - 0.05, tend + 0.05, tree=tree)
self.reader_equil = ST40Reader(pulse, tstart - 0.1, tend + 0.1, tree=tree)
_tend = tend + dt * 2
_tstart = tstart - dt * 2
if _tstart < 0:
_tstart = 0.0
self.reader = ST40Reader(pulse, _tstart, _tend, tree=tree)
self.reader_equil = ST40Reader(pulse, _tstart, _tend, tree=tree)

self.equilibrium: Equilibrium
self.raw_data: dict = {}
Expand Down
27 changes: 14 additions & 13 deletions indica/readers/st40reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""

from copy import deepcopy
from typing import Any
from typing import Dict
from typing import List
Expand Down Expand Up @@ -96,6 +97,7 @@ class ST40Reader(DataReader):
"tws_c": "get_spectrometer",
"ts": "get_thomson_scattering",
}
# TODO: this will not be necessary once the MDS+ standardisation is complete
UIDS_MDS = {
"xrcs": "sxr",
"princeton": "spectrom",
Expand Down Expand Up @@ -296,7 +298,8 @@ class ST40Reader(DataReader):
},
}

_IMPLEMENTATION_QUANTITIES = { # TODO: these will be different diagnostics!!!!!!!
# TODO: this can be deleted once MDS+ standardisation is complete
_IMPLEMENTATION_QUANTITIES = {
"diode_arrays": { # GETTING THE DATA OF THE SXR CAMERA
"sxr_camera_1": ("brightness", "total"),
"sxr_camera_2": ("brightness", "50_Al_filtered"),
Expand All @@ -305,6 +308,7 @@ class ST40Reader(DataReader):
},
}

# TODO: this can be deleted once MDS+ standardisation is complete
_RADIATION_RANGES = {
"sxr_camera_1": (1, 20),
"sxr_camera_2": (21, 40),
Expand Down Expand Up @@ -768,6 +772,7 @@ def _get_charge_exchange(
z, z_path = self._get_signal(uid, instrument, ":z", revision)
R, R_path = self._get_signal(uid, instrument, ":R", revision)

# TODO: temporary fix until geometry sorted (especially pulse if statement..)
try:
location, location_path = self._get_signal(
uid, instrument, ".geometry:location", revision
Expand All @@ -779,15 +784,13 @@ def _get_charge_exchange(
location = np.array([location])
direction = np.array([direction])

# TODO: temporary fix until geometry sorted
if location.shape[0] != x.shape[0]:
if self.pulse > 10200:
index = np.arange(18, 36)
else:
index = np.arange(21, 36)
location = location[index]
direction = direction[index]

except TreeNNF:
location = None
direction = None
Expand All @@ -812,10 +815,8 @@ def _get_charge_exchange(
)
except TreeNNF:
qval_err = np.full_like(qval, 0.0)
# q_path_err = ""

dimensions, _ = self._get_signal_dims(q_path, len(qval.shape))

results[q + "_records"] = q_path
results[q] = qval
results[f"{q}_error"] = qval_err
Expand Down Expand Up @@ -868,13 +869,6 @@ def _get_spectrometer(
location = np.array([location])
direction = np.array([direction])

# if self.pulse > 10200:
# index = np.arange(18, 36)
# else:
# index = np.arange(21, 36)
# location = location[index]
# direction = direction[index]

for q in quantities:
qval, q_path = self._get_signal(
uid,
Expand All @@ -892,7 +886,6 @@ def _get_spectrometer(
)
except TreeNNF:
qval_err = np.full_like(qval, 0.0)
# q_path_err = ""

dimensions, _ = self._get_signal_dims(q_path, len(qval.shape))

Expand Down Expand Up @@ -953,6 +946,8 @@ def _get_diode_filters(
_labels, _ = self._get_signal(uid, instrument, ":label", revision)
if type(_labels[0]) == np.bytes_:
labels = np.array([label.decode("UTF-8") for label in _labels])
else:
labels = _labels

results["times"] = times
results["labels"] = labels
Expand Down Expand Up @@ -1085,6 +1080,8 @@ def _get_thomson_scattering(
revision = results["revision"]

times, times_path = self._get_signal(uid, instrument, ":time", revision)
# TODO: hardcoded correction to TS coordinates to be fixed in MDS+
print("\n Hardcoded correction to TS coordinates to be fixed in MDS+ \n")
# location, location_path = self._get_signal(
# uid, instrument, ".geometry:location", revision
# )
Expand All @@ -1095,6 +1092,9 @@ def _get_thomson_scattering(
y, y_path = self._get_signal(uid, instrument, ":y", revision)
z, z_path = self._get_signal(uid, instrument, ":z", revision)
R, R_path = self._get_signal(uid, instrument, ":R", revision)
z = R * 0.0
x = deepcopy(R)
y = 0

for q in quantities:
qval, q_path = self._get_signal(
Expand Down Expand Up @@ -1170,6 +1170,7 @@ def get_revision_name(self, revision) -> str:
"""Return string defining RUN## or BEST if revision = 0"""

if type(revision) == int:
rev_str = ""
if revision < 0:
rev_str = ""
elif revision == 0:
Expand Down
Loading
Loading