diff --git a/src/tests/integration/test_duckDB.py b/src/tests/integration/test_duckDB.py
index 0e85d69be..c7105a7a2 100644
--- a/src/tests/integration/test_duckDB.py
+++ b/src/tests/integration/test_duckDB.py
@@ -7,6 +7,7 @@
from sql.connection import Connection
from sql.run import ResultSet
+from sql import connection
@pytest.fixture
@@ -85,3 +86,133 @@ def test_can_convert_to_bears_natively(
getattr(mock, expected_native_method).assert_called_once_with()
assert isinstance(out.result, expected_type)
assert out.result.shape == (2, 2)
+
+
+@pytest.mark.parametrize(
+ "config",
+ [
+ "%config SqlMagic.autopandas = True",
+ "%config SqlMagic.autopandas = False",
+ ],
+ ids=[
+ "autopandas_on",
+ "autopandas_off",
+ ],
+)
+@pytest.mark.parametrize(
+ "sql, tables",
+ [
+ ["%sql SELECT * FROM weather; SELECT * FROM weather;", ["weather"]],
+ [
+ "%sql CREATE TABLE names (name VARCHAR,); SELECT * FROM weather;",
+ ["weather", "names"],
+ ],
+ [
+ (
+ "%sql CREATE TABLE names (city VARCHAR,);"
+ "CREATE TABLE more_names (city VARCHAR,);"
+ "INSERT INTO names VALUES ('NYC');"
+ "SELECT * FROM names UNION ALL SELECT * FROM more_names;"
+ ),
+ ["weather", "names", "more_names"],
+ ],
+ ],
+ ids=[
+ "multiple_selects",
+ "multiple_statements",
+ "multiple_tables_created",
+ ],
+)
+def test_multiple_statements(ip_empty, config, sql, tables):
+ ip_empty.run_cell("%sql duckdb://")
+ ip_empty.run_cell(config)
+
+ ip_empty.run_cell("%sql CREATE TABLE weather (city VARCHAR,);")
+ ip_empty.run_cell("%sql INSERT INTO weather VALUES ('NYC');")
+ ip_empty.run_cell("%sql SELECT * FROM weather;")
+
+ out = ip_empty.run_cell(sql)
+ out_tables = ip_empty.run_cell("%sqlcmd tables")
+
+ assert out.error_in_exec is None
+
+ if config == "%config SqlMagic.autopandas = True":
+ assert out.result.to_dict() == {"city": {0: "NYC"}}
+ else:
+ assert out.result.dict() == {"city": ("NYC",)}
+
+ assert set(tables) == set(r[0] for r in out_tables.result._table.rows)
+
+
+def test_dataframe_returned_only_if_last_statement_is_select(ip_empty):
+ ip_empty.run_cell("%sql duckdb://")
+ ip_empty.run_cell("%config SqlMagic.autopandas=True")
+ connection.Connection.connections["duckdb://"].engine.raw_connection = Mock(
+ side_effect=ValueError("some error")
+ )
+
+ out = ip_empty.run_cell(
+ "%sql CREATE TABLE a (c VARCHAR,); CREATE TABLE b (c VARCHAR,);"
+ )
+
+ assert out.error_in_exec is None
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ (
+ "%sql CREATE TABLE a (x INT,); CREATE TABLE b (x INT,); "
+ "INSERT INTO a VALUES (1,); INSERT INTO b VALUES(2,); "
+ "SELECT * FROM a UNION ALL SELECT * FROM b;"
+ ),
+ """\
+%%sql
+CREATE TABLE a (x INT,);
+CREATE TABLE b (x INT,);
+INSERT INTO a VALUES (1,);
+INSERT INTO b VALUES(2,);
+SELECT * FROM a UNION ALL SELECT * FROM b;
+""",
+ ],
+)
+def test_commits_all_statements(ip_empty, sql):
+ ip_empty.run_cell("%sql duckdb://")
+ out = ip_empty.run_cell(sql)
+ assert out.error_in_exec is None
+ assert out.result.dict() == {"x": (1, 2)}
+
+
+def test_resultset_uses_native_duckdb_df(ip_empty):
+ from sqlalchemy import create_engine
+ from sql.connection import Connection
+
+ engine = create_engine("duckdb://")
+
+ engine.execute("CREATE TABLE a (x INT,);")
+ engine.execute("INSERT INTO a(x) VALUES (10),(20),(30);")
+
+ sql = "SELECT * FROM a"
+
+ # this breaks if there's an open results set
+ engine.execute(sql).fetchall()
+
+ results = engine.execute(sql)
+
+ Connection.set(engine, displaycon=False)
+
+ results.fetchmany = Mock(wraps=results.fetchmany)
+
+ mock = Mock()
+ mock.displaylimit = 1
+
+ result_set = ResultSet(results, mock, statement=sql)
+
+ result_set.fetch_results()
+
+ df = result_set.DataFrame()
+
+ assert isinstance(df, pd.DataFrame)
+ assert df.to_dict() == {"x": {0: 1, 1: 2, 2: 3}}
+
+ results.fetchmany.assert_called_once_with(size=1)
diff --git a/src/tests/test_resultset.py b/src/tests/test_resultset.py
index 4bf6cb337..ec58e8d56 100644
--- a/src/tests/test_resultset.py
+++ b/src/tests/test_resultset.py
@@ -1,7 +1,8 @@
+from unittest.mock import Mock, call
+
from sqlalchemy import create_engine
-from sql.connection import Connection
from pathlib import Path
-from unittest.mock import Mock
+
import pytest
import pandas as pd
@@ -9,7 +10,6 @@
import sqlalchemy
from sql.run import ResultSet
-from sql import run as run_module
import re
@@ -35,7 +35,7 @@ def result():
@pytest.fixture
def result_set(result, config):
- return ResultSet(result, config).fetch_results()
+ return ResultSet(result, config)
def test_resultset_getitem(result_set):
@@ -51,12 +51,6 @@ def test_resultset_dicts(result_set):
assert list(result_set.dicts()) == [{"x": 0}, {"x": 1}, {"x": 2}]
-def test_resultset_dataframe(result_set, monkeypatch):
- monkeypatch.setattr(run_module.Connection, "current", Mock())
-
- assert result_set.DataFrame().equals(pd.DataFrame({"x": range(3)}))
-
-
def test_resultset_polars_dataframe(result_set, monkeypatch):
assert result_set.PolarsDataFrame().frame_equal(pl.DataFrame({"x": range(3)}))
@@ -71,215 +65,276 @@ def test_resultset_str(result_set):
assert str(result_set) == "+---+\n| x |\n+---+\n| 0 |\n| 1 |\n| 2 |\n+---+"
-def test_resultset_repr_html(result_set):
- assert result_set._repr_html_() == (
- "
\n \n \n "
- "x | \n
\n \n \n "
- "\n 0 | \n
\n \n "
- "1 | \n
\n \n 2 | \n "
- "
\n \n
\n"
- ""
- "ResultSet
: to convert to pandas, call "
- ".DataFrame()
or to polars, call "
- ".PolarsDataFrame()
"
- )
+def test_resultset_config_autolimit_dict(result, config):
+ config.autolimit = 1
+ resultset = ResultSet(result, config)
+ assert resultset.dict() == {"x": (0,)}
-@pytest.mark.parametrize(
- "fname, parameters",
- [
- ("head", -1),
- ("tail", None),
- ("value_counts", None),
- ("not_df_function", None),
- ],
-)
-def test_invalid_operation_error(result_set, fname, parameters):
- with pytest.raises(AttributeError) as excinfo:
- getattr(result_set, fname)(parameters)
+def _get_number_of_rows_in_html_table(html):
+ """
+ Returns the number of tags within the
section
+ """
+ pattern = r"(.*?)<\/tbody>"
+ tbody_content = re.findall(pattern, html, re.DOTALL)[0]
+ row_count = len(re.findall(r"", tbody_content))
- assert str(excinfo.value) == (
- f"'{fname}' is not a valid operation, you can convert this "
- "into a pandas data frame by calling '.DataFrame()' or a "
- "polars data frame by calling '.PolarsDataFrame()'"
- )
+ return row_count
-def test_resultset_config_autolimit_dict(result, config):
- config.autolimit = 1
- resultset = ResultSet(result, config).fetch_results()
- assert resultset.dict() == {"x": (0,)}
+# TODO: add dbapi tests
-def test_resultset_with_non_sqlalchemy_results(config):
- df = pd.DataFrame({"x": range(3)}) # noqa
- conn = Connection(engine=create_engine("duckdb://"))
- result = conn.execute("SELECT * FROM df")
- assert ResultSet(result, config).fetch_results() == [(0,), (1,), (2,)]
+@pytest.fixture
+def results(ip_empty):
+ engine = create_engine("duckdb://")
+ engine.execute("CREATE TABLE a (x INT,);")
-def test_none_pretty(config):
- conn = Connection(engine=create_engine("sqlite://"))
- result = conn.execute("create table some_table (name, age)")
- result_set = ResultSet(result, config)
- assert result_set.pretty is None
- assert "" == str(result_set)
+ engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);")
+ sql = "SELECT * FROM a"
+ results = engine.execute(sql)
-def test_lazy_loading(result, config):
- resultset = ResultSet(result, config)
- assert len(resultset._results) == 0
- resultset.fetch_results()
- assert len(resultset._results) == 3
+ results.fetchmany = Mock(wraps=results.fetchmany)
+ results.fetchall = Mock(wraps=results.fetchall)
+ results.fetchone = Mock(wraps=results.fetchone)
+ # results.fetchone = Mock(side_effect=ValueError("fetchone called"))
-@pytest.mark.parametrize(
- "autolimit, expected",
- [
- (None, 3),
- (False, 3),
- (0, 3),
- (1, 1),
- (2, 2),
- (3, 3),
- (4, 3),
- ],
-)
-def test_lazy_loading_autolimit(result, config, autolimit, expected):
- config.autolimit = autolimit
- resultset = ResultSet(result, config)
- assert len(resultset._results) == 0
- resultset.fetch_results()
- assert len(resultset._results) == expected
+ yield results
-@pytest.mark.parametrize(
- "displaylimit, expected",
- [
- (0, 3),
- (1, 1),
- (2, 2),
- (3, 3),
- (4, 3),
- ],
-)
-def test_lazy_loading_displaylimit(result, config, displaylimit, expected):
- config.displaylimit = displaylimit
- result_set = ResultSet(result, config)
+@pytest.mark.parametrize("uri", ["duckdb://", "sqlite://"])
+def test_convert_to_dataframe(ip_empty, uri):
+ engine = create_engine(uri)
- assert len(result_set._results) == 0
- result_set.fetch_results()
- html = result_set._repr_html_()
- row_count = _get_number_of_rows_in_html_table(html)
- assert row_count == expected
+ mock = Mock()
+ mock.displaylimit = 100
+ mock.autolimit = 100000
+ results = engine.execute("CREATE TABLE a (x INT);")
+ # results = engine.execute("CREATE TABLE test (n INT, name TEXT)")
-@pytest.mark.parametrize(
- "displaylimit, expected_display",
- [
- (0, 3),
- (1, 1),
- (2, 2),
- (3, 3),
- (4, 3),
- ],
-)
-def test_lazy_loading_displaylimit_fetch_all(
- result, config, displaylimit, expected_display
-):
- max_results_count = 3
- config.autolimit = False
- config.displaylimit = displaylimit
- result_set = ResultSet(result, config)
+ rs = ResultSet(results, mock)
+ df = rs.DataFrame()
- # Initialize result_set without fetching results
- assert len(result_set._results) == 0
+ assert df.to_dict() == {}
- # Fetch the min number of rows (based on configuration)
- result_set.fetch_results()
- html = result_set._repr_html_()
- row_count = _get_number_of_rows_in_html_table(html)
- expected_results = (
- max_results_count
- if expected_display + 1 >= max_results_count
- else expected_display + 1
- )
+def test_convert_to_dataframe_2(ip_empty):
+ engine = create_engine("duckdb://")
- assert len(result_set._results) == expected_results
- assert row_count == expected_display
+ mock = Mock()
+ mock.displaylimit = 100
+ mock.autolimit = 100000
- # Fetch the the rest results, but don't display them in the table
- result_set.fetch_results(fetch_all=True)
+ engine.execute("CREATE TABLE a (x INT,);")
+ results = engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);")
+ rs = ResultSet(results, mock)
+ df = rs.DataFrame()
- html = result_set._repr_html_()
- row_count = _get_number_of_rows_in_html_table(html)
+ assert df.to_dict() == {"Count": {0: 5}}
- assert len(result_set._results) == max_results_count
- assert row_count == expected_display
+@pytest.mark.parametrize("uri", ["duckdb://", "sqlite://"])
+def test_convert_to_dataframe_3(ip_empty, uri):
+ engine = create_engine(uri)
-@pytest.mark.parametrize(
- "displaylimit, expected_display",
- [
- (0, 3),
- (1, 1),
- (2, 2),
- (3, 3),
- (4, 3),
- ],
-)
-def test_lazy_loading_list(result, config, displaylimit, expected_display):
- max_results_count = 3
- config.autolimit = False
- config.displaylimit = displaylimit
- result_set = ResultSet(result, config)
+ mock = Mock()
+ mock.displaylimit = 100
+ mock.autolimit = 100000
+
+ engine.execute("CREATE TABLE a (x INT);")
+ engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);")
+ results = engine.execute("SELECT * FROM a")
+ # from ipdb import set_trace
+
+ # set_trace()
+ rs = ResultSet(results, mock, statement="SELECT * FROM a")
+ df = rs.DataFrame()
+
+ # TODO: check native duckdb was called if using duckb
+ assert df.to_dict() == {"x": {0: 1, 1: 2, 2: 3, 3: 4, 4: 5}}
- # Initialize result_set without fetching results
- assert len(result_set._results) == 0
- # Fetch the min number of rows (based on configuration)
- result_set.fetch_results()
+def test_done_fetching_if_reached_autolimit(results):
+ mock = Mock()
+ mock.autolimit = 2
+ mock.displaylimit = 100
- expected_results = (
- max_results_count
- if expected_display + 1 >= max_results_count
- else expected_display + 1
- )
+ rs = ResultSet(results, mock)
- assert len(result_set._results) == expected_results
- assert len(list(result_set)) == max_results_count
+ assert rs.did_finish_fetching() is True
+
+
+def test_done_fetching_if_reached_autolimit_2(results):
+ mock = Mock()
+ mock.autolimit = 4
+ mock.displaylimit = 100
+
+ rs = ResultSet(results, mock)
+ list(rs)
+
+ assert rs.did_finish_fetching() is True
+
+
+@pytest.mark.parametrize("method", ["__repr__", "_repr_html_"])
+@pytest.mark.parametrize("autolimit", [1000_000, 0])
+def test_no_displaylimit(results, method, autolimit):
+ mock = Mock()
+ mock.displaylimit = 0
+ mock.autolimit = autolimit
+
+ rs = ResultSet(results, mock)
+ getattr(rs, method)()
+
+ assert rs._results == [(1,), (2,), (3,), (4,), (5,)]
+ assert rs.did_finish_fetching() is True
+
+
+def test_no_fetching_if_one_result():
+ engine = create_engine("duckdb://")
+ engine.execute("CREATE TABLE a (x INT,);")
+ engine.execute("INSERT INTO a(x) VALUES (1);")
+
+ mock = Mock()
+ mock.displaylimit = 100
+ mock.autolimit = 1000_000
+
+ results = engine.execute("SELECT * FROM a")
+ results.fetchmany = Mock(wraps=results.fetchmany)
+ results.fetchall = Mock(wraps=results.fetchall)
+ results.fetchone = Mock(wraps=results.fetchone)
+
+ rs = ResultSet(results, mock)
+
+ assert rs.did_finish_fetching() is True
+ results.fetchall.assert_not_called()
+ results.fetchmany.assert_called_once_with(size=2)
+ results.fetchone.assert_not_called()
+
+ str(rs)
+ list(rs)
+
+ results.fetchall.assert_not_called()
+ results.fetchmany.assert_called_once_with(size=2)
+ results.fetchone.assert_not_called()
+
+
+# TODO: try with other values, and also change the display limit and re-run the repr
+# TODO: try with __repr__ and __str__
+def test_resultset_fetches_required_rows(results):
+ mock = Mock()
+ mock.displaylimit = 3
+ mock.autolimit = 1000_000
+
+ ResultSet(results, mock)
+ # rs.fetch_results()
+ # rs._repr_html_()
+ # str(rs)
+
+ results.fetchall.assert_not_called()
+ results.fetchmany.assert_called_once_with(size=2)
+ results.fetchone.assert_not_called()
+
+
+def test_fetches_remaining_rows(results):
+ mock = Mock()
+ mock.displaylimit = 1
+ mock.autolimit = 1000_000
+
+ rs = ResultSet(results, mock)
+
+ # this will trigger fetching two
+ str(rs)
+
+ results.fetchall.assert_not_called()
+ results.fetchmany.assert_has_calls([call(size=2)])
+ results.fetchone.assert_not_called()
+
+ # this will trigger fetching the rest
+ assert list(rs) == [(1,), (2,), (3,), (4,), (5,)]
+
+ results.fetchall.assert_called_once_with()
+ results.fetchmany.assert_has_calls([call(size=2)])
+ results.fetchone.assert_not_called()
+
+ rs.sqlaproxy.fetchmany = Mock(side_effect=ValueError("fetchmany called"))
+ rs.sqlaproxy.fetchall = Mock(side_effect=ValueError("fetchall called"))
+ rs.sqlaproxy.fetchone = Mock(side_effect=ValueError("fetchone called"))
+
+ # this should not trigger any more fetching
+ assert list(rs) == [(1,), (2,), (3,), (4,), (5,)]
@pytest.mark.parametrize(
- "autolimit, expected_results",
+ "method, repr_",
[
- (0, 3),
- (1, 1),
- (2, 2),
- (3, 3),
- (4, 3),
+ [
+ "__repr__",
+ "+---+\n| x |\n+---+\n| 1 |\n| 2 |\n| 3 |\n+---+",
+ ],
+ [
+ "_repr_html_",
+ "\n \n \n x | \n "
+ "
\n \n \n \n "
+ "1 | \n
\n \n "
+ "2 | \n
\n \n "
+ "3 | \n
\n \n
",
+ ],
],
)
-def test_lazy_loading_autolimit_list(result, config, autolimit, expected_results):
- config.autolimit = autolimit
- result_set = ResultSet(result, config)
- assert len(result_set._results) == 0
+def test_resultset_fetches_required_rows_repr_html(results, method, repr_):
+ mock = Mock()
+ mock.displaylimit = 3
+ mock.autolimit = 1000_000
- result_set.fetch_results()
+ rs = ResultSet(results, mock)
+ rs_repr = getattr(rs, method)()
- assert len(result_set._results) == expected_results
- assert len(list(result_set)) == expected_results
+ assert repr_ in rs_repr
+ assert rs.did_finish_fetching() is False
+ results.fetchall.assert_not_called()
+ results.fetchmany.assert_has_calls([call(size=2), call(size=1)])
+ results.fetchone.assert_not_called()
-def _get_number_of_rows_in_html_table(html):
- """
- Returns the number of
tags within the
section
- """
- pattern = r"(.*?)<\/tbody>"
- tbody_content = re.findall(pattern, html, re.DOTALL)[0]
- row_count = len(re.findall(r"", tbody_content))
+def test_resultset_fetches_no_rows(results):
+ mock = Mock()
+ mock.displaylimit = 1
+ mock.autolimit = 1000_000
- return row_count
+ ResultSet(results, mock)
+
+ results.fetchmany.assert_has_calls([call(size=2)])
+ results.fetchone.assert_not_called()
+ results.fetchall.assert_not_called()
+
+
+def test_resultset_autolimit_one(results):
+ mock = Mock()
+ mock.displaylimit = 10
+ mock.autolimit = 1
+
+ rs = ResultSet(results, mock)
+ repr(rs)
+ str(rs)
+ rs._repr_html_()
+ list(rs)
+
+ results.fetchmany.assert_has_calls([call(size=1)])
+ results.fetchone.assert_not_called()
+ results.fetchall.assert_not_called()
+
+
+# TODO: try with more values of displaylimit
+# TODO: test some edge cases. e.g., displaylimit is set to 10 but we only have 5 rows
+def test_displaylimit_message(results):
+ mock = Mock()
+ mock.displaylimit = 1
+ mock.autolimit = 0
+
+ rs = ResultSet(results, mock)
+
+ assert "Truncated to displaylimit of 1" in rs._repr_html_()