From 42715a03c7bfb820395b73f968545c91ce1d466c Mon Sep 17 00:00:00 2001 From: Andrew Helwer <2n8rn1w1f@mozmail.com> Date: Wed, 27 Mar 2024 15:49:34 -0400 Subject: [PATCH] Maybe finished number set shim script Signed-off-by: Andrew Helwer <2n8rn1w1f@mozmail.com> --- .github/scripts/unicode_number_set_shim.py | 53 ++++++++++++++++++---- 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/.github/scripts/unicode_number_set_shim.py b/.github/scripts/unicode_number_set_shim.py index cd1b862c..282c2b32 100644 --- a/.github/scripts/unicode_number_set_shim.py +++ b/.github/scripts/unicode_number_set_shim.py @@ -28,15 +28,44 @@ class NumberSetShim: NumberSetShim('Real', 'ℝ', 'real') ] -def build_query(language): +def build_number_set_query(language): """ - Builds number set shim query. + Builds query looking for use of number sets. """ return language.query(' '.join([f'({shim.capture}_number_set "{shim.unicode}") @{shim.capture}' for shim in shims])) -def get_required_defs(examples_root, parser, path, query): - tree, _, _ = tla_utils.parse_module(examples_root, parser, path) - return set([name for _, name in query.features.captures(tree.root_node)]) +def build_insertion_point_query(language): + """ + Builds query to get insertion point for shim definitions. + """ + return language.query('((header_line) name: (identifier) (header_line)) @header (extends) @extends') + +def get_required_defs(tree, query): + return set([name for _, name in query.captures(tree.root_node)]) + +def get_def_text(required_defs): + defs = '\n' + for shim in shims: + if shim.capture in required_defs: + defs += f'LOCAL {shim.unicode} ≜ {shim.ascii}\n' + return defs + +def get_insertion_point(tree, query): + captures = query.captures(tree.root_node) + has_extends = any(name for (_, name) in captures if name == 'extends') + if has_extends: + extends_node = next(node for (node, name) in captures if name == 'extends') + return extends_node.byte_range[1] + else: + header = next(node for (node, name) in captures if name == 'header') + return header.byte_range[1] + +def insert_defs(module_path, insertion_point, defs): + def_bytes = bytes(defs, 'utf-8') + with open(module_path, 'rb+', encoding='utf-8') as module: + module_bytes = bytearray(module.read()) + module_bytes[insertion_point:insertion_point] = def_bytes + module.write(module_bytes) if __name__ == '__main__': parser = ArgumentParser(description='Adds ℕ/ℤ/ℝ Unicode number set shim definitions to modules as needed.') @@ -53,7 +82,8 @@ def get_required_defs(examples_root, parser, path, query): only_modules = [normpath(path) for path in args.only] (TLAPLUS_LANGUAGE, parser) = tla_utils.build_ts_grammar(normpath(args.ts_path)) - query = build_query(TLAPLUS_LANGUAGE) + number_set_query = build_number_set_query(TLAPLUS_LANGUAGE) + insertion_point_query = build_insertion_point_query(TLAPLUS_LANGUAGE) modules = [ module['path'] @@ -65,6 +95,13 @@ def get_required_defs(examples_root, parser, path, query): for module_path in modules: logging.info(f'Processing {module_path}') - required_defs = get_required_defs(examples_root, parser, module_path, query) - logging.info(f'Require {required_defs}') + tree, _, _ = tla_utils.parse_module(examples_root, parser, module_path) + required_defs = get_required_defs(tree, number_set_query) + if not any(required_defs): + logging.info('No shim insertion necessary') + continue + logging.info(f'Inserting defs {required_defs}') + defs = get_def_text(required_defs) + insertion_point = get_insertion_point(tree, insertion_point_query) + insert_defs(module_path, insertion_point, defs)