Skip to content

Commit

Permalink
Merge branch 'st40' into marcosertoli/neutral_beam
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosertoli committed Oct 9, 2023
2 parents 4a7e8e9 + 683004d commit f8dbde8
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 90 deletions.
216 changes: 132 additions & 84 deletions indica/plotters/plot_time_evolution.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from copy import deepcopy
import getpass

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from indica.numpy_typing import ArrayLike
from indica.readers.read_st40 import ReadST40
from indica.utilities import save_figure
from indica.utilities import set_axis_sci
from indica.utilities import set_plot_colors
from indica.utilities import set_plot_rcparams
Expand All @@ -14,99 +15,146 @@
CMAP, COLORS = set_plot_colors()
LINESTYLES = ["solid", "dashed", "dotted"]
set_plot_rcparams("profiles")
QUANTITIES: list = [
"smmh:ne",
"xrcs:ti_w",
"xrcs:te_n3w",
"xrcs:te_kw",
"xrcs:spectra",
"sxrc_xy2:brightness",
"efit:ipla",
"efit:rmag",
"efit:zmag",
"sxr_spd:brightness",
# "cxff_pi:ti",
# "cxff_pi:vtor",
# "cxff_tws_c:ti",
# "cxff_tws_c:vtor",
# "ts:te",
# "ts:ne",
# "lines",
]

Y0 = {}
Y0["nirh1"] = True
Y0["smmh1"] = True
Y0["xrcs"] = True
Y0["sxr_diode_1"] = True
Y0["brems"] = True
Y0["efit"] = True
Y0["cxff_tws_c"] = True
Y0["cxff_pi"] = True

def plot_sxrc(
pulse: int = 10607, # 10605
what_to_read: list = ["sxrc_xy2:brightness"],
tplot: ArrayLike = [0.067, 0.069], # [0.07, 0.09]
save_fig=False,
plot_raw: bool = False,
xvar: str = "t",
yvar: str = "channel",
tstart=0.02,
tend=0.1,
dt: float = 0.001,
data_key="binned_data",

def plot_st40_data(
pulses: list,
tstart: float = 0.02,
tend: float = 0.1,
dt: float = 0.005,
quantities: list = [],
tplot: float = None,
save_fig: bool = False,
fig_path: str = None,
fig_style: str = "profiles",
plot_binned: bool = True,
):

instruments = []
quantities = []
linestyles = {}
for i, identifier in enumerate(what_to_read):
instr, quant = identifier.split(":")
instruments.append(instr)
quantities.append(quant)
linestyles[instr] = LINESTYLES[i]

st40 = ReadST40(pulse, tstart, tend, dt=dt)
st40(instruments=instruments)

data: dict = {}
for instr, quant in zip(instruments, quantities):
if instr not in data.keys():
data[instr] = {}
data[instr][quant] = getattr(st40, data_key)[instr][quant].transpose(yvar, xvar)
_rho, _theta = data[instr][quant].transform.convert_to_rho_theta(
t=data[instr][quant].t
)
if tplot is None:
tplot = np.mean([tstart, tend])
if len(quantities) == 0:
quantities = QUANTITIES
instruments = list(np.unique([quant.split(":")[0] for quant in quantities]))
if fig_path is None:
fig_path = f"{FIG_PATH}"

for instr, quant in zip(instruments, quantities):
plt.figure()
plot = data[instr][quant].plot(label=f"{instr}:{quant}")
set_axis_sci(plot_object=plot)
plt.title(f"{instr.upper()} for pulse {pulse}")
set_plot_rcparams(fig_style)
xr.set_options(keep_attrs=True)
colors = CMAP(np.linspace(0.75, 0.1, len(pulses), dtype=float))

if tplot is not None:
cols = CMAP(np.linspace(0.75, 0.1, np.size(tplot), dtype=float))
raw: dict = {quant: {} for quant in quantities}
binned = deepcopy(raw)
data = {"raw": raw, "binned": binned}
for i, pulse in enumerate(pulses):
fig_path += f"_{pulse}"
st40 = ReadST40(pulse, tstart=tstart, tend=tend, dt=dt)
st40(instruments)
for quantity in quantities:
instr, quant = quantity.split(":")
data["binned"][quantity][pulse] = st40.binned_data[instr][quant]
data["raw"][quantity][pulse] = st40.raw_data[instr][quant]

for quantity in quantities:
print(quantity)
instr, _ = quantity.split(":")
plt.figure()
for icol, t in enumerate(tplot):
for instr, quant in zip(instruments, quantities):
_data = data[instr][quant].sel(t=t, method="nearest")
if "error" in _data.attrs:
_err = (_data.error + _data.stdev).sel(t=t, method="nearest")
else:
_err = xr.full_like(_data, 0.0)
_R_diff = (
_data.transform.impact_parameter.value
- _data.transform.equilibrium.rmag
).sel(t=t, method="nearest")
plt.fill_between(
_R_diff.values,
_data.values + _err.values,
_data.values - _err.values,
color=cols[icol],
alpha=0.5,
)
plt.plot(
_R_diff.values,
_data.values,
label=f"t={t:.3f}",
linestyle=linestyles[instr],
color=cols[icol],
for i, pulse in enumerate(pulses):
color = colors[i]
if len(data["raw"][quantity].keys()) == 0:
continue
plot_data(data, quantity, pulse, tplot, key="raw", color=color)
if plot_binned:
plot_data(data, quantity, pulse, tplot, key="binned", color=color)

set_axis_sci()
if instr in Y0.keys():
plt.ylim(
0,
)
plt.xlabel("Impact R - R$_{mag}$")
plt.ylabel(f"{_data.long_name} [{_data.units}]")
plt.title(f"{instr.upper()} for pulse {pulse}")
plt.legend()
set_axis_sci()

return st40


# def plot_time_surface(
# pulse: int = 10605,
# instruments: list = ["sxrc_xy1"],
# quantity: str = "brightness",
# tplot: ArrayLike = None,
# save_fig=False,
# plot_raw: bool = False,
# xvar: str = "time",
# yvar: str = "channel",
# ):
plt.legend()
plt.autoscale()
save_figure(fig_path, f"{quantity}", save_fig=save_fig)


def plot_data(data, quantity: str, pulse: int, tplot: float, key="raw", color=None):
str_to_add = ""
instr, quant = quantity.split(":")
if key == "raw":
marker = None
else:
marker = "o"

_data = data[key][quantity][pulse]
tslice = slice(_data.t.min().values, _data.t.max().values)
if "error" not in _data.attrs:
_data.attrs["error"] = xr.full_like(_data, 0.0)
if "stdev" not in _data.attrs:
_data.attrs["stdev"] = xr.full_like(_data, 0.0)
_err = np.sqrt(_data.error**2 + _data.stdev**2)
_err = xr.where(_err / _data.values < 1.0, _err, 0.0)
if len(_data.dims) > 1:
str_to_add = f" @ {tplot:.3f} s"
tslice = _data.t.sel(t=tplot, method="nearest")

_data = _data.sel(t=tslice)
_err = _err.sel(t=tslice)
if instr in "xrcs" and quant == "spectra":
bgnd = _data.sel(wavelength=slice(0.393, 0.388)).mean("wavelength")
_data -= bgnd
label = None
if key == "raw":
label = str(pulse)
alpha = 0.5
plt.fill_between(
_data.coords[_data.dims[0]].values,
_data.values - _err.values,
_data.values + _err.values,
color=color,
alpha=alpha,
)
if key == "binned":
alpha = 0.8
plt.errorbar(
_data.coords[_data.dims[0]].values,
_data.values,
_err.values,
color=color,
alpha=alpha,
)
_data.plot(label=label, color=color, alpha=alpha, marker=marker)
plt.title(f"{instr.upper()} {quant}" + str_to_add)


if __name__ == "__main__":
plt.ioff()
plot_sxrc()
plot_st40_data([11226, 11225])
plt.show()
14 changes: 10 additions & 4 deletions indica/readers/read_st40.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,12 @@ def __call__(
instruments = INSTRUMENTS
if revisions is None:
revisions = {instrument: 0 for instrument in instruments}
for instr in instruments:
if instr not in revisions.keys():
revisions[instr] = 0
if "efit" not in revisions:
revisions["efit"] = 0

if tstart is None:
tstart = self.tstart
if tend is None:
Expand All @@ -326,14 +332,14 @@ def __call__(
dt = self.dt

self.reset_data()
self.get_equilibrium(R_shift=R_shift)
for i, instrument in enumerate(instruments):
self.get_equilibrium(R_shift=R_shift, revision=revisions["efit"])
for instrument in instruments:
print(f"Reading {instrument}")
if debug:
self.get_raw_data("", instrument, revisions[i])
self.get_raw_data("", instrument, revisions[instrument])
else:
try:
self.get_raw_data("", instrument, revisions[i])
self.get_raw_data("", instrument, revisions[instrument])
except Exception as e:
print(f"Error reading {instrument}: {e}")

Expand Down
2 changes: 1 addition & 1 deletion indica/readers/st40reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def __init__(
pulse: int,
tstart: float,
tend: float,
server: str = "192.168.1.7", # 192.168.1.7 10.0.40.13
server: str = "smaug",
tree: str = "ST40",
default_error: float = 0.05,
max_freq: float = 1e6,
Expand Down
2 changes: 1 addition & 1 deletion indica/writers/bda_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def check_analysis_run(
which_run,
):
# Checker function to see if data already exists in a run
IP_address_smaug = "192.168.1.7:8000"
IP_address_smaug = "smaug"
conn = Connection(IP_address_smaug)
conn.openTree("BDA", pulseNo)

Expand Down

0 comments on commit f8dbde8

Please sign in to comment.