diff --git a/CHANGELOG.md b/CHANGELOG.md index ac15c181c..cdd98f1c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ * [Feature] Added a line under `ResultSet` to distinguish it from data frame and error message when invalid operations are performed (#468) * [Feature] Moved `%sqlrender` feature to `%sqlcmd snippets` (#647) * [Feature] Added tables listing stored snippets when `%sqlcmd snippets` is called (#648) -* [Feature] Modified `%load_ext sql` to load saved configurations from `pyproject.toml` file (#689) +* [Feature] Allow loading configuration values from a `pyproject.toml` file upon magic initialization (#689) * [Doc] Modified integrations content to ensure they're all consistent (#523) * [Doc] Document --persist-replace in API section (#539) * [Fix] Fixed CI issue by updating `invalid_connection_string_duckdb` in `test_magic.py` (#631) diff --git a/doc/api/configuration.md b/doc/api/configuration.md index 2d98a01c2..ad992808f 100644 --- a/doc/api/configuration.md +++ b/doc/api/configuration.md @@ -244,7 +244,7 @@ print(res) ## Loading configuration settings -You can define configurations in a `pyproject.toml` file and automatically load the configurations when you run `%load_ext sql`. It will try to find the file recursively from the current to root directory. If the file is not found, default values will be used. A sample `pyproject.toml` could look like this: +You can define configurations in a `pyproject.toml` file and automatically load the configurations when you run `%load_ext sql`. If the file is not found in the current or parent directories, default values will be used. A sample `pyproject.toml` could look like this: ``` [tool.jupysql.SqlMagic] diff --git a/src/sql/display.py b/src/sql/display.py index c0c2e7f6c..44a682363 100644 --- a/src/sql/display.py +++ b/src/sql/display.py @@ -90,3 +90,8 @@ def message(message): def message_success(message): """Display a success message""" display(Message(message, style="color: green")) + + +def table(headers, rows): + """Display a table""" + display(Table(headers, rows)) diff --git a/src/sql/magic.py b/src/sql/magic.py index cb92356c3..ff2340712 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -1,8 +1,5 @@ import json import re -from pathlib import Path -import toml -import difflib try: from ipywidgets import interact @@ -33,7 +30,7 @@ from sql.magic_plot import SqlPlotMagic from sql.magic_cmd import SqlCmdMagic from sql._patch import patch_ipython_usage_error -from sql import query_util +from sql import query_util, util from sql.util import get_suggestions_message, pretty_print from ploomber_core.dependencies import check_installed @@ -622,71 +619,37 @@ def _persist_dataframe( display.message_success(f"Success! Persisted {table_name} to the database.") -def find_toml_path(): - """ - Recursively finds an absolute path to pyproject.toml starting - from current directory - """ - current = Path().resolve() - while not (current / "pyproject.toml").exists(): - if current == current.parent: - return None - - current = current.parent - display.message(f"Found pyproject.toml from '{current}'") - - return str(current) + "/pyproject.toml" - - -def get_configs(file_path): - """Returns saved configs from given file_path""" - with open(file_path, "r") as file: - content = file.read() - data = toml.loads(content) - try: - return data["tool"]["jupysql"]["SqlMagic"] - except KeyError: - return {} - +def set_configs(ip, file_path): + """Set user defined SqlMagic configuration settings""" + sql = ip.find_cell_magic("sql").__self__ + user_configs = util.get_user_configs(file_path, ["tool", "jupysql", "SqlMagic"]) + default_configs = util.get_default_configs(sql) + table_rows = [] + for config, value in user_configs.items(): + if config in default_configs.keys(): + default_type = type(default_configs[config]) + if isinstance(value, default_type): + setattr(sql, config, value) + table_rows.append([config, value]) + else: + display.message( + f"'{value}' is an invalid value for '{config}'. " + f"Please use {default_type} value instead." + ) + else: + util.find_close_match_config(config, default_configs.keys()) -def get_config_variables(obj): - """Returns a list of configs users can set""" - config_vars = list(vars(obj)["_all_trait_default_generators"].keys()) - config_vars.remove("config") - config_vars.remove("parent") - return config_vars + return table_rows -def load_configs(ip): - """Loads saved configs in pyproject.toml""" - file_path = find_toml_path() +def load_SqlMagic_configs(ip): + """Loads saved SqlMagic configs in pyproject.toml""" + file_path = util.find_path_from_root("pyproject.toml") if file_path: - configs = get_configs(file_path) - config_vars = get_config_variables(SqlMagic) - for config, value in configs.items(): - if config in config_vars: - default_value = getattr(SqlMagic, config).default_value - if isinstance(value, type(default_value)): - ip.run_cell(f"%config SqlMagic.{config} = {value}") - display.message(f"Setting '{config}' into '{value}'") - else: - display.message( - f"'{value}' is an invalid value for '{config}'. " - f"Please use {type(default_value)} value instead." - ) - else: - closest_match = difflib.get_close_matches(config, config_vars, n=1) - if not closest_match: - display.message( - f"'{config}' is an invalid configuration. Please review our " - "configuration guideline: " - "https://jupysql.ploomber.io/en/latest/api/configuration.html#options" # noqa breaks the check-for-broken-links - ) - else: - display.message( - f"'{config}' is an invalid configuration. " - f"Did you mean '{closest_match[0]}'?" - ) + table_rows = set_configs(ip, file_path) + if table_rows: + display.message("Setting changed: ") + display.table(["Config", "value"], table_rows) def load_ipython_extension(ip): @@ -702,6 +665,6 @@ def load_ipython_extension(ip): ip.register_magics(SqlPlotMagic) ip.register_magics(SqlCmdMagic) - load_configs(ip) + load_SqlMagic_configs(ip) patch_ipython_usage_error(ip) diff --git a/src/sql/util.py b/src/sql/util.py index 3df5122cf..f3cfd4f65 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -4,8 +4,10 @@ import difflib from sql.connection import Connection from sql.store import store, _get_dependents_for_key -from sql import exceptions +from sql import exceptions, display import json +from pathlib import Path +import toml SINGLE_QUOTE = "'" DOUBLE_QUOTE = '"' @@ -363,3 +365,120 @@ def show_deprecation_warning(): "raise an exception in the next major release so please remove it.", FutureWarning, ) + + +def find_path_from_root(file_name): + """ + Recursively finds an absolute path to file_name starting + from current to root directory + """ + current = Path().resolve() + while not (current / file_name).exists(): + if current == current.parent: + return None + + current = current.parent + display.message(f"Found {file_name} from '{current}'") + + return str(Path(current, file_name)) + + +def find_close_match_config(word, possibilities, n=3): + """Finds closest matching configurations and displays message""" + closest_matches = difflib.get_close_matches(word, possibilities, n=n) + if not closest_matches: + display.message( + f"'{word}' is an invalid configuration. Please review our " + "configuration guideline: " + "https://jupysql.ploomber.io/en/latest/api/configuration.html#options" + ) + else: + display.message( + f"'{word}' is an invalid configuration. Did you mean " + f"{pretty_print(closest_matches, last_delimiter='or')}?" + ) + + +def get_line_content_from_toml(file_path, line_number): + """ + Locates a line that error occurs when loading a toml file + and returns the line, key, and value + """ + with open(file_path, "r") as file: + lines = file.readlines() + eline = lines[line_number - 1].strip() + ekey, evalue = None, None + if "=" in eline: + ekey, evalue = map(str.strip, eline.split("=")) + return eline, ekey, evalue + + +def load_toml(file_path): + """ + Returns toml file content in a dictionary format + and raises error if it fails to load the toml file + """ + try: + with open(file_path, "r") as file: + content = file.read() + return toml.loads(content) + except toml.TomlDecodeError as e: + eline, ekey, evalue = get_line_content_from_toml(file_path, e.lineno) + if "Duplicate keys!" in str(e): + raise ValueError(f"Duplicate key found : '{ekey}'") from None + elif "Only all lowercase booleans" in str(e): + raise ValueError( + f"Invalid value '{evalue}' in '{eline}'. " + "Valid boolean values: true, false" + ) from None + elif "invalid literal for int()" in str(e): + raise ValueError( + f"Invalid value '{evalue}' in '{eline}'. " + "To use 'str' value, enclose it with ' or \"." + ) from None + else: + raise e from None + + +def get_user_configs(file_path, section_names): + """ + Returns saved configuration settings in a toml file from given file_path + + Parameters + ---------- + file_path : str + file path to a toml file + section_names : list + section names that contains the configuration settings + + Returns + ------- + dict + saved configuration settings + """ + data = load_toml(file_path) + while section_names: + actual, given = section_names.pop(0), data.keys() + if actual not in given: + close_match = difflib.get_close_matches(actual, given) + if not close_match: + raise KeyError( + f"'{actual}' not found from pyproject.toml. " + "Please check the following section names: " + f"{pretty_print(given)}." + ) + else: + raise KeyError( + f"{pretty_print(close_match)} is an invalid section name. " + f"Did you mean '{actual}'?" + ) + data = data[actual] + return data + + +def get_default_configs(sql): + """Returns a list of SqlMagic config settings users can set""" + default_configs = sql.trait_defaults() + del default_configs["parent"] + del default_configs["config"] + return default_configs diff --git a/src/tests/pyproject.toml b/src/tests/pyproject.toml deleted file mode 100644 index ad944d370..000000000 --- a/src/tests/pyproject.toml +++ /dev/null @@ -1,5 +0,0 @@ -[tool.jupysql.SqlMagic] -autocom = true -autolimit = "text" -invalid = false -displaycon = false \ No newline at end of file diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index 5df1d9531..d803217b2 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -6,6 +6,7 @@ import re import sys import tempfile +from toml import TomlDecodeError from textwrap import dedent from unittest.mock import patch @@ -1298,28 +1299,171 @@ def test_interact_and_missing_ipywidgets_installed(ip): assert isinstance(out.error_in_exec, ModuleNotFoundError) -def test_loading_config(request, monkeypatch, ip, capsys): - monkeypatch.chdir(request.fspath.dirname) +@pytest.mark.parametrize( + "file_content, expect, revert", + [ + ( + """ +[tool.jupysql.SqlMagic] +autocommit = false +autolimit = 1 +style = "RANDOM" +""", + [ + "Found pyproject.toml", + r"autocommit\s*\|\s*False", + r"autolimit\s*\|\s*1", + r"style\s*\|\s*RANDOM", + ], + {"autocommit": True, "autolimit": 0, "style": "DEFAULT"} + ), + ( + """ +[tool.jupysql.SqlMagic] +""", + ["Found pyproject.toml"], + {} + ), + ( + "", + [""], + {} + ) + ], +) +def test_valid_loading_toml(tmp_empty, ip, capsys, file_content, expect, revert): + Path("pyproject.toml").write_text(file_content) + os.makedirs("sub") + os.chdir("sub") + ip.run_cell("%load_ext sql").result out, _ = capsys.readouterr() - # checking displayed messages - assert "Found pyproject.toml from " in out - assert "'autocom' is an invalid configuration. Did you mean 'autocommit'?" in out - assert ( - "'text' is an invalid value for 'autolimit'. Please use " - " value instead." in out - ) - assert ( - "'invalid' is an invalid configuration. Please review our configuration " - "guideline: https://jupysql.ploomber.io/en/latest/" - "api/configuration.html#options" in out - ) - assert "Setting 'displaycon' into 'False'" in out + assert all(re.search(substring, out) for substring in expect) + + sql = ip.find_cell_magic("sql").__self__ + [setattr(sql, config, value) for config, value in revert.items()] + + +@pytest.mark.parametrize( + "file_content, error_type, error_msg", + [ + ( + """ +[tool.jupySql.SqlMagic] +autocommit = true +""", + KeyError, + "'jupySql' is an invalid section name. Did you mean 'jupysql'?", + ), + ( + """ +[tool.invalid_section.SqlMagic] +autocommit = true +""", + KeyError, + ( + "'jupysql' not found from pyproject.toml. Please check the following " + "section names: 'invalid_section'." + ) + ), + ( + """ +[tool.jupysql.SqlMagic] +autocommit = true +autocommit = true +""", + ValueError, + "Duplicate key found : 'autocommit'", + ), + ( + """ +[tool.jupysql.SqlMagic] +autocommit = True +""", + ValueError, + ( + "Invalid value 'True' in 'autocommit = True'. " + "Valid boolean values: true, false" + ), + ), + ( + """ +[tool.jupysql.SqlMagic] +autocommit = invalid +""", + ValueError, + ( + "Invalid value 'invalid' in 'autocommit = invalid'. " + "To use 'str' value, enclose it with ' or \"." + ), + ), + ( + """ +[tool.jupysql.SqlMagic] +autocommit = 1. +""", + TomlDecodeError, + "" + ), + ], +) +def test_invalid_loading_toml(tmp_empty, ip, file_content, error_type, error_msg): + Path("pyproject.toml").write_text(file_content) + os.makedirs("sub") + os.chdir("sub") + + out = ip.run_cell("%load_ext sql") + + assert isinstance(out.error_in_exec, error_type) + assert error_msg in str(out.error_in_exec) + + +@pytest.mark.parametrize( + "file_content, expect, confirm, revert", + [ + ( + """ +[tool.jupysql.SqlMagic] +autocom = true +autop = false +autolimit = "text" +invalid = false +displaycon = false +""", + [ + "Found pyproject.toml", + "'autocom' is an invalid configuration. Did you mean 'autocommit'?", + ( + "'autop' is an invalid configuration. " + "Did you mean 'autopandas', or 'autopolars'?" + ), + ( + "'text' is an invalid value for 'autolimit'. " + "Please use value instead." + ), + ( + "'invalid' is an invalid configuration. Please review our " + "configuration guideline: https://jupysql.ploomber.io/en/latest/api/configuration.html#options" # noqa + ), + r"displaycon\s*\|\s*False", + ], + {"displaycon": False, "autolimit": 0}, + {"displaycon": True} + ), + ], +) +def test_mixed_loading_toml(tmp_empty, ip, capsys, + file_content, expect, confirm, revert): + Path("pyproject.toml").write_text(file_content) + os.makedirs("sub") + os.chdir("sub") + + ip.run_cell("%load_ext sql").result + out, _ = capsys.readouterr() - # checking configurations - assert ip.run_line_magic("config", "SqlMagic.autocommit") is True - assert ip.run_line_magic("config", "SqlMagic.autolimit") == 0 - assert ip.run_line_magic("config", "SqlMagic.displaycon") is False + assert all(re.search(substring, out) for substring in expect) - ip.run_line_magic("config", "SqlMagic.displaycon=True") + sql = ip.find_cell_magic("sql").__self__ + assert all([getattr(sql, config) == value for config, value in confirm.items()]) + [setattr(sql, config, value) for config, value in revert.items()]