Skip to content

Commit

Permalink
update based on the first comments
Browse files Browse the repository at this point in the history
  • Loading branch information
bbeat2782 committed Jul 14, 2023
1 parent d234872 commit 02f5aa2
Show file tree
Hide file tree
Showing 7 changed files with 320 additions and 94 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* [Feature] Added a line under `ResultSet` to distinguish it from data frame and error message when invalid operations are performed (#468)
* [Feature] Moved `%sqlrender` feature to `%sqlcmd snippets` (#647)
* [Feature] Added tables listing stored snippets when `%sqlcmd snippets` is called (#648)
* [Feature] Modified `%load_ext sql` to load saved configurations from `pyproject.toml` file (#689)
* [Feature] Allow loading configuration values from a `pyproject.toml` file upon magic initialization (#689)
* [Doc] Modified integrations content to ensure they're all consistent (#523)
* [Doc] Document --persist-replace in API section (#539)
* [Fix] Fixed CI issue by updating `invalid_connection_string_duckdb` in `test_magic.py` (#631)
Expand Down
2 changes: 1 addition & 1 deletion doc/api/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ print(res)

## Loading configuration settings

You can define configurations in a `pyproject.toml` file and automatically load the configurations when you run `%load_ext sql`. It will try to find the file recursively from the current to root directory. If the file is not found, default values will be used. A sample `pyproject.toml` could look like this:
You can define configurations in a `pyproject.toml` file and automatically load the configurations when you run `%load_ext sql`. If the file is not found in the current or parent directories, default values will be used. A sample `pyproject.toml` could look like this:

```
[tool.jupysql.SqlMagic]
Expand Down
5 changes: 5 additions & 0 deletions src/sql/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,8 @@ def message(message):
def message_success(message):
"""Display a success message"""
display(Message(message, style="color: green"))


def table(headers, rows):
"""Display a table"""
display(Table(headers, rows))
95 changes: 29 additions & 66 deletions src/sql/magic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import json
import re
from pathlib import Path
import toml
import difflib

try:
from ipywidgets import interact
Expand Down Expand Up @@ -33,7 +30,7 @@
from sql.magic_plot import SqlPlotMagic
from sql.magic_cmd import SqlCmdMagic
from sql._patch import patch_ipython_usage_error
from sql import query_util
from sql import query_util, util
from sql.util import get_suggestions_message, pretty_print
from ploomber_core.dependencies import check_installed

Expand Down Expand Up @@ -622,71 +619,37 @@ def _persist_dataframe(
display.message_success(f"Success! Persisted {table_name} to the database.")


def find_toml_path():
"""
Recursively finds an absolute path to pyproject.toml starting
from current directory
"""
current = Path().resolve()
while not (current / "pyproject.toml").exists():
if current == current.parent:
return None

current = current.parent
display.message(f"Found pyproject.toml from '{current}'")

return str(current) + "/pyproject.toml"


def get_configs(file_path):
"""Returns saved configs from given file_path"""
with open(file_path, "r") as file:
content = file.read()
data = toml.loads(content)
try:
return data["tool"]["jupysql"]["SqlMagic"]
except KeyError:
return {}

def set_configs(ip, file_path):
"""Set user defined SqlMagic configuration settings"""
sql = ip.find_cell_magic("sql").__self__
user_configs = util.get_user_configs(file_path, ["tool", "jupysql", "SqlMagic"])
default_configs = util.get_default_configs(sql)
table_rows = []
for config, value in user_configs.items():
if config in default_configs.keys():
default_type = type(default_configs[config])
if isinstance(value, default_type):
setattr(sql, config, value)
table_rows.append([config, value])
else:
display.message(
f"'{value}' is an invalid value for '{config}'. "
f"Please use {default_type} value instead."
)
else:
util.find_close_match_config(config, default_configs.keys())

def get_config_variables(obj):
"""Returns a list of configs users can set"""
config_vars = list(vars(obj)["_all_trait_default_generators"].keys())
config_vars.remove("config")
config_vars.remove("parent")
return config_vars
return table_rows


def load_configs(ip):
"""Loads saved configs in pyproject.toml"""
file_path = find_toml_path()
def load_SqlMagic_configs(ip):
"""Loads saved SqlMagic configs in pyproject.toml"""
file_path = util.find_path_from_root("pyproject.toml")
if file_path:
configs = get_configs(file_path)
config_vars = get_config_variables(SqlMagic)
for config, value in configs.items():
if config in config_vars:
default_value = getattr(SqlMagic, config).default_value
if isinstance(value, type(default_value)):
ip.run_cell(f"%config SqlMagic.{config} = {value}")
display.message(f"Setting '{config}' into '{value}'")
else:
display.message(
f"'{value}' is an invalid value for '{config}'. "
f"Please use {type(default_value)} value instead."
)
else:
closest_match = difflib.get_close_matches(config, config_vars, n=1)
if not closest_match:
display.message(
f"'{config}' is an invalid configuration. Please review our "
"configuration guideline: "
"https://jupysql.ploomber.io/en/latest/api/configuration.html#options" # noqa breaks the check-for-broken-links
)
else:
display.message(
f"'{config}' is an invalid configuration. "
f"Did you mean '{closest_match[0]}'?"
)
table_rows = set_configs(ip, file_path)
if table_rows:
display.message("Setting changed: ")
display.table(["Config", "value"], table_rows)


def load_ipython_extension(ip):
Expand All @@ -702,6 +665,6 @@ def load_ipython_extension(ip):
ip.register_magics(SqlPlotMagic)
ip.register_magics(SqlCmdMagic)

load_configs(ip)
load_SqlMagic_configs(ip)

patch_ipython_usage_error(ip)
121 changes: 120 additions & 1 deletion src/sql/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import difflib
from sql.connection import Connection
from sql.store import store, _get_dependents_for_key
from sql import exceptions
from sql import exceptions, display
import json
from pathlib import Path
import toml

SINGLE_QUOTE = "'"
DOUBLE_QUOTE = '"'
Expand Down Expand Up @@ -363,3 +365,120 @@ def show_deprecation_warning():
"raise an exception in the next major release so please remove it.",
FutureWarning,
)


def find_path_from_root(file_name):
"""
Recursively finds an absolute path to file_name starting
from current to root directory
"""
current = Path().resolve()
while not (current / file_name).exists():
if current == current.parent:
return None

current = current.parent
display.message(f"Found {file_name} from '{current}'")

return str(Path(current, file_name))


def find_close_match_config(word, possibilities, n=3):
"""Finds closest matching configurations and displays message"""
closest_matches = difflib.get_close_matches(word, possibilities, n=n)
if not closest_matches:
display.message(
f"'{word}' is an invalid configuration. Please review our "
"configuration guideline: "
"https://jupysql.ploomber.io/en/latest/api/configuration.html#options"
)
else:
display.message(
f"'{word}' is an invalid configuration. Did you mean "
f"{pretty_print(closest_matches, last_delimiter='or')}?"
)


def get_line_content_from_toml(file_path, line_number):
"""
Locates a line that error occurs when loading a toml file
and returns the line, key, and value
"""
with open(file_path, "r") as file:
lines = file.readlines()
eline = lines[line_number - 1].strip()
ekey, evalue = None, None
if "=" in eline:
ekey, evalue = map(str.strip, eline.split("="))
return eline, ekey, evalue


def load_toml(file_path):
"""
Returns toml file content in a dictionary format
and raises error if it fails to load the toml file
"""
try:
with open(file_path, "r") as file:
content = file.read()
return toml.loads(content)
except toml.TomlDecodeError as e:
eline, ekey, evalue = get_line_content_from_toml(file_path, e.lineno)
if "Duplicate keys!" in str(e):
raise ValueError(f"Duplicate key found : '{ekey}'") from None
elif "Only all lowercase booleans" in str(e):
raise ValueError(
f"Invalid value '{evalue}' in '{eline}'. "
"Valid boolean values: true, false"
) from None
elif "invalid literal for int()" in str(e):
raise ValueError(
f"Invalid value '{evalue}' in '{eline}'. "
"To use 'str' value, enclose it with ' or \"."
) from None
else:
raise e from None


def get_user_configs(file_path, section_names):
"""
Returns saved configuration settings in a toml file from given file_path
Parameters
----------
file_path : str
file path to a toml file
section_names : list
section names that contains the configuration settings
Returns
-------
dict
saved configuration settings
"""
data = load_toml(file_path)
while section_names:
actual, given = section_names.pop(0), data.keys()
if actual not in given:
close_match = difflib.get_close_matches(actual, given)
if not close_match:
raise KeyError(
f"'{actual}' not found from pyproject.toml. "
"Please check the following section names: "
f"{pretty_print(given)}."
)
else:
raise KeyError(
f"{pretty_print(close_match)} is an invalid section name. "
f"Did you mean '{actual}'?"
)
data = data[actual]
return data


def get_default_configs(sql):
"""Returns a list of SqlMagic config settings users can set"""
default_configs = sql.trait_defaults()
del default_configs["parent"]
del default_configs["config"]
return default_configs
5 changes: 0 additions & 5 deletions src/tests/pyproject.toml

This file was deleted.

Loading

0 comments on commit 02f5aa2

Please sign in to comment.