diff --git a/CHANGELOG.md b/CHANGELOG.md index 186a7c9cf..0772e960b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ * [Doc] Add Redshift tutorial * [Feature] Adds Redshift support for `%sqlplot boxplot` * [Fix] Improves performance when converting DuckDB results to `pandas.DataFrame` +* [Fix] Added support for `profile` and `explore` commands with saved snippets (#658) ## 0.9.0 (2023-08-01) @@ -41,6 +42,7 @@ * [Doc] Document `--persist-replace` in API section ([#539](https://github.com/ploomber/jupysql/issues/539)) * [Doc] Re-organized sections. Adds section showing how to share notebooks via Ploomber Cloud + ## 0.7.9 (2023-06-19) * [Feature] Modified `histogram` command to support data with NULL values ([#176](https://github.com/ploomber/jupysql/issues/176)) diff --git a/doc/tutorials/duckdb-native-sqlalchemy.md b/doc/tutorials/duckdb-native-sqlalchemy.md index a296d5bfa..a6cc9f60f 100644 --- a/doc/tutorials/duckdb-native-sqlalchemy.md +++ b/doc/tutorials/duckdb-native-sqlalchemy.md @@ -41,6 +41,7 @@ df = pd.DataFrame(np.random.randn(num_rows, num_cols)) ```{code-cell} ipython3 import duckdb + conn = duckdb.connect() ``` diff --git a/src/sql/inspect.py b/src/sql/inspect.py index 4307ca51b..bfb414a5c 100644 --- a/src/sql/inspect.py +++ b/src/sql/inspect.py @@ -172,7 +172,7 @@ class Columns(DatabaseInspection): """ def __init__(self, name, schema, conn=None) -> None: - util.is_table_exists(name, schema) + util.is_saved_snippet_or_table_exists(name, schema) inspector = _get_inspector(conn) @@ -230,14 +230,19 @@ class TableDescription(DatabaseInspection): """ def __init__(self, table_name, schema=None) -> None: - util.is_table_exists(table_name, schema) + with_ = util.is_saved_snippet_or_table_exists( + table_name, schema, task="profile" + ) if schema: table_name = f"{schema}.{table_name}" conn = ConnectionManager.current - columns_query_result = conn.raw_execute(f"SELECT * FROM {table_name} WHERE 1=0") + columns_query_result = conn.execute( + f"SELECT * FROM {table_name} WHERE 1=0", with_=with_ + ) + if ConnectionManager.current.is_dbapi_connection: columns = [i[0] for i in columns_query_result.description] else: @@ -253,8 +258,8 @@ def __init__(self, table_name, schema=None) -> None: # check the datatype of a column try: - result = ConnectionManager.current.raw_execute( - f"""SELECT {column} FROM {table_name} LIMIT 1""" + result = conn.execute( + f"SELECT {column} FROM {table_name} LIMIT 1", with_=with_ ).fetchone() value = result[0] @@ -269,10 +274,11 @@ def __init__(self, table_name, schema=None) -> None: message_check = True # Note: index is reserved word in sqlite try: - result_col_freq_values = ConnectionManager.current.raw_execute( + result_col_freq_values = conn.execute( f"""SELECT DISTINCT {column} as top, COUNT({column}) as frequency FROM {table_name} GROUP BY top ORDER BY frequency Desc""", + with_=with_, ).fetchall() table_stats[column]["freq"] = result_col_freq_values[0][1] @@ -285,14 +291,13 @@ def __init__(self, table_name, schema=None) -> None: try: # get all non None values, min, max and avg. - result_value_values = ConnectionManager.current.raw_execute( - f""" - SELECT MIN({column}) AS min, + result_value_values = conn.execute( + f"""SELECT MIN({column}) AS min, MAX({column}) AS max, COUNT({column}) AS count FROM {table_name} - WHERE {column} IS NOT NULL - """, + WHERE {column} IS NOT NULL""", + with_=with_, ).fetchall() columns_to_include_in_report.update(["count", "min", "max"]) @@ -308,13 +313,12 @@ def __init__(self, table_name, schema=None) -> None: try: # get unique values - result_value_values = ConnectionManager.current.raw_execute( - f""" - SELECT + result_value_values = conn.execute( + f""" SELECT COUNT(DISTINCT {column}) AS unique_count FROM {table_name} - WHERE {column} IS NOT NULL - """, + WHERE {column} IS NOT NULL""", + with_=with_, ).fetchall() table_stats[column]["unique"] = result_value_values[0][0] columns_to_include_in_report.update(["unique"]) @@ -322,12 +326,11 @@ def __init__(self, table_name, schema=None) -> None: pass try: - results_avg = ConnectionManager.current.raw_execute( - f""" - SELECT AVG({column}) AS avg - FROM {table_name} - WHERE {column} IS NOT NULL - """, + results_avg = conn.execute( + f""" SELECT AVG({column}) AS avg + FROM {table_name} + WHERE {column} IS NOT NULL""", + with_=with_, ).fetchall() columns_to_include_in_report.update(["mean"]) @@ -341,9 +344,8 @@ def __init__(self, table_name, schema=None) -> None: try: # Note: stddev_pop and PERCENTILE_DISC will work only on DuckDB - result = ConnectionManager.current.raw_execute( - f""" - SELECT + result = conn.execute( + f""" SELECT stddev_pop({column}) as key_std, percentile_disc(0.25) WITHIN GROUP (ORDER BY {column}) as key_25, @@ -351,8 +353,8 @@ def __init__(self, table_name, schema=None) -> None: (ORDER BY {column}) as key_50, percentile_disc(0.75) WITHIN GROUP (ORDER BY {column}) as key_75 - FROM {table_name} - """, + FROM {table_name}""", + with_=with_, ).fetchall() columns_to_include_in_report.update(special_numeric_keys) diff --git a/src/sql/magic_plot.py b/src/sql/magic_plot.py index ab33fa3b6..4836b367c 100644 --- a/src/sql/magic_plot.py +++ b/src/sql/magic_plot.py @@ -92,7 +92,7 @@ def execute(self, line="", cell="", local_ns=None): if cmd.args.with_: with_ = cmd.args.with_ else: - with_ = self._check_table_exists(table) + with_ = util.is_saved_snippet_or_table_exists(table, task="plot") if cmd.args.line[0] in {"box", "boxplot"}: return plot.boxplot( @@ -127,12 +127,3 @@ def execute(self, line="", cell="", local_ns=None): show_num=cmd.args.show_numbers, conn=None, ) - - @staticmethod - def _check_table_exists(table): - with_ = None - if util.is_saved_snippet(table): - with_ = [table] - else: - util.is_table_exists(table) - return with_ diff --git a/src/sql/util.py b/src/sql/util.py index 424852ba8..5731588a7 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -187,8 +187,15 @@ def strip_multiple_chars(string: str, chars: str) -> str: def is_saved_snippet(table: str) -> bool: + """ + checks if table is part of a snippet + Parameters + ---------- + table: str + Table name + + """ if table in list(store): - display.message(f"Plotting using saved snippet : {table}") return True return False @@ -264,7 +271,7 @@ def support_only_sql_alchemy_connection(command): def fetch_sql_with_pagination( - table, offset, n_rows, sort_column=None, sort_order=None + table, offset, n_rows, sort_column=None, sort_order=None, with_=None ) -> tuple: """ Returns next n_rows and columns from table starting at the offset @@ -286,19 +293,22 @@ def fetch_sql_with_pagination( sort_order : 'DESC' or 'ASC', default None Order list + + with_ : str, default None + Name of the snippet used to generate the table """ - is_table_exists(table) + if with_ is None: + is_table_exists(table) order_by = "" if not sort_column else f"ORDER BY {sort_column} {sort_order}" + query = f""" SELECT * FROM {table} {order_by} + LIMIT {n_rows} OFFSET {offset}""" - query = f""" - SELECT * FROM {table} {order_by} - OFFSET {offset} ROWS FETCH NEXT {n_rows} ROWS ONLY""" - - rows = ConnectionManager.current.execute(query).fetchall() + rows = ConnectionManager.current.execute(query, with_=with_).fetchall() - columns = ConnectionManager.current.raw_execute( - f"SELECT * FROM {table} WHERE 1=0" + columns = ConnectionManager.current.execute( + f""" SELECT * FROM {table} WHERE 1=0""", + with_=with_, ).keys() return rows, columns @@ -374,6 +384,33 @@ def show_deprecation_warning(): ) +def is_saved_snippet_or_table_exists(table, schema=None, task=None): + """ + Check if the referenced table is a snippet + Parameters + ---------- + table : str, + name of table + + schema : str, default None + Name of the schema + + task: str, default None + Command calling this function + + Returns + ------- + with_ : str, default None + name of the snippet + """ + with_ = None + if is_saved_snippet(table): + with_ = [table] + else: + is_table_exists(table, schema) + return with_ + + def find_path_from_root(file_name): """ Recursively finds an absolute path to file_name starting diff --git a/src/sql/widgets/table_widget/table_widget.py b/src/sql/widgets/table_widget/table_widget.py index 4c8d955fe..9872da7d0 100644 --- a/src/sql/widgets/table_widget/table_widget.py +++ b/src/sql/widgets/table_widget/table_widget.py @@ -5,9 +5,8 @@ from sql.util import ( fetch_sql_with_pagination, parse_sql_results_to_json, - is_table_exists, + is_saved_snippet_or_table_exists, ) - from sql.widgets import utils from sql.telemetry import telemetry @@ -32,7 +31,7 @@ def __init__(self, table): self.html = "" - is_table_exists(table) + self.with_ = is_saved_snippet_or_table_exists(table, task="explore") # load css html_style = utils.load_css(f"{BASE_DIR}/css/tableWidget.css") @@ -57,11 +56,19 @@ def create_table(self, table): Creates an HTML table with default data """ rows_per_page = 10 - rows, columns = fetch_sql_with_pagination(table, 0, rows_per_page) + rows, columns = fetch_sql_with_pagination( + table, 0, rows_per_page, with_=self.with_ + ) rows = parse_sql_results_to_json(rows, columns) - query = f"SELECT count(*) FROM {table}" + query = str( + ConnectionManager.current._prepare_query( + f""" SELECT count(*) FROM {table}""", + with_=self.with_, + ) + ) n_total = ConnectionManager.current.raw_execute(query).fetchone()[0] + table_name = table.strip('"').strip("'") n_pages = math.ceil(n_total / rows_per_page) @@ -184,6 +191,7 @@ def _recv(msg): n_rows, sort_column=sort_column, sort_order=sort_order, + with_=self.with_, ) rows_json = parse_sql_results_to_json(rows, columns) diff --git a/src/tests/integration/test_generic_db_operations.py b/src/tests/integration/test_generic_db_operations.py index 49d982ecb..fa229f152 100644 --- a/src/tests/integration/test_generic_db_operations.py +++ b/src/tests/integration/test_generic_db_operations.py @@ -7,7 +7,7 @@ from sql.error_message import CTE_MSG from unittest.mock import ANY, Mock from IPython.core.error import UsageError - +from IPython.utils.io import capture_output import math ALL_DATABASES = [ @@ -1095,3 +1095,49 @@ def test_autocommit_create_table_multiple_cells( ).result assert len(result) == 3 + + +@pytest.mark.parametrize("ip_with_dynamic_db", ALL_DATABASES) +def test_explore_fetch_sql_with_pagination( + ip_with_dynamic_db, + request, + tmp_empty, +): + if ip_with_dynamic_db not in [ + "ip_with_postgreSQL", + "ip_with_duckDB", + "ip_with_SQLite", + ]: + pytest.xfail(reason="There may be an issue due to code refactoring") + + ip_with_dynamic_db = request.getfixturevalue(ip_with_dynamic_db) + + ip_with_dynamic_db.run_cell( + """ + %%sql + CREATE TABLE people (name varchar(50), number int, country varchar(50)); + INSERT INTO people VALUES ('joe', 82, 'usa'); + INSERT INTO people VALUES ('paula', 93, 'uk'); + INSERT INTO people VALUES ('sam', 77, 'canada'); + INSERT INTO people VALUES ('emily', 65, 'usa'); + INSERT INTO people VALUES ('michael', 89, 'usa'); + INSERT INTO people VALUES ('sarah', 81, 'uk'); + INSERT INTO people VALUES ('james', 76, 'usa'); + INSERT INTO people VALUES ('angela', 88, 'usa'); + INSERT INTO people VALUES ('robert', 82, 'usa'); + INSERT INTO people VALUES ('lisa', 92, 'uk'); + INSERT INTO people VALUES ('mark', 77, 'usa'); + INSERT INTO people VALUES ('jennifer', 68, 'australia'); + """ + ) + ip_with_dynamic_db.run_cell( + """ + %%sql --save citizen + select * from people where country = 'usa' + """ + ) + with capture_output() as captured: + ip_with_dynamic_db.run_cell("%sqlcmd explore -t citizen") + + assert "sql.widgets.table_widget.table_widget.TableWidget object" in captured.stdout + ip_with_dynamic_db.run_cell("""%sql DROP TABLE people""") diff --git a/src/tests/test_magic_cmd.py b/src/tests/test_magic_cmd.py index 965347d46..89dc5c54f 100644 --- a/src/tests/test_magic_cmd.py +++ b/src/tests/test_magic_cmd.py @@ -9,6 +9,7 @@ from sql.store import store from sql.inspect import _is_numeric from sql.display import Table, Message +from IPython.utils.io import capture_output VALID_COMMANDS_MESSAGE = ( @@ -386,6 +387,59 @@ def test_table_profile_store(ip, tmp_empty): assert report.is_file() +def test_table_profile_with_snippets(ip, tmp_empty): + ip.run_cell( + """ + %%sql + CREATE TABLE people (name varchar(50), number int, country varchar(50)); + INSERT INTO people VALUES ('joe', 82, 'usa'); + INSERT INTO people VALUES ('paula', 93, 'uk'); + INSERT INTO people VALUES ('sam', 77, 'canada'); + INSERT INTO people VALUES ('emily', 65, 'usa'); + INSERT INTO people VALUES ('michael', 89, 'usa'); + INSERT INTO people VALUES ('sarah', 81, 'uk'); + INSERT INTO people VALUES ('james', 76, 'usa'); + INSERT INTO people VALUES ('angela', 88, 'usa'); + INSERT INTO people VALUES ('robert', 82, 'usa'); + INSERT INTO people VALUES ('lisa', 92, 'uk'); + INSERT INTO people VALUES ('mark', 77, 'usa'); + INSERT INTO people VALUES ('jennifer', 68, 'australia'); + """ + ) + ip.run_cell( + """ + %%sql --save citizen + select * from people where country = 'usa' + """ + ) + out = ip.run_cell("%sqlcmd profile -t citizen").result + + expected = { + "count": [7, 7, 7], + "mean": [math.nan, "79.8571", math.nan], + "min": [math.nan, 65, math.nan], + "max": [math.nan, 89, math.nan], + "unique": [7, 6, 1], + "freq": [1, math.nan, 7], + "top": ["robert", math.nan, "usa"], + } + + stats_table = out._table + + assert len(stats_table.rows) == len(expected) + + for row in stats_table: + profile_metric = _get_row_string(row, " ") + name = _get_row_string(row, "name") + number = _get_row_string(row, "number") + country = _get_row_string(row, "country") + + assert profile_metric in expected + assert name == str(expected[profile_metric][0]) + assert number == str(expected[profile_metric][1]) + assert country == str(expected[profile_metric][2]) + + @pytest.mark.parametrize( "cell, error_message", [ @@ -605,3 +659,34 @@ def test_delete_invalid_snippet(arg, ip_snippets): assert excinfo.value.error_type == "UsageError" assert str(excinfo.value) == "No such saved snippet found : non_existent_snippet" + + +def test_table_explore_with_snippets(ip, tmp_empty): + ip.run_cell( + """ + %%sql + CREATE TABLE people (name varchar(50), number int, country varchar(50)); + INSERT INTO people VALUES ('joe', 82, 'usa'); + INSERT INTO people VALUES ('paula', 93, 'uk'); + INSERT INTO people VALUES ('sam', 77, 'canada'); + INSERT INTO people VALUES ('emily', 65, 'usa'); + INSERT INTO people VALUES ('michael', 89, 'usa'); + INSERT INTO people VALUES ('sarah', 81, 'uk'); + INSERT INTO people VALUES ('james', 76, 'usa'); + INSERT INTO people VALUES ('angela', 88, 'usa'); + INSERT INTO people VALUES ('robert', 82, 'usa'); + INSERT INTO people VALUES ('lisa', 92, 'uk'); + INSERT INTO people VALUES ('mark', 77, 'usa'); + INSERT INTO people VALUES ('jennifer', 68, 'australia'); + """ + ) + ip.run_cell( + """ + %%sql --save citizen + select * from people where country = 'usa' + """ + ) + with capture_output() as captured: + ip.run_cell("%sqlcmd explore -t citizen") + + assert "sql.widgets.table_widget.table_widget.TableWidget object" in captured.stdout