Skip to content

Commit

Permalink
Update configurations for ruff (mckinsey#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
huong-li-nguyen authored Feb 8, 2024
1 parent d467cda commit 5f09980
Show file tree
Hide file tree
Showing 97 changed files with 235 additions and 96 deletions.
31 changes: 18 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
# THIS IS NOT DESCRIBING A PACKAGE, but the DEV environment of this mono-repo
# In order to install the packages of this mono-repo from source, refer to the pyproject.toml in the relevant folder

[lint.isort]
known-first-party = ["vizro"]

[lint.pydocstyle]
convention = "google"

[lint.pylint]
max-args = 6

[project]
authors = [
{name = "Vizro Team"}
Expand Down Expand Up @@ -48,12 +57,18 @@ warn_required_dynamic_aliases = true
warn_untyped_fields = true

[tool.ruff]
line-length = 120
target-version = "py38"

[tool.ruff.lint]
# see: https://beta.ruff.rs/docs/rules/
ignore = [
"D104" # undocumented-public-package
"D104", # undocumented-public-package
"D401", # first-line should be in imperative mood
# D407 needs to be ignored as it otherwise messes up the formatting in our API docs
"D407" # missing dashed underline after section
]
ignore-init-module-imports = true
line-length = 120
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
Expand All @@ -65,12 +80,8 @@ select = [
"RUF", # Ruff-specific rules
"PL" # pylint
]
target-version = "py38"

[tool.ruff.isort]
known-first-party = ["vizro"]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
# Ignore current false positives for pydantic models subclassing across files
# See: https://github.com/astral-sh/ruff/issues/5243#issuecomment-1860776975
"**/models/**" = ["RUF012"]
Expand All @@ -79,9 +90,3 @@ known-first-party = ["vizro"]
"**/tests/**" = ["PLR2004", "S101", "TID252", "D100", "D101", "D102", "D103", "PLC1901", "RUF012"]
# Ignore import violations in all __init__.py files
"__init__.py" = ["E402", "F401"]

[tool.ruff.pydocstyle]
convention = "google"

[tool.ruff.pylint]
max-args = 6
3 changes: 2 additions & 1 deletion tools/check_for_datafiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
def check_for_data_files():
"""Recursively finds all data files in non-whitelisted folders.
Raises:
Raises
AssertionError if data files are present in non-whitelisted folders.
"""
project_dir = str(Path(__file__).parent.parent)
whitelist_dir = {f"{project_dir}{dir}" for dir in whitelist_folders}
Expand Down
48 changes: 48 additions & 0 deletions vizro-ai/changelog.d/20240208_161900_huong_li_nguyen_tidy_ruff.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
3 changes: 3 additions & 0 deletions vizro-ai/src/vizro_ai/_vizro_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, model_name: str = "gpt-3.5-turbo-0613", temperature: int = 0)
Args:
model_name: Model name in string format.
temperature: Temperature parameter for LLM.
"""
self.model_name = model_name
self.temperature = temperature
Expand Down Expand Up @@ -97,6 +98,7 @@ def _get_chart_code(self, df: pd.DataFrame, user_input: str) -> str:
Args:
df: The dataframe to be analyzed
user_input: User questions or descriptions of the desired visual
"""
# TODO refine and update error handling
return self._run_plot_tasks(df, user_input, explain=False).get("code_string")
Expand All @@ -111,6 +113,7 @@ def plot(
user_input: User questions or descriptions of the desired visual.
explain: Flag to include explanation in response.
max_debug_retry: Maximum number of retries to debug errors. Defaults to `3`.
"""
output_dict = self._run_plot_tasks(df, user_input, explain=explain, max_debug_retry=max_debug_retry)
code_string = output_dict.get("code_string")
Expand Down
3 changes: 2 additions & 1 deletion vizro-ai/src/vizro_ai/chains/_llm_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def execute_chain(self, input_str: str):
class FunctionCallChain(VizroBaseChain, ABC):
"""LLM Chain with Function Calling."""

def __init__(
def __init__( # noqa: PLR0913
self,
llm: LLM_MODELS,
raw_prompt: str,
Expand Down Expand Up @@ -107,6 +107,7 @@ def execute_chain(self, input_str: str) -> Dict[str, Any]:
Returns:
args as a dictionary
"""
raw_ans = self.chain.generate([{"input": input_str}])
args = self._custom_parse(raw_ans.generations[0])
Expand Down
4 changes: 3 additions & 1 deletion vizro-ai/src/vizro_ai/chains/_llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@
class LLM(BaseModel):
"""Represents a Language Learning Model (LLM).
Attributes:
Attributes
name (str): The name of the LLM.
max_tokens (int): The maximum number of tokens that the LLM can handle.
wrapper (callable): The langchain function used to instantiate the model.
"""

name: str
Expand Down Expand Up @@ -73,6 +74,7 @@ def get_llm_model(self, model_name: str, temperature: float = 0) -> LLM_MODELS:
Raises:
ValueError: If the model name is not found.
"""
model = self.models.get(model_name.lower())
if model:
Expand Down
4 changes: 3 additions & 1 deletion vizro-ai/src/vizro_ai/components/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class VizroAiComponentBase(ABC):
"""Abstract Base Class that represents a blueprint for Vizro-AI components.
Attributes:
Attributes
prompt (str): Prompt for specific components.
Public Methods:
Expand All @@ -17,6 +17,7 @@ class VizroAiComponentBase(ABC):
Private Methods:
_pre_process: A helper method for LLMChain input vars preprocess.
_post_process: Another helper method for LLMChain output postprocess.
"""

prompt: str = "default prompt place holder"
Expand All @@ -26,6 +27,7 @@ def __init__(self, llm: LLM_MODELS):
Args:
llm: LLM model wrapped with Langchain wrapper.
"""
self.llm = llm

Expand Down
3 changes: 2 additions & 1 deletion vizro-ai/src/vizro_ai/components/chart_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ class ChartSelection(BaseModel):
class GetChartSelection(VizroAiComponentBase):
"""Get Chart Types.
Attributes:
Attributes
prompt (str): Prompt chart selection chains.
"""

prompt: str = chart_type_prompt
Expand Down
3 changes: 2 additions & 1 deletion vizro-ai/src/vizro_ai/components/code_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ class CodeDebug(BaseModel):
class GetDebugger(VizroAiComponentBase):
"""Get Visual code.
Attributes:
Attributes
prompt (str): Prompt visual code.
"""

prompt: str = debugging_prompt
Expand Down
3 changes: 2 additions & 1 deletion vizro-ai/src/vizro_ai/components/custom_chart_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ class GetCustomChart(VizroAiComponentBase):
# TODO Explore if it is possible to create CustomChart without LLM
"""Get custom chart code.
Attributes:
Attributes
prompt (str): Prompt custom chart code.
"""

prompt: str = custom_chart_prompt
Expand Down
3 changes: 2 additions & 1 deletion vizro-ai/src/vizro_ai/components/dataframe_craft.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ class DataFrameCraft(BaseModel):
class GetDataFrameCraft(VizroAiComponentBase):
"""Get dataframe code.
Attributes:
Attributes
prompt (str): Prompt dataframe wrangling chain.
"""

prompt: str = dataframe_prompt
Expand Down
3 changes: 2 additions & 1 deletion vizro-ai/src/vizro_ai/components/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ class CodeExplanation(BaseModel):
class GetCodeExplanation(VizroAiComponentBase):
"""Get Explanation of a code snippet.
Attributes:
Attributes
prompt (str): Prompt code explanation.
"""

prompt: str = code_explanation_prompt
Expand Down
3 changes: 2 additions & 1 deletion vizro-ai/src/vizro_ai/components/visual_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ class VizroCode(BaseModel):
class GetVisualCode(VizroAiComponentBase):
"""Get Visual code.
Attributes:
Attributes
prompt (str): Prompt visual code.
"""

prompt: str = visual_code_prompt
Expand Down
1 change: 1 addition & 0 deletions vizro-ai/src/vizro_ai/schema_manager/schema_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def register(self, obj: Union[Callable, BaseModel]):
Args:
obj (Union[Callable, BaseModel]): function or pydantic model to register
"""
if inspect.isfunction(obj):
annotations = obj.__annotations__
Expand Down
2 changes: 2 additions & 0 deletions vizro-ai/src/vizro_ai/task_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self, llm: LLM_MODELS):
Args:
llm: The LLM instance to be used by components in the pipeline.
"""
self.llm = llm
self.components = []
Expand All @@ -27,6 +28,7 @@ def add(self, component_class, input_keys: Optional[List[str]] = None, output_ke
These should match the output keys of previous components in the pipeline, if applicable.
output_key: The key or identifier for the output that this component will produce.
This can be used as an input key for subsequent components.
"""
self.components.append((component_class, input_keys, output_key))

Expand Down
1 change: 1 addition & 0 deletions vizro-ai/src/vizro_ai/task_pipeline/_pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self, llm: LLM_MODELS = None):
Args:
llm: Large language Model.
"""
self.llm = llm

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

@pytest.fixture
def fake_llm():
"""This is to simulate the response of LLM."""
"""Simulate the response of LLM."""
response = ['{{"fixed_code": "{}"}}'.format("print(df[['country', 'continent']])")]
return FakeListLLM(responses=response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def some_chart_name(data_frame):

@pytest.fixture
def fake_llm(output_custom_chart_LLM_1):
"""This is to simulate the response of LLM."""
"""Simulate the response of LLM."""
response = ['{{"custom_chart_code": "{}"}}'.format(output_custom_chart_LLM_1)]
return FakeListLLM(responses=response)

Expand Down
11 changes: 9 additions & 2 deletions vizro-ai/tests/unit/vizro-ai/components/test_visual_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def expected_final_output_2():

@pytest.fixture
def fake_llm(output_visual_code_LLM_1):
"""This is to simulate the response of LLM."""
"""Simulate the response of LLM."""
response = ['{{"visual_code": "{}"}}'.format(output_visual_code_LLM_1)]
return FakeListLLM(responses=response)

Expand Down Expand Up @@ -106,7 +106,14 @@ def test_post_process(self, input, output, df_code, request):


class TestGetVisualCodeRun:
def test_fake_run(self, fake_llm, output_visual_code_LLM_1, expected_final_output_1, df_code_1, chart_types):
def test_fake_run( # noqa: PLR0913
self,
fake_llm,
output_visual_code_LLM_1,
expected_final_output_1,
df_code_1,
chart_types,
):
get_visual_code = GetVisualCode(fake_llm)
processed_code = get_visual_code.run(
chain_input=output_visual_code_LLM_1, df_code=df_code_1, chart_types=chart_types
Expand Down
1 change: 0 additions & 1 deletion vizro-ai/tests/unit/vizro-ai/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Fixtures to be shared across several tests."""

import pytest

import vizro.plotly.express as px


Expand Down
Loading

0 comments on commit 5f09980

Please sign in to comment.