From 66adcb2d81587fdc828f9ffac09932595e388f7e Mon Sep 17 00:00:00 2001 From: Greg Tucker Date: Tue, 21 Nov 2023 17:17:40 +0100 Subject: [PATCH] Fix component placement bug (#17) * [Refactor] version and registry file for editable instal Importing `__version__` from `mccode_antlr` only works if it is installed, otherwise the `__init__` file is not evaluated(?) when importing the module. It is also not possible to use the importlib.resources files method to find the registry files from an editable-installed module. In an attempt to support ediable installs, for easier debugging of dependent modules, the version method is defined in a new file, version.py, which is imported and used to populate the module __version__ property. This method is then used in place of the dunder version throughout mccode_antlr. The importlib.metadata module has a distribution method which can be used to identify editable installs of modules. This is used in a bit of a hack to locate registry files in case the importlib.resources method fails. Unfortunatley, the auto-generation of the grammar files does not work for editable installs; so the whole machinery still fails. These files *should* be static in the distributed module -- so that the dependency on antlr can be reduced to just the runtime. But since this is not entirely trivial to change, editable installations will continue to not work properly. * [Ref] minor changes in runtime identified while debugging * [Fix] C component positioning bug using rotation angles instead of position vector * [Add] 0D dat file, [Expand] compiled test case --- .gitignore | 3 +- mccode_antlr/__init__.py | 14 +----- mccode_antlr/commands.py | 4 +- mccode_antlr/io/hdf5.py | 10 ++-- mccode_antlr/libc-registry.txt | 4 +- mccode_antlr/loader/datfile.py | 59 +++++++++++++++++++++--- mccode_antlr/reader/registry.py | 32 +++++++++---- mccode_antlr/translators/c_initialise.py | 31 +++++++------ mccode_antlr/translators/c_macros.py | 2 +- mccode_antlr/version.py | 11 +++++ test/test_instr.py | 39 ++++++++++------ 11 files changed, 145 insertions(+), 64 deletions(-) create mode 100644 mccode_antlr/version.py diff --git a/.gitignore b/.gitignore index cdff89b..7e2978b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ */*.egg-info/ .idea/ *build/ -*.egg-info/ \ No newline at end of file +*.egg-info/ +*.c diff --git a/mccode_antlr/__init__.py b/mccode_antlr/__init__.py index e27677e..d374079 100644 --- a/mccode_antlr/__init__.py +++ b/mccode_antlr/__init__.py @@ -3,19 +3,7 @@ __author__ = "Gregory Tucker" __affiliation__ = "European Spallation Source ERIC" - -def version(): - import sys - if sys.version_info[0] == 3 and sys.version_info[1] < 8: - import importlib_metadata - else: - import importlib.metadata as importlib_metadata - try: - return importlib_metadata.version("mccode_antlr") - except importlib_metadata.PackageNotFoundError: - return "dev" - - +from .version import version __version__ = version() __all__ = ["__author__", "__affiliation__", "__version__", "version"] diff --git a/mccode_antlr/commands.py b/mccode_antlr/commands.py index 04bd668..2bc1ae7 100644 --- a/mccode_antlr/commands.py +++ b/mccode_antlr/commands.py @@ -24,8 +24,8 @@ def resolvable(name: str): if args.version: from sys import exit - from mccode_antlr import __version__ - print(f'mccode_antlr code generator version {__version__}') + from mccode_antlr.version import version + print(f'mccode_antlr code generator version {version()}') print(' Copyright (c) European Spallation Source ERIC, 2023') print('Based on McStas/McXtrace version 3') print(' Copyright (c) DTU Physics and Risoe National Laboratory, 1997-2023') diff --git a/mccode_antlr/io/hdf5.py b/mccode_antlr/io/hdf5.py index 0ece2eb..0f7c68e 100644 --- a/mccode_antlr/io/hdf5.py +++ b/mccode_antlr/io/hdf5.py @@ -10,17 +10,17 @@ def _split_version_name(version_name): def _write_header(group, data_type): - from mccode_antlr import __version__ - group.attrs[VERSION_NAME_KEY] = f'{__version__}/{data_type.__name__}' + from mccode_antlr.version import version + group.attrs[VERSION_NAME_KEY] = f'{version()}/{data_type.__name__}' def _check_header(group, data_type): - from mccode_antlr import __version__ + from mccode_antlr.version import version as module_version if VERSION_NAME_KEY not in group.attrs: raise RuntimeError(f"File does not have type information") version, name = _split_version_name(group.attrs[VERSION_NAME_KEY]) - if version != __version__: - raise RuntimeError(f"File was created with mccode_antlr {version}, but asked to read version {__version__}") + if version != module_version(): + raise RuntimeError(f"File was created with mccode_antlr {version}, but asked to read {module_version()}") if name != data_type.__name__: raise RuntimeError(f"Group contains mccode_antlr type {name}, but asked to read type {data_type}") diff --git a/mccode_antlr/libc-registry.txt b/mccode_antlr/libc-registry.txt index e440b09..667c958 100644 --- a/mccode_antlr/libc-registry.txt +++ b/mccode_antlr/libc-registry.txt @@ -7,11 +7,11 @@ interoff-lib.h ce347e0fddffc679c653fc152753c034604b4b3e6b4a7f370e759c873fc77c2a interpolation-lib.c 2c7b84c35305f401873f33b49b7f0c21588dd37a2647d9e08ec6f3412d6eb039 interpolation-lib.h dcafc01739ee3b7bc24bf09536df9c2e34666c61a6f3eeb86f1b39c086bae45e mccode-r.c 8f2a0b28c7f3602501dd551592516ef9c4cacef85640ae648a7613e700de3aa9 -mccode-r.h 598a33467133d253b212ea36587a242acdf38374d81e5f132e68d95ddeae069e +mccode-r.h 827861922d2c9e18a7ab9d8a176b5acb0e950f31afe15bde46daffa359f8ecf3 mccode_main.c a5f6cfe4e5b10af434987eed499b053987b3e1997971afb7146b83c59260bed0 mcstas-d.h 58eabd41cbe02e9a7c2e631eefa2111db176c4a6c9e95dbeb6a14aab4d811591 mcxtrace-d.h 3fbfa9a8148a39457b30cc09c25e006695a7f94aa37152c3c7250c05412f1d8b -metadata-r.c 8f8123ef527bebfe8e9f08b82cd7b3898a46e2559db0eefeba91a3d18db076b2 +metadata-r.c b97cfc8ba36602173058306bae3bd996b23fb78d15a85369d8ae507862986b0b metadata-r.h d50ed95754a90dea9891955cc99c03395e662e2b6d9b106f98246a7fcc9e497d nlib/general.c 487d491147dc09d3fc32805f494584167a46cedc9f4f9e4b742071480dcf0f90 nlib/general.h 7b8564d217ec5aed28026d7b02efeab6d4cb0a3cb0602d87ee956ce46dd6f88e diff --git a/mccode_antlr/loader/datfile.py b/mccode_antlr/loader/datfile.py index ddc98e1..7fa3107 100644 --- a/mccode_antlr/loader/datfile.py +++ b/mccode_antlr/loader/datfile.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass, field from typing import Union from pathlib import Path @@ -13,7 +15,7 @@ class DatFileCommon: data: ndarray = field(default_factory=ndarray) @classmethod - def from_filename(cls, filename: str): + def from_filename(cls, filename: str | Path): from numpy import array source = Path(filename).resolve() if not source.exists(): @@ -94,6 +96,45 @@ def dim_metadata(length, label_unit, lower_limit, upper_limit) -> dict: return dict(lenght=length, label=label, unit=unit, bin_boundaries=boundaries) +@dataclass +class DatFile0D(DatFileCommon): + def __post_init__(self): + nv = len(self.variables) + if self.data.shape == (nv, ): + # shortcut in case we're already in the right shape (e.g., from __add__) + return + if self.data.size != nv: + raise RuntimeError(f'Unexpected data shape {self.data.shape} for metadata specifying {nv=}') + self.data = self.data.reshape((nv, )) + + def dim_metadata(self) -> list[dict]: + return [] + + def print_data(self, file): + print(' '.join(str(x) for x in self.data), file=file) + + @staticmethod + def parts(): + # Yes, Ncount shows up twice ... + first = ('Format', 'URL', 'Creator', 'Instrument', 'Ncount', 'Trace', 'Gravitation', 'Seed', 'Directory') + second = ('Date', 'type', 'Source', 'component', 'position', 'title', 'Ncount', 'filename', 'statistics', + 'signal', 'values', 'yvar', 'ylabel', 'xlimits', 'variables') + return first, second + + def safe_to_combine(self, other): + if not isinstance(other, DatFile1D): + return False + if self.variables != other.variables: + return False + if self.data.shape != other.data.shape: + return False + return True + + def __add__(self, other): + both = super().__add__(other) + return DatFile1D(both.source, both.metadata, both.parameters, both.variables, both.data) + + @dataclass class DatFile1D(DatFileCommon): def __post_init__(self): @@ -192,13 +233,19 @@ def __add__(self, other): return DatFile2D(both.source, both.metadata, both.parameters, both.variables, both.data) -def read_mccode_dat(filename: str): +def read_mccode_dat(filename: str | Path): common = DatFileCommon.from_filename(filename) - ndim = len(common.metadata['type'].split('(', 1)[1].strip(')').split(',')) - if ndim < 1 or ndim > 2: + array_type = common.metadata['type'] + ndim, data_type = -1, None + if array_type.startswith('array_0d'): + ndim, data_type = 0, DatFile0D + elif array_type.startswith('array_1d'): + ndim, data_type = 1, DatFile1D + elif array_type.startswith('array_2d'): + ndim, data_type = 2, DatFile2D + if ndim < 0 or ndim > 2: raise RuntimeError(f'Unexpected number of dimensions: {ndim}') - dat_type = DatFile1D if ndim == 1 else DatFile2D - return dat_type(common.source, common.metadata, common.parameters, common.variables, common.data) + return data_type(common.source, common.metadata, common.parameters, common.variables, common.data) def combine_scan_dicts(a: dict, b: dict): diff --git a/mccode_antlr/reader/registry.py b/mccode_antlr/reader/registry.py index 5ba5303..8eba24a 100644 --- a/mccode_antlr/reader/registry.py +++ b/mccode_antlr/reader/registry.py @@ -1,8 +1,8 @@ import pooch from pathlib import Path from re import Pattern -from importlib.resources import files, as_file -from mccode_antlr import __version__ +from mccode_antlr.version import version as mccode_antlr_version + def ensure_regex_pattern(pattern): import re @@ -67,24 +67,40 @@ def _name_plus_suffix(name: str, suffix: str = None): return path.as_posix() +def find_registry_file(name: str): + """Find a registry file in the mccode_antlr package""" + from importlib.resources import files, as_file + from importlib.metadata import distribution + from json import loads + if isinstance(name, Path): + name = name.as_posix() + if files('mccode_antlr').joinpath(name).is_file(): + return files('mccode_antlr').joinpath(name) + info = loads(distribution('mccode_antlr').read_text('direct_url.json')) + if 'dir_info' in info and 'editable' in info['dir_info'] and info['dir_info']['editable'] and 'url' in info: + path = Path(info['url'].split('file://')[1]).joinpath('mccode_antlr', name) + return path if path.is_file() else None + return None + + class RemoteRegistry(Registry): - def __init__(self, name: str, url: str, filename=None): + def __init__(self, name: str, url: str, filename=None, version=None): self.name = name self.filename = filename self.pooch = pooch.create( path=pooch.os_cache(f'mccode_antlr-{name}'), base_url=url, - version=__version__, + version=version or mccode_antlr_version(), version_dev="main", registry=None, ) if isinstance(filename, Path): self.pooch.load_registry(filename) - elif files('mccode_antlr').joinpath(filename).is_file(): - with as_file(files('mccode_antlr').joinpath(filename)) as path: - self.pooch.load_registry(path) else: - raise RuntimeError(f"The provided filename {filename} is not a path or file packaged with this module") + filepath = find_registry_file(filename) + if filepath is None: + raise RuntimeError(f"The provided filename {filename} is not a path or file packaged with this module") + self.pooch.load_registry(filepath) def to_file(self, output, wrapper): contents = '(' + ', '.join([ diff --git a/mccode_antlr/translators/c_initialise.py b/mccode_antlr/translators/c_initialise.py index dc92232..65bf734 100644 --- a/mccode_antlr/translators/c_initialise.py +++ b/mccode_antlr/translators/c_initialise.py @@ -1,7 +1,7 @@ from zenlog import log _GETDISTANCE_FCT = """ -double index_getdistance(int first_index, int second_index) +double index_getdistance(long first_index, long second_index) /* Calculate the distance two components from their indexes*/ { return coords_len(coords_sub(POS_A_COMP_INDEX(first_index), POS_A_COMP_INDEX(second_index))); @@ -10,16 +10,16 @@ double getdistance(char* first_component, char* second_component) /* Calculate the distance between two named components */ { - int first_index = _getcomp_index(first_component); - int second_index = _getcomp_index(second_component); + long first_index = _getcomp_index(first_component); + long second_index = _getcomp_index(second_component); return index_getdistance(first_index, second_index); } -double checked_setpos_getdistance(int current_index, char* first_component, char* second_component) +double checked_setpos_getdistance(long current_index, char* first_component, char* second_component) /* Calculate the distance between two named components at *_setpos() time, with component index checking */ { - int first_index = _getcomp_index(first_component); - int second_index = _getcomp_index(second_component); + long first_index = _getcomp_index(first_component); + long second_index = _getcomp_index(second_component); if (first_index >= current_index || second_index >= current_index) { printf(\"setpos_getdistance can only be used with the names of components before the current one!\\n\"); return 0; @@ -30,11 +30,17 @@ """ +def _split_xyz_ref(xyz_ref): + x, y, z = [f'{v:p}' for v in xyz_ref[0]] + return x, y, z, xyz_ref[1] + + def cogen_comp_init_position(index, comp, last, instr): ref = None if index == 0 else instr.components[last] var = f'_{comp.name}_var' lines = [ f' /* component {comp.name}={comp.type.name}() AT ROTATED */', + f' /* {comp} */', ' {', ' Coords tc1, tc2;', ' tc1 = coords_set(0,0,0);', @@ -43,8 +49,7 @@ def cogen_comp_init_position(index, comp, last, instr): ' rot_set_rotation(tr1,0,0,0);' ] # Rotation first - x, y, z = [f'{v:p}' for v in comp.rotate_relative[0]] - rel = comp.rotate_relative[1] + x, y, z, rel = _split_xyz_ref(comp.rotate_relative) if rel is None: # log.debug(f'{comp.name} has absolute orientation with rotation ({x}, {y}, {z})') lines.append( @@ -67,8 +72,7 @@ def cogen_comp_init_position(index, comp, last, instr): lines.append(f' {var}._rotation_is_identity = rot_test_identity({var}._rotation_relative);') # Then translation - x, y, z = [f'{v:p}' for v in comp.rotate_relative[0]] - rel = comp.at_relative[1] + x, y, z, rel = _split_xyz_ref(comp.at_relative) if rel is None: # log.debug(f'{comp.name} has absolute positioning with rotation ({x}, {y}, {z})') lines.append(f' {var}._position_absolute = coords_set({x}, {y}, {z});') @@ -119,9 +123,10 @@ def parameter_line(default): ' }', ]) elif p.value.has_value and p.value.value != '0' and p.value.value != 'NULL' and p.value.value != '""': - pl.append(f' stracpy({fullname}, {value}, {len(p.value.value)-2});') + # len(value)-1 to remove quotes, but copy null terminator + pl.append(f' stracpy({fullname}, {value}, {len(p.value.value)-1});') else: - pl.append(f" {fullname}[0] = '\\0';") + pl.append(f" {fullname}[0] = '\\0';") elif default.value.is_vector or p.value.is_vector: if p.value.vector_known: for i, v in enumerate(p.value.value): @@ -148,7 +153,7 @@ def parameter_line(default): f' SIG_MESSAGE("[_{comp.name}_setpos] component {comp.name}={comp.type.name}() SETTING [{f}:{n}]");', f' stracpy(_{comp.name}_var._name, "{comp.name}", {min(len(comp.name)+1, 16384)});', f' stracpy(_{comp.name}_var._type, "{comp.type.name}", {min(len(comp.type.name)+1, 16384)});', - f' int current_setpos_index = _{comp.name}_var._index = {1 + index};' + f' long current_setpos_index = _{comp.name}_var._index = {1 + index};' # _index is a long ] # <<< This is the first call to `cogen_comp_init_par`: `cogen_comp_init_par(comp, instr, "SETTING") diff --git a/mccode_antlr/translators/c_macros.py b/mccode_antlr/translators/c_macros.py index 6292aff..5b9838e 100644 --- a/mccode_antlr/translators/c_macros.py +++ b/mccode_antlr/translators/c_macros.py @@ -37,7 +37,7 @@ def cogen_getparticlevar_fct(uservars): def cogen_getcompindex_fct(instr): lines = [ - "int _getcomp_index(char* compname)", + "long _getcomp_index(char* compname)", "/* Enables retrieving the component position & rotation when the index is not known.", " * Component indexing into MACROS, e.g., POS_A_COMP_INDEX, are 1-based! */", "{", diff --git a/mccode_antlr/version.py b/mccode_antlr/version.py new file mode 100644 index 0000000..638632b --- /dev/null +++ b/mccode_antlr/version.py @@ -0,0 +1,11 @@ +def version(): + import sys + if sys.version_info[0] == 3 and sys.version_info[1] < 8: + import importlib_metadata + else: + import importlib.metadata as importlib_metadata + try: + return importlib_metadata.version("mccode_antlr") + except importlib_metadata.PackageNotFoundError: + return "dev" + diff --git a/test/test_instr.py b/test/test_instr.py index 04d3483..3f0de9a 100644 --- a/test/test_instr.py +++ b/test/test_instr.py @@ -415,6 +415,7 @@ class CompiledInstr(CompiledTest): def _compile_and_run(self, instr, parameters, run=True): from mccode_antlr.compiler.c import compile_instrument, CBinaryTarget, run_compiled_instrument from mccode_antlr.translators.target import MCSTAS_GENERATOR + from mccode_antlr.loader import read_mccode_dat from tempfile import TemporaryDirectory from os import R_OK, access from pathlib import Path @@ -435,22 +436,18 @@ def _compile_and_run(self, instr, parameters, run=True): self.assertTrue(access(binary, R_OK)) if run: run_compiled_instrument(binary, target, f"--dir {directory}/instr {parameters}") - sim_files = list(Path(directory).glob('**/*.sim')) - print(sim_files) + sim_files = list(Path(directory).glob('**/*.dat')) + dats = {file.stem: read_mccode_dat(file) for file in sim_files} + return dats + return None def test_one_axis(self): - from mccode_antlr.compiler.c import compile_instrument, run_compiled_instrument, CBinaryTarget - from mccode_antlr.translators.target import MCSTAS_GENERATOR - from tempfile import TemporaryDirectory - from os import R_OK, access - from pathlib import Path - from math import pi, asin, sqrt from mccode_antlr.loader import parse_mcstas_instr d_spacing = 3.355 # (002) for Highly-ordered Pyrolytic Graphite mean_energy = 5.0 - energy_width = 0.1 - mean_ki = sqrt(mean_energy / 2.7022) + energy_width = 0.5 + mean_ki = sqrt(mean_energy / 2.0722) instr = f""" DEFINE INSTRUMENT splitRunTest(a1=0, a2=0, virtual_source_x=0.05, virtual_source_y=0.1, string newname) TRACE @@ -458,11 +455,14 @@ def test_one_axis(self): COMPONENT source = Source_simple(yheight=2*virtual_source_y, xwidth=0.2, dist=1.5, focus_xw=0.06, focus_yh=0.12, E0={mean_energy}, dE={energy_width}) AT (0, 0, 0) RELATIVE origin + COMPONENT m0 = PSD_monitor(xwidth=0.1, yheight=0.15, nx=100, ny=160, restore_neutron=1) AT (0, 0, 0.01) RELATIVE PREVIOUS COMPONENT guide = Guide_gravity(w1 = 0.06, h1 = 0.12, w2 = 0.05, h2 = 0.1, l = 30, m = 4) AT (0, 0, 1.5) RELATIVE PREVIOUS COMPONENT guide_end = Arm() AT (0, 0, 30) RELATIVE PREVIOUS + COMPONENT m1 = PSD_monitor(xwidth=0.1, yheight=0.15, nx=100, ny=160, restore_neutron=1) AT (0, 0, 0.01) RELATIVE PREVIOUS COMPONENT aperture = Slit(xwidth=virtual_source_x, yheight=virtual_source_y) AT (0, 0, 0.01) RELATIVE PREVIOUS COMPONENT split_at = Arm() AT (0, 0, 0.0001) RELATIVE PREVIOUS + COMPONENT m3 = PSD_monitor(xwidth=0.1, yheight=0.15, nx=100, ny=160, restore_neutron=1) AT (0, 0, 0.01) RELATIVE PREVIOUS COMPONENT mono_point = Arm() AT (0, 0, 0.8) RELATIVE split_at METADATA "txt" "something" %{{ This is some unparsed metadata that will be included as a literal string in the instrument. @@ -470,15 +470,28 @@ def test_one_axis(self): COMPONENT mono = Monochromator_curved(zwidth = 0.02, yheight = 0.02, NH = 13, NV = 7, DM={d_spacing}) AT (0, 0, 0) RELATIVE mono_point ROTATED (0, a1, 0) RELATIVE mono_point COMPONENT sample_arm = Arm() AT (0, 0, 0) RELATIVE mono_point ROTATED (0, a2, 0) RELATIVE mono_point - COMPONENT detector = Monitor(xwidth=0.01, yheight=0.05) AT (0, 0, 0.8) RELATIVE sample_arm + COMPONENT detector = Monitor(xwidth=0.1, yheight=0.15, restore_neutron=1) AT (0, 0, 0.8) RELATIVE sample_arm COMPONENT lmon = L_monitor(filename=newname) AT (0, 0, 0.001) RELATIVE PREVIOUS END """ instr = parse_mcstas_instr(instr) a1 = asin(pi / d_spacing / mean_ki) * 180 / pi - parameters = f'a1={a1} a2={2 * a1}' - self._compile_and_run(instr, parameters) + self.assertAlmostEqual(a1, 37.0722, 4) + parameters = f'a1={a1} a2={2 * a1} -n 1000000' + + dats = self._compile_and_run(instr, parameters) + self.assertEqual(len(dats), 5) + self.assertEqual(dats['m0'].data.shape, (3, 160, 100)) + self.assertEqual(dats['m1'].data.shape, (3, 160, 100)) + self.assertEqual(dats['m3'].data.shape, (3, 160, 100)) + self.assertEqual(dats['detector'].data.shape, (3, )) + + # Moving farther from the source means less (but finite) intensity in equivalent monitors + self.assertTrue(sum(sum(dats['m0']['I'])) > sum(sum(dats['m1']['I'])) > sum(sum(dats['m3']['I'])) > 0) + # The detector has been positioned correctly to collect intensity + self.assertTrue(dats['detector']['I'] > 0) + def test_assembled_parameters(self): """Check that setting an instance parameter to a value that is an instrument parameter name works"""