diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7b766ee..c96db1d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -45,6 +45,9 @@ jobs: - name: Linter run: | pycodestyle --filename=protobuf-uml-diagram --exclude=.git,__pycache__,.tox,venv,protobuf_uml_diagram.egg-info,.pytest_cache --max-line-length=120 + - name: MyPy + run: | + mypy protobuf_uml_diagram.py - name: Unit tests with coverage run: | coverage run -p setup.py test diff --git a/protobuf_uml_diagram.py b/protobuf_uml_diagram.py index b591d77..d669f6c 100644 --- a/protobuf_uml_diagram.py +++ b/protobuf_uml_diagram.py @@ -19,38 +19,38 @@ import logging from importlib import import_module from io import StringIO +from os import PathLike from pathlib import Path from string import Template from types import ModuleType -from typing import List, Tuple, Union +from typing import cast, List, Optional, Tuple, Union import click from google.protobuf.descriptor import Descriptor, FieldDescriptor from google.protobuf.descriptor_pb2 import FieldDescriptorProto -from graphviz import Source +from graphviz import Source # type: ignore # TODO: https://github.com/xflr6/graphviz/issues/203 logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) -Text = Union[str, bytes] - # https://github.com/pallets/click/issues/405#issuecomment-470812067 class PathPath(click.Path): """A Click path argument that returns a pathlib Path, not a string""" - def convert(self, value: Text, param: Text, ctx) -> Path: + def convert(self, value: Union[str, PathLike[str]], param: Optional[click.Parameter], ctx: Union[click.Context, None]) -> Path: """Convert a text parameter into a ``Path`` object. :param value: parameter value - :type value: Text + :type value: str :param param: parameter name - :type param: Text + :type param: str :param ctx: context - :type ctx: object + :type ctx: click.Context :return: a ``Path`` object :rtype: Path """ - return Path(super().convert(value, param, ctx)) + p = super().convert(value, param, ctx) + return Path(cast(Path, p)) # -- UML diagram @@ -73,15 +73,15 @@ def _process_module(proto_module: ModuleType) -> Tuple[List[str], List[str]]: :return: list of descriptors :rtype: List[Descriptor] """ - classes = [] - relationships = [] + classes: List[str] = [] + relationships: List[str] = [] for type_name, type_descriptor in proto_module.DESCRIPTOR.message_types_by_name.items(): _process_descriptor(type_descriptor, classes, relationships) return classes, relationships -def _process_descriptor(descriptor: Descriptor, classes: list, - relationships: list) -> None: +def _process_descriptor(descriptor: Descriptor, classes: List[str], + relationships: List[str]) -> None: """ :param descriptor: a Protobuf descriptor :type descriptor: Descriptor @@ -173,8 +173,8 @@ def _module(proto: str) -> ModuleType: class Diagram: """A diagram builder.""" - _proto_module: ModuleType = None - _rendered_filename: str = None + _proto_module: Union[ModuleType, None] = None + _rendered_filename: Union[str, None] = None _file_format = "png" def from_file(self, proto_file: str): @@ -191,6 +191,8 @@ def from_file(self, proto_file: str): def to_file(self, output: Path): if not output: raise ValueError("Missing output location!") + if not self._proto_module or not self._proto_module.__file__: + raise ValueError("Missing protobuf module!") uml_file = Path(self._proto_module.__file__).stem self._rendered_filename = str(output.joinpath(uml_file)) return self diff --git a/setup.py b/setup.py index 76612fe..43f67b4 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ import codecs from os.path import join, dirname, abspath -from setuptools import setup +from setuptools import setup # type: ignore here = abspath(dirname(__file__)) @@ -35,6 +35,11 @@ def read(*parts): 'pytest-runner>=4.1,<6.1' ] +mypy_requires = [ + 'mypy==1.7.*', + 'types-protobuf==4.24.*' +] + tests_require = [ 'codecov==2.1.*', 'coverage>=5.3,<7.4', @@ -46,7 +51,7 @@ def read(*parts): extras_require = { 'tests': tests_require, - 'all': install_requires + tests_require + 'all': install_requires + tests_require + mypy_requires } setup( diff --git a/tests.py b/tests.py index 1da1734..b63ecb3 100644 --- a/tests.py +++ b/tests.py @@ -37,12 +37,12 @@ def test_from_file_raises(self): def test_to_file_raises(self): with pytest.raises(ValueError) as e: - Diagram().to_file(None) + Diagram().to_file(None) # type: ignore assert 'Missing output location' in str(e.value) def test_with_format_raises(self): with pytest.raises(ValueError) as e: - Diagram().with_format(None) + Diagram().with_format(None) # type: ignore assert 'Missing file' in str(e.value) def test_build_raises(self):