Skip to content

Commit

Permalink
test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
edublancas committed Jul 14, 2023
1 parent 96686f7 commit ba89b71
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 76 deletions.
30 changes: 16 additions & 14 deletions src/sql/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,17 @@ def __init__(self, sqlaproxy, config, statement=None, conn=None):
# note that calling this will fetch the keys
self.pretty_table = self._init_table()

self._done_fetching = False
self._mark_fetching_as_done = False

if self.config.autolimit == 1:
# if autolimit is 1, we only want to fetch one row
self.fetchmany(size=1)
self.did_finish_fetching()
self.done_fetching()
else:
# in all other cases, 2 allows us to know if there are more rows
# for example when creating a table, the results contains one row, in
# such case, fetching 2 rows will tell us that there are no more rows
# and can set the _done_fetching flag to True
# and can set the _mark_fetching_as_done flag to True
self.fetchmany(size=2)

self._finished_init = True
Expand Down Expand Up @@ -172,13 +172,13 @@ def extend_results(self, elements):
self._results.extend(elements)
self.pretty_table.add_rows(elements[:to_add])

def done_fetching(self):
self._done_fetching = True
def mark_fetching_as_done(self):
self._mark_fetching_as_done = True
# NOTE: don't close the connection here (self.sqlaproxy.close()),
# because we need to keep it open for the next query

def did_finish_fetching(self):
return self._done_fetching
def done_fetching(self):
return self._mark_fetching_as_done

@property
def field_names(self):
Expand Down Expand Up @@ -214,6 +214,7 @@ def _repr_html_(self):
_cell_with_spaces_pattern = re.compile(r"(<td>)( {2,})")

result = self.pretty_table.get_html_string()

HTML = (
"%s\n<span style='font-style:italic;font-size:11px'>"
"<code>ResultSet</code> : to convert to pandas, call <a href="
Expand All @@ -227,7 +228,8 @@ def _repr_html_(self):
# to create clickable links
result = html.unescape(result)
result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result)
if self.config.displaylimit != 0:

if self.config.displaylimit != 0 and not self.done_fetching():
HTML = (
'%s\n<span style="font-style:italic;text-align:center;">'
"Truncated to displaylimit of %d</span>"
Expand Down Expand Up @@ -494,26 +496,26 @@ def csv(self, filename=None, **format_params):

def fetchmany(self, size):
"""Fetch n results and add it to the results"""
if not self.did_finish_fetching():
if not self.done_fetching():
try:
returned = self.sqlaproxy.fetchmany(size=size)
# sqlite raises this error when running a script that doesn't return rows
# e.g, 'CREATE TABLE' but others don't (e.g., duckdb)
except ResourceClosedError:
self.done_fetching()
self.mark_fetching_as_done()
return

self.extend_results(returned)

if len(returned) < size:
self.done_fetching()
self.mark_fetching_as_done()

if (
self.config.autolimit is not None
and self.config.autolimit != 0
and len(self._results) >= self.config.autolimit
):
self.done_fetching()
self.mark_fetching_as_done()

def fetch_for_repr_if_needed(self):
if self.config.displaylimit == 0:
Expand All @@ -525,9 +527,9 @@ def fetch_for_repr_if_needed(self):
self.fetchmany(missing)

def fetchall(self):
if not self.did_finish_fetching():
if not self.done_fetching():
self.extend_results(self.sqlaproxy.fetchall())
self.done_fetching()
self.mark_fetching_as_done()

def _init_table(self):
pretty = CustomPrettyTable(self.field_names)
Expand Down
Loading

0 comments on commit ba89b71

Please sign in to comment.