diff --git a/src/tests/integration/test_duckDB.py b/src/tests/integration/test_duckDB.py index 0e85d69be..c7105a7a2 100644 --- a/src/tests/integration/test_duckDB.py +++ b/src/tests/integration/test_duckDB.py @@ -7,6 +7,7 @@ from sql.connection import Connection from sql.run import ResultSet +from sql import connection @pytest.fixture @@ -85,3 +86,133 @@ def test_can_convert_to_bears_natively( getattr(mock, expected_native_method).assert_called_once_with() assert isinstance(out.result, expected_type) assert out.result.shape == (2, 2) + + +@pytest.mark.parametrize( + "config", + [ + "%config SqlMagic.autopandas = True", + "%config SqlMagic.autopandas = False", + ], + ids=[ + "autopandas_on", + "autopandas_off", + ], +) +@pytest.mark.parametrize( + "sql, tables", + [ + ["%sql SELECT * FROM weather; SELECT * FROM weather;", ["weather"]], + [ + "%sql CREATE TABLE names (name VARCHAR,); SELECT * FROM weather;", + ["weather", "names"], + ], + [ + ( + "%sql CREATE TABLE names (city VARCHAR,);" + "CREATE TABLE more_names (city VARCHAR,);" + "INSERT INTO names VALUES ('NYC');" + "SELECT * FROM names UNION ALL SELECT * FROM more_names;" + ), + ["weather", "names", "more_names"], + ], + ], + ids=[ + "multiple_selects", + "multiple_statements", + "multiple_tables_created", + ], +) +def test_multiple_statements(ip_empty, config, sql, tables): + ip_empty.run_cell("%sql duckdb://") + ip_empty.run_cell(config) + + ip_empty.run_cell("%sql CREATE TABLE weather (city VARCHAR,);") + ip_empty.run_cell("%sql INSERT INTO weather VALUES ('NYC');") + ip_empty.run_cell("%sql SELECT * FROM weather;") + + out = ip_empty.run_cell(sql) + out_tables = ip_empty.run_cell("%sqlcmd tables") + + assert out.error_in_exec is None + + if config == "%config SqlMagic.autopandas = True": + assert out.result.to_dict() == {"city": {0: "NYC"}} + else: + assert out.result.dict() == {"city": ("NYC",)} + + assert set(tables) == set(r[0] for r in out_tables.result._table.rows) + + +def test_dataframe_returned_only_if_last_statement_is_select(ip_empty): + ip_empty.run_cell("%sql duckdb://") + ip_empty.run_cell("%config SqlMagic.autopandas=True") + connection.Connection.connections["duckdb://"].engine.raw_connection = Mock( + side_effect=ValueError("some error") + ) + + out = ip_empty.run_cell( + "%sql CREATE TABLE a (c VARCHAR,); CREATE TABLE b (c VARCHAR,);" + ) + + assert out.error_in_exec is None + + +@pytest.mark.parametrize( + "sql", + [ + ( + "%sql CREATE TABLE a (x INT,); CREATE TABLE b (x INT,); " + "INSERT INTO a VALUES (1,); INSERT INTO b VALUES(2,); " + "SELECT * FROM a UNION ALL SELECT * FROM b;" + ), + """\ +%%sql +CREATE TABLE a (x INT,); +CREATE TABLE b (x INT,); +INSERT INTO a VALUES (1,); +INSERT INTO b VALUES(2,); +SELECT * FROM a UNION ALL SELECT * FROM b; +""", + ], +) +def test_commits_all_statements(ip_empty, sql): + ip_empty.run_cell("%sql duckdb://") + out = ip_empty.run_cell(sql) + assert out.error_in_exec is None + assert out.result.dict() == {"x": (1, 2)} + + +def test_resultset_uses_native_duckdb_df(ip_empty): + from sqlalchemy import create_engine + from sql.connection import Connection + + engine = create_engine("duckdb://") + + engine.execute("CREATE TABLE a (x INT,);") + engine.execute("INSERT INTO a(x) VALUES (10),(20),(30);") + + sql = "SELECT * FROM a" + + # this breaks if there's an open results set + engine.execute(sql).fetchall() + + results = engine.execute(sql) + + Connection.set(engine, displaycon=False) + + results.fetchmany = Mock(wraps=results.fetchmany) + + mock = Mock() + mock.displaylimit = 1 + + result_set = ResultSet(results, mock, statement=sql) + + result_set.fetch_results() + + df = result_set.DataFrame() + + assert isinstance(df, pd.DataFrame) + assert df.to_dict() == {"x": {0: 1, 1: 2, 2: 3}} + + results.fetchmany.assert_called_once_with(size=1) diff --git a/src/tests/test_resultset.py b/src/tests/test_resultset.py index 4bf6cb337..ec58e8d56 100644 --- a/src/tests/test_resultset.py +++ b/src/tests/test_resultset.py @@ -1,7 +1,8 @@ +from unittest.mock import Mock, call + from sqlalchemy import create_engine -from sql.connection import Connection from pathlib import Path -from unittest.mock import Mock + import pytest import pandas as pd @@ -9,7 +10,6 @@ import sqlalchemy from sql.run import ResultSet -from sql import run as run_module import re @@ -35,7 +35,7 @@ def result(): @pytest.fixture def result_set(result, config): - return ResultSet(result, config).fetch_results() + return ResultSet(result, config) def test_resultset_getitem(result_set): @@ -51,12 +51,6 @@ def test_resultset_dicts(result_set): assert list(result_set.dicts()) == [{"x": 0}, {"x": 1}, {"x": 2}] -def test_resultset_dataframe(result_set, monkeypatch): - monkeypatch.setattr(run_module.Connection, "current", Mock()) - - assert result_set.DataFrame().equals(pd.DataFrame({"x": range(3)})) - - def test_resultset_polars_dataframe(result_set, monkeypatch): assert result_set.PolarsDataFrame().frame_equal(pl.DataFrame({"x": range(3)})) @@ -71,215 +65,276 @@ def test_resultset_str(result_set): assert str(result_set) == "+---+\n| x |\n+---+\n| 0 |\n| 1 |\n| 2 |\n+---+" -def test_resultset_repr_html(result_set): - assert result_set._repr_html_() == ( - "\n \n \n " - "\n \n \n \n " - "\n \n \n \n " - "\n \n \n \n " - "\n \n
x
0
1
2
\n" - "" - "ResultSet : to convert to pandas, call " - ".DataFrame() or to polars, call " - ".PolarsDataFrame()
" - ) +def test_resultset_config_autolimit_dict(result, config): + config.autolimit = 1 + resultset = ResultSet(result, config) + assert resultset.dict() == {"x": (0,)} -@pytest.mark.parametrize( - "fname, parameters", - [ - ("head", -1), - ("tail", None), - ("value_counts", None), - ("not_df_function", None), - ], -) -def test_invalid_operation_error(result_set, fname, parameters): - with pytest.raises(AttributeError) as excinfo: - getattr(result_set, fname)(parameters) +def _get_number_of_rows_in_html_table(html): + """ + Returns the number of tags within the section + """ + pattern = r"(.*?)<\/tbody>" + tbody_content = re.findall(pattern, html, re.DOTALL)[0] + row_count = len(re.findall(r"", tbody_content)) - assert str(excinfo.value) == ( - f"'{fname}' is not a valid operation, you can convert this " - "into a pandas data frame by calling '.DataFrame()' or a " - "polars data frame by calling '.PolarsDataFrame()'" - ) + return row_count -def test_resultset_config_autolimit_dict(result, config): - config.autolimit = 1 - resultset = ResultSet(result, config).fetch_results() - assert resultset.dict() == {"x": (0,)} +# TODO: add dbapi tests -def test_resultset_with_non_sqlalchemy_results(config): - df = pd.DataFrame({"x": range(3)}) # noqa - conn = Connection(engine=create_engine("duckdb://")) - result = conn.execute("SELECT * FROM df") - assert ResultSet(result, config).fetch_results() == [(0,), (1,), (2,)] +@pytest.fixture +def results(ip_empty): + engine = create_engine("duckdb://") + engine.execute("CREATE TABLE a (x INT,);") -def test_none_pretty(config): - conn = Connection(engine=create_engine("sqlite://")) - result = conn.execute("create table some_table (name, age)") - result_set = ResultSet(result, config) - assert result_set.pretty is None - assert "" == str(result_set) + engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);") + sql = "SELECT * FROM a" + results = engine.execute(sql) -def test_lazy_loading(result, config): - resultset = ResultSet(result, config) - assert len(resultset._results) == 0 - resultset.fetch_results() - assert len(resultset._results) == 3 + results.fetchmany = Mock(wraps=results.fetchmany) + results.fetchall = Mock(wraps=results.fetchall) + results.fetchone = Mock(wraps=results.fetchone) + # results.fetchone = Mock(side_effect=ValueError("fetchone called")) -@pytest.mark.parametrize( - "autolimit, expected", - [ - (None, 3), - (False, 3), - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), - ], -) -def test_lazy_loading_autolimit(result, config, autolimit, expected): - config.autolimit = autolimit - resultset = ResultSet(result, config) - assert len(resultset._results) == 0 - resultset.fetch_results() - assert len(resultset._results) == expected + yield results -@pytest.mark.parametrize( - "displaylimit, expected", - [ - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), - ], -) -def test_lazy_loading_displaylimit(result, config, displaylimit, expected): - config.displaylimit = displaylimit - result_set = ResultSet(result, config) +@pytest.mark.parametrize("uri", ["duckdb://", "sqlite://"]) +def test_convert_to_dataframe(ip_empty, uri): + engine = create_engine(uri) - assert len(result_set._results) == 0 - result_set.fetch_results() - html = result_set._repr_html_() - row_count = _get_number_of_rows_in_html_table(html) - assert row_count == expected + mock = Mock() + mock.displaylimit = 100 + mock.autolimit = 100000 + results = engine.execute("CREATE TABLE a (x INT);") + # results = engine.execute("CREATE TABLE test (n INT, name TEXT)") -@pytest.mark.parametrize( - "displaylimit, expected_display", - [ - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), - ], -) -def test_lazy_loading_displaylimit_fetch_all( - result, config, displaylimit, expected_display -): - max_results_count = 3 - config.autolimit = False - config.displaylimit = displaylimit - result_set = ResultSet(result, config) + rs = ResultSet(results, mock) + df = rs.DataFrame() - # Initialize result_set without fetching results - assert len(result_set._results) == 0 + assert df.to_dict() == {} - # Fetch the min number of rows (based on configuration) - result_set.fetch_results() - html = result_set._repr_html_() - row_count = _get_number_of_rows_in_html_table(html) - expected_results = ( - max_results_count - if expected_display + 1 >= max_results_count - else expected_display + 1 - ) +def test_convert_to_dataframe_2(ip_empty): + engine = create_engine("duckdb://") - assert len(result_set._results) == expected_results - assert row_count == expected_display + mock = Mock() + mock.displaylimit = 100 + mock.autolimit = 100000 - # Fetch the the rest results, but don't display them in the table - result_set.fetch_results(fetch_all=True) + engine.execute("CREATE TABLE a (x INT,);") + results = engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);") + rs = ResultSet(results, mock) + df = rs.DataFrame() - html = result_set._repr_html_() - row_count = _get_number_of_rows_in_html_table(html) + assert df.to_dict() == {"Count": {0: 5}} - assert len(result_set._results) == max_results_count - assert row_count == expected_display +@pytest.mark.parametrize("uri", ["duckdb://", "sqlite://"]) +def test_convert_to_dataframe_3(ip_empty, uri): + engine = create_engine(uri) -@pytest.mark.parametrize( - "displaylimit, expected_display", - [ - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), - ], -) -def test_lazy_loading_list(result, config, displaylimit, expected_display): - max_results_count = 3 - config.autolimit = False - config.displaylimit = displaylimit - result_set = ResultSet(result, config) + mock = Mock() + mock.displaylimit = 100 + mock.autolimit = 100000 + + engine.execute("CREATE TABLE a (x INT);") + engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);") + results = engine.execute("SELECT * FROM a") + # from ipdb import set_trace + + # set_trace() + rs = ResultSet(results, mock, statement="SELECT * FROM a") + df = rs.DataFrame() + + # TODO: check native duckdb was called if using duckb + assert df.to_dict() == {"x": {0: 1, 1: 2, 2: 3, 3: 4, 4: 5}} - # Initialize result_set without fetching results - assert len(result_set._results) == 0 - # Fetch the min number of rows (based on configuration) - result_set.fetch_results() +def test_done_fetching_if_reached_autolimit(results): + mock = Mock() + mock.autolimit = 2 + mock.displaylimit = 100 - expected_results = ( - max_results_count - if expected_display + 1 >= max_results_count - else expected_display + 1 - ) + rs = ResultSet(results, mock) - assert len(result_set._results) == expected_results - assert len(list(result_set)) == max_results_count + assert rs.did_finish_fetching() is True + + +def test_done_fetching_if_reached_autolimit_2(results): + mock = Mock() + mock.autolimit = 4 + mock.displaylimit = 100 + + rs = ResultSet(results, mock) + list(rs) + + assert rs.did_finish_fetching() is True + + +@pytest.mark.parametrize("method", ["__repr__", "_repr_html_"]) +@pytest.mark.parametrize("autolimit", [1000_000, 0]) +def test_no_displaylimit(results, method, autolimit): + mock = Mock() + mock.displaylimit = 0 + mock.autolimit = autolimit + + rs = ResultSet(results, mock) + getattr(rs, method)() + + assert rs._results == [(1,), (2,), (3,), (4,), (5,)] + assert rs.did_finish_fetching() is True + + +def test_no_fetching_if_one_result(): + engine = create_engine("duckdb://") + engine.execute("CREATE TABLE a (x INT,);") + engine.execute("INSERT INTO a(x) VALUES (1);") + + mock = Mock() + mock.displaylimit = 100 + mock.autolimit = 1000_000 + + results = engine.execute("SELECT * FROM a") + results.fetchmany = Mock(wraps=results.fetchmany) + results.fetchall = Mock(wraps=results.fetchall) + results.fetchone = Mock(wraps=results.fetchone) + + rs = ResultSet(results, mock) + + assert rs.did_finish_fetching() is True + results.fetchall.assert_not_called() + results.fetchmany.assert_called_once_with(size=2) + results.fetchone.assert_not_called() + + str(rs) + list(rs) + + results.fetchall.assert_not_called() + results.fetchmany.assert_called_once_with(size=2) + results.fetchone.assert_not_called() + + +# TODO: try with other values, and also change the display limit and re-run the repr +# TODO: try with __repr__ and __str__ +def test_resultset_fetches_required_rows(results): + mock = Mock() + mock.displaylimit = 3 + mock.autolimit = 1000_000 + + ResultSet(results, mock) + # rs.fetch_results() + # rs._repr_html_() + # str(rs) + + results.fetchall.assert_not_called() + results.fetchmany.assert_called_once_with(size=2) + results.fetchone.assert_not_called() + + +def test_fetches_remaining_rows(results): + mock = Mock() + mock.displaylimit = 1 + mock.autolimit = 1000_000 + + rs = ResultSet(results, mock) + + # this will trigger fetching two + str(rs) + + results.fetchall.assert_not_called() + results.fetchmany.assert_has_calls([call(size=2)]) + results.fetchone.assert_not_called() + + # this will trigger fetching the rest + assert list(rs) == [(1,), (2,), (3,), (4,), (5,)] + + results.fetchall.assert_called_once_with() + results.fetchmany.assert_has_calls([call(size=2)]) + results.fetchone.assert_not_called() + + rs.sqlaproxy.fetchmany = Mock(side_effect=ValueError("fetchmany called")) + rs.sqlaproxy.fetchall = Mock(side_effect=ValueError("fetchall called")) + rs.sqlaproxy.fetchone = Mock(side_effect=ValueError("fetchone called")) + + # this should not trigger any more fetching + assert list(rs) == [(1,), (2,), (3,), (4,), (5,)] @pytest.mark.parametrize( - "autolimit, expected_results", + "method, repr_", [ - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), + [ + "__repr__", + "+---+\n| x |\n+---+\n| 1 |\n| 2 |\n| 3 |\n+---+", + ], + [ + "_repr_html_", + "\n \n \n \n " + "\n \n \n \n " + "\n \n \n " + "\n \n \n " + "\n \n \n
x
1
2
3
", + ], ], ) -def test_lazy_loading_autolimit_list(result, config, autolimit, expected_results): - config.autolimit = autolimit - result_set = ResultSet(result, config) - assert len(result_set._results) == 0 +def test_resultset_fetches_required_rows_repr_html(results, method, repr_): + mock = Mock() + mock.displaylimit = 3 + mock.autolimit = 1000_000 - result_set.fetch_results() + rs = ResultSet(results, mock) + rs_repr = getattr(rs, method)() - assert len(result_set._results) == expected_results - assert len(list(result_set)) == expected_results + assert repr_ in rs_repr + assert rs.did_finish_fetching() is False + results.fetchall.assert_not_called() + results.fetchmany.assert_has_calls([call(size=2), call(size=1)]) + results.fetchone.assert_not_called() -def _get_number_of_rows_in_html_table(html): - """ - Returns the number of tags within the section - """ - pattern = r"(.*?)<\/tbody>" - tbody_content = re.findall(pattern, html, re.DOTALL)[0] - row_count = len(re.findall(r"", tbody_content)) +def test_resultset_fetches_no_rows(results): + mock = Mock() + mock.displaylimit = 1 + mock.autolimit = 1000_000 - return row_count + ResultSet(results, mock) + + results.fetchmany.assert_has_calls([call(size=2)]) + results.fetchone.assert_not_called() + results.fetchall.assert_not_called() + + +def test_resultset_autolimit_one(results): + mock = Mock() + mock.displaylimit = 10 + mock.autolimit = 1 + + rs = ResultSet(results, mock) + repr(rs) + str(rs) + rs._repr_html_() + list(rs) + + results.fetchmany.assert_has_calls([call(size=1)]) + results.fetchone.assert_not_called() + results.fetchall.assert_not_called() + + +# TODO: try with more values of displaylimit +# TODO: test some edge cases. e.g., displaylimit is set to 10 but we only have 5 rows +def test_displaylimit_message(results): + mock = Mock() + mock.displaylimit = 1 + mock.autolimit = 0 + + rs = ResultSet(results, mock) + + assert "Truncated to displaylimit of 1" in rs._repr_html_()