diff --git a/pypath/formats/sqldump.py b/pypath/formats/sqldump.py index de64a135e..6b8342116 100644 --- a/pypath/formats/sqldump.py +++ b/pypath/formats/sqldump.py @@ -23,7 +23,7 @@ Processing MySQL dump files. """ -from typing import IO, Iterable +from typing import Generator, IO, Iterable import os import collections import re @@ -49,7 +49,7 @@ def tables( con_param: dict | None = None, source_id: str | tuple | None = None, **kwargs - ) -> tuple[dict, dict] | dict[str, pd.DataFrame] | sqlite3.Connection: + ) -> dict[str, list[tuple] | pd.DataFrame] | sqlite3.Connection: """ From a SQL dump, retrieve tables as data frames or as lists of tuples. @@ -91,17 +91,16 @@ def tables( else: contents = {} - headers = {} for table in tables: - header, content = table_from_cache(table, source_id) + content = table_from_cache(table, source_id) if not content: seek(sqldump, table) - header, content = one_table(sqldump) - table_to_cache(header, content, table, source_id) + content = list(one_table(sqldump)) + table_to_cache(content, table, source_id) if not content: @@ -109,7 +108,11 @@ def tables( if return_df or return_sqlite: - content = pd.DataFrame.from_records(content, columns = header) + content = pd.DataFrame.from_records( + content, + columns = content[0]._fields, + ) + _log( f'Data frame of table `{table}`: ' f'{content.shape[0]} records, {common.df_memory_usage(content)}.' @@ -130,7 +133,6 @@ def tables( else: contents[table] = content - headers[table] = None if return_df else header tables_done.add(table) @@ -138,13 +140,17 @@ def tables( break - return contents if return_df or return_sqlite else (headers, contents) + if not return_sqlite and len(contents) == 1: + + contents = common.first(contents.values()) + + return contents def table_from_cache( table: str, source_id: tuple, - ) -> tuple[list, list]: + ) -> list[tuple]: """ Retrieve table contents from cache. """ @@ -161,14 +167,18 @@ def table_from_cache( f'Cache path: `{cache_path}`.' ) - return pickle.load(fp) + name, fields, content = pickle.load(fp) - return [], [] + record = collections.namedtuple(name, fields) + content = [record(*r) for r in content] + + return content + + return [] def table_to_cache( - header: dict, - content: dict, + content: list[tuple], table: str, source_id: tuple, ) -> None: @@ -182,30 +192,37 @@ def table_to_cache( with open(cache_path, 'wb') as fp: _log(f'Saving table `{table}` to cache. Cache path: `{cache_path}`.') - pickle.dump((header, content), fp) + pickle.dump( + ( + content[0].__class__.__name__, + content[0]._fields, + [tuple(r) for r in content] + ), + fp, + ) -def one_table(sqldump: IO) -> tuple[list[str], list[tuple]]: +def one_table(sqldump: IO) -> Generator[tuple]: """ Retrieve one table from a MySQL dump. """ parser = sqlparse.parsestream(sqldump) - header = [] - content = [] - inserts = 0 + n_inserts = 0 + n_records = 0 + record = None table_name = the_table_name = None for statement in parser: if statement.get_type() == 'INSERT': - inserts += 1 + n_inserts += 1 sublists = statement.get_sublists() table_info = next(sublists) table_name = table_info.get_name() - if inserts == 1: + if n_inserts == 1: the_table_name = table_name _log(f'Processing table: `{table_name}`') @@ -214,28 +231,30 @@ def one_table(sqldump: IO) -> tuple[list[str], list[tuple]]: break - header = [ - col.get_name() - for col in table_info.get_parameters() - ] + if record is None: - content.extend( - tuple( + record = collections.namedtuple( + f'Ramp{table_name.capitalize()}', + [col.get_name() for col in table_info.get_parameters()], + ) + globals()[record.__class__.__name__] = record + + for rec in next(sublists).get_sublists(): + + n_records += 1 + + yield record(*( s.value.strip('"\'') for s in next(rec.get_sublists()).get_identifiers() - ) - for rec in next(sublists).get_sublists() - ) + )) if table_name: _log( - f'Processed {len(content)} records in {inserts} INSERT statements ' + f'Processed {n_records} records in {n_inserts} INSERT statements ' f'from table `{the_table_name}`.' ) - return header, content - def get_source_id( sqldump: str | IO,