diff --git a/src/sisl_toolbox/cli/__init__.py b/src/sisl_toolbox/cli/__init__.py index bc84a60fd8..55680b8d75 100644 --- a/src/sisl_toolbox/cli/__init__.py +++ b/src/sisl_toolbox/cli/__init__.py @@ -7,73 +7,21 @@ accessible. """ +import typer +from ._typer_wrappers import annotate_typer -class SToolBoxCLI: - """ Run the CLI `stoolbox` """ +from sisl_toolbox.siesta.atom._atom import atom_plot +from sisl_toolbox.transiesta.poisson.fftpoisson_fix import fftpoisson_fix - def __init__(self): - self._cmds = [] +app = typer.Typer( + name="Sisl toolbox", + help="Specific toolboxes to aid sisl users", + rich_markup_mode="markdown", + add_completion=False +) - def register(self, setup): - """ Register a setup callback function which creates the subparser +app.command()(annotate_typer(atom_plot)) +app.command("ts-fft")(annotate_typer(fftpoisson_fix)) - The ``setup(..)`` command must accept a sub-parser from `argparse` as its - first argument. +stoolbox_cli = app - The only requirements to create a sub-command is to fullfill these requirements: - - 1. Create a new parser using ``subp.add_parser``. - 2. Ensure a runner is attached to the subparser through ``.set_defaults(runner=)`` - - A minimal example would be: - - >>> def setup(subp): - ... p = subp.add_parser("test-sub") - ... def test_sub_method(args): - ... print(args) - ... p.set_defaults(runner=test_sub_method) - """ - self._cmds.append(setup) - - def __call__(self, argv=None): - import argparse - import sys - from pathlib import Path - - # Create command-line - cmd = Path(sys.argv[0]) - p = argparse.ArgumentParser(f"{cmd.name}", - description="Specific toolboxes to aid sisl users") - - info = { - "title": "Toolboxes", - "metavar": "TOOL", - } - - # Check which Python version we have - version = sys.version_info - if version.major >= 3 and version.minor >= 7: - info["required"] = True - - # Create the sub-parser - subp = p.add_subparsers(**info) - - for cmd in self._cmds: - cmd(subp) - - args = p.parse_args(argv) - args.runner(args) - - -# Populate the commands - -# First create the class to hold and dynamically create the commands -stoolbox_cli = SToolBoxCLI() - -from sisl_toolbox.transiesta.poisson.fftpoisson_fix import fftpoisson_fix_cli - -stoolbox_cli.register(fftpoisson_fix_cli) - -from sisl_toolbox.siesta.atom._atom import atom_plot_cli - -stoolbox_cli.register(atom_plot_cli) diff --git a/src/sisl_toolbox/cli/_cli_arguments.py b/src/sisl_toolbox/cli/_cli_arguments.py new file mode 100644 index 0000000000..f0328b735e --- /dev/null +++ b/src/sisl_toolbox/cli/_cli_arguments.py @@ -0,0 +1,54 @@ +# Classes that hold information regarding how a given parameter should behave in a CLI +# They are meant to be used as metadata for the type annotations. That is, passing them +# to Annotated. E.g.: Annotated[int, CLIArgument(option="some_option")]. Even if they +# are empty, they indicate whether to treat the parameter as an argument or an option. +class CLIArgument: + def __init__(self, **kwargs): + self.kwargs = kwargs + +class CLIOption: + def __init__(self, *param_decls: str, **kwargs): + if len(param_decls) > 0: + kwargs["param_decls"] = param_decls + self.kwargs = kwargs + +def get_params_help(func) -> dict: + """Gets the text help of parameters from the docstring""" + params_help = {} + + in_parameters = False + read_key = None + arg_content = "" + + for line in func.__doc__.split("\n"): + if "Parameters" in line: + in_parameters = True + space = line.find("Parameters") + continue + + if in_parameters: + if len(line) < space + 1: + continue + if len(line) > 1 and line[0] != " ": + break + + if line[space] not in (" ", "-"): + if read_key is not None: + params_help[read_key] = arg_content + + read_key = line.split(":")[0].strip() + arg_content = "" + else: + if arg_content == "": + arg_content = line.strip() + arg_content = arg_content[0].upper() + arg_content[1:] + else: + arg_content += " " + line.strip() + + if line.startswith("------"): + break + + if read_key is not None: + params_help[read_key] = arg_content + + return params_help \ No newline at end of file diff --git a/src/sisl_toolbox/cli/_typer_wrappers.py b/src/sisl_toolbox/cli/_typer_wrappers.py new file mode 100644 index 0000000000..625d2376c3 --- /dev/null +++ b/src/sisl_toolbox/cli/_typer_wrappers.py @@ -0,0 +1,112 @@ +import typing +from typing_extensions import Annotated + +from enum import Enum + +import inspect +from copy import copy +import yaml + +import typer + +from ._cli_arguments import CLIArgument, CLIOption, get_params_help + +def get_dict_param_kwargs(dict_annotation_args): + + def yaml_dict(d: str): + + if isinstance(d, dict): + return d + + return yaml.safe_load(d) + + argument_kwargs = {"parser": yaml_dict} + + if len(dict_annotation_args) == 2: + try: + argument_kwargs["metavar"] = f"YAML_DICT[{dict_annotation_args[0].__name__}: {dict_annotation_args[1].__name__}]" + except: + argument_kwargs["metavar"] = f"YAML_DICT[{dict_annotation_args[0]}: {dict_annotation_args[1]}]" + + return argument_kwargs + +# This dictionary keeps the kwargs that should be passed to typer arguments/options +# for a given type. This is for example to be used for types that typer does not +# have built in support for. +_CUSTOM_TYPE_KWARGS = { + dict: get_dict_param_kwargs, +} + +def _get_custom_type_kwargs(type_): + + if hasattr(type_, "__metadata__"): + type_ = type_.__origin__ + + if typing.get_origin(type_) is not None: + args = typing.get_args(type_) + type_ = typing.get_origin(type_) + else: + args = () + + try: + argument_kwargs = _CUSTOM_TYPE_KWARGS.get(type_, {}) + if callable(argument_kwargs): + argument_kwargs = argument_kwargs(args) + except: + argument_kwargs = {} + + return argument_kwargs + + +def annotate_typer(func): + """Annotates a function for a typer app. + + It returns a new function, the original function is not modified. + """ + # Get the help message for all parameters found at the docstring + params_help = get_params_help(func) + + # Get the original signature of the function + sig = inspect.signature(func) + + # Loop over parameters in the signature, modifying them to include the + # typer info. + new_parameters = [] + for param in sig.parameters.values(): + + argument_kwargs = _get_custom_type_kwargs(param.annotation) + + default = param.default + if isinstance(param.default, Enum): + default = default.value + + typer_arg_cls = typer.Argument if param.default == inspect.Parameter.empty else typer.Option + if hasattr(param.annotation, "__metadata__"): + for meta in param.annotation.__metadata__: + if isinstance(meta, CLIArgument): + typer_arg_cls = typer.Argument + argument_kwargs.update(meta.kwargs) + elif isinstance(meta, CLIOption): + typer_arg_cls = typer.Option + argument_kwargs.update(meta.kwargs) + + if "param_decls" in argument_kwargs: + argument_args = argument_kwargs.pop("param_decls") + else: + argument_args = [] + + new_parameters.append( + param.replace( + default=default, + annotation=Annotated[param.annotation, typer_arg_cls(*argument_args, help=params_help.get(param.name), **argument_kwargs)] + ) + ) + + # Create a copy of the function and update it with the modified signature. + # Also remove parameters documentation from the docstring. + annotated_func = copy(func) + + annotated_func.__signature__ = sig.replace(parameters=new_parameters) + annotated_func.__doc__ = func.__doc__[:func.__doc__.find("Parameters\n")] + + return annotated_func \ No newline at end of file diff --git a/src/sisl_toolbox/siesta/atom/_atom.py b/src/sisl_toolbox/siesta/atom/_atom.py index 8f740d5f45..b380f5f540 100644 --- a/src/sisl_toolbox/siesta/atom/_atom.py +++ b/src/sisl_toolbox/siesta/atom/_atom.py @@ -22,8 +22,12 @@ which will show 4 plots for different sections. A command-line tool is also made available through the `stoolbox`. """ +from typing import Optional, List +from typing_extensions import Annotated + import sys from collections.abc import Iterable +from enum import Enum from functools import reduce from pathlib import Path @@ -33,7 +37,9 @@ import sisl as si from sisl.utils import NotNonePropertyDict, PropertyDict -__all__ = ["AtomInput", "atom_plot_cli"] +from sisl_toolbox.cli._cli_arguments import CLIArgument, CLIOption + +__all__ = ["AtomInput", "atom_plot"] _script = Path(sys.argv[0]).name @@ -737,63 +743,58 @@ def next_rc(ir, ic, nrows, ncols): return fig, axs -def atom_plot_cli(subp=None): - """ Run plotting command for the output of atom """ - - is_sub = not subp is None - - title = "Plotting facility for atom output (run in the atom output directory)" - if is_sub: - global _script - _script = f"{_script} atom-plot" - p = subp.add_parser("atom-plot", description=title, help=title) - else: - import argparse - p = argparse.ArgumentParser(title) - - p.add_argument("--plot", '-P', action='append', type=str, - choices=('wavefunction', 'charge', 'log', 'potential'), - help="""Determine what to plot""") - - p.add_argument("-l", default='spdf', type=str, - help="""Which l shells to plot""") - - p.add_argument("--save", "-S", default=None, - help="""Save output plots to file.""") - - p.add_argument("--show", default=False, action='store_true', - help="""Force showing the plot (only if --save is specified)""") +class AtomPlotOption(Enum): + """Plotting options for atom""" + wavefunction = 'wavefunction' + charge = 'charge' + log = 'log' + potential = 'potential' - p.add_argument("input", type=str, default="INP", - help="""Input file name (default INP)""") +def atom_plot( + plot: Annotated[Optional[List[AtomPlotOption]], CLIArgument()] = None, + input: Path = Path("INP"), + l: str = 'spdf', + save: Annotated[Optional[str], CLIOption("-S", "--save")] = None, + show: bool = False +): + """Plotting facility for atom output (run in the atom output directory) - if is_sub: - p.set_defaults(runner=atom_plot) - else: - atom_plot(p.parse_args()) - - -def atom_plot(args): + Parameters + ---------- + plot: + Determine what to plot. If None is given, it plots everything. + input: + Input file name. + l: + Which l shells to plot. + save: + Save output plots to file. + show: + Force showing the plot. + """ import matplotlib.pyplot as plt - input = Path(args.input) - atom = AtomInput.from_input(input) + input_path = Path(input) + atom = AtomInput.from_input(input_path) # If the specified input is a file, use the parent # Otherwise use the input *as is*. - if input.is_file(): - path = input.parent + if input_path.is_file(): + path = input_path.parent else: - path = input + path = input_path # if users have not specified what to plot, we plot everything - if args.plot is None: - args.plot = ('wavefunction', 'charge', 'log', 'potential') - fig = atom.plot(path, plot=args.plot, l=args.l, show=False)[0] + if plot is None: + plots = [p.value for p in AtomPlotOption] + else: + plots = [p.value if isinstance(p, AtomPlotOption) else p for p in plot ] - if args.save is None: + fig = atom.plot(path, plot=plots, l=l, show=False)[0] + + if save is None: plt.show() else: - fig.savefig(args.save) - if args.show: + fig.savefig(save) + if show: plt.show() diff --git a/src/sisl_toolbox/transiesta/poisson/fftpoisson_fix.py b/src/sisl_toolbox/transiesta/poisson/fftpoisson_fix.py index d280f73b04..da461f51d0 100644 --- a/src/sisl_toolbox/transiesta/poisson/fftpoisson_fix.py +++ b/src/sisl_toolbox/transiesta/poisson/fftpoisson_fix.py @@ -47,7 +47,11 @@ - It may not always converge which requires some fine-tuning of the tolerances, secondly it may converge too fast so the solution is not really good. """ +from typing import Tuple, List, Optional, Dict +from typing_extensions import Annotated + import argparse as argp +from enum import Enum import os import sys from pathlib import Path @@ -56,7 +60,9 @@ import sisl as si -__all__ = ['pyamg_solve', 'solve_poisson', 'fftpoisson_fix_cli', 'fftpoisson_fix_run'] +from sisl_toolbox.cli._cli_arguments import CLIOption + +__all__ = ['pyamg_solve', 'solve_poisson', 'fftpoisson_fix'] # Base-script name @@ -278,122 +284,162 @@ def sl2idx(grid, sl): return grid - -def fftpoisson_fix_cli(subp=None): - is_sub = not subp is None - - title = "FFT Poisson corrections for TranSiesta calculations for arbitrary number of electrodes." - if is_sub: - global _script - _script = f"{_script} ts-fft" - p = subp.add_parser("ts-fft", description=title, help=title) - else: - p = argp.ArgumentParser(title) - - tuning = p.add_argument_group("tuning", "Tuning fine details of the Poisson calculation.") - - p.add_argument("--geometry", "-G", default="siesta.TBT.nc", metavar="FILE", - help="siesta.TBT.nc file which contains the geometry and electrode information, currently we cannot read that from fdf-files.") - - p.add_argument("--shape", "-s", nargs=3, type=int, required=True, metavar=("A", "B", "C"), - help="Grid shape, this *has* to be conforming to the TranSiesta calculation, read from output: 'InitMesh: MESH = A x B x C'") - - n = {"a": "first", "b": "second", "c": "third"} - for d in "abc": - p.add_argument(f"--boundary-condition-{d}", f"-bc-{d}", nargs=2, type=str, default=["p", "p"], - metavar=("BOTTOM", "TOP"), - help=("Boundary condition along the {} lattice vector [periodic/p, neumann/n, dirichlet/d]. " - "Specify separate BC at the start and end of the lattice vector, respectively.".format(n[d]))) - - p.add_argument("--elec-V", "-V", action="append", nargs=2, metavar=("NAME", "V"), default=[], - help="Specify chemical potential on electrode") - - p.add_argument("--pyamg-shape", "-ps", nargs=3, type=int, metavar=("A", "B", "C"), default=None, - help="Grid used to solve the Poisson equation, if shape is different the Grid will be interpolated (order=2) after.") - - p.add_argument("--device", "-D", type=float, default=None, metavar="VAL", - help="Fix the value of all device atoms to a value. In some cases this turns out to yield a better box boundary. The default is to *not* fix the potential on the device atoms.") - - tuning.add_argument("--radius", "-R", type=float, default=3., metavar="R", - help=("Radius of atoms when figuring out the electrode sizes, this corresponds to the extend of " - "each electrode where boundary conditions are fixed. Should be tuned according to the atomic species [3 Ang]")) - - tuning.add_argument("--dtype", "-d", choices=["d", "f64", "f", "f32"], default="d", - help="Precision of data (d/f64==double, f/f32==single)") - - tuning.add_argument("--tolerance", "-T", type=float, default=1e-7, metavar="EPS", - help="Precision required for the pyamg solver. NOTE when using single precision arrays this should probably be on the order of 1e-5") - - tuning.add_argument("--acceleration", "-A", dest="accel", default="cg", metavar="METHOD", - help="""Acceleration method for pyamg. May be useful if it fails to converge - -Try one of: cg, gmres, fgmres, cr, cgnr, cgne, bicgstab, steepest_descent, minimal_residual""") - - test = p.add_argument_group("testing", "Options used for testing output. None of these options should be used for production runs!") - test.add_argument("--box", dest="box", action="store_true", default=False, - help="Only store the initial box solution (i.e. do not run PyAMG)") - - test.add_argument("--no-boundary-fft", action="store_false", dest="boundary_fft", default=True, - help="Once the electrode boundary conditions are solved we perform a second solution with boundaries fixed. Using this flag disables this second solution.") - - if _DEBUG: - test.add_argument("--plot", dest="plot", default=None, type=int, - help="Plot grid by averaging over the axis given as argument") - - test.add_argument("--plot-boundary", dest="plot_boundary", action="store_true", - help="Plot all 6 edges of the box with their fixed values (just before 2nd pyamg solve step)") - - p.add_argument("--out", "-o", action="append", default=None, - help="Output file to store the resulting Poisson solution. It *has* to have TSV.nc file ending to make the file conforming with TranSiesta.") - - if is_sub: - p.set_defaults(runner=fftpoisson_fix_run) - else: - fftpoisson_fix_run(p.parse_args()) - - -def fftpoisson_fix_run(args): - if args.out is None: +class DtypeOption(Enum): + """Data types""" + d = "d" + f = "f" + f64 = "f64" + f32 = "f32" + +class AccelMethod(Enum): + """Acceleration methods for pyamg""" + cg = "cg" + gmres = "gmres" + fgmres = "fgmres" + cr = "cr" + cgnr = "cgnr" + cgne = "cgne" + bicgstab = "bicgstab" + steepest_descent = "steepest_descent" + minimal_residual = "minimal_residual" + +def fftpoisson_fix( + shape: Annotated[Tuple[int, int, int], CLIOption("-S", "--shape")], + geometry: Path = Path("siesta.TBT.nc"), + boundary_condition_a: Annotated[Tuple[str, str], CLIOption("-bc-a", "--boundary-condition-a")] = ("p", "p"), + boundary_condition_b: Annotated[Tuple[str, str], CLIOption("-bc-b", "--boundary-condition-b")] = ("p", "p"), + boundary_condition_c: Annotated[Tuple[str, str], CLIOption("-bc-c", "--boundary-condition-c")] = ("p", "p"), + elec_V: Annotated[Dict[str, float], CLIOption("-V", "--elec-V", )] = {}, + pyamg_shape: Annotated[Tuple[int, int, int], CLIOption("-ps", "--pyamg-shape")] = (-1, -1, -1), + device: Annotated[Optional[float], CLIOption("-D", "--device")] = None, + radius: Annotated[float, CLIOption("-R", "--radius")] = 3., + dtype: Annotated[DtypeOption, CLIOption("-d", "--dtype")] = DtypeOption.d, + tolerance: Annotated[float, CLIOption("-T", "--tolerance")] = 1e-7, + accel: Annotated[AccelMethod, CLIOption("-A", "--acceleration")] = AccelMethod.cg, + out: Annotated[List[str], CLIOption("-o", "--out", metavar="PATH")] = [], + box: bool = False, + boundary_fft: bool = True, + plot: Optional[int] = None, + plot_boundary: bool = False, +): + """FFT Poisson corrections for TranSiesta calculations for arbitrary number of electrodes. + + Parameters + ---------- + geometry: + siesta.TBT.nc file which contains the geometry and electrode information, + currently we cannot read that from fdf-files. + shape: + Grid shape, this *has* to be conforming to the TranSiesta calculation, + read from output: 'InitMesh: MESH = A x B x C' + boundary_condition_a: + Boundary condition along the first lattice vector [periodic/p, neumann/n, dirichlet/d]. + Specify separate BC at the start and end of the lattice vector, respectively. + boundary_condition_b: + Boundary condition along the second lattice vector [periodic/p, neumann/n, dirichlet/d]. + Specify separate BC at the start and end of the lattice vector, respectively. + boundary_condition_c: + Boundary condition along the third lattice vector [periodic/p, neumann/n, dirichlet/d]. + Specify separate BC at the start and end of the lattice vector, respectively. + elec_V: + Specify chemical potential on electrode. + pyamg_shape: + Grid used to solve the Poisson equation, if shape is different + the Grid will be interpolated (order=2) after. + (-1, -1, -1) means use the same as shape. + device: + Fix the value of all device atoms to a value. + In some cases this turns out to yield a better box boundary. + The default is to *not* fix the potential on the device atoms. + radius: + Radius of atoms when figuring out the electrode sizes, + this corresponds to the extend of each electrode where boundary conditions are fixed. + Should be tuned according to the atomic species [3 Ang] + dtype: + Precision of data (d/f64==double, f/f32==single) + tolerance: + Precision required for the pyamg solver. + NOTE when using single precision arrays this should probably be on the order of 1e-5 + accel: + Acceleration method for pyamg. + May be useful if it fails to converge + out: + Output file to store the resulting Poisson solution. + It *has* to have TSV.nc file ending to make the file conforming with TranSiesta. + box: + Only store the initial box solution (i.e. do not run PyAMG) + boundary_fft: + Once the electrode boundary conditions are solved we perform a second solution with boundaries fixed. + Using this flag disables this second solution. + plot: + Plot grid by averaging over the axis given as argument + plot_boundary: + Plot all 6 edges of the box with their fixed values (just before 2nd pyamg solve step) + + """ + + print(dict( + geometry=geometry, + shape=shape, + boundary_condition_a=boundary_condition_a, + boundary_condition_b=boundary_condition_b, + boundary_condition_c=boundary_condition_c, + elec_V=elec_V, + pyamg_shape=pyamg_shape, + device=device, + radius=radius, + dtype=dtype, + tolerance=tolerance, + accel=accel, + out=out, + box=box, + boundary_fft=boundary_fft, + plot=plot, + plot_boundary=plot_boundary, + )) + + return + if len(out) == 0: print(f">\n>\n>{_script}: No out-files has been specified, work will be carried out but not saved!\n>\n>\n") # Read in geometry - geometry = si.get_sile(args.geometry).read_geometry() + geometry = si.get_sile(geometry).read_geometry() # Figure out the electrodes elecs_V = {} - if len(args.elec_V) == 0: + if len(elec_V) == 0: print(geometry.names) raise ValueError(f"{_script}: Please specify all electrode potentials using --elec-V") - for name, V in args.elec_V: - elecs_V[name] = float(V) + for name, V in elec_V.items(): + elecs_V[name] = V - if args.dtype.lower() in ("f", "f32"): + if dtype.value.lower() in ("f", "f32"): dtype = np.float32 - elif args.dtype.lower() in ("d", "f64"): + elif dtype.value.lower() in ("d", "f64"): dtype = np.float64 # Now we can solve Poisson - if args.pyamg_shape is None: - shape = args.shape + if pyamg_shape[0] == -1: + shape = shape else: - shape = args.pyamg_shape + shape = pyamg_shape # Create the boundary conditions boundary = [] - boundary.append(args.boundary_condition_a) - boundary.append(args.boundary_condition_b) - boundary.append(args.boundary_condition_c) - - V = solve_poisson(geometry, shape, radius=args.radius, boundary=boundary, - dtype=dtype, tolerance=args.tolerance, box=args.box, - accel=args.accel, boundary_fft=args.boundary_fft, - device_val=args.device, plot_boundary=args.plot_boundary, + boundary.append(boundary_condition_a) + boundary.append(boundary_condition_b) + boundary.append(boundary_condition_c) + + V = solve_poisson(geometry, shape, radius=radius, boundary=boundary, + dtype=dtype, tolerance=tolerance, box=box, + accel=accel.value, boundary_fft=boundary_fft, + device_val=device, plot_boundary=plot_boundary, **elecs_V) if _DEBUG: - if not args.plot is None: - dat = V.average(args.plot) + if not plot is None: + dat = V.average(plot) import matplotlib.pyplot as plt axs = [ np.linspace(0, V.lattice.length[ax], shape, endpoint=False) for ax, shape in enumerate(V.shape) @@ -401,29 +447,24 @@ def fftpoisson_fix_run(args): idx = list(range(3)) # Now plot data - del axs[args.plot] - del idx[args.plot] + del axs[plot] + del idx[plot] X, Y = np.meshgrid(*axs) plt.contourf(X, Y, np.squeeze(dat.grid).T) plt.colorbar() - plt.title(f"Averaged over {'ABC'[args.plot]} axis") + plt.title(f"Averaged over {'ABC'[plot]} axis") plt.xlabel(f"Distance along {'ABC'[idx[0]]} [Ang]") plt.ylabel(f"Distance along {'ABC'[idx[1]]} [Ang]") plt.show() - if np.any(np.array(args.shape) != np.array(V.shape)): + if np.any(np.array(shape) != np.array(V.shape)): print("\nInterpolating the solution...") - V = V.interp(args.shape, 2) + V = V.interp(shape, 2) print("Done interpolating!") print("") # Write solution to the output - if not args.out is None: - for out in args.out: - print(f"Writing to file: {out}...") - V.write(out) - - -if __name__ == "__main__": - fftpoisson_fix_cli() + for out_file in out: + print(f"Writing to file: {out_file}...") + V.write(out_file)