From 15b75988703e2ae765fea2045d62d47cbdc73af1 Mon Sep 17 00:00:00 2001 From: Mike Shultz Date: Fri, 22 Mar 2024 11:25:59 -0600 Subject: [PATCH] feat: contract flattener (#107) * feat: contract flattener * fix: add missing dep vyper * fix: try lowering vyper min version for Python 3.8 support * fix(test): install necessary vyper versions in test * fix: compiler version handling in compiler.compile_code() * style(lint): unused import * fix(test): explicitly install compiler in test_pc_map * feat: adds `ape vyper flatten` command * test: add CLI test for flattener * chore: bump minimum eth-ape version to 0.7.12 * docs: adds Contract Flattening section to README * chore: bump eth-ape minimum ersion to 0.7.13 * fix(docs): fix warning directive in README * docs: update return value docstring Co-authored-by: antazoey * docs: speeling Co-authored-by: antazoey * refactor: splitlines() * refactor: limit `vyper flatten` command to vyper only Co-authored-by: antazoey * fix(docs): bad bug report link Co-authored-by: antazoey * style(docs): period Co-authored-by: antazoey * fix(docs): comment spelling Co-authored-by: antazoey * fix(docs): comment spelling Co-authored-by: antazoey * refactor: splitlines Co-authored-by: antazoey * fix: missing type hint Co-authored-by: antazoey * docs: how to format returns Co-authored-by: antazoey * refactor: not is None Co-authored-by: antazoey * refactor: check with installed versions when compiling before installing during no-pragma fallback --------- Co-authored-by: antazoey --- README.md | 13 ++ ape_vyper/_cli.py | 25 +++ ape_vyper/ast.py | 101 ++++++++++++ ape_vyper/compiler.py | 150 +++++++++++++++++- ape_vyper/interface.py | 140 ++++++++++++++++ setup.py | 8 +- tests/conftest.py | 6 + .../contracts/passing_contracts/flatten_me.vy | 34 ++++ tests/test_cli.py | 31 ++++ tests/test_compiler.py | 32 +++- 10 files changed, 530 insertions(+), 10 deletions(-) create mode 100644 ape_vyper/_cli.py create mode 100644 ape_vyper/ast.py create mode 100644 ape_vyper/interface.py create mode 100644 tests/contracts/passing_contracts/flatten_me.vy create mode 100644 tests/test_cli.py diff --git a/README.md b/README.md index d4585653..3d728800 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,19 @@ ape compile The `.vy` files in your project will compile into `ContractTypes` that you can deploy and interact with in Ape. +### Contract Flattening + +For ease of publishing, validation, and some other cases it's sometimes useful to "flatten" your contract into a single file. +This combines your contract and any imported interfaces together in a way the compiler can understand. +You can do so with a command like this: + +```bash +ape vyper flatten contracts/MyContract.vy build/MyContractFlattened.vy +``` + +> \[!WARNING\] +> This feature is experimental. Please [report any bugs](https://github.com/ApeWorX/ape-solidity/issues/new?assignees=&labels=bug&projects=&template=bug.md) you find when trying it out. + ### Compiler Version By default, the `ape-vyper` plugin uses version pragma for version specification. diff --git a/ape_vyper/_cli.py b/ape_vyper/_cli.py new file mode 100644 index 00000000..665bfff1 --- /dev/null +++ b/ape_vyper/_cli.py @@ -0,0 +1,25 @@ +from pathlib import Path + +import ape +import click +from ape.cli import ape_cli_context + + +@click.group +def cli(): + """`vyper` command group""" + + +@cli.command(short_help="Flatten select contract source files") +@ape_cli_context() +@click.argument("CONTRACT", type=click.Path(exists=True, resolve_path=True)) +@click.argument("OUTFILE", type=click.Path(exists=False, resolve_path=True, writable=True)) +def flatten(cli_ctx, contract: Path, outfile: Path): + """ + Flatten a contract into a single file + """ + with Path(outfile).open("w") as fout: + content = ape.compilers.vyper.flatten_contract( + Path(contract), base_path=ape.project.contracts_folder + ) + fout.write(str(content)) diff --git a/ape_vyper/ast.py b/ape_vyper/ast.py new file mode 100644 index 00000000..234ca492 --- /dev/null +++ b/ape_vyper/ast.py @@ -0,0 +1,101 @@ +"""Utilities for dealing with Vyper AST""" + +from typing import List + +from ethpm_types import ABI, MethodABI +from ethpm_types.abi import ABIType +from vyper.ast import parse_to_ast # type: ignore +from vyper.ast.nodes import FunctionDef, Module, Name, Subscript # type: ignore + +DEFAULT_VYPER_MUTABILITY = "nonpayable" +DECORATOR_MUTABILITY = { + "pure", # Function does not read contract state or environment variables + "view", # Function does not alter contract state + "payable", # Function is able to receive Ether and may alter state + "nonpayable", # Function may alter sate +} + + +def funcdef_decorators(funcdef: FunctionDef) -> List[str]: + return [d.id for d in funcdef.get("decorator_list") or []] + + +def funcdef_inputs(funcdef: FunctionDef) -> List[ABIType]: + """Get a FunctionDef's defined input args""" + args = funcdef.get("args") + # TODO: Does Vyper allow complex input types, like structs and arrays? + return ( + [ABIType.model_validate({"name": arg.arg, "type": arg.annotation.id}) for arg in args.args] + if args + else [] + ) + + +def funcdef_outputs(funcdef: FunctionDef) -> List[ABIType]: + """Get a FunctionDef's outputs, or return values""" + returns = funcdef.get("returns") + + if not returns: + return [] + + if isinstance(returns, Name): + # TODO: Structs fall in here. I think they're supposed to be a tuple of types in the ABI. + # Need to dig into that more. + return [ABIType.model_validate({"type": returns.id})] + + elif isinstance(returns, Subscript): + # An array type + length = returns.slice.value.value + array_type = returns.value.id + # TOOD: Is this an acurrate way to define a fixed length array for ABI? + return [ABIType.model_validate({"type": f"{array_type}[{length}]"})] + + raise NotImplementedError(f"Unhandled return type {type(returns)}") + + +def funcdef_state_mutability(funcdef: FunctionDef) -> str: + """Get a FunctionDef's declared state mutability""" + for decorator in funcdef_decorators(funcdef): + if decorator in DECORATOR_MUTABILITY: + return decorator + return DEFAULT_VYPER_MUTABILITY + + +def funcdef_is_external(funcdef: FunctionDef) -> bool: + """Check if a FunctionDef is declared external""" + for decorator in funcdef_decorators(funcdef): + if decorator == "external": + return True + return False + + +def funcdef_to_abi(func: FunctionDef) -> ABI: + """Return a MethodABI instance for a Vyper FunctionDef""" + return MethodABI.model_validate( + { + "name": func.get("name"), + "inputs": funcdef_inputs(func), + "outputs": funcdef_outputs(func), + "stateMutability": funcdef_state_mutability(func), + } + ) + + +def module_to_abi(module: Module) -> List[ABI]: + """ + Create a list of MethodABIs from a Vyper AST Module instance. + """ + abi = [] + for child in module.get_children(): + if isinstance(child, FunctionDef): + abi.append(funcdef_to_abi(child)) + return abi + + +def source_to_abi(source: str) -> List[ABI]: + """ + Given Vyper source code, return a list of Ape ABI elements needed for an external interface. + This currently does not include complex types or events. + """ + module = parse_to_ast(source) + return module_to_abi(module) diff --git a/ape_vyper/compiler.py b/ape_vyper/compiler.py index f7a8559c..e7691fd0 100644 --- a/ape_vyper/compiler.py +++ b/ape_vyper/compiler.py @@ -6,7 +6,7 @@ from fnmatch import fnmatch from importlib import import_module from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast +from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast import vvm # type: ignore from ape.api import PluginConfig @@ -20,7 +20,7 @@ from ethpm_types import ASTNode, PackageManifest, PCMap, SourceMapItem from ethpm_types.ast import ASTClassification from ethpm_types.contract_type import SourceMap -from ethpm_types.source import Compiler, ContractSource, Function, SourceLocation +from ethpm_types.source import Compiler, Content, ContractSource, Function, SourceLocation from evm_trace.enums import CALL_OPCODES from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import Version @@ -28,6 +28,7 @@ from vvm import compile_standard as vvm_compile_standard from vvm.exceptions import VyperError # type: ignore +from ape_vyper.ast import source_to_abi from ape_vyper.exceptions import ( RUNTIME_ERROR_MAP, IntegerBoundsCheck, @@ -35,6 +36,13 @@ VyperCompileError, VyperInstallError, ) +from ape_vyper.interface import ( + extract_import_aliases, + extract_imports, + extract_meta, + generate_interface, + iface_name_from_file, +) DEV_MSG_PATTERN = re.compile(r".*\s*#\s*(dev:.+)") _RETURN_OPCODES = ("RETURN", "REVERT", "STOP") @@ -333,12 +341,11 @@ def config_version_pragma(self) -> Optional[SpecifierSet]: return None @property - def import_remapping(self) -> Dict[str, Dict]: + def remapped_manifests(self) -> Dict[str, PackageManifest]: """ - Configured interface imports from dependencies. + Interface import manifests. """ - interfaces = {} dependencies: Dict[str, PackageManifest] = {} for remapping in self.settings.import_remapping: @@ -366,7 +373,19 @@ def import_remapping(self) -> Dict[str, Dict]: dependency = dependency_versions[version].compile() dependencies[remapping] = dependency - for name, ct in (dependency.contract_types or {}).items(): + return dependencies + + @property + def import_remapping(self) -> Dict[str, Dict]: + """ + Configured interface imports from dependencies. + """ + + interfaces = {} + + for remapping in self.settings.import_remapping: + key, _ = remapping.split("=") + for name, ct in (self.remapped_manifests[remapping].contract_types or {}).items(): interfaces[f"{key}/{name}.json"] = { "abi": [x.model_dump(mode="json", by_alias=True) for x in ct.abi] } @@ -518,8 +537,14 @@ def compile( def compile_code(self, code: str, base_path: Optional[Path] = None, **kwargs) -> ContractType: base_path = base_path or self.project_manager.contracts_folder + + # Figure out what compiler version we need for this contract... + version = self._source_vyper_version(code) + # ...and install it if necessary + _install_vyper(version) + try: - result = vvm.compile_source(code, base_path=base_path) + result = vvm.compile_source(code, base_path=base_path, vyper_version=version) except Exception as err: raise VyperCompileError(str(err)) from err @@ -531,6 +556,117 @@ def compile_code(self, code: str, base_path: Optional[Path] = None, **kwargs) -> **kwargs, ) + def _source_vyper_version(self, code: str) -> Version: + """Given source code, figure out which Vyper version to use""" + version_spec = get_version_pragma_spec(code) + + def first_full_release(versions: Iterable[Version]) -> Optional[Version]: + for vers in versions: + if not vers.is_devrelease and not vers.is_postrelease and not vers.is_prerelease: + return vers + return None + + if version_spec is None: + if version := first_full_release(self.installed_versions + self.available_versions): + return version + raise VyperInstallError("No available version.") + + return next(version_spec.filter(self.available_versions)) + + def _flatten_source( + self, path: Path, base_path: Optional[Path] = None, raw_import_name: Optional[str] = None + ) -> str: + base_path = base_path or self.config_manager.contracts_folder + + # Get the non stdlib import paths for our contracts + imports = list( + filter( + lambda x: not x.startswith("vyper/"), + [y for x in self.get_imports([path], base_path).values() for y in x], + ) + ) + + dependencies: Dict[str, PackageManifest] = {} + for key, manifest in self.remapped_manifests.items(): + package = key.split("=")[0] + + if manifest.sources is None: + continue + + for source_id in manifest.sources.keys(): + import_match = f"{package}/{source_id}" + dependencies[import_match] = manifest + + flattened_source = "" + interfaces_source = "" + og_source = (base_path / path).read_text() + + # Get info about imports and source meta + aliases = extract_import_aliases(og_source) + pragma, source_without_meta = extract_meta(og_source) + stdlib_imports, _, source_without_imports = extract_imports(source_without_meta) + + for import_path in sorted(imports): + import_file = base_path / import_path + + # Vyper imported interface names come from their file names + file_name = iface_name_from_file(import_file) + # If we have a known alias, ("import X as Y"), use the alias as interface name + iface_name = aliases[file_name] if file_name in aliases else file_name + + # We need to compare without extensions because sometimes they're made up for some + # reason. TODO: Cleaner way to deal with this? + def _match_source(import_path: str) -> Optional[PackageManifest]: + import_path_name = ".".join(import_path.split(".")[:-1]) + for source_path in dependencies.keys(): + if source_path.startswith(import_path_name): + return dependencies[source_path] + return None + + if matched_source := _match_source(import_path): + if not matched_source.contract_types: + continue + + abis = [ + el + for k in matched_source.contract_types.keys() + for el in matched_source.contract_types[k].abi + ] + interfaces_source += generate_interface(abis, iface_name) + continue + + # Vyper imported interface names come from their file names + file_name = iface_name_from_file(import_file) + # Generate an ABI from the source code + abis = source_to_abi(import_file.read_text()) + interfaces_source += generate_interface(abis, iface_name) + + def no_nones(it: Iterable[Optional[str]]) -> Iterable[str]: + # Type guard like generator to remove Nones and make mypy happy + for el in it: + if el is not None: + yield el + + # Join all the OG and generated parts back together + flattened_source = "\n\n".join( + no_nones((pragma, stdlib_imports, interfaces_source, source_without_imports)) + ) + + # TODO: Replace this nonsense with a real code formatter + def format_source(source: str) -> str: + while "\n\n\n\n" in source: + source = source.replace("\n\n\n\n", "\n\n\n") + return source + + return format_source(flattened_source) + + def flatten_contract(self, path: Path, base_path: Optional[Path] = None) -> Content: + """ + Returns the flattened contract suitable for compilation or verification as a single file + """ + source = self._flatten_source(path, base_path, path.name) + return Content({i: ln for i, ln in enumerate(source.splitlines())}) + def get_version_map( self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None ) -> Dict[Version, Set[Path]]: diff --git a/ape_vyper/interface.py b/ape_vyper/interface.py new file mode 100644 index 00000000..1b5acfcd --- /dev/null +++ b/ape_vyper/interface.py @@ -0,0 +1,140 @@ +""" +Tools for working with ABI specs and Vyper interface source code +""" + +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from ethpm_types import ABI, MethodABI +from ethpm_types.abi import ABIType + +INDENT_SPACES = 4 +INDENT = " " * INDENT_SPACES + + +def indent_line(line: str, level=1) -> str: + """Indent a source line of code""" + return f"{INDENT * level}{line}" + + +def iface_name_from_file(fpath: Path) -> str: + """Get Interface name from file path""" + return fpath.name.split(".")[0] + + +def generate_inputs(inputs: List[ABIType]) -> str: + """Generate the source code input args from ABI inputs""" + return ", ".join(f"{i.name}: {i.type}" for i in inputs) + + +def generate_method(abi: MethodABI) -> str: + """Generate Vyper interface method definition""" + inputs = generate_inputs(abi.inputs) + return_maybe = f" -> {abi.outputs[0].type}" if abi.outputs else "" + return f"def {abi.name}({inputs}){return_maybe}: {abi.stateMutability}\n" + + +def abi_to_type(iface: Dict[str, Any]) -> Optional[ABI]: + """Convert a dict JSON-like interface to an ethpm-types ABI type""" + if iface["type"] == "function": + return MethodABI.model_validate(iface) + return None + + +def generate_interface(abi: Union[List[Dict[str, Any]], List[ABI]], iface_name: str) -> str: + """ + Generate a Vyper interface source code from an ABI spec + + Args: + abi (List[Union[Dict[str, Any], ABI]]): An ABI spec for a contract + iface_name (str): The name of the interface + + Returns: + ``str`` Vyper source code for the interface + """ + source = f"interface {iface_name}:\n" + + for iface in abi: + if isinstance(iface, dict): + _iface = abi_to_type(iface) + + if _iface is None: + continue + + # Re-assignment after None check because mypy + iface = _iface + + if isinstance(iface, MethodABI): + source += indent_line(generate_method(iface)) + + return f"{source}\n" + + +def extract_meta(source_code: str) -> Tuple[Optional[str], str]: + """Extract version pragma, and returne cleaned source""" + version_pragma: Optional[str] = None + cleaned_source_lines: List[str] = [] + + """ + Pragma format changed a bit. + + >= 3.10: #pragma version ^0.3.0 + < 3.10: # @version ^0.3.0 + + Both are valid until 0.4 where the latter may be deprecated + """ + for line in source_code.splitlines(): + if line.startswith("#") and ( + ("pragma version" in line or "@version" in line) and version_pragma is None + ): + version_pragma = line + else: + cleaned_source_lines.append(line) + + return (version_pragma, "\n".join(cleaned_source_lines)) + + +def extract_imports(source: str) -> Tuple[str, str, str]: + """ + Extract import lines from the source, return them and the source without imports + + Returns: + Tuple[str, str, str]: (stdlib_import_lines, interface_import_lines, cleaned_source) + """ + interface_import_lines = [] + stdlib_import_lines = [] + cleaned_source_lines = [] + + for line in source.splitlines(): + if line.startswith("import ") or (line.startswith("from ") and " import " in line): + if "vyper.interfaces" in line: + stdlib_import_lines.append(line) + else: + interface_import_lines.append(line) + else: + cleaned_source_lines.append(line) + + return ( + "\n".join(stdlib_import_lines), + "\n".join(interface_import_lines), + "\n".join(cleaned_source_lines), + ) + + +def extract_import_aliases(source: str) -> Dict[str, str]: + """ + Extract import aliases from import lines + + Returns: + Dict[str, str]: {import: alias} + """ + aliases = {} + for line in source.splitlines(): + if ( + line.startswith("import ") or (line.startswith("from ") and " import " in line) + ) and " as " in line: + subject_parts = line.split("import ")[1] + alias_parts = subject_parts.split(" as ") + iface_path_name = alias_parts[0].split(".")[-1] # Remove path parts from import + aliases[iface_path_name] = alias_parts[1] + return aliases diff --git a/setup.py b/setup.py index f7e0046d..678406a1 100644 --- a/setup.py +++ b/setup.py @@ -58,14 +58,20 @@ url="https://github.com/ApeWorX/ape-vyper", include_package_data=True, install_requires=[ - "eth-ape>=0.7,<0.8", + "eth-ape>=0.7.13,<0.8", "ethpm-types", # Use same version as eth-ape "tqdm", # Use same version as eth-ape "vvm>=0.2.0,<0.3", + "vyper~=0.3.7", ], python_requires=">=3.8,<4", extras_require=extras_require, py_modules=["ape_vyper"], + entry_points={ + "ape_cli_subcommands": [ + "ape_vyper=ape_vyper._cli:cli", + ], + }, license="Apache-2.0", zip_safe=False, keywords="ethereum", diff --git a/tests/conftest.py b/tests/conftest.py index 1e83701c..54b79248 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ import ape import pytest import vvm # type: ignore +from click.testing import CliRunner # NOTE: Ensure that we don't use local paths for these DATA_FOLDER = Path(mkdtemp()).resolve() @@ -211,6 +212,11 @@ def all_versions(): return ALL_VERSIONS +@pytest.fixture +def cli_runner(): + return CliRunner() + + def _get_tb_contract(version: str, project, account): registry_type = project.get_contract(f"registry_{version}") registry = account.deploy(registry_type) diff --git a/tests/contracts/passing_contracts/flatten_me.vy b/tests/contracts/passing_contracts/flatten_me.vy new file mode 100644 index 00000000..456e2457 --- /dev/null +++ b/tests/contracts/passing_contracts/flatten_me.vy @@ -0,0 +1,34 @@ +# pragma version 0.3.10 + +from vyper.interfaces import ERC20 + +from interfaces import IFace2 as IFaceTwo +import interfaces.IFace as IFace +import exampledep.Dependency as Dep + + +@external +@view +def read_contract(some_address: address) -> uint256: + myContract: IFace = IFace(some_address) + return myContract.read_stuff() + + +@external +@view +def read_another_contract(some_address: address) -> uint256: + two: IFaceTwo = IFaceTwo(some_address) + return two.read_stuff_3() + + +@external +@view +def read_from_dep(some_address: address) -> uint256: + dep: Dep = Dep(some_address) + return dep.read_stuff_2() + + +@external +def send_me(token_address: address, amount: uint256) -> bool: + token: ERC20 = ERC20(token_address) + return token.transferFrom(msg.sender, self, amount, default_return_value=True) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 00000000..4be8b562 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,31 @@ +from pathlib import Path +from tempfile import NamedTemporaryFile + +import pytest + +from ape_vyper._cli import cli + + +@pytest.mark.parametrize( + "contract_name,expected", + [ + # This first one has most known edge cases + ( + "flatten_me.vy", + [ + "from vyper.interfaces import ERC20", + "interface Dep:", + "interface IFace:", + "interface IFaceTwo:", + ], + ), + ], +) +def test_cli_flatten(project, contract_name, expected, cli_runner): + path = project.contracts_folder / contract_name + with NamedTemporaryFile() as tmpfile: + result = cli_runner.invoke(cli, ["flatten", str(path), tmpfile.name]) + assert result.exit_code == 0, result.stderr_bytes + output = Path(tmpfile.name).read_text() + for expect in expected: + assert expect in output diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 5d42a37b..98f00972 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -1,10 +1,10 @@ import re import pytest +import vvm # type: ignore from ape.exceptions import ContractLogicError from ethpm_types import ContractType from packaging.version import Version -from vvm import compile_source # type: ignore from vvm.exceptions import VyperError # type: ignore from ape_vyper.compiler import RuntimeErrorType @@ -105,6 +105,7 @@ def test_get_version_map(project, compiler, all_versions): "contract_no_pragma.vy", # no pragma should compile with latest version "empty.vy", # empty file still compiles with latest version "pragma_with_space.vy", + "flatten_me.vy", ] # Add the 0.3.10 contracts. @@ -225,7 +226,8 @@ def test_pc_map(compiler, project, src, vers): result = compiler.compile([path], base_path=project.contracts_folder)[0] actual = result.pcmap.root code = path.read_text() - compile_result = compile_source(code, vyper_version=vers, evm_version=compiler.evm_version)[ + vvm.install_vyper(vers) + compile_result = vvm.compile_source(code, vyper_version=vers, evm_version=compiler.evm_version)[ "" ] src_map = compile_result["source_map"] @@ -503,3 +505,29 @@ def test_compile_with_version_set_in_settings_dict(config, compiler_manager, pro ) with pytest.raises(VyperCompileError, match=expected): compiler_manager.compile([contract], settings={"version": "0.3.3"}) + + +@pytest.mark.parametrize( + "contract_name", + [ + # This first one has most known edge cases + "flatten_me.vy", + # Test on the below for general compatibility. + "contract_with_dev_messages.vy", + "erc20.vy", + "use_iface.vy", + "optimize_codesize.vy", + "evm_pragma.vy", + "use_iface2.vy", + "contract_no_pragma.vy", # no pragma should compile with latest version + "empty.vy", # empty file still compiles with latest version + "pragma_with_space.vy", + ], +) +def test_flatten_contract(all_versions, project, contract_name, compiler): + path = project.contracts_folder / contract_name + source = compiler.flatten_contract(path) + source_code = str(source) + version = compiler._source_vyper_version(source_code) + vvm.install_vyper(str(version)) + vvm.compile_source(source_code, base_path=project.contracts_folder, vyper_version=version)