diff --git a/src/sql/run.py b/src/sql/run.py
index 72ab565ab..09d050c4c 100644
--- a/src/sql/run.py
+++ b/src/sql/run.py
@@ -131,17 +131,17 @@ def __init__(self, sqlaproxy, config, statement=None, conn=None):
# note that calling this will fetch the keys
self.pretty_table = self._init_table()
- self._done_fetching = False
+ self._mark_fetching_as_done = False
if self.config.autolimit == 1:
# if autolimit is 1, we only want to fetch one row
self.fetchmany(size=1)
- self.did_finish_fetching()
+ self.done_fetching()
else:
# in all other cases, 2 allows us to know if there are more rows
# for example when creating a table, the results contains one row, in
# such case, fetching 2 rows will tell us that there are no more rows
- # and can set the _done_fetching flag to True
+ # and can set the _mark_fetching_as_done flag to True
self.fetchmany(size=2)
self._finished_init = True
@@ -172,13 +172,13 @@ def extend_results(self, elements):
self._results.extend(elements)
self.pretty_table.add_rows(elements[:to_add])
- def done_fetching(self):
- self._done_fetching = True
+ def mark_fetching_as_done(self):
+ self._mark_fetching_as_done = True
# NOTE: don't close the connection here (self.sqlaproxy.close()),
# because we need to keep it open for the next query
- def did_finish_fetching(self):
- return self._done_fetching
+ def done_fetching(self):
+ return self._mark_fetching_as_done
@property
def field_names(self):
@@ -214,6 +214,7 @@ def _repr_html_(self):
_cell_with_spaces_pattern = re.compile(r"(
)( {2,})")
result = self.pretty_table.get_html_string()
+
HTML = (
"%s\n"
"ResultSet : to convert to pandas, call '
"Truncated to displaylimit of %d"
@@ -494,26 +496,26 @@ def csv(self, filename=None, **format_params):
def fetchmany(self, size):
"""Fetch n results and add it to the results"""
- if not self.did_finish_fetching():
+ if not self.done_fetching():
try:
returned = self.sqlaproxy.fetchmany(size=size)
# sqlite raises this error when running a script that doesn't return rows
# e.g, 'CREATE TABLE' but others don't (e.g., duckdb)
except ResourceClosedError:
- self.done_fetching()
+ self.mark_fetching_as_done()
return
self.extend_results(returned)
if len(returned) < size:
- self.done_fetching()
+ self.mark_fetching_as_done()
if (
self.config.autolimit is not None
and self.config.autolimit != 0
and len(self._results) >= self.config.autolimit
):
- self.done_fetching()
+ self.mark_fetching_as_done()
def fetch_for_repr_if_needed(self):
if self.config.displaylimit == 0:
@@ -525,9 +527,9 @@ def fetch_for_repr_if_needed(self):
self.fetchmany(missing)
def fetchall(self):
- if not self.did_finish_fetching():
+ if not self.done_fetching():
self.extend_results(self.sqlaproxy.fetchall())
- self.done_fetching()
+ self.mark_fetching_as_done()
def _init_table(self):
pretty = CustomPrettyTable(self.field_names)
diff --git a/src/tests/test_resultset.py b/src/tests/test_resultset.py
index 244681543..25f90c7af 100644
--- a/src/tests/test_resultset.py
+++ b/src/tests/test_resultset.py
@@ -10,7 +10,7 @@
import polars as pl
import sqlalchemy
-from sql.connection import DBAPIConnection
+from sql.connection import DBAPIConnection, Connection
from sql.run import ResultSet
from sql import run as run_module
@@ -87,8 +87,6 @@ def test_resultset_repr_html(result_set):
".PolarsDataFrame() "
) in html_
- assert "Truncated to displaylimit of 5" in html_
-
@pytest.mark.parametrize(
"fname, parameters",
@@ -116,9 +114,6 @@ def test_resultset_config_autolimit_dict(result, config):
assert resultset.dict() == {"x": (0,)}
-# TODO: add dbapi tests
-
-
@pytest.fixture
def results(ip_empty):
engine = create_engine("duckdb://")
@@ -140,56 +135,96 @@ def results(ip_empty):
session.close()
-@pytest.mark.parametrize("uri", ["duckdb://", "sqlite://"])
-def test_convert_to_dataframe(ip_empty, uri):
- engine = create_engine(uri)
- session = engine.connect()
+@pytest.fixture
+def duckdb_sqlalchemy(ip_empty):
+ conn = Connection(create_engine("duckdb://"))
+ yield conn
+
+
+@pytest.fixture
+def sqlite_sqlalchemy(ip_empty):
+ conn = Connection(create_engine("sqlite://"))
+ yield conn
+
+
+@pytest.fixture
+def duckdb_dbapi():
+ conn_ = duckdb.connect(":memory:")
+ conn = DBAPIConnection(conn_)
+ yield conn
+
+
+@pytest.mark.parametrize(
+ "session, expected_value",
+ [
+ ("duckdb_sqlalchemy", {}),
+ ("duckdb_dbapi", {"Count": {}}),
+ ("sqlite_sqlalchemy", {}),
+ ],
+)
+def test_convert_to_dataframe_create_table(session, expected_value, request):
+ session = request.getfixturevalue(session)
mock = Mock()
mock.displaylimit = 100
mock.autolimit = 100000
- results = session.execute(text("CREATE TABLE a (x INT);"))
+ statement = "CREATE TABLE a (x INT);"
+ results = session.execute(statement)
- rs = ResultSet(results, mock, statement=None, conn=Mock())
+ rs = ResultSet(results, mock, statement=statement, conn=session)
df = rs.DataFrame()
- assert df.to_dict() == {}
+ assert df.to_dict() == expected_value
-def test_convert_to_dataframe_2(ip_empty):
- engine = create_engine("duckdb://")
- session = engine.connect()
+@pytest.mark.parametrize(
+ "session, expected_value",
+ [
+ ("duckdb_sqlalchemy", {"Count": {0: 5}}),
+ ("duckdb_dbapi", {"Count": {0: 5}}),
+ ("sqlite_sqlalchemy", {}),
+ ],
+)
+def test_convert_to_dataframe_insert_into(session, expected_value, request):
+ session = request.getfixturevalue(session)
mock = Mock()
mock.displaylimit = 100
mock.autolimit = 100000
- session.execute(text("CREATE TABLE a (x INT,);"))
- results = session.execute(text("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);"))
- rs = ResultSet(results, mock, statement=None, conn=Mock())
+ session.execute("CREATE TABLE a (x INT,);")
+ statement = "INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);"
+ results = session.execute(statement)
+ rs = ResultSet(results, mock, statement=statement, conn=session)
df = rs.DataFrame()
- assert df.to_dict() == {"Count": {0: 5}}
+ assert df.to_dict() == expected_value
-@pytest.mark.parametrize("uri", ["duckdb://", "sqlite://"])
-def test_convert_to_dataframe_3(ip_empty, uri):
- engine = create_engine(uri)
- session = engine.connect()
+@pytest.mark.parametrize(
+ "session",
+ [
+ "duckdb_sqlalchemy",
+ "duckdb_dbapi",
+ "sqlite_sqlalchemy",
+ ],
+)
+def test_convert_to_dataframe_select(session, request):
+ session = request.getfixturevalue(session)
mock = Mock()
mock.displaylimit = 100
mock.autolimit = 100000
- session.execute(text("CREATE TABLE a (x INT);"))
- session.execute(text("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);"))
- results = session.execute(text("SELECT * FROM a"))
+ session.execute("CREATE TABLE a (x INT);")
+ session.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);")
+ statement = "SELECT * FROM a"
+ results = session.execute(statement)
- rs = ResultSet(results, mock, statement="SELECT * FROM a", conn=Mock())
+ rs = ResultSet(results, mock, statement=statement, conn=session)
df = rs.DataFrame()
- # TODO: check native duckdb was called if using duckb
assert df.to_dict() == {"x": {0: 1, 1: 2, 2: 3, 3: 4, 4: 5}}
@@ -243,7 +278,6 @@ def test_convert_to_dataframe_using_native_duckdb(
if to_df_method == "PolarsDataFrame":
d["x"] = list(d["x"])
- # TODO: check native duckdb was called if using duckb
assert d == expected_value
@@ -254,7 +288,7 @@ def test_done_fetching_if_reached_autolimit(results):
rs = ResultSet(results, mock, statement=None, conn=Mock())
- assert rs.did_finish_fetching() is True
+ assert rs.done_fetching() is True
def test_done_fetching_if_reached_autolimit_2(results):
@@ -265,7 +299,7 @@ def test_done_fetching_if_reached_autolimit_2(results):
rs = ResultSet(results, mock, statement=None, conn=Mock())
list(rs)
- assert rs.did_finish_fetching() is True
+ assert rs.done_fetching() is True
@pytest.mark.parametrize("method", ["__repr__", "_repr_html_"])
@@ -279,7 +313,7 @@ def test_no_displaylimit(results, method, autolimit):
getattr(rs, method)()
assert rs._results == [(1,), (2,), (3,), (4,), (5,)]
- assert rs.did_finish_fetching() is True
+ assert rs.done_fetching() is True
def test_no_fetching_if_one_result():
@@ -300,7 +334,7 @@ def test_no_fetching_if_one_result():
rs = ResultSet(results, mock, statement=None, conn=Mock())
- assert rs.did_finish_fetching() is True
+ assert rs.done_fetching() is True
results.fetchall.assert_not_called()
results.fetchmany.assert_called_once_with(size=2)
results.fetchone.assert_not_called()
@@ -313,23 +347,32 @@ def test_no_fetching_if_one_result():
results.fetchone.assert_not_called()
-# TODO: try with other values, and also change the display limit and re-run the repr
-# TODO: try with __repr__ and __str__
-def test_resultset_fetches_required_rows(results):
+def test_resultset_fetches_minimum_number_of_rows(results):
mock = Mock()
mock.displaylimit = 3
mock.autolimit = 1000_000
ResultSet(results, mock, statement=None, conn=Mock())
- # rs.fetch_results()
- # rs._repr_html_()
- # str(rs)
results.fetchall.assert_not_called()
results.fetchmany.assert_called_once_with(size=2)
results.fetchone.assert_not_called()
+@pytest.mark.parametrize("method", ["__repr__", "_repr_html_"])
+def test_resultset_fetches_minimum_number_of_rows_for_repr(results, method):
+ mock = Mock()
+ mock.displaylimit = 3
+ mock.autolimit = 1000_000
+
+ rs = ResultSet(results, mock, statement=None, conn=Mock())
+ getattr(rs, method)()
+
+ results.fetchall.assert_not_called()
+ assert results.fetchmany.call_args_list == [call(size=2), call(size=1)]
+ results.fetchone.assert_not_called()
+
+
def test_fetches_remaining_rows(results):
mock = Mock()
mock.displaylimit = 1
@@ -375,7 +418,10 @@ def test_fetches_remaining_rows(results):
" | 3 | \n \n \n",
],
],
- ids=["repr", "repr_html"],
+ ids=[
+ "repr",
+ "repr_html",
+ ],
)
def test_resultset_fetches_required_rows_repr(results, method, repr_expected):
mock = Mock()
@@ -386,24 +432,12 @@ def test_resultset_fetches_required_rows_repr(results, method, repr_expected):
repr_returned = getattr(rs, method)()
assert repr_expected in repr_returned
- assert rs.did_finish_fetching() is False
+ assert rs.done_fetching() is False
results.fetchall.assert_not_called()
results.fetchmany.assert_has_calls([call(size=2), call(size=1)])
results.fetchone.assert_not_called()
-def test_resultset_fetches_no_rows(results):
- mock = Mock()
- mock.displaylimit = 1
- mock.autolimit = 1000_000
-
- ResultSet(results, mock, statement=None, conn=Mock())
-
- results.fetchmany.assert_has_calls([call(size=2)])
- results.fetchone.assert_not_called()
- results.fetchall.assert_not_called()
-
-
def test_resultset_autolimit_one(results):
mock = Mock()
mock.displaylimit = 10
@@ -438,25 +472,116 @@ def test_display_limit_respected_even_when_feched_all(results):
)
-# TODO: try with more values of displaylimit
-# TODO: test some edge cases. e.g., displaylimit is set to 10 but we only have 5 rows
-def test_displaylimit_message(results):
+@pytest.mark.parametrize(
+ "displaylimit, message",
+ [
+ (1, "Truncated to displaylimit of 1"),
+ (2, "Truncated to displaylimit of 2"),
+ ],
+)
+def test_displaylimit_message(displaylimit, message, results):
mock = Mock()
- mock.displaylimit = 1
+ mock.displaylimit = displaylimit
mock.autolimit = 0
rs = ResultSet(results, mock, statement=None, conn=Mock())
- assert "Truncated to displaylimit of 1" in rs._repr_html_()
+ assert message in rs._repr_html_()
+
+
+@pytest.mark.parametrize("displaylimit", [0, 1000])
+def test_no_displaylimit_message(results, displaylimit):
+ mock = Mock()
+ mock.displaylimit = displaylimit
+ mock.autolimit = 0
+
+ rs = ResultSet(results, mock, statement=None, conn=Mock())
+
+ assert "Truncated to displaylimit" not in rs._repr_html_()
def test_refreshes_sqlaproxy_for_sqlalchemy_duckdb():
- pass
+ first = Connection(create_engine("duckdb://"))
+ first.execute("CREATE TABLE numbers (x INTEGER)")
+ first.execute("INSERT INTO numbers VALUES (1), (2), (3), (4), (5)")
+ first.execute("CREATE TABLE characters (c VARCHAR)")
+ first.execute("INSERT INTO characters VALUES ('a'), ('b'), ('c'), ('d'), ('e')")
+
+ mock = Mock()
+ mock.displaylimit = 10
+ mock.autolimit = 0
+
+ statement = text("SELECT * FROM numbers")
+ first_set = ResultSet(
+ first.session.execute(statement), mock, statement=statement, conn=first
+ )
+
+ original_id = id(first_set._sqlaproxy)
+
+ # create a new resultset so the other one is no longer the latest one
+ statement = text("SELECT * FROM characters")
+ ResultSet(first.session.execute(statement), mock, statement=statement, conn=first)
+
+ # force fetching data, this should trigger a refresh
+ list(first_set)
+
+ assert id(first_set._sqlaproxy) != original_id
def test_doesnt_refresh_sqlaproxy_for_if_not_sqlalchemy_and_duckdb():
- pass
+ first = DBAPIConnection(duckdb.connect(":memory:"))
+ first.execute("CREATE TABLE numbers (x INTEGER)")
+ first.execute("INSERT INTO numbers VALUES (1), (2), (3), (4), (5)")
+ first.execute("CREATE TABLE characters (c VARCHAR)")
+ first.execute("INSERT INTO characters VALUES ('a'), ('b'), ('c'), ('d'), ('e')")
+
+ mock = Mock()
+ mock.displaylimit = 10
+ mock.autolimit = 0
+
+ statement = "SELECT * FROM numbers"
+ first_set = ResultSet(
+ first.session.execute(statement), mock, statement=statement, conn=first
+ )
+
+ original_id = id(first_set._sqlaproxy)
+
+ # create a new resultset so the other one is no longer the latest one
+ statement = "SELECT * FROM characters"
+ ResultSet(first.session.execute(statement), mock, statement=statement, conn=first)
+
+ # force fetching data, this should not trigger a refresh
+ list(first_set)
+
+ assert id(first_set._sqlaproxy) == original_id
def test_doesnt_refresh_sqlaproxy_if_different_connection():
- pass
+ first = Connection(create_engine("duckdb://"))
+ first.execute("CREATE TABLE numbers (x INTEGER)")
+ first.execute("INSERT INTO numbers VALUES (1), (2), (3), (4), (5)")
+
+ second = Connection(create_engine("duckdb://"))
+ second.execute("CREATE TABLE characters (c VARCHAR)")
+ second.execute("INSERT INTO characters VALUES ('a'), ('b'), ('c'), ('d'), ('e')")
+
+ mock = Mock()
+ mock.displaylimit = 10
+ mock.autolimit = 0
+
+ statement = "SELECT * FROM numbers"
+ first_set = ResultSet(
+ first.session.execute(text(statement)), mock, statement=statement, conn=first
+ )
+
+ original_id = id(first_set._sqlaproxy)
+
+ statement = "SELECT * FROM characters"
+ ResultSet(
+ second.session.execute(text(statement)), mock, statement=statement, conn=second
+ )
+
+ # force fetching data
+ list(first_set)
+
+ assert id(first_set._sqlaproxy) == original_id