Skip to content

Commit

Permalink
Maybe finished number set shim script
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Helwer <[email protected]>
  • Loading branch information
ahelwer committed Mar 27, 2024
1 parent 3546ce8 commit 42715a0
Showing 1 changed file with 45 additions and 8 deletions.
53 changes: 45 additions & 8 deletions .github/scripts/unicode_number_set_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,44 @@ class NumberSetShim:
NumberSetShim('Real', 'ℝ', 'real')
]

def build_query(language):
def build_number_set_query(language):
"""
Builds number set shim query.
Builds query looking for use of number sets.
"""
return language.query(' '.join([f'({shim.capture}_number_set "{shim.unicode}") @{shim.capture}' for shim in shims]))

def get_required_defs(examples_root, parser, path, query):
tree, _, _ = tla_utils.parse_module(examples_root, parser, path)
return set([name for _, name in query.features.captures(tree.root_node)])
def build_insertion_point_query(language):
"""
Builds query to get insertion point for shim definitions.
"""
return language.query('((header_line) name: (identifier) (header_line)) @header (extends) @extends')

def get_required_defs(tree, query):
return set([name for _, name in query.captures(tree.root_node)])

def get_def_text(required_defs):
defs = '\n'
for shim in shims:
if shim.capture in required_defs:
defs += f'LOCAL {shim.unicode}{shim.ascii}\n'
return defs

def get_insertion_point(tree, query):
captures = query.captures(tree.root_node)
has_extends = any(name for (_, name) in captures if name == 'extends')
if has_extends:
extends_node = next(node for (node, name) in captures if name == 'extends')
return extends_node.byte_range[1]
else:
header = next(node for (node, name) in captures if name == 'header')
return header.byte_range[1]

def insert_defs(module_path, insertion_point, defs):
def_bytes = bytes(defs, 'utf-8')
with open(module_path, 'rb+', encoding='utf-8') as module:
module_bytes = bytearray(module.read())
module_bytes[insertion_point:insertion_point] = def_bytes
module.write(module_bytes)

if __name__ == '__main__':
parser = ArgumentParser(description='Adds ℕ/ℤ/ℝ Unicode number set shim definitions to modules as needed.')
Expand All @@ -53,7 +82,8 @@ def get_required_defs(examples_root, parser, path, query):
only_modules = [normpath(path) for path in args.only]

(TLAPLUS_LANGUAGE, parser) = tla_utils.build_ts_grammar(normpath(args.ts_path))
query = build_query(TLAPLUS_LANGUAGE)
number_set_query = build_number_set_query(TLAPLUS_LANGUAGE)
insertion_point_query = build_insertion_point_query(TLAPLUS_LANGUAGE)

modules = [
module['path']
Expand All @@ -65,6 +95,13 @@ def get_required_defs(examples_root, parser, path, query):

for module_path in modules:
logging.info(f'Processing {module_path}')
required_defs = get_required_defs(examples_root, parser, module_path, query)
logging.info(f'Require {required_defs}')
tree, _, _ = tla_utils.parse_module(examples_root, parser, module_path)
required_defs = get_required_defs(tree, number_set_query)
if not any(required_defs):
logging.info('No shim insertion necessary')
continue
logging.info(f'Inserting defs {required_defs}')
defs = get_def_text(required_defs)
insertion_point = get_insertion_point(tree, insertion_point_query)
insert_defs(module_path, insertion_point, defs)

0 comments on commit 42715a0

Please sign in to comment.