diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fb1ca56a..28881d1f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: - id: isort - repo: https://github.com/psf/black - rev: 23.10.1 + rev: 23.12.0 hooks: - id: black name: black @@ -21,7 +21,7 @@ repos: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.6.1 + rev: v1.7.1 hooks: - id: mypy additional_dependencies: [types-setuptools, pydantic==1.10.4] diff --git a/ape_vyper/compiler.py b/ape_vyper/compiler.py index 6ec3bfe0..f7a8559c 100644 --- a/ape_vyper/compiler.py +++ b/ape_vyper/compiler.py @@ -6,24 +6,25 @@ from fnmatch import fnmatch from importlib import import_module from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union, cast +from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast import vvm # type: ignore -from ape._pydantic_compat import validator from ape.api import PluginConfig from ape.api.compiler import CompilerAPI from ape.exceptions import ContractLogicError from ape.logging import logger from ape.types import ContractSourceCoverage, ContractType, SourceTraceback, TraceFrame -from ape.utils import GithubClient, cached_property, get_relative_path +from ape.utils import GithubClient, cached_property, get_relative_path, pragma_str_to_specifier_set +from eth_pydantic_types import HexBytes from eth_utils import is_0x_prefixed -from ethpm_types import ASTNode, HexBytes, PackageManifest, PCMap, SourceMapItem +from ethpm_types import ASTNode, PackageManifest, PCMap, SourceMapItem from ethpm_types.ast import ASTClassification from ethpm_types.contract_type import SourceMap -from ethpm_types.source import ContractSource, Function, SourceLocation +from ethpm_types.source import Compiler, ContractSource, Function, SourceLocation from evm_trace.enums import CALL_OPCODES from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import Version +from pydantic import field_serializer, field_validator from vvm import compile_standard as vvm_compile_standard from vvm.exceptions import VyperError # type: ignore @@ -41,6 +42,7 @@ _FUNCTION_AST_TYPES = (_FUNCTION_DEF, "Name", "arguments") _EMPTY_REVERT_OFFSET = 18 _NON_PAYABLE_STR = f"dev: {RuntimeErrorType.NONPAYABLE_CHECK.value}" +Optimization = Union[str, bool] class VyperConfig(PluginConfig): @@ -69,9 +71,16 @@ class VyperConfig(PluginConfig): """ - @validator("version", pre=True) + @field_validator("version", mode="before") def validate_version(cls, value): - return SpecifierSet(_version_to_specifier(value)) if isinstance(value, str) else value + return pragma_str_to_specifier_set(value) if isinstance(value, str) else value + + @field_serializer("version") + def serialize_version(self, value: Optional[SpecifierSet], _info) -> Optional[str]: + if version := value: + return str(version) + + return None def _install_vyper(version: Version): @@ -93,13 +102,13 @@ def get_version_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]: Returns: ``packaging.specifiers.SpecifierSet``, or None if no valid pragma is found. """ - version_pragma_patterns = [ + _version_pragma_patterns: Tuple[str, str] = ( r"(?:\n|^)\s*#\s*@version\s*([^\n]*)", r"(?:\n|^)\s*#\s*pragma\s+version\s*([^\n]*)", - ] + ) source_str = source if isinstance(source, str) else source.read_text() - for pattern in version_pragma_patterns: + for pattern in _version_pragma_patterns: for match in re.finditer(pattern, source_str): raw_pragma = match.groups()[0] pragma_str = " ".join(raw_pragma.split()).replace("^", "~=") @@ -125,24 +134,59 @@ def get_optimization_pragma(source: Union[str, Path]) -> Optional[str]: ``str``, or None if no valid pragma is found. """ source_str = source if isinstance(source, str) else source.read_text() - pragma_match = next(re.finditer(r"(?:\n|^)\s*#pragma\s+optimize\s+([^\n]*)", source_str), None) - if pragma_match is None: - return None + if pragma_match := next( + re.finditer(r"(?:\n|^)\s*#pragma\s+optimize\s+([^\n]*)", source_str), None + ): + return pragma_match.groups()[0] + + return None + + +def get_evmversion_pragma(source: Union[str, Path]) -> Optional[str]: + """ + Extracts evm version pragma information from Vyper source code. + + Args: + source (Union[str, Path]): Vyper source code + + Returns: + ``str``, or None if no valid pragma is found. + """ + source_str = source if isinstance(source, str) else source.read_text() + if pragma_match := next( + re.finditer(r"(?:\n|^)\s*#pragma\s+evm-version\s+([^\n]*)", source_str), None + ): + return pragma_match.groups()[0] - return pragma_match.groups()[0] + return None def get_optimization_pragma_map( - contract_filepaths: List[Path], -) -> Dict[Union[str, bool], Set[Path]]: - optimization_pragma_map: Dict[Union[str, bool], Set[Path]] = {} + contract_filepaths: Sequence[Path], base_path: Path +) -> Dict[str, Optimization]: + pragma_map: Dict[str, Optimization] = {} + for path in contract_filepaths: pragma = get_optimization_pragma(path) or True - if pragma not in optimization_pragma_map: - optimization_pragma_map[pragma] = set() - optimization_pragma_map[pragma].add(path) + source_id = str(get_relative_path(path.absolute(), base_path.absolute())) + pragma_map[source_id] = pragma + + return pragma_map - return optimization_pragma_map + +def get_evm_version_pragma_map( + contract_filepaths: Sequence[Path], base_path: Path +) -> Dict[str, str]: + pragmas: Dict[str, str] = {} + for path in contract_filepaths: + pragma = get_evmversion_pragma(path) + if not pragma: + continue + + source_id = str(get_relative_path(path.absolute(), base_path.absolute())) + pragmas[source_id] = pragma + + return pragmas class VyperCompiler(CompilerAPI): @@ -159,7 +203,7 @@ def evm_version(self) -> Optional[str]: return self.settings.evm_version def get_imports( - self, contract_filepaths: List[Path], base_path: Optional[Path] = None + self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None ) -> Dict[str, List[str]]: base_path = (base_path or self.project_manager.contracts_folder).absolute() import_map = {} @@ -192,7 +236,7 @@ def get_imports( return import_map - def get_versions(self, all_paths: List[Path]) -> Set[str]: + def get_versions(self, all_paths: Sequence[Path]) -> Set[str]: versions = set() for path in all_paths: if version_spec := get_version_pragma_spec(path): @@ -323,7 +367,9 @@ def import_remapping(self) -> Dict[str, Dict]: dependencies[remapping] = dependency for name, ct in (dependency.contract_types or {}).items(): - interfaces[f"{key}/{name}.json"] = {"abi": [x.dict() for x in ct.abi]} + interfaces[f"{key}/{name}.json"] = { + "abi": [x.model_dump(mode="json", by_alias=True) for x in ct.abi] + } return interfaces @@ -335,7 +381,7 @@ def classify_ast(self, _node: ASTNode): self.classify_ast(child) def compile( - self, contract_filepaths: List[Path], base_path: Optional[Path] = None + self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None ) -> List[ContractType]: contract_types = [] base_path = base_path or self.config_manager.contracts_folder @@ -343,22 +389,18 @@ def compile( version_map = self.get_version_map(sources) compiler_data = self._get_compiler_arguments(version_map, base_path) all_settings = self.get_compiler_settings(sources, base_path=base_path) + contract_versions: Dict[str, Tuple[Version, str]] = {} - for vyper_version, source_paths in version_map.items(): - version_settings = all_settings.get(vyper_version, {}) - optimizations_map = get_optimization_pragma_map(list(source_paths)) - - for optimization, source_paths in optimizations_map.items(): - settings: Dict[str, Any] = version_settings.copy() - settings["optimize"] = optimization or True - path_args = { - str(get_relative_path(p.absolute(), base_path)): p for p in source_paths - } - settings["outputSelection"] = {s: ["*"] for s in path_args} + for vyper_version, version_settings in all_settings.items(): + for settings_key, settings in version_settings.items(): + source_ids = settings["outputSelection"] + optimization_paths = {p: base_path / p for p in source_ids} input_json = { "language": "Vyper", "settings": settings, - "sources": {s: {"content": p.read_text()} for s, p in path_args.items()}, + "sources": { + s: {"content": p.read_text()} for s, p in optimization_paths.items() + }, } if interfaces := self.import_remapping: @@ -382,7 +424,7 @@ def compile( } for name, output in output_items.items(): # De-compress source map to get PC POS map. - ast = ASTNode.parse_obj(result["sources"][source_id]["ast"]) + ast = ASTNode.model_validate(result["sources"][source_id]["ast"]) self.classify_ast(ast) # Track function offsets. @@ -399,7 +441,7 @@ def compile( evm = output["evm"] bytecode = evm["deployedBytecode"] opcodes = bytecode["opcodes"].split(" ") - compressed_src_map = SourceMap(__root__=bytecode["sourceMap"]) + compressed_src_map = SourceMap(root=bytecode["sourceMap"]) src_map = list(compressed_src_map.parse())[1:] pcmap = ( @@ -428,6 +470,49 @@ def compile( dev_messages=dev_messages, ) contract_types.append(contract_type) + contract_versions[name] = (vyper_version, settings_key) + + # Output compiler data used. + compilers_used: Dict[Version, Dict[str, Compiler]] = {} + for ct in contract_types: + if not ct.name: + # Won't happen, but just for mypy. + continue + + ct_version, ct_settings_key = contract_versions[ct.name] + settings = all_settings[ct_version][ct_settings_key] + + if ct_version not in compilers_used: + compilers_used[ct_version] = {} + + if ct_settings_key in compilers_used[ct_version] and ct.name not in ( + compilers_used[ct_version][ct_settings_key].contractTypes or [] + ): + # Add contractType to already-tracked compiler. + compilers_used[ct_version][ct_settings_key].contractTypes = [ + *(compilers_used[ct_version][ct_settings_key].contractTypes or []), + ct.name, + ] + + elif ct_settings_key not in compilers_used[ct_version]: + # Add optimization-compiler for the first time. + compilers_used[ct_version][ct_settings_key] = Compiler( + name=self.name.lower(), + version=f"{ct_version}", + contractTypes=[ct.name], + settings=settings, + ) + + # Output compiler data to the cached project manifest. + compilers_ls = [ + compiler + for optimization_settings in compilers_used.values() + for compiler in optimization_settings.values() + ] + + # NOTE: This method handles merging contractTypes and filtered out + # no longer used Compilers. + self.project_manager.local_project.add_compiler_data(compilers_ls) return contract_types @@ -446,20 +531,8 @@ def compile_code(self, code: str, base_path: Optional[Path] = None, **kwargs) -> **kwargs, ) - def get_optimization_pragma_map( - self, contract_filepaths: List[Path] - ) -> Dict[Union[str, bool], Set[Path]]: - optimization_pragma_map: Dict[Union[str, bool], Set[Path]] = {} - for path in contract_filepaths: - pragma = get_optimization_pragma(path) or True - if pragma not in optimization_pragma_map: - optimization_pragma_map[pragma] = set() - optimization_pragma_map[pragma].add(path) - - return optimization_pragma_map - def get_version_map( - self, contract_filepaths: List[Path], base_path: Optional[Path] = None + self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None ) -> Dict[Version, Set[Path]]: version_map: Dict[Version, Set[Path]] = {} source_path_by_version_spec: Dict[SpecifierSet, Set[Path]] = {} @@ -521,23 +594,49 @@ def get_version_map( return version_map def get_compiler_settings( - self, contract_filepaths: List[Path], base_path: Optional[Path] = None + self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None ) -> Dict[Version, Dict]: + valid_paths = [p for p in contract_filepaths if p.suffix == ".vy"] contracts_path = base_path or self.config_manager.contracts_folder - files_by_vyper_version = self.get_version_map(contract_filepaths, base_path=contracts_path) + files_by_vyper_version = self.get_version_map(valid_paths, base_path=contracts_path) if not files_by_vyper_version: return {} compiler_data = self._get_compiler_arguments(files_by_vyper_version, contracts_path) settings = {} for version, data in compiler_data.items(): - source_paths = files_by_vyper_version.get(version) + source_paths = list(files_by_vyper_version.get(version, [])) if not source_paths: continue - version_settings: Dict = {"optimize": True} - if evm_version := data.get("evm_version"): - version_settings["evmVersion"] = evm_version + output_selection: Dict[str, Set[str]] = {} + optimizations_map = get_optimization_pragma_map(source_paths, contracts_path) + evm_version_map = get_evm_version_pragma_map(source_paths, contracts_path) + default_evm_version = data.get("evm_version", data.get("evmVersion")) + for source_path in source_paths: + source_id = str(get_relative_path(source_path.absolute(), contracts_path)) + optimization = optimizations_map.get(source_id, True) + evm_version = evm_version_map.get(source_id, default_evm_version) + settings_key = f"{optimization}%{evm_version}".lower() + if settings_key not in output_selection: + output_selection[settings_key] = {source_id} + else: + output_selection[settings_key].add(source_id) + + version_settings: Dict[str, Dict] = {} + for settings_key, selection in output_selection.items(): + optimization, evm_version = settings_key.split("%") + if optimization == "true": + optimization = True + elif optimization == "false": + optimization = False + + version_settings[settings_key] = { + "optimize": optimization, + "outputSelection": {s: ["*"] for s in selection}, + } + if evm_version and evm_version not in ("none", "null"): + version_settings[settings_key]["evmVersion"] = f"{evm_version}" settings[version] = version_settings @@ -592,7 +691,7 @@ def _profile(_name: str, _full_name: str): # function_name -> (pc, location) pending_statements: Dict[str, List[Tuple[int, SourceLocation]]] = {} - for pc, item in contract_source.pcmap.__root__.items(): + for pc, item in contract_source.pcmap.root.items(): pc_int = int(pc) if pc_int < 0: continue @@ -787,7 +886,7 @@ def trace_source( if source_contract_type := self.project_manager._create_contract_source(contract_type): return self._get_traceback(source_contract_type, trace, calldata) - return SourceTraceback.parse_obj([]) + return SourceTraceback.model_validate([]) def _get_traceback( self, @@ -796,10 +895,10 @@ def _get_traceback( calldata: HexBytes, previous_depth: Optional[int] = None, ) -> SourceTraceback: - traceback = SourceTraceback.parse_obj([]) + traceback = SourceTraceback.model_validate([]) method_id = HexBytes(calldata[:4]) completed = False - pcmap = PCMap.parse_obj({}) + pcmap = PCMap.model_validate({}) for frame in trace: if frame.op in CALL_OPCODES: @@ -858,7 +957,7 @@ def _get_traceback( is_non_payable_hit = False if next_frame and next_frame.op == "SSTORE": push_location = tuple(loc["location"]) # type: ignore - pcmap = PCMap.parse_obj({next_frame.pc: {"location": push_location}}) + pcmap = PCMap.model_validate({next_frame.pc: {"location": push_location}}) elif next_frame and next_frame.op in _RETURN_OPCODES: completed = True @@ -945,7 +1044,6 @@ def _get_traceback( traceback.add_builtin_jump( name, f"dev: {dev}", - self.name, full_name=full_name, pcs=pcs, source_path=contract_src.source_path, @@ -1022,7 +1120,7 @@ def _get_pcmap(bytecode: Dict) -> PCMap: src_info = bytecode["sourceMapFull"] pc_data = {pc: {"location": ln} for pc, ln in src_info["pc_pos_map"].items()} if not pc_data: - return PCMap.parse_obj({}) + return PCMap.model_validate({}) # Apply other errors. errors = src_info["error_map"] @@ -1079,7 +1177,7 @@ def _get_pcmap(bytecode: Dict) -> PCMap: else: pc_data[err_pc] = {"dev": f"dev: {error_str}", "location": location} - return PCMap.parse_obj(pc_data) + return PCMap.model_validate(pc_data) def _get_legacy_pcmap(ast: ASTNode, src_map: List[SourceMapItem], opcodes: List[str]): @@ -1169,7 +1267,8 @@ def _get_legacy_pcmap(ast: ASTNode, src_map: List[SourceMapItem], opcodes: List[ item["dev"] = f"dev: {RuntimeErrorType.USER_ASSERT.value}" break - return PCMap.parse_obj(dict(pc_map_list)) + pcmap_data = dict(pc_map_list) + return PCMap.model_validate(pcmap_data) def _find_non_payable_check(src_map: List[SourceMapItem], opcodes: List[str]) -> Optional[int]: @@ -1231,7 +1330,7 @@ def _extend_return(function: Function, traceback: SourceTraceback, last_pc: int, location = return_ast.line_numbers last_lineno = max(0, location[2] - 1) - for frameset in traceback.__root__[::-1]: + for frameset in traceback.root[::-1]: if frameset.end_lineno is not None: last_lineno = frameset.end_lineno break @@ -1255,9 +1354,9 @@ def _is_fallback_check(opcodes: List[str], op: str) -> bool: ) -def _version_to_specifier(version: str) -> str: - pragma_str = " ".join(version.split()).replace("^", "~=") - if pragma_str and pragma_str[0].isnumeric(): - return f"=={pragma_str}" - - return pragma_str +# def _version_to_specifier(version: str) -> str: +# pragma_str = " ".join(version.split()).replace("^", "~=") +# if pragma_str and pragma_str[0].isnumeric(): +# return f"=={pragma_str}" +# +# return pragma_str diff --git a/setup.py b/setup.py index 64e04fdf..f7e0046d 100644 --- a/setup.py +++ b/setup.py @@ -11,10 +11,9 @@ "hypothesis>=6.2.0,<7.0", # Strategy-based fuzzer ], "lint": [ - "black>=23.10.1,<24", # Auto-formatter and linter - "mypy>=1.6.1", # Static type analyzer + "black>=23.12.0,<24", # Auto-formatter and linter + "mypy>=1.7.1", # Static type analyzer "types-setuptools", # Needed due to mypy typeshed - "pydantic<2.0", # Needed for successful type check. TODO: Remove after full v2 support. "flake8>=6.1.0,<7", # Style linter "isort>=5.10.1", # Import sorting linter "mdformat>=0.7.17", # Auto-formatter for markdown @@ -59,7 +58,7 @@ url="https://github.com/ApeWorX/ape-vyper", include_package_data=True, install_requires=[ - "eth-ape>=0.6.23,<0.7", + "eth-ape>=0.7,<0.8", "ethpm-types", # Use same version as eth-ape "tqdm", # Use same version as eth-ape "vvm>=0.2.0,<0.3", diff --git a/tests/contracts/passing_contracts/evm_pragma.vy b/tests/contracts/passing_contracts/evm_pragma.vy new file mode 100644 index 00000000..ae1bfd63 --- /dev/null +++ b/tests/contracts/passing_contracts/evm_pragma.vy @@ -0,0 +1,8 @@ +#pragma version 0.3.10 +#pragma evm-version paris + +x: uint256 + +@external +def __init__(): + self.x = 0 diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 994d8b7d..5d42a37b 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -100,6 +100,7 @@ def test_get_version_map(project, compiler, all_versions): "erc20.vy", "use_iface.vy", "optimize_codesize.vy", + "evm_pragma.vy", "use_iface2.vy", "contract_no_pragma.vy", # no pragma should compile with latest version "empty.vy", # empty file still compiles with latest version @@ -132,25 +133,45 @@ def test_get_version_map(project, compiler, all_versions): def test_compiler_data_in_manifest(project): - _ = project.contracts - manifest = project.extract_manifest() - assert len(manifest.compilers) >= 3, manifest.compilers - - vyper_latest = [c for c in manifest.compilers if str(c.version) == str(VERSION_FROM_PRAGMA)][0] - vyper_028 = [c for c in manifest.compilers if str(c.version) == str(OLDER_VERSION_FROM_PRAGMA)][ - 0 - ] - - for compiler in (vyper_028, vyper_latest): - assert compiler.name == "vyper" - - assert len(vyper_latest.contractTypes) >= 9 - assert len(vyper_028.contractTypes) >= 1 - assert "contract_0310" in vyper_latest.contractTypes - assert "older_version" in vyper_028.contractTypes - for compiler in (vyper_latest, vyper_028): - assert compiler.settings["evmVersion"] == "istanbul" - assert compiler.settings["optimize"] is True + def run_test(manifest): + assert len(manifest.compilers) >= 3, manifest.compilers + + all_latest = [c for c in manifest.compilers if str(c.version) == str(VERSION_FROM_PRAGMA)] + codesize_latest = [c for c in all_latest if c.settings["optimize"] == "codesize"][0] + evm_latest = [c for c in all_latest if c.settings["evmVersion"] == "paris"][0] + true_latest = [ + c + for c in all_latest + if c.settings["optimize"] is True and c.settings["evmVersion"] != "paris" + ][0] + vyper_028 = [ + c for c in manifest.compilers if str(c.version) == str(OLDER_VERSION_FROM_PRAGMA) + ][0] + + for compiler in (vyper_028, codesize_latest, true_latest): + assert compiler.name == "vyper" + assert compiler.settings["evmVersion"] == "istanbul" + + # There is only one contract with codesize pragma. + assert codesize_latest.contractTypes == ["optimize_codesize"] + assert codesize_latest.settings["optimize"] == "codesize" + + # There is only one contract with evm-version pragma. + assert evm_latest.contractTypes == ["evm_pragma"] + assert evm_latest.settings["evmVersion"] == "paris" + + assert len(true_latest.contractTypes) >= 9 + assert len(vyper_028.contractTypes) >= 1 + assert "contract_0310" in true_latest.contractTypes + assert "older_version" in vyper_028.contractTypes + for compiler in (true_latest, vyper_028): + assert compiler.settings["optimize"] is True + + project.local_project.update_manifest(compilers=[]) + project.load_contracts(use_cache=False) + run_test(project.local_project.manifest) + man = project.extract_manifest() + run_test(man) def test_compile_parse_dev_messages(compiler, dev_revert_source, project): @@ -202,7 +223,7 @@ def test_pc_map(compiler, project, src, vers): path = project.contracts_folder / f"{src}.vy" result = compiler.compile([path], base_path=project.contracts_folder)[0] - actual = result.pcmap.__root__ + actual = result.pcmap.root code = path.read_text() compile_result = compile_source(code, vyper_version=vers, evm_version=compiler.evm_version)[ ""