diff --git a/src/sql/run.py b/src/sql/run.py index 72ab565ab..09d050c4c 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -131,17 +131,17 @@ def __init__(self, sqlaproxy, config, statement=None, conn=None): # note that calling this will fetch the keys self.pretty_table = self._init_table() - self._done_fetching = False + self._mark_fetching_as_done = False if self.config.autolimit == 1: # if autolimit is 1, we only want to fetch one row self.fetchmany(size=1) - self.did_finish_fetching() + self.done_fetching() else: # in all other cases, 2 allows us to know if there are more rows # for example when creating a table, the results contains one row, in # such case, fetching 2 rows will tell us that there are no more rows - # and can set the _done_fetching flag to True + # and can set the _mark_fetching_as_done flag to True self.fetchmany(size=2) self._finished_init = True @@ -172,13 +172,13 @@ def extend_results(self, elements): self._results.extend(elements) self.pretty_table.add_rows(elements[:to_add]) - def done_fetching(self): - self._done_fetching = True + def mark_fetching_as_done(self): + self._mark_fetching_as_done = True # NOTE: don't close the connection here (self.sqlaproxy.close()), # because we need to keep it open for the next query - def did_finish_fetching(self): - return self._done_fetching + def done_fetching(self): + return self._mark_fetching_as_done @property def field_names(self): @@ -214,6 +214,7 @@ def _repr_html_(self): _cell_with_spaces_pattern = re.compile(r"()( {2,})") result = self.pretty_table.get_html_string() + HTML = ( "%s\n" "ResultSet : to convert to pandas, call ' "Truncated to displaylimit of %d" @@ -494,26 +496,26 @@ def csv(self, filename=None, **format_params): def fetchmany(self, size): """Fetch n results and add it to the results""" - if not self.did_finish_fetching(): + if not self.done_fetching(): try: returned = self.sqlaproxy.fetchmany(size=size) # sqlite raises this error when running a script that doesn't return rows # e.g, 'CREATE TABLE' but others don't (e.g., duckdb) except ResourceClosedError: - self.done_fetching() + self.mark_fetching_as_done() return self.extend_results(returned) if len(returned) < size: - self.done_fetching() + self.mark_fetching_as_done() if ( self.config.autolimit is not None and self.config.autolimit != 0 and len(self._results) >= self.config.autolimit ): - self.done_fetching() + self.mark_fetching_as_done() def fetch_for_repr_if_needed(self): if self.config.displaylimit == 0: @@ -525,9 +527,9 @@ def fetch_for_repr_if_needed(self): self.fetchmany(missing) def fetchall(self): - if not self.did_finish_fetching(): + if not self.done_fetching(): self.extend_results(self.sqlaproxy.fetchall()) - self.done_fetching() + self.mark_fetching_as_done() def _init_table(self): pretty = CustomPrettyTable(self.field_names) diff --git a/src/tests/test_resultset.py b/src/tests/test_resultset.py index 244681543..25f90c7af 100644 --- a/src/tests/test_resultset.py +++ b/src/tests/test_resultset.py @@ -10,7 +10,7 @@ import polars as pl import sqlalchemy -from sql.connection import DBAPIConnection +from sql.connection import DBAPIConnection, Connection from sql.run import ResultSet from sql import run as run_module @@ -87,8 +87,6 @@ def test_resultset_repr_html(result_set): ".PolarsDataFrame()
" ) in html_ - assert "Truncated to displaylimit of 5" in html_ - @pytest.mark.parametrize( "fname, parameters", @@ -116,9 +114,6 @@ def test_resultset_config_autolimit_dict(result, config): assert resultset.dict() == {"x": (0,)} -# TODO: add dbapi tests - - @pytest.fixture def results(ip_empty): engine = create_engine("duckdb://") @@ -140,56 +135,96 @@ def results(ip_empty): session.close() -@pytest.mark.parametrize("uri", ["duckdb://", "sqlite://"]) -def test_convert_to_dataframe(ip_empty, uri): - engine = create_engine(uri) - session = engine.connect() +@pytest.fixture +def duckdb_sqlalchemy(ip_empty): + conn = Connection(create_engine("duckdb://")) + yield conn + + +@pytest.fixture +def sqlite_sqlalchemy(ip_empty): + conn = Connection(create_engine("sqlite://")) + yield conn + + +@pytest.fixture +def duckdb_dbapi(): + conn_ = duckdb.connect(":memory:") + conn = DBAPIConnection(conn_) + yield conn + + +@pytest.mark.parametrize( + "session, expected_value", + [ + ("duckdb_sqlalchemy", {}), + ("duckdb_dbapi", {"Count": {}}), + ("sqlite_sqlalchemy", {}), + ], +) +def test_convert_to_dataframe_create_table(session, expected_value, request): + session = request.getfixturevalue(session) mock = Mock() mock.displaylimit = 100 mock.autolimit = 100000 - results = session.execute(text("CREATE TABLE a (x INT);")) + statement = "CREATE TABLE a (x INT);" + results = session.execute(statement) - rs = ResultSet(results, mock, statement=None, conn=Mock()) + rs = ResultSet(results, mock, statement=statement, conn=session) df = rs.DataFrame() - assert df.to_dict() == {} + assert df.to_dict() == expected_value -def test_convert_to_dataframe_2(ip_empty): - engine = create_engine("duckdb://") - session = engine.connect() +@pytest.mark.parametrize( + "session, expected_value", + [ + ("duckdb_sqlalchemy", {"Count": {0: 5}}), + ("duckdb_dbapi", {"Count": {0: 5}}), + ("sqlite_sqlalchemy", {}), + ], +) +def test_convert_to_dataframe_insert_into(session, expected_value, request): + session = request.getfixturevalue(session) mock = Mock() mock.displaylimit = 100 mock.autolimit = 100000 - session.execute(text("CREATE TABLE a (x INT,);")) - results = session.execute(text("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);")) - rs = ResultSet(results, mock, statement=None, conn=Mock()) + session.execute("CREATE TABLE a (x INT,);") + statement = "INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);" + results = session.execute(statement) + rs = ResultSet(results, mock, statement=statement, conn=session) df = rs.DataFrame() - assert df.to_dict() == {"Count": {0: 5}} + assert df.to_dict() == expected_value -@pytest.mark.parametrize("uri", ["duckdb://", "sqlite://"]) -def test_convert_to_dataframe_3(ip_empty, uri): - engine = create_engine(uri) - session = engine.connect() +@pytest.mark.parametrize( + "session", + [ + "duckdb_sqlalchemy", + "duckdb_dbapi", + "sqlite_sqlalchemy", + ], +) +def test_convert_to_dataframe_select(session, request): + session = request.getfixturevalue(session) mock = Mock() mock.displaylimit = 100 mock.autolimit = 100000 - session.execute(text("CREATE TABLE a (x INT);")) - session.execute(text("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);")) - results = session.execute(text("SELECT * FROM a")) + session.execute("CREATE TABLE a (x INT);") + session.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);") + statement = "SELECT * FROM a" + results = session.execute(statement) - rs = ResultSet(results, mock, statement="SELECT * FROM a", conn=Mock()) + rs = ResultSet(results, mock, statement=statement, conn=session) df = rs.DataFrame() - # TODO: check native duckdb was called if using duckb assert df.to_dict() == {"x": {0: 1, 1: 2, 2: 3, 3: 4, 4: 5}} @@ -243,7 +278,6 @@ def test_convert_to_dataframe_using_native_duckdb( if to_df_method == "PolarsDataFrame": d["x"] = list(d["x"]) - # TODO: check native duckdb was called if using duckb assert d == expected_value @@ -254,7 +288,7 @@ def test_done_fetching_if_reached_autolimit(results): rs = ResultSet(results, mock, statement=None, conn=Mock()) - assert rs.did_finish_fetching() is True + assert rs.done_fetching() is True def test_done_fetching_if_reached_autolimit_2(results): @@ -265,7 +299,7 @@ def test_done_fetching_if_reached_autolimit_2(results): rs = ResultSet(results, mock, statement=None, conn=Mock()) list(rs) - assert rs.did_finish_fetching() is True + assert rs.done_fetching() is True @pytest.mark.parametrize("method", ["__repr__", "_repr_html_"]) @@ -279,7 +313,7 @@ def test_no_displaylimit(results, method, autolimit): getattr(rs, method)() assert rs._results == [(1,), (2,), (3,), (4,), (5,)] - assert rs.did_finish_fetching() is True + assert rs.done_fetching() is True def test_no_fetching_if_one_result(): @@ -300,7 +334,7 @@ def test_no_fetching_if_one_result(): rs = ResultSet(results, mock, statement=None, conn=Mock()) - assert rs.did_finish_fetching() is True + assert rs.done_fetching() is True results.fetchall.assert_not_called() results.fetchmany.assert_called_once_with(size=2) results.fetchone.assert_not_called() @@ -313,23 +347,32 @@ def test_no_fetching_if_one_result(): results.fetchone.assert_not_called() -# TODO: try with other values, and also change the display limit and re-run the repr -# TODO: try with __repr__ and __str__ -def test_resultset_fetches_required_rows(results): +def test_resultset_fetches_minimum_number_of_rows(results): mock = Mock() mock.displaylimit = 3 mock.autolimit = 1000_000 ResultSet(results, mock, statement=None, conn=Mock()) - # rs.fetch_results() - # rs._repr_html_() - # str(rs) results.fetchall.assert_not_called() results.fetchmany.assert_called_once_with(size=2) results.fetchone.assert_not_called() +@pytest.mark.parametrize("method", ["__repr__", "_repr_html_"]) +def test_resultset_fetches_minimum_number_of_rows_for_repr(results, method): + mock = Mock() + mock.displaylimit = 3 + mock.autolimit = 1000_000 + + rs = ResultSet(results, mock, statement=None, conn=Mock()) + getattr(rs, method)() + + results.fetchall.assert_not_called() + assert results.fetchmany.call_args_list == [call(size=2), call(size=1)] + results.fetchone.assert_not_called() + + def test_fetches_remaining_rows(results): mock = Mock() mock.displaylimit = 1 @@ -375,7 +418,10 @@ def test_fetches_remaining_rows(results): "3\n \n \n", ], ], - ids=["repr", "repr_html"], + ids=[ + "repr", + "repr_html", + ], ) def test_resultset_fetches_required_rows_repr(results, method, repr_expected): mock = Mock() @@ -386,24 +432,12 @@ def test_resultset_fetches_required_rows_repr(results, method, repr_expected): repr_returned = getattr(rs, method)() assert repr_expected in repr_returned - assert rs.did_finish_fetching() is False + assert rs.done_fetching() is False results.fetchall.assert_not_called() results.fetchmany.assert_has_calls([call(size=2), call(size=1)]) results.fetchone.assert_not_called() -def test_resultset_fetches_no_rows(results): - mock = Mock() - mock.displaylimit = 1 - mock.autolimit = 1000_000 - - ResultSet(results, mock, statement=None, conn=Mock()) - - results.fetchmany.assert_has_calls([call(size=2)]) - results.fetchone.assert_not_called() - results.fetchall.assert_not_called() - - def test_resultset_autolimit_one(results): mock = Mock() mock.displaylimit = 10 @@ -438,25 +472,116 @@ def test_display_limit_respected_even_when_feched_all(results): ) -# TODO: try with more values of displaylimit -# TODO: test some edge cases. e.g., displaylimit is set to 10 but we only have 5 rows -def test_displaylimit_message(results): +@pytest.mark.parametrize( + "displaylimit, message", + [ + (1, "Truncated to displaylimit of 1"), + (2, "Truncated to displaylimit of 2"), + ], +) +def test_displaylimit_message(displaylimit, message, results): mock = Mock() - mock.displaylimit = 1 + mock.displaylimit = displaylimit mock.autolimit = 0 rs = ResultSet(results, mock, statement=None, conn=Mock()) - assert "Truncated to displaylimit of 1" in rs._repr_html_() + assert message in rs._repr_html_() + + +@pytest.mark.parametrize("displaylimit", [0, 1000]) +def test_no_displaylimit_message(results, displaylimit): + mock = Mock() + mock.displaylimit = displaylimit + mock.autolimit = 0 + + rs = ResultSet(results, mock, statement=None, conn=Mock()) + + assert "Truncated to displaylimit" not in rs._repr_html_() def test_refreshes_sqlaproxy_for_sqlalchemy_duckdb(): - pass + first = Connection(create_engine("duckdb://")) + first.execute("CREATE TABLE numbers (x INTEGER)") + first.execute("INSERT INTO numbers VALUES (1), (2), (3), (4), (5)") + first.execute("CREATE TABLE characters (c VARCHAR)") + first.execute("INSERT INTO characters VALUES ('a'), ('b'), ('c'), ('d'), ('e')") + + mock = Mock() + mock.displaylimit = 10 + mock.autolimit = 0 + + statement = text("SELECT * FROM numbers") + first_set = ResultSet( + first.session.execute(statement), mock, statement=statement, conn=first + ) + + original_id = id(first_set._sqlaproxy) + + # create a new resultset so the other one is no longer the latest one + statement = text("SELECT * FROM characters") + ResultSet(first.session.execute(statement), mock, statement=statement, conn=first) + + # force fetching data, this should trigger a refresh + list(first_set) + + assert id(first_set._sqlaproxy) != original_id def test_doesnt_refresh_sqlaproxy_for_if_not_sqlalchemy_and_duckdb(): - pass + first = DBAPIConnection(duckdb.connect(":memory:")) + first.execute("CREATE TABLE numbers (x INTEGER)") + first.execute("INSERT INTO numbers VALUES (1), (2), (3), (4), (5)") + first.execute("CREATE TABLE characters (c VARCHAR)") + first.execute("INSERT INTO characters VALUES ('a'), ('b'), ('c'), ('d'), ('e')") + + mock = Mock() + mock.displaylimit = 10 + mock.autolimit = 0 + + statement = "SELECT * FROM numbers" + first_set = ResultSet( + first.session.execute(statement), mock, statement=statement, conn=first + ) + + original_id = id(first_set._sqlaproxy) + + # create a new resultset so the other one is no longer the latest one + statement = "SELECT * FROM characters" + ResultSet(first.session.execute(statement), mock, statement=statement, conn=first) + + # force fetching data, this should not trigger a refresh + list(first_set) + + assert id(first_set._sqlaproxy) == original_id def test_doesnt_refresh_sqlaproxy_if_different_connection(): - pass + first = Connection(create_engine("duckdb://")) + first.execute("CREATE TABLE numbers (x INTEGER)") + first.execute("INSERT INTO numbers VALUES (1), (2), (3), (4), (5)") + + second = Connection(create_engine("duckdb://")) + second.execute("CREATE TABLE characters (c VARCHAR)") + second.execute("INSERT INTO characters VALUES ('a'), ('b'), ('c'), ('d'), ('e')") + + mock = Mock() + mock.displaylimit = 10 + mock.autolimit = 0 + + statement = "SELECT * FROM numbers" + first_set = ResultSet( + first.session.execute(text(statement)), mock, statement=statement, conn=first + ) + + original_id = id(first_set._sqlaproxy) + + statement = "SELECT * FROM characters" + ResultSet( + second.session.execute(text(statement)), mock, statement=statement, conn=second + ) + + # force fetching data + list(first_set) + + assert id(first_set._sqlaproxy) == original_id