Skip to content

Commit

Permalink
formats.sqldump: one_table is a generator by default
Browse files Browse the repository at this point in the history
  • Loading branch information
deeenes committed Oct 17, 2024
1 parent 548584f commit 2c676f2
Showing 1 changed file with 52 additions and 33 deletions.
85 changes: 52 additions & 33 deletions pypath/formats/sqldump.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Processing MySQL dump files.
"""

from typing import IO, Iterable
from typing import Generator, IO, Iterable
import os
import collections
import re
Expand All @@ -49,7 +49,7 @@ def tables(
con_param: dict | None = None,
source_id: str | tuple | None = None,
**kwargs
) -> tuple[dict, dict] | dict[str, pd.DataFrame] | sqlite3.Connection:
) -> dict[str, list[tuple] | pd.DataFrame] | sqlite3.Connection:
"""
From a SQL dump, retrieve tables as data frames or as lists of tuples.
Expand Down Expand Up @@ -91,25 +91,28 @@ def tables(
else:

contents = {}
headers = {}

for table in tables:

header, content = table_from_cache(table, source_id)
content = table_from_cache(table, source_id)

if not content:

seek(sqldump, table)
header, content = one_table(sqldump)
table_to_cache(header, content, table, source_id)
content = list(one_table(sqldump))
table_to_cache(content, table, source_id)

if not content:

_log(f'Table `{table}` not found.')

if return_df or return_sqlite:

content = pd.DataFrame.from_records(content, columns = header)
content = pd.DataFrame.from_records(
content,
columns = content[0]._fields,
)

_log(
f'Data frame of table `{table}`: '
f'{content.shape[0]} records, {common.df_memory_usage(content)}.'
Expand All @@ -130,21 +133,24 @@ def tables(
else:

contents[table] = content
headers[table] = None if return_df else header

tables_done.add(table)

if not tables - tables_done:

break

return contents if return_df or return_sqlite else (headers, contents)
if not return_sqlite and len(contents) == 1:

contents = common.first(contents.values())

return contents


def table_from_cache(
table: str,
source_id: tuple,
) -> tuple[list, list]:
) -> list[tuple]:
"""
Retrieve table contents from cache.
"""
Expand All @@ -161,14 +167,18 @@ def table_from_cache(
f'Cache path: `{cache_path}`.'
)

return pickle.load(fp)
name, fields, content = pickle.load(fp)

return [], []
record = collections.namedtuple(name, fields)
content = [record(*r) for r in content]

return content

return []


def table_to_cache(
header: dict,
content: dict,
content: list[tuple],
table: str,
source_id: tuple,
) -> None:
Expand All @@ -182,30 +192,37 @@ def table_to_cache(
with open(cache_path, 'wb') as fp:

_log(f'Saving table `{table}` to cache. Cache path: `{cache_path}`.')
pickle.dump((header, content), fp)
pickle.dump(
(
content[0].__class__.__name__,
content[0]._fields,
[tuple(r) for r in content]
),
fp,
)


def one_table(sqldump: IO) -> tuple[list[str], list[tuple]]:
def one_table(sqldump: IO) -> Generator[tuple]:
"""
Retrieve one table from a MySQL dump.
"""

parser = sqlparse.parsestream(sqldump)
header = []
content = []
inserts = 0
n_inserts = 0
n_records = 0
record = None
table_name = the_table_name = None

for statement in parser:

if statement.get_type() == 'INSERT':

inserts += 1
n_inserts += 1
sublists = statement.get_sublists()
table_info = next(sublists)
table_name = table_info.get_name()

if inserts == 1:
if n_inserts == 1:

the_table_name = table_name
_log(f'Processing table: `{table_name}`')
Expand All @@ -214,28 +231,30 @@ def one_table(sqldump: IO) -> tuple[list[str], list[tuple]]:

break

header = [
col.get_name()
for col in table_info.get_parameters()
]
if record is None:

content.extend(
tuple(
record = collections.namedtuple(
f'Ramp{table_name.capitalize()}',
[col.get_name() for col in table_info.get_parameters()],
)
globals()[record.__class__.__name__] = record

for rec in next(sublists).get_sublists():

n_records += 1

yield record(*(
s.value.strip('"\'')
for s in next(rec.get_sublists()).get_identifiers()
)
for rec in next(sublists).get_sublists()
)
))

if table_name:

_log(
f'Processed {len(content)} records in {inserts} INSERT statements '
f'Processed {n_records} records in {n_inserts} INSERT statements '
f'from table `{the_table_name}`.'
)

return header, content


def get_source_id(
sqldump: str | IO,
Expand Down

0 comments on commit 2c676f2

Please sign in to comment.