From 4cbfe2d395ceaa4834487618fa4404ff8783d888 Mon Sep 17 00:00:00 2001 From: Anmol Date: Mon, 15 Jul 2024 18:43:37 +0200 Subject: [PATCH 01/10] first commit --- pyparamgui/__init__.py | 2 +- pyparamgui/config.py | 46 +++++++ pyparamgui/generator.py | 265 ++++++++++++++++++++++++++++++++++++++++ pyparamgui/io.py | 51 ++++++++ pyparamgui/widget.py | 113 +++++++++++++++++ 5 files changed, 476 insertions(+), 1 deletion(-) create mode 100644 pyparamgui/config.py create mode 100644 pyparamgui/generator.py create mode 100644 pyparamgui/io.py create mode 100644 pyparamgui/widget.py diff --git a/pyparamgui/__init__.py b/pyparamgui/__init__.py index 70f39a4..89747ca 100644 --- a/pyparamgui/__init__.py +++ b/pyparamgui/__init__.py @@ -3,5 +3,5 @@ from __future__ import annotations __author__ = """Anmol Bhatia""" -__email__ = "a.bhatia2@student.vu.nl" +__email__ = "anmolbhatia05@gmail.com" __version__ = "0.0.1" diff --git a/pyparamgui/config.py b/pyparamgui/config.py new file mode 100644 index 0000000..685cdda --- /dev/null +++ b/pyparamgui/config.py @@ -0,0 +1,46 @@ +from typing import Dict + +import numpy as np +from pydantic import BaseModel, ConfigDict + +class KineticParameters(BaseModel): + decay_rates: list[float] + +class SpectralParameters(BaseModel): + amplitudes: list[float] + location_mean: list[float] + width: list[float] + skewness: list[float] + +class TimeCoordinates(BaseModel): + timepoints_max: int + timepoints_stepsize: float + +class SpectralCoordinates(BaseModel): + wavelength_min: int + wavelength_max: int + wavelength_stepsize: int + +def generate_simulation_coordinates(time_coordinates: TimeCoordinates, spectral_coordinates: SpectralCoordinates) -> Dict[str, np.ndarray]: + time_axis = np.arange(0, time_coordinates.timepoints_max * time_coordinates.timepoints_stepsize, time_coordinates.timepoints_stepsize) + spectral_axis = np.arange(spectral_coordinates.wavelength_min, spectral_coordinates.wavelength_max, spectral_coordinates.wavelength_stepsize) + return {"time": time_axis, "spectral": spectral_axis} + +class Settings(BaseModel): + stdev_noise: float + seed: int + add_gaussian_irf: bool = False + use_sequential_scheme: bool = False + +class IRF(BaseModel): + center: float = 0 + width: float = 0 + +class SimulationConfig(BaseModel): + kinetic_parameters: KineticParameters + spectral_parameters: SpectralParameters + coordinates: Dict[str, np.ndarray] + settings: Settings + irf: IRF + + model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/pyparamgui/generator.py b/pyparamgui/generator.py new file mode 100644 index 0000000..bc1a4fd --- /dev/null +++ b/pyparamgui/generator.py @@ -0,0 +1,265 @@ +"""The glotaran generator module.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any +from typing import TypedDict +from typing import cast + +from glotaran.builtin.io.yml.utils import write_dict +from glotaran.builtin.megacomplexes.decay import DecayParallelMegacomplex +from glotaran.builtin.megacomplexes.decay import DecaySequentialMegacomplex +from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex +from glotaran.model import Model + + +def _generate_decay_model( + *, nr_compartments: int, irf: bool, spectral: bool, decay_type: str +) -> dict[str, Any]: + """Generate a decay model dictionary. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + spectral : bool + Whether to add a spectral model. + decay_type : str + The dype of the decay + + Returns + ------- + dict[str, Any] : + The generated model dictionary. + """ + compartments = [f"species_{i+1}" for i in range(nr_compartments)] + rates = [f"rates.species_{i+1}" for i in range(nr_compartments)] + model = { + "megacomplex": { + f"megacomplex_{decay_type}_decay": { + "type": f"decay-{decay_type}", + "compartments": compartments, + "rates": rates, + }, + }, + "dataset": {"dataset_1": {"megacomplex": [f"megacomplex_{decay_type}_decay"]}}, + } + if spectral: + model["megacomplex"]["megacomplex_spectral"] = { # type:ignore[index] + "type": "spectral", + "shape": { + compartment: f"shape_species_{i+1}" for i, compartment in enumerate(compartments) + }, + } + model["shape"] = { + f"shape_species_{i+1}": { + "type": "gaussian", + "amplitude": f"shapes.species_{i+1}.amplitude", + "location": f"shapes.species_{i+1}.location", + "width": f"shapes.species_{i+1}.width", + } + for i in range(nr_compartments) + } + model["dataset"]["dataset_1"]["global_megacomplex"] = [ # type:ignore[index] + "megacomplex_spectral" + ] + if irf: + model["dataset"]["dataset_1"]["irf"] = "gaussian_irf" # type:ignore[index] + model["irf"] = { + "gaussian_irf": {"type": "gaussian", "center": "irf.center", "width": "irf.width"}, + } + return model + + +def generate_parallel_decay_model( + *, nr_compartments: int = 1, irf: bool = False +) -> dict[str, Any]: + """Generate a parallel decay model dictionary. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + + Returns + ------- + dict[str, Any] : + The generated model dictionary. + """ + return _generate_decay_model( + nr_compartments=nr_compartments, irf=irf, spectral=False, decay_type="parallel" + ) + + +def generate_parallel_spectral_decay_model( + *, nr_compartments: int = 1, irf: bool = False +) -> dict[str, Any]: + """Generate a parallel spectral decay model dictionary. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + + Returns + ------- + dict[str, Any] : + The generated model dictionary. + """ + return _generate_decay_model( + nr_compartments=nr_compartments, irf=irf, spectral=True, decay_type="parallel" + ) + + +def generate_sequential_decay_model(nr_compartments: int = 1, irf: bool = False) -> dict[str, Any]: + """Generate a sequential decay model dictionary. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + + Returns + ------- + dict[str, Any] : + The generated model dictionary. + """ + return _generate_decay_model( + nr_compartments=nr_compartments, irf=irf, spectral=False, decay_type="sequential" + ) + + +def generate_sequential_spectral_decay_model( + *, nr_compartments: int = 1, irf: bool = False +) -> dict[str, Any]: + """Generate a sequential spectral decay model dictionary. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + + Returns + ------- + dict[str, Any] : + The generated model dictionary. + """ + return _generate_decay_model( + nr_compartments=nr_compartments, irf=irf, spectral=True, decay_type="sequential" + ) + + +generators: dict[str, Callable] = { + "decay_parallel": generate_parallel_decay_model, + "spectral_decay_parallel": generate_parallel_spectral_decay_model, + "decay_sequential": generate_sequential_decay_model, + "spectral_decay_sequential": generate_sequential_spectral_decay_model, +} + +available_generators: list[str] = list(generators.keys()) + + +class GeneratorArguments(TypedDict, total=False): + """Arguments used by ``generate_model`` and ``generate_model``. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + + See Also + -------- + generate_model + generate_model_yml + """ + + nr_compartments: int + irf: bool + + +def generate_model(*, generator_name: str, generator_arguments: GeneratorArguments) -> Model: + """Generate a model. + + Parameters + ---------- + generator_name : str + The generator to use. + generator_arguments : GeneratorArguments + Arguments for the generator. + + Returns + ------- + Model + The generated model + + See Also + -------- + generate_parallel_decay_model + generate_parallel_spectral_decay_model + generate_sequential_decay_model + generate_sequential_spectral_decay_model + + Raises + ------ + ValueError + Raised when an unknown generator is specified. + """ + if generator_name not in generators: + raise ValueError( + f"Unknown model generator '{generator_name}'. " + f"Known generators are: {list(generators.keys())}" + ) + model = generators[generator_name](**generator_arguments) + print(model) + return Model.create_class_from_megacomplexes( + [DecayParallelMegacomplex, DecaySequentialMegacomplex, SpectralMegacomplex] + )(**model) + + +def generate_model_yml(*, generator_name: str, generator_arguments: GeneratorArguments) -> str: + """Generate a model as yml string. + + Parameters + ---------- + generator_name : str + The generator to use. + generator_arguments : GeneratorArguments + Arguments for the generator. + + Returns + ------- + str + The generated model yml string. + + See Also + -------- + generate_parallel_decay_model + generate_parallel_spectral_decay_model + generate_sequential_decay_model + generate_sequential_spectral_decay_model + + Raises + ------ + ValueError + Raised when an unknown generator is specified. + """ + if generator_name not in generators: + raise ValueError( + f"Unknown model generator '{generator_name}'. " + f"Known generators are: {list(generators.keys())}" + ) + model = generators[generator_name](**generator_arguments) + return cast(str, write_dict(model)) diff --git a/pyparamgui/io.py b/pyparamgui/io.py new file mode 100644 index 0000000..d26c7b3 --- /dev/null +++ b/pyparamgui/io.py @@ -0,0 +1,51 @@ +from typing import Any, Dict, Union +import yaml + +import numpy as np + +from pyparamgui.generator import generate_model +from pyparamgui.config import SimulationConfig, Settings +from glotaran.model.model import Model +from glotaran.builtin.io.yml.yml import save_model +from glotaran.parameter.parameters import Parameters +from glotaran.plugin_system.project_io_registration import save_parameters +from glotaran.plugin_system.data_io_registration import save_dataset +from glotaran.simulation.simulation import simulate + +def _generate_model_file(simulation_config: SimulationConfig, nr_compartments: int, file_name: str) -> Model: + generator_name = "spectral_decay_sequential" if simulation_config.settings.use_sequential_scheme else "spectral_decay_parallel" + model = generate_model(generator_name=generator_name, generator_arguments={"nr_compartments": nr_compartments, "irf": simulation_config.settings.add_gaussian_irf}) + save_model(model, "temp_model.yml", allow_overwrite=True) + _sanitize_yaml_file("temp_model.yml", file_name) + return model + +def _generate_parameter_file(model: Model, file_name: str) -> Parameters: + parameters = model.generate_parameters() + model.validate(parameters) + save_parameters(parameters, file_name, allow_overwrite=True) + return parameters + +def _generate_data_file(model: Model, parameters: Parameters, coordinates: Dict[str, np.ndarray], settings: Settings, file_name: str): + noise = False if settings.stdev_noise == 0 else True + data = simulate(model, "dataset_1", parameters, coordinates, noise=noise, noise_std_dev=settings.stdev_noise, noise_seed=settings.seed) + save_dataset(data, file_name, "nc", allow_overwrite=True) + +def generate_model_parameter_and_data_files(simulation_config: SimulationConfig, model_file_name: str = "model.yml", parameter_file_name: str = "parameters.csv", data_file_name: str = "dataset.nc"): + nr_compartments = len(simulation_config.kinetic_parameters.decay_rates) + model = _generate_model_file(simulation_config, nr_compartments, model_file_name) + parameters = _generate_parameter_file(model, parameter_file_name) + _generate_data_file(model, parameters, simulation_config.coordinates, simulation_config.settings, data_file_name) + +def _sanitize_dict(d: Union[Dict[str, Any], Any]) -> Union[Dict[str, Any], Any]: + if not isinstance(d, dict): + return d + return {k: _sanitize_dict(v) for k, v in d.items() if v not in (None, [], {})} + +def _sanitize_yaml_file(input_file: str, output_file: str) -> None: + with open(input_file, 'r') as f: + data = yaml.safe_load(f) + + sanitized_data = _sanitize_dict(data) + + with open(output_file, 'w') as f: + yaml.safe_dump(sanitized_data, f) \ No newline at end of file diff --git a/pyparamgui/widget.py b/pyparamgui/widget.py new file mode 100644 index 0000000..5f61361 --- /dev/null +++ b/pyparamgui/widget.py @@ -0,0 +1,113 @@ +import panel as pn +import numpy as np + +from pyparamgui.config import KineticParameters, SpectralParameters, TimeCoordinates, SpectralCoordinates, Settings, IRF, SimulationConfig, generate_simulation_coordinates +from pyparamgui.io import generate_model_parameter_and_data_files + +class SimulationWidget: + def __init__(self): + pn.extension() + self.simulation_config = None + self.decay_rates = pn.widgets.TextInput(name='Decay rates', value='0.055, 0.005') + self.amplitudes = pn.widgets.TextInput(name='Amplitudes', value='1., 1.') + self.location = pn.widgets.TextInput(name='Location (mean) of spectra', value='22000, 20000') + self.width = pn.widgets.TextInput(name='Width of spectra', value='4000, 3500') + self.skewness = pn.widgets.TextInput(name='Skewness of spectra', value='0.1, -0.1') + self.timepoints_max = pn.widgets.IntInput(name='Timepoints, max', value=80) + self.timepoints_stepsize = pn.widgets.FloatInput(name='Stepsize', value=1) + self.wavelength_min = pn.widgets.IntInput(name='Wavelength Min', value=400) + self.wavelength_max = pn.widgets.IntInput(name='Wavelength Max', value=600) + self.wavelength_stepsize = pn.widgets.IntInput(name='Stepsize', value=5) + self.stdev_noise = pn.widgets.FloatInput(name='Std.dev. noise', value=0.01) + self.seed = pn.widgets.IntInput(name='Seed', value=123) + self.add_gaussian_irf = pn.widgets.Checkbox(name='Add Gaussian IRF') + self.irf_location = pn.widgets.FloatInput(name='IRF location', value=0) + self.irf_width = pn.widgets.FloatInput(name='IRF width', value=0) + self.use_sequential_scheme = pn.widgets.Checkbox(name='Use a sequential scheme') + self.model_file_name = pn.widgets.TextInput(name='Model File Name', value='model.yml') + self.parameter_file_name = pn.widgets.TextInput(name='Parameter File Name', value='parameters.csv') + self.data_file_name = pn.widgets.TextInput(name='Data File Name', value='dataset.nc') + self.button = pn.widgets.Button(name='Simulate', button_type='primary') + self.output_pane = pn.pane.Markdown("") + + self.widget = pn.Column( + self.decay_rates, self.amplitudes, self.location, self.width, self.skewness, + pn.Row(self.timepoints_max, self.timepoints_stepsize), + pn.Row(self.wavelength_min, self.wavelength_max, self.wavelength_stepsize), + self.stdev_noise, self.seed, self.add_gaussian_irf, self.irf_location, self.irf_width, + self.use_sequential_scheme, + pn.Row(self.model_file_name, self.parameter_file_name, self.data_file_name), + self.button, self.output_pane + ) + + self.button.on_click(self.callback) + + def callback(self, event): + try: + decay_rates = np.fromstring(self.decay_rates.value, sep=',') + amplitudes = np.fromstring(self.amplitudes.value, sep=',') + location = np.fromstring(self.location.value, sep=',') + width = np.fromstring(self.width.value, sep=',') + skewness = np.fromstring(self.skewness.value, sep=',') + + valid_input = True + messages = [] + + if self.wavelength_min.value >= self.wavelength_max.value or self.timepoints_max.value <= 0: + valid_input = False + messages.append("Invalid timepoints or wavelength specification") + + lengths = {len(decay_rates), len(amplitudes), len(location), len(width), len(skewness)} + if len(lengths) > 1: + valid_input = False + messages.append("Parameter fields of unequal length") + + if not valid_input: + self.output_pane.object = pn.pane.Markdown('\n'.join(f"**Error:** {msg}" for msg in messages)) + else: + self.simulation_config = SimulationConfig( + kinetic_parameters=KineticParameters( + decay_rates=decay_rates.tolist() + ), + spectral_parameters=SpectralParameters( + amplitudes=amplitudes.tolist(), + location_mean=location.tolist(), + width=width.tolist(), + skewness=skewness.tolist() + ), + coordinates=generate_simulation_coordinates( + TimeCoordinates( + timepoints_max=self.timepoints_max.value, + timepoints_stepsize=self.timepoints_stepsize.value + ), + SpectralCoordinates( + wavelength_min=self.wavelength_min.value, + wavelength_max=self.wavelength_max.value, + wavelength_stepsize=self.wavelength_stepsize.value + ) + ), + settings=Settings( + stdev_noise=self.stdev_noise.value, + seed=self.seed.value, + add_gaussian_irf=self.add_gaussian_irf.value, + use_sequential_scheme=self.use_sequential_scheme.value + ), + irf=IRF( + center=self.irf_location.value, + width=self.irf_width.value + ) + ) + generate_model_parameter_and_data_files( + self.simulation_config, + model_file_name=self.model_file_name.value, + parameter_file_name=self.parameter_file_name.value, + data_file_name=self.data_file_name.value + ) + self.output_pane.object = "**Simulation successful!**\n\n**Files created!**" + + except Exception as e: + self.output_pane.object = f"**Error in simulation:** {str(e)}" + + +simulation_form = SimulationWidget() +simulation_form.widget.servable() \ No newline at end of file From e41531f334f57e3513782881440cb527ac8068e9 Mon Sep 17 00:00:00 2001 From: anmolbhatia05 Date: Fri, 26 Jul 2024 17:41:21 +0200 Subject: [PATCH 02/10] added comments, python docstrings and JSDoc strings --- .DS_Store | Bin 0 -> 6148 bytes .cruft.json | 4 +- pyparamgui/.DS_Store | Bin 0 -> 6148 bytes pyparamgui/__init__.py | 4 + pyparamgui/__main__.py | 1 - pyparamgui/config.py | 46 -------- pyparamgui/generator.py | 6 +- pyparamgui/io.py | 51 -------- pyparamgui/schema.py | 114 ++++++++++++++++++ pyparamgui/static/form.css | 26 +++++ pyparamgui/static/form.js | 231 +++++++++++++++++++++++++++++++++++++ pyparamgui/utils.py | 169 +++++++++++++++++++++++++++ pyparamgui/widget.py | 214 +++++++++++++++++----------------- pyproject.toml | 6 +- requirements_pinned.txt | 4 + 15 files changed, 670 insertions(+), 206 deletions(-) create mode 100644 .DS_Store create mode 100644 pyparamgui/.DS_Store delete mode 100644 pyparamgui/__main__.py delete mode 100644 pyparamgui/config.py delete mode 100644 pyparamgui/io.py create mode 100644 pyparamgui/schema.py create mode 100644 pyparamgui/static/form.css create mode 100644 pyparamgui/static/form.js create mode 100644 pyparamgui/utils.py diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..dfd4cf418ab9094c2ac66732970b49704f3ecfed GIT binary patch literal 6148 zcmeHK%Z?I36ureT(#-fEM&oQuI&n)z291xnK^~e(6q9Mxg&J(=mT9NyMrcGsNLcGX z_zQ0R692`Oo?BH$Xqc@MQ&L}bAkvpY!E7`L;pSVq^J z0}9zkk9-O#rXKCDM4JkufKg!66yR_7EOqIaT#E4B{`r3UPi>oPimOSh> zT^cQqRD4dw<>(Q5tzo21a;TAt{*j8d=rPndffjElz<%a0$7gz}_#GP1N4QHvcpoZ` zDOl!6?Y=B%x0chHj0RfoZ=FFHjFM96yU1*1x6kCPoV9Dca*pGPleo!vRCR|hxYUZ{ zu=6|bcAf?Pq+PgFi=)I1B7Z0ef#)OS0gN8U(s9o1mvt$e$%H=ULbZWrxx z`Od6pPw$pXMf>KBTeDf-x_s^Wy`%O?*pK2*cz2k;QEUem>R`kKh9(%g>jqX|cC$MET%?pNiy_v>1uv2AiHkYj=akXohAKFbez|1$ck3 z;fxiHbA@v2Kq9XIz$~hzAoQCZzw`O9sS!foJ2*TNsR(Vfn^1@RI$M4f8+G~|1!yxi~>f1|4IRotvmH9UP+&= zD=)`qtqp$-XXCiJLPBb5PhQ*Bv4RV?k^zmAFKomYkz=>5<@~fW1yhB&cYw9Z$2~}2d#z9keRnT zJMV^j#qBNtnQfmgfeC;iT@mjZ`ljpZJv)nxqS!N5*x-)8DPDWk#}mpO8E1tT?DKaw z>*jH>ZWgT2zjsBSewuLvjPZ;mIZM3ZfqfgI_C4Uq#{tjC8B&2%AQeajQh|S~0QYRQ z;m9##Dv%1K0zV4q_o2`gYhdeWpAH5+zU%rxQH^7}OAt%UYhdfh6`DAe=v0X-hB%%1 z5_L7Ob#yw!I`cWvS>lEw)}8rc<&f%_F%?J!x(b~8bfWwJHT}f>ze~zlDv%2NDFtLO zzn#zcO0l>0UQYMgLcgbf8*8ncp>qh}M7QFLy}F`f*44n)(df(@otOs!)g>(z_zeZV E0AsKz&j0`b literal 0 HcmV?d00001 diff --git a/pyparamgui/__init__.py b/pyparamgui/__init__.py index 89747ca..13ebf4e 100644 --- a/pyparamgui/__init__.py +++ b/pyparamgui/__init__.py @@ -5,3 +5,7 @@ __author__ = """Anmol Bhatia""" __email__ = "anmolbhatia05@gmail.com" __version__ = "0.0.1" + +from .widget import widget, setup_widget_observer + +__all__ = ['widget', 'setup_widget_observer'] diff --git a/pyparamgui/__main__.py b/pyparamgui/__main__.py deleted file mode 100644 index dd0b80e..0000000 --- a/pyparamgui/__main__.py +++ /dev/null @@ -1 +0,0 @@ -"""Main module.""" diff --git a/pyparamgui/config.py b/pyparamgui/config.py deleted file mode 100644 index 685cdda..0000000 --- a/pyparamgui/config.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Dict - -import numpy as np -from pydantic import BaseModel, ConfigDict - -class KineticParameters(BaseModel): - decay_rates: list[float] - -class SpectralParameters(BaseModel): - amplitudes: list[float] - location_mean: list[float] - width: list[float] - skewness: list[float] - -class TimeCoordinates(BaseModel): - timepoints_max: int - timepoints_stepsize: float - -class SpectralCoordinates(BaseModel): - wavelength_min: int - wavelength_max: int - wavelength_stepsize: int - -def generate_simulation_coordinates(time_coordinates: TimeCoordinates, spectral_coordinates: SpectralCoordinates) -> Dict[str, np.ndarray]: - time_axis = np.arange(0, time_coordinates.timepoints_max * time_coordinates.timepoints_stepsize, time_coordinates.timepoints_stepsize) - spectral_axis = np.arange(spectral_coordinates.wavelength_min, spectral_coordinates.wavelength_max, spectral_coordinates.wavelength_stepsize) - return {"time": time_axis, "spectral": spectral_axis} - -class Settings(BaseModel): - stdev_noise: float - seed: int - add_gaussian_irf: bool = False - use_sequential_scheme: bool = False - -class IRF(BaseModel): - center: float = 0 - width: float = 0 - -class SimulationConfig(BaseModel): - kinetic_parameters: KineticParameters - spectral_parameters: SpectralParameters - coordinates: Dict[str, np.ndarray] - settings: Settings - irf: IRF - - model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/pyparamgui/generator.py b/pyparamgui/generator.py index bc1a4fd..d278ae1 100644 --- a/pyparamgui/generator.py +++ b/pyparamgui/generator.py @@ -56,16 +56,19 @@ def _generate_decay_model( } model["shape"] = { f"shape_species_{i+1}": { - "type": "gaussian", + "type": "skewed-gaussian", "amplitude": f"shapes.species_{i+1}.amplitude", "location": f"shapes.species_{i+1}.location", "width": f"shapes.species_{i+1}.width", + "skewness": f"shapes.species_{i+1}.skewness" } for i in range(nr_compartments) } model["dataset"]["dataset_1"]["global_megacomplex"] = [ # type:ignore[index] "megacomplex_spectral" ] + model["dataset"]["dataset_1"]["spectral_axis_inverted"] = True + model["dataset"]["dataset_1"]["spectral_axis_scale"] = 1E7 if irf: model["dataset"]["dataset_1"]["irf"] = "gaussian_irf" # type:ignore[index] model["irf"] = { @@ -223,7 +226,6 @@ def generate_model(*, generator_name: str, generator_arguments: GeneratorArgumen f"Known generators are: {list(generators.keys())}" ) model = generators[generator_name](**generator_arguments) - print(model) return Model.create_class_from_megacomplexes( [DecayParallelMegacomplex, DecaySequentialMegacomplex, SpectralMegacomplex] )(**model) diff --git a/pyparamgui/io.py b/pyparamgui/io.py deleted file mode 100644 index d26c7b3..0000000 --- a/pyparamgui/io.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Any, Dict, Union -import yaml - -import numpy as np - -from pyparamgui.generator import generate_model -from pyparamgui.config import SimulationConfig, Settings -from glotaran.model.model import Model -from glotaran.builtin.io.yml.yml import save_model -from glotaran.parameter.parameters import Parameters -from glotaran.plugin_system.project_io_registration import save_parameters -from glotaran.plugin_system.data_io_registration import save_dataset -from glotaran.simulation.simulation import simulate - -def _generate_model_file(simulation_config: SimulationConfig, nr_compartments: int, file_name: str) -> Model: - generator_name = "spectral_decay_sequential" if simulation_config.settings.use_sequential_scheme else "spectral_decay_parallel" - model = generate_model(generator_name=generator_name, generator_arguments={"nr_compartments": nr_compartments, "irf": simulation_config.settings.add_gaussian_irf}) - save_model(model, "temp_model.yml", allow_overwrite=True) - _sanitize_yaml_file("temp_model.yml", file_name) - return model - -def _generate_parameter_file(model: Model, file_name: str) -> Parameters: - parameters = model.generate_parameters() - model.validate(parameters) - save_parameters(parameters, file_name, allow_overwrite=True) - return parameters - -def _generate_data_file(model: Model, parameters: Parameters, coordinates: Dict[str, np.ndarray], settings: Settings, file_name: str): - noise = False if settings.stdev_noise == 0 else True - data = simulate(model, "dataset_1", parameters, coordinates, noise=noise, noise_std_dev=settings.stdev_noise, noise_seed=settings.seed) - save_dataset(data, file_name, "nc", allow_overwrite=True) - -def generate_model_parameter_and_data_files(simulation_config: SimulationConfig, model_file_name: str = "model.yml", parameter_file_name: str = "parameters.csv", data_file_name: str = "dataset.nc"): - nr_compartments = len(simulation_config.kinetic_parameters.decay_rates) - model = _generate_model_file(simulation_config, nr_compartments, model_file_name) - parameters = _generate_parameter_file(model, parameter_file_name) - _generate_data_file(model, parameters, simulation_config.coordinates, simulation_config.settings, data_file_name) - -def _sanitize_dict(d: Union[Dict[str, Any], Any]) -> Union[Dict[str, Any], Any]: - if not isinstance(d, dict): - return d - return {k: _sanitize_dict(v) for k, v in d.items() if v not in (None, [], {})} - -def _sanitize_yaml_file(input_file: str, output_file: str) -> None: - with open(input_file, 'r') as f: - data = yaml.safe_load(f) - - sanitized_data = _sanitize_dict(data) - - with open(output_file, 'w') as f: - yaml.safe_dump(sanitized_data, f) \ No newline at end of file diff --git a/pyparamgui/schema.py b/pyparamgui/schema.py new file mode 100644 index 0000000..f470fb7 --- /dev/null +++ b/pyparamgui/schema.py @@ -0,0 +1,114 @@ +"""This module has the different model classes representing different parameters, coordinates, and settings for simulation.""" + +from __future__ import annotations + +from typing import Dict + +import numpy as np +from pydantic import BaseModel, ConfigDict + +class KineticParameters(BaseModel): + """Kinetic parameters for the simulation. + + Attributes: + decay_rates (list[float]): List of decay rates. + """ + decay_rates: list[float] + +class SpectralParameters(BaseModel): + """Spectral parameters for the simulation. + + Attributes: + amplitude (list[float]): List of amplitudes. + location (list[float]): List of locations. + width (list[float]): List of widths. + skewness (list[float]): List of skewness values. + """ + amplitude: list[float] + location: list[float] + width: list[float] + skewness: list[float] + +class TimeCoordinates(BaseModel): + """ + Time coordinates for the simulation. + + Attributes: + timepoints_max (int): Maximum number of time points. + timepoints_stepsize (float): Step size between time points. + """ + timepoints_max: int + timepoints_stepsize: float + +class SpectralCoordinates(BaseModel): + """ + Spectral coordinates for the simulation. + + Attributes: + wavelength_min (int): Minimum wavelength. + wavelength_max (int): Maximum wavelength. + wavelength_stepsize (float): Step size between wavelengths. + """ + wavelength_min: int + wavelength_max: int + wavelength_stepsize: float + +def generate_simulation_coordinates(time_coordinates: TimeCoordinates, spectral_coordinates: SpectralCoordinates) -> Dict[str, np.ndarray]: + """ + Generate simulation coordinates based on time and spectral coordinates. + + Args: + time_coordinates (TimeCoordinates): The time coordinates for the simulation. + spectral_coordinates (SpectralCoordinates): The spectral coordinates for the simulation. + + Returns: + Dict[str, np.ndarray]: A dictionary containing the time and spectral axes as numpy arrays. + """ + time_axis = np.arange(0, time_coordinates.timepoints_max * time_coordinates.timepoints_stepsize, time_coordinates.timepoints_stepsize) + spectral_axis = np.arange(spectral_coordinates.wavelength_min, spectral_coordinates.wavelength_max, spectral_coordinates.wavelength_stepsize) + return {"time": time_axis, "spectral": spectral_axis} + +class Settings(BaseModel): + """ + Other settings for the simulation. + + Attributes: + stdev_noise (float): Standard deviation of the noise to be added to the simulation data. + seed (int): Seed for the random number generator to ensure reproducibility. + add_gaussian_irf (bool): Flag to indicate whether to add a Gaussian Instrument Response Function (IRF) to the simulation. Default is False. + use_sequential_scheme (bool): Flag to indicate whether to use a sequential scheme in the simulation. Default is False. + """ + stdev_noise: float + seed: int + add_gaussian_irf: bool = False + use_sequential_scheme: bool = False + +class IRF(BaseModel): + """ + Instrument Response Function (IRF) settings for the simulation. + + Attributes: + center (float): The center position of the IRF. + width (float): The width of the IRF. + """ + center: float + width: float + +class SimulationConfig(BaseModel): + """ + Configuration for the simulation, combining various parameters and settings. + + Attributes: + kinetic_parameters (KineticParameters): Kinetic parameters for the simulation. + spectral_parameters (SpectralParameters): Spectral parameters for the simulation. + coordinates (Dict[str, np.ndarray]): Dictionary containing the time and spectral axes as numpy arrays. + settings (Settings): Other settings for the simulation, including noise standard deviation, random seed, and flags for adding Gaussian IRF and using a sequential scheme. + irf (IRF): Instrument Response Function (IRF) settings, including center position and width. + """ + kinetic_parameters: KineticParameters + spectral_parameters: SpectralParameters + coordinates: Dict[str, np.ndarray] + settings: Settings + irf: IRF + + model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/pyparamgui/static/form.css b/pyparamgui/static/form.css new file mode 100644 index 0000000..4d6f0e2 --- /dev/null +++ b/pyparamgui/static/form.css @@ -0,0 +1,26 @@ +form { + display: flex; + flex-direction: column; +} +.form-group { + display: flex; + align-items: center; + margin-bottom: 10px; +} +label { + margin-right: 10px; + color: black; + width: 150px; +} +input { + flex: 1; + margin: 5px 0; +} +hr { + margin: 20px 0; + border: none; + border-top: 1px solid #ccc; +} +button { + margin-top: 10px; +} \ No newline at end of file diff --git a/pyparamgui/static/form.js b/pyparamgui/static/form.js new file mode 100644 index 0000000..0c5ca4a --- /dev/null +++ b/pyparamgui/static/form.js @@ -0,0 +1,231 @@ +function render({ model, el }) { + const form = document.createElement('form'); + + /** + * Creates a form group with a label and a text input field. + * + * @param {string} labelText - The text content for the label. + * @param {string} inputId - The id attribute for the input element. + * @param {string} inputName - The name attribute for the input element. + * @param {string} inputValue - The initial value for the input element. + * + * @returns {HTMLDivElement} The form group element containing the label and input. + */ + function createTextFormGroup(labelText, inputId, inputName, inputValue) { + const formGroup = document.createElement('div'); + formGroup.className = 'form-group'; + + const label = document.createElement('label'); + label.setAttribute('for', inputId); + label.textContent = labelText; + + const input = document.createElement('input'); + input.setAttribute('type', 'text'); + input.setAttribute('id', inputId); + input.setAttribute('name', inputName); + input.value = inputValue; + + formGroup.appendChild(label); + formGroup.appendChild(input); + + return formGroup; + } + + /** + * Creates a form group with a label and a checkbox input field. + * + * @param {string} labelText - The text content for the label. + * @param {string} inputId - The id attribute for the input element. + * @param {string} inputName - The name attribute for the input element. + * + * @returns {HTMLDivElement} The form group element containing the label and checkbox input. + */ + function createCheckboxFormGroup(labelText, inputId, inputName) { + const formGroup = document.createElement('div'); + formGroup.className = 'form-group'; + + const label = document.createElement('label'); + label.setAttribute('for', inputId); + label.textContent = labelText; + + const input = document.createElement('input'); + input.setAttribute('type', 'checkbox'); + input.setAttribute('id', inputId); + input.setAttribute('name', inputName); + + formGroup.appendChild(label); + formGroup.appendChild(input); + + return formGroup; + } + + form.appendChild(createTextFormGroup('Decay rates:', 'decay_rates_input', 'decay_rates_input', '0.055, 0.005')); + form.appendChild(document.createElement('hr')); + form.appendChild(createTextFormGroup('Amplitudes:', 'amplitude_input', 'amplitude_input', '1., 1.')); + form.appendChild(createTextFormGroup('Location (mean) of spectra:', 'location_input', 'location_input', '22000, 20000')); + form.appendChild(createTextFormGroup('Width of spectra:', 'width_input', 'width_input', '4000, 3500')); + form.appendChild(createTextFormGroup('Skewness of spectra:', 'skewness_input', 'skewness_input', '0.1, -0.1')); + form.appendChild(document.createElement('hr')); + form.appendChild(createTextFormGroup('Timepoints, max:', 'timepoints_max_input', 'timepoints_max_input', '80')); + form.appendChild(createTextFormGroup('Stepsize:', 'timepoints_stepsize_input', 'timepoints_stepsize_input', '1')); + form.appendChild(document.createElement('hr')); + form.appendChild(createTextFormGroup('Wavelength Min:', 'wavelength_min_input', 'wavelength_min_input', '400')); + form.appendChild(createTextFormGroup('Wavelength Max:', 'wavelength_max_input', 'wavelength_max_input', '600')); + form.appendChild(createTextFormGroup('Stepsize:', 'wavelength_stepsize_input', 'wavelength_stepsize_input', '5')); + form.appendChild(document.createElement('hr')); + form.appendChild(createTextFormGroup('Std.dev. noise:', 'stdev_noise_input', 'stdev_noise_input', '0.01')); + form.appendChild(createTextFormGroup('Seed:', 'seed_input', 'seed_input', '123')); + form.appendChild(document.createElement('hr')); + form.appendChild(createCheckboxFormGroup('Add Gaussian IRF:', 'add_gaussian_irf_input', 'add_gaussian_irf_input')); + form.appendChild(createTextFormGroup('IRF location:', 'irf_location_input', 'irf_location_input', '3')); + form.appendChild(createTextFormGroup('IRF width:', 'irf_width_input', 'irf_width_input', '1')); + form.appendChild(document.createElement('hr')); + form.appendChild(createCheckboxFormGroup('Use Sequential Scheme:', 'use_sequential_scheme_input', 'use_sequential_scheme_input')); + form.appendChild(document.createElement('hr')); + form.appendChild(createTextFormGroup('Model File Name:', 'model_file_name_input', 'model_file_name_input', 'model.yml')); + form.appendChild(createTextFormGroup('Parameter File Name:', 'parameter_file_name_input', 'parameter_file_name_input', 'parameters.csv')); + form.appendChild(createTextFormGroup('Data File Name:', 'data_file_name_input', 'data_file_name_input', 'dataset.nc')); + + el.appendChild(form); + + /** + * Converts the input values from the form into their respective data types. + * + * @returns {Object|null} An object containing the converted input values, or null if an error occurs. + * + * @property {number[]} decay_rates - Array of decay rates as floats. + * @property {number[]} amplitude - Array of amplitudes as floats. + * @property {number[]} location - Array of locations as floats. + * @property {number[]} width - Array of widths as floats. + * @property {number[]} skewness - Array of skewness values as floats. + * @property {number} timepoints_max - Maximum number of timepoints as an integer. + * @property {number} timepoints_stepsize - Step size for timepoints as a float. + * @property {number} wavelength_min - Minimum wavelength value as a float. + * @property {number} wavelength_max - Maximum wavelength value as a float. + * @property {number} wavelength_stepsize - Step size for wavelength as a float. + * @property {number} stdev_noise - Standard deviation of noise as a float. + * @property {number} seed - Seed for random number generation as an integer. + * @property {number} irf_location - Location of the IRF center as a float. + * @property {number} irf_width - Width of the IRF as a float. + */ + function convertInputs() { + try { + const decay_rates = decay_rates_input.value.split(',').map(parseFloat); + const amplitude = amplitude_input.value.split(',').map(parseFloat); + const location = location_input.value.split(',').map(parseFloat); + const width = width_input.value.split(',').map(parseFloat); + const skewness = skewness_input.value.split(',').map(parseFloat); + const timepoints_max = parseInt(timepoints_max_input.value, 10); + const timepoints_stepsize = parseFloat(timepoints_stepsize_input.value); + const wavelength_min = parseFloat(wavelength_min_input.value); + const wavelength_max = parseFloat(wavelength_max_input.value); + const wavelength_stepsize = parseFloat(wavelength_stepsize_input.value); + const stdev_noise = parseFloat(stdev_noise_input.value); + const seed = parseInt(seed_input.value, 10); + const irf_location = parseFloat(irf_location_input.value); + const irf_width = parseFloat(irf_width_input.value); + + return { decay_rates, amplitude, location, width, skewness, timepoints_max, timepoints_stepsize, wavelength_min, wavelength_max, wavelength_stepsize, stdev_noise, seed, irf_location, irf_width }; + } catch (error) { + alert('Error converting inputs: ' + error.message); + return null; + } + } + + /** + * Validates the input values for the simulation. + * + * @param {Object} inputs - The input values to validate. + * + * @param {number[]} inputs.decay_rates - Array of decay rates as floats. + * @param {number[]} inputs.amplitude - Array of amplitudes as floats. + * @param {number[]} inputs.location - Array of locations as floats. + * @param {number[]} inputs.width - Array of widths as floats. + * @param {number[]} inputs.skewness - Array of skewness values as floats. + * @param {number} inputs.wavelength_min - Minimum wavelength value as a float. + * @param {number} inputs.wavelength_max - Maximum wavelength value as a float. + * @param {number} inputs.timepoints_max - Maximum number of timepoints as an integer. + * + * @returns {boolean} True if all inputs are valid, otherwise false. + */ + function validateInputs(inputs) { + try { + const { decay_rates, amplitude, location, width, skewness } = inputs; + + if (decay_rates.some(isNaN)) { + alert('Invalid decay rates'); + return false; + } + if (amplitude.some(isNaN)) { + alert('Invalid amplitudes'); + return false; + } + if (location.some(isNaN)) { + alert('Invalid locations'); + return false; + } + if (width.some(isNaN)) { + alert('Invalid widths'); + return false; + } + if (skewness.some(isNaN)) { + alert('Invalid skewness values'); + return false; + } + + const lengths = [decay_rates.length, amplitude.length, location.length, width.length, skewness.length]; + if (new Set(lengths).size !== 1) { + alert('All input lists must have the same length'); + return false; + } + + if (inputs.wavelength_min >= inputs.wavelength_max || inputs.timepoints_max <= 0) { + alert('Invalid timepoints or wavelength specification'); + return false; + } + + return true; + } catch (error) { + alert('Validation error: ' + error.message); + return false; + } + } + + const btn = document.createElement("button"); + btn.textContent = 'Simulate'; + btn.addEventListener('click', function(event) { + event.preventDefault(); + + const convertedInputs = convertInputs(); + if (!convertedInputs) return; + + const isValid = validateInputs(convertedInputs); + if (!isValid) return; + + model.set("decay_rates_input", convertedInputs.decay_rates); + model.set("amplitude_input", convertedInputs.amplitude); + model.set("location_input", convertedInputs.location); + model.set("width_input", convertedInputs.width); + model.set("skewness_input", convertedInputs.skewness); + model.set("timepoints_max_input", convertedInputs.timepoints_max); + model.set("timepoints_stepsize_input", convertedInputs.timepoints_stepsize); + model.set("wavelength_min_input", convertedInputs.wavelength_min); + model.set("wavelength_max_input", convertedInputs.wavelength_max); + model.set("wavelength_stepsize_input", convertedInputs.wavelength_stepsize); + model.set("stdev_noise_input", convertedInputs.stdev_noise); + model.set("seed_input", convertedInputs.seed); + model.set("add_gaussian_irf_input", add_gaussian_irf_input.checked); + model.set("irf_location_input", convertedInputs.irf_location); + model.set("irf_width_input", convertedInputs.irf_width); + model.set("use_sequential_scheme_input", use_sequential_scheme_input.checked); + model.set("model_file_name_input", model_file_name_input.value); + model.set("parameter_file_name_input", parameter_file_name_input.value); + model.set("data_file_name_input", data_file_name_input.value); + model.set("simulate", self.crypto.randomUUID()); + + model.save_changes(); + }); + el.appendChild(btn); +} + +export default { render }; diff --git a/pyparamgui/utils.py b/pyparamgui/utils.py new file mode 100644 index 0000000..52aca49 --- /dev/null +++ b/pyparamgui/utils.py @@ -0,0 +1,169 @@ +"""This module has various utility functions related to generating files, sanitizing yaml files, etc.""" + +from __future__ import annotations + +from typing import Any, Dict, Union +import yaml +import os + +import numpy as np +from glotaran.model.model import Model +from glotaran.builtin.io.yml.yml import save_model +from glotaran.parameter.parameters import Parameters +from glotaran.plugin_system.project_io_registration import save_parameters +from glotaran.plugin_system.data_io_registration import save_dataset +from glotaran.simulation.simulation import simulate + +from pyparamgui.generator import generate_model +from pyparamgui.schema import SimulationConfig, Settings + +def _generate_model_file(simulation_config: SimulationConfig, nr_compartments: int, file_name: str) -> Model: + """ + Generate and save a model file for the simulation. + + This function generates a model based on the provided simulation configuration and number of compartments. + It saves the generated model to a temporary YAML file, sanitizes the file, and then saves it to the specified file name. + + Args: + simulation_config (SimulationConfig): The configuration for the simulation. + nr_compartments (int): The number of compartments in the model. + file_name (str): The name of the file to save the sanitized model. + + Returns: + Model: The generated model. + """ + generator_name = "spectral_decay_sequential" if simulation_config.settings.use_sequential_scheme else "spectral_decay_parallel" + model = generate_model(generator_name=generator_name, generator_arguments={"nr_compartments": nr_compartments, "irf": simulation_config.settings.add_gaussian_irf}) + save_model(model, "temp_model.yml", allow_overwrite=True) + _sanitize_yaml_file("temp_model.yml", file_name) + return model + +def _update_parameter_values(parameters: Parameters, simulation_config: SimulationConfig): + """ + Update parameter values based on the simulation configuration. + + This function iterates through all parameters and updates their values according to the + provided simulation configuration. It handles parameters related to spectral shapes, + kinetic rates, and IRF (Instrument Response Function). + + Args: + parameters (Parameters): The parameters to be updated. + simulation_config (SimulationConfig): The configuration containing the new values for the parameters. + + Returns: + Parameters: The updated parameters. + """ + for param in parameters.all(): + label = param.label + if label.startswith('shapes.species_'): + parts = label.split('.') + species_index = int(parts[1].split('_')[1]) - 1 + attribute = parts[2] + + if attribute == 'amplitude': + param.value = simulation_config.spectral_parameters.amplitude[species_index] + elif attribute == 'location': + param.value = simulation_config.spectral_parameters.location[species_index] + elif attribute == 'width': + param.value = simulation_config.spectral_parameters.width[species_index] + elif attribute == 'skewness': + param.value = simulation_config.spectral_parameters.skewness[species_index] + + elif label.startswith('rates.species_'): + species_index = int(label.split('_')[1]) - 1 + param.value = simulation_config.kinetic_parameters.decay_rates[species_index] + + elif label.startswith('irf') and simulation_config.settings.add_gaussian_irf: + if 'width' in label: + param.value = simulation_config.irf.width + if 'center' in label: + param.value = simulation_config.irf.center + return parameters + +def _generate_parameter_file(simulation_config: SimulationConfig, model: Model, file_name: str) -> Parameters: + """ + Generate and save the parameter file for the simulation. + + This function generates the parameters for the given model, updates them based on the simulation configuration, + validates the updated parameters, and saves them to a file. + + Args: + simulation_config (SimulationConfig): The configuration for the simulation. + model (Model): The model for which parameters are to be generated. + file_name (str): The name of the file to save the parameters. + + Returns: + Parameters: The updated and validated parameters. + """ + parameters = model.generate_parameters() + updated_parameters = _update_parameter_values(parameters, simulation_config) + model.validate(updated_parameters) + save_parameters(updated_parameters, file_name, allow_overwrite=True) + return updated_parameters + +def _generate_data_file(model: Model, parameters: Parameters, coordinates: Dict[str, np.ndarray], settings: Settings, file_name: str): + """ + Generate and save the data file for the simulation. + + This function simulates the data based on the given model, parameters, coordinates, and settings, + and saves the simulated data to a file. + + Args: + model (Model): The model used for simulation. + parameters (Parameters): The parameters used for simulation. + coordinates (Dict[str, np.ndarray]): The coordinates for the simulation. + settings (Settings): The settings for the simulation. + file_name (str): The name of the file to save the simulated data. + """ + noise = False if settings.stdev_noise == 0 else True + data = simulate(model, "dataset_1", parameters, coordinates, noise=noise, noise_std_dev=settings.stdev_noise, noise_seed=settings.seed) + save_dataset(data, file_name, "nc", allow_overwrite=True) + +def generate_model_parameter_and_data_files(simulation_config: SimulationConfig, model_file_name: str = "model.yml", parameter_file_name: str = "parameters.csv", data_file_name: str = "dataset.nc"): + """ + Generate and save the model, parameter, and data files for the simulation. + + This function generates the model file, parameter file, and data file based on the given simulation configuration. + + Args: + simulation_config (SimulationConfig): The configuration for the simulation. + model_file_name (str, optional): The name of the file to save the model. Defaults to "model.yml". + parameter_file_name (str, optional): The name of the file to save the parameters. Defaults to "parameters.csv". + data_file_name (str, optional): The name of the file to save the data. Defaults to "dataset.nc". + """ + nr_compartments = len(simulation_config.kinetic_parameters.decay_rates) + model = _generate_model_file(simulation_config, nr_compartments, model_file_name) + parameters = _generate_parameter_file(simulation_config, model, parameter_file_name) + _generate_data_file(model, parameters, simulation_config.coordinates, simulation_config.settings, data_file_name) + +def _sanitize_dict(d: Union[Dict[str, Any], Any]) -> Union[Dict[str, Any], Any]: + """ + Recursively sanitize a dictionary by removing keys with values that are None, empty lists, or empty dictionaries. + + Args: + d (Union[Dict[str, Any], Any]): The dictionary to sanitize or any other value. + + Returns: + Union[Dict[str, Any], Any]: The sanitized dictionary or the original value if it is not a dictionary. + """ + if not isinstance(d, dict): + return d + return {k: _sanitize_dict(v) for k, v in d.items() if v not in (None, [], {})} + +def _sanitize_yaml_file(input_file: str, output_file: str) -> None: + """ + Sanitize a YAML file by removing keys with values that are None, empty lists, or empty dictionaries, + and save the sanitized content to a new file. + + Args: + input_file (str): The path to the input YAML file. + output_file (str): The path to the output sanitized YAML file. + """ + with open(input_file, 'r') as f: + data = yaml.safe_load(f) + + sanitized_data = _sanitize_dict(data) + + with open(output_file, 'w') as f: + yaml.safe_dump(sanitized_data, f) + os.remove(input_file) diff --git a/pyparamgui/widget.py b/pyparamgui/widget.py index 5f61361..af8a021 100644 --- a/pyparamgui/widget.py +++ b/pyparamgui/widget.py @@ -1,113 +1,121 @@ -import panel as pn -import numpy as np +"""This module contains the simulation widget.""" -from pyparamgui.config import KineticParameters, SpectralParameters, TimeCoordinates, SpectralCoordinates, Settings, IRF, SimulationConfig, generate_simulation_coordinates -from pyparamgui.io import generate_model_parameter_and_data_files +from __future__ import annotations -class SimulationWidget: - def __init__(self): - pn.extension() - self.simulation_config = None - self.decay_rates = pn.widgets.TextInput(name='Decay rates', value='0.055, 0.005') - self.amplitudes = pn.widgets.TextInput(name='Amplitudes', value='1., 1.') - self.location = pn.widgets.TextInput(name='Location (mean) of spectra', value='22000, 20000') - self.width = pn.widgets.TextInput(name='Width of spectra', value='4000, 3500') - self.skewness = pn.widgets.TextInput(name='Skewness of spectra', value='0.1, -0.1') - self.timepoints_max = pn.widgets.IntInput(name='Timepoints, max', value=80) - self.timepoints_stepsize = pn.widgets.FloatInput(name='Stepsize', value=1) - self.wavelength_min = pn.widgets.IntInput(name='Wavelength Min', value=400) - self.wavelength_max = pn.widgets.IntInput(name='Wavelength Max', value=600) - self.wavelength_stepsize = pn.widgets.IntInput(name='Stepsize', value=5) - self.stdev_noise = pn.widgets.FloatInput(name='Std.dev. noise', value=0.01) - self.seed = pn.widgets.IntInput(name='Seed', value=123) - self.add_gaussian_irf = pn.widgets.Checkbox(name='Add Gaussian IRF') - self.irf_location = pn.widgets.FloatInput(name='IRF location', value=0) - self.irf_width = pn.widgets.FloatInput(name='IRF width', value=0) - self.use_sequential_scheme = pn.widgets.Checkbox(name='Use a sequential scheme') - self.model_file_name = pn.widgets.TextInput(name='Model File Name', value='model.yml') - self.parameter_file_name = pn.widgets.TextInput(name='Parameter File Name', value='parameters.csv') - self.data_file_name = pn.widgets.TextInput(name='Data File Name', value='dataset.nc') - self.button = pn.widgets.Button(name='Simulate', button_type='primary') - self.output_pane = pn.pane.Markdown("") +import pathlib - self.widget = pn.Column( - self.decay_rates, self.amplitudes, self.location, self.width, self.skewness, - pn.Row(self.timepoints_max, self.timepoints_stepsize), - pn.Row(self.wavelength_min, self.wavelength_max, self.wavelength_stepsize), - self.stdev_noise, self.seed, self.add_gaussian_irf, self.irf_location, self.irf_width, - self.use_sequential_scheme, - pn.Row(self.model_file_name, self.parameter_file_name, self.data_file_name), - self.button, self.output_pane - ) - - self.button.on_click(self.callback) +import traitlets +import anywidget - def callback(self, event): - try: - decay_rates = np.fromstring(self.decay_rates.value, sep=',') - amplitudes = np.fromstring(self.amplitudes.value, sep=',') - location = np.fromstring(self.location.value, sep=',') - width = np.fromstring(self.width.value, sep=',') - skewness = np.fromstring(self.skewness.value, sep=',') +from pyparamgui.schema import KineticParameters, SpectralParameters, TimeCoordinates, SpectralCoordinates, Settings, IRF, SimulationConfig, generate_simulation_coordinates +from pyparamgui.utils import generate_model_parameter_and_data_files - valid_input = True - messages = [] +class Widget(anywidget.AnyWidget): + """ + A widget class for handling simulation parameters, coordinates and settings. - if self.wavelength_min.value >= self.wavelength_max.value or self.timepoints_max.value <= 0: - valid_input = False - messages.append("Invalid timepoints or wavelength specification") + Attributes: + _esm (pathlib.Path): Path to the JavaScript file for the widget. + _css (pathlib.Path): Path to the CSS file for the widget. + decay_rates_input (traitlets.List): List of decay rates as floats. + amplitude_input (traitlets.List): List of amplitudes as floats. + location_input (traitlets.List): List of locations as floats. + width_input (traitlets.List): List of widths as floats. + skewness_input (traitlets.List): List of skewness values as floats. + timepoints_max_input (traitlets.Int): Maximum number of timepoints. + timepoints_stepsize_input (traitlets.Float): Step size for timepoints. + wavelength_min_input (traitlets.Float): Minimum wavelength value. + wavelength_max_input (traitlets.Float): Maximum wavelength value. + wavelength_stepsize_input (traitlets.Float): Step size for wavelength. + stdev_noise_input (traitlets.Float): Standard deviation of noise. + seed_input (traitlets.Int): Seed for random number generation. + add_gaussian_irf_input (traitlets.Bool): Flag to add Gaussian IRF. + irf_location_input (traitlets.Float): Location of the IRF center. + irf_width_input (traitlets.Float): Width of the IRF. + use_sequential_scheme_input (traitlets.Bool): Flag to use sequential scheme. + model_file_name_input (traitlets.Unicode): Name of the model file. + parameter_file_name_input (traitlets.Unicode): Name of the parameter file. + data_file_name_input (traitlets.Unicode): Name of the data file. + simulate (traitlets.Unicode): Trigger for simulation. + """ + _esm: pathlib.Path = pathlib.Path(__file__).parent / "static" / "form.js" + _css: pathlib.Path = pathlib.Path(__file__).parent / "static" / "form.css" + decay_rates_input: traitlets.List = traitlets.List(trait=traitlets.Float()).tag(sync=True) + amplitude_input: traitlets.List = traitlets.List(trait=traitlets.Float()).tag(sync=True) + location_input: traitlets.List = traitlets.List(trait=traitlets.Float()).tag(sync=True) + width_input: traitlets.List = traitlets.List(trait=traitlets.Float()).tag(sync=True) + skewness_input: traitlets.List = traitlets.List(trait=traitlets.Float()).tag(sync=True) + timepoints_max_input: traitlets.Int = traitlets.Int().tag(sync=True) + timepoints_stepsize_input: traitlets.Float = traitlets.Float().tag(sync=True) + wavelength_min_input: traitlets.Float = traitlets.Float().tag(sync=True) + wavelength_max_input: traitlets.Float = traitlets.Float().tag(sync=True) + wavelength_stepsize_input: traitlets.Float = traitlets.Float().tag(sync=True) + stdev_noise_input: traitlets.Float = traitlets.Float().tag(sync=True) + seed_input: traitlets.Int = traitlets.Int().tag(sync=True) + add_gaussian_irf_input: traitlets.Bool = traitlets.Bool().tag(sync=True) + irf_location_input: traitlets.Float = traitlets.Float().tag(sync=True) + irf_width_input: traitlets.Float = traitlets.Float().tag(sync=True) + use_sequential_scheme_input: traitlets.Bool = traitlets.Bool().tag(sync=True) + model_file_name_input: traitlets.Unicode = traitlets.Unicode("").tag(sync=True) + parameter_file_name_input: traitlets.Unicode = traitlets.Unicode("").tag(sync=True) + data_file_name_input: traitlets.Unicode = traitlets.Unicode("").tag(sync=True) + simulate: traitlets.Unicode = traitlets.Unicode("").tag(sync=True) - lengths = {len(decay_rates), len(amplitudes), len(location), len(width), len(skewness)} - if len(lengths) > 1: - valid_input = False - messages.append("Parameter fields of unequal length") +widget = Widget() - if not valid_input: - self.output_pane.object = pn.pane.Markdown('\n'.join(f"**Error:** {msg}" for msg in messages)) - else: - self.simulation_config = SimulationConfig( - kinetic_parameters=KineticParameters( - decay_rates=decay_rates.tolist() - ), - spectral_parameters=SpectralParameters( - amplitudes=amplitudes.tolist(), - location_mean=location.tolist(), - width=width.tolist(), - skewness=skewness.tolist() - ), - coordinates=generate_simulation_coordinates( - TimeCoordinates( - timepoints_max=self.timepoints_max.value, - timepoints_stepsize=self.timepoints_stepsize.value - ), - SpectralCoordinates( - wavelength_min=self.wavelength_min.value, - wavelength_max=self.wavelength_max.value, - wavelength_stepsize=self.wavelength_stepsize.value - ) - ), - settings=Settings( - stdev_noise=self.stdev_noise.value, - seed=self.seed.value, - add_gaussian_irf=self.add_gaussian_irf.value, - use_sequential_scheme=self.use_sequential_scheme.value - ), - irf=IRF( - center=self.irf_location.value, - width=self.irf_width.value - ) - ) - generate_model_parameter_and_data_files( - self.simulation_config, - model_file_name=self.model_file_name.value, - parameter_file_name=self.parameter_file_name.value, - data_file_name=self.data_file_name.value - ) - self.output_pane.object = "**Simulation successful!**\n\n**Files created!**" - - except Exception as e: - self.output_pane.object = f"**Error in simulation:** {str(e)}" +def _simulate(change) -> None: + """ + A Private callback function for simulating the data based on the parameters, coordinates, and other simulation settings. + + This function generates the model, parameter, and data files using the provided widget inputs. + The 'change' parameter is not used within this function, but it is required to be present + because it represents the state change of the traitlets. This is a common pattern when + using traitlets to observe changes in widget state. + """ + simulation_config = SimulationConfig( + kinetic_parameters=KineticParameters( + decay_rates=widget.decay_rates_input + ), + spectral_parameters=SpectralParameters( + amplitude=widget.amplitude_input, + location=widget.location_input, + width=widget.width_input, + skewness=widget.skewness_input + ), + coordinates=generate_simulation_coordinates( + TimeCoordinates( + timepoints_max=widget.timepoints_max_input, + timepoints_stepsize=widget.timepoints_stepsize_input + ), + SpectralCoordinates( + wavelength_min=widget.wavelength_min_input, + wavelength_max=widget.wavelength_max_input, + wavelength_stepsize=widget.wavelength_stepsize_input + ) + ), + settings=Settings( + stdev_noise=widget.stdev_noise_input, + seed=widget.seed_input, + add_gaussian_irf=widget.add_gaussian_irf_input, + use_sequential_scheme=widget.use_sequential_scheme_input + ), + irf=IRF( + center=widget.irf_location_input, + width=widget.irf_width_input + ) + ) + generate_model_parameter_and_data_files( + simulation_config, + model_file_name=widget.model_file_name_input, + parameter_file_name=widget.parameter_file_name_input, + data_file_name=widget.data_file_name_input + ) -simulation_form = SimulationWidget() -simulation_form.widget.servable() \ No newline at end of file +def setup_widget_observer() -> None: + """ + Sets up the observer pattern on the 'simulate' traitlet to synchronize the frontend widget + with the backend simulation code. This function ensures that any changes in the widget's state + trigger the simulation process, which generates the model, parameter, and data files. + """ + widget.observe(handler=_simulate, names=['simulate']) diff --git a/pyproject.toml b/pyproject.toml index fd05725..efe0ef6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ keywords = [ license = { file = "LICENSE" } authors = [ - { name = "Anmol Bhatia", email = "a.bhatia2@student.vu.nl" }, + { name = "Anmol Bhatia", email = "anmolbhatia05@gmail.com " }, ] requires-python = ">=3.8" classifiers = [ @@ -35,6 +35,10 @@ dynamic = [ ] dependencies = [ + "pydantic==2.8.2", + "anywidget==0.9.13", + "pyglotaran==0.7.2", + "pyyaml==6.0.1", ] optional-dependencies.dev = [ "pyparamgui[docs,test]", diff --git a/requirements_pinned.txt b/requirements_pinned.txt index 38ff788..14f376c 100644 --- a/requirements_pinned.txt +++ b/requirements_pinned.txt @@ -1,2 +1,6 @@ # runtime requirements # pinned so the bot can create PRs to test with new versions + +pydantic==2.8.2 +anywidget==0.9.13 +pyglotaran==0.7.2 From 7f25805264d3dc23533b39de185adac18b6c6c74 Mon Sep 17 00:00:00 2001 From: anmolbhatia05 Date: Fri, 26 Jul 2024 19:47:41 +0200 Subject: [PATCH 03/10] changes from pre-commit --- pyparamgui/__init__.py | 5 +- pyparamgui/generator.py | 8 +- pyparamgui/schema.py | 57 ++-- pyparamgui/static/form.css | 30 +- pyparamgui/static/form.js | 595 +++++++++++++++++++++++-------------- pyparamgui/utils.py | 136 ++++++--- pyparamgui/widget.py | 61 ++-- pyproject.toml | 2 +- requirements_pinned.txt | 1 + 9 files changed, 561 insertions(+), 334 deletions(-) diff --git a/pyparamgui/__init__.py b/pyparamgui/__init__.py index 13ebf4e..16d4f9f 100644 --- a/pyparamgui/__init__.py +++ b/pyparamgui/__init__.py @@ -6,6 +6,7 @@ __email__ = "anmolbhatia05@gmail.com" __version__ = "0.0.1" -from .widget import widget, setup_widget_observer +from pyparamgui.widget import setup_widget_observer +from pyparamgui.widget import widget -__all__ = ['widget', 'setup_widget_observer'] +__all__ = ["widget", "setup_widget_observer"] diff --git a/pyparamgui/generator.py b/pyparamgui/generator.py index d278ae1..5906897 100644 --- a/pyparamgui/generator.py +++ b/pyparamgui/generator.py @@ -1,4 +1,4 @@ -"""The glotaran generator module.""" +"""The glotaran generator module.""" from __future__ import annotations @@ -60,7 +60,7 @@ def _generate_decay_model( "amplitude": f"shapes.species_{i+1}.amplitude", "location": f"shapes.species_{i+1}.location", "width": f"shapes.species_{i+1}.width", - "skewness": f"shapes.species_{i+1}.skewness" + "skewness": f"shapes.species_{i+1}.skewness", } for i in range(nr_compartments) } @@ -68,7 +68,7 @@ def _generate_decay_model( "megacomplex_spectral" ] model["dataset"]["dataset_1"]["spectral_axis_inverted"] = True - model["dataset"]["dataset_1"]["spectral_axis_scale"] = 1E7 + model["dataset"]["dataset_1"]["spectral_axis_scale"] = 1e7 if irf: model["dataset"]["dataset_1"]["irf"] = "gaussian_irf" # type:ignore[index] model["irf"] = { @@ -193,7 +193,7 @@ class GeneratorArguments(TypedDict, total=False): irf: bool -def generate_model(*, generator_name: str, generator_arguments: GeneratorArguments) -> Model: +def generate_model(*, generator_name: str, generator_arguments: GeneratorArguments) -> Model: """Generate a model. Parameters diff --git a/pyparamgui/schema.py b/pyparamgui/schema.py index f470fb7..48ac321 100644 --- a/pyparamgui/schema.py +++ b/pyparamgui/schema.py @@ -1,11 +1,14 @@ -"""This module has the different model classes representing different parameters, coordinates, and settings for simulation.""" +"""This module has the different model classes representing different parameters, coordinates, and +settings for simulation.""" from __future__ import annotations from typing import Dict import numpy as np -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel +from pydantic import ConfigDict + class KineticParameters(BaseModel): """Kinetic parameters for the simulation. @@ -13,8 +16,10 @@ class KineticParameters(BaseModel): Attributes: decay_rates (list[float]): List of decay rates. """ + decay_rates: list[float] + class SpectralParameters(BaseModel): """Spectral parameters for the simulation. @@ -24,38 +29,43 @@ class SpectralParameters(BaseModel): width (list[float]): List of widths. skewness (list[float]): List of skewness values. """ + amplitude: list[float] location: list[float] width: list[float] skewness: list[float] + class TimeCoordinates(BaseModel): - """ - Time coordinates for the simulation. + """Time coordinates for the simulation. Attributes: timepoints_max (int): Maximum number of time points. timepoints_stepsize (float): Step size between time points. """ + timepoints_max: int timepoints_stepsize: float + class SpectralCoordinates(BaseModel): - """ - Spectral coordinates for the simulation. + """Spectral coordinates for the simulation. Attributes: wavelength_min (int): Minimum wavelength. wavelength_max (int): Maximum wavelength. wavelength_stepsize (float): Step size between wavelengths. """ + wavelength_min: int wavelength_max: int wavelength_stepsize: float - -def generate_simulation_coordinates(time_coordinates: TimeCoordinates, spectral_coordinates: SpectralCoordinates) -> Dict[str, np.ndarray]: - """ - Generate simulation coordinates based on time and spectral coordinates. + + +def generate_simulation_coordinates( + time_coordinates: TimeCoordinates, spectral_coordinates: SpectralCoordinates +) -> Dict[str, np.ndarray]: + """Generate simulation coordinates based on time and spectral coordinates. Args: time_coordinates (TimeCoordinates): The time coordinates for the simulation. @@ -64,13 +74,21 @@ def generate_simulation_coordinates(time_coordinates: TimeCoordinates, spectral_ Returns: Dict[str, np.ndarray]: A dictionary containing the time and spectral axes as numpy arrays. """ - time_axis = np.arange(0, time_coordinates.timepoints_max * time_coordinates.timepoints_stepsize, time_coordinates.timepoints_stepsize) - spectral_axis = np.arange(spectral_coordinates.wavelength_min, spectral_coordinates.wavelength_max, spectral_coordinates.wavelength_stepsize) + time_axis = np.arange( + 0, + time_coordinates.timepoints_max * time_coordinates.timepoints_stepsize, + time_coordinates.timepoints_stepsize, + ) + spectral_axis = np.arange( + spectral_coordinates.wavelength_min, + spectral_coordinates.wavelength_max, + spectral_coordinates.wavelength_stepsize, + ) return {"time": time_axis, "spectral": spectral_axis} + class Settings(BaseModel): - """ - Other settings for the simulation. + """Other settings for the simulation. Attributes: stdev_noise (float): Standard deviation of the noise to be added to the simulation data. @@ -78,25 +96,27 @@ class Settings(BaseModel): add_gaussian_irf (bool): Flag to indicate whether to add a Gaussian Instrument Response Function (IRF) to the simulation. Default is False. use_sequential_scheme (bool): Flag to indicate whether to use a sequential scheme in the simulation. Default is False. """ + stdev_noise: float seed: int add_gaussian_irf: bool = False use_sequential_scheme: bool = False + class IRF(BaseModel): - """ - Instrument Response Function (IRF) settings for the simulation. + """Instrument Response Function (IRF) settings for the simulation. Attributes: center (float): The center position of the IRF. width (float): The width of the IRF. """ + center: float width: float + class SimulationConfig(BaseModel): - """ - Configuration for the simulation, combining various parameters and settings. + """Configuration for the simulation, combining various parameters and settings. Attributes: kinetic_parameters (KineticParameters): Kinetic parameters for the simulation. @@ -105,6 +125,7 @@ class SimulationConfig(BaseModel): settings (Settings): Other settings for the simulation, including noise standard deviation, random seed, and flags for adding Gaussian IRF and using a sequential scheme. irf (IRF): Instrument Response Function (IRF) settings, including center position and width. """ + kinetic_parameters: KineticParameters spectral_parameters: SpectralParameters coordinates: Dict[str, np.ndarray] diff --git a/pyparamgui/static/form.css b/pyparamgui/static/form.css index 4d6f0e2..4053a4c 100644 --- a/pyparamgui/static/form.css +++ b/pyparamgui/static/form.css @@ -1,26 +1,26 @@ form { - display: flex; - flex-direction: column; + display: flex; + flex-direction: column; } .form-group { - display: flex; - align-items: center; - margin-bottom: 10px; + display: flex; + align-items: center; + margin-bottom: 10px; } label { - margin-right: 10px; - color: black; - width: 150px; + margin-right: 10px; + color: black; + width: 150px; } input { - flex: 1; - margin: 5px 0; + flex: 1; + margin: 5px 0; } hr { - margin: 20px 0; - border: none; - border-top: 1px solid #ccc; + margin: 20px 0; + border: none; + border-top: 1px solid #ccc; } button { - margin-top: 10px; -} \ No newline at end of file + margin-top: 10px; +} diff --git a/pyparamgui/static/form.js b/pyparamgui/static/form.js index 0c5ca4a..33a8292 100644 --- a/pyparamgui/static/form.js +++ b/pyparamgui/static/form.js @@ -1,231 +1,384 @@ function render({ model, el }) { - const form = document.createElement('form'); - - /** - * Creates a form group with a label and a text input field. - * - * @param {string} labelText - The text content for the label. - * @param {string} inputId - The id attribute for the input element. - * @param {string} inputName - The name attribute for the input element. - * @param {string} inputValue - The initial value for the input element. - * - * @returns {HTMLDivElement} The form group element containing the label and input. - */ - function createTextFormGroup(labelText, inputId, inputName, inputValue) { - const formGroup = document.createElement('div'); - formGroup.className = 'form-group'; - - const label = document.createElement('label'); - label.setAttribute('for', inputId); - label.textContent = labelText; - - const input = document.createElement('input'); - input.setAttribute('type', 'text'); - input.setAttribute('id', inputId); - input.setAttribute('name', inputName); - input.value = inputValue; - - formGroup.appendChild(label); - formGroup.appendChild(input); - - return formGroup; - } + const form = document.createElement("form"); - /** - * Creates a form group with a label and a checkbox input field. - * - * @param {string} labelText - The text content for the label. - * @param {string} inputId - The id attribute for the input element. - * @param {string} inputName - The name attribute for the input element. - * - * @returns {HTMLDivElement} The form group element containing the label and checkbox input. - */ - function createCheckboxFormGroup(labelText, inputId, inputName) { - const formGroup = document.createElement('div'); - formGroup.className = 'form-group'; - - const label = document.createElement('label'); - label.setAttribute('for', inputId); - label.textContent = labelText; - - const input = document.createElement('input'); - input.setAttribute('type', 'checkbox'); - input.setAttribute('id', inputId); - input.setAttribute('name', inputName); - - formGroup.appendChild(label); - formGroup.appendChild(input); - - return formGroup; - } + /** + * Creates a form group with a label and a text input field. + * + * @param {string} labelText - The text content for the label. + * @param {string} inputId - The id attribute for the input element. + * @param {string} inputName - The name attribute for the input element. + * @param {string} inputValue - The initial value for the input element. + * + * @returns {HTMLDivElement} The form group element containing the label and input. + */ + function createTextFormGroup(labelText, inputId, inputName, inputValue) { + const formGroup = document.createElement("div"); + formGroup.className = "form-group"; + + const label = document.createElement("label"); + label.setAttribute("for", inputId); + label.textContent = labelText; + + const input = document.createElement("input"); + input.setAttribute("type", "text"); + input.setAttribute("id", inputId); + input.setAttribute("name", inputName); + input.value = inputValue; + + formGroup.appendChild(label); + formGroup.appendChild(input); + + return formGroup; + } + + /** + * Creates a form group with a label and a checkbox input field. + * + * @param {string} labelText - The text content for the label. + * @param {string} inputId - The id attribute for the input element. + * @param {string} inputName - The name attribute for the input element. + * + * @returns {HTMLDivElement} The form group element containing the label and checkbox input. + */ + function createCheckboxFormGroup(labelText, inputId, inputName) { + const formGroup = document.createElement("div"); + formGroup.className = "form-group"; + + const label = document.createElement("label"); + label.setAttribute("for", inputId); + label.textContent = labelText; + + const input = document.createElement("input"); + input.setAttribute("type", "checkbox"); + input.setAttribute("id", inputId); + input.setAttribute("name", inputName); + + formGroup.appendChild(label); + formGroup.appendChild(input); + + return formGroup; + } - form.appendChild(createTextFormGroup('Decay rates:', 'decay_rates_input', 'decay_rates_input', '0.055, 0.005')); - form.appendChild(document.createElement('hr')); - form.appendChild(createTextFormGroup('Amplitudes:', 'amplitude_input', 'amplitude_input', '1., 1.')); - form.appendChild(createTextFormGroup('Location (mean) of spectra:', 'location_input', 'location_input', '22000, 20000')); - form.appendChild(createTextFormGroup('Width of spectra:', 'width_input', 'width_input', '4000, 3500')); - form.appendChild(createTextFormGroup('Skewness of spectra:', 'skewness_input', 'skewness_input', '0.1, -0.1')); - form.appendChild(document.createElement('hr')); - form.appendChild(createTextFormGroup('Timepoints, max:', 'timepoints_max_input', 'timepoints_max_input', '80')); - form.appendChild(createTextFormGroup('Stepsize:', 'timepoints_stepsize_input', 'timepoints_stepsize_input', '1')); - form.appendChild(document.createElement('hr')); - form.appendChild(createTextFormGroup('Wavelength Min:', 'wavelength_min_input', 'wavelength_min_input', '400')); - form.appendChild(createTextFormGroup('Wavelength Max:', 'wavelength_max_input', 'wavelength_max_input', '600')); - form.appendChild(createTextFormGroup('Stepsize:', 'wavelength_stepsize_input', 'wavelength_stepsize_input', '5')); - form.appendChild(document.createElement('hr')); - form.appendChild(createTextFormGroup('Std.dev. noise:', 'stdev_noise_input', 'stdev_noise_input', '0.01')); - form.appendChild(createTextFormGroup('Seed:', 'seed_input', 'seed_input', '123')); - form.appendChild(document.createElement('hr')); - form.appendChild(createCheckboxFormGroup('Add Gaussian IRF:', 'add_gaussian_irf_input', 'add_gaussian_irf_input')); - form.appendChild(createTextFormGroup('IRF location:', 'irf_location_input', 'irf_location_input', '3')); - form.appendChild(createTextFormGroup('IRF width:', 'irf_width_input', 'irf_width_input', '1')); - form.appendChild(document.createElement('hr')); - form.appendChild(createCheckboxFormGroup('Use Sequential Scheme:', 'use_sequential_scheme_input', 'use_sequential_scheme_input')); - form.appendChild(document.createElement('hr')); - form.appendChild(createTextFormGroup('Model File Name:', 'model_file_name_input', 'model_file_name_input', 'model.yml')); - form.appendChild(createTextFormGroup('Parameter File Name:', 'parameter_file_name_input', 'parameter_file_name_input', 'parameters.csv')); - form.appendChild(createTextFormGroup('Data File Name:', 'data_file_name_input', 'data_file_name_input', 'dataset.nc')); - - el.appendChild(form); - - /** - * Converts the input values from the form into their respective data types. - * - * @returns {Object|null} An object containing the converted input values, or null if an error occurs. - * - * @property {number[]} decay_rates - Array of decay rates as floats. - * @property {number[]} amplitude - Array of amplitudes as floats. - * @property {number[]} location - Array of locations as floats. - * @property {number[]} width - Array of widths as floats. - * @property {number[]} skewness - Array of skewness values as floats. - * @property {number} timepoints_max - Maximum number of timepoints as an integer. - * @property {number} timepoints_stepsize - Step size for timepoints as a float. - * @property {number} wavelength_min - Minimum wavelength value as a float. - * @property {number} wavelength_max - Maximum wavelength value as a float. - * @property {number} wavelength_stepsize - Step size for wavelength as a float. - * @property {number} stdev_noise - Standard deviation of noise as a float. - * @property {number} seed - Seed for random number generation as an integer. - * @property {number} irf_location - Location of the IRF center as a float. - * @property {number} irf_width - Width of the IRF as a float. - */ - function convertInputs() { - try { - const decay_rates = decay_rates_input.value.split(',').map(parseFloat); - const amplitude = amplitude_input.value.split(',').map(parseFloat); - const location = location_input.value.split(',').map(parseFloat); - const width = width_input.value.split(',').map(parseFloat); - const skewness = skewness_input.value.split(',').map(parseFloat); - const timepoints_max = parseInt(timepoints_max_input.value, 10); - const timepoints_stepsize = parseFloat(timepoints_stepsize_input.value); - const wavelength_min = parseFloat(wavelength_min_input.value); - const wavelength_max = parseFloat(wavelength_max_input.value); - const wavelength_stepsize = parseFloat(wavelength_stepsize_input.value); - const stdev_noise = parseFloat(stdev_noise_input.value); - const seed = parseInt(seed_input.value, 10); - const irf_location = parseFloat(irf_location_input.value); - const irf_width = parseFloat(irf_width_input.value); - - return { decay_rates, amplitude, location, width, skewness, timepoints_max, timepoints_stepsize, wavelength_min, wavelength_max, wavelength_stepsize, stdev_noise, seed, irf_location, irf_width }; - } catch (error) { - alert('Error converting inputs: ' + error.message); - return null; - } + form.appendChild( + createTextFormGroup( + "Decay rates:", + "decay_rates_input", + "decay_rates_input", + "0.055, 0.005", + ), + ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createTextFormGroup( + "Amplitudes:", + "amplitude_input", + "amplitude_input", + "1., 1.", + ), + ); + form.appendChild( + createTextFormGroup( + "Location (mean) of spectra:", + "location_input", + "location_input", + "22000, 20000", + ), + ); + form.appendChild( + createTextFormGroup( + "Width of spectra:", + "width_input", + "width_input", + "4000, 3500", + ), + ); + form.appendChild( + createTextFormGroup( + "Skewness of spectra:", + "skewness_input", + "skewness_input", + "0.1, -0.1", + ), + ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createTextFormGroup( + "Timepoints, max:", + "timepoints_max_input", + "timepoints_max_input", + "80", + ), + ); + form.appendChild( + createTextFormGroup( + "Stepsize:", + "timepoints_stepsize_input", + "timepoints_stepsize_input", + "1", + ), + ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createTextFormGroup( + "Wavelength Min:", + "wavelength_min_input", + "wavelength_min_input", + "400", + ), + ); + form.appendChild( + createTextFormGroup( + "Wavelength Max:", + "wavelength_max_input", + "wavelength_max_input", + "600", + ), + ); + form.appendChild( + createTextFormGroup( + "Stepsize:", + "wavelength_stepsize_input", + "wavelength_stepsize_input", + "5", + ), + ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createTextFormGroup( + "Std.dev. noise:", + "stdev_noise_input", + "stdev_noise_input", + "0.01", + ), + ); + form.appendChild( + createTextFormGroup("Seed:", "seed_input", "seed_input", "123"), + ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createCheckboxFormGroup( + "Add Gaussian IRF:", + "add_gaussian_irf_input", + "add_gaussian_irf_input", + ), + ); + form.appendChild( + createTextFormGroup( + "IRF location:", + "irf_location_input", + "irf_location_input", + "3", + ), + ); + form.appendChild( + createTextFormGroup( + "IRF width:", + "irf_width_input", + "irf_width_input", + "1", + ), + ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createCheckboxFormGroup( + "Use Sequential Scheme:", + "use_sequential_scheme_input", + "use_sequential_scheme_input", + ), + ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createTextFormGroup( + "Model File Name:", + "model_file_name_input", + "model_file_name_input", + "model.yml", + ), + ); + form.appendChild( + createTextFormGroup( + "Parameter File Name:", + "parameter_file_name_input", + "parameter_file_name_input", + "parameters.csv", + ), + ); + form.appendChild( + createTextFormGroup( + "Data File Name:", + "data_file_name_input", + "data_file_name_input", + "dataset.nc", + ), + ); + + el.appendChild(form); + + /** + * Converts the input values from the form into their respective data types. + * + * @returns {Object|null} An object containing the converted input values, or null if an error occurs. + * + * @property {number[]} decay_rates - Array of decay rates as floats. + * @property {number[]} amplitude - Array of amplitudes as floats. + * @property {number[]} location - Array of locations as floats. + * @property {number[]} width - Array of widths as floats. + * @property {number[]} skewness - Array of skewness values as floats. + * @property {number} timepoints_max - Maximum number of timepoints as an integer. + * @property {number} timepoints_stepsize - Step size for timepoints as a float. + * @property {number} wavelength_min - Minimum wavelength value as a float. + * @property {number} wavelength_max - Maximum wavelength value as a float. + * @property {number} wavelength_stepsize - Step size for wavelength as a float. + * @property {number} stdev_noise - Standard deviation of noise as a float. + * @property {number} seed - Seed for random number generation as an integer. + * @property {number} irf_location - Location of the IRF center as a float. + * @property {number} irf_width - Width of the IRF as a float. + */ + function convertInputs() { + try { + const decay_rates = decay_rates_input.value.split(",").map(parseFloat); + const amplitude = amplitude_input.value.split(",").map(parseFloat); + const location = location_input.value.split(",").map(parseFloat); + const width = width_input.value.split(",").map(parseFloat); + const skewness = skewness_input.value.split(",").map(parseFloat); + const timepoints_max = parseInt(timepoints_max_input.value, 10); + const timepoints_stepsize = parseFloat(timepoints_stepsize_input.value); + const wavelength_min = parseFloat(wavelength_min_input.value); + const wavelength_max = parseFloat(wavelength_max_input.value); + const wavelength_stepsize = parseFloat(wavelength_stepsize_input.value); + const stdev_noise = parseFloat(stdev_noise_input.value); + const seed = parseInt(seed_input.value, 10); + const irf_location = parseFloat(irf_location_input.value); + const irf_width = parseFloat(irf_width_input.value); + + return { + decay_rates, + amplitude, + location, + width, + skewness, + timepoints_max, + timepoints_stepsize, + wavelength_min, + wavelength_max, + wavelength_stepsize, + stdev_noise, + seed, + irf_location, + irf_width, + }; + } catch (error) { + alert("Error converting inputs: " + error.message); + return null; } + } + + /** + * Validates the input values for the simulation. + * + * @param {Object} inputs - The input values to validate. + * + * @param {number[]} inputs.decay_rates - Array of decay rates as floats. + * @param {number[]} inputs.amplitude - Array of amplitudes as floats. + * @param {number[]} inputs.location - Array of locations as floats. + * @param {number[]} inputs.width - Array of widths as floats. + * @param {number[]} inputs.skewness - Array of skewness values as floats. + * @param {number} inputs.wavelength_min - Minimum wavelength value as a float. + * @param {number} inputs.wavelength_max - Maximum wavelength value as a float. + * @param {number} inputs.timepoints_max - Maximum number of timepoints as an integer. + * + * @returns {boolean} True if all inputs are valid, otherwise false. + */ + function validateInputs(inputs) { + try { + const { decay_rates, amplitude, location, width, skewness } = inputs; + + if (decay_rates.some(isNaN)) { + alert("Invalid decay rates"); + return false; + } + if (amplitude.some(isNaN)) { + alert("Invalid amplitudes"); + return false; + } + if (location.some(isNaN)) { + alert("Invalid locations"); + return false; + } + if (width.some(isNaN)) { + alert("Invalid widths"); + return false; + } + if (skewness.some(isNaN)) { + alert("Invalid skewness values"); + return false; + } + + const lengths = [ + decay_rates.length, + amplitude.length, + location.length, + width.length, + skewness.length, + ]; + if (new Set(lengths).size !== 1) { + alert("All input lists must have the same length"); + return false; + } - /** - * Validates the input values for the simulation. - * - * @param {Object} inputs - The input values to validate. - * - * @param {number[]} inputs.decay_rates - Array of decay rates as floats. - * @param {number[]} inputs.amplitude - Array of amplitudes as floats. - * @param {number[]} inputs.location - Array of locations as floats. - * @param {number[]} inputs.width - Array of widths as floats. - * @param {number[]} inputs.skewness - Array of skewness values as floats. - * @param {number} inputs.wavelength_min - Minimum wavelength value as a float. - * @param {number} inputs.wavelength_max - Maximum wavelength value as a float. - * @param {number} inputs.timepoints_max - Maximum number of timepoints as an integer. - * - * @returns {boolean} True if all inputs are valid, otherwise false. - */ - function validateInputs(inputs) { - try { - const { decay_rates, amplitude, location, width, skewness } = inputs; - - if (decay_rates.some(isNaN)) { - alert('Invalid decay rates'); - return false; - } - if (amplitude.some(isNaN)) { - alert('Invalid amplitudes'); - return false; - } - if (location.some(isNaN)) { - alert('Invalid locations'); - return false; - } - if (width.some(isNaN)) { - alert('Invalid widths'); - return false; - } - if (skewness.some(isNaN)) { - alert('Invalid skewness values'); - return false; - } - - const lengths = [decay_rates.length, amplitude.length, location.length, width.length, skewness.length]; - if (new Set(lengths).size !== 1) { - alert('All input lists must have the same length'); - return false; - } - - if (inputs.wavelength_min >= inputs.wavelength_max || inputs.timepoints_max <= 0) { - alert('Invalid timepoints or wavelength specification'); - return false; - } - - return true; - } catch (error) { - alert('Validation error: ' + error.message); - return false; - } + if ( + inputs.wavelength_min >= inputs.wavelength_max || + inputs.timepoints_max <= 0 + ) { + alert("Invalid timepoints or wavelength specification"); + return false; + } + + return true; + } catch (error) { + alert("Validation error: " + error.message); + return false; } + } + + const btn = document.createElement("button"); + btn.textContent = "Simulate"; + btn.addEventListener("click", function (event) { + event.preventDefault(); + + const convertedInputs = convertInputs(); + if (!convertedInputs) return; + + const isValid = validateInputs(convertedInputs); + if (!isValid) return; + + model.set("decay_rates_input", convertedInputs.decay_rates); + model.set("amplitude_input", convertedInputs.amplitude); + model.set("location_input", convertedInputs.location); + model.set("width_input", convertedInputs.width); + model.set("skewness_input", convertedInputs.skewness); + model.set("timepoints_max_input", convertedInputs.timepoints_max); + model.set("timepoints_stepsize_input", convertedInputs.timepoints_stepsize); + model.set("wavelength_min_input", convertedInputs.wavelength_min); + model.set("wavelength_max_input", convertedInputs.wavelength_max); + model.set("wavelength_stepsize_input", convertedInputs.wavelength_stepsize); + model.set("stdev_noise_input", convertedInputs.stdev_noise); + model.set("seed_input", convertedInputs.seed); + model.set("add_gaussian_irf_input", add_gaussian_irf_input.checked); + model.set("irf_location_input", convertedInputs.irf_location); + model.set("irf_width_input", convertedInputs.irf_width); + model.set( + "use_sequential_scheme_input", + use_sequential_scheme_input.checked, + ); + model.set("model_file_name_input", model_file_name_input.value); + model.set("parameter_file_name_input", parameter_file_name_input.value); + model.set("data_file_name_input", data_file_name_input.value); + model.set("simulate", self.crypto.randomUUID()); - const btn = document.createElement("button"); - btn.textContent = 'Simulate'; - btn.addEventListener('click', function(event) { - event.preventDefault(); - - const convertedInputs = convertInputs(); - if (!convertedInputs) return; - - const isValid = validateInputs(convertedInputs); - if (!isValid) return; - - model.set("decay_rates_input", convertedInputs.decay_rates); - model.set("amplitude_input", convertedInputs.amplitude); - model.set("location_input", convertedInputs.location); - model.set("width_input", convertedInputs.width); - model.set("skewness_input", convertedInputs.skewness); - model.set("timepoints_max_input", convertedInputs.timepoints_max); - model.set("timepoints_stepsize_input", convertedInputs.timepoints_stepsize); - model.set("wavelength_min_input", convertedInputs.wavelength_min); - model.set("wavelength_max_input", convertedInputs.wavelength_max); - model.set("wavelength_stepsize_input", convertedInputs.wavelength_stepsize); - model.set("stdev_noise_input", convertedInputs.stdev_noise); - model.set("seed_input", convertedInputs.seed); - model.set("add_gaussian_irf_input", add_gaussian_irf_input.checked); - model.set("irf_location_input", convertedInputs.irf_location); - model.set("irf_width_input", convertedInputs.irf_width); - model.set("use_sequential_scheme_input", use_sequential_scheme_input.checked); - model.set("model_file_name_input", model_file_name_input.value); - model.set("parameter_file_name_input", parameter_file_name_input.value); - model.set("data_file_name_input", data_file_name_input.value); - model.set("simulate", self.crypto.randomUUID()); - - model.save_changes(); - }); - el.appendChild(btn); + model.save_changes(); + }); + el.appendChild(btn); } export default { render }; diff --git a/pyparamgui/utils.py b/pyparamgui/utils.py index 52aca49..a50eb14 100644 --- a/pyparamgui/utils.py +++ b/pyparamgui/utils.py @@ -1,25 +1,31 @@ -"""This module has various utility functions related to generating files, sanitizing yaml files, etc.""" +"""This module has various utility functions related to generating files, sanitizing yaml files, +etc.""" from __future__ import annotations -from typing import Any, Dict, Union -import yaml import os +from typing import Any +from typing import Dict +from typing import Union import numpy as np -from glotaran.model.model import Model +import yaml from glotaran.builtin.io.yml.yml import save_model +from glotaran.model.model import Model from glotaran.parameter.parameters import Parameters -from glotaran.plugin_system.project_io_registration import save_parameters from glotaran.plugin_system.data_io_registration import save_dataset +from glotaran.plugin_system.project_io_registration import save_parameters from glotaran.simulation.simulation import simulate from pyparamgui.generator import generate_model -from pyparamgui.schema import SimulationConfig, Settings +from pyparamgui.schema import Settings +from pyparamgui.schema import SimulationConfig -def _generate_model_file(simulation_config: SimulationConfig, nr_compartments: int, file_name: str) -> Model: - """ - Generate and save a model file for the simulation. + +def _generate_model_file( + simulation_config: SimulationConfig, nr_compartments: int, file_name: str +) -> Model: + """Generate and save a model file for the simulation. This function generates a model based on the provided simulation configuration and number of compartments. It saves the generated model to a temporary YAML file, sanitizes the file, and then saves it to the specified file name. @@ -32,15 +38,25 @@ def _generate_model_file(simulation_config: SimulationConfig, nr_compartments: i Returns: Model: The generated model. """ - generator_name = "spectral_decay_sequential" if simulation_config.settings.use_sequential_scheme else "spectral_decay_parallel" - model = generate_model(generator_name=generator_name, generator_arguments={"nr_compartments": nr_compartments, "irf": simulation_config.settings.add_gaussian_irf}) + generator_name = ( + "spectral_decay_sequential" + if simulation_config.settings.use_sequential_scheme + else "spectral_decay_parallel" + ) + model = generate_model( + generator_name=generator_name, + generator_arguments={ + "nr_compartments": nr_compartments, + "irf": simulation_config.settings.add_gaussian_irf, + }, + ) save_model(model, "temp_model.yml", allow_overwrite=True) _sanitize_yaml_file("temp_model.yml", file_name) return model + def _update_parameter_values(parameters: Parameters, simulation_config: SimulationConfig): - """ - Update parameter values based on the simulation configuration. + """Update parameter values based on the simulation configuration. This function iterates through all parameters and updates their values according to the provided simulation configuration. It handles parameters related to spectral shapes, @@ -55,34 +71,36 @@ def _update_parameter_values(parameters: Parameters, simulation_config: Simulati """ for param in parameters.all(): label = param.label - if label.startswith('shapes.species_'): - parts = label.split('.') - species_index = int(parts[1].split('_')[1]) - 1 + if label.startswith("shapes.species_"): + parts = label.split(".") + species_index = int(parts[1].split("_")[1]) - 1 attribute = parts[2] - - if attribute == 'amplitude': + + if attribute == "amplitude": param.value = simulation_config.spectral_parameters.amplitude[species_index] - elif attribute == 'location': + elif attribute == "location": param.value = simulation_config.spectral_parameters.location[species_index] - elif attribute == 'width': + elif attribute == "width": param.value = simulation_config.spectral_parameters.width[species_index] - elif attribute == 'skewness': + elif attribute == "skewness": param.value = simulation_config.spectral_parameters.skewness[species_index] - elif label.startswith('rates.species_'): - species_index = int(label.split('_')[1]) - 1 + elif label.startswith("rates.species_"): + species_index = int(label.split("_")[1]) - 1 param.value = simulation_config.kinetic_parameters.decay_rates[species_index] - elif label.startswith('irf') and simulation_config.settings.add_gaussian_irf: - if 'width' in label: + elif label.startswith("irf") and simulation_config.settings.add_gaussian_irf: + if "width" in label: param.value = simulation_config.irf.width - if 'center' in label: + if "center" in label: param.value = simulation_config.irf.center return parameters -def _generate_parameter_file(simulation_config: SimulationConfig, model: Model, file_name: str) -> Parameters: - """ - Generate and save the parameter file for the simulation. + +def _generate_parameter_file( + simulation_config: SimulationConfig, model: Model, file_name: str +) -> Parameters: + """Generate and save the parameter file for the simulation. This function generates the parameters for the given model, updates them based on the simulation configuration, validates the updated parameters, and saves them to a file. @@ -101,9 +119,15 @@ def _generate_parameter_file(simulation_config: SimulationConfig, model: Model, save_parameters(updated_parameters, file_name, allow_overwrite=True) return updated_parameters -def _generate_data_file(model: Model, parameters: Parameters, coordinates: Dict[str, np.ndarray], settings: Settings, file_name: str): - """ - Generate and save the data file for the simulation. + +def _generate_data_file( + model: Model, + parameters: Parameters, + coordinates: Dict[str, np.ndarray], + settings: Settings, + file_name: str, +): + """Generate and save the data file for the simulation. This function simulates the data based on the given model, parameters, coordinates, and settings, and saves the simulated data to a file. @@ -116,12 +140,25 @@ def _generate_data_file(model: Model, parameters: Parameters, coordinates: Dict[ file_name (str): The name of the file to save the simulated data. """ noise = False if settings.stdev_noise == 0 else True - data = simulate(model, "dataset_1", parameters, coordinates, noise=noise, noise_std_dev=settings.stdev_noise, noise_seed=settings.seed) + data = simulate( + model, + "dataset_1", + parameters, + coordinates, + noise=noise, + noise_std_dev=settings.stdev_noise, + noise_seed=settings.seed, + ) save_dataset(data, file_name, "nc", allow_overwrite=True) -def generate_model_parameter_and_data_files(simulation_config: SimulationConfig, model_file_name: str = "model.yml", parameter_file_name: str = "parameters.csv", data_file_name: str = "dataset.nc"): - """ - Generate and save the model, parameter, and data files for the simulation. + +def generate_model_parameter_and_data_files( + simulation_config: SimulationConfig, + model_file_name: str = "model.yml", + parameter_file_name: str = "parameters.csv", + data_file_name: str = "dataset.nc", +): + """Generate and save the model, parameter, and data files for the simulation. This function generates the model file, parameter file, and data file based on the given simulation configuration. @@ -134,11 +171,18 @@ def generate_model_parameter_and_data_files(simulation_config: SimulationConfig, nr_compartments = len(simulation_config.kinetic_parameters.decay_rates) model = _generate_model_file(simulation_config, nr_compartments, model_file_name) parameters = _generate_parameter_file(simulation_config, model, parameter_file_name) - _generate_data_file(model, parameters, simulation_config.coordinates, simulation_config.settings, data_file_name) + _generate_data_file( + model, + parameters, + simulation_config.coordinates, + simulation_config.settings, + data_file_name, + ) + def _sanitize_dict(d: Union[Dict[str, Any], Any]) -> Union[Dict[str, Any], Any]: - """ - Recursively sanitize a dictionary by removing keys with values that are None, empty lists, or empty dictionaries. + """Recursively sanitize a dictionary by removing keys with values that are None, empty lists, + or empty dictionaries. Args: d (Union[Dict[str, Any], Any]): The dictionary to sanitize or any other value. @@ -150,20 +194,20 @@ def _sanitize_dict(d: Union[Dict[str, Any], Any]) -> Union[Dict[str, Any], Any]: return d return {k: _sanitize_dict(v) for k, v in d.items() if v not in (None, [], {})} + def _sanitize_yaml_file(input_file: str, output_file: str) -> None: - """ - Sanitize a YAML file by removing keys with values that are None, empty lists, or empty dictionaries, - and save the sanitized content to a new file. + """Sanitize a YAML file by removing keys with values that are None, empty lists, or empty + dictionaries, and save the sanitized content to a new file. Args: input_file (str): The path to the input YAML file. output_file (str): The path to the output sanitized YAML file. """ - with open(input_file, 'r') as f: + with open(input_file, "r") as f: data = yaml.safe_load(f) - + sanitized_data = _sanitize_dict(data) - - with open(output_file, 'w') as f: + + with open(output_file, "w") as f: yaml.safe_dump(sanitized_data, f) os.remove(input_file) diff --git a/pyparamgui/widget.py b/pyparamgui/widget.py index af8a021..2838eda 100644 --- a/pyparamgui/widget.py +++ b/pyparamgui/widget.py @@ -4,15 +4,22 @@ import pathlib -import traitlets import anywidget +import traitlets -from pyparamgui.schema import KineticParameters, SpectralParameters, TimeCoordinates, SpectralCoordinates, Settings, IRF, SimulationConfig, generate_simulation_coordinates +from pyparamgui.schema import IRF +from pyparamgui.schema import KineticParameters +from pyparamgui.schema import Settings +from pyparamgui.schema import SimulationConfig +from pyparamgui.schema import SpectralCoordinates +from pyparamgui.schema import SpectralParameters +from pyparamgui.schema import TimeCoordinates +from pyparamgui.schema import generate_simulation_coordinates from pyparamgui.utils import generate_model_parameter_and_data_files + class Widget(anywidget.AnyWidget): - """ - A widget class for handling simulation parameters, coordinates and settings. + """A widget class for handling simulation parameters, coordinates and settings. Attributes: _esm (pathlib.Path): Path to the JavaScript file for the widget. @@ -38,6 +45,7 @@ class Widget(anywidget.AnyWidget): data_file_name_input (traitlets.Unicode): Name of the data file. simulate (traitlets.Unicode): Trigger for simulation. """ + _esm: pathlib.Path = pathlib.Path(__file__).parent / "static" / "form.js" _css: pathlib.Path = pathlib.Path(__file__).parent / "static" / "form.css" decay_rates_input: traitlets.List = traitlets.List(trait=traitlets.Float()).tag(sync=True) @@ -61,61 +69,60 @@ class Widget(anywidget.AnyWidget): data_file_name_input: traitlets.Unicode = traitlets.Unicode("").tag(sync=True) simulate: traitlets.Unicode = traitlets.Unicode("").tag(sync=True) + widget = Widget() + def _simulate(change) -> None: - """ - A Private callback function for simulating the data based on the parameters, coordinates, and other simulation settings. - + """A Private callback function for simulating the data based on the parameters, coordinates, + and other simulation settings. + This function generates the model, parameter, and data files using the provided widget inputs. The 'change' parameter is not used within this function, but it is required to be present - because it represents the state change of the traitlets. This is a common pattern when - using traitlets to observe changes in widget state. + because it represents the state change of the traitlets. This is a common pattern when using + traitlets to observe changes in widget state. """ simulation_config = SimulationConfig( - kinetic_parameters=KineticParameters( - decay_rates=widget.decay_rates_input - ), + kinetic_parameters=KineticParameters(decay_rates=widget.decay_rates_input), spectral_parameters=SpectralParameters( amplitude=widget.amplitude_input, location=widget.location_input, width=widget.width_input, - skewness=widget.skewness_input + skewness=widget.skewness_input, ), coordinates=generate_simulation_coordinates( TimeCoordinates( timepoints_max=widget.timepoints_max_input, - timepoints_stepsize=widget.timepoints_stepsize_input + timepoints_stepsize=widget.timepoints_stepsize_input, ), SpectralCoordinates( wavelength_min=widget.wavelength_min_input, wavelength_max=widget.wavelength_max_input, - wavelength_stepsize=widget.wavelength_stepsize_input - ) + wavelength_stepsize=widget.wavelength_stepsize_input, + ), ), settings=Settings( stdev_noise=widget.stdev_noise_input, seed=widget.seed_input, add_gaussian_irf=widget.add_gaussian_irf_input, - use_sequential_scheme=widget.use_sequential_scheme_input + use_sequential_scheme=widget.use_sequential_scheme_input, ), - irf=IRF( - center=widget.irf_location_input, - width=widget.irf_width_input - ) + irf=IRF(center=widget.irf_location_input, width=widget.irf_width_input), ) generate_model_parameter_and_data_files( simulation_config, model_file_name=widget.model_file_name_input, parameter_file_name=widget.parameter_file_name_input, - data_file_name=widget.data_file_name_input + data_file_name=widget.data_file_name_input, ) + def setup_widget_observer() -> None: + """Sets up the observer pattern on the 'simulate' traitlet to synchronize the frontend widget + with the backend simulation code. + + This function ensures that any changes in the widget's state trigger the simulation process, + which generates the model, parameter, and data files. """ - Sets up the observer pattern on the 'simulate' traitlet to synchronize the frontend widget - with the backend simulation code. This function ensures that any changes in the widget's state - trigger the simulation process, which generates the model, parameter, and data files. - """ - widget.observe(handler=_simulate, names=['simulate']) + widget.observe(handler=_simulate, names=["simulate"]) diff --git a/pyproject.toml b/pyproject.toml index efe0ef6..34635cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,8 +35,8 @@ dynamic = [ ] dependencies = [ - "pydantic==2.8.2", "anywidget==0.9.13", + "pydantic==2.8.2", "pyglotaran==0.7.2", "pyyaml==6.0.1", ] diff --git a/requirements_pinned.txt b/requirements_pinned.txt index 14f376c..77b15c1 100644 --- a/requirements_pinned.txt +++ b/requirements_pinned.txt @@ -4,3 +4,4 @@ pydantic==2.8.2 anywidget==0.9.13 pyglotaran==0.7.2 +pyyaml==6.0.1 From 410f32767ae784ee18f313a446819e089ec32605 Mon Sep 17 00:00:00 2001 From: anmolbhatia05 Date: Fri, 26 Jul 2024 19:55:18 +0200 Subject: [PATCH 04/10] ruff automatic fixes --- pyparamgui/schema.py | 27 ++++++++++++++++++--------- pyparamgui/utils.py | 17 +++++++++++------ pyparamgui/widget.py | 3 ++- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/pyparamgui/schema.py b/pyparamgui/schema.py index 48ac321..0d3f71c 100644 --- a/pyparamgui/schema.py +++ b/pyparamgui/schema.py @@ -1,5 +1,6 @@ """This module has the different model classes representing different parameters, coordinates, and -settings for simulation.""" +settings for simulation. +""" from __future__ import annotations @@ -13,7 +14,8 @@ class KineticParameters(BaseModel): """Kinetic parameters for the simulation. - Attributes: + Attributes + ---------- decay_rates (list[float]): List of decay rates. """ @@ -23,7 +25,8 @@ class KineticParameters(BaseModel): class SpectralParameters(BaseModel): """Spectral parameters for the simulation. - Attributes: + Attributes + ---------- amplitude (list[float]): List of amplitudes. location (list[float]): List of locations. width (list[float]): List of widths. @@ -39,7 +42,8 @@ class SpectralParameters(BaseModel): class TimeCoordinates(BaseModel): """Time coordinates for the simulation. - Attributes: + Attributes + ---------- timepoints_max (int): Maximum number of time points. timepoints_stepsize (float): Step size between time points. """ @@ -51,7 +55,8 @@ class TimeCoordinates(BaseModel): class SpectralCoordinates(BaseModel): """Spectral coordinates for the simulation. - Attributes: + Attributes + ---------- wavelength_min (int): Minimum wavelength. wavelength_max (int): Maximum wavelength. wavelength_stepsize (float): Step size between wavelengths. @@ -71,7 +76,8 @@ def generate_simulation_coordinates( time_coordinates (TimeCoordinates): The time coordinates for the simulation. spectral_coordinates (SpectralCoordinates): The spectral coordinates for the simulation. - Returns: + Returns + ------- Dict[str, np.ndarray]: A dictionary containing the time and spectral axes as numpy arrays. """ time_axis = np.arange( @@ -90,7 +96,8 @@ def generate_simulation_coordinates( class Settings(BaseModel): """Other settings for the simulation. - Attributes: + Attributes + ---------- stdev_noise (float): Standard deviation of the noise to be added to the simulation data. seed (int): Seed for the random number generator to ensure reproducibility. add_gaussian_irf (bool): Flag to indicate whether to add a Gaussian Instrument Response Function (IRF) to the simulation. Default is False. @@ -106,7 +113,8 @@ class Settings(BaseModel): class IRF(BaseModel): """Instrument Response Function (IRF) settings for the simulation. - Attributes: + Attributes + ---------- center (float): The center position of the IRF. width (float): The width of the IRF. """ @@ -118,7 +126,8 @@ class IRF(BaseModel): class SimulationConfig(BaseModel): """Configuration for the simulation, combining various parameters and settings. - Attributes: + Attributes + ---------- kinetic_parameters (KineticParameters): Kinetic parameters for the simulation. spectral_parameters (SpectralParameters): Spectral parameters for the simulation. coordinates (Dict[str, np.ndarray]): Dictionary containing the time and spectral axes as numpy arrays. diff --git a/pyparamgui/utils.py b/pyparamgui/utils.py index a50eb14..e12c26a 100644 --- a/pyparamgui/utils.py +++ b/pyparamgui/utils.py @@ -1,5 +1,6 @@ """This module has various utility functions related to generating files, sanitizing yaml files, -etc.""" +etc. +""" from __future__ import annotations @@ -35,7 +36,8 @@ def _generate_model_file( nr_compartments (int): The number of compartments in the model. file_name (str): The name of the file to save the sanitized model. - Returns: + Returns + ------- Model: The generated model. """ generator_name = ( @@ -66,7 +68,8 @@ def _update_parameter_values(parameters: Parameters, simulation_config: Simulati parameters (Parameters): The parameters to be updated. simulation_config (SimulationConfig): The configuration containing the new values for the parameters. - Returns: + Returns + ------- Parameters: The updated parameters. """ for param in parameters.all(): @@ -110,7 +113,8 @@ def _generate_parameter_file( model (Model): The model for which parameters are to be generated. file_name (str): The name of the file to save the parameters. - Returns: + Returns + ------- Parameters: The updated and validated parameters. """ parameters = model.generate_parameters() @@ -187,7 +191,8 @@ def _sanitize_dict(d: Union[Dict[str, Any], Any]) -> Union[Dict[str, Any], Any]: Args: d (Union[Dict[str, Any], Any]): The dictionary to sanitize or any other value. - Returns: + Returns + ------- Union[Dict[str, Any], Any]: The sanitized dictionary or the original value if it is not a dictionary. """ if not isinstance(d, dict): @@ -203,7 +208,7 @@ def _sanitize_yaml_file(input_file: str, output_file: str) -> None: input_file (str): The path to the input YAML file. output_file (str): The path to the output sanitized YAML file. """ - with open(input_file, "r") as f: + with open(input_file) as f: data = yaml.safe_load(f) sanitized_data = _sanitize_dict(data) diff --git a/pyparamgui/widget.py b/pyparamgui/widget.py index 2838eda..852cfd4 100644 --- a/pyparamgui/widget.py +++ b/pyparamgui/widget.py @@ -21,7 +21,8 @@ class Widget(anywidget.AnyWidget): """A widget class for handling simulation parameters, coordinates and settings. - Attributes: + Attributes + ---------- _esm (pathlib.Path): Path to the JavaScript file for the widget. _css (pathlib.Path): Path to the CSS file for the widget. decay_rates_input (traitlets.List): List of decay rates as floats. From f6f62ff38dcd074bca3c4afe653520d57da091a1 Mon Sep 17 00:00:00 2001 From: anmolbhatia05 Date: Fri, 26 Jul 2024 19:59:16 +0200 Subject: [PATCH 05/10] accepted ruff unsafe fixes --- pyparamgui/generator.py | 15 ++++++++++++--- pyparamgui/schema.py | 6 ++---- pyparamgui/utils.py | 22 ++++++++++++---------- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/pyparamgui/generator.py b/pyparamgui/generator.py index 5906897..e75c75a 100644 --- a/pyparamgui/generator.py +++ b/pyparamgui/generator.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Callable +from typing import TYPE_CHECKING from typing import Any from typing import TypedDict from typing import cast @@ -13,6 +13,9 @@ from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex from glotaran.model import Model +if TYPE_CHECKING: + from collections.abc import Callable + def _generate_decay_model( *, nr_compartments: int, irf: bool, spectral: bool, decay_type: str @@ -221,10 +224,13 @@ def generate_model(*, generator_name: str, generator_arguments: GeneratorArgumen Raised when an unknown generator is specified. """ if generator_name not in generators: - raise ValueError( + msg = ( f"Unknown model generator '{generator_name}'. " f"Known generators are: {list(generators.keys())}" ) + raise ValueError( + msg + ) model = generators[generator_name](**generator_arguments) return Model.create_class_from_megacomplexes( [DecayParallelMegacomplex, DecaySequentialMegacomplex, SpectralMegacomplex] @@ -259,9 +265,12 @@ def generate_model_yml(*, generator_name: str, generator_arguments: GeneratorArg Raised when an unknown generator is specified. """ if generator_name not in generators: - raise ValueError( + msg = ( f"Unknown model generator '{generator_name}'. " f"Known generators are: {list(generators.keys())}" ) + raise ValueError( + msg + ) model = generators[generator_name](**generator_arguments) return cast(str, write_dict(model)) diff --git a/pyparamgui/schema.py b/pyparamgui/schema.py index 0d3f71c..f545868 100644 --- a/pyparamgui/schema.py +++ b/pyparamgui/schema.py @@ -4,8 +4,6 @@ from __future__ import annotations -from typing import Dict - import numpy as np from pydantic import BaseModel from pydantic import ConfigDict @@ -69,7 +67,7 @@ class SpectralCoordinates(BaseModel): def generate_simulation_coordinates( time_coordinates: TimeCoordinates, spectral_coordinates: SpectralCoordinates -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """Generate simulation coordinates based on time and spectral coordinates. Args: @@ -137,7 +135,7 @@ class SimulationConfig(BaseModel): kinetic_parameters: KineticParameters spectral_parameters: SpectralParameters - coordinates: Dict[str, np.ndarray] + coordinates: dict[str, np.ndarray] settings: Settings irf: IRF diff --git a/pyparamgui/utils.py b/pyparamgui/utils.py index e12c26a..b4258df 100644 --- a/pyparamgui/utils.py +++ b/pyparamgui/utils.py @@ -5,22 +5,24 @@ from __future__ import annotations import os +from typing import TYPE_CHECKING from typing import Any -from typing import Dict -from typing import Union -import numpy as np import yaml from glotaran.builtin.io.yml.yml import save_model -from glotaran.model.model import Model -from glotaran.parameter.parameters import Parameters from glotaran.plugin_system.data_io_registration import save_dataset from glotaran.plugin_system.project_io_registration import save_parameters from glotaran.simulation.simulation import simulate from pyparamgui.generator import generate_model -from pyparamgui.schema import Settings -from pyparamgui.schema import SimulationConfig + +if TYPE_CHECKING: + import numpy as np + from glotaran.model.model import Model + from glotaran.parameter.parameters import Parameters + + from pyparamgui.schema import Settings + from pyparamgui.schema import SimulationConfig def _generate_model_file( @@ -127,7 +129,7 @@ def _generate_parameter_file( def _generate_data_file( model: Model, parameters: Parameters, - coordinates: Dict[str, np.ndarray], + coordinates: dict[str, np.ndarray], settings: Settings, file_name: str, ): @@ -143,7 +145,7 @@ def _generate_data_file( settings (Settings): The settings for the simulation. file_name (str): The name of the file to save the simulated data. """ - noise = False if settings.stdev_noise == 0 else True + noise = settings.stdev_noise != 0 data = simulate( model, "dataset_1", @@ -184,7 +186,7 @@ def generate_model_parameter_and_data_files( ) -def _sanitize_dict(d: Union[Dict[str, Any], Any]) -> Union[Dict[str, Any], Any]: +def _sanitize_dict(d: dict[str, Any] | Any) -> dict[str, Any] | Any: """Recursively sanitize a dictionary by removing keys with values that are None, empty lists, or empty dictionaries. From 64a525d2db9f444c22039e22a8b7aaa3de21b39d Mon Sep 17 00:00:00 2001 From: anmolbhatia05 Date: Fri, 26 Jul 2024 20:54:26 +0200 Subject: [PATCH 06/10] fixed some more ruff issues --- pyparamgui/__init__.py | 13 +++++++++++++ pyparamgui/generator.py | 8 ++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/pyparamgui/__init__.py b/pyparamgui/__init__.py index 16d4f9f..05d0604 100644 --- a/pyparamgui/__init__.py +++ b/pyparamgui/__init__.py @@ -10,3 +10,16 @@ from pyparamgui.widget import widget __all__ = ["widget", "setup_widget_observer"] + +""" +Package Usage: + # considering that each command is run in a separate jupyter notebook cell + + %env ANYWIDGET_HMR=1 + + from pyparamgui import widget, setup_widget_observer + + widget + + setup_widget_observer() +""" diff --git a/pyparamgui/generator.py b/pyparamgui/generator.py index e75c75a..b7856b9 100644 --- a/pyparamgui/generator.py +++ b/pyparamgui/generator.py @@ -228,9 +228,7 @@ def generate_model(*, generator_name: str, generator_arguments: GeneratorArgumen f"Unknown model generator '{generator_name}'. " f"Known generators are: {list(generators.keys())}" ) - raise ValueError( - msg - ) + raise ValueError(msg) model = generators[generator_name](**generator_arguments) return Model.create_class_from_megacomplexes( [DecayParallelMegacomplex, DecaySequentialMegacomplex, SpectralMegacomplex] @@ -269,8 +267,6 @@ def generate_model_yml(*, generator_name: str, generator_arguments: GeneratorArg f"Unknown model generator '{generator_name}'. " f"Known generators are: {list(generators.keys())}" ) - raise ValueError( - msg - ) + raise ValueError(msg) model = generators[generator_name](**generator_arguments) return cast(str, write_dict(model)) From 769f17719eaf78f8eaf7d7df3a8cb73ce0d6f9e2 Mon Sep 17 00:00:00 2001 From: Joris Snellenburg Date: Sat, 27 Jul 2024 14:51:28 +0200 Subject: [PATCH 07/10] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactored=20code=20?= =?UTF-8?q?for=20better=20testability=20and=20added=20tests=20=F0=9F=A7=AA?= =?UTF-8?q?=20(#7)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Refactored several functions to improve readability and maintainability - Replaced os module with pathlib for file operations - Improved docstrings for clarity and consistency and satify linters - Renamed 'widget' to '_widget' in widget.py to facilitate unit testing - Added a new test case for the _simulate function in widget.py - --- .gitignore | 3 + .pre-commit-config.yaml | 5 +- pyparamgui/__init__.py | 5 +- pyparamgui/generator.py | 44 +++++++------ pyparamgui/schema.py | 24 ++++--- pyparamgui/utils.py | 138 +++++++++++++++++++++++++--------------- pyparamgui/widget.py | 62 +++++++++--------- pyproject.toml | 4 +- tests/test_simulate.py | 93 +++++++++++++++++++++++++++ 9 files changed, 259 insertions(+), 119 deletions(-) create mode 100644 tests/test_simulate.py diff --git a/.gitignore b/.gitignore index 3a2eb3e..903bd80 100644 --- a/.gitignore +++ b/.gitignore @@ -104,3 +104,6 @@ ENV/ # IDE settings .vscode/ + +examples/ +!examples/*.ipynb diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e09c42..d3def5e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -74,13 +74,13 @@ repos: alias: flake8-docs args: - "--select=DOC" - - "--extend-ignore=DOC502" + - "--extend-ignore=DOC502,DOC601,DOC603,DOC101,DOC103,DOC201" - "--color=always" - "--require-return-section-when-returning-nothing=False" - "--allow-init-docstring=True" - "--skip-checking-short-docstrings=False" name: "flake8 lint docstrings" - exclude: "^(docs/|tests?/)" + exclude: "^(docs/|tests?/|pyparamgui/generator.py)" additional_dependencies: [pydoclint==0.5.3] - repo: https://github.com/econchick/interrogate @@ -97,6 +97,7 @@ repos: types: [file] types_or: [python, pyi, markdown, rst, jupyter] args: [-L nnumber] + exclude: ^examples/ - repo: https://github.com/rhysd/actionlint rev: "v1.7.1" diff --git a/pyparamgui/__init__.py b/pyparamgui/__init__.py index 05d0604..0391a6d 100644 --- a/pyparamgui/__init__.py +++ b/pyparamgui/__init__.py @@ -6,11 +6,10 @@ __email__ = "anmolbhatia05@gmail.com" __version__ = "0.0.1" +from pyparamgui.widget import _widget from pyparamgui.widget import setup_widget_observer -from pyparamgui.widget import widget - -__all__ = ["widget", "setup_widget_observer"] +__all__ = ["_widget", "setup_widget_observer"] """ Package Usage: # considering that each command is run in a separate jupyter notebook cell diff --git a/pyparamgui/generator.py b/pyparamgui/generator.py index b7856b9..9d49b6b 100644 --- a/pyparamgui/generator.py +++ b/pyparamgui/generator.py @@ -31,16 +31,17 @@ def _generate_decay_model( spectral : bool Whether to add a spectral model. decay_type : str - The dype of the decay + The type of the decay Returns ------- - dict[str, Any] : + dict[str, Any] The generated model dictionary. """ compartments = [f"species_{i+1}" for i in range(nr_compartments)] rates = [f"rates.species_{i+1}" for i in range(nr_compartments)] - model = { + + model: dict[str, Any] = { "megacomplex": { f"megacomplex_{decay_type}_decay": { "type": f"decay-{decay_type}", @@ -51,11 +52,14 @@ def _generate_decay_model( "dataset": {"dataset_1": {"megacomplex": [f"megacomplex_{decay_type}_decay"]}}, } if spectral: - model["megacomplex"]["megacomplex_spectral"] = { # type:ignore[index] - "type": "spectral", - "shape": { - compartment: f"shape_species_{i+1}" for i, compartment in enumerate(compartments) - }, + model["megacomplex"] |= { + "megacomplex_spectral": { + "type": "spectral", + "shape": { + compartment: f"shape_species_{i+1}" + for i, compartment in enumerate(compartments) + }, + } } model["shape"] = { f"shape_species_{i+1}": { @@ -67,13 +71,13 @@ def _generate_decay_model( } for i in range(nr_compartments) } - model["dataset"]["dataset_1"]["global_megacomplex"] = [ # type:ignore[index] - "megacomplex_spectral" - ] - model["dataset"]["dataset_1"]["spectral_axis_inverted"] = True - model["dataset"]["dataset_1"]["spectral_axis_scale"] = 1e7 + model["dataset"]["dataset_1"] |= { + "global_megacomplex": ["megacomplex_spectral"], + "spectral_axis_inverted": True, + "spectral_axis_scale": 1e7, + } if irf: - model["dataset"]["dataset_1"]["irf"] = "gaussian_irf" # type:ignore[index] + model["dataset"]["dataset_1"] |= {"irf": "gaussian_irf"} model["irf"] = { "gaussian_irf": {"type": "gaussian", "center": "irf.center", "width": "irf.width"}, } @@ -94,7 +98,7 @@ def generate_parallel_decay_model( Returns ------- - dict[str, Any] : + dict[str, Any] The generated model dictionary. """ return _generate_decay_model( @@ -116,7 +120,7 @@ def generate_parallel_spectral_decay_model( Returns ------- - dict[str, Any] : + dict[str, Any] The generated model dictionary. """ return _generate_decay_model( @@ -124,7 +128,9 @@ def generate_parallel_spectral_decay_model( ) -def generate_sequential_decay_model(nr_compartments: int = 1, irf: bool = False) -> dict[str, Any]: +def generate_sequential_decay_model( + *, nr_compartments: int = 1, irf: bool = False +) -> dict[str, Any]: """Generate a sequential decay model dictionary. Parameters @@ -136,7 +142,7 @@ def generate_sequential_decay_model(nr_compartments: int = 1, irf: bool = False) Returns ------- - dict[str, Any] : + dict[str, Any] The generated model dictionary. """ return _generate_decay_model( @@ -158,7 +164,7 @@ def generate_sequential_spectral_decay_model( Returns ------- - dict[str, Any] : + dict[str, Any] The generated model dictionary. """ return _generate_decay_model( diff --git a/pyparamgui/schema.py b/pyparamgui/schema.py index f545868..260177c 100644 --- a/pyparamgui/schema.py +++ b/pyparamgui/schema.py @@ -1,5 +1,6 @@ -"""This module has the different model classes representing different parameters, coordinates, and -settings for simulation. +"""Schema module representing attributes used for simulation. + +e.g. parameters, coordinates, and settings used in simulation. """ from __future__ import annotations @@ -76,7 +77,10 @@ def generate_simulation_coordinates( Returns ------- - Dict[str, np.ndarray]: A dictionary containing the time and spectral axes as numpy arrays. + dict[str, np.ndarray] + A dictionary containing two keys: + - 'time': A numpy array representing the time axis. + - 'spectral': A numpy array representing the spectral axis. """ time_axis = np.arange( 0, @@ -98,8 +102,10 @@ class Settings(BaseModel): ---------- stdev_noise (float): Standard deviation of the noise to be added to the simulation data. seed (int): Seed for the random number generator to ensure reproducibility. - add_gaussian_irf (bool): Flag to indicate whether to add a Gaussian Instrument Response Function (IRF) to the simulation. Default is False. - use_sequential_scheme (bool): Flag to indicate whether to use a sequential scheme in the simulation. Default is False. + add_gaussian_irf (bool): Whether to add a Gaussian IRF to the simulation. + Default is False. + use_sequential_scheme (bool): Whether to use a sequential scheme in the simulation. + Default is False. """ stdev_noise: float @@ -128,9 +134,11 @@ class SimulationConfig(BaseModel): ---------- kinetic_parameters (KineticParameters): Kinetic parameters for the simulation. spectral_parameters (SpectralParameters): Spectral parameters for the simulation. - coordinates (Dict[str, np.ndarray]): Dictionary containing the time and spectral axes as numpy arrays. - settings (Settings): Other settings for the simulation, including noise standard deviation, random seed, and flags for adding Gaussian IRF and using a sequential scheme. - irf (IRF): Instrument Response Function (IRF) settings, including center position and width. + coordinates (Dict[str, np.ndarray]): Dictionary containing the time and spectral axes as + numpy arrays. + settings (Settings): Other settings for the simulation, including noise standard deviation, + random seed, and flags for adding Gaussian IRF and using a sequential scheme. + irf (IRF): Instrument Response Function (IRF) settings, e.g. center position and width. """ kinetic_parameters: KineticParameters diff --git a/pyparamgui/utils.py b/pyparamgui/utils.py index b4258df..6d92b3c 100644 --- a/pyparamgui/utils.py +++ b/pyparamgui/utils.py @@ -1,10 +1,8 @@ -"""This module has various utility functions related to generating files, sanitizing yaml files, -etc. -""" +"""Utility module for generating files, sanitizing yaml files, etc.""" from __future__ import annotations -import os +from pathlib import Path from typing import TYPE_CHECKING from typing import Any @@ -19,6 +17,7 @@ if TYPE_CHECKING: import numpy as np from glotaran.model.model import Model + from glotaran.parameter.parameter import Parameter from glotaran.parameter.parameters import Parameters from pyparamgui.schema import Settings @@ -30,8 +29,10 @@ def _generate_model_file( ) -> Model: """Generate and save a model file for the simulation. - This function generates a model based on the provided simulation configuration and number of compartments. - It saves the generated model to a temporary YAML file, sanitizes the file, and then saves it to the specified file name. + This function generates a model based on the provided simulation configuration + and number of compartments. + It saves the generated model to a temporary YAML file, sanitizes the file, + and then saves it to the specified file name. Args: simulation_config (SimulationConfig): The configuration for the simulation. @@ -40,7 +41,8 @@ def _generate_model_file( Returns ------- - Model: The generated model. + Model + The generated model object, which can be used for further processing or simulation. """ generator_name = ( "spectral_decay_sequential" @@ -59,55 +61,77 @@ def _generate_model_file( return model -def _update_parameter_values(parameters: Parameters, simulation_config: SimulationConfig): +def _update_parameter_values( + parameters: Parameters, simulation_config: SimulationConfig +) -> Parameters: """Update parameter values based on the simulation configuration. - This function iterates through all parameters and updates their values according to the - provided simulation configuration. It handles parameters related to spectral shapes, - kinetic rates, and IRF (Instrument Response Function). + This function iterates through all parameters and updates their values + based on the provided simulation configuration. It handles shape parameters, + rate parameters, and IRF (Instrument Response Function) parameters. - Args: - parameters (Parameters): The parameters to be updated. - simulation_config (SimulationConfig): The configuration containing the new values for the parameters. + Parameters + ---------- + parameters : Parameters + The set of parameters to be updated. + simulation_config : SimulationConfig + The configuration object containing the simulation settings and parameter values. Returns ------- - Parameters: The updated parameters. + `Parameters` + The updated set of parameters with new values based on the + simulation configuration. """ for param in parameters.all(): label = param.label if label.startswith("shapes.species_"): - parts = label.split(".") - species_index = int(parts[1].split("_")[1]) - 1 - attribute = parts[2] - - if attribute == "amplitude": - param.value = simulation_config.spectral_parameters.amplitude[species_index] - elif attribute == "location": - param.value = simulation_config.spectral_parameters.location[species_index] - elif attribute == "width": - param.value = simulation_config.spectral_parameters.width[species_index] - elif attribute == "skewness": - param.value = simulation_config.spectral_parameters.skewness[species_index] - + _update_shape_parameter(param, label, simulation_config) elif label.startswith("rates.species_"): - species_index = int(label.split("_")[1]) - 1 - param.value = simulation_config.kinetic_parameters.decay_rates[species_index] - + _update_rate_parameter(param, label, simulation_config) elif label.startswith("irf") and simulation_config.settings.add_gaussian_irf: - if "width" in label: - param.value = simulation_config.irf.width - if "center" in label: - param.value = simulation_config.irf.center + _update_irf_parameter(param, label, simulation_config) return parameters +def _update_shape_parameter(param: Parameter, label: str, simulation_config: SimulationConfig): + """Update shape parameters.""" + parts = label.split(".") + species_index = int(parts[1].split("_")[1]) - 1 + attribute = parts[2] + spectral_params = simulation_config.spectral_parameters + + if attribute == "amplitude": + param.value = spectral_params.amplitude[species_index] + elif attribute == "location": + param.value = spectral_params.location[species_index] + elif attribute == "width": + param.value = spectral_params.width[species_index] + elif attribute == "skewness": + param.value = spectral_params.skewness[species_index] + + +def _update_rate_parameter(param: Parameter, label: str, simulation_config: SimulationConfig): + """Update rate parameters.""" + species_index = int(label.split("_")[1]) - 1 + param.value = simulation_config.kinetic_parameters.decay_rates[species_index] + + +def _update_irf_parameter(param: Parameter, label: str, simulation_config: SimulationConfig): + """Update IRF parameters.""" + if "width" in label: + param.value = simulation_config.irf.width + elif "center" in label: + param.value = simulation_config.irf.center + + def _generate_parameter_file( simulation_config: SimulationConfig, model: Model, file_name: str ) -> Parameters: """Generate and save the parameter file for the simulation. - This function generates the parameters for the given model, updates them based on the simulation configuration, + This function generates the parameters for the given model, + updates them based on the simulation configuration, validates the updated parameters, and saves them to a file. Args: @@ -117,7 +141,8 @@ def _generate_parameter_file( Returns ------- - Parameters: The updated and validated parameters. + `Parameters` + The updated and validated parameters. """ parameters = model.generate_parameters() updated_parameters = _update_parameter_values(parameters, simulation_config) @@ -135,8 +160,8 @@ def _generate_data_file( ): """Generate and save the data file for the simulation. - This function simulates the data based on the given model, parameters, coordinates, and settings, - and saves the simulated data to a file. + This function simulates the data based on the given model, parameters, coordinates, + and settings, and saves the simulated data to a file. Args: model (Model): The model used for simulation. @@ -166,13 +191,17 @@ def generate_model_parameter_and_data_files( ): """Generate and save the model, parameter, and data files for the simulation. - This function generates the model file, parameter file, and data file based on the given simulation configuration. + This function generates the model file, parameter file, and data file based on the given + simulation configuration. Args: simulation_config (SimulationConfig): The configuration for the simulation. - model_file_name (str, optional): The name of the file to save the model. Defaults to "model.yml". - parameter_file_name (str, optional): The name of the file to save the parameters. Defaults to "parameters.csv". - data_file_name (str, optional): The name of the file to save the data. Defaults to "dataset.nc". + model_file_name (str, optional): The name of the file to save the model. + Defaults to "model.yml". + parameter_file_name (str, optional): The name of the file to save the parameters. + Defaults to "parameters.csv". + data_file_name (str, optional): The name of the file to save the data. + Defaults to "dataset.nc". """ nr_compartments = len(simulation_config.kinetic_parameters.decay_rates) model = _generate_model_file(simulation_config, nr_compartments, model_file_name) @@ -187,15 +216,20 @@ def generate_model_parameter_and_data_files( def _sanitize_dict(d: dict[str, Any] | Any) -> dict[str, Any] | Any: - """Recursively sanitize a dictionary by removing keys with values that are None, empty lists, + """Sanitize an input dictionary and produce a new sanitized dictionary. + + Recursively sanitize a dictionary by removing keys with values that are None, empty lists, or empty dictionaries. - Args: - d (Union[Dict[str, Any], Any]): The dictionary to sanitize or any other value. + Parameters + ---------- + d : dict[str, Any] | Any + The dictionary to sanitize or any other value. Returns ------- - Union[Dict[str, Any], Any]: The sanitized dictionary or the original value if it is not a dictionary. + dict[str, Any] | Any + The sanitized dict or the original value if input is not a dict. """ if not isinstance(d, dict): return d @@ -203,18 +237,20 @@ def _sanitize_dict(d: dict[str, Any] | Any) -> dict[str, Any] | Any: def _sanitize_yaml_file(input_file: str, output_file: str) -> None: - """Sanitize a YAML file by removing keys with values that are None, empty lists, or empty + """Sanitize an input YAML file and produce a new sanitized YAML file. + + Sanitize by removing keys with values that are None, empty lists, or empty dictionaries, and save the sanitized content to a new file. Args: input_file (str): The path to the input YAML file. output_file (str): The path to the output sanitized YAML file. """ - with open(input_file) as f: + with Path(input_file).open() as f: data = yaml.safe_load(f) sanitized_data = _sanitize_dict(data) - with open(output_file, "w") as f: + with Path(output_file).open("w") as f: yaml.safe_dump(sanitized_data, f) - os.remove(input_file) + Path(input_file).unlink() diff --git a/pyparamgui/widget.py b/pyparamgui/widget.py index 852cfd4..eb2834c 100644 --- a/pyparamgui/widget.py +++ b/pyparamgui/widget.py @@ -1,4 +1,4 @@ -"""This module contains the simulation widget.""" +"""Simulation widget module.""" from __future__ import annotations @@ -71,59 +71,55 @@ class Widget(anywidget.AnyWidget): simulate: traitlets.Unicode = traitlets.Unicode("").tag(sync=True) -widget = Widget() +_widget = Widget() -def _simulate(change) -> None: - """A Private callback function for simulating the data based on the parameters, coordinates, - and other simulation settings. +def _simulate(_) -> None: + """Generate simulation files based on (global) widget (`_widget`) inputs. - This function generates the model, parameter, and data files using the provided widget inputs. - - The 'change' parameter is not used within this function, but it is required to be present - because it represents the state change of the traitlets. This is a common pattern when using - traitlets to observe changes in widget state. + This private callback function creates model, parameter, and data files + using the current widget (`_widget`) state. The 'change' parameter is unused but + required for traitlet observation. """ simulation_config = SimulationConfig( - kinetic_parameters=KineticParameters(decay_rates=widget.decay_rates_input), + kinetic_parameters=KineticParameters(decay_rates=_widget.decay_rates_input), spectral_parameters=SpectralParameters( - amplitude=widget.amplitude_input, - location=widget.location_input, - width=widget.width_input, - skewness=widget.skewness_input, + amplitude=_widget.amplitude_input, + location=_widget.location_input, + width=_widget.width_input, + skewness=_widget.skewness_input, ), coordinates=generate_simulation_coordinates( TimeCoordinates( - timepoints_max=widget.timepoints_max_input, - timepoints_stepsize=widget.timepoints_stepsize_input, + timepoints_max=_widget.timepoints_max_input, + timepoints_stepsize=_widget.timepoints_stepsize_input, ), SpectralCoordinates( - wavelength_min=widget.wavelength_min_input, - wavelength_max=widget.wavelength_max_input, - wavelength_stepsize=widget.wavelength_stepsize_input, + wavelength_min=_widget.wavelength_min_input, + wavelength_max=_widget.wavelength_max_input, + wavelength_stepsize=_widget.wavelength_stepsize_input, ), ), settings=Settings( - stdev_noise=widget.stdev_noise_input, - seed=widget.seed_input, - add_gaussian_irf=widget.add_gaussian_irf_input, - use_sequential_scheme=widget.use_sequential_scheme_input, + stdev_noise=_widget.stdev_noise_input, + seed=_widget.seed_input, + add_gaussian_irf=_widget.add_gaussian_irf_input, + use_sequential_scheme=_widget.use_sequential_scheme_input, ), - irf=IRF(center=widget.irf_location_input, width=widget.irf_width_input), + irf=IRF(center=_widget.irf_location_input, width=_widget.irf_width_input), ) generate_model_parameter_and_data_files( simulation_config, - model_file_name=widget.model_file_name_input, - parameter_file_name=widget.parameter_file_name_input, - data_file_name=widget.data_file_name_input, + model_file_name=_widget.model_file_name_input, + parameter_file_name=_widget.parameter_file_name_input, + data_file_name=_widget.data_file_name_input, ) def setup_widget_observer() -> None: - """Sets up the observer pattern on the 'simulate' traitlet to synchronize the frontend widget - with the backend simulation code. + """Set up an observer to trigger simulation when the widget state changes. - This function ensures that any changes in the widget's state trigger the simulation process, - which generates the model, parameter, and data files. + This function sets up an observer on the 'simulate' traitlet. When triggered, it runs the + simulation process, generating the necessary model, parameter, and data files. """ - widget.observe(handler=_simulate, names=["simulate"]) + _widget.observe(handler=_simulate, names=["simulate"]) diff --git a/pyproject.toml b/pyproject.toml index 34635cb..1324986 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,15 +17,13 @@ license = { file = "LICENSE" } authors = [ { name = "Anmol Bhatia", email = "anmolbhatia05@gmail.com " }, ] -requires-python = ">=3.8" +requires-python = ">=3.10" classifiers = [ "Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Natural Language :: English", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", diff --git a/tests/test_simulate.py b/tests/test_simulate.py new file mode 100644 index 0000000..ae5f4ee --- /dev/null +++ b/tests/test_simulate.py @@ -0,0 +1,93 @@ +"""Test for widget _simulate functionality.""" + +from __future__ import annotations + +import tempfile +from contextlib import contextmanager +from pathlib import Path + +import pytest + +from pyparamgui.widget import Widget +from pyparamgui.widget import _simulate +from pyparamgui.widget import _widget as global_widget + + +def _create_mock_widget(): + widget = Widget() + widget.decay_rates_input = [0.1, 0.2] + widget.amplitude_input = [1.0, 2.0] + widget.location_input = [400, 500] + widget.width_input = [10, 20] + widget.skewness_input = [0, 0] + widget.timepoints_max_input = 100 + widget.timepoints_stepsize_input = 0.1 + widget.wavelength_min_input = 400 + widget.wavelength_max_input = 600 + widget.wavelength_stepsize_input = 1 + widget.stdev_noise_input = 0.01 + widget.seed_input = 42 + widget.add_gaussian_irf_input = True + widget.irf_location_input = 0 + widget.irf_width_input = 0.1 + widget.use_sequential_scheme_input = False + widget.model_file_name_input = "model.yml" + widget.parameter_file_name_input = "parameters.csv" + widget.data_file_name_input = "dataset.nc" + return widget + + +@pytest.fixture() +def mock_widget(): + """Return a mock Widget for testing.""" + return _create_mock_widget() + + +@contextmanager +def use_mock_widget(mock_widget): + """Mock the global Widget instance (`_widget`) during the test.""" + original_widget = global_widget + import pyparamgui.widget + + pyparamgui.widget._widget = mock_widget + try: + yield + finally: + pyparamgui.widget._widget = original_widget + + +@pytest.fixture() +def temp_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as tmpdirname: + yield tmpdirname + + +def test_simulate(mock_widget, temp_dir): + """Test the _simulate function to ensure it generates non-empty files.""" + model_file_name_input_path = Path(temp_dir) / "model.yml" + parameter_file_name_input_path = Path(temp_dir) / "parameters.csv" + data_file_name_input_path = Path(temp_dir) / "dataset.nc" + + with use_mock_widget(mock_widget): + mock_widget.model_file_name_input = str(model_file_name_input_path) + mock_widget.parameter_file_name_input = str(parameter_file_name_input_path) + mock_widget.data_file_name_input = str(data_file_name_input_path) + _simulate(None) + + # Check if files exist and are not empty + # Ideally you would also want to check their content here + assert model_file_name_input_path.exists() + assert model_file_name_input_path.stat().st_size > 0 + + assert parameter_file_name_input_path.exists() + assert parameter_file_name_input_path.stat().st_size > 0 + + assert data_file_name_input_path.exists() + assert data_file_name_input_path.stat().st_size > 0 + + +if __name__ == "__main__": + my_mock_widget = _create_mock_widget() + with tempfile.TemporaryDirectory() as temp_dir: + test_simulate(my_mock_widget, temp_dir) From 6e3a0cac76ea3579d8af7c4408a1adc99d9034d2 Mon Sep 17 00:00:00 2001 From: anmolbhatia05 Date: Sat, 27 Jul 2024 18:28:14 +0200 Subject: [PATCH 08/10] PR improvements --- README.md | 6 +- pyparamgui/__init__.py | 13 +- pyparamgui/static/form.css | 5 + pyparamgui/static/form.js | 408 +++++++++++++++++++++---------------- pyparamgui/widget.py | 66 +++--- pyproject.toml | 1 + requirements_pinned.txt | 1 + tests/test_dummy.py | 7 - tests/test_simulate.py | 44 ++-- 9 files changed, 295 insertions(+), 256 deletions(-) delete mode 100644 tests/test_dummy.py diff --git a/README.md b/README.md index 6b2c7a8..17b5d46 100644 --- a/README.md +++ b/README.md @@ -14,11 +14,7 @@ [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) -pyglotaran notebook widgets for teaching parameter estimation examples - -## Features - -- TODO +A pyglotaran based jupyter notebook widget for teaching parameter estimation examples. It can simulate data, visualize it and create related model.yml, parameters.csv and dataset.nc files. It is supposed to help students learn about the basics of the pyglotaran ecosystem. ## Contributors ✨ diff --git a/pyparamgui/__init__.py b/pyparamgui/__init__.py index 0391a6d..e087c69 100644 --- a/pyparamgui/__init__.py +++ b/pyparamgui/__init__.py @@ -6,19 +6,14 @@ __email__ = "anmolbhatia05@gmail.com" __version__ = "0.0.1" -from pyparamgui.widget import _widget -from pyparamgui.widget import setup_widget_observer +from pyparamgui.widget import Widget -__all__ = ["_widget", "setup_widget_observer"] +__all__ = ["Widget"] """ Package Usage: - # considering that each command is run in a separate jupyter notebook cell - %env ANYWIDGET_HMR=1 + from pyparamgui import Widget - from pyparamgui import widget, setup_widget_observer - + widget = Widget() widget - - setup_widget_observer() """ diff --git a/pyparamgui/static/form.css b/pyparamgui/static/form.css index 4053a4c..2e884ce 100644 --- a/pyparamgui/static/form.css +++ b/pyparamgui/static/form.css @@ -24,3 +24,8 @@ hr { button { margin-top: 10px; } +message { + transition: opacity 0.5s ease; + opacity: 0; + color: black; +} \ No newline at end of file diff --git a/pyparamgui/static/form.js b/pyparamgui/static/form.js index 33a8292..19abc1f 100644 --- a/pyparamgui/static/form.js +++ b/pyparamgui/static/form.js @@ -1,63 +1,215 @@ -function render({ model, el }) { - const form = document.createElement("form"); +/** + * Creates a form group with a label and a text input field. + * + * @param {string} labelText - The text content for the label. + * @param {string} inputId - The id attribute for the input element. + * @param {string} inputName - The name attribute for the input element. + * @param {string} inputValue - The initial value for the input element. + * + * @returns {HTMLDivElement} The form group element containing the label and input. + */ +function createTextFormGroup(labelText, inputId, inputName, inputValue) { + const formGroup = document.createElement("div"); + formGroup.className = "form-group"; + + const label = document.createElement("label"); + label.setAttribute("for", inputId); + label.textContent = labelText; + + const input = document.createElement("input"); + input.setAttribute("type", "text"); + input.setAttribute("id", inputId); + input.setAttribute("name", inputName); + input.value = inputValue; + + formGroup.appendChild(label); + formGroup.appendChild(input); + + return formGroup; +} - /** - * Creates a form group with a label and a text input field. - * - * @param {string} labelText - The text content for the label. - * @param {string} inputId - The id attribute for the input element. - * @param {string} inputName - The name attribute for the input element. - * @param {string} inputValue - The initial value for the input element. - * - * @returns {HTMLDivElement} The form group element containing the label and input. - */ - function createTextFormGroup(labelText, inputId, inputName, inputValue) { - const formGroup = document.createElement("div"); - formGroup.className = "form-group"; +/** + * Creates a form group with a label and a checkbox input field. + * + * @param {string} labelText - The text content for the label. + * @param {string} inputId - The id attribute for the input element. + * @param {string} inputName - The name attribute for the input element. + * @param {boolean} [inputChecked=false] - The initial checked state for the checkbox. + * + * @returns {HTMLDivElement} The form group element containing the label and checkbox input. + */ +function createCheckboxFormGroup(labelText, inputId, inputName, inputChecked = false) { + const formGroup = document.createElement("div"); + formGroup.className = "form-group"; - const label = document.createElement("label"); - label.setAttribute("for", inputId); - label.textContent = labelText; + const label = document.createElement("label"); + label.setAttribute("for", inputId); + label.textContent = labelText; - const input = document.createElement("input"); - input.setAttribute("type", "text"); - input.setAttribute("id", inputId); - input.setAttribute("name", inputName); - input.value = inputValue; + const input = document.createElement("input"); + input.setAttribute("type", "checkbox"); + input.setAttribute("id", inputId); + input.setAttribute("name", inputName); + input.checked = inputChecked; - formGroup.appendChild(label); - formGroup.appendChild(input); + formGroup.appendChild(label); + formGroup.appendChild(input); - return formGroup; + return formGroup; +} + +/** + * Converts the input values from the form into their respective data types. + * + * @param {Object} inputs - The input values to convert. + * + * @returns {Object|null} An object containing the converted input values, or null if an error occurs. + * + * @property {number[]} decay_rates - Array of decay rates as floats. + * @property {number[]} amplitude - Array of amplitudes as floats. + * @property {number[]} location - Array of locations as floats. + * @property {number[]} width - Array of widths as floats. + * @property {number[]} skewness - Array of skewness values as floats. + * @property {number} timepoints_max - Maximum number of timepoints as an integer. + * @property {number} timepoints_stepsize - Step size for timepoints as a float. + * @property {number} wavelength_min - Minimum wavelength value as a float. + * @property {number} wavelength_max - Maximum wavelength value as a float. + * @property {number} wavelength_stepsize - Step size for wavelength as a float. + * @property {number} stdev_noise - Standard deviation of noise as a float. + * @property {number} seed - Seed for random number generation as an integer. + * @property {number} irf_location - Location of the IRF center as a float. + * @property {number} irf_width - Width of the IRF as a float. + */ +function convertInputs(inputs) { + try { + const decay_rates = inputs.decay_rates.split(",").map(parseFloat); + const amplitude = inputs.amplitude.split(",").map(parseFloat); + const location = inputs.location.split(",").map(parseFloat); + const width = inputs.width.split(",").map(parseFloat); + const skewness = inputs.skewness.split(",").map(parseFloat); + const timepoints_max = parseInt(inputs.timepoints_max, 10); + const timepoints_stepsize = parseFloat(inputs.timepoints_stepsize); + const wavelength_min = parseFloat(inputs.wavelength_min); + const wavelength_max = parseFloat(inputs.wavelength_max); + const wavelength_stepsize = parseFloat(inputs.wavelength_stepsize); + const stdev_noise = parseFloat(inputs.stdev_noise); + const seed = parseInt(inputs.seed, 10); + const irf_location = parseFloat(inputs.irf_location); + const irf_width = parseFloat(inputs.irf_width); + + return { + decay_rates, + amplitude, + location, + width, + skewness, + timepoints_max, + timepoints_stepsize, + wavelength_min, + wavelength_max, + wavelength_stepsize, + stdev_noise, + seed, + irf_location, + irf_width, + }; + } catch (error) { + alert("Error converting inputs: " + error.message); + return null; } +} - /** - * Creates a form group with a label and a checkbox input field. - * - * @param {string} labelText - The text content for the label. - * @param {string} inputId - The id attribute for the input element. - * @param {string} inputName - The name attribute for the input element. - * - * @returns {HTMLDivElement} The form group element containing the label and checkbox input. - */ - function createCheckboxFormGroup(labelText, inputId, inputName) { - const formGroup = document.createElement("div"); - formGroup.className = "form-group"; +/** + * Validates the input values for the simulation. + * + * @param {Object} inputs - The input values to validate. + * + * @param {number[]} inputs.decay_rates - Array of decay rates as floats. + * @param {number[]} inputs.amplitude - Array of amplitudes as floats. + * @param {number[]} inputs.location - Array of locations as floats. + * @param {number[]} inputs.width - Array of widths as floats. + * @param {number[]} inputs.skewness - Array of skewness values as floats. + * @param {number} inputs.wavelength_min - Minimum wavelength value as a float. + * @param {number} inputs.wavelength_max - Maximum wavelength value as a float. + * @param {number} inputs.timepoints_max - Maximum number of timepoints as an integer. + * + * @returns {boolean} True if all inputs are valid, otherwise false. + */ +function validateInputs(inputs) { + try { + const { decay_rates, amplitude, location, width, skewness } = inputs; - const label = document.createElement("label"); - label.setAttribute("for", inputId); - label.textContent = labelText; + if (decay_rates.some(isNaN)) { + alert("Invalid decay rates"); + return false; + } + if (amplitude.some(isNaN)) { + alert("Invalid amplitudes"); + return false; + } + if (location.some(isNaN)) { + alert("Invalid locations"); + return false; + } + if (width.some(isNaN)) { + alert("Invalid widths"); + return false; + } + if (skewness.some(isNaN)) { + alert("Invalid skewness values"); + return false; + } - const input = document.createElement("input"); - input.setAttribute("type", "checkbox"); - input.setAttribute("id", inputId); - input.setAttribute("name", inputName); + const lengths = [ + decay_rates.length, + amplitude.length, + location.length, + width.length, + skewness.length, + ]; + if (new Set(lengths).size !== 1) { + alert("All input lists must have the same length"); + return false; + } - formGroup.appendChild(label); - formGroup.appendChild(input); + if ( + inputs.wavelength_min >= inputs.wavelength_max || + inputs.timepoints_max <= 0 + ) { + alert("Invalid timepoints or wavelength specification"); + return false; + } - return formGroup; + return true; + } catch (error) { + alert("Validation error: " + error.message); + return false; } +} + +/** + * Displays a temporary message indicating simulation completion. + * @param {HTMLElement} parentElement - The parent element to which the message will be appended. + */ +function displaySimulationMessage(parentElement) { + const message = document.createElement("p"); + message.textContent = "Simulated! Files created!"; + parentElement.appendChild(message); + + setTimeout(() => { + setTimeout(() => { + parentElement.removeChild(message); + }, 500); + }, 2000); +}; + +/** + * Entrypoint for frontend rendering called by Python backend based on anywidget. + * @param {Object} model - The model data passed from the backend. + * @param {HTMLElement} el - The HTML element where the form will be rendered. + */ +function render({ model, el }) { + const form = document.createElement("form"); form.appendChild( createTextFormGroup( @@ -211,142 +363,40 @@ function render({ model, el }) { "dataset.nc", ), ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createCheckboxFormGroup( + "Visualize Data:", + "visualize_data_input", + "visualize_data_input", + true, + ), + ); el.appendChild(form); - /** - * Converts the input values from the form into their respective data types. - * - * @returns {Object|null} An object containing the converted input values, or null if an error occurs. - * - * @property {number[]} decay_rates - Array of decay rates as floats. - * @property {number[]} amplitude - Array of amplitudes as floats. - * @property {number[]} location - Array of locations as floats. - * @property {number[]} width - Array of widths as floats. - * @property {number[]} skewness - Array of skewness values as floats. - * @property {number} timepoints_max - Maximum number of timepoints as an integer. - * @property {number} timepoints_stepsize - Step size for timepoints as a float. - * @property {number} wavelength_min - Minimum wavelength value as a float. - * @property {number} wavelength_max - Maximum wavelength value as a float. - * @property {number} wavelength_stepsize - Step size for wavelength as a float. - * @property {number} stdev_noise - Standard deviation of noise as a float. - * @property {number} seed - Seed for random number generation as an integer. - * @property {number} irf_location - Location of the IRF center as a float. - * @property {number} irf_width - Width of the IRF as a float. - */ - function convertInputs() { - try { - const decay_rates = decay_rates_input.value.split(",").map(parseFloat); - const amplitude = amplitude_input.value.split(",").map(parseFloat); - const location = location_input.value.split(",").map(parseFloat); - const width = width_input.value.split(",").map(parseFloat); - const skewness = skewness_input.value.split(",").map(parseFloat); - const timepoints_max = parseInt(timepoints_max_input.value, 10); - const timepoints_stepsize = parseFloat(timepoints_stepsize_input.value); - const wavelength_min = parseFloat(wavelength_min_input.value); - const wavelength_max = parseFloat(wavelength_max_input.value); - const wavelength_stepsize = parseFloat(wavelength_stepsize_input.value); - const stdev_noise = parseFloat(stdev_noise_input.value); - const seed = parseInt(seed_input.value, 10); - const irf_location = parseFloat(irf_location_input.value); - const irf_width = parseFloat(irf_width_input.value); - - return { - decay_rates, - amplitude, - location, - width, - skewness, - timepoints_max, - timepoints_stepsize, - wavelength_min, - wavelength_max, - wavelength_stepsize, - stdev_noise, - seed, - irf_location, - irf_width, - }; - } catch (error) { - alert("Error converting inputs: " + error.message); - return null; - } - } - - /** - * Validates the input values for the simulation. - * - * @param {Object} inputs - The input values to validate. - * - * @param {number[]} inputs.decay_rates - Array of decay rates as floats. - * @param {number[]} inputs.amplitude - Array of amplitudes as floats. - * @param {number[]} inputs.location - Array of locations as floats. - * @param {number[]} inputs.width - Array of widths as floats. - * @param {number[]} inputs.skewness - Array of skewness values as floats. - * @param {number} inputs.wavelength_min - Minimum wavelength value as a float. - * @param {number} inputs.wavelength_max - Maximum wavelength value as a float. - * @param {number} inputs.timepoints_max - Maximum number of timepoints as an integer. - * - * @returns {boolean} True if all inputs are valid, otherwise false. - */ - function validateInputs(inputs) { - try { - const { decay_rates, amplitude, location, width, skewness } = inputs; - - if (decay_rates.some(isNaN)) { - alert("Invalid decay rates"); - return false; - } - if (amplitude.some(isNaN)) { - alert("Invalid amplitudes"); - return false; - } - if (location.some(isNaN)) { - alert("Invalid locations"); - return false; - } - if (width.some(isNaN)) { - alert("Invalid widths"); - return false; - } - if (skewness.some(isNaN)) { - alert("Invalid skewness values"); - return false; - } - - const lengths = [ - decay_rates.length, - amplitude.length, - location.length, - width.length, - skewness.length, - ]; - if (new Set(lengths).size !== 1) { - alert("All input lists must have the same length"); - return false; - } - - if ( - inputs.wavelength_min >= inputs.wavelength_max || - inputs.timepoints_max <= 0 - ) { - alert("Invalid timepoints or wavelength specification"); - return false; - } - - return true; - } catch (error) { - alert("Validation error: " + error.message); - return false; - } - } - const btn = document.createElement("button"); btn.textContent = "Simulate"; btn.addEventListener("click", function (event) { event.preventDefault(); - const convertedInputs = convertInputs(); + const inputs = { + decay_rates: decay_rates_input.value, + amplitude: amplitude_input.value, + location: location_input.value, + width: width_input.value, + skewness: skewness_input.value, + timepoints_max: timepoints_max_input.value, + timepoints_stepsize: timepoints_stepsize_input.value, + wavelength_min: wavelength_min_input.value, + wavelength_max: wavelength_max_input.value, + wavelength_stepsize: wavelength_stepsize_input.value, + stdev_noise: stdev_noise_input.value, + seed: seed_input.value, + irf_location: irf_location_input.value, + irf_width: irf_width_input.value, + }; + const convertedInputs = convertInputs(inputs); if (!convertedInputs) return; const isValid = validateInputs(convertedInputs); @@ -374,11 +424,15 @@ function render({ model, el }) { model.set("model_file_name_input", model_file_name_input.value); model.set("parameter_file_name_input", parameter_file_name_input.value); model.set("data_file_name_input", data_file_name_input.value); + model.set("visualize_data", visualize_data_input.checked); model.set("simulate", self.crypto.randomUUID()); model.save_changes(); + + displaySimulationMessage(el); }); + el.appendChild(btn); } -export default { render }; +export default { render }; \ No newline at end of file diff --git a/pyparamgui/widget.py b/pyparamgui/widget.py index eb2834c..0e844ac 100644 --- a/pyparamgui/widget.py +++ b/pyparamgui/widget.py @@ -6,6 +6,8 @@ import anywidget import traitlets +from glotaran.io import load_dataset +from pyglotaran_extras import plot_data_overview from pyparamgui.schema import IRF from pyparamgui.schema import KineticParameters @@ -45,6 +47,7 @@ class Widget(anywidget.AnyWidget): parameter_file_name_input (traitlets.Unicode): Name of the parameter file. data_file_name_input (traitlets.Unicode): Name of the data file. simulate (traitlets.Unicode): Trigger for simulation. + visualize_data (traitlets.Bool): Flag to visualize data. """ _esm: pathlib.Path = pathlib.Path(__file__).parent / "static" / "form.js" @@ -69,57 +72,60 @@ class Widget(anywidget.AnyWidget): parameter_file_name_input: traitlets.Unicode = traitlets.Unicode("").tag(sync=True) data_file_name_input: traitlets.Unicode = traitlets.Unicode("").tag(sync=True) simulate: traitlets.Unicode = traitlets.Unicode("").tag(sync=True) + visualize_data: traitlets.Bool = traitlets.Bool(default_value=True).tag(sync=True) + def __init__(self): + super().__init__() + self.observe(handler=_simulate, names=["simulate"]) -_widget = Widget() - -def _simulate(_) -> None: +def _simulate(change) -> None: """Generate simulation files based on (global) widget (`_widget`) inputs. This private callback function creates model, parameter, and data files using the current widget (`_widget`) state. The 'change' parameter is unused but required for traitlet observation. """ + widget_instance = change["owner"] simulation_config = SimulationConfig( - kinetic_parameters=KineticParameters(decay_rates=_widget.decay_rates_input), + kinetic_parameters=KineticParameters(decay_rates=widget_instance.decay_rates_input), spectral_parameters=SpectralParameters( - amplitude=_widget.amplitude_input, - location=_widget.location_input, - width=_widget.width_input, - skewness=_widget.skewness_input, + amplitude=widget_instance.amplitude_input, + location=widget_instance.location_input, + width=widget_instance.width_input, + skewness=widget_instance.skewness_input, ), coordinates=generate_simulation_coordinates( TimeCoordinates( - timepoints_max=_widget.timepoints_max_input, - timepoints_stepsize=_widget.timepoints_stepsize_input, + timepoints_max=widget_instance.timepoints_max_input, + timepoints_stepsize=widget_instance.timepoints_stepsize_input, ), SpectralCoordinates( - wavelength_min=_widget.wavelength_min_input, - wavelength_max=_widget.wavelength_max_input, - wavelength_stepsize=_widget.wavelength_stepsize_input, + wavelength_min=widget_instance.wavelength_min_input, + wavelength_max=widget_instance.wavelength_max_input, + wavelength_stepsize=widget_instance.wavelength_stepsize_input, ), ), settings=Settings( - stdev_noise=_widget.stdev_noise_input, - seed=_widget.seed_input, - add_gaussian_irf=_widget.add_gaussian_irf_input, - use_sequential_scheme=_widget.use_sequential_scheme_input, + stdev_noise=widget_instance.stdev_noise_input, + seed=widget_instance.seed_input, + add_gaussian_irf=widget_instance.add_gaussian_irf_input, + use_sequential_scheme=widget_instance.use_sequential_scheme_input, ), - irf=IRF(center=_widget.irf_location_input, width=_widget.irf_width_input), + irf=IRF(center=widget_instance.irf_location_input, width=widget_instance.irf_width_input), ) generate_model_parameter_and_data_files( simulation_config, - model_file_name=_widget.model_file_name_input, - parameter_file_name=_widget.parameter_file_name_input, - data_file_name=_widget.data_file_name_input, + model_file_name=widget_instance.model_file_name_input, + parameter_file_name=widget_instance.parameter_file_name_input, + data_file_name=widget_instance.data_file_name_input, ) - - -def setup_widget_observer() -> None: - """Set up an observer to trigger simulation when the widget state changes. - - This function sets up an observer on the 'simulate' traitlet. When triggered, it runs the - simulation process, generating the necessary model, parameter, and data files. - """ - _widget.observe(handler=_simulate, names=["simulate"]) + if widget_instance.visualize_data: + irf_location = ( + None if not simulation_config.settings.add_gaussian_irf + else simulation_config.irf.center + ) + plot_data_overview( + dataset=load_dataset(widget_instance.data_file_name_input), + irf_location=irf_location + ) diff --git a/pyproject.toml b/pyproject.toml index 1324986..e222216 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "anywidget==0.9.13", "pydantic==2.8.2", "pyglotaran==0.7.2", + "pyglotaran_extras==0.7.2", "pyyaml==6.0.1", ] optional-dependencies.dev = [ diff --git a/requirements_pinned.txt b/requirements_pinned.txt index 77b15c1..573b71a 100644 --- a/requirements_pinned.txt +++ b/requirements_pinned.txt @@ -4,4 +4,5 @@ pydantic==2.8.2 anywidget==0.9.13 pyglotaran==0.7.2 +pyglotaran_extras==0.7.2 pyyaml==6.0.1 diff --git a/tests/test_dummy.py b/tests/test_dummy.py deleted file mode 100644 index ca70ca2..0000000 --- a/tests/test_dummy.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Just a dummy test so pytest does not fail due to missing tests if no CLI is selected.""" - -from __future__ import annotations - - -def test_dummy(): - """Create and actual test.""" diff --git a/tests/test_simulate.py b/tests/test_simulate.py index ae5f4ee..fbb44c6 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -3,14 +3,12 @@ from __future__ import annotations import tempfile -from contextlib import contextmanager from pathlib import Path import pytest from pyparamgui.widget import Widget from pyparamgui.widget import _simulate -from pyparamgui.widget import _widget as global_widget def _create_mock_widget(): @@ -34,6 +32,7 @@ def _create_mock_widget(): widget.model_file_name_input = "model.yml" widget.parameter_file_name_input = "parameters.csv" widget.data_file_name_input = "dataset.nc" + widget.visualize_data = True return widget @@ -43,19 +42,6 @@ def mock_widget(): return _create_mock_widget() -@contextmanager -def use_mock_widget(mock_widget): - """Mock the global Widget instance (`_widget`) during the test.""" - original_widget = global_widget - import pyparamgui.widget - - pyparamgui.widget._widget = mock_widget - try: - yield - finally: - pyparamgui.widget._widget = original_widget - - @pytest.fixture() def temp_dir(): """Create a temporary directory for testing.""" @@ -69,22 +55,24 @@ def test_simulate(mock_widget, temp_dir): parameter_file_name_input_path = Path(temp_dir) / "parameters.csv" data_file_name_input_path = Path(temp_dir) / "dataset.nc" - with use_mock_widget(mock_widget): - mock_widget.model_file_name_input = str(model_file_name_input_path) - mock_widget.parameter_file_name_input = str(parameter_file_name_input_path) - mock_widget.data_file_name_input = str(data_file_name_input_path) - _simulate(None) + mock_widget.model_file_name_input = str(model_file_name_input_path) + mock_widget.parameter_file_name_input = str(parameter_file_name_input_path) + mock_widget.data_file_name_input = str(data_file_name_input_path) + + mock_change = {} + mock_change["owner"] = mock_widget + _simulate(mock_change) - # Check if files exist and are not empty - # Ideally you would also want to check their content here - assert model_file_name_input_path.exists() - assert model_file_name_input_path.stat().st_size > 0 + # Check if files exist and are not empty + # Ideally you would also want to check their content here + assert model_file_name_input_path.exists() + assert model_file_name_input_path.stat().st_size > 0 - assert parameter_file_name_input_path.exists() - assert parameter_file_name_input_path.stat().st_size > 0 + assert parameter_file_name_input_path.exists() + assert parameter_file_name_input_path.stat().st_size > 0 - assert data_file_name_input_path.exists() - assert data_file_name_input_path.stat().st_size > 0 + assert data_file_name_input_path.exists() + assert data_file_name_input_path.stat().st_size > 0 if __name__ == "__main__": From b596469aa74c205c40eec497b6777bdf625fd74d Mon Sep 17 00:00:00 2001 From: anmolbhatia05 Date: Sat, 27 Jul 2024 18:34:26 +0200 Subject: [PATCH 09/10] pre-commit code checks --- pyparamgui/static/form.css | 2 +- pyparamgui/static/form.js | 11 ++++++++--- pyparamgui/widget.py | 8 ++++---- pyproject.toml | 2 +- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/pyparamgui/static/form.css b/pyparamgui/static/form.css index 2e884ce..1e832cd 100644 --- a/pyparamgui/static/form.css +++ b/pyparamgui/static/form.css @@ -28,4 +28,4 @@ message { transition: opacity 0.5s ease; opacity: 0; color: black; -} \ No newline at end of file +} diff --git a/pyparamgui/static/form.js b/pyparamgui/static/form.js index 19abc1f..5b237e6 100644 --- a/pyparamgui/static/form.js +++ b/pyparamgui/static/form.js @@ -38,7 +38,12 @@ function createTextFormGroup(labelText, inputId, inputName, inputValue) { * * @returns {HTMLDivElement} The form group element containing the label and checkbox input. */ -function createCheckboxFormGroup(labelText, inputId, inputName, inputChecked = false) { +function createCheckboxFormGroup( + labelText, + inputId, + inputName, + inputChecked = false, +) { const formGroup = document.createElement("div"); formGroup.className = "form-group"; @@ -201,7 +206,7 @@ function displaySimulationMessage(parentElement) { parentElement.removeChild(message); }, 500); }, 2000); -}; +} /** * Entrypoint for frontend rendering called by Python backend based on anywidget. @@ -435,4 +440,4 @@ function render({ model, el }) { el.appendChild(btn); } -export default { render }; \ No newline at end of file +export default { render }; diff --git a/pyparamgui/widget.py b/pyparamgui/widget.py index 0e844ac..5335e5e 100644 --- a/pyparamgui/widget.py +++ b/pyparamgui/widget.py @@ -79,7 +79,7 @@ def __init__(self): self.observe(handler=_simulate, names=["simulate"]) -def _simulate(change) -> None: +def _simulate(change: dict) -> None: """Generate simulation files based on (global) widget (`_widget`) inputs. This private callback function creates model, parameter, and data files @@ -122,10 +122,10 @@ def _simulate(change) -> None: ) if widget_instance.visualize_data: irf_location = ( - None if not simulation_config.settings.add_gaussian_irf + None + if not simulation_config.settings.add_gaussian_irf else simulation_config.irf.center ) plot_data_overview( - dataset=load_dataset(widget_instance.data_file_name_input), - irf_location=irf_location + dataset=load_dataset(widget_instance.data_file_name_input), irf_location=irf_location ) diff --git a/pyproject.toml b/pyproject.toml index e222216..98082cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "anywidget==0.9.13", "pydantic==2.8.2", "pyglotaran==0.7.2", - "pyglotaran_extras==0.7.2", + "pyglotaran-extras==0.7.2", "pyyaml==6.0.1", ] optional-dependencies.dev = [ From 19f3053e81ea1165a4454d51154e85d28f72e035 Mon Sep 17 00:00:00 2001 From: anmolbhatia05 Date: Sat, 27 Jul 2024 18:39:26 +0200 Subject: [PATCH 10/10] pre-commit checks --- pyparamgui/widget.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pyparamgui/widget.py b/pyparamgui/widget.py index 5335e5e..c11476b 100644 --- a/pyparamgui/widget.py +++ b/pyparamgui/widget.py @@ -75,6 +75,16 @@ class Widget(anywidget.AnyWidget): visualize_data: traitlets.Bool = traitlets.Bool(default_value=True).tag(sync=True) def __init__(self): + """Initialize the Widget instance and set up traitlet observers. + + This constructor initializes the Widget instance by calling the parent + class's initializer and sets up an observer for the 'simulate' traitlet. + The observer triggers the `_simulate` function whenever the 'simulate' + traitlet changes. + + Observers: + - simulate: Calls the `_simulate` function when the 'simulate' traitlet changes. + """ super().__init__() self.observe(handler=_simulate, names=["simulate"])