Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[tool]: add all imported modules to -f annotated_ast output #4209

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions tests/unit/ast/test_ast_dict.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import json

from vyper import compiler
Expand Down Expand Up @@ -216,24 +217,27 @@ def foo():
input_bundle = make_input_bundle({"lib1.vy": lib1, "main.vy": main})

lib1_file = input_bundle.load_file("lib1.vy")
out = compiler.compile_from_file_input(
lib1_out = compiler.compile_from_file_input(
lib1_file, input_bundle=input_bundle, output_formats=["annotated_ast_dict"]
)
lib1_ast = out["annotated_ast_dict"]["ast"]

lib1_ast = copy.deepcopy(lib1_out["annotated_ast_dict"]["ast"])
lib1_sha256sum = lib1_ast.pop("source_sha256sum")
assert lib1_sha256sum == lib1_file.sha256sum
to_strip = NODE_SRC_ATTRIBUTES + ("resolved_path", "variable_reads", "variable_writes")
_strip_source_annotations(lib1_ast, to_strip=to_strip)

main_file = input_bundle.load_file("main.vy")
out = compiler.compile_from_file_input(
main_out = compiler.compile_from_file_input(
main_file, input_bundle=input_bundle, output_formats=["annotated_ast_dict"]
)
main_ast = out["annotated_ast_dict"]["ast"]
main_ast = main_out["annotated_ast_dict"]["ast"]
main_sha256sum = main_ast.pop("source_sha256sum")
assert main_sha256sum == main_file.sha256sum
_strip_source_annotations(main_ast, to_strip=to_strip)

assert main_out["annotated_ast_dict"]["imports"][0] == lib1_out["annotated_ast_dict"]["ast"]

# TODO: would be nice to refactor this into bunch of small test cases
assert main_ast == {
"ast_type": "Module",
Expand Down
6 changes: 6 additions & 0 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,12 @@ def validate(self):
class Ellipsis(Constant):
__slots__ = ()

def to_dict(self):
ast_dict = super().to_dict()
# python ast ellipsis() is not json serializable; use a string
ast_dict["value"] = self.node_source_code
return ast_dict


class Dict(ExprNode):
__slots__ = ("keys", "values")
Expand Down
1 change: 1 addition & 0 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class TopLevel(VyperNode):
class Module(TopLevel):
path: str = ...
resolved_path: str = ...
source_id: int = ...
def namespace(self) -> Any: ... # context manager

class FunctionDef(TopLevel):
Expand Down
27 changes: 26 additions & 1 deletion vyper/compiler/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
from collections import deque
from pathlib import PurePath

from vyper.ast import ast_to_dict
import vyper.ast as vy_ast
Fixed Show fixed Hide fixed
from vyper.ast.utils import ast_to_dict
from vyper.codegen.ir_node import IRnode
from vyper.compiler.output_bundle import SolcJSONWriter, VyperArchiveWriter
from vyper.compiler.phases import CompilerData
from vyper.compiler.utils import build_gas_estimates
from vyper.evm import opcodes
from vyper.exceptions import VyperException
from vyper.ir import compile_ir
from vyper.semantics.analysis.base import ModuleInfo
from vyper.semantics.types.function import FunctionVisibility, StateMutability
from vyper.semantics.types.module import InterfaceT
from vyper.typing import StorageLayout
from vyper.utils import vyper_warn
from vyper.warnings import ContractSizeLimitWarning
Expand All @@ -26,9 +29,31 @@ def build_ast_dict(compiler_data: CompilerData) -> dict:


def build_annotated_ast_dict(compiler_data: CompilerData) -> dict:
module_t = compiler_data.annotated_vyper_module._metadata["type"]
imported_module_infos = module_t.reachable_imports
unique_modules: dict[str, vy_ast.Module] = {}
for info in imported_module_infos:
if isinstance(info.typ, InterfaceT):
ast = info.typ.decl_node
if ast is None: # json abi
continue
else:
assert isinstance(info.typ, ModuleInfo)
ast = info.typ.module_t._module

assert isinstance(ast, vy_ast.Module) # help mypy
# use resolved_path for uniqueness, since Module objects can actually
# come from multiple InputBundles (particularly builtin interfaces),
# so source_id is not guaranteed to be unique.
if ast.resolved_path in unique_modules:
# sanity check -- object equality
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is operator checks for identity, not equality

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, that's what I meant by object ("pointer") equality

assert unique_modules[ast.resolved_path] is ast
unique_modules[ast.resolved_path] = ast

annotated_ast_dict = {
"contract_name": str(compiler_data.contract_path),
"ast": ast_to_dict(compiler_data.annotated_vyper_module),
"imports": [ast_to_dict(ast) for ast in unique_modules.values()],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might be confusing to have all the imports flattened into main as this also includes indirect imports

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a way of normalization, it avoids duplicating trees in the output

}
return annotated_ast_dict

Expand Down
Loading