Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Python] Function Keyword Argument and Function Call through Parameter PR #644

Merged
merged 20 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
5593d05
in progress changes for fixing sidarthe and chime tests
titomeister Oct 26, 2023
6506c86
Fixing sidarthe generation for functions as parameters
titomeister Oct 27, 2023
bd1558e
saving progress on fixing sidarthe
titomeister Oct 30, 2023
ea181ef
Cleaned up and fixed argument function calls
titomeister Nov 1, 2023
50abf51
Working on fixing a missing opi port issue
titomeister Nov 13, 2023
91083b4
Fixed test case, fixed missing OPI issue, ready for PR
titomeister Nov 14, 2023
e428c6b
Miscellaneous cleanup and removal of prints
titomeister Nov 14, 2023
edb6e0f
Fixing test
titomeister Nov 14, 2023
d94dce4
Merge branch 'main' of github.com:ml4ai/skema into ferra/climlab_func…
titomeister Nov 15, 2023
9e4868f
Merge branch 'main' into ferra/climlab_func_def_issue
vincentraymond-ua Nov 15, 2023
373eedd
Fixed an issue with ellipsis literal value
titomeister Nov 15, 2023
422d592
Merge branch 'ferra/climlab_func_def_issue' of github.com:ml4ai/skema…
titomeister Nov 15, 2023
84f55bb
Temporarily disabling import test
titomeister Nov 15, 2023
d62f1ff
Merge branch 'main' into ferra/climlab_func_def_issue
titomeister Nov 15, 2023
fde5bb7
Merge branch 'main' into ferra/climlab_func_def_issue
titomeister Nov 16, 2023
98514bb
Fixed CHIME-penn file ingestion
titomeister Nov 17, 2023
1ece6dc
Merge branch 'main' into ferra/climlab_func_def_issue
myedibleenso Nov 17, 2023
bf29be7
Merge branch 'main' into ferra/climlab_func_def_issue
myedibleenso Nov 17, 2023
ee747c2
Merge branch 'main' into ferra/climlab_func_def_issue
vincentraymond-ua Nov 17, 2023
f258a5b
Merge branch 'main' into ferra/climlab_func_def_issue
myedibleenso Nov 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions skema/gromet/execution_engine/types/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,16 @@ class Range:

def exec(input: int) -> range:
return range(input)


class Call:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We eventually will want to stop using the primitives in skema/gromet/execution_engine/types/ since they are artifacts left over from an outdated version of the execution engine. We have a YAML file which I believe should compile the same information in a more approachable way at skema/gromet/primitive_map.yaml

How large of a change would this require for the Gromet generation?

This is also something we may want to combine with python_builtins.yaml or create a home directory for these type of files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay. I had forgotten about that. I think as long as we can use the same interface all that would be needed would be to update what's going on behind the scenes. That is, I just need to change the interface to pull from the correct primitive map file. I don't expect there to be any reason to explicitly change anything in the Gromet generation code itself.

source_language_name = {"CAST": "_call"}
inputs = [
Field("func_name", "string"),
Field("args","Any",variatic=True)
]
outputs = [Field("call_output", "Any")]
shorthand = "_call"
documentation = ""

# TODO: exec
104 changes: 85 additions & 19 deletions skema/program_analysis/CAST/pythonAST/py_ast_to_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,16 @@ def get_node_name(ast_node):
elif isinstance(ast_node, Attribute):
return [ast_node.attr.name]
elif isinstance(ast_node, Var):
return [ast_node.val.name]
if isinstance(ast_node.val, Call):
return get_node_name(ast_node.val.func)
else:
return [ast_node.val.name]
elif isinstance(ast_node, Assignment):
return get_node_name(ast_node.left)
elif isinstance(ast_node, ast.Subscript):
return get_node_name(ast_node.value)
elif isinstance(ast_node, ast.Call):
return get_node_name(ast_node.func)
elif (
isinstance(ast_node, LiteralValue)
and (ast_node.value_type == StructureType.LIST or ast_node.value_type == StructureType.TUPLE)
Expand Down Expand Up @@ -240,6 +245,8 @@ def __init__(self, file_name: str, legacy: bool = False):
self.dict_comp_count = 0
self.lambda_count = 0

self.curr_func_args = []

def insert_next_id(self, scope_dict: dict, dict_key: str):
"""Given a scope_dictionary and a variable name as a key,
we insert a new key_value pair for the scope dictionary
Expand Down Expand Up @@ -1521,17 +1528,62 @@ def visit_Call(
)
]
else:
return [
Call(
func=Name(
node.func.id,
id=curr_scope_id_dict[unique_name] if unique_name in curr_scope_id_dict else prev_scope_id_dict[unique_name], # NOTE: do this everywhere?
source_refs=ref,
),
arguments=args,
source_refs=ref,
if node.func.id in self.curr_func_args:
unique_name = construct_unique_name(
self.filenames[-1], "_call"
)
]
if unique_name not in prev_scope_id_dict.keys(): # and unique_name not in curr_scope_id_dict.keys():
# If a built-in is called, then it gets added to the global dictionary if
# it hasn't been called before. This is to maintain one consistent ID per built-in
# function
if unique_name not in self.global_identifier_dict.keys():
self.insert_next_id(
self.global_identifier_dict, unique_name
)

prev_scope_id_dict[unique_name] = self.global_identifier_dict[
unique_name
]
unique_name = construct_unique_name(
self.filenames[-1], node.func.id
)
if unique_name not in prev_scope_id_dict.keys(): # and unique_name not in curr_scope_id_dict.keys():
# If a built-in is called, then it gets added to the global dictionary if
# it hasn't been called before. This is to maintain one consistent ID per built-in
# function
if unique_name not in self.global_identifier_dict.keys():
self.insert_next_id(
self.global_identifier_dict, unique_name
)

prev_scope_id_dict[unique_name] = self.global_identifier_dict[
unique_name
]
func_name_arg = Name(name=node.func.id, id=prev_scope_id_dict[unique_name], source_refs=ref)

return [
Call(
func=Name(
"_call",
id=curr_scope_id_dict[unique_name] if unique_name in curr_scope_id_dict else prev_scope_id_dict[unique_name], # NOTE: do this everywhere?
source_refs=ref,
),
arguments=[func_name_arg]+args,
source_refs=ref,
)
]
else:
return [
Call(
func=Name(
node.func.id,
id=curr_scope_id_dict[unique_name] if unique_name in curr_scope_id_dict else prev_scope_id_dict[unique_name], # NOTE: do this everywhere?
source_refs=ref,
),
arguments=args,
source_refs=ref,
)
]

def collect_fields(
self, node: ast.FunctionDef, prev_scope_id_dict, curr_scope_id_dict
Expand Down Expand Up @@ -1864,7 +1916,7 @@ def visit_Constant(
elif node.value is None:
return [LiteralValue(None, None, source_code_data_type, ref)]
elif isinstance(node.value, type(...)):
return []
return [LiteralValue(ScalarType.ELLIPSIS, "...", source_code_data_type, ref)]
else:
raise TypeError(f"Type {str(type(node.value))} not supported")

Expand Down Expand Up @@ -2178,6 +2230,11 @@ def visit_FunctionDef(
# The idea for this is to prevent any weird overwritting issues that may arise from modifying
# dictionaries in place
prev_scope_id_dict_copy = copy.deepcopy(prev_scope_id_dict)


# Need to maintain the previous scope, so copy them over here
prev_func_args = copy.deepcopy(self.curr_func_args)
self.curr_func_args = []

body = []
args = []
Expand All @@ -2195,6 +2252,7 @@ def visit_FunctionDef(
self.insert_next_id(curr_scope_id_dict, f"{arg.arg}")
# self.insert_next_id(curr_scope_id_dict, unique_name)
arg_ref = SourceRef(self.filenames[-1], arg.col_offset, arg.end_col_offset, arg.lineno, arg.end_lineno)
self.curr_func_args.append(arg.arg)
args.append(
Var(
Name(
Expand All @@ -2218,6 +2276,7 @@ def visit_FunctionDef(
prev_scope_id_dict,
curr_scope_id_dict,
)[0]
self.curr_func_args.append(arg.arg)
args.append(
Var(
Name(
Expand Down Expand Up @@ -2254,6 +2313,7 @@ def visit_FunctionDef(
if arg_count == default_val_count:
break
self.insert_next_id(curr_scope_id_dict, arg.arg)
self.curr_func_args.append(arg.arg)
args.append(
Var(
Name(
Expand Down Expand Up @@ -2297,6 +2357,7 @@ def visit_FunctionDef(
curr_scope_id_dict,
)[0]
# self.insert_next_id(curr_scope_id_dict, unique_name)
self.curr_func_args.append(arg.arg)
args.append(
Var(
Name(
Expand Down Expand Up @@ -2335,6 +2396,7 @@ def visit_FunctionDef(
# unique_name = construct_unique_name(self.filenames[-1], arg.arg)
self.insert_next_id(curr_scope_id_dict, arg.arg)
# self.insert_next_id(curr_scope_id_dict, unique_name)
self.curr_func_args.append(arg.arg)
args.append(
Var(
Name(
Expand Down Expand Up @@ -2439,8 +2501,14 @@ def visit_FunctionDef(
for piece in node.body:
if isinstance(piece, ast.Assign):
names = get_node_name(piece)

for var_name in names:

# If something is overwritten in the curr_func_args then we
# remove it here, as it's no longer a function
if var_name in self.curr_func_args:
self.curr_func_args.remove(var_name)

# unique_name = construct_unique_name(
# self.filenames[-1], var_name
# )
Expand All @@ -2451,6 +2519,7 @@ def visit_FunctionDef(
for piece in node.body:

if isinstance(piece, ast.FunctionDef):
self.curr_func_args.append(piece.name)
unique_name = construct_unique_name(self.filenames[-1], piece.name)
self.insert_next_id(curr_scope_id_dict, unique_name)
prev_scope_id_dict[unique_name] = curr_scope_id_dict[unique_name]
Expand Down Expand Up @@ -2489,11 +2558,6 @@ def visit_FunctionDef(
# Merge keys from prev_scope not in cur_scope into cur_scope
# merge_dicts(prev_scope_id_dict, curr_scope_id_dict)

# Visit the deferred functions
#for piece in functions_to_visit:
# to_add = self.visit(piece, curr_scope_id_dict, {})
# body.extend(to_add)

# TODO: Decorators? Returns? Type_comment?
ref = [
SourceRef(
Expand All @@ -2510,6 +2574,9 @@ def visit_FunctionDef(
# TODO: this might need to be different, since Python variables can exist outside of a scope??
prev_scope_id_dict = copy.deepcopy(prev_scope_id_dict_copy)

prev_func_args = copy.deepcopy(self.curr_func_args)
self.curr_func_args = copy.deepcopy(prev_func_args)

# Global level (i.e. module level) functions have their module names appended to them, we make sure
# we have the correct name depending on whether or not we're visiting a global
# level function or a function enclosed within another function
Expand Down Expand Up @@ -4056,7 +4123,6 @@ def visit_Index(
AstNode: Depending on what the value of the Index node is,
different CAST nodes are returned.
"""

return self.visit(node.value, prev_scope_id_dict, curr_scope_id_dict)

@visit.register
Expand Down
Loading
Loading