Skip to content

Commit

Permalink
Add module to generate a graph from a model
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick Valsecchi committed Mar 16, 2017
1 parent 22ba5ff commit 1407587
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 2 deletions.
5 changes: 4 additions & 1 deletion acceptance_tests/app/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ ENV DEVELOPMENT 0

# Step #2 copy the rest of the files (watch for the .dockerignore)
COPY . /app
RUN python /app/setup.py install

RUN python ./setup.py install && \
./models_graph.py > models.dot && \
./models_graph.py Hello > models_hello.dot

CMD ["c2cwsgiutils_run"]
11 changes: 11 additions & 0 deletions acceptance_tests/app/models_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env python3
from c2cwsgiutils.models_graph import generate_model_graph

from c2cwsgiutils_app import models


def main():
generate_model_graph(models)


main()
82 changes: 82 additions & 0 deletions c2cwsgiutils/models_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import inspect
import sqlalchemy as sa
import sys


def generate_model_graph(module):
if len(sys.argv) == 1:
base_name = 'Base'
elif len(sys.argv) == 2:
base_name = sys.argv[1]
else:
print("Invalid parameters\nUsage: %s [base_class]" % sys.argv[0])
exit(1)

_generate_model_graph(module, getattr(module, base_name))


def _generate_model_graph(module, base):
print("""
digraph {
rankdir=BT;
""")

interesting = {
getattr(module, symbol_name)
for symbol_name in dir(module)
if _is_interesting(getattr(module, symbol_name), base)
}

for symbol in list(interesting):
symbol = getattr(module, symbol.__name__)
if _is_interesting(symbol, base):
_print_node(symbol, interesting)

print("}")


def _print_node(symbol, interesting):
print('%s [label="%s", shape=box];' % (symbol.__name__, _get_table_desc(symbol)))
for parent in symbol.__bases__:
if parent != object:
if parent not in interesting:
_print_node(parent, interesting)
interesting.add(parent)
print("%s -> %s;" % (symbol.__name__, parent.__name__))


def _is_interesting(what, base):
return inspect.isclass(what) and issubclass(what, base)


def _get_table_desc(symbol):
cols = [symbol.__name__, ""] + _get_local_cols(symbol)

return "\\n".join(cols)


def _get_all_cols(symbol):
cols = []

for member_name in symbol.__dict__:
member = getattr(symbol, member_name)
if member_name in ('__table__', 'metadata'):
pass
elif isinstance(member, sa.sql.schema.SchemaItem):
cols.append(member_name + ('[null]' if member.nullable else ''))
elif isinstance(member, sa.orm.attributes.InstrumentedAttribute):
nullable = member.property.columns[0].nullable \
if isinstance(member.property, sa.orm.ColumnProperty) \
else False
link = not isinstance(member.property, sa.orm.ColumnProperty)
cols.append(member_name + (' [null]' if nullable else '') + (' ->' if link else ''))

return cols


def _get_local_cols(symbol):
result = set(_get_all_cols(symbol))
for parent in symbol.__bases__:
result -= set(_get_all_cols(parent))

return sorted(list(result))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from setuptools import setup, find_packages


version = '0.4.0'
version = '0.5.0'

setup(
name='c2cwsgiutils',
Expand Down

0 comments on commit 1407587

Please sign in to comment.