Skip to content

Commit

Permalink
Merge pull request #16 from aja114/aja114/allow_excluding_tables
Browse files Browse the repository at this point in the history
add include/exclude tables option when creating graph
  • Loading branch information
tedivm authored Jun 2, 2024
2 parents 62cb200 + dc899b6 commit 1a332cf
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ repos:
hooks:
- id: pytest
name: pytest
entry: make pytest_check
entry: make pytest
language: system
pass_filenames: false
- id: ruff
Expand Down
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,25 @@ This is equivalent to running this style of python import:
from example_app.models import *
```

### Include or Exclude tables

After importing the models, it is possible to select a subset of those models by using the `--exlude-tables` and `--include-tables` options.
These are mutually exclusive options, the user can only provide inclusions or exclusions:

```bash
paracelsus graph example_app.models.base:Base \
--import-module "example_app.models.*" \
--exclude-tables "comments"
```

This is equivalent to:

```bash
paracelsus graph example_app.models.base:Base \
--import-module "example_app.models.*" \
--include-tables "users"
--include-tables "posts"
```

### Generate Mermaid Diagrams

Expand Down
20 changes: 20 additions & 0 deletions paracelsus/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def graph(
help="Module, typically an SQL Model, to import. Modules that end in :* will act as `from module import *`"
),
] = [],
exclude_tables: Annotated[
List[str],
typer.Option(help="List of tables that are excluded from the graph"),
] = [],
include_tables: Annotated[
List[str],
typer.Option(help="List of tables that are included in the graph"),
] = [],
python_dir: Annotated[
List[Path],
typer.Option(
Expand All @@ -66,6 +74,8 @@ def graph(
get_graph_string(
base_class_path=base_class,
import_module=import_module,
include_tables=set(include_tables),
exclude_tables=set(exclude_tables),
python_dir=python_dir,
format=format.value,
)
Expand Down Expand Up @@ -102,6 +112,14 @@ def inject(
help="Module, typically an SQL Model, to import. Modules that end in :* will act as `from module import *`"
),
] = [],
exclude_tables: Annotated[
List[str],
typer.Option(help="List of tables that are excluded from the graph"),
] = [],
include_tables: Annotated[
List[str],
typer.Option(help="List of tables that are included in the graph"),
] = [],
python_dir: Annotated[
List[Path],
typer.Option(
Expand All @@ -127,6 +145,8 @@ def inject(
graph = get_graph_string(
base_class_path=base_class_path,
import_module=import_module,
include_tables=set(include_tables),
exclude_tables=set(exclude_tables),
python_dir=python_dir,
format=format.value,
)
Expand Down
62 changes: 60 additions & 2 deletions paracelsus/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import os
import sys
from pathlib import Path
from typing import List
from typing import List, Set

from sqlalchemy import MetaData

from .transformers.dot import Dot
from .transformers.mermaid import Mermaid
Expand All @@ -16,8 +18,11 @@


def get_graph_string(
*,
base_class_path: str,
import_module: List[str],
include_tables: Set[str],
exclude_tables: Set[str],
python_dir: List[Path],
format: str,
) -> str:
Expand Down Expand Up @@ -48,5 +53,58 @@ def get_graph_string(
raise ValueError(f"Unknown Format: {format}")
transformer = transformers[format]

# Keep only the tables which were included / not-excluded
include_tables = resolve_included_tables(
include_tables=include_tables, exclude_tables=exclude_tables, all_tables=set(metadata.tables.keys())
)
filtered_metadata = filter_metadata(metadata=metadata, include_tables=include_tables)

# Save the graph structure to string.
return str(transformer(metadata))
return str(transformer(filtered_metadata))


def resolve_included_tables(
include_tables: Set[str],
exclude_tables: Set[str],
all_tables: Set[str],
) -> Set[str]:
"""Resolves the final set of tables to include in the graph.
Given sets of inclusions and exclusions and the set of all tables we define
the following cases are:
- Empty inclusion and empty exclusion -> include all tables.
- Empty inclusion and some exclusions -> include all tables except the ones in the exclusion set.
- Some inclusions and empty exclusion -> make sure tables in the inclusion set are present in
all tables then include the tables in the inclusion set.
- Some inclusions and some exclusions -> not resolvable, an error is raised.
"""
match len(include_tables), len(exclude_tables):
case 0, 0:
return all_tables
case 0, int():
return all_tables - exclude_tables
case int(), 0:
if not include_tables.issubset(all_tables):
non_existent_tables = include_tables - all_tables
raise ValueError(
f"Some tables to include ({non_existent_tables}) don't exist"
"withinthe found tables ({all_tables})."
)
return include_tables
case _:
raise ValueError(
f"Only one or none of include_tables ({include_tables}) or exclude_tables"
f"({exclude_tables}) can contain values."
)


def filter_metadata(
metadata: MetaData,
include_tables: Set[str],
) -> MetaData:
"""Create a subset of the metadata based on the tables to include."""
filtered_metadata = MetaData()
for tablename, table in metadata.tables.items():
if tablename in include_tables:
table.tometadata(filtered_metadata)
return filtered_metadata
14 changes: 14 additions & 0 deletions paracelsus/transformers/dot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import pydot # type: ignore
import logging
from sqlalchemy.sql.schema import MetaData, Table

from . import utils


logger = logging.getLogger(__name__)


class Dot:
comment_format: str = "dot"
metadata: MetaData
Expand All @@ -24,6 +28,16 @@ def __init__(self, metaclass: MetaData) -> None:
key_parts = foreign_key.target_fullname.split(".")
left_table = key_parts[0]
left_column = key_parts[1]

# We don't add the connection to the fk table if the latter
# is not included in our graph.
if left_table not in self.metadata.tables:
logger.warning(
f"Table '{table}.{column.name}' is a foreign key to '{left_table}' "
"which is not included in the graph, skipping the connection."
)
continue

edge = pydot.Edge(left_table, table.name)
edge.set_label(column.name)
edge.set_dir("both")
Expand Down
13 changes: 13 additions & 0 deletions paracelsus/transformers/mermaid.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import logging
from sqlalchemy.sql.schema import Column, MetaData, Table

from . import utils


logger = logging.getLogger(__name__)


class Mermaid:
comment_format: str = "mermaid"
metadata: MetaData
Expand Down Expand Up @@ -64,6 +68,15 @@ def _relationships(self, column: Column) -> str:
left_column = key_parts[1]
left_operand = ""

# We don't add the connection to the fk table if the latter
# is not included in our graph.
if left_table not in self.metadata.tables:
logger.warning(
f"Table '{right_table}.{column_name}' is a foreign key to '{left_table}' "
"which is not included in the graph, skipping the connection."
)
continue

lcolumn = self.metadata.tables[left_table].columns[left_column]
if lcolumn.unique or lcolumn.primary_key:
left_operand = "||"
Expand Down
38 changes: 38 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,44 @@ def test_graph(package_path):
mermaid_assert(result.stdout)


def test_graph_with_exclusion(package_path):
result = runner.invoke(
app,
[
"graph",
"example.base:Base",
"--import-module",
"example.models",
"--python-dir",
str(package_path),
"--exclude-tables",
"comments",
],
)
assert result.exit_code == 0
assert "posts {" in result.stdout
assert "comments {" not in result.stdout


def test_graph_with_inclusion(package_path):
result = runner.invoke(
app,
[
"graph",
"example.base:Base",
"--import-module",
"example.models",
"--python-dir",
str(package_path),
"--include-tables",
"comments",
],
)
assert result.exit_code == 0
assert "posts {" not in result.stdout
assert "comments {" in result.stdout


def test_inject_check(package_path):
result = runner.invoke(
app,
Expand Down
66 changes: 65 additions & 1 deletion tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,69 @@


def test_get_graph_string(package_path):
graph_string = get_graph_string("example.base:Base", ["example.models"], [package_path], "mermaid")
graph_string = get_graph_string(
base_class_path="example.base:Base",
import_module=["example.models"],
include_tables=set(),
exclude_tables=set(),
python_dir=[package_path],
format="mermaid",
)
mermaid_assert(graph_string)


def test_get_graph_string_with_exclude(package_path):
"""Excluding tables removes them from the graph string."""
graph_string = get_graph_string(
base_class_path="example.base:Base",
import_module=["example.models"],
include_tables=set(),
exclude_tables={"comments"},
python_dir=[package_path],
format="mermaid",
)
assert "comments {" not in graph_string
assert "posts {" in graph_string
assert "users {" in graph_string
assert "users ||--o{ posts" in graph_string

# Excluding a table to which another table holds a foreign key will raise an error.
graph_string = get_graph_string(
base_class_path="example.base:Base",
import_module=["example.models"],
include_tables=set(),
exclude_tables={"users", "comments"},
python_dir=[package_path],
format="mermaid",
)
assert "posts {" in graph_string
assert "users ||--o{ posts" not in graph_string


def test_get_graph_string_with_include(package_path):
"""Excluding tables keeps them in the graph string."""
graph_string = get_graph_string(
base_class_path="example.base:Base",
import_module=["example.models"],
include_tables={"users", "posts"},
exclude_tables=set(),
python_dir=[package_path],
format="mermaid",
)
assert "comments {" not in graph_string
assert "posts {" in graph_string
assert "users {" in graph_string
assert "users ||--o{ posts" in graph_string

# Including a table that holds a foreign key to a non-existing table will keep
# the table but skip the connection.
graph_string = get_graph_string(
base_class_path="example.base:Base",
import_module=["example.models"],
include_tables={"posts"},
exclude_tables=set(),
python_dir=[package_path],
format="mermaid",
)
assert "posts {" in graph_string
assert "users ||--o{ posts" not in graph_string

0 comments on commit 1a332cf

Please sign in to comment.