diff --git a/.gitignore b/.gitignore index 61626e4..071f852 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ +.ruff_cache/ cover/ # Translations diff --git a/README.md b/README.md index 4d1bf48..e1fd4ee 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,24 @@ Paracelsus generates Entity Relationship Diagrams by reading your SQLAlchemy models. -* ERDs can be injected into documentation as [Mermaid Diagrams](https://mermaid.js.org/). -* Paracelsus can be run in CICD to check that databases are up to date. -* ERDs can be created as files in either [Dot](https://graphviz.org/doc/info/lang.html) or Mermaid format. -* DOT files can be used to generate SVG or PNG files, or edited in [GraphViz](https://graphviz.org/) or other editors. - +- [Paracelsus](#paracelsus) + - [Features](#features) + - [Usage](#usage) + - [Installation](#installation) + - [Basic CLI Usage](#basic-cli-usage) + - [Importing Models](#importing-models) + - [Generate Mermaid Diagrams](#generate-mermaid-diagrams) + - [Inject Mermaid Diagrams](#inject-mermaid-diagrams) + - [Creating Images](#creating-images) + - [pyproject.toml](#pyprojecttoml) + - [Sponsorship](#sponsorship) + +## Features + +- ERDs can be injected into documentation as [Mermaid Diagrams](https://mermaid.js.org/). +- Paracelsus can be run in CICD to check that databases are up to date. +- ERDs can be created as files in either [Dot](https://graphviz.org/doc/info/lang.html) or Mermaid format. +- DOT files can be used to generate SVG or PNG files, or edited in [GraphViz](https://graphviz.org/) or other editors. ## Usage @@ -29,9 +42,9 @@ paracelsus --help It has three commands: -* `version` outputs the version of the currently installed `paracelsus` cli. -* `graph` generates a graph and outputs it to `stdout`. -* `inject` inserts the graph into a markdown file. +- `version` outputs the version of the currently installed `paracelsus` cli. +- `graph` generates a graph and outputs it to `stdout`. +- `inject` inserts the graph into a markdown file. ### Importing Models @@ -161,6 +174,18 @@ To create a PNG file: ![Alt text](./docs/example.png "a title") +### pyproject.toml + +The settings for your project can be saved directly in the `pyprojects.toml` file of your project. + +```toml +[tool.paracelsus] +base = "example.base:Base" +imports = [ + "example.models" +] +``` + ## Sponsorship This project is developed by [Robert Hafner](https://blog.tedivm.com) If you find this project useful please consider sponsoring me using Github! diff --git a/paracelsus/cli.py b/paracelsus/cli.py index 115083a..0cd423e 100644 --- a/paracelsus/cli.py +++ b/paracelsus/cli.py @@ -1,26 +1,17 @@ -import importlib -import os import re import sys from enum import Enum from pathlib import Path -from typing import List +from typing import Any, Dict, List, Optional import typer from typing_extensions import Annotated -from .transformers.dot import Dot -from .transformers.mermaid import Mermaid +from .graph import get_graph_string, transformers +from .pyproject import get_pyproject_settings app = typer.Typer() -transformers = { - "mmd": Mermaid, - "mermaid": Mermaid, - "dot": Dot, - "gv": Dot, -} - class Formats(str, Enum): mermaid = "mermaid" @@ -29,49 +20,22 @@ class Formats(str, Enum): gv = "gv" -def get_graph_string( - base_class_path: str, - import_module: List[str], - python_dir: List[Path], - format: str, -) -> str: - # Update the PYTHON_PATH to allow more module imports. - sys.path.append(str(os.getcwd())) - for dir in python_dir: - sys.path.append(str(dir)) - - # Import the base class so the metadata class can be extracted from it. - # The metadata class is passed to the transformer. - module_path, class_name = base_class_path.split(":", 2) - base_module = importlib.import_module(module_path) - base_class = getattr(base_module, class_name) - metadata = base_class.metadata - - # The modules holding the model classes have to be imported to get put in the metaclass model registry. - # These modules aren't actually used in any way, so they are discarded. - # They are also imported in scope of this function to prevent namespace pollution. - for module in import_module: - if ":*" in module: - # Sure, execs are gross, but this is the only way to dynamically import wildcards. - exec(f"from {module[:-2]} import *") - else: - importlib.import_module(module) - - # Grab a transformer. - if format not in transformers: - raise ValueError(f"Unknown Format: {format}") - transformer = transformers[format] - - # Save the graph structure to string. - return str(transformer(metadata)) +def get_base_class(base_class_path: str | None, settings=Dict[str, Any] | None) -> str: + if base_class_path: + return base_class_path + if not settings: + raise ValueError("`base_class_path` argument must be passed if no pyproject.toml file is present.") + if "base" not in settings: + raise ValueError("`base_class_path` argument must be passed if not defined in pyproject.toml.") + return settings["base"] @app.command(help="Create the graph structure and print it to stdout.") def graph( base_class_path: Annotated[ - str, + Optional[str], typer.Argument(help="The SQLAlchemy base class used by the database to graph."), - ], + ] = None, import_module: Annotated[ List[str], typer.Option( @@ -92,9 +56,15 @@ def graph( Formats, typer.Option(help="The file format to output the generated graph to.") ] = Formats.mermaid, ): + settings = get_pyproject_settings() + base_class = get_base_class(base_class_path, settings) + + if settings and "imports" in settings: + import_module.extend(settings["imports"]) + typer.echo( get_graph_string( - base_class_path=base_class_path, + base_class_path=base_class, import_module=import_module, python_dir=python_dir, format=format.value, diff --git a/paracelsus/graph.py b/paracelsus/graph.py new file mode 100644 index 0000000..88cae36 --- /dev/null +++ b/paracelsus/graph.py @@ -0,0 +1,52 @@ +import importlib +import os +import sys +from pathlib import Path +from typing import List + +from .transformers.dot import Dot +from .transformers.mermaid import Mermaid + +transformers = { + "mmd": Mermaid, + "mermaid": Mermaid, + "dot": Dot, + "gv": Dot, +} + + +def get_graph_string( + base_class_path: str, + import_module: List[str], + python_dir: List[Path], + format: str, +) -> str: + # Update the PYTHON_PATH to allow more module imports. + sys.path.append(str(os.getcwd())) + for dir in python_dir: + sys.path.append(str(dir)) + + # Import the base class so the metadata class can be extracted from it. + # The metadata class is passed to the transformer. + module_path, class_name = base_class_path.split(":", 2) + base_module = importlib.import_module(module_path) + base_class = getattr(base_module, class_name) + metadata = base_class.metadata + + # The modules holding the model classes have to be imported to get put in the metaclass model registry. + # These modules aren't actually used in any way, so they are discarded. + # They are also imported in scope of this function to prevent namespace pollution. + for module in import_module: + if ":*" in module: + # Sure, execs are gross, but this is the only way to dynamically import wildcards. + exec(f"from {module[:-2]} import *") + else: + importlib.import_module(module) + + # Grab a transformer. + if format not in transformers: + raise ValueError(f"Unknown Format: {format}") + transformer = transformers[format] + + # Save the graph structure to string. + return str(transformer(metadata)) diff --git a/paracelsus/pyproject.py b/paracelsus/pyproject.py new file mode 100644 index 0000000..7bd81c6 --- /dev/null +++ b/paracelsus/pyproject.py @@ -0,0 +1,16 @@ +import os +import tomllib +from pathlib import Path +from typing import Any, Dict + + +def get_pyproject_settings(dir: Path = Path(os.getcwd())) -> Dict[str, Any] | None: + pyproject = dir / "pyproject.toml" + + if not pyproject.exists(): + return None + + with open(pyproject, "rb") as f: + data = tomllib.load(f) + + return data.get("tool", {}).get("paracelsus", None) diff --git a/tests/assets/README.md b/tests/assets/README.md new file mode 100644 index 0000000..ad2c459 --- /dev/null +++ b/tests/assets/README.md @@ -0,0 +1,8 @@ +# Test Directory + +Please ignore. + +## Schema + + + diff --git a/tests/assets/example/__init__.py b/tests/assets/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/example/base.py b/tests/assets/example/base.py new file mode 100644 index 0000000..59be703 --- /dev/null +++ b/tests/assets/example/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/tests/assets/example/models.py b/tests/assets/example/models.py new file mode 100644 index 0000000..47c17d1 --- /dev/null +++ b/tests/assets/example/models.py @@ -0,0 +1,37 @@ +from datetime import UTC, datetime +from uuid import uuid4 + +from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text, Uuid +from sqlalchemy.orm import mapped_column + +from .base import Base + + +class User(Base): + __tablename__ = "users" + + id = mapped_column(Uuid, primary_key=True, default=uuid4()) + display_name = mapped_column(String(100)) + created = mapped_column(DateTime, nullable=False, default=datetime.now(UTC)) + + +class Post(Base): + __tablename__ = "posts" + + id = mapped_column(Uuid, primary_key=True, default=uuid4()) + author = mapped_column(ForeignKey(User.id), nullable=False) + created = mapped_column(DateTime, nullable=False, default=datetime.now(UTC)) + live = mapped_column(Boolean, default=False) + content = mapped_column(Text, default="") + + +class Comment(Base): + __tablename__ = "comments" + + id = mapped_column(Uuid, primary_key=True, default=uuid4()) + post = mapped_column(Uuid, ForeignKey(Post.id), default=uuid4()) + author = mapped_column(ForeignKey(User.id), nullable=False) + created = mapped_column(DateTime, nullable=False, default=datetime.now(UTC)) + live = mapped_column(Boolean, default=False) + content = mapped_column(Text, default="") + content = mapped_column(Text, default="") diff --git a/tests/assets/pyproject.toml b/tests/assets/pyproject.toml new file mode 100644 index 0000000..53cd70b --- /dev/null +++ b/tests/assets/pyproject.toml @@ -0,0 +1,5 @@ +[tool.paracelsus] +base = "example.base:Base" +imports = [ + "example.models" +] diff --git a/tests/conftest.py b/tests/conftest.py index 5430302..216493e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,6 @@ -from datetime import datetime +import os +from datetime import UTC, datetime +from pathlib import Path from uuid import uuid4 import pytest @@ -15,14 +17,14 @@ class User(Base): id = mapped_column(Uuid, primary_key=True, default=uuid4()) display_name = mapped_column(String(100)) - created = mapped_column(DateTime, nullable=False, default=datetime.utcnow()) + created = mapped_column(DateTime, nullable=False, default=datetime.now(UTC)) class Post(Base): __tablename__ = "posts" id = mapped_column(Uuid, primary_key=True, default=uuid4()) author = mapped_column(ForeignKey(User.id), nullable=False) - created = mapped_column(DateTime, nullable=False, default=datetime.utcnow()) + created = mapped_column(DateTime, nullable=False, default=datetime.now(UTC)) live = mapped_column(Boolean, default=False, comment="True if post is published") content = mapped_column(Text, default="") @@ -32,8 +34,13 @@ class Comment(Base): id = mapped_column(Uuid, primary_key=True, default=uuid4()) post = mapped_column(Uuid, ForeignKey(Post.id), default=uuid4()) author = mapped_column(ForeignKey(User.id), nullable=False) - created = mapped_column(DateTime, nullable=False, default=datetime.utcnow()) + created = mapped_column(DateTime, nullable=False, default=datetime.now(UTC)) live = mapped_column(Boolean, default=False) content = mapped_column(Text, default="") return Base.metadata + + +@pytest.fixture +def package_path(): + return Path(os.path.dirname(os.path.realpath(__file__))) / "assets" diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..f51ef0a --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,49 @@ +from typer.testing import CliRunner + +from paracelsus.cli import app + +runner = CliRunner() + + +def test_graph(package_path): + result = runner.invoke( + app, ["graph", "example.base:Base", "--import-module", "example.models", "--python-dir", str(package_path)] + ) + + assert result.exit_code == 0 + + assert "users {" in result.stdout + assert "posts {" in result.stdout + assert "comments {" in result.stdout + + assert "users ||--o{ posts : author" in result.stdout + assert "posts ||--o{ comments : post" in result.stdout + assert "users ||--o{ comments : author" in result.stdout + + assert "CHAR(32) author FK" in result.stdout + assert 'CHAR(32) post FK "nullable"' in result.stdout + assert "DATETIME created" in result.stdout + + +def test_inject(package_path): + result = runner.invoke( + app, + [ + "inject", + str(package_path / "README.md"), + "example.base:Base", + "--import-module", + "example.models", + "--python-dir", + str(package_path), + "--check", + ], + ) + + assert result.exit_code == 1 + + +def test_version(): + result = runner.invoke(app, ["version"]) + + assert result.exit_code == 0 diff --git a/tests/test_graph.py b/tests/test_graph.py new file mode 100644 index 0000000..f588ad0 --- /dev/null +++ b/tests/test_graph.py @@ -0,0 +1,17 @@ +from paracelsus.graph import get_graph_string + + +def test_get_graph_string(package_path): + graph_string = get_graph_string("example.base:Base", ["example.models"], [package_path], "mermaid") + + assert "users {" in graph_string + assert "posts {" in graph_string + assert "comments {" in graph_string + + assert "users ||--o{ posts : author" in graph_string + assert "posts ||--o{ comments : post" in graph_string + assert "users ||--o{ comments : author" in graph_string + + assert "CHAR(32) author FK" in graph_string + assert 'CHAR(32) post FK "nullable"' in graph_string + assert "DATETIME created" in graph_string diff --git a/tests/test_pyproject.py b/tests/test_pyproject.py new file mode 100644 index 0000000..4a6a223 --- /dev/null +++ b/tests/test_pyproject.py @@ -0,0 +1,14 @@ +from paracelsus.pyproject import get_pyproject_settings + + +def test_pyproject(package_path): + settings = get_pyproject_settings(package_path) + assert "base" in settings + assert "imports" in settings + assert settings["base"] == "example.base:Base" + assert "example.models" in settings["imports"] + + +def test_pyproject_none(): + settings = get_pyproject_settings() + assert settings is None