-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_ttd_targets.py
65 lines (49 loc) · 2.15 KB
/
generate_ttd_targets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from os.path import join
import yaml
import argparse
import logging
import numpy as np
import h5py
logging.basicConfig(filename="instantiate.log",
format="%(asctime)s %(message)s",
encoding="utf-8",
level=logging.INFO)
parser = argparse.ArgumentParser(
prog="generat_ttd_targets.py",
description="Generate time-to-disruption target for each shot in the dataset")
parser.add_argument("--dataset_def", type=str,
help="YAML file that contains definition of the dataset")
parser.add_argument("--destination", type=str,
help="Destination for Dataset HDF5 files")
args = parser.parse_args()
with open(args.dataset_def, 'r') as fp:
dataset_def = yaml.safe_load(fp)
for shotnr in dataset_def["shots"].keys():
# Iterate over the target variables and find the longest time base
# of the signals. Use this timebase to generate a ttd target
with h5py.File(join(args.destination, f"{shotnr}.h5"), "a") as df:
if dataset_def["shots"][shotnr]["ttd"] > 0.0:
# Shot is disruptive, target is now a count-down
tb = np.arange(0.0, dataset_def["shots"][shotnr]["ttd"], 1.0, dtype=np.float32)
target = tb.max() - tb
target = np.clip(target, 0.0, dataset_def["ttd_max"])
else:
# Shot is not disruptive. TTD is ttd_max (defined in yaml file)
tb = np.arange(0.0, df.attrs["tmax"], 1.0)
target = dataset_def["ttd_max"] * np.ones_like(tb)
target = np.log10(target + 0.1)
logging.info(f"target_ttd = {target.shape}")
try:
if df["target_ttd"]["xdata"].size > 0:
logging.info(f"TTD for shot {shotnr} already exists.")
continue
except KeyError:
pass
# This is only executed if the try/except block abvove passes
# Write new ttd into hdf5 file
grp_t = df.create_group("target_ttd")
grp_t.create_dataset("xdata", data=tb.astype(np.float32))
grp_t.create_dataset("zdata", data=target.astype(np.float32))
# end of file generate_ttd_targets.py