diff --git a/mccode_antlr/common/metadata.py b/mccode_antlr/common/metadata.py index a098a11..81789bd 100644 --- a/mccode_antlr/common/metadata.py +++ b/mccode_antlr/common/metadata.py @@ -55,15 +55,18 @@ def partial_from_tokens(source: DataSource, mimetype: str, name: str, value: str @staticmethod def from_component_tokens(source: str, mimetype: str, name: str, value: str): - return MetaData.partial_from_tokens(DataSource(DataSource.Type.Component, source), name, mimetype, value) + return MetaData.partial_from_tokens(source=DataSource(DataSource.Type.Component, source), + mimetype=mimetype, name=name, value=value) @staticmethod def from_instance_tokens(source: str, mimetype: str, name: str, value: str): - return MetaData.partial_from_tokens(DataSource(DataSource.Type.Instance, source), name, mimetype, value) + return MetaData.partial_from_tokens(source=DataSource(DataSource.Type.Instance, source), + mimetype=mimetype, name=name, value=value) @staticmethod def from_instrument_tokens(source: str, mimetype: str, name: str, value: str): - return MetaData.partial_from_tokens(DataSource(DataSource.Type.Instrument, source), name, mimetype, value) + return MetaData.partial_from_tokens(source=DataSource(DataSource.Type.Instrument, source), + mimetype=mimetype, name=name, value=value) # output to metadata_table_struct initializer list def to_table_row(self): diff --git a/mccode_antlr/comp/visitor.py b/mccode_antlr/comp/visitor.py index 2793eaa..4f040b2 100644 --- a/mccode_antlr/comp/visitor.py +++ b/mccode_antlr/comp/visitor.py @@ -225,7 +225,10 @@ def visitDisplayBlockCopy(self, ctx: Parser.DisplayBlockCopyContext): def visitMetadata(self, ctx: Parser.MetadataContext): filename, line_number, metadata = self.visit(ctx.unparsed_block()) - self.state.add_metadata(MetaData.from_component_tokens(self.state.name, str(ctx.mime), str(ctx.name), metadata)) + mime = ctx.mime.text if ctx.mime.type == Parser.Identifier else ctx.mime.text[1:-1] + name = ctx.name.text if ctx.name.type == Parser.Identifier else ctx.name.text[1:-1] + metadata = MetaData.from_component_tokens(source=self.state.name, mimetype=mime, name=name, value=metadata) + self.state.add_metadata(metadata) def visitUnparsed_block(self, ctx: Parser.Unparsed_blockContext): # We want to extract the source-file line number (and filename) for use in the C-preprocessor diff --git a/mccode_antlr/instr/visitor.py b/mccode_antlr/instr/visitor.py index c10a2a5..10f5f3d 100644 --- a/mccode_antlr/instr/visitor.py +++ b/mccode_antlr/instr/visitor.py @@ -41,7 +41,8 @@ def getInstrument_parameter(self, ctx: McInstrParser.Instrument_parameterContext def visitInstrument_metadata(self, ctx: McInstrParser.Instrument_metadataContext): for metadata_context in ctx.metadata(): mime, name, metadata = self.visit(metadata_context) - self.state.add_metadata(MetaData.from_instrument_tokens(self.state.name, mime, name, metadata)) + metadata = MetaData.from_instrument_tokens(source=self.state.name, mimetype=mime, name=name, value=metadata) + self.state.add_metadata(metadata) def visitInstrumentParameterDouble(self, ctx: McInstrParser.InstrumentParameterDoubleContext): name = str(ctx.Identifier()) @@ -138,7 +139,8 @@ def visitComponent_instance(self, ctx: McInstrParser.Component_instanceContext): # deal with definition vs instance metadata here? for metadata_context in ctx.metadata(): mime, name, metadata = self.visit(metadata_context) - instance.add_metadata(MetaData.from_component_tokens(name, mime, name, metadata)) + metadata = MetaData.from_instance_tokens(source=instance.name, mimetype=mime, name=name, value=metadata) + instance.add_metadata(metadata) # Include this instantiated component instance in the instrument components list if self.destination is None or not instance.removable: # if this _is_ an included instrument, any REMOVABLE component instances should not be added diff --git a/test/test_metadata.py b/test/test_metadata.py new file mode 100644 index 0000000..9de8f23 --- /dev/null +++ b/test/test_metadata.py @@ -0,0 +1,54 @@ +import unittest + + +class MetadataTestCase(unittest.TestCase): + def test_direct_creation(self): + from mccode_antlr.common import MetaData + md = MetaData.from_instance_tokens('instance_source', 'mimetype', 'metadata_name', 'metadata_value') + self.assertEqual(md.source.name, 'instance_source') + self.assertEqual(md.source.type_name, 'Instance') + self.assertEqual(md.name, 'metadata_name') + self.assertEqual(md.mimetype, 'mimetype') + self.assertEqual(md.value, 'metadata_value') + + def test_parsed(self): + from mccode_antlr.loader import parse_mcstas_instr + from json import loads + instr = """ + DEFINE INSTRUMENT splitRunTest(a1=0, a2=0, virtual_source_x=0.05, virtual_source_y=0.1, string newname) + TRACE + COMPONENT origin = Arm() AT (0, 0, 0) ABSOLUTE + COMPONENT point = Arm() AT (0, 0, 0.8) RELATIVE origin + METADATA "application/text" "names with spaces keep their quotes" %{ + This is some unparsed metadata that will be included as a literal string in the instrument. + %} + COMPONENT line = Arm() AT (0, 0, 1) RELATIVE PREVIOUS METADATA "application/json" identifier_name %{ + {"key": "value", "array": [1, 2, 3]} + %} + END + """ + instr = parse_mcstas_instr(instr) + self.assertEqual(len(instr.components), 3) + self.assertEqual(instr.components[1].name, 'point') + self.assertEqual(len(instr.components[1].metadata), 1) + md = instr.components[1].metadata[0] + self.assertEqual(md.source.name, 'point') + self.assertEqual(md.source.type_name, 'Instance') + self.assertEqual(md.name, '"names with spaces keep their quotes"') + self.assertEqual(md.mimetype, 'application/text') + self.assertEqual(md.value, """ + This is some unparsed metadata that will be included as a literal string in the instrument. + """) + self.assertEqual(instr.components[2].name, 'line') + self.assertEqual(len(instr.components[2].metadata), 1) + md = instr.components[2].metadata[0] + self.assertEqual(md.source.name, 'line') + self.assertEqual(md.source.type_name, 'Instance') + self.assertEqual(md.name, 'identifier_name') + self.assertEqual(md.mimetype, 'application/json') + self.assertEqual(loads(md.value), {'key': 'value', 'array': [1, 2, 3]}) + + + +if __name__ == '__main__': + unittest.main()