diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..dfd4cf4 Binary files /dev/null and b/.DS_Store differ diff --git a/.cruft.json b/.cruft.json index 0f39e87..37dce94 100644 --- a/.cruft.json +++ b/.cruft.json @@ -5,8 +5,8 @@ "context": { "cookiecutter": { "full_name": "Anmol Bhatia", - "email": "a.bhatia2@student.vu.nl", - "github_username": "glotaran", + "email": "anmolbhatia05@gmail.com", + "github_username": "anmolbhatia05", "project_name": "PyParamGUI", "project_slug": "pyparamgui", "project_slug_url": "pyparamgui", diff --git a/.gitignore b/.gitignore index 3a2eb3e..903bd80 100644 --- a/.gitignore +++ b/.gitignore @@ -104,3 +104,6 @@ ENV/ # IDE settings .vscode/ + +examples/ +!examples/*.ipynb diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e09c42..d3def5e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -74,13 +74,13 @@ repos: alias: flake8-docs args: - "--select=DOC" - - "--extend-ignore=DOC502" + - "--extend-ignore=DOC502,DOC601,DOC603,DOC101,DOC103,DOC201" - "--color=always" - "--require-return-section-when-returning-nothing=False" - "--allow-init-docstring=True" - "--skip-checking-short-docstrings=False" name: "flake8 lint docstrings" - exclude: "^(docs/|tests?/)" + exclude: "^(docs/|tests?/|pyparamgui/generator.py)" additional_dependencies: [pydoclint==0.5.3] - repo: https://github.com/econchick/interrogate @@ -97,6 +97,7 @@ repos: types: [file] types_or: [python, pyi, markdown, rst, jupyter] args: [-L nnumber] + exclude: ^examples/ - repo: https://github.com/rhysd/actionlint rev: "v1.7.1" diff --git a/pyparamgui/.DS_Store b/pyparamgui/.DS_Store new file mode 100644 index 0000000..eecab7a Binary files /dev/null and b/pyparamgui/.DS_Store differ diff --git a/pyparamgui/__init__.py b/pyparamgui/__init__.py index 70f39a4..0391a6d 100644 --- a/pyparamgui/__init__.py +++ b/pyparamgui/__init__.py @@ -3,5 +3,22 @@ from __future__ import annotations __author__ = """Anmol Bhatia""" -__email__ = "a.bhatia2@student.vu.nl" +__email__ = "anmolbhatia05@gmail.com" __version__ = "0.0.1" + +from pyparamgui.widget import _widget +from pyparamgui.widget import setup_widget_observer + +__all__ = ["_widget", "setup_widget_observer"] +""" +Package Usage: + # considering that each command is run in a separate jupyter notebook cell + + %env ANYWIDGET_HMR=1 + + from pyparamgui import widget, setup_widget_observer + + widget + + setup_widget_observer() +""" diff --git a/pyparamgui/__main__.py b/pyparamgui/__main__.py deleted file mode 100644 index dd0b80e..0000000 --- a/pyparamgui/__main__.py +++ /dev/null @@ -1 +0,0 @@ -"""Main module.""" diff --git a/pyparamgui/generator.py b/pyparamgui/generator.py new file mode 100644 index 0000000..9d49b6b --- /dev/null +++ b/pyparamgui/generator.py @@ -0,0 +1,278 @@ +"""The glotaran generator module.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Any +from typing import TypedDict +from typing import cast + +from glotaran.builtin.io.yml.utils import write_dict +from glotaran.builtin.megacomplexes.decay import DecayParallelMegacomplex +from glotaran.builtin.megacomplexes.decay import DecaySequentialMegacomplex +from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex +from glotaran.model import Model + +if TYPE_CHECKING: + from collections.abc import Callable + + +def _generate_decay_model( + *, nr_compartments: int, irf: bool, spectral: bool, decay_type: str +) -> dict[str, Any]: + """Generate a decay model dictionary. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + spectral : bool + Whether to add a spectral model. + decay_type : str + The type of the decay + + Returns + ------- + dict[str, Any] + The generated model dictionary. + """ + compartments = [f"species_{i+1}" for i in range(nr_compartments)] + rates = [f"rates.species_{i+1}" for i in range(nr_compartments)] + + model: dict[str, Any] = { + "megacomplex": { + f"megacomplex_{decay_type}_decay": { + "type": f"decay-{decay_type}", + "compartments": compartments, + "rates": rates, + }, + }, + "dataset": {"dataset_1": {"megacomplex": [f"megacomplex_{decay_type}_decay"]}}, + } + if spectral: + model["megacomplex"] |= { + "megacomplex_spectral": { + "type": "spectral", + "shape": { + compartment: f"shape_species_{i+1}" + for i, compartment in enumerate(compartments) + }, + } + } + model["shape"] = { + f"shape_species_{i+1}": { + "type": "skewed-gaussian", + "amplitude": f"shapes.species_{i+1}.amplitude", + "location": f"shapes.species_{i+1}.location", + "width": f"shapes.species_{i+1}.width", + "skewness": f"shapes.species_{i+1}.skewness", + } + for i in range(nr_compartments) + } + model["dataset"]["dataset_1"] |= { + "global_megacomplex": ["megacomplex_spectral"], + "spectral_axis_inverted": True, + "spectral_axis_scale": 1e7, + } + if irf: + model["dataset"]["dataset_1"] |= {"irf": "gaussian_irf"} + model["irf"] = { + "gaussian_irf": {"type": "gaussian", "center": "irf.center", "width": "irf.width"}, + } + return model + + +def generate_parallel_decay_model( + *, nr_compartments: int = 1, irf: bool = False +) -> dict[str, Any]: + """Generate a parallel decay model dictionary. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + + Returns + ------- + dict[str, Any] + The generated model dictionary. + """ + return _generate_decay_model( + nr_compartments=nr_compartments, irf=irf, spectral=False, decay_type="parallel" + ) + + +def generate_parallel_spectral_decay_model( + *, nr_compartments: int = 1, irf: bool = False +) -> dict[str, Any]: + """Generate a parallel spectral decay model dictionary. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + + Returns + ------- + dict[str, Any] + The generated model dictionary. + """ + return _generate_decay_model( + nr_compartments=nr_compartments, irf=irf, spectral=True, decay_type="parallel" + ) + + +def generate_sequential_decay_model( + *, nr_compartments: int = 1, irf: bool = False +) -> dict[str, Any]: + """Generate a sequential decay model dictionary. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + + Returns + ------- + dict[str, Any] + The generated model dictionary. + """ + return _generate_decay_model( + nr_compartments=nr_compartments, irf=irf, spectral=False, decay_type="sequential" + ) + + +def generate_sequential_spectral_decay_model( + *, nr_compartments: int = 1, irf: bool = False +) -> dict[str, Any]: + """Generate a sequential spectral decay model dictionary. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + + Returns + ------- + dict[str, Any] + The generated model dictionary. + """ + return _generate_decay_model( + nr_compartments=nr_compartments, irf=irf, spectral=True, decay_type="sequential" + ) + + +generators: dict[str, Callable] = { + "decay_parallel": generate_parallel_decay_model, + "spectral_decay_parallel": generate_parallel_spectral_decay_model, + "decay_sequential": generate_sequential_decay_model, + "spectral_decay_sequential": generate_sequential_spectral_decay_model, +} + +available_generators: list[str] = list(generators.keys()) + + +class GeneratorArguments(TypedDict, total=False): + """Arguments used by ``generate_model`` and ``generate_model``. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + + See Also + -------- + generate_model + generate_model_yml + """ + + nr_compartments: int + irf: bool + + +def generate_model(*, generator_name: str, generator_arguments: GeneratorArguments) -> Model: + """Generate a model. + + Parameters + ---------- + generator_name : str + The generator to use. + generator_arguments : GeneratorArguments + Arguments for the generator. + + Returns + ------- + Model + The generated model + + See Also + -------- + generate_parallel_decay_model + generate_parallel_spectral_decay_model + generate_sequential_decay_model + generate_sequential_spectral_decay_model + + Raises + ------ + ValueError + Raised when an unknown generator is specified. + """ + if generator_name not in generators: + msg = ( + f"Unknown model generator '{generator_name}'. " + f"Known generators are: {list(generators.keys())}" + ) + raise ValueError(msg) + model = generators[generator_name](**generator_arguments) + return Model.create_class_from_megacomplexes( + [DecayParallelMegacomplex, DecaySequentialMegacomplex, SpectralMegacomplex] + )(**model) + + +def generate_model_yml(*, generator_name: str, generator_arguments: GeneratorArguments) -> str: + """Generate a model as yml string. + + Parameters + ---------- + generator_name : str + The generator to use. + generator_arguments : GeneratorArguments + Arguments for the generator. + + Returns + ------- + str + The generated model yml string. + + See Also + -------- + generate_parallel_decay_model + generate_parallel_spectral_decay_model + generate_sequential_decay_model + generate_sequential_spectral_decay_model + + Raises + ------ + ValueError + Raised when an unknown generator is specified. + """ + if generator_name not in generators: + msg = ( + f"Unknown model generator '{generator_name}'. " + f"Known generators are: {list(generators.keys())}" + ) + raise ValueError(msg) + model = generators[generator_name](**generator_arguments) + return cast(str, write_dict(model)) diff --git a/pyparamgui/schema.py b/pyparamgui/schema.py new file mode 100644 index 0000000..260177c --- /dev/null +++ b/pyparamgui/schema.py @@ -0,0 +1,150 @@ +"""Schema module representing attributes used for simulation. + +e.g. parameters, coordinates, and settings used in simulation. +""" + +from __future__ import annotations + +import numpy as np +from pydantic import BaseModel +from pydantic import ConfigDict + + +class KineticParameters(BaseModel): + """Kinetic parameters for the simulation. + + Attributes + ---------- + decay_rates (list[float]): List of decay rates. + """ + + decay_rates: list[float] + + +class SpectralParameters(BaseModel): + """Spectral parameters for the simulation. + + Attributes + ---------- + amplitude (list[float]): List of amplitudes. + location (list[float]): List of locations. + width (list[float]): List of widths. + skewness (list[float]): List of skewness values. + """ + + amplitude: list[float] + location: list[float] + width: list[float] + skewness: list[float] + + +class TimeCoordinates(BaseModel): + """Time coordinates for the simulation. + + Attributes + ---------- + timepoints_max (int): Maximum number of time points. + timepoints_stepsize (float): Step size between time points. + """ + + timepoints_max: int + timepoints_stepsize: float + + +class SpectralCoordinates(BaseModel): + """Spectral coordinates for the simulation. + + Attributes + ---------- + wavelength_min (int): Minimum wavelength. + wavelength_max (int): Maximum wavelength. + wavelength_stepsize (float): Step size between wavelengths. + """ + + wavelength_min: int + wavelength_max: int + wavelength_stepsize: float + + +def generate_simulation_coordinates( + time_coordinates: TimeCoordinates, spectral_coordinates: SpectralCoordinates +) -> dict[str, np.ndarray]: + """Generate simulation coordinates based on time and spectral coordinates. + + Args: + time_coordinates (TimeCoordinates): The time coordinates for the simulation. + spectral_coordinates (SpectralCoordinates): The spectral coordinates for the simulation. + + Returns + ------- + dict[str, np.ndarray] + A dictionary containing two keys: + - 'time': A numpy array representing the time axis. + - 'spectral': A numpy array representing the spectral axis. + """ + time_axis = np.arange( + 0, + time_coordinates.timepoints_max * time_coordinates.timepoints_stepsize, + time_coordinates.timepoints_stepsize, + ) + spectral_axis = np.arange( + spectral_coordinates.wavelength_min, + spectral_coordinates.wavelength_max, + spectral_coordinates.wavelength_stepsize, + ) + return {"time": time_axis, "spectral": spectral_axis} + + +class Settings(BaseModel): + """Other settings for the simulation. + + Attributes + ---------- + stdev_noise (float): Standard deviation of the noise to be added to the simulation data. + seed (int): Seed for the random number generator to ensure reproducibility. + add_gaussian_irf (bool): Whether to add a Gaussian IRF to the simulation. + Default is False. + use_sequential_scheme (bool): Whether to use a sequential scheme in the simulation. + Default is False. + """ + + stdev_noise: float + seed: int + add_gaussian_irf: bool = False + use_sequential_scheme: bool = False + + +class IRF(BaseModel): + """Instrument Response Function (IRF) settings for the simulation. + + Attributes + ---------- + center (float): The center position of the IRF. + width (float): The width of the IRF. + """ + + center: float + width: float + + +class SimulationConfig(BaseModel): + """Configuration for the simulation, combining various parameters and settings. + + Attributes + ---------- + kinetic_parameters (KineticParameters): Kinetic parameters for the simulation. + spectral_parameters (SpectralParameters): Spectral parameters for the simulation. + coordinates (Dict[str, np.ndarray]): Dictionary containing the time and spectral axes as + numpy arrays. + settings (Settings): Other settings for the simulation, including noise standard deviation, + random seed, and flags for adding Gaussian IRF and using a sequential scheme. + irf (IRF): Instrument Response Function (IRF) settings, e.g. center position and width. + """ + + kinetic_parameters: KineticParameters + spectral_parameters: SpectralParameters + coordinates: dict[str, np.ndarray] + settings: Settings + irf: IRF + + model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/pyparamgui/static/form.css b/pyparamgui/static/form.css new file mode 100644 index 0000000..4053a4c --- /dev/null +++ b/pyparamgui/static/form.css @@ -0,0 +1,26 @@ +form { + display: flex; + flex-direction: column; +} +.form-group { + display: flex; + align-items: center; + margin-bottom: 10px; +} +label { + margin-right: 10px; + color: black; + width: 150px; +} +input { + flex: 1; + margin: 5px 0; +} +hr { + margin: 20px 0; + border: none; + border-top: 1px solid #ccc; +} +button { + margin-top: 10px; +} diff --git a/pyparamgui/static/form.js b/pyparamgui/static/form.js new file mode 100644 index 0000000..33a8292 --- /dev/null +++ b/pyparamgui/static/form.js @@ -0,0 +1,384 @@ +function render({ model, el }) { + const form = document.createElement("form"); + + /** + * Creates a form group with a label and a text input field. + * + * @param {string} labelText - The text content for the label. + * @param {string} inputId - The id attribute for the input element. + * @param {string} inputName - The name attribute for the input element. + * @param {string} inputValue - The initial value for the input element. + * + * @returns {HTMLDivElement} The form group element containing the label and input. + */ + function createTextFormGroup(labelText, inputId, inputName, inputValue) { + const formGroup = document.createElement("div"); + formGroup.className = "form-group"; + + const label = document.createElement("label"); + label.setAttribute("for", inputId); + label.textContent = labelText; + + const input = document.createElement("input"); + input.setAttribute("type", "text"); + input.setAttribute("id", inputId); + input.setAttribute("name", inputName); + input.value = inputValue; + + formGroup.appendChild(label); + formGroup.appendChild(input); + + return formGroup; + } + + /** + * Creates a form group with a label and a checkbox input field. + * + * @param {string} labelText - The text content for the label. + * @param {string} inputId - The id attribute for the input element. + * @param {string} inputName - The name attribute for the input element. + * + * @returns {HTMLDivElement} The form group element containing the label and checkbox input. + */ + function createCheckboxFormGroup(labelText, inputId, inputName) { + const formGroup = document.createElement("div"); + formGroup.className = "form-group"; + + const label = document.createElement("label"); + label.setAttribute("for", inputId); + label.textContent = labelText; + + const input = document.createElement("input"); + input.setAttribute("type", "checkbox"); + input.setAttribute("id", inputId); + input.setAttribute("name", inputName); + + formGroup.appendChild(label); + formGroup.appendChild(input); + + return formGroup; + } + + form.appendChild( + createTextFormGroup( + "Decay rates:", + "decay_rates_input", + "decay_rates_input", + "0.055, 0.005", + ), + ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createTextFormGroup( + "Amplitudes:", + "amplitude_input", + "amplitude_input", + "1., 1.", + ), + ); + form.appendChild( + createTextFormGroup( + "Location (mean) of spectra:", + "location_input", + "location_input", + "22000, 20000", + ), + ); + form.appendChild( + createTextFormGroup( + "Width of spectra:", + "width_input", + "width_input", + "4000, 3500", + ), + ); + form.appendChild( + createTextFormGroup( + "Skewness of spectra:", + "skewness_input", + "skewness_input", + "0.1, -0.1", + ), + ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createTextFormGroup( + "Timepoints, max:", + "timepoints_max_input", + "timepoints_max_input", + "80", + ), + ); + form.appendChild( + createTextFormGroup( + "Stepsize:", + "timepoints_stepsize_input", + "timepoints_stepsize_input", + "1", + ), + ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createTextFormGroup( + "Wavelength Min:", + "wavelength_min_input", + "wavelength_min_input", + "400", + ), + ); + form.appendChild( + createTextFormGroup( + "Wavelength Max:", + "wavelength_max_input", + "wavelength_max_input", + "600", + ), + ); + form.appendChild( + createTextFormGroup( + "Stepsize:", + "wavelength_stepsize_input", + "wavelength_stepsize_input", + "5", + ), + ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createTextFormGroup( + "Std.dev. noise:", + "stdev_noise_input", + "stdev_noise_input", + "0.01", + ), + ); + form.appendChild( + createTextFormGroup("Seed:", "seed_input", "seed_input", "123"), + ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createCheckboxFormGroup( + "Add Gaussian IRF:", + "add_gaussian_irf_input", + "add_gaussian_irf_input", + ), + ); + form.appendChild( + createTextFormGroup( + "IRF location:", + "irf_location_input", + "irf_location_input", + "3", + ), + ); + form.appendChild( + createTextFormGroup( + "IRF width:", + "irf_width_input", + "irf_width_input", + "1", + ), + ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createCheckboxFormGroup( + "Use Sequential Scheme:", + "use_sequential_scheme_input", + "use_sequential_scheme_input", + ), + ); + form.appendChild(document.createElement("hr")); + form.appendChild( + createTextFormGroup( + "Model File Name:", + "model_file_name_input", + "model_file_name_input", + "model.yml", + ), + ); + form.appendChild( + createTextFormGroup( + "Parameter File Name:", + "parameter_file_name_input", + "parameter_file_name_input", + "parameters.csv", + ), + ); + form.appendChild( + createTextFormGroup( + "Data File Name:", + "data_file_name_input", + "data_file_name_input", + "dataset.nc", + ), + ); + + el.appendChild(form); + + /** + * Converts the input values from the form into their respective data types. + * + * @returns {Object|null} An object containing the converted input values, or null if an error occurs. + * + * @property {number[]} decay_rates - Array of decay rates as floats. + * @property {number[]} amplitude - Array of amplitudes as floats. + * @property {number[]} location - Array of locations as floats. + * @property {number[]} width - Array of widths as floats. + * @property {number[]} skewness - Array of skewness values as floats. + * @property {number} timepoints_max - Maximum number of timepoints as an integer. + * @property {number} timepoints_stepsize - Step size for timepoints as a float. + * @property {number} wavelength_min - Minimum wavelength value as a float. + * @property {number} wavelength_max - Maximum wavelength value as a float. + * @property {number} wavelength_stepsize - Step size for wavelength as a float. + * @property {number} stdev_noise - Standard deviation of noise as a float. + * @property {number} seed - Seed for random number generation as an integer. + * @property {number} irf_location - Location of the IRF center as a float. + * @property {number} irf_width - Width of the IRF as a float. + */ + function convertInputs() { + try { + const decay_rates = decay_rates_input.value.split(",").map(parseFloat); + const amplitude = amplitude_input.value.split(",").map(parseFloat); + const location = location_input.value.split(",").map(parseFloat); + const width = width_input.value.split(",").map(parseFloat); + const skewness = skewness_input.value.split(",").map(parseFloat); + const timepoints_max = parseInt(timepoints_max_input.value, 10); + const timepoints_stepsize = parseFloat(timepoints_stepsize_input.value); + const wavelength_min = parseFloat(wavelength_min_input.value); + const wavelength_max = parseFloat(wavelength_max_input.value); + const wavelength_stepsize = parseFloat(wavelength_stepsize_input.value); + const stdev_noise = parseFloat(stdev_noise_input.value); + const seed = parseInt(seed_input.value, 10); + const irf_location = parseFloat(irf_location_input.value); + const irf_width = parseFloat(irf_width_input.value); + + return { + decay_rates, + amplitude, + location, + width, + skewness, + timepoints_max, + timepoints_stepsize, + wavelength_min, + wavelength_max, + wavelength_stepsize, + stdev_noise, + seed, + irf_location, + irf_width, + }; + } catch (error) { + alert("Error converting inputs: " + error.message); + return null; + } + } + + /** + * Validates the input values for the simulation. + * + * @param {Object} inputs - The input values to validate. + * + * @param {number[]} inputs.decay_rates - Array of decay rates as floats. + * @param {number[]} inputs.amplitude - Array of amplitudes as floats. + * @param {number[]} inputs.location - Array of locations as floats. + * @param {number[]} inputs.width - Array of widths as floats. + * @param {number[]} inputs.skewness - Array of skewness values as floats. + * @param {number} inputs.wavelength_min - Minimum wavelength value as a float. + * @param {number} inputs.wavelength_max - Maximum wavelength value as a float. + * @param {number} inputs.timepoints_max - Maximum number of timepoints as an integer. + * + * @returns {boolean} True if all inputs are valid, otherwise false. + */ + function validateInputs(inputs) { + try { + const { decay_rates, amplitude, location, width, skewness } = inputs; + + if (decay_rates.some(isNaN)) { + alert("Invalid decay rates"); + return false; + } + if (amplitude.some(isNaN)) { + alert("Invalid amplitudes"); + return false; + } + if (location.some(isNaN)) { + alert("Invalid locations"); + return false; + } + if (width.some(isNaN)) { + alert("Invalid widths"); + return false; + } + if (skewness.some(isNaN)) { + alert("Invalid skewness values"); + return false; + } + + const lengths = [ + decay_rates.length, + amplitude.length, + location.length, + width.length, + skewness.length, + ]; + if (new Set(lengths).size !== 1) { + alert("All input lists must have the same length"); + return false; + } + + if ( + inputs.wavelength_min >= inputs.wavelength_max || + inputs.timepoints_max <= 0 + ) { + alert("Invalid timepoints or wavelength specification"); + return false; + } + + return true; + } catch (error) { + alert("Validation error: " + error.message); + return false; + } + } + + const btn = document.createElement("button"); + btn.textContent = "Simulate"; + btn.addEventListener("click", function (event) { + event.preventDefault(); + + const convertedInputs = convertInputs(); + if (!convertedInputs) return; + + const isValid = validateInputs(convertedInputs); + if (!isValid) return; + + model.set("decay_rates_input", convertedInputs.decay_rates); + model.set("amplitude_input", convertedInputs.amplitude); + model.set("location_input", convertedInputs.location); + model.set("width_input", convertedInputs.width); + model.set("skewness_input", convertedInputs.skewness); + model.set("timepoints_max_input", convertedInputs.timepoints_max); + model.set("timepoints_stepsize_input", convertedInputs.timepoints_stepsize); + model.set("wavelength_min_input", convertedInputs.wavelength_min); + model.set("wavelength_max_input", convertedInputs.wavelength_max); + model.set("wavelength_stepsize_input", convertedInputs.wavelength_stepsize); + model.set("stdev_noise_input", convertedInputs.stdev_noise); + model.set("seed_input", convertedInputs.seed); + model.set("add_gaussian_irf_input", add_gaussian_irf_input.checked); + model.set("irf_location_input", convertedInputs.irf_location); + model.set("irf_width_input", convertedInputs.irf_width); + model.set( + "use_sequential_scheme_input", + use_sequential_scheme_input.checked, + ); + model.set("model_file_name_input", model_file_name_input.value); + model.set("parameter_file_name_input", parameter_file_name_input.value); + model.set("data_file_name_input", data_file_name_input.value); + model.set("simulate", self.crypto.randomUUID()); + + model.save_changes(); + }); + el.appendChild(btn); +} + +export default { render }; diff --git a/pyparamgui/utils.py b/pyparamgui/utils.py new file mode 100644 index 0000000..6d92b3c --- /dev/null +++ b/pyparamgui/utils.py @@ -0,0 +1,256 @@ +"""Utility module for generating files, sanitizing yaml files, etc.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING +from typing import Any + +import yaml +from glotaran.builtin.io.yml.yml import save_model +from glotaran.plugin_system.data_io_registration import save_dataset +from glotaran.plugin_system.project_io_registration import save_parameters +from glotaran.simulation.simulation import simulate + +from pyparamgui.generator import generate_model + +if TYPE_CHECKING: + import numpy as np + from glotaran.model.model import Model + from glotaran.parameter.parameter import Parameter + from glotaran.parameter.parameters import Parameters + + from pyparamgui.schema import Settings + from pyparamgui.schema import SimulationConfig + + +def _generate_model_file( + simulation_config: SimulationConfig, nr_compartments: int, file_name: str +) -> Model: + """Generate and save a model file for the simulation. + + This function generates a model based on the provided simulation configuration + and number of compartments. + It saves the generated model to a temporary YAML file, sanitizes the file, + and then saves it to the specified file name. + + Args: + simulation_config (SimulationConfig): The configuration for the simulation. + nr_compartments (int): The number of compartments in the model. + file_name (str): The name of the file to save the sanitized model. + + Returns + ------- + Model + The generated model object, which can be used for further processing or simulation. + """ + generator_name = ( + "spectral_decay_sequential" + if simulation_config.settings.use_sequential_scheme + else "spectral_decay_parallel" + ) + model = generate_model( + generator_name=generator_name, + generator_arguments={ + "nr_compartments": nr_compartments, + "irf": simulation_config.settings.add_gaussian_irf, + }, + ) + save_model(model, "temp_model.yml", allow_overwrite=True) + _sanitize_yaml_file("temp_model.yml", file_name) + return model + + +def _update_parameter_values( + parameters: Parameters, simulation_config: SimulationConfig +) -> Parameters: + """Update parameter values based on the simulation configuration. + + This function iterates through all parameters and updates their values + based on the provided simulation configuration. It handles shape parameters, + rate parameters, and IRF (Instrument Response Function) parameters. + + Parameters + ---------- + parameters : Parameters + The set of parameters to be updated. + simulation_config : SimulationConfig + The configuration object containing the simulation settings and parameter values. + + Returns + ------- + `Parameters` + The updated set of parameters with new values based on the + simulation configuration. + """ + for param in parameters.all(): + label = param.label + if label.startswith("shapes.species_"): + _update_shape_parameter(param, label, simulation_config) + elif label.startswith("rates.species_"): + _update_rate_parameter(param, label, simulation_config) + elif label.startswith("irf") and simulation_config.settings.add_gaussian_irf: + _update_irf_parameter(param, label, simulation_config) + return parameters + + +def _update_shape_parameter(param: Parameter, label: str, simulation_config: SimulationConfig): + """Update shape parameters.""" + parts = label.split(".") + species_index = int(parts[1].split("_")[1]) - 1 + attribute = parts[2] + spectral_params = simulation_config.spectral_parameters + + if attribute == "amplitude": + param.value = spectral_params.amplitude[species_index] + elif attribute == "location": + param.value = spectral_params.location[species_index] + elif attribute == "width": + param.value = spectral_params.width[species_index] + elif attribute == "skewness": + param.value = spectral_params.skewness[species_index] + + +def _update_rate_parameter(param: Parameter, label: str, simulation_config: SimulationConfig): + """Update rate parameters.""" + species_index = int(label.split("_")[1]) - 1 + param.value = simulation_config.kinetic_parameters.decay_rates[species_index] + + +def _update_irf_parameter(param: Parameter, label: str, simulation_config: SimulationConfig): + """Update IRF parameters.""" + if "width" in label: + param.value = simulation_config.irf.width + elif "center" in label: + param.value = simulation_config.irf.center + + +def _generate_parameter_file( + simulation_config: SimulationConfig, model: Model, file_name: str +) -> Parameters: + """Generate and save the parameter file for the simulation. + + This function generates the parameters for the given model, + updates them based on the simulation configuration, + validates the updated parameters, and saves them to a file. + + Args: + simulation_config (SimulationConfig): The configuration for the simulation. + model (Model): The model for which parameters are to be generated. + file_name (str): The name of the file to save the parameters. + + Returns + ------- + `Parameters` + The updated and validated parameters. + """ + parameters = model.generate_parameters() + updated_parameters = _update_parameter_values(parameters, simulation_config) + model.validate(updated_parameters) + save_parameters(updated_parameters, file_name, allow_overwrite=True) + return updated_parameters + + +def _generate_data_file( + model: Model, + parameters: Parameters, + coordinates: dict[str, np.ndarray], + settings: Settings, + file_name: str, +): + """Generate and save the data file for the simulation. + + This function simulates the data based on the given model, parameters, coordinates, + and settings, and saves the simulated data to a file. + + Args: + model (Model): The model used for simulation. + parameters (Parameters): The parameters used for simulation. + coordinates (Dict[str, np.ndarray]): The coordinates for the simulation. + settings (Settings): The settings for the simulation. + file_name (str): The name of the file to save the simulated data. + """ + noise = settings.stdev_noise != 0 + data = simulate( + model, + "dataset_1", + parameters, + coordinates, + noise=noise, + noise_std_dev=settings.stdev_noise, + noise_seed=settings.seed, + ) + save_dataset(data, file_name, "nc", allow_overwrite=True) + + +def generate_model_parameter_and_data_files( + simulation_config: SimulationConfig, + model_file_name: str = "model.yml", + parameter_file_name: str = "parameters.csv", + data_file_name: str = "dataset.nc", +): + """Generate and save the model, parameter, and data files for the simulation. + + This function generates the model file, parameter file, and data file based on the given + simulation configuration. + + Args: + simulation_config (SimulationConfig): The configuration for the simulation. + model_file_name (str, optional): The name of the file to save the model. + Defaults to "model.yml". + parameter_file_name (str, optional): The name of the file to save the parameters. + Defaults to "parameters.csv". + data_file_name (str, optional): The name of the file to save the data. + Defaults to "dataset.nc". + """ + nr_compartments = len(simulation_config.kinetic_parameters.decay_rates) + model = _generate_model_file(simulation_config, nr_compartments, model_file_name) + parameters = _generate_parameter_file(simulation_config, model, parameter_file_name) + _generate_data_file( + model, + parameters, + simulation_config.coordinates, + simulation_config.settings, + data_file_name, + ) + + +def _sanitize_dict(d: dict[str, Any] | Any) -> dict[str, Any] | Any: + """Sanitize an input dictionary and produce a new sanitized dictionary. + + Recursively sanitize a dictionary by removing keys with values that are None, empty lists, + or empty dictionaries. + + Parameters + ---------- + d : dict[str, Any] | Any + The dictionary to sanitize or any other value. + + Returns + ------- + dict[str, Any] | Any + The sanitized dict or the original value if input is not a dict. + """ + if not isinstance(d, dict): + return d + return {k: _sanitize_dict(v) for k, v in d.items() if v not in (None, [], {})} + + +def _sanitize_yaml_file(input_file: str, output_file: str) -> None: + """Sanitize an input YAML file and produce a new sanitized YAML file. + + Sanitize by removing keys with values that are None, empty lists, or empty + dictionaries, and save the sanitized content to a new file. + + Args: + input_file (str): The path to the input YAML file. + output_file (str): The path to the output sanitized YAML file. + """ + with Path(input_file).open() as f: + data = yaml.safe_load(f) + + sanitized_data = _sanitize_dict(data) + + with Path(output_file).open("w") as f: + yaml.safe_dump(sanitized_data, f) + Path(input_file).unlink() diff --git a/pyparamgui/widget.py b/pyparamgui/widget.py new file mode 100644 index 0000000..eb2834c --- /dev/null +++ b/pyparamgui/widget.py @@ -0,0 +1,125 @@ +"""Simulation widget module.""" + +from __future__ import annotations + +import pathlib + +import anywidget +import traitlets + +from pyparamgui.schema import IRF +from pyparamgui.schema import KineticParameters +from pyparamgui.schema import Settings +from pyparamgui.schema import SimulationConfig +from pyparamgui.schema import SpectralCoordinates +from pyparamgui.schema import SpectralParameters +from pyparamgui.schema import TimeCoordinates +from pyparamgui.schema import generate_simulation_coordinates +from pyparamgui.utils import generate_model_parameter_and_data_files + + +class Widget(anywidget.AnyWidget): + """A widget class for handling simulation parameters, coordinates and settings. + + Attributes + ---------- + _esm (pathlib.Path): Path to the JavaScript file for the widget. + _css (pathlib.Path): Path to the CSS file for the widget. + decay_rates_input (traitlets.List): List of decay rates as floats. + amplitude_input (traitlets.List): List of amplitudes as floats. + location_input (traitlets.List): List of locations as floats. + width_input (traitlets.List): List of widths as floats. + skewness_input (traitlets.List): List of skewness values as floats. + timepoints_max_input (traitlets.Int): Maximum number of timepoints. + timepoints_stepsize_input (traitlets.Float): Step size for timepoints. + wavelength_min_input (traitlets.Float): Minimum wavelength value. + wavelength_max_input (traitlets.Float): Maximum wavelength value. + wavelength_stepsize_input (traitlets.Float): Step size for wavelength. + stdev_noise_input (traitlets.Float): Standard deviation of noise. + seed_input (traitlets.Int): Seed for random number generation. + add_gaussian_irf_input (traitlets.Bool): Flag to add Gaussian IRF. + irf_location_input (traitlets.Float): Location of the IRF center. + irf_width_input (traitlets.Float): Width of the IRF. + use_sequential_scheme_input (traitlets.Bool): Flag to use sequential scheme. + model_file_name_input (traitlets.Unicode): Name of the model file. + parameter_file_name_input (traitlets.Unicode): Name of the parameter file. + data_file_name_input (traitlets.Unicode): Name of the data file. + simulate (traitlets.Unicode): Trigger for simulation. + """ + + _esm: pathlib.Path = pathlib.Path(__file__).parent / "static" / "form.js" + _css: pathlib.Path = pathlib.Path(__file__).parent / "static" / "form.css" + decay_rates_input: traitlets.List = traitlets.List(trait=traitlets.Float()).tag(sync=True) + amplitude_input: traitlets.List = traitlets.List(trait=traitlets.Float()).tag(sync=True) + location_input: traitlets.List = traitlets.List(trait=traitlets.Float()).tag(sync=True) + width_input: traitlets.List = traitlets.List(trait=traitlets.Float()).tag(sync=True) + skewness_input: traitlets.List = traitlets.List(trait=traitlets.Float()).tag(sync=True) + timepoints_max_input: traitlets.Int = traitlets.Int().tag(sync=True) + timepoints_stepsize_input: traitlets.Float = traitlets.Float().tag(sync=True) + wavelength_min_input: traitlets.Float = traitlets.Float().tag(sync=True) + wavelength_max_input: traitlets.Float = traitlets.Float().tag(sync=True) + wavelength_stepsize_input: traitlets.Float = traitlets.Float().tag(sync=True) + stdev_noise_input: traitlets.Float = traitlets.Float().tag(sync=True) + seed_input: traitlets.Int = traitlets.Int().tag(sync=True) + add_gaussian_irf_input: traitlets.Bool = traitlets.Bool().tag(sync=True) + irf_location_input: traitlets.Float = traitlets.Float().tag(sync=True) + irf_width_input: traitlets.Float = traitlets.Float().tag(sync=True) + use_sequential_scheme_input: traitlets.Bool = traitlets.Bool().tag(sync=True) + model_file_name_input: traitlets.Unicode = traitlets.Unicode("").tag(sync=True) + parameter_file_name_input: traitlets.Unicode = traitlets.Unicode("").tag(sync=True) + data_file_name_input: traitlets.Unicode = traitlets.Unicode("").tag(sync=True) + simulate: traitlets.Unicode = traitlets.Unicode("").tag(sync=True) + + +_widget = Widget() + + +def _simulate(_) -> None: + """Generate simulation files based on (global) widget (`_widget`) inputs. + + This private callback function creates model, parameter, and data files + using the current widget (`_widget`) state. The 'change' parameter is unused but + required for traitlet observation. + """ + simulation_config = SimulationConfig( + kinetic_parameters=KineticParameters(decay_rates=_widget.decay_rates_input), + spectral_parameters=SpectralParameters( + amplitude=_widget.amplitude_input, + location=_widget.location_input, + width=_widget.width_input, + skewness=_widget.skewness_input, + ), + coordinates=generate_simulation_coordinates( + TimeCoordinates( + timepoints_max=_widget.timepoints_max_input, + timepoints_stepsize=_widget.timepoints_stepsize_input, + ), + SpectralCoordinates( + wavelength_min=_widget.wavelength_min_input, + wavelength_max=_widget.wavelength_max_input, + wavelength_stepsize=_widget.wavelength_stepsize_input, + ), + ), + settings=Settings( + stdev_noise=_widget.stdev_noise_input, + seed=_widget.seed_input, + add_gaussian_irf=_widget.add_gaussian_irf_input, + use_sequential_scheme=_widget.use_sequential_scheme_input, + ), + irf=IRF(center=_widget.irf_location_input, width=_widget.irf_width_input), + ) + generate_model_parameter_and_data_files( + simulation_config, + model_file_name=_widget.model_file_name_input, + parameter_file_name=_widget.parameter_file_name_input, + data_file_name=_widget.data_file_name_input, + ) + + +def setup_widget_observer() -> None: + """Set up an observer to trigger simulation when the widget state changes. + + This function sets up an observer on the 'simulate' traitlet. When triggered, it runs the + simulation process, generating the necessary model, parameter, and data files. + """ + _widget.observe(handler=_simulate, names=["simulate"]) diff --git a/pyproject.toml b/pyproject.toml index fd05725..1324986 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,17 +15,15 @@ keywords = [ license = { file = "LICENSE" } authors = [ - { name = "Anmol Bhatia", email = "a.bhatia2@student.vu.nl" }, + { name = "Anmol Bhatia", email = "anmolbhatia05@gmail.com " }, ] -requires-python = ">=3.8" +requires-python = ">=3.10" classifiers = [ "Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Natural Language :: English", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -35,6 +33,10 @@ dynamic = [ ] dependencies = [ + "anywidget==0.9.13", + "pydantic==2.8.2", + "pyglotaran==0.7.2", + "pyyaml==6.0.1", ] optional-dependencies.dev = [ "pyparamgui[docs,test]", diff --git a/requirements_pinned.txt b/requirements_pinned.txt index 38ff788..77b15c1 100644 --- a/requirements_pinned.txt +++ b/requirements_pinned.txt @@ -1,2 +1,7 @@ # runtime requirements # pinned so the bot can create PRs to test with new versions + +pydantic==2.8.2 +anywidget==0.9.13 +pyglotaran==0.7.2 +pyyaml==6.0.1 diff --git a/tests/test_simulate.py b/tests/test_simulate.py new file mode 100644 index 0000000..ae5f4ee --- /dev/null +++ b/tests/test_simulate.py @@ -0,0 +1,93 @@ +"""Test for widget _simulate functionality.""" + +from __future__ import annotations + +import tempfile +from contextlib import contextmanager +from pathlib import Path + +import pytest + +from pyparamgui.widget import Widget +from pyparamgui.widget import _simulate +from pyparamgui.widget import _widget as global_widget + + +def _create_mock_widget(): + widget = Widget() + widget.decay_rates_input = [0.1, 0.2] + widget.amplitude_input = [1.0, 2.0] + widget.location_input = [400, 500] + widget.width_input = [10, 20] + widget.skewness_input = [0, 0] + widget.timepoints_max_input = 100 + widget.timepoints_stepsize_input = 0.1 + widget.wavelength_min_input = 400 + widget.wavelength_max_input = 600 + widget.wavelength_stepsize_input = 1 + widget.stdev_noise_input = 0.01 + widget.seed_input = 42 + widget.add_gaussian_irf_input = True + widget.irf_location_input = 0 + widget.irf_width_input = 0.1 + widget.use_sequential_scheme_input = False + widget.model_file_name_input = "model.yml" + widget.parameter_file_name_input = "parameters.csv" + widget.data_file_name_input = "dataset.nc" + return widget + + +@pytest.fixture() +def mock_widget(): + """Return a mock Widget for testing.""" + return _create_mock_widget() + + +@contextmanager +def use_mock_widget(mock_widget): + """Mock the global Widget instance (`_widget`) during the test.""" + original_widget = global_widget + import pyparamgui.widget + + pyparamgui.widget._widget = mock_widget + try: + yield + finally: + pyparamgui.widget._widget = original_widget + + +@pytest.fixture() +def temp_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as tmpdirname: + yield tmpdirname + + +def test_simulate(mock_widget, temp_dir): + """Test the _simulate function to ensure it generates non-empty files.""" + model_file_name_input_path = Path(temp_dir) / "model.yml" + parameter_file_name_input_path = Path(temp_dir) / "parameters.csv" + data_file_name_input_path = Path(temp_dir) / "dataset.nc" + + with use_mock_widget(mock_widget): + mock_widget.model_file_name_input = str(model_file_name_input_path) + mock_widget.parameter_file_name_input = str(parameter_file_name_input_path) + mock_widget.data_file_name_input = str(data_file_name_input_path) + _simulate(None) + + # Check if files exist and are not empty + # Ideally you would also want to check their content here + assert model_file_name_input_path.exists() + assert model_file_name_input_path.stat().st_size > 0 + + assert parameter_file_name_input_path.exists() + assert parameter_file_name_input_path.stat().st_size > 0 + + assert data_file_name_input_path.exists() + assert data_file_name_input_path.stat().st_size > 0 + + +if __name__ == "__main__": + my_mock_widget = _create_mock_widget() + with tempfile.TemporaryDirectory() as temp_dir: + test_simulate(my_mock_widget, temp_dir)