diff --git a/src/mccode_antlr/assembler/assembler.py b/src/mccode_antlr/assembler/assembler.py index ff0bf2e..5ca543d 100644 --- a/src/mccode_antlr/assembler/assembler.py +++ b/src/mccode_antlr/assembler/assembler.py @@ -130,18 +130,19 @@ def user_vars(self, string, source=None, line=-1): def ensure_user_var(self, string, source=None, line=-1): # tying the Assembler to work with C might not be great from mccode_antlr.translators.c_listener import extract_c_declared_variables as parse - input = parse(string) - if len(input) == 0: + variables = parse(string) + if len(variables) == 0: raise ValueError(f'The provided input {string} does not specify a C parameter declaration.') - if len(input) != 1: - print(f'The provided input {string} specifies {len(input)} C parameter declarations, using only the first') - name = list(input.keys())[0] - dtype, _ = input[name] + if len(variables) != 1: + print(f'The provided input {string} specifies {len(variables)} C parameter declarations, using only the first') + decl = variables[0] + name = decl.name + dtype = decl.dtype for user_vars in self.instrument.user: - dec_type_init_dict = parse(user_vars.source) - if any(d == dtype and n == name for n, (d, _) in dec_type_init_dict.items()): + uv_variables = parse(user_vars.source) + if any(x.dtype == dtype and x.name == name for x in uv_variables): return - if any(n == name for n in dec_type_init_dict): + if any(x.name == name for x in uv_variables): print(f'A USERVARS variable with name {name} but type different than {dtype} has already been defined.') return return self.user_vars(string, source=source, line=line) diff --git a/src/mccode_antlr/common/expression.py b/src/mccode_antlr/common/expression.py index 1220790..ff2945e 100644 --- a/src/mccode_antlr/common/expression.py +++ b/src/mccode_antlr/common/expression.py @@ -875,7 +875,7 @@ def __mul__(self, other): if self.is_value(-1): return (-other).as_type(pdt) if other.is_value(1): - return (self).as_type(pdt) + return self.as_type(pdt) if other.is_value(-1): return (-self).as_type(pdt) if other.is_op or self.is_id or other.is_id: @@ -888,7 +888,7 @@ def __truediv__(self, other): if self.is_zero: return Value(0, DataType.int if pdt.is_str else pdt) if other.is_value(1): - return (self).as_type(pdt) + return self.as_type(pdt) if other.is_value(-1): return (-self).as_type(pdt) if other.is_zero: diff --git a/src/mccode_antlr/comp/comp.py b/src/mccode_antlr/comp/comp.py index 724d923..8bdb01d 100644 --- a/src/mccode_antlr/comp/comp.py +++ b/src/mccode_antlr/comp/comp.py @@ -13,21 +13,21 @@ class Comp: """ name: str = None # Component *type* name, e.g. {name}.comp category: str = None # Component type catagory -- nearly free-form - define: tuple[ComponentParameter] = field(default_factory=tuple) # C #define'ed parameters - setting: tuple[ComponentParameter] = field(default_factory=tuple) # Formal 'setting' parameters - output: tuple[ComponentParameter] = field(default_factory=tuple) # 'output' parameters - metadata: tuple[MetaData] = field(default_factory=tuple) # metadata for use by simulation consumers + define: tuple[ComponentParameter, ...] = field(default_factory=tuple) # C #define'ed parameters + setting: tuple[ComponentParameter, ...] = field(default_factory=tuple) # Formal 'setting' parameters + output: tuple[ComponentParameter, ...] = field(default_factory=tuple) # 'output' parameters + metadata: tuple[MetaData, ...] = field(default_factory=tuple) # metadata for use by simulation consumers dependency: str = None # compile-time dependency acc: bool = True # False if this component *can not* work under OpenACC # literal strings writen into C source files - share: tuple[RawC] = field(default_factory=tuple) # function(s) for all instances of this class - user: tuple[RawC] = field(default_factory=tuple) # struct members for _particle - declare: tuple[RawC] = field(default_factory=tuple) # global parameters used in component trace - initialize: tuple[RawC] = field(default_factory=tuple) # initialization of global declare parameters - trace: tuple[RawC] = field(default_factory=tuple) # per-particle interaction executed at TRACE time - save: tuple[RawC] = field(default_factory=tuple) # statements executed after TRACE to save results - final: tuple[RawC] = field(default_factory=tuple) # clean-up memory for global declare parameters - display: tuple[RawC] = field(default_factory=tuple) # draw this component + share: tuple[RawC, ...] = field(default_factory=tuple) # function(s) for all instances of this class + user: tuple[RawC, ...] = field(default_factory=tuple) # struct members for _particle + declare: tuple[RawC, ...] = field(default_factory=tuple) # global parameters used in component trace + initialize: tuple[RawC, ...] = field(default_factory=tuple) # initialization of global declare parameters + trace: tuple[RawC, ...] = field(default_factory=tuple) # per-particle interaction executed at TRACE time + save: tuple[RawC, ...] = field(default_factory=tuple) # statements executed after TRACE to save results + final: tuple[RawC, ...] = field(default_factory=tuple) # clean-up memory for global declare parameters + display: tuple[RawC, ...] = field(default_factory=tuple) # draw this component def __hash__(self): return hash(repr(self)) diff --git a/src/mccode_antlr/instr/instance.py b/src/mccode_antlr/instr/instance.py index 3584e56..356a0cc 100644 --- a/src/mccode_antlr/instr/instance.py +++ b/src/mccode_antlr/instr/instance.py @@ -149,7 +149,7 @@ def set_parameter(self, name: str, value, overwrite=False, allow_repeated=True): self.parameters += (ComponentParameter(p.name, value), ) - def verify_parameters(self, instrument_parameters: tuple[InstrumentParameter]): + def verify_parameters(self, instrument_parameters: tuple[InstrumentParameter, ...]): """Check for instance parameters which are identifiers that match instrument parameter names, and flag them as parameter objects""" instrument_parameter_names = [x.name for x in instrument_parameters] diff --git a/src/mccode_antlr/instr/instr.py b/src/mccode_antlr/instr/instr.py index 5fb40c4..c37dde0 100644 --- a/src/mccode_antlr/instr/instr.py +++ b/src/mccode_antlr/instr/instr.py @@ -20,18 +20,18 @@ class Instr: """ name: str = None # Instrument name, e.g. {name}.instr (typically) source: str = None # Instrument *file* name - parameters: tuple[InstrumentParameter] = field(default_factory=tuple) # runtime-set instrument parameters - metadata: tuple[MetaData] = field(default_factory=tuple) # metadata for use by simulation consumers - components: tuple[Instance] = field(default_factory=tuple) # - included: tuple[str] = field(default_factory=tuple) # names of included instr definition(s) - user: tuple[RawC] = field(default_factory=tuple) # struct members for _particle - declare: tuple[RawC] = field(default_factory=tuple) # global parameters used in component trace - initialize: tuple[RawC] = field(default_factory=tuple) # initialization of global declare parameters - save: tuple[RawC] = field(default_factory=tuple) # statements executed after TRACE to save results - final: tuple[RawC] = field(default_factory=tuple) # clean-up memory for global declare parameters + parameters: tuple[InstrumentParameter, ...] = field(default_factory=tuple) # runtime-set instrument parameters + metadata: tuple[MetaData, ...] = field(default_factory=tuple) # metadata for use by simulation consumers + components: tuple[Instance, ...] = field(default_factory=tuple) # + included: tuple[str, ...] = field(default_factory=tuple) # names of included instr definition(s) + user: tuple[RawC, ...] = field(default_factory=tuple) # struct members for _particle + declare: tuple[RawC, ...] = field(default_factory=tuple) # global parameters used in component trace + initialize: tuple[RawC, ...] = field(default_factory=tuple) # initialization of global declare parameters + save: tuple[RawC, ...] = field(default_factory=tuple) # statements executed after TRACE to save results + final: tuple[RawC, ...] = field(default_factory=tuple) # clean-up memory for global declare parameters groups: dict[str, Group] = field(default_factory=dict) - flags: tuple[str] = field(default_factory=tuple) # (C) flags needed for compilation of the (translated) instrument - registries: tuple[Registry] = field(default_factory=tuple) # the registries used by the reader to populate this + flags: tuple[str, ...] = field(default_factory=tuple) # (C) flags needed for compilation of the (translated) instrument + registries: tuple[Registry, ...] = field(default_factory=tuple) # the registries used by the reader to populate this def to_file(self, output=None, wrapper=None): if output is None: diff --git a/src/mccode_antlr/instr/orientation.py b/src/mccode_antlr/instr/orientation.py index fc269f2..1cf52fc 100644 --- a/src/mccode_antlr/instr/orientation.py +++ b/src/mccode_antlr/instr/orientation.py @@ -433,7 +433,7 @@ def axes_euler_angles(m: Rotation, degrees) -> Angles: @dataclass class Part: """The Seitz matrix part of any arbitrary projective affine transformation""" - _axes: Seitz = field(default_factory=Seitz) + _axes: Seitz = field(default_factory=lambda: Seitz()) def __post_init__(self): """If this is not defined, the subclass' __post_init__ may not be called""" @@ -449,7 +449,7 @@ def is_rotation(self): # The first condition _should_ always be true -- the second is only true if this is not the identity matrix if round((self._axes.inverse() * self._axes).trace(), 12) == Expr.float(3.): return round(self._axes.trace(), 12) != Expr.float(3.) - loging.info(f'Not a rotation matrix: {self._axes}') + logger.info(f'Not a rotation matrix: {self._axes}') return False @property @@ -474,11 +474,11 @@ def rotation_axis_angle(self) -> tuple[Vector, Expr, str]: # The eigenvalues, dd, are (1+0j, a+bj, a-bj) of which we want 1+0j. axis = vv[:, argmin(sqrt(real(conj(dd-1) * (dd-1))))] if sum(imag(axis)) != 0: - loging.warning(f'Imaginary rotation axis {real(axis)} + j {imag(axis)}') + logger.warning(f'Imaginary rotation axis {real(axis)} + j {imag(axis)}') axis = real(axis) cos_angle = (matrix[0][0] + matrix[1][1] + matrix[2][2] - 1) / 2 if abs(cos_angle) > 1: - loging.warning(f'Invalid cos(angle) {cos_angle} for {self}') + logger.warning(f'Invalid cos(angle) {cos_angle} for {self}') cos_angle = 1 if cos_angle > 0 else -1 angle = Expr.float(acos_degree(cos_angle)) axis = Vector(Expr.float(axis[0]), Expr.float(axis[1]), Expr.float(axis[2])) @@ -560,7 +560,7 @@ def __contains__(self, value): @dataclass class TranslationPart(Part): """A specialization to the translation-only part of a projective affine transformation""" - v: Vector = field(default_factory=Vector) + v: Vector = field(default_factory=lambda: Vector()) def __str__(self): return f'({self.v[0]}, {self.v[1]}, {self.v[2]}) [0, 0, 0]' @@ -733,7 +733,7 @@ def __repr__(self): def stack(self): return self._stack - def _copy(self, deep: bool = True) -> tuple[Part]: + def _copy(self, deep: bool = True) -> tuple[Part, ...]: if deep: from copy import deepcopy return tuple([deepcopy(x) for x in self._stack]) diff --git a/src/mccode_antlr/loader/loader.py b/src/mccode_antlr/loader/loader.py index 9516c1a..77a85d1 100644 --- a/src/mccode_antlr/loader/loader.py +++ b/src/mccode_antlr/loader/loader.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path from typing import Union from mccode_antlr.instr import Instr diff --git a/src/mccode_antlr/translators/c.py b/src/mccode_antlr/translators/c.py index 04023b6..d6361df 100644 --- a/src/mccode_antlr/translators/c.py +++ b/src/mccode_antlr/translators/c.py @@ -1,36 +1,10 @@ """Translates a McComp instrument from its intermediate form to a C runtime source file.""" from loguru import logger -from collections import namedtuple from dataclasses import dataclass -from ..reader import Registry, LIBC_REGISTRY -from ..instr import Instr, Instance +from ..reader import LIBC_REGISTRY from .target import TargetVisitor from .c_listener import extract_c_declared_variables -# For use in keeping track of 'USERVAR' particle struct injections -CDeclaration = namedtuple("CDeclaration", "name type init is_pointer is_array orig") - - -def append_cdeclaration_name(decl: CDeclaration, suffix): - return CDeclaration(f'{decl.name}_{suffix}', decl.type, decl.init, decl.is_pointer, decl.is_array, decl.orig) - - -def extract_declaration(dec, c_type, init): - is_pointer = '*' in dec - is_array = '[' in dec and ']' in dec - if is_pointer: - name = dec.translate(str.maketrans('', '', '*')) - elif is_array: - # since dec could be 'name[x][y][z]...[a]' don't attempt to parse the size of the array - name = dec.split('[', 1)[0] - else: - name = dec - return CDeclaration(name, c_type, init, is_pointer, is_array, dec) - - -def extracted_declares(declares): - return [extract_declaration(dec, c_type, init) for dec, (c_type, init) in declares.items()] - @dataclass class CInclude: @@ -166,18 +140,18 @@ def _handle_raw_c_include(self, parent: str, raw_c: str): return includes, raw_c def _parse_libraries_for_typedefs(self): - from .c_listener import extract_c_declared_variables_and_defined_types as parse + from .c_listener import extract_c_defined_types as parse typedefs = set() for include in self.includes: # logger.debug(f'library {include.name}') # The files '%include'-d can themselves use the '%include' mechanism :/ - declares, defined_types = parse(include.content, user_types=list(typedefs)) + defined_types = parse(include.content, user_types=list(typedefs)) typedefs = typedefs.union(set(defined_types)) # logger.debug(f'{include.name} done') # types can also be defined in component 'SHARE' blocks: for block in [share for comp in self.source.component_types() if len(comp.share) for share in comp.share]: # TODO decide if this should be block.translated or block.source - declares, defined_types = parse(block.to_c(), user_types=list(typedefs)) + defined_types = parse(block.to_c(), user_types=list(typedefs)) typedefs = typedefs.union(set(defined_types)) self.typedefs = list(typedefs) @@ -199,7 +173,7 @@ def _determine_uservars(self): """ def extract_declares(name, raw_c_obj): # TODO decide if this should be raw_c_obj.translated or raw_c_obj.source - c_decs = extracted_declares(extract_c_declared_variables(raw_c_obj.to_c(), user_types=self.typedefs)) + c_decs = extract_c_declared_variables(raw_c_obj.to_c(), user_types=self.typedefs) if any(d.init is not None for d in c_decs): logger.critical(f'Warning USERVARS block from {name} contains assignment(s) (= sign).') logger.critical(' Move them to an EXTEND section. May fail at compile') @@ -236,8 +210,7 @@ def extract_declares(name, raw_c_obj): i_declares = [] for index, instance in enumerate(self.source.components): if instance.type.name == comp_name: - for dec in a_declares: - i_declares.append(append_cdeclaration_name(dec, index+1)) + i_declares.extend([d.copy(suffix=str(index+1)) for d in a_declares]) inst_declares[comp_name] = i_declares # Replace the component definition 'base' uservar declares with the 'real' ones: comp_declares = inst_declares @@ -284,13 +257,12 @@ def __post_init__(self): self._parse_libraries_for_typedefs() # pull together the per-component-type defined parameters into a dictionary... since this is required - # in multiple places :/ -- now using CDeclaration named tuples - sc = str.maketrans('', '', '*[] ') # for '* name' or '*name', or '**** name' or 'name[]' -> 'name' + # in multiple places :/ -- now using CDeclarator class objects for typ in inst.component_types(): dp = [] for block in typ.declare: # TODO decide if this should be block.translated or block.source - dp.extend(extracted_declares(extract_c_declared_variables(block.to_c(), user_types=self.typedefs))) + dp.extend(extract_c_declared_variables(block.to_c(), user_types=self.typedefs)) dp = list(dict.fromkeys(dp)) self.component_declared_parameters[typ.name] = dp self._determine_uservars() @@ -313,7 +285,7 @@ def visit_header(self): uuv = self._instrument_and_component_uservars() is_mcstas = self.is_mcstas - self.out(header_pre_runtime(is_mcstas, self.source, self.runtime, self.config, self.typedefs, uuv)) + self.out(header_pre_runtime(is_mcstas, self.source, self.runtime, self.config, uuv)) # runtime part if self.config.get('include_runtime'): self.out('#define MC_EMBEDDED_RUNTIME') diff --git a/src/mccode_antlr/translators/c_decls.py b/src/mccode_antlr/translators/c_decls.py index 6c3cb3c..fb2eb73 100644 --- a/src/mccode_antlr/translators/c_decls.py +++ b/src/mccode_antlr/translators/c_decls.py @@ -152,7 +152,6 @@ def component_type_declaration(comp, typedefs: list, declared_parameters: list): # The call tree for functions that access `comp->def->out_par` is such that the pointer is not used before it # is replaced, so at least there is no ambiguity between DECLARE-found parameters and OUTPUT PARAMETERS in cogen. - warnings = 0 lines = [ f"/* Parameter definition for component type '{comp.name}' */", f'struct _struct_{comp.name}_parameters {{', @@ -179,22 +178,8 @@ def component_type_declaration(comp, typedefs: list, declared_parameters: list): lines.append(f' {par.value.mccode_c_type} {par.name};') # This is the loop over the *replaced* `comp->def->out_par` e.g., found DECLARE parameters - lines.append(f"/* Component type '{comp.name}' private parameters */") - for x in declared_parameters: - # Switch these to use CDeclarations, then we have (.name, .type, .init, .is_pointer, .is_array, .orig) - # and the append would be f' {x.type} {x.orig}; /* {"Not initialized" if x.init is None else x.init} */' - # But of course, we need to do a bit more work to initialize any static array, so we instead would - # branch on x.is_array and then either count the number of initializer elements or punt to 16384 elements - # as McCode-3 does. - if x.is_array: - if x.init is None: - # hopefully handle all size-specified cases... - lines.append(f' {x.type} {x.orig}; /* Not initialized */') - else: - n_inits = 16384 if not isinstance(x.init, str) else min(len(x.init.split(',')), 16384) - lines.append(f' {x.type} {x.name}[{n_inits}]; /* {x.init} */') - else: - lines.append(f' {x.type} {x.orig}; /* {"Not initialized" if x.init is None else x.init} */') + lines.append(f" /* Component type '{comp.name}' private parameters */") + lines.extend([f' {x.as_struct_member()};' for x in declared_parameters]) if len(comp.setting) + len(declared_parameters) == 0: lines.append(f' char {comp.name}_has_no_parameters;') diff --git a/src/mccode_antlr/translators/c_defines.py b/src/mccode_antlr/translators/c_defines.py index 276fe69..d56f663 100644 --- a/src/mccode_antlr/translators/c_defines.py +++ b/src/mccode_antlr/translators/c_defines.py @@ -1,3 +1,4 @@ +from .c_listener import CDeclarator from ..comp import Comp from ..common import ComponentParameter @@ -10,13 +11,13 @@ def undef(a: ComponentParameter): return f'#undef {a.name}' -def cogen_parameter_define(comp: Comp, declares: list): +def cogen_parameter_define(comp: Comp, declares: list[CDeclarator]): # All parameters get defines? Not just DEFINE PARAMETERS lines = [define(par) for pars in (comp.define, comp.setting, comp.output, declares) for par in pars] return '\n'.join(lines) -def cogen_parameter_undef(comp: Comp, declares: list): +def cogen_parameter_undef(comp: Comp, declares: list[CDeclarator]): # The same parameters that were defined need to be undefined: lines = [undef(par) for pars in (comp.define, comp.setting, comp.output, declares) for par in pars] return '\n'.join(lines) diff --git a/src/mccode_antlr/translators/c_header.py b/src/mccode_antlr/translators/c_header.py index 742e857..e452dc6 100644 --- a/src/mccode_antlr/translators/c_header.py +++ b/src/mccode_antlr/translators/c_header.py @@ -1,4 +1,15 @@ -def header_pre_runtime(is_mcstas, source, runtime: dict, config: dict, typedefs: list, uservars: list): +from textwrap import dedent +from .c_listener import CDeclarator +from ..instr import Instr + + +def header_pre_runtime( + is_mcstas: bool, + source: Instr, + runtime: dict, + config: dict, + uservars: list[CDeclarator] +): from datetime import datetime from mccode_antlr import version @@ -23,7 +34,7 @@ def jump_line(instance, jump): # Also store these strings in the appropriate instrument list for later def/undef as state variables uservar_string = '//user variables and comp - injections:\n' if len(uservars) else '' - uservar_string += '\n'.join([f' {x.type} {x.name};' for x in uservars]) + uservar_string += '\n'.join([f' {x};' for x in uservars]) # Shouldn't we exclude array-valued names here? getvar = '\n'.join( @@ -35,12 +46,12 @@ def jump_line(instance, jump): # For non-array values we can safely use memcpy (probably) setvar_void = '\n'.join( - [f' if(!str_comp("{x.name}",name)){{memcpy(&(p->{x.name}), value, sizeof({x.type})); rval=0;}}' + [f' if(!str_comp("{x.name}",name)){{memcpy(&(p->{x.name}), value, sizeof({x.dtype})); rval=0;}}' for x in uservars if not x.is_array and not x.is_pointer]) # an array -- need to know how many elements we're copying setvar_void_array = '\n'.join( - [f' if(!str_comp("{x.name}",name)){{memcpy(&(p->{x.name}), value, elements * sizeof({x.type})); rval=0;}}' + [f' if(!str_comp("{x.name}",name)){{memcpy(&(p->{x.name}), value, elements * sizeof({x.dtype})); rval=0;}}' for x in uservars if x.is_array or x.is_pointer]) getuservar_byid = '\n'.join( @@ -48,220 +59,220 @@ def jump_line(instance, jump): uservar_init = '' for x in uservars: - if (x.type not in ('double', 'MCNUM', 'int') and x.init is None) or x.is_pointer or x.is_array: + if (x.dtype not in ('double', 'MCNUM', 'int') and x.init is None) or x.is_pointer or x.is_array: array_str = ' array' if x.is_pointer or x.is_array else '' - print(f'\nWARNING:\n --> USERVAR {x.name} is of type {x.type}{array_str}') + print(f'\nWARNING:\n --> USERVAR {x.name} is of type {x.dtype}{array_str}') print(' --> and may need specific per-particle initialization through an EXTEND block!\n') else: uservar_init += f'\np->{x.name}={0 if x.init is None else x.init};' - contents = f"""/* Automatically generated file. Do not edit. - * Format: ANSI C source code - * Creator: {runtime.get("fancy")} <{runtime.get("url")}> - * Generator: mccode-antlr {version()} - * Instrument: {source.source} ({source.name}) - * Date: {datetime.now()} - * File: {config.get('output')} - * CFLAGS={' '.join(set(source.flags))} - */ - -/* In case of cl.exe on Windows, suppress warnings about #pragma acc - Transferred from https://github.com/McStasMcXtrace/McCode/commit/0e2785a2d3fd742d46597139234dbc47e56344bb -*/ -#ifdef _MSC_EXTENSIONS -#pragma warning(disable: 4068) -#endif - -#define MCCODE_STRING "{runtime.get("fancy")}" -#define FLAVOR "{runtime.get("name", "none")}" -#define FLAVOR_UPPER "{runtime.get("name", "none").upper()}" -{'#define MC_USE_DEFAULT_MAIN' if config.get('default_main') else ''} -{'#define MC_TRACE_ENABLED' if config.get('enable_trace') else ''} -{'#define MC_PORTABLE' if config.get('portable') else ''} - -#include - -typedef double MCNUM; -typedef struct {{MCNUM x, y, z;}} Coords; -typedef MCNUM Rotation[3][3]; -#define MCCODE_BASE_TYPES - -#ifndef MC_NUSERVAR -#define MC_NUSERVAR 10 -#endif - -/* Particle JUMP control logic */ -struct particle_logic_struct {{ - int dummy;{jump_string} -}}; -struct _struct_particle {{ - double x,y,z; /* position [m] */ -{particle_struct} - /* Generic Temporaries: */ - /* May be used internally by components e.g. for special */ - /* return-values from functions used in trace, thusreturned via */ - /* particle struct. (Example: Wolter Conics from McStas, silicon slabs.) */ - double _mctmp_a; /* temp a */ - double _mctmp_b; /* temp b */ - double _mctmp_c; /* temp c */ - unsigned long randstate[7]; - double t, p; /* time, event weight */ - long long _uid; /* Unique event ID */ - long _index; /* component index where to send this event */ - long _absorbed; /* flag set to TRUE when this event is to be removed/ignored */ - long _scattered; /* flag set to TRUE when this event has interacted with the last component instance */ - long _restore; /* set to true if neutron event must be restored */ - long flag_nocoordschange; /* set to true if particle is jumping */ - struct particle_logic_struct _logic; - {uservar_string} -}}; -typedef struct _struct_particle _class_particle; - -_class_particle _particle_global_randnbuse_var; -_class_particle* _particle = &_particle_global_randnbuse_var; - -// Below lines relating to mcgenstate / setstate are in principle McStas - centric, we ought to generate -//this function based on "project" -#pragma acc routine -_class_particle mcgenstate(void); -#pragma acc routine -_class_particle mcsetstate(double x, double y, double z, double vx, double vy, double vz, - double t, double sx, double sy, double sz, double p, int mcgravitation, void *mcMagnet, int mcallowbackprop); - -extern int mcgravitation; /* flag to enable gravitation */ -#pragma acc declare create ( mcgravitation ) -int mcallowbackprop; -#pragma acc declare create ( mcallowbackprop ) - -_class_particle mcgenstate(void) {{ - _class_particle particle = mcsetstate(0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, mcgravitation, NULL, mcallowbackprop); - return(particle); -}} -/*Generated user variable handlers:*/ - -#pragma acc routine -double particle_getvar(_class_particle *p, char *name, int *suc); - -#ifdef OPENACC -#pragma acc routine -int str_comp(char *str1, char *str2); -#endif - -double particle_getvar(_class_particle *p, char *name, int *suc){{ -#ifndef OPENACC -#define str_comp strcmp -#endif - int s=1; - double rval=0; - if(!str_comp("x",name)){{rval=p->x;s=0;}} - if(!str_comp("y",name)){{rval=p->y;s=0;}} - if(!str_comp("z",name)){{rval=p->z;s=0;}} - if(!str_comp("vx",name)){{rval=p->vx;s=0;}} - if(!str_comp("vy",name)){{rval=p->vy;s=0;}} - if(!str_comp("vz",name)){{rval=p->vz;s=0;}} - if(!str_comp("sx",name)){{rval=p->sx;s=0;}} - if(!str_comp("sy",name)){{rval=p->sy;s=0;}} - if(!str_comp("sz",name)){{rval=p->sz;s=0;}} - if(!str_comp("t",name)){{rval=p->t;s=0;}} - if(!str_comp("p",name)){{rval=p->p;s=0;}} - if(!str_comp("_mctmp_a",name)){{rval=p->_mctmp_a;s=0;}} - if(!str_comp("_mctmp_b",name)){{rval=p->_mctmp_b;s=0;}} - if(!str_comp("_mctmp_c",name)){{rval=p->_mctmp_c;s=0;}} -{getvar} - if (suc!=0x0) {{*suc=s;}} - return rval; -}} - -#pragma acc routine -void* particle_getvar_void(_class_particle *p, char *name, int *suc);\n -#ifdef OPENACC -#pragma acc routine -int str_comp(char *str1, char *str2); -#endif - -void* particle_getvar_void(_class_particle *p, char *name, int *suc){{ -#ifndef OPENACC -#define str_comp strcmp -#endif - int s=1; - void* rval=0; - if(!str_comp("x",name)) {{rval=(void*)&(p->x); s=0;}} - if(!str_comp("y",name)) {{rval=(void*)&(p->y); s=0;}} - if(!str_comp("z",name)) {{rval=(void*)&(p->z); s=0;}} - if(!str_comp("vx",name)){{rval=(void*)&(p->vx);s=0;}} - if(!str_comp("vy",name)){{rval=(void*)&(p->vy);s=0;}} - if(!str_comp("vz",name)){{rval=(void*)&(p->vz);s=0;}} - if(!str_comp("sx",name)){{rval=(void*)&(p->sx);s=0;}} - if(!str_comp("sy",name)){{rval=(void*)&(p->sy);s=0;}} - if(!str_comp("sz",name)){{rval=(void*)&(p->sz);s=0;}} - if(!str_comp("t",name)) {{rval=(void*)&(p->t); s=0;}} - if(!str_comp("p",name)) {{rval=(void*)&(p->p); s=0;}} -{getvar_void} - if (suc!=0x0) {{*suc=s;}} - return rval; -}} - -#pragma acc routine -int particle_setvar_void(_class_particle *, char *, void*);\n -int particle_setvar_void(_class_particle *p, char *name, void* value){{ -#ifndef OPENACC -#define str_comp strcmp -#endif - int rval=1; - if(!str_comp("x",name)) {{memcpy(&(p->x), value, sizeof(double)); rval=0;}} - if(!str_comp("y",name)) {{memcpy(&(p->y), value, sizeof(double)); rval=0;}} - if(!str_comp("z",name)) {{memcpy(&(p->z), value, sizeof(double)); rval=0;}} - if(!str_comp("vx",name)){{memcpy(&(p->vx), value, sizeof(double)); rval=0;}} - if(!str_comp("vy",name)){{memcpy(&(p->vy), value, sizeof(double)); rval=0;}} - if(!str_comp("vz",name)){{memcpy(&(p->vz), value, sizeof(double)); rval=0;}} - if(!str_comp("sx",name)){{memcpy(&(p->sx), value, sizeof(double)); rval=0;}} - if(!str_comp("sy",name)){{memcpy(&(p->sy), value, sizeof(double)); rval=0;}} - if(!str_comp("sz",name)){{memcpy(&(p->sz), value, sizeof(double)); rval=0;}} - if(!str_comp("p",name)) {{memcpy(&(p->p), value, sizeof(double)); rval=0;}} - if(!str_comp("t",name)) {{memcpy(&(p->t), value, sizeof(double)); rval=0;}} -{setvar_void} - return rval; -}} - -#pragma acc routine -int particle_setvar_void_array(_class_particle *, char *, void*, int); - -int particle_setvar_void_array(_class_particle *p, char *name, void* value, int elements){{ -#ifndef OPENACC -#define str_comp strcmp -#endif - int rval=1; -{setvar_void_array} - return rval; -}} - -// Function to handle a particle restore of physical particle params -#pragma acc routine -void particle_restore(_class_particle *p, _class_particle *p0); -void particle_restore(_class_particle *p, _class_particle *p0) {{ - p->x = p0->x; p->y = p0->y; p->z = p0->z; - p->vx = p0->vx; p->vy = p0->vy; p->vz = p0->vz; - p->sx = p0->sx; p->sy = p0->sy; p->sz = p0->sz; - p->t = p0->t; p->p = p0->p; - p->_absorbed=0; p->_restore=0; -}} - -#pragma acc routine -double particle_getuservar_byid(_class_particle *p, int id, int *suc){{ - int s=1; - double rval=0; - switch(id){{ -{getuservar_byid} - }} - if (suc!=0x0) {{*suc=s;}} - return rval; -}} - -#pragma acc routine -void particle_uservar_init(_class_particle *p){{ -{uservar_init} -}} -""" + contents = dedent(f"""/* Automatically generated file. Do not edit. + * Format: ANSI C source code + * Creator: {runtime.get("fancy")} <{runtime.get("url")}> + * Generator: mccode-antlr {version()} + * Instrument: {source.source} ({source.name}) + * Date: {datetime.now()} + * File: {config.get('output')} + * CFLAGS={' '.join(set(source.flags))} + */ + + /* In case of cl.exe on Windows, suppress warnings about #pragma acc + Transferred from https://github.com/McStasMcXtrace/McCode/commit/0e2785a2d3fd742d46597139234dbc47e56344bb + */ + #ifdef _MSC_EXTENSIONS + #pragma warning(disable: 4068) + #endif + + #define MCCODE_STRING "{runtime.get("fancy")}" + #define FLAVOR "{runtime.get("name", "none")}" + #define FLAVOR_UPPER "{runtime.get("name", "none").upper()}" + {'#define MC_USE_DEFAULT_MAIN' if config.get('default_main') else ''} + {'#define MC_TRACE_ENABLED' if config.get('enable_trace') else ''} + {'#define MC_PORTABLE' if config.get('portable') else ''} + + #include + + typedef double MCNUM; + typedef struct {{MCNUM x, y, z;}} Coords; + typedef MCNUM Rotation[3][3]; + #define MCCODE_BASE_TYPES + + #ifndef MC_NUSERVAR + #define MC_NUSERVAR 10 + #endif + + /* Particle JUMP control logic */ + struct particle_logic_struct {{ + int dummy;{jump_string} + }}; + struct _struct_particle {{ + double x,y,z; /* position [m] */ + {particle_struct} + /* Generic Temporaries: */ + /* May be used internally by components e.g. for special */ + /* return-values from functions used in trace, thusreturned via */ + /* particle struct. (Example: Wolter Conics from McStas, silicon slabs.) */ + double _mctmp_a; /* temp a */ + double _mctmp_b; /* temp b */ + double _mctmp_c; /* temp c */ + unsigned long randstate[7]; + double t, p; /* time, event weight */ + long long _uid; /* Unique event ID */ + long _index; /* component index where to send this event */ + long _absorbed; /* flag set to TRUE when this event is to be removed/ignored */ + long _scattered; /* flag set to TRUE when this event has interacted with the last component instance */ + long _restore; /* set to true if neutron event must be restored */ + long flag_nocoordschange; /* set to true if particle is jumping */ + struct particle_logic_struct _logic; + {uservar_string} + }}; + typedef struct _struct_particle _class_particle; + + _class_particle _particle_global_randnbuse_var; + _class_particle* _particle = &_particle_global_randnbuse_var; + + // Below lines relating to mcgenstate / setstate are in principle McStas - centric, we ought to generate + //this function based on "project" + #pragma acc routine + _class_particle mcgenstate(void); + #pragma acc routine + _class_particle mcsetstate(double x, double y, double z, double vx, double vy, double vz, + double t, double sx, double sy, double sz, double p, int mcgravitation, void *mcMagnet, int mcallowbackprop); + + extern int mcgravitation; /* flag to enable gravitation */ + #pragma acc declare create ( mcgravitation ) + int mcallowbackprop; + #pragma acc declare create ( mcallowbackprop ) + + _class_particle mcgenstate(void) {{ + _class_particle particle = mcsetstate(0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, mcgravitation, NULL, mcallowbackprop); + return(particle); + }} + /*Generated user variable handlers:*/ + + #pragma acc routine + double particle_getvar(_class_particle *p, char *name, int *suc); + + #ifdef OPENACC + #pragma acc routine + int str_comp(char *str1, char *str2); + #endif + + double particle_getvar(_class_particle *p, char *name, int *suc){{ + #ifndef OPENACC + #define str_comp strcmp + #endif + int s=1; + double rval=0; + if(!str_comp("x",name)){{rval=p->x;s=0;}} + if(!str_comp("y",name)){{rval=p->y;s=0;}} + if(!str_comp("z",name)){{rval=p->z;s=0;}} + if(!str_comp("vx",name)){{rval=p->vx;s=0;}} + if(!str_comp("vy",name)){{rval=p->vy;s=0;}} + if(!str_comp("vz",name)){{rval=p->vz;s=0;}} + if(!str_comp("sx",name)){{rval=p->sx;s=0;}} + if(!str_comp("sy",name)){{rval=p->sy;s=0;}} + if(!str_comp("sz",name)){{rval=p->sz;s=0;}} + if(!str_comp("t",name)){{rval=p->t;s=0;}} + if(!str_comp("p",name)){{rval=p->p;s=0;}} + if(!str_comp("_mctmp_a",name)){{rval=p->_mctmp_a;s=0;}} + if(!str_comp("_mctmp_b",name)){{rval=p->_mctmp_b;s=0;}} + if(!str_comp("_mctmp_c",name)){{rval=p->_mctmp_c;s=0;}} + {getvar} + if (suc!=0x0) {{*suc=s;}} + return rval; + }} + + #pragma acc routine + void* particle_getvar_void(_class_particle *p, char *name, int *suc);\n + #ifdef OPENACC + #pragma acc routine + int str_comp(char *str1, char *str2); + #endif + + void* particle_getvar_void(_class_particle *p, char *name, int *suc){{ + #ifndef OPENACC + #define str_comp strcmp + #endif + int s=1; + void* rval=0; + if(!str_comp("x",name)) {{rval=(void*)&(p->x); s=0;}} + if(!str_comp("y",name)) {{rval=(void*)&(p->y); s=0;}} + if(!str_comp("z",name)) {{rval=(void*)&(p->z); s=0;}} + if(!str_comp("vx",name)){{rval=(void*)&(p->vx);s=0;}} + if(!str_comp("vy",name)){{rval=(void*)&(p->vy);s=0;}} + if(!str_comp("vz",name)){{rval=(void*)&(p->vz);s=0;}} + if(!str_comp("sx",name)){{rval=(void*)&(p->sx);s=0;}} + if(!str_comp("sy",name)){{rval=(void*)&(p->sy);s=0;}} + if(!str_comp("sz",name)){{rval=(void*)&(p->sz);s=0;}} + if(!str_comp("t",name)) {{rval=(void*)&(p->t); s=0;}} + if(!str_comp("p",name)) {{rval=(void*)&(p->p); s=0;}} + {getvar_void} + if (suc!=0x0) {{*suc=s;}} + return rval; + }} + + #pragma acc routine + int particle_setvar_void(_class_particle *, char *, void*);\n + int particle_setvar_void(_class_particle *p, char *name, void* value){{ + #ifndef OPENACC + #define str_comp strcmp + #endif + int rval=1; + if(!str_comp("x",name)) {{memcpy(&(p->x), value, sizeof(double)); rval=0;}} + if(!str_comp("y",name)) {{memcpy(&(p->y), value, sizeof(double)); rval=0;}} + if(!str_comp("z",name)) {{memcpy(&(p->z), value, sizeof(double)); rval=0;}} + if(!str_comp("vx",name)){{memcpy(&(p->vx), value, sizeof(double)); rval=0;}} + if(!str_comp("vy",name)){{memcpy(&(p->vy), value, sizeof(double)); rval=0;}} + if(!str_comp("vz",name)){{memcpy(&(p->vz), value, sizeof(double)); rval=0;}} + if(!str_comp("sx",name)){{memcpy(&(p->sx), value, sizeof(double)); rval=0;}} + if(!str_comp("sy",name)){{memcpy(&(p->sy), value, sizeof(double)); rval=0;}} + if(!str_comp("sz",name)){{memcpy(&(p->sz), value, sizeof(double)); rval=0;}} + if(!str_comp("p",name)) {{memcpy(&(p->p), value, sizeof(double)); rval=0;}} + if(!str_comp("t",name)) {{memcpy(&(p->t), value, sizeof(double)); rval=0;}} + {setvar_void} + return rval; + }} + + #pragma acc routine + int particle_setvar_void_array(_class_particle *, char *, void*, int); + + int particle_setvar_void_array(_class_particle *p, char *name, void* value, int elements){{ + #ifndef OPENACC + #define str_comp strcmp + #endif + int rval=1; + {setvar_void_array} + return rval; + }} + + // Function to handle a particle restore of physical particle params + #pragma acc routine + void particle_restore(_class_particle *p, _class_particle *p0); + void particle_restore(_class_particle *p, _class_particle *p0) {{ + p->x = p0->x; p->y = p0->y; p->z = p0->z; + p->vx = p0->vx; p->vy = p0->vy; p->vz = p0->vz; + p->sx = p0->sx; p->sy = p0->sy; p->sz = p0->sz; + p->t = p0->t; p->p = p0->p; + p->_absorbed=0; p->_restore=0; + }} + + #pragma acc routine + double particle_getuservar_byid(_class_particle *p, int id, int *suc){{ + int s=1; + double rval=0; + switch(id){{ + {getuservar_byid} + }} + if (suc!=0x0) {{*suc=s;}} + return rval; + }} + + #pragma acc routine + void particle_uservar_init(_class_particle *p){{ + {uservar_init} + }} + """) return contents @@ -285,22 +296,22 @@ def source_file_contents(): return message main_file_string = 'int main(int argc, char *argv[]){return mccode_main(argc, argv);}' - contents = f""" -/* ***************************************************************************** - * Start of instrument '{source.name}' generated code -***************************************************************************** */ - -#ifdef MC_TRACE_ENABLED -int traceenabled = 1; -#else -int traceenabled = 0; -#endif -#define {runtime.get("name", "none").upper()} "{escape_str_for_c(str(include_path))}" -int defaultmain = {1 if config.get("default_main") else 0}; -char instrument_name[] = "{source.name}"; -char instrument_source[] = "{escape_str_for_c(source.source)}"; -char *instrument_exe = NULL; /* will be set to argv[0] in main */ -char instrument_code[] = "{source_file_contents()}"; -{main_file_string if config.get('default_main') else ''} -""" + contents = dedent(f""" + /* ***************************************************************************** + * Start of instrument '{source.name}' generated code + ***************************************************************************** */ + + #ifdef MC_TRACE_ENABLED + int traceenabled = 1; + #else + int traceenabled = 0; + #endif + #define {runtime.get("name", "none").upper()} "{escape_str_for_c(str(include_path))}" + int defaultmain = {1 if config.get("default_main") else 0}; + char instrument_name[] = "{source.name}"; + char instrument_source[] = "{escape_str_for_c(source.source)}"; + char *instrument_exe = NULL; /* will be set to argv[0] in main */ + char instrument_code[] = "{source_file_contents()}"; + {main_file_string if config.get('default_main') else ''} + """) return contents diff --git a/src/mccode_antlr/translators/c_initialise.py b/src/mccode_antlr/translators/c_initialise.py index 170a357..9393403 100644 --- a/src/mccode_antlr/translators/c_initialise.py +++ b/src/mccode_antlr/translators/c_initialise.py @@ -171,8 +171,16 @@ def parameter_line(default): for c_dec in component_declared_parameters[comp.type.name]: # c_dec is a CDeclare named tuple with (name, type, init, is_pointer, is_array, orig) initialized_value = 'NULL' if (c_dec.is_pointer and c_dec.init is None) else c_dec.init - if initialized_value is not None: - lines.append(f' _{comp.name}_var._parameters.{c_dec.name} = {initialized_value};') + fullname = f'_{comp.name}_var._parameters.{c_dec.name}' + if c_dec.is_array and initialized_value is not None: + for i, val in enumerate(initialized_value.strip('{}').split(',')): + lines.append(f' {fullname}[{i}] = {val.strip()};') + # elif c_dec.is_struct and initialized_value is not None: + # # deal with a structured initializer ... + # # = {.x=#, .y=#, ...} + # pass + elif initialized_value is not None: + lines.append(f' {fullname} = {initialized_value};') # >>> End of second call to `cogen_comp_init_par` # position/rotation diff --git a/src/mccode_antlr/translators/c_listener.py b/src/mccode_antlr/translators/c_listener.py index 9d9afdb..fee8616 100644 --- a/src/mccode_antlr/translators/c_listener.py +++ b/src/mccode_antlr/translators/c_listener.py @@ -1,3 +1,7 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from typing import TypeVar + from loguru import logger from ..grammar import CParser, McInstrParser, CVisitor from ..instr import InstrVisitor @@ -40,53 +44,142 @@ def syntaxError(self, recognizer, offendingSymbol, *args, **kwargs): return ErrorListener() +TCDeclarator = TypeVar('TCDeclarator', bound='CDeclarator') +TCFuncPointer = TypeVar('TCFuncPointer', bound='CFuncPointer') -class CDeclarator: - def __init__(self, pointer, declare, extensions): - self.pointer = pointer - self.declare = declare - self.extensions = extensions - self.dtype = None +def not_both_None_or_not_equal(a, b): + if (a is None and b is not None) or (a is not None and b is None): + return False + return a is not None and b is not None and a != b + +def same_type(a, b): + return type(a) == type(b) + +@dataclass +class CFuncPointer: + declare: TCDeclarator + modifiers: str | None = None + args: str | None = None + + @property + def name(self) -> str: + return self.declare.name + + def copy(self, suffix: str | None = None) -> TCFuncPointer: + return CFuncPointer( + declare=self.declare.copy(suffix=suffix), + modifiers=self.modifiers, + args=self.args, + ) + + def __eq__(self, other: TCFuncPointer) -> bool: + if not isinstance(other, CFuncPointer): + raise ValueError('Type mismatch') + if not_both_None_or_not_equal(self.args, other.args): + return False + if not_both_None_or_not_equal(self.modifiers, other.modifiers): + return False + return self.declare == other.declare + + def string(self, dec_str): + f = f'({self.modifiers} {dec_str})' if self.modifiers else f'({dec_str})' + return f'{f}({self.args})' if self.args else f'{f}()' def __str__(self): + return self.string(str(self.declare)) + + def __hash__(self): + return hash(str(self)) + + def as_struct_member(self, max_array_length): + dec = self.declare.as_struct_member(max_array_length=max_array_length) + return self.string(dec) + + +@dataclass +class CDeclarator: + declare: str | CFuncPointer + pointer: str | None = None + extensions: list[str] = field(default_factory=list) + elements: int | str | None = None + dtype: str | None = None + init: str | None = None + + @property + def is_pointer(self) -> bool: + return self.pointer is not None and len(self.pointer.strip(' ')) > 0 + + @property + def is_array(self) -> bool: + if self.elements is not None: + return True + # jump through the CFuncPointer to its CDeclarator + return isinstance(self.declare, CFuncPointer) and self.declare.declare.is_array + + @property + def name(self) -> str: + return self.declare if isinstance(self.declare, str) else self.declare.name + + def copy(self, suffix: str | None = None) -> TCDeclarator: + dec = self.declare + if suffix is not None and isinstance(self.declare, str): + dec = f'{self.declare}_{suffix}' + elif suffix is not None: + dec = self.declare.copy(suffix=suffix) + return CDeclarator( + pointer=self.pointer, + declare=dec, + extensions=[x for x in self.extensions], + elements=self.elements, + dtype=self.dtype, + init=self.init, + ) + + def __eq__(self, other: TCDeclarator): + if not isinstance(other, CDeclarator): + raise ValueError('Type mismatch') + if len(self.extensions) != len(other.extensions): + return False + for a, b in zip(self.extensions, other.extensions): + if a != b: + return False + for a, b in zip((self.pointer, self.elements, self.dtype, self.init), + (other.pointer, other.elements, other.dtype, other.init)): + if not_both_None_or_not_equal(a, b): + return False + return same_type(self.declare, other.declare) and self.declare == other.declare + + def string(self, dec_str): ext = " ".join(f'{x}' for x in self.extensions) - dec = f'{self.declare} {ext}' if len(ext) else f'{self.declare}' + dec = f'{dec_str} {ext}' if len(ext) else f'{dec_str}' if self.pointer: dec = f'{self.pointer} {dec}' if self.dtype: dec = f'{self.dtype} {dec}' return dec - def variable_key(self): - ext = " ".join(f'{x}' for x in self.extensions) - dec = f'{self.declare} {ext}' if len(ext) else f'{self.declare}' - if self.pointer: - dec = f'{self.pointer} {dec}' - return dec - - def __hash__(self): - return hash(str(self)) - - -class CFuncPointer: - def __init__(self, declare: CDeclarator, modifiers): - self.mods = modifiers - self.declare = declare - self.args = None - def __str__(self): - f = f'({self.mods} {self.declare})' if self.mods else f'({self.declare})' - return f'{f}({self.args})' if self.args else f'{f}()' + return self.string(str(self.declare)) def __hash__(self): return hash(str(self)) + def as_struct_member(self, max_array_length: int = 16384): + if self.init: + max_array_length = min(len(self.init.split(',')), max_array_length) + no = self.elements if self.elements else max_array_length + if isinstance(self.declare, CFuncPointer): + return self.string(self.declare.as_struct_member(max_array_length=no)) + elif self.elements is not None: + return f'{self}[{no}]' + return str(self) + class DeclaresCVisitor(CVisitor): def __init__(self, typedefs: list | None = None, verbose: bool = False): self.verbose = verbose self.typedefs = [x for x in typedefs] if typedefs else [] - self.declares = {} + self.declares = [] def debug(self, message): if self.verbose: @@ -121,11 +214,12 @@ def visitDeclaration(self, ctx:CParser.DeclarationContext): elif len(inits): for decl, init in inits: decl.dtype = ' '.join(specs) - self.declares[decl] = init + decl.init = init + self.declares.append(decl) elif len(specs) > 1: decl = CDeclarator(pointer=None, declare=specs[-1], extensions=[]) decl.dtype = ' '.join(specs[:-1]) - self.declares[decl] = None + self.declares.append(decl) # Five declaration specifiers def visitStorageClassSpecifier(self, ctx:CParser.StorageClassSpecifierContext): @@ -194,7 +288,22 @@ def visitDeclarator(self, ctx:CParser.DeclaratorContext): ptr = self.visit(ctx.pointer()) if ctx.pointer() else None dec = self.visit(ctx.directDeclarator()) extensions = [self.visit(x) for x in ctx.gccDeclaratorExtension()] - return CDeclarator(pointer=ptr, declare=dec, extensions=extensions) + elements = None + if isinstance(dec, CFuncPointer): + # elements = dec.declare.elements + # dec.declare.elements = None + pass + elif all(x in dec for x in ('[', ']')): + if dec.count('[') > 1 or dec.count(']') > 1: + raise RuntimeError('No idea how to handle multi-level arrays') + dec, num_post = dec.split('[', 1) + num, _ = num_post.split(']', 1) + try: + elements = int(num) if len(num) else 0 + except ValueError as er: + logger.info(f"Could not convert an integer from {num} due to {er}") + elements = num + return CDeclarator(pointer=ptr, declare=dec, extensions=extensions, elements=elements) def visitPointer(self, ctx:CParser.PointerContext): self.debug(f'pointer {literal_string(ctx)}') @@ -217,7 +326,9 @@ def visitDirectDeclarator(self, ctx:CParser.DirectDeclaratorContext): return dec + after_dd_str -def extract_c_declared_variables_and_defined_types(block: str, user_types: list = None, verbose=False): +def extract_c_declared_variables_and_defined_types( + block: str, user_types: list = None, verbose=False +) -> tuple[list[CDeclarator], list[str]]: from antlr4 import InputStream, CommonTokenStream from antlr4.error.ErrorListener import ErrorListener from ..grammar import CLexer @@ -229,17 +340,25 @@ def extract_c_declared_variables_and_defined_types(block: str, user_types: list tree = parser.compilationUnit() visitor = DeclaresCVisitor(user_types, verbose=verbose) visitor.visitCompilationUnit(tree) - # Consider _using_ the CDeclarator class instead of this conversion? - variables = {dec.variable_key(): (dec.dtype, init) for dec, init in visitor.declares.items()} - return variables, visitor.typedefs + return visitor.declares, visitor.typedefs -def extract_c_declared_variables(block: str, user_types: list = None, verbose=False): - variables, types = extract_c_declared_variables_and_defined_types(block, user_types, verbose=verbose) +def extract_c_declared_variables( + block: str, user_types: list = None, verbose=False +) -> list[CDeclarator]: + variables, _ = extract_c_declared_variables_and_defined_types(block, user_types, verbose=verbose) return variables +def extract_c_defined_types( + block: str, user_types: list = None, verbose=False +) -> list[str]: + _, types = extract_c_declared_variables_and_defined_types(block, user_types, verbose=verbose) + return types + -def extract_c_defined_then_declared_variables(defined_in_block: str, declared_in_block): +def extract_c_defined_then_declared_variables( + defined_in_block: str, declared_in_block +) -> list[CDeclarator]: _, defined_in_types = extract_c_declared_variables_and_defined_types(defined_in_block) return extract_c_declared_variables(declared_in_block, user_types=defined_in_types) @@ -271,8 +390,12 @@ def visitExpressionIdentifier(self, ctx: McInstrParser.ExpressionIdentifierConte return Expr.id(name) -def evaluate_c_defined_variables(variables: dict[str, str], initialized_in: str, known: dict[str, Expr] = None, - verbose=False): +def evaluate_c_defined_variables( + variables: dict[str, str], + initialized_in: str, + known: dict[str, Expr] = None, + verbose=False +): """Evaluate individual statements from C-like source in an attempt to find values for the provided variables""" from antlr4 import InputStream from ..grammar import McInstr_parse, McInstr_ErrorListener @@ -301,12 +424,16 @@ def _get_expr(type_name: str, initial_value: str) -> Expr: return expr -def extract_c_declared_expressions(block: str, user_types: list = None, verbose=False) -> dict[str, Expr]: +def extract_c_declared_expressions( + block: str, user_types: list = None, verbose=False +) -> dict[CDeclarator, Expr]: variables = extract_c_declared_variables(block, user_types, verbose=verbose) - return {name: _get_expr(dt, val) for name, (dt, val) in variables.items()} + return {d: _get_expr(d.dtype, d.init) for d in variables} -def evaluate_c_defined_expressions(variables: dict[str, Expr], initialized_in: str, verbose=False) -> dict[str, Expr]: +def evaluate_c_defined_expressions( + variables: dict[str, Expr], initialized_in: str, verbose=False +) -> dict[str, Expr]: """For defined identifiers, evaluate a (simple) block of C code to determine the end values of the identifiers""" names_types = {name: expr.data_type.name for name, expr in variables.items()} return evaluate_c_defined_variables(names_types, initialized_in, known=variables, verbose=verbose) diff --git a/src/mccode_antlr/translators/c_trace.py b/src/mccode_antlr/translators/c_trace.py index 434ddc1..fdcb66f 100644 --- a/src/mccode_antlr/translators/c_trace.py +++ b/src/mccode_antlr/translators/c_trace.py @@ -1,4 +1,8 @@ -def _runtime_parameters(is_mcstas): +from .c_listener import CDeclarator +from ..comp import Comp +from ..instr import Instr + +def _runtime_parameters(is_mcstas: bool): pars = ['x', 'y', 'z'] if is_mcstas: pars.extend(['vx', 'vy', 'vz', 't', 'sx', 'sy', 'sz', 'p', 'mcgravitation', 'mcMagnet', 'allow_backprop']) @@ -9,14 +13,14 @@ def _runtime_parameters(is_mcstas): return pars -def _runtime_kv_parameters(is_mcstas): +def _runtime_kv_parameters(is_mcstas: bool): pars = ['p', 't'] pars.extend(['vx', 'vy', 'vz'] if is_mcstas else ['kx', 'ky', 'kz']) pars.extend(['x', 'y', 'z']) return pars -def def_trace_section(is_mcstas): +def def_trace_section(is_mcstas: bool): lines = [ "/*******************************************************************************", "* components TRACE", @@ -49,7 +53,7 @@ def def_trace_section(is_mcstas): return '\n'.join(lines) -def undef_trace_section(is_mcstas): +def undef_trace_section(is_mcstas: bool): lines = [f'#undef {x}' for x in _runtime_parameters(is_mcstas)] lines.extend([ "#ifdef OPENACC", @@ -73,14 +77,33 @@ def undef_trace_section(is_mcstas): return '\n'.join(lines) -def cogen_trace_section(is_mcstas, source, declared_parameters, instrument_uservars, component_uservars): +def cogen_trace_section( + is_mcstas: bool, + source: Instr, + declared_parameters: dict[str, list[CDeclarator]], + instrument_uservars: list[CDeclarator], + component_uservars: dict[str, list[CDeclarator]], +) -> str: return '\n'.join([ - cogen_comp_trace_class(is_mcstas, c, source, declared_parameters[c.name], - instrument_uservars, component_uservars[c.name]) for c in source.component_types() + cogen_comp_trace_class( + is_mcstas, + component_type, + source, + declared_parameters[component_type.name], + instrument_uservars, + component_uservars[component_type.name] + ) for component_type in source.component_types() ]) -def cogen_comp_trace_class(is_mcstas, comp, source, declared_parameters, instr_uservars, comp_uservars): +def cogen_comp_trace_class( + is_mcstas: bool, + comp: Comp, + source: Instr, + declared_parameters: list[CDeclarator], + instr_uservars: list[CDeclarator], + comp_uservars: list[CDeclarator], +) -> str: from .c_defines import cogen_parameter_define, cogen_parameter_undef # count matching component type instances which define an EXTEND block: extended = [(n, i) for n, i in enumerate(source.components) if i.type.name == comp.name and len(i.extend)] @@ -103,7 +126,7 @@ def cogen_comp_trace_class(is_mcstas, comp, source, declared_parameters, instr_u # Check if there are any user-defined parameter types ... (something which wasn't set previously?) # This is the 'symbol' type - declared_types = [x.type for x in declared_parameters] + declared_types = [x.dtype for x in declared_parameters] # there must be a better way than this is_symbol = [t == 'symbol' for t in declared_types] # TODO FIXME This should be looping through setting parameters. It is probably wrong. @@ -116,8 +139,8 @@ def cogen_comp_trace_class(is_mcstas, comp, source, declared_parameters, instr_u # loop through the symbol types and set the component-instance values from ... somewhere for c_dec in declared_parameters: # Use the user-defined instance parameter if it exists, or attempt to use a default otherwise? - inst_param = [p for p in inst.parameters if p.name == c_dec.name] - v = inst_param[0].value if len(inst_param) else c_dec.init + v = next(iter(p for p in inst.parameters if p.name == c_dec.name), None) + v = v.value if v else c_dec.init if v is not None: lines.append(f' {c_dec.name} = {v};') lines.append(' }') @@ -143,7 +166,6 @@ def long(x): if len(extended): # combine the USERVARS from the instrument and this component type blocks: - # uvs = set().union(instr_uservars).union(comp_uservars) uvs = list(dict.fromkeys([*instr_uservars, *comp_uservars])) # So that the EXTEND block(s) can access them lines.extend([f' #define {x.name} (_particle->{x.name})' for x in uvs]) diff --git a/tests/instr/test_instance_parameters.py b/tests/instr/test_instance_parameters.py index 7b220a5..172d477 100644 --- a/tests/instr/test_instance_parameters.py +++ b/tests/instr/test_instance_parameters.py @@ -144,9 +144,13 @@ def declare_line(name, vec): # But we can attempt to parse the declarations and instantiation blocks from the instrument from mccode_antlr.translators.c_listener import extract_c_declared_expressions, evaluate_c_defined_expressions + # Here variables _is_ and _must be_ dict[str, Expr] variables = {x.name: x.value for x in assembler.instrument.parameters} for dec in assembler.instrument.declare: - variables.update(extract_c_declared_expressions(dec.source)) + # This produces dict[CDeclarator, Expr] + decs_expr = extract_c_declared_expressions(dec.source) + # So extract just the CDeclarator's name + variables.update({key.name: val for key, val in decs_expr.items()}) # defined as # TODO this does not work because the simple "C"-style expression parser doesn't know about pointers diff --git a/tests/runtime/test_examples.py b/tests/runtime/test_examples.py index b93b3d0..4c9221d 100644 --- a/tests/runtime/test_examples.py +++ b/tests/runtime/test_examples.py @@ -113,86 +113,65 @@ def test_function_pointer_declare_parameter(): @compiled def test_function_pointer_component_declare_parameter(): - """This doesn't work yet because the declared-parameter handling doesn't - know about function pointers, and inserts in the component struct - ``` - typedef struct { - ... - int (fun_ptr)(int, int); - int (fun_ptr_arr[])(int, int)[3]; - } _parameters; - ``` - and tries accessing them as, e.g., `_parameters.(fun_ptr)(int, int)` + from mccode_antlr.reader.registry import InMemoryRegistry + in_memory_registry = InMemoryRegistry('test_components') + comp_name = 'declares_function_pointer' + in_memory_registry.add_comp(comp_name, dedent(rf""" + DEFINE COMPONENT {comp_name} DEFINITION PARAMETERS () + SETTING PARAMETERS (int selector=0, int A=0, int B=0) + OUTPUT PARAMETERS () + SHARE %{{ + int add(int a, int b){{return a + b;}} + int sub(int a, int b){{return a - b;}} + int mul(int a, int b){{return a * b;}} + %}} + DECLARE %{{ + int (*fun_ptr)(int, int); + int (*fun_ptr_arr[])(int, int) = {{add, sub, mul}}; + %}} + INITIALIZE %{{ + switch (selector) {{ + case 0: fun_ptr = add; break; + case 1: fun_ptr = sub; break; + case 2: fun_ptr = mul; break; + default: + printf("%s: Invalid selector=%d, valid settings are {{0, 1, 2}}\n", NAME_CURRENT_COMP, selector); + exit(-1); + }} + %}} + TRACE %{{ + printf("%s=%d\n", NAME_CURRENT_COMP, fun_ptr(A, B)); + printf("%s=%d\n", NAME_CURRENT_COMP, fun_ptr_arr[selector](A, B)); + %}} + SAVE %{{ + %}} + FINALLY %{{ + %}} + END + """)) - When it should enter - ``` - typedef struct { - ... - int (* fun_ptr)(int, int); - int (* fun_ptr_arr[3])(int, int); - } _parameters; - ``` - and access them as `_parameters.fun_ptr` and `_parameters.fun_ptr_arr`. - """ - from loguru import logger - logger.critical('Update C defined-parameter handling') - # from mccode_antlr.reader.registry import InMemoryRegistry - # in_memory_registry = InMemoryRegistry('test_components') - # comp_name = 'declares_function_pointer' - # in_memory_registry.add_comp(comp_name, dedent(rf""" - # DEFINE COMPONENT {comp_name} DEFINITION PARAMETERS () - # SETTING PARAMETERS (int selector=0, int A=0, int B=0) - # OUTPUT PARAMETERS () - # SHARE %{{ - # int add(int a, int b){{return a + b;}} - # int sub(int a, int b){{return a - b;}} - # int mul(int a, int b){{return a * b;}} - # %}} - # DECLARE %{{ - # int (*fun_ptr)(int, int); - # int (*fun_ptr_arr[])(int, int) = {{add, sub, mul}}; - # %}} - # INITIALIZE %{{ - # switch (selector) {{ - # case 0: fun_ptr = add; break; - # case 1: fun_ptr = sub; break; - # case 2: fun_ptr = mul; break; - # default: - # printf("%s: Invalid selector=%d, valid settings are {{0, 1, 2}}\n", NAME_CURRENT_COMP, selector); - # exit(-1); - # }} - # %}} - # TRACE %{{ - # printf("%s=%d\n", NAME_CURRENT_COMP, fun_ptr(A, B)); - # printf("%s=%d\n", NAME_CURRENT_COMP, fun_ptr_arr[selector](A, B)); - # %}} - # SAVE %{{ - # %}} - # FINALLY %{{ - # %}} - # END - # """)) - # print(in_memory_registry.components) - # - # instr = parse_mcstas_instr(dedent(rf""" - # DEFINE INSTRUMENT with_component_function_pointers(int which=0) - # DECLARE %{{ %}} - # INITIALIZE %{{ %}} - # TRACE - # COMPONENT first = {comp_name}(selector=which, A=2, B=1) AT (0, 0, 0) ABSOLUTE - # COMPONENT second = {comp_name}(selector=which, A=1, B=3) AT (0, 0, 0) ABSOLUTE - # FINALLY %{{ %}} - # END - # """), registries=[in_memory_registry]) - # for which in (0, 1, 2): - # results, files = compile_and_run(instr, f'-n 1 which={which}') - # lines = results.decode('utf-8').splitlines() - # a, b = 0, 0 - # if which == 0: - # a, b = 2 + 1, 1 + 3 - # elif which == 1: - # a, b = 2 - 1, 1 - 3 - # elif which == 2: - # a, b = 2 * 1, 1 * 3 - # assert lines[0] == f"first={a}" - # assert lines[1] == f"second={b}" \ No newline at end of file + instr = parse_mcstas_instr(dedent(rf""" + DEFINE INSTRUMENT with_component_function_pointers(int which=0) + DECLARE %{{ %}} + INITIALIZE %{{ %}} + TRACE + COMPONENT first = {comp_name}(selector=which, A=2, B=1) AT (0, 0, 0) ABSOLUTE + COMPONENT second = {comp_name}(selector=which, A=1, B=3) AT (0, 0, 0) ABSOLUTE + FINALLY %{{ %}} + END + """), registries=[in_memory_registry]) + for which in (0, 1, 2): + results, files = compile_and_run(instr, f'-n 1 which={which}') + lines = results.decode('utf-8').splitlines() + print(lines) + a, b = 0, 0 + if which == 0: + a, b = 2 + 1, 1 + 3 + elif which == 1: + a, b = 2 - 1, 1 - 3 + elif which == 2: + a, b = 2 * 1, 1 * 3 + assert lines[0] == f"first={a}" + assert lines[1] == f"first={a}" + assert lines[2] == f"second={b}" + assert lines[3] == f"second={b}" \ No newline at end of file diff --git a/tests/test_c_type_declaration.py b/tests/test_c_type_declaration.py index 8dd8f75..4e2962f 100644 --- a/tests/test_c_type_declaration.py +++ b/tests/test_c_type_declaration.py @@ -1,5 +1,6 @@ from mccode_antlr.translators.c_listener import ( - extract_c_declared_variables_and_defined_types as extract + extract_c_declared_variables_and_defined_types as extract, + CDeclarator, CFuncPointer ) from textwrap import dedent @@ -80,12 +81,14 @@ def test_assignments(): variables, types = extract(block) assert len(variables) == 3 assert len(types) == 0 - assert 'blah' in variables - assert variables['blah'] == ('int', '1') - assert 'yarg' in variables - assert variables['yarg'] == ('double', None) - assert 'mmmm[11]' in variables - assert variables['mmmm[11]'] == ('char', '"0123456789"') + expected = [ + CDeclarator(dtype='int', declare='blah', init='1'), + CDeclarator(dtype='double', declare='yarg'), + CDeclarator(dtype='char', declare='mmmm', init='"0123456789"', elements=11), + ] + for x in expected: + assert x in variables + def test_struct_declaration(): block = dedent("""\ @@ -96,14 +99,12 @@ def test_struct_declaration(): variables, types = extract(block) assert len(variables) == 3 assert len(types) == 0 - expected = { - 'the_struct': ('struct my_struct_type', None), - 'a_struct_with_values': ('struct another_struct', '{0, 1.0, "two"}'), - '* ptr_to_third_struct': ('struct the_third_s', None), - } + expected = [ + CDeclarator(dtype='struct my_struct_type', declare='the_struct'), + CDeclarator(dtype='struct another_struct', declare='a_struct_with_values', init='{0, 1.0, "two"}'), + CDeclarator(dtype='struct the_third_s', declare='ptr_to_third_struct', pointer='*') + ] assert all(x in variables for x in expected) - for name, (dtype, value) in expected.items(): - assert(variables[name] == (dtype, value)) def test_typedef_declaration(): @@ -117,14 +118,12 @@ def test_typedef_declaration(): assert len(variables) == 3 assert len(types) == 1 assert types[0] == "blah" - expected = { - 'really_a_double': ('blah', '1.0f'), - '* double_ptr': ('blah', 'NULL'), - 'double_array[10]': ('blah', None), - } + expected = [ + CDeclarator(dtype='blah', declare='really_a_double', init='1.0f'), + CDeclarator(dtype='blah', declare='double_ptr', pointer='*', init='NULL'), + CDeclarator(dtype='blah', declare='double_array', elements=10) + ] assert all(x in variables for x in expected) - for name, (dtype, value) in expected.items(): - assert(variables[name] == (dtype, value)) def test_flatellipse_finite_mirror(): @@ -143,18 +142,20 @@ def test_flatellipse_finite_mirror(): variables, types = extract(block) assert len(types) == 0 assert len(variables) == 6 - expected = { - 's': ('Scene', None), - 'p1': ('Point', None), - 'traceNeutronConicWithTables(_class_particle* p, ConicSurf c)': ('void', None), - '* rfront_inner': ('double', None), - 'silicon': ('int', None), - 'rsTable': ('t_Table', None) - } - assert 'c' not in variables - for name, (dtype, value) in expected.items(): - assert name in variables, f"{name} not in {list(variables.items())}" - assert variables[name] == (dtype, value) + expected = [ + CDeclarator(dtype='Scene', declare='s'), + CDeclarator(dtype='Point', declare='p1'), + # the function pre-declaration _IS NOT_ a function pointer (and should be in DECLARE) + CDeclarator( + dtype='void', + declare='traceNeutronConicWithTables(_class_particle* p, ConicSurf c)', + ), + CDeclarator(dtype='double', pointer='*', declare='rfront_inner'), + CDeclarator(dtype='int', declare='silicon'), + CDeclarator(dtype='t_Table', declare='rsTable') + ] + for x, y in zip(expected, variables): + assert x == y def test_function_pointer_declaration(): @@ -166,13 +167,36 @@ def test_function_pointer_declaration(): variables, types = extract(block) assert len(types) == 0 assert len(variables) == 3 - print(variables) - - expected = { - '(* fun_ptr)(int, int)': ('int', None), - '(* fun_ptr_ar3[3])(int, int)': ('int', None), - '(* fun_ptr_arr[])(int, int)': ('int', '{add, sub, mul}'), - } - for name, (dtype, value) in expected.items(): - assert name in variables - assert variables[name] == (dtype, value) \ No newline at end of file + expected = [ + CDeclarator( + dtype='int', + declare=CFuncPointer( + declare=CDeclarator(pointer='*', declare='fun_ptr'), + args='int, int', + ), + ), + CDeclarator( + dtype='int', + declare=CFuncPointer( + declare=CDeclarator(pointer='*', declare='fun_ptr_ar3', elements=3), + args='int, int', + ), + ), + CDeclarator( + dtype='int', + declare=CFuncPointer( + declare=CDeclarator(pointer='*', declare='fun_ptr_arr', elements=0), + args='int, int', + ), + init='{add, sub, mul}', + ), + ] + for x, y in zip(expected, variables): + assert x == y + members = [ + 'int (* fun_ptr)(int, int)', + 'int (* fun_ptr_ar3[3])(int, int)', + 'int (* fun_ptr_arr[3])(int, int)', + ] + for x, y in zip(members, variables): + assert x == y.as_struct_member() \ No newline at end of file diff --git a/tests/test_extract.py b/tests/test_extract.py index bfadaf3..1ba61f0 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -36,7 +36,8 @@ def test_runtime_parameter(self): # parse the declare block(s) to find parameter declarations variables = {} for dec in instr.declare: - variables.update({k: t for k, (t, _) in extract_c_declared_variables(dec.source).items()}) + decs = extract_c_declared_variables(dec.source) + variables.update({d.name: d.dtype for d in decs}) # Then parse and evaluate initialize to set their value(s). values = {} for init in instr.initialize: