Skip to content

Commit

Permalink
refactors resulset
Browse files Browse the repository at this point in the history
  • Loading branch information
edublancas committed Jul 12, 2023
1 parent 88fff1a commit c0a7548
Showing 1 changed file with 138 additions and 130 deletions.
268 changes: 138 additions & 130 deletions src/sql/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import sqlalchemy
import sqlparse
from sql.connection import Connection
from sqlalchemy.exc import ResourceClosedError
from sql import exceptions, display
from .column_guesser import ColumnGuesserMixin

Expand Down Expand Up @@ -109,24 +110,83 @@ class ResultSet(ColumnGuesserMixin):
Can access rows listwise, or by string value of leftmost column.
"""

def __init__(self, sqlaproxy, config, statement):
def __init__(self, sqlaproxy, config, statement=None):
self.config = config
self.keys = {}
self._results = []
self.truncated = False
self.sqlaproxy = sqlaproxy
self.statement = statement

self._keys = None
self._field_names = None
self._results = []

# https://peps.python.org/pep-0249/#description
self.is_dbapi_results = hasattr(sqlaproxy, "description")

self.pretty = None
# NOTE: this will trigger key fetching
self.pretty_table = self._init_table()

self._done_fetching = False

if self.config.autolimit == 1:
self.fetchmany(size=1)
self.did_finish_fetching()
# EXPLAIN WHY WE NEED TWO
else:
self.fetchmany(size=2)

def extend_results(self, elements):
self._results.extend(elements)

# NOTE: we shouold use add_rows but there is a subclass that behaves weird
# so using add_row for now
for e in elements:
self.pretty_table.add_row(e)

def done_fetching(self):
self._done_fetching = True
# self.sqlaproxy.close()

def did_finish_fetching(self):
return self._done_fetching

# NOTE: this triggers key fetching
@property
def field_names(self):
if self._field_names is None:
self._field_names = unduplicate_field_names(self.keys)

return self._field_names

# NOTE: this triggers key fetching
@property
def keys(self):
if self._keys is not None:
return self._keys

if not self.is_dbapi_results:
try:
self._keys = self.sqlaproxy.keys()
# sqlite raises this error when running a script that doesn't return rows
# e.g, 'CREATE TABLE' but others don't (e.g., duckdb)
except ResourceClosedError:
self._keys = []
return self._keys

elif isinstance(self.sqlaproxy.description, Iterable):
self._keys = [i[0] for i in self.sqlaproxy.description]
else:
self._keys = []

return self._keys

def _repr_html_(self):
self.fetch_for_repr_if_needed()

_cell_with_spaces_pattern = re.compile(r"(<td>)( {2,})")
if self.pretty:
self.pretty.add_rows(self)
result = self.pretty.get_html_string()
if self.pretty_table:
self.pretty_table.add_rows(self)
result = self.pretty_table.get_html_string()
HTML = (
"%s\n<span style='font-style:italic;font-size:11px'>"
"<code>ResultSet</code> : to convert to pandas, call <a href="
Expand All @@ -140,7 +200,7 @@ def _repr_html_(self):
# to create clickable links
result = html.unescape(result)
result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result)
if self.truncated:
if self.config.displaylimit != 0:
HTML = (
'%s\n<span style="font-style:italic;text-align:center;">'
"Truncated to displaylimit of %d</span>"
Expand All @@ -150,7 +210,7 @@ def _repr_html_(self):
'<a href="https://jupysql.ploomber.io/en/latest/api/configuration.html#displaylimit">displaylimit</a>' # noqa: E501
" configuration</span>"
)
result = HTML % (result, self.pretty.row_count)
result = HTML % (result, self.config.displaylimit)
return result
else:
return None
Expand All @@ -159,18 +219,18 @@ def __len__(self):
return len(self._results)

def __iter__(self):
results = self._fetch_query_results(fetch_all=True)
self.fetchall()

for result in results:
for result in self._results:
yield result

def __str__(self, *arg, **kwarg):
if self.pretty:
self.pretty.add_rows(self)
return str(self.pretty or "")
def __str__(self):
self.fetch_for_repr_if_needed()
return str(self.pretty_table)

def __repr__(self) -> str:
return str(self)
self.fetch_for_repr_if_needed()
return str(self.pretty_table)

def __eq__(self, another: object) -> bool:
return self._results == another
Expand All @@ -190,17 +250,17 @@ def __getitem__(self, key):
raise KeyError('%d results for "%s"' % (len(result), key))
return result[0]

def __getattribute__(self, attr):
"Raises AttributeError when invalid operation is performed."
try:
return object.__getattribute__(self, attr)
except AttributeError:
err_msg = (
f"'{attr}' is not a valid operation, you can convert this "
"into a pandas data frame by calling '.DataFrame()' or a "
"polars data frame by calling '.PolarsDataFrame()'"
)
raise AttributeError(err_msg) from None
# def __getattribute__(self, attr):
# "Raises AttributeError when invalid operation is performed."
# try:
# return object.__getattribute__(self, attr)
# except AttributeError:
# err_msg = (
# f"'{attr}' is not a valid operation, you can convert this "
# "into a pandas data frame by calling '.DataFrame()' or a "
# "polars data frame by calling '.PolarsDataFrame()'"
# )
# raise AttributeError(err_msg) from None

def dict(self):
"""Returns a single dict built from the result set
Expand All @@ -213,12 +273,13 @@ def dicts(self):
for row in self:
yield dict(zip(self.keys, row))

@telemetry.log_call("data-frame", payload=True)
def DataFrame(self, payload):
# @telemetry.log_call("data-frame", payload=True)
# payload
def DataFrame(self):
"Returns a Pandas DataFrame instance built from the result set."
payload[
"connection_info"
] = Connection.current._get_curr_sqlalchemy_connection_info()
# payload[
# "connection_info"
# ] = Connection.current._get_curr_sqlalchemy_connection_info()

# native duckdb connection
if hasattr(self.sqlaproxy, "df"):
Expand Down Expand Up @@ -361,9 +422,11 @@ def bar(self, key_word_sep=" ", title=None, **kwargs):
def csv(self, filename=None, **format_params):
"""Generate results in comma-separated form. Write to ``filename`` if given.
Any other parameters will be passed on to csv.writer."""
if not self.pretty:
if not self.pretty_table:
return None # no results
self.pretty.add_rows(self)

self.pretty_table.add_rows(self)

if filename:
encoding = format_params.get("encoding", "utf-8")
outfile = open(filename, "w", newline="", encoding=encoding)
Expand All @@ -380,104 +443,51 @@ def csv(self, filename=None, **format_params):
else:
return outfile.getvalue()

def fetch_results(self, fetch_all=False):
"""
Returns a limited representation of the query results.
def fetchmany(self, size):
"""Fetch n results and add it to the results"""
if not self.did_finish_fetching():
try:
returned = self.sqlaproxy.fetchmany(size=size)
# sqlite raises this error when running a script that doesn't return rows
# e.g, 'CREATE TABLE' but others don't (e.g., duckdb)
except ResourceClosedError:
self.done_fetching()
return

Parameters
----------
fetch_all : bool default False
Return all query rows
"""
is_dbapi_results = self.is_dbapi_results
sqlaproxy = self.sqlaproxy
config = self.config
self.extend_results(returned)

if is_dbapi_results:
should_try_fetch_results = True
else:
should_try_fetch_results = sqlaproxy.returns_rows
if len(returned) < size:
self.done_fetching()

if should_try_fetch_results:
# sql alchemy results
if not is_dbapi_results:
self.keys = sqlaproxy.keys()
elif isinstance(sqlaproxy.description, Iterable):
self.keys = [i[0] for i in sqlaproxy.description]
else:
self.keys = []
if (
self.config.autolimit is not None
and self.config.autolimit != 0
and len(self._results) >= self.config.autolimit
):
self.done_fetching()

if len(self.keys) > 0:
self._results = self._fetch_query_results(fetch_all=fetch_all)
def fetch_for_repr_if_needed(self):
if self.config.displaylimit == 0:
self.fetchall()

self.field_names = unduplicate_field_names(self.keys)
missing = self.config.displaylimit - len(self._results)

_style = None
if missing > 0:
self.fetchmany(missing)

self.pretty = PrettyTable(self.field_names)
def fetchall(self):
if not self.did_finish_fetching():
self.extend_results(self.sqlaproxy.fetchall())
self.done_fetching()

if isinstance(config.style, str):
_style = prettytable.__dict__[config.style.upper()]
self.pretty.set_style(_style)
def _init_table(self):
pretty = PrettyTable(self.field_names)

return self
if isinstance(self.config.style, str):
_style = prettytable.__dict__[self.config.style.upper()]
pretty.set_style(_style)

def _fetch_query_results(self, fetch_all=False):
"""
Returns rows of a query result as a list of tuples.
Parameters
----------
fetch_all : bool default False
Return all query rows
"""
sqlaproxy = self.sqlaproxy
config = self.config
_should_try_lazy_fetch = hasattr(sqlaproxy, "_soft_closed")

_should_fetch_all = (
(config.displaylimit == 0 or not config.displaylimit)
or fetch_all
or not _should_try_lazy_fetch
)

is_autolimit = isinstance(config.autolimit, int) and config.autolimit > 0
is_connection_closed = (
sqlaproxy._soft_closed if _should_try_lazy_fetch else False
)

should_return_results = is_connection_closed or (
len(self._results) > 0 and is_autolimit
)

if should_return_results:
# this means we already loaded all
# the results to self._results or we use
# autolimit and shouldn't fetch more
results = self._results
else:
if is_autolimit:
results = sqlaproxy.fetchmany(size=config.autolimit)
else:
if _should_fetch_all:
all_results = sqlaproxy.fetchall()
results = self._results + all_results
self._results = results
else:
results = sqlaproxy.fetchmany(size=config.displaylimit)

if _should_try_lazy_fetch:
# Try to fetch an extra row to find out
# if there are more results to fetch
row = sqlaproxy.fetchone()
if row is not None:
results += [row]

# Check if we have more rows to show
if config.displaylimit > 0:
self.truncated = len(results) > config.displaylimit

return results
return pretty


def display_affected_rowcount(rowcount):
Expand Down Expand Up @@ -647,17 +657,15 @@ def run(conn, sql, config):
# if regular sqlalchemy, pass a text object
if not is_custom_connection:
statement = sqlalchemy.sql.text(statement)
else:
result = conn.session.execute(statement)
_commit(conn=conn, config=config, manual_commit=manual_commit)

if result and config.feedback:
if hasattr(result, "rowcount"):
display_affected_rowcount(result.rowcount)
result = conn.session.execute(statement)
_commit(conn=conn, config=config, manual_commit=manual_commit)

if result and config.feedback:
if hasattr(result, "rowcount"):
display_affected_rowcount(result.rowcount)

resultset = ResultSet(result, config, statement)
# lazy load
resultset.fetch_results()
return select_df_type(resultset, config)


Expand Down

0 comments on commit c0a7548

Please sign in to comment.