diff --git a/.github/workflows/nemo_inspector_tests.yaml b/.github/workflows/nemo_inspector_tests.yaml deleted file mode 100644 index 4fa8ee7c0..000000000 --- a/.github/workflows/nemo_inspector_tests.yaml +++ /dev/null @@ -1,45 +0,0 @@ -name: NeMo Inspector Tool Tests - -on: - pull_request: - branches: [ "main" ] - workflow_dispatch: - -permissions: - contents: read - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -jobs: - inspector-tests: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - - name: Set up Python 3.10 - uses: actions/setup-python@v3 - with: - python-version: "3.10" - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install . - pip install -r requirements/inspector.txt - pip install -r requirements/inspector-tests.txt - pip install -r requirements/common-tests.txt - - - name: Set up Chrome - uses: browser-actions/setup-chrome@latest - - - name: Set up ChromeDriver - uses: nanasess/setup-chromedriver@master - - - name: Run NeMo inspector tests - run: | - export DISPLAY=:99 - sudo Xvfb :99 -ac & - python -m pytest nemo_inspector/tests/ --reruns 3 --reruns-delay 1 --junitxml=results.xml --cov-report=term-missing:skip-covered --cov=visualization --durations=30 -vv diff --git a/README.md b/README.md index 6f49d3820..73bf9298b 100644 --- a/README.md +++ b/README.md @@ -89,10 +89,7 @@ See our [paper](https://arxiv.org/abs/2410.01560) for ablations studies and more ## Nemo Inspector -We also provide a convenient [tool](/nemo_inspector/Readme.md) for visualizing inference and data analysis -| Overview | Inference Page | Analyze Page | -| :-------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------: | -| [![Demo of the tool](/nemo_inspector/images/demo.png)](https://www.youtube.com/watch?v=EmBFEl7ydqE) | [![Demo of the inference page](/nemo_inspector/images/inference_page.png)](https://www.youtube.com/watch?v=6utSkPCdNks) | [![Demo of the analyze page](/nemo_inspector/images/analyze_page.png)](https://www.youtube.com/watch?v=cnPyDlDmQXg) | +We also provide a convenient [tool](https://github.com/NVIDIA/NeMo-Inspector) for visualizing inference and data analysis. ## Papers diff --git a/nemo_inspector/Readme.md b/nemo_inspector/Readme.md deleted file mode 100644 index 467208306..000000000 --- a/nemo_inspector/Readme.md +++ /dev/null @@ -1,140 +0,0 @@ -# NeMo Inspector tool - -## Demo -This is a tool for data analysis, consisting of two pages: "Inference" and "Analyze". - -### Overview of the tool -[![Demo of the tool](/nemo_inspector/images/demo.png)](https://www.youtube.com/watch?v=EmBFEl7ydqE) - -### Demo of the Inference Page -[![Demo of the inference page](/nemo_inspector/images/inference_page.png)](https://www.youtube.com/watch?v=6utSkPCdNks) - -### Demo of the Analyze Page -[![Demo of the analyze page](/nemo_inspector/images/analyze_page.png)](https://www.youtube.com/watch?v=cnPyDlDmQXg) - -## Getting Started -Before using this tool, follow the instructions in [prerequisites.md](/docs/prerequisites.md), and install requirements: -```shell -pip install -r requirements/inspector.txt -``` -You can adjust parameters in the [inspector_config.yaml](/nemo_inspector/settings/inspector_config.yaml) file or via the command line. Use the following command to launch the program (all parameters are optional): -```shell -python nemo_inspector/nemo_inspector.py \ -++server.host= -``` -For the "Inference" page, launch the server with the model (see [inference.md](/docs/inference.md)), specify `host` and, if necessary, `ssh_key` and `ssh_server`. - -## Inference page -This page enables the analysis of model answers based on different parameters. It offers two modes: "Chat", "Run one sample". - -- **Chat** mode facilitates a conversation with the model and requires minimal parameter setup. -- **Run one sample** mode allows you to send a single question to the model. It can be a question from the dataset (with parameters `input_file` or `dataset` and `split`) or a custom question. The answer is validated by comparing it with the `expected_answer` field. - -## Analyze page -To use the Analyze page, specify paths to the generations you want to use (if not obtained through the "Inference" page). You can pass parameters via the command line with `++inspector_params.model_prediction.generation1='/some_path/generation1/output.jsonl'` or add them in an additional config file. - -```yaml -inspector_params: - model_prediction: - generation1: /some_path/generation1/output.jsonl - generation2: /some_path/generation2/output-rs*.jsonl -``` - -The tool also supports comparison of multiple generations (e.g. - `generation2` in the config above). All files satisfying the given pattern will be considered for analysis. - -On this page, you can sort, filter, and compare generations. You can also add labels to the data and save your modified, filtered, and sorted generation by specifying `save_generations_path`. - -### Filtering -You can create custom functions to filter data. There are two modes: Filter Files mode and Filter Questions mode. - -#### Filter Files mode -In this mode the functions will filter each sample from different files. It should take a dictionary containing keys representing generation names and values as JSON data from your generation. - -Custom filtering functions should return a Boolean value. For instance: - -```python -def custom_filtering_function(error_message: str) -> bool: - # Your code here - return result - -custom_filtering_function(data['generation1']['error_message']) # This line will be used for filtering -``` -The last line in the custom filtering function will be used for data filtering; all preceding code within the function is executed but does not directly impact the filtering process. - -To apply filters for different generations, separate expressions with '&&' symbols. - ```python - data['generation1']['is_correct'] && not data['generation2']['is_correct'] - ``` - Do not write expressions for different generations without separators in this mode. - -#### Filter Questions mode -In this mode the function will filter each question. Files will not be filtered. It should take a dictionary containing keys representing generation names and a list of values as JSON data from your generation from each file. - -In this mode you should not use the && separator. For instance, an example from the previous mode can be written like this: - ```python - data['generation1'][0]['is_correct'] and not data['generation2'][0]['is_correct'] - # Filter questions where the first file of the first generation contains a correct solution and the first file from the second generation contains a wrong solution - ``` - or like this: - ```python - data['generation1'][0]['correct_responses'] == 1 and data['generation2'][0]['correct_responses'] == 0 - # Custom Statistics are dublicated in all JSONs. So here, 'correct_responses' value will be the same for all file for a specific generation and question - ``` - In this mode you can also compare fields of different generations - ```python - data['generation1'][0]['is_correct'] != data['generation2'][0]['is_correct'] - ``` - These examples can not be used in the Filter Files mode - -### Sorting -Sorting functions operate similarly to filtering functions, with a few distinctions: - -1. Sorting functions operate on individual data entries rather than on dictionaries containing generation name keys. -2. Sorting functions cannot be applied across different generations simultaneously. - -Here is an example of a correct sorting function: - -```python -def custom_sorting_function(generation: str): - return len(generation) - -custom_sorting_function(data['generation']) -``` - -### Statistics -There are two types of statistics: "Custom Statistics" and "General Custom Statistics". Custom statistics apply to different samples of a single question. There are some default custom statistics: "correct_responses", "wrong_responses", and "no_responses". General Custom Statistics apply to each sample across all questions. Default general custom statistics - "dataset size", "overall number of samples" and "generations per sample" - -![stats](/nemo_inspector/images/stats.png) - -You can define your own Custom and General Custom Statistics functions. For Custom Statistics, the function should take an array of JSONs from each file. For General Custom Statistics, the function should take a list of lists of dictionaries, where the first dimension corresponds to the question index and the second dimension to the file index. - -Here are examples of correct functions for both statistics types: - -```python -# Custom Statistic function -def unique_error_counter(datas): - unique_errors = set() - for data in datas: - unique_errors.add(data.get('error_message')) - return len(unique_errors) - -def number_of_runs(datas): - return len(datas) - -# Mapping function names to functions -{'unique_errors': unique_error_counter, "number_of_runs": number_of_runs} -``` -```python -# General Custom Statistic function -def overall_unique_error_counter(datas): - unique_errors = set() - for question_data in datas: - for file_data in question_data: - unique_errors.add(file_data.get('error_message')) - return len(unique_errors) - -# Mapping function names to functions -{'unique_errors': overall_unique_error_counter} -``` -Note that the last line in both statistic sections should be a dictionary where each key is the function's name and the corresponding value is the function itself. diff --git a/nemo_inspector/__init__.py b/nemo_inspector/__init__.py deleted file mode 100644 index d9155f923..000000000 --- a/nemo_inspector/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/nemo_inspector/callbacks/__init__.py b/nemo_inspector/callbacks/__init__.py deleted file mode 100644 index ec565ec96..000000000 --- a/nemo_inspector/callbacks/__init__.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import glob -import os -from dataclasses import asdict -from pathlib import Path -from typing import Dict, List - -import dash_bootstrap_components as dbc -import hydra -from dash import Dash -from flask import Flask -from omegaconf import OmegaConf -from settings.constants import ( - CODE_BEGIN, - CODE_END, - CODE_OUTPUT_BEGIN, - CODE_OUTPUT_END, - CODE_SEPARATORS, - CONFIGS_FOLDER, - RETRIEVAL, - RETRIEVAL_FIELDS, - TEMPLATES_FOLDER, - UNDEFINED, -) -from settings.inspector_config import InspectorConfig -from utils.common import initialize_default - -from nemo_skills.prompt.few_shot_examples import examples_map -from nemo_skills.prompt.utils import PromptConfig, get_prompt, load_config -from nemo_skills.utils import setup_logging - -setup_logging() -config_path = os.path.join(os.path.abspath(Path(__file__).parents[1]), "settings") - -config = {} - -generation_config_dir = Path(__file__).resolve().parents[2].joinpath("nemo_skills", "inference").resolve() -os.environ["config_dir"] = str(generation_config_dir) - - -def list_yaml_files(folder_path): - yaml_files = glob.glob(os.path.join(folder_path, '**', '*.yaml'), recursive=True) - - yaml_files_relative = [os.path.relpath(file, folder_path) for file in yaml_files] - - return yaml_files_relative - - -def get_specific_fields(dict_cfg: Dict, fields: List[Dict]) -> Dict: - retrieved_values = {} - for key, value in dict_cfg.items(): - if key in fields: - retrieved_values[key] = value - if isinstance(value, Dict): - retrieved_values = { - **retrieved_values, - **get_specific_fields(value, fields), - } - return retrieved_values - - -def update_nested_dict(dict1, dict2): - for key, value in dict2.items(): - # If the value is a dictionary, call the function recursively - if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict): - update_nested_dict(dict1[key], value) - else: - # Otherwise, directly update the value - dict1[key] = value - - -@hydra.main(version_base=None, config_path=config_path, config_name="inspector_config") -def set_config(cfg: InspectorConfig) -> None: - global config - if not cfg.input_file and not cfg.dataset and not cfg.split: - cfg.dataset = UNDEFINED - cfg.split = UNDEFINED - - cfg.output_file = UNDEFINED - - examples_types = list(examples_map.keys()) - - if "server_type" not in cfg.server: - cfg.server = OmegaConf.create({"server_type": UNDEFINED}) - - if cfg.server.server_type != 'openai' and cfg.prompt_template is None: - cfg.prompt_template = UNDEFINED - - config['nemo_inspector'] = asdict(OmegaConf.to_object(cfg)) - - config['nemo_inspector']['types'] = { - "prompt_config": [UNDEFINED] + list_yaml_files(os.path.join(CONFIGS_FOLDER)), - "prompt_template": [UNDEFINED] + list_yaml_files(os.path.join(TEMPLATES_FOLDER)), - "examples_type": [UNDEFINED, RETRIEVAL] + examples_types, - "retrieval_field": [""], - } - conf_path = ( - config['nemo_inspector']['prompt_config'] - if os.path.isfile(str(config['nemo_inspector']['prompt_config'])) - else os.path.join(CONFIGS_FOLDER, f"{config['nemo_inspector']['prompt_config']}.yaml") - ) - template_path = ( - config['nemo_inspector']['prompt_template'] - if os.path.isfile(str(config['nemo_inspector']['prompt_template'])) - else os.path.join(TEMPLATES_FOLDER, f"{config['nemo_inspector']['prompt_template']}.yaml") - ) - - if not os.path.isfile(conf_path) and not os.path.isfile(template_path): - prompt_config = initialize_default(PromptConfig) - elif not os.path.isfile(conf_path): - prompt_config = initialize_default(PromptConfig, load_config(template_path)) - elif not os.path.isfile(template_path): - prompt_config = initialize_default(PromptConfig, asdict(get_prompt(config_path).config)) - else: - prompt_config = initialize_default(PromptConfig, asdict(get_prompt(config_path, template_path).config)) - - config['nemo_inspector']['prompt'] = asdict(prompt_config) - update_nested_dict(config['nemo_inspector']['prompt'], OmegaConf.to_container(cfg).get('prompt', {})) - - for separator_type, separator in CODE_SEPARATORS.items(): - if not config['nemo_inspector']['prompt']['template'][separator_type]: - config['nemo_inspector']['prompt']['template'][separator_type] = separator - - config['nemo_inspector']['inspector_params']['code_separators'] = ( - config['nemo_inspector']['prompt']['template'][CODE_BEGIN], - config['nemo_inspector']['prompt']['template'][CODE_END], - ) - config['nemo_inspector']['inspector_params']['code_output_separators'] = ( - config['nemo_inspector']['prompt']['template'][CODE_OUTPUT_BEGIN], - config['nemo_inspector']['prompt']['template'][CODE_OUTPUT_END], - ) - - config['nemo_inspector']['retrieval_fields'] = get_specific_fields(config['nemo_inspector'], RETRIEVAL_FIELDS) - - config['nemo_inspector']['input_file'] = str(config['nemo_inspector']['input_file']) - for name in ['offset', 'max_samples', 'batch_size', 'skip_filled', 'dry_run']: - config['nemo_inspector'].pop(name) - - -set_config() -server = Flask(__name__) -server.config.update(config) - -assets_path = os.path.join(os.path.dirname(__file__), 'assets') - -app = Dash( - __name__, - suppress_callback_exceptions=True, - external_stylesheets=[dbc.themes.BOOTSTRAP], - server=server, - assets_folder=assets_path, -) - -from callbacks.analyze_callbacks import choose_base_model -from callbacks.base_callback import nav_click -from callbacks.run_prompt_callbacks import preview diff --git a/nemo_inspector/callbacks/analyze_callbacks.py b/nemo_inspector/callbacks/analyze_callbacks.py deleted file mode 100644 index 2fed35799..000000000 --- a/nemo_inspector/callbacks/analyze_callbacks.py +++ /dev/null @@ -1,1228 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import logging -import os -from typing import Dict, List, Tuple - -from callbacks import app -from dash import ALL, callback_context, html, no_update -from dash.dependencies import Input, Output, State -from dash.exceptions import PreventUpdate -from layouts import ( - get_detailed_answer_column, - get_filter_answers_layout, - get_filter_text, - get_labels, - get_model_answers_table_layout, - get_models_selector_table_cell, - get_row_detailed_inner_data, - get_sorting_answers_layout, - get_stats_input, - get_table_data, - get_table_detailed_inner_data, - get_update_dataset_layout, -) -from settings.constants import ( - CHOOSE_LABEL, - CHOOSE_MODEL, - DELETE, - EDIT_ICON_PATH, - ERROR_MESSAGE_TEMPLATE, - EXTRA_FIELDS, - FILE_NAME, - FILES_FILTERING, - GENERAL_STATS, - INLINE_STATS, - LABEL, - LABEL_SELECTOR_ID, - MODEL_SELECTOR_ID, - QUESTIONS_FILTERING, - SAVE_ICON_PATH, -) -from utils.common import ( - calculate_metrics_for_whole_data, - get_available_models, - get_compared_rows, - get_custom_stats, - get_deleted_stats, - get_editable_rows, - get_excluded_row, - get_filtered_files, - get_general_custom_stats, - get_stats_raw, -) - - -@app.callback( - [ - Output("compare_models_rows", "children", allow_duplicate=True), - Output("loading_container", "children", allow_duplicate=True), - ], - Input("base_model_answers_selector", "value"), - State("loading_container", "children"), - prevent_initial_call=True, -) -def choose_base_model( - base_model: str, - loading_container: str, -) -> Tuple[List, bool]: - if base_model == CHOOSE_MODEL: - return no_update, no_update - get_excluded_row().clear() - return ( - get_model_answers_table_layout( - base_model=base_model, - ), - loading_container + " ", - ) - - -@app.callback( - [ - Output("update_dataset_modal", "is_open", allow_duplicate=True), - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - ], - [Input("update_dataset_button", "n_clicks"), Input("apply_update_dataset_button", "n_clicks")], - [State("update_dataset_modal", "is_open"), State("js_trigger", "children")], - prevent_initial_call=True, -) -def open_update_dataset_modal(n1: int, n2: int, is_open: bool, js_trigger: str) -> bool: - if n1 or n2: - is_open = not is_open - return is_open, "", js_trigger + " " - return is_open, "", js_trigger + " " - - -@app.callback( - [ - Output("compare_models_rows", "children", allow_duplicate=True), - Output("loading_container", "children", allow_duplicate=True), - ], - Input("apply_update_dataset_button", "n_clicks"), - [ - State("update_dataset_input", "value"), - State({"type": "model_selector", "id": ALL}, "value"), - State("base_model_answers_selector", "value"), - State("loading_container", "children"), - ], - prevent_initial_call=True, -) -def update_dataset( - n_ckicks: int, - update_function: str, - models: List[str], - base_model: str, - loading_container: str, -) -> Tuple[List[html.Tr], bool]: - if base_model == CHOOSE_MODEL or not update_function: - return no_update, no_update - return ( - get_update_dataset_layout(base_model=base_model, update_function=update_function, models=models), - loading_container + " ", - ) - - -@app.callback( - Output("save_dataset_modal", "is_open", allow_duplicate=True), - Input("save_dataset", "n_clicks"), - prevent_initial_call=True, -) -def open_save_dataset_modal(n1: int) -> bool: - ctx = callback_context - if not ctx.triggered: - return no_update - - return True - - -@app.callback( - [Output("save_dataset_modal", "is_open", allow_duplicate=True), Output('error_message', 'children')], - Input("save_dataset_button", "n_clicks"), - [ - State("base_model_answers_selector", "value"), - State("save_path", "value"), - ], - prevent_initial_call=True, -) -def save_dataset(n_click: int, base_model: str, save_path: str) -> Tuple[List, bool]: - if not n_click or not save_path or not base_model: - return no_update, no_update - - if not os.path.exists(save_path): - try: - os.mkdir(save_path) - except: - return True, html.Pre(f'could not save generations by path {save_path}') - - new_data = {} - - for data in get_table_data(): - for file_data in data[base_model]: - file_name = file_data[FILE_NAME] - if file_name not in new_data: - new_data[file_name] = [] - new_data[file_name].append({key: value for key, value in file_data.items() if key not in EXTRA_FIELDS}) - - for file_name, data in new_data.items(): - with open(os.path.join(save_path, file_name + '.jsonl'), 'w') as file: - file.write("\n".join([json.dumps(line) for line in data])) - - return False, '' - - -@app.callback( - [ - Output({"type": "filter", "id": ALL}, "is_open"), - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - ], - [ - Input({"type": "set_filter_button", "id": ALL}, "n_clicks"), - Input({"type": "apply_filter_button", "id": ALL}, "n_clicks"), - ], - [ - State({"type": "filter", "id": ALL}, "is_open"), - State("js_trigger", "children"), - ], - prevent_initial_call=True, -) -def toggle_modal_filter(n1: int, n2: int, is_open: bool, js_trigger: str) -> bool: - ctx = callback_context - if not ctx.triggered: - return [no_update] * len(is_open), no_update, no_update - button_id = json.loads(ctx.triggered[-1]['prop_id'].split('.')[0])['id'] + 1 - - if not ctx.triggered[0]['value']: - return [no_update] * len(is_open), "", js_trigger + " " - - if n1[button_id] or n2[button_id]: - is_open[button_id] = not is_open[button_id] - return is_open, "", js_trigger + " " - return is_open, "", js_trigger + " " - - -@app.callback( - [ - Output({"type": "sorting", "id": ALL}, "is_open"), - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - ], - [ - Input({"type": "set_sorting_button", "id": ALL}, "n_clicks"), - Input({"type": "apply_sorting_button", "id": ALL}, "n_clicks"), - ], - [ - State({"type": "sorting", "id": ALL}, "is_open"), - State("js_trigger", "children"), - ], - prevent_initial_call=True, -) -def toggle_modal_sorting(n1: int, n2: int, is_open: bool, js_trigger: str) -> bool: - ctx = callback_context - if not ctx.triggered: - return [no_update] * len(is_open), no_update, no_update - - button_id = json.loads(ctx.triggered[-1]['prop_id'].split('.')[0])['id'] + 1 - - if not ctx.triggered[0]['value']: - return [no_update] * len(is_open), no_update, no_update - - if n1[button_id] or n2[button_id]: - is_open[button_id] = not is_open[button_id] - return is_open, "", js_trigger + " " - return is_open, "", js_trigger + " " - - -@app.callback( - Output({"type": "label", "id": ALL}, "is_open"), - [ - Input({"type": "set_file_label_button", "id": ALL}, "n_clicks"), - Input({"type": "apply_label_button", "id": ALL}, "n_clicks"), - Input({"type": "delete_label_button", "id": ALL}, "n_clicks"), - ], - [State({"type": "label", "id": ALL}, "is_open")], -) -def toggle_modal_label(n1: int, n2: int, n3: int, is_open: bool) -> bool: - ctx = callback_context - if not ctx.triggered: - return [no_update] * len(is_open) - - button_id = json.loads(ctx.triggered[-1]['prop_id'].split('.')[0])['id'] + 1 - if not ctx.triggered[0]['value']: - return [no_update] * len(is_open) - - if n1[button_id] or n2[button_id] or n3[button_id]: - is_open[button_id] = not is_open[button_id] - return is_open - return is_open - - -@app.callback( - [ - Output("new_stats", "is_open"), - Output("stats_input_container", "children", allow_duplicate=True), - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - ], - [ - Input("set_new_stats_button", "n_clicks"), - Input("apply_new_stats", "n_clicks"), - ], - [ - State("new_stats", "is_open"), - State("stats_modes", "value"), - State("js_trigger", "children"), - ], - prevent_initial_call=True, -) -def toggle_modal_stats(n1: int, n2: int, is_open: bool, modes: List[str], js_trigger: str) -> bool: - if not n1 and not n2: - return no_update, no_update, no_update, no_update - - if n1 or n2: - is_open = not is_open - return is_open, get_stats_input(modes), "", js_trigger + " " - return is_open, get_stats_input(modes), "", js_trigger + " " - - -@app.callback( - Output("compare_models_rows", "children", allow_duplicate=True), - Input("apply_new_stats", "n_clicks"), - [ - State("stats_input", "value"), - State("base_model_answers_selector", "value"), - State("stats_modes", "value"), - ], - prevent_initial_call=True, -) -def apply_new_stat( - n_click: int, - code_raw: str, - base_model: str, - stats_modes: List[str], -) -> List: - if not n_click or code_raw == "": - return no_update - code_raw_lines = code_raw.strip().split('\n') - if not stats_modes or DELETE not in stats_modes: - code = '\n'.join(code_raw_lines[:-1]) + '\nnew_stats = ' + code_raw_lines[-1] - else: - code = "delete_stats = " + f"'{code_raw_lines[-1]}'" - namespace = {} - try: - exec(code, namespace) - except Exception as e: - logging.error(ERROR_MESSAGE_TEMPLATE.format(code, str(e))) - return no_update - if stats_modes and GENERAL_STATS in stats_modes: - if DELETE in stats_modes: - get_general_custom_stats().pop(namespace['delete_stats'], None) - else: - get_general_custom_stats().update(namespace['new_stats']) - get_stats_raw()[GENERAL_STATS][' '.join(namespace['new_stats'].keys())] = code_raw - else: - if stats_modes and DELETE in stats_modes: - get_custom_stats().pop(namespace['delete_stats'], None) - get_deleted_stats().update(namespace['delete_stats']) - else: - get_custom_stats().update(namespace['new_stats']) - get_stats_raw()[INLINE_STATS][' '.join(namespace['new_stats'].keys())] = code_raw - if base_model == CHOOSE_MODEL: - return [] - calculate_metrics_for_whole_data(get_table_data(), base_model) - return get_model_answers_table_layout(base_model=base_model, use_current=True) - - -@app.callback( - [ - Output("stats_input", "value", allow_duplicate=True), - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - ], - Input("stats_extractor", "value"), - [ - State("stats_modes", "value"), - State("js_trigger", "children"), - ], - prevent_initial_call=True, -) -def apply_new_stat(stat: str, stats_modes: List[str], js_trigger: str) -> List: - mode = GENERAL_STATS if GENERAL_STATS in stats_modes else INLINE_STATS - return get_stats_raw()[mode][stat], " ", js_trigger + " " - - -@app.callback( - [ - Output("compare_models_rows", "children", allow_duplicate=True), - Output("filtering_container", "children"), - Output("loading_container", "children", allow_duplicate=True), - ], - [ - Input({"type": "apply_filter_button", "id": -1}, "n_clicks"), - ], - [ - State({"type": "filter_function_input", "id": -1}, "value"), - State({"type": "apply_on_filtered_data", "id": -1}, "value"), - State({"type": "filter_mode", "id": -1}, "value"), - State({"type": "sorting_function_input", "id": -1}, "value"), - State({"type": "model_selector", "id": ALL}, "value"), - State("base_model_answers_selector", "value"), - State("filtering_container", "children"), - State("loading_container", "children"), - ], - prevent_initial_call=True, -) -def filter_data( - n_ckicks: int, - filter_function: str, - apply_on_filtered_data: int, - filter_mode: List[str], - sorting_function: str, - models: List[str], - base_model: str, - filtering_functions: str, - loading_container: str, -) -> Tuple[List[html.Tr], bool]: - if not n_ckicks: - return no_update, no_update, no_update - if apply_on_filtered_data and filtering_functions: - filtering_functions['props']['children'] += f"\n{filter_function}" - if base_model == CHOOSE_MODEL: - return [], no_update, no_update - if len(get_table_data()) == 0: # TODO fix - models = [models[0]] - get_filter_answers_layout( - base_model=base_model, - filtering_function=filter_function, - filter_mode=(FILES_FILTERING if filter_mode and len(filter_mode) else QUESTIONS_FILTERING), - apply_on_filtered_data=(apply_on_filtered_data if apply_on_filtered_data else 0), - models=models, - ) - return ( - get_sorting_answers_layout( - base_model=base_model, - sorting_function=sorting_function, - models=models, - ), - ( - html.Pre(f"Filtering function:\n{filter_function}") - if not apply_on_filtered_data or not filtering_functions - else filtering_functions - ), - loading_container + " ", - ) - - -@app.callback( - [ - Output("compare_models_rows", "children", allow_duplicate=True), - Output("sorting_container", "children"), - Output("loading_container", "children", allow_duplicate=True), - ], - Input({"type": "apply_sorting_button", "id": -1}, "n_clicks"), - [ - State({"type": "sorting_function_input", "id": -1}, "value"), - State({"type": "model_selector", "id": ALL}, "value"), - State("base_model_answers_selector", "value"), - State("loading_container", "children"), - ], - prevent_initial_call=True, -) -def sorting_data( - n_ckicks: int, - sorting_function: str, - models: List[str], - base_model: str, - loading_container: str, -) -> Tuple[List[html.Tr], bool]: - if base_model == CHOOSE_MODEL or not sorting_function: - return no_update, no_update, no_update - return ( - get_sorting_answers_layout( - base_model=base_model, - sorting_function=sorting_function, - models=models, - ), - html.Pre(f'Sorting function:\n{sorting_function}'), - loading_container + " ", - ) - - -@app.callback( - [ - Output( - "dummy_output", - 'children', - allow_duplicate=True, - ), - Output({"type": "del_row", "id": ALL}, "children"), - ], - Input({"type": "del_row", "id": ALL}, "n_clicks"), - [ - State({"type": "row_name", "id": ALL}, "children"), - State({"type": "del_row", "id": ALL}, "id"), - State({"type": "del_row", "id": ALL}, "children"), - State( - "dummy_output", - 'children', - ), - ], - prevent_initial_call=True, -) -def del_row( - n_clicks: List[int], - rows: List[str], - button_ids: List[Dict], - del_row_labels: List[str], - dummy_data: str, -) -> Tuple[str, List[str]]: - ctx = callback_context - if not ctx.triggered or not n_clicks: - return no_update, [no_update] * len(button_ids) - button_id = json.loads(ctx.triggered[0]['prop_id'].split('.')[0])['id'] - row_index = 0 - for i, current_button_id in enumerate(button_ids): - if current_button_id['id'] == button_id: - row_index = i - break - if not n_clicks[row_index]: - return no_update, [no_update] * len(button_ids) - if rows[row_index] in get_excluded_row(): - get_excluded_row().remove(rows[row_index]) - del_row_labels[row_index] = "-" - else: - get_excluded_row().add(rows[row_index]) - del_row_labels[row_index] = "+" - - return dummy_data + '1', del_row_labels - - -@app.callback( - [ - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - ], - Input({"type": "editable_row", "id": ALL, "model_name": ALL}, 'value'), - [ - State({"type": "model_selector", "id": ALL}, "value"), - State('datatable', 'selected_rows'), - State('datatable', "page_current"), - State('datatable', "page_size"), - State({"type": "editable_row", "id": ALL, "model_name": ALL}, 'id'), - State({"type": 'file_selector', "id": ALL}, 'value'), - State("js_trigger", "children"), - ], - prevent_initial_call=True, -) -def update_data_table( - new_rows_values: List[str], - models: List[str], - idx: List[int], - current_page: int, - page_size: int, - new_rows_ids: List[str], - file_names: List[str], - js_trigger: str, -) -> Tuple[str, str]: - ctx = callback_context - if not ctx.triggered or not idx: - return no_update, no_update - - file_ids = {} - question_id = current_page * page_size + idx[0] - for model_id, name in enumerate(file_names): - for file_id, file in enumerate( - get_table_data()[question_id][models[model_id]] if len(get_table_data()) else [] - ): - if file[FILE_NAME] == name: - file_ids[models[model_id]] = file_id - - for new_rows_id, new_rows_value in zip(new_rows_ids, new_rows_values): - updated_field = new_rows_id['id'] - updated_model = new_rows_id['model_name'] - get_table_data()[question_id][updated_model][file_ids[updated_model]][updated_field] = new_rows_value - - return '', js_trigger + ' ' - - -@app.callback( - Output( - "dummy_output", - 'children', - allow_duplicate=True, - ), - Input({"type": "compare_texts_button", "id": ALL}, "n_clicks"), - [ - State("dummy_output", 'children'), - State({"type": "row_name", "id": ALL}, "children"), - State({"type": "compare_texts_button", "id": ALL}, "n_clicks"), - ], - prevent_initial_call=True, -) -def compare(n_clicks: List[int], dummy_data: str, row_names: str, button_ids: List[str]): - ctx = callback_context - if not ctx.triggered or not n_clicks: - return no_update - button_id = json.loads(ctx.triggered[0]['prop_id'].split('.')[0])['id'] - if row_names[button_id] not in get_compared_rows(): - get_compared_rows().add(row_names[button_id]) - else: - get_compared_rows().remove(row_names[button_id]) - return dummy_data + '1' - - -@app.callback( - [ - Output( - "dummy_output", - 'children', - allow_duplicate=True, - ), - Output({"type": "edit_row_image", "id": ALL}, "src"), - ], - Input({"type": "edit_row_button", "id": ALL}, "n_clicks"), - [ - State({"type": "row_name", "id": ALL}, "children"), - State({"type": "edit_row_image", "id": ALL}, "id"), - State({"type": "edit_row_image", "id": ALL}, "src"), - State({"type": "model_selector", "id": ALL}, "value"), - State('datatable', 'selected_rows'), - State('datatable', "page_current"), - State('datatable', "page_size"), - State({"type": 'file_selector', "id": ALL}, 'value'), - State( - "dummy_output", - 'children', - ), - ], - prevent_initial_call=True, -) -def edit_row( - n_clicks: List[int], - rows: List[str], - button_ids: List[Dict], - edit_row_labels: List[str], - models: List[str], - idx: List[int], - current_page: int, - page_size: int, - file_names: List[str], - dummy_data: str, -) -> Tuple[str, List[str]]: - ctx = callback_context - if not ctx.triggered or not n_clicks or not idx: - return no_update, [no_update] * len(button_ids) - button_id = json.loads(ctx.triggered[0]['prop_id'].split('.')[0])['id'] - row_index = 0 - for i, current_button_id in enumerate(button_ids): - if current_button_id['id'] == button_id: - row_index = i - break - file_ids = [0] * len(models) - question_id = current_page * page_size + idx[0] - for model_id, name in enumerate(file_names): - for file_id, file in enumerate( - get_table_data()[question_id][models[model_id]] if len(get_table_data()) else [] - ): - if file[FILE_NAME] == name: - file_ids[model_id] = file_id - - if not n_clicks[row_index]: - return no_update, [no_update] * len(button_ids) - - if rows[row_index] in get_editable_rows(): - edit_row_labels[row_index] = EDIT_ICON_PATH - get_editable_rows().remove(rows[row_index]) - else: - get_editable_rows().add(rows[row_index]) - edit_row_labels[row_index] = SAVE_ICON_PATH - - return dummy_data + '1', edit_row_labels - - -@app.callback( - Output('datatable', 'data'), - [ - Input('datatable', "page_current"), - Input('datatable', "page_size"), - ], - State("base_model_answers_selector", "value"), -) -def change_page(page_current: int, page_size: int, base_model: str) -> List[Dict]: - if not get_table_data(): - return no_update - return [ - data[base_model][0] - for data in get_table_data()[page_current * page_size : (page_current + 1) * page_size] - if base_model in data.keys() - ] - - -@app.callback( - [ - Output( - {'type': 'detailed_models_answers', 'id': ALL}, - 'children', - allow_duplicate=True, - ), - Output( - {"type": "filter_function_input", "id": ALL}, - "value", - allow_duplicate=True, - ), - Output( - {"type": "sorting_function_input", "id": ALL}, - "value", - allow_duplicate=True, - ), - ], - [ - Input('datatable', 'selected_rows'), - Input( - "dummy_output", - 'children', - ), - ], - [ - State({"type": "model_selector", "id": ALL}, "value"), - State({"type": "sorting_function_input", "id": ALL}, "value"), - State({"type": "filter_function_input", "id": ALL}, "value"), - State({"type": "row_name", "id": ALL}, "children"), - State('datatable', "page_current"), - State('datatable', "page_size"), - State({"type": 'file_selector', "id": ALL}, 'value'), - State({"type": 'text_modes', "id": ALL}, 'value'), - ], - prevent_initial_call=True, -) -def show_item( - idx: List[int], - dummmy_trigger: str, - models: List[str], - sorting_functions: List[str], - filter_functions: List[str], - rows_names: List[str], - current_page: int, - page_size: int, - file_names: List[str], - text_modes: List[List[str]], -) -> List[str]: - if not idx: - raise PreventUpdate - ctx = callback_context - if not ctx.triggered: - return [no_update, no_update, no_update] - elif ctx.triggered[0]['prop_id'] == 'datatable.selected_rows': - filter_functions = [filter_functions[0]] + [None] * (len(filter_functions) - 1) - sorting_functions = [sorting_functions[0]] + [None] * (len(sorting_functions) - 1) - question_id = current_page * page_size + idx[0] - file_ids = [0] * len(models) - for model_id, name in enumerate(file_names): - for file_id, file in enumerate( - get_table_data()[question_id][models[model_id]] if len(get_table_data()) else [] - ): - if file[FILE_NAME] == name: - file_ids[model_id] = file_id - return [ - get_table_detailed_inner_data( - question_id=question_id, - rows_names=rows_names, - models=models, - files_id=file_ids, - filter_functions=filter_functions[1:], - sorting_functions=sorting_functions[1:], - text_modes=text_modes, - ), - filter_functions, - sorting_functions, - ] - - -@app.callback( - [ - Output("stats_input_container", "children", allow_duplicate=True), - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - ], - Input("stats_modes", "value"), - State("js_trigger", "children"), - prevent_initial_call=True, -) -def change_stats_mode(modes: List[str], js_trigger: str) -> str: - if modes is None: - return no_update, no_update, no_update - return get_stats_input(modes), "", js_trigger + " " - - -@app.callback( - [ - Output( - {"type": "filter_text", "id": -1}, - "children", - allow_duplicate=True, - ), - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - ], - Input({"type": "filter_mode", "id": -1}, "value"), - State("js_trigger", "children"), - prevent_initial_call=True, -) -def change_filter_mode(modes: List[str], js_trigger: str) -> str: - if modes is None: - return no_update, no_update, no_update - mode = FILES_FILTERING if modes and len(modes) else QUESTIONS_FILTERING - text = get_filter_text(mode=mode) - return ( - text, - "", - js_trigger + " ", - ) - - -@app.callback( - Output( - "dummy_output", - 'children', - allow_duplicate=True, - ), - [ - Input({"type": "apply_label_button", "id": ALL}, "n_clicks"), - Input({"type": "delete_label_button", "id": ALL}, "n_clicks"), - ], - [ - State( - {"type": "aplly_for_all_files", "id": ALL}, - "value", - ), - State({"type": "label_selector", "id": ALL}, 'value'), - State({"type": "label_selector", "id": ALL}, "id"), - State('datatable', "page_current"), - State('datatable', "page_size"), - State('datatable', 'selected_rows'), - State({"type": "model_selector", "id": ALL}, "value"), - State("base_model_answers_selector", "value"), - State({"type": 'file_selector', "id": ALL}, 'value'), - State({"type": 'file_selector', "id": ALL}, 'options'), - State( - "dummy_output", - 'children', - ), - ], - prevent_initial_call=True, -) -def change_label( - n_click_apply: List[int], - n_click_del: List[int], - apply_for_all: List[bool], - labels: List[str], - label_ids: List[int], - current_page: int, - page_size: int, - idx: List[int], - models: List[str], - base_model: str, - file_names: List[str], - file_options: List[str], - dummy_data: str, -) -> List[List[str]]: - ctx = callback_context - if not ctx.triggered: - return no_update - - button_id = label_ids.index( - json.loads(LABEL_SELECTOR_ID.format(json.loads(ctx.triggered[-1]['prop_id'].split('.')[0])['id'])) - ) - is_apply = json.loads(ctx.triggered[-1]['prop_id'].split('.')[0])['type'] == "apply_label_button" - if not ctx.triggered[0]['value'] or labels[button_id] == CHOOSE_LABEL: - return no_update - - ALL_FILES = "ALL_FILES" - if button_id == 0: - files = [ALL_FILES] - file = [ALL_FILES] - models_to_process = [(base_model, files, file)] - apply_for_all = [[True] * len(models)] - question_ids = list(range(len(get_table_data()))) - else: - if not idx: - return no_update - models_to_process = [ - ( - models[button_id - 1], - file_options[button_id - 1], - file_names[button_id - 1], - ) - ] - question_ids = [current_page * page_size + idx[0]] - - apply_for_all_files = bool(len(apply_for_all[button_id - 1])) - for question_id in question_ids: - for model, current_file_options, current_file in models_to_process: - options = ( - current_file_options - if button_id != 0 - else [{'value': file[FILE_NAME]} for file in get_table_data()[question_id][model]] - ) - for file in options: - if not apply_for_all_files and not file['value'] == current_file: - continue - - file_id = 0 - for i, model_file in enumerate(get_table_data()[question_id][model]): - if model_file[FILE_NAME] == file['value']: - file_id = i - break - - if labels[button_id] not in get_table_data()[question_id][model][file_id][LABEL]: - if is_apply: - get_table_data()[question_id][model][file_id][LABEL].append(labels[button_id]) - - elif not is_apply: - get_table_data()[question_id][model][file_id][LABEL].remove(labels[button_id]) - - return dummy_data + "1" - - -@app.callback( - [ - Output( - {'type': 'detailed_models_answers', 'id': ALL}, - 'children', - allow_duplicate=True, - ), - Output( - "dummy_output", - 'children', - allow_duplicate=True, - ), - ], - [ - Input({"type": 'file_selector', "id": ALL}, 'value'), - Input({"type": 'text_modes', "id": ALL}, 'value'), - ], - [ - State('datatable', 'selected_rows'), - State({"type": 'file_selector', "id": ALL}, 'options'), - State({"type": "model_selector", "id": ALL}, "value"), - State({"type": "model_selector", "id": ALL}, "id"), - State({"type": "row_name", "id": ALL}, "children"), - State('datatable', "page_current"), - State('datatable', "page_size"), - State( - {'type': 'detailed_models_answers', 'id': ALL}, - 'children', - ), - State("dummy_output", "children"), - ], - prevent_initial_call=True, -) -def change_file( - file_names: List[str], - text_modes: List[List[str]], - idx: List[int], - file_options: List[str], - models: List[str], - model_ids: List[int], - rows_names: List[str], - current_page: int, - page_size: int, - table_data: List[str], - dummy_data: str, -) -> List[str]: - if not idx: - raise PreventUpdate - - ctx = callback_context - if not ctx.triggered: - return [no_update] * len(table_data), no_update - - question_id = page_size * current_page + idx[0] - for trigger in ctx.triggered: - try: - button_id = model_ids.index( - json.loads(MODEL_SELECTOR_ID.format(json.loads(trigger['prop_id'].split('.')[0])['id'])) - ) - except ValueError: - continue - - model = models[button_id] - - def get_file_id(name_id: str): - file_id = 0 - file_name = file_names[name_id]['value'] if isinstance(file_names[name_id], Dict) else file_names[name_id] - for i, file_data in enumerate(get_table_data()[question_id][model]): - if file_data[FILE_NAME] == file_name: - file_id = i - break - return file_id - - file_id = get_file_id(button_id) - base_file_id = get_file_id(0) - - question_id = current_page * page_size + idx[0] - table_data[button_id * len(rows_names) : (button_id + 1) * len(rows_names)] = get_row_detailed_inner_data( - question_id=question_id, - model=model, - file_id=file_id, - rows_names=rows_names, - files_names=[option['value'] for option in file_options[button_id]], - col_id=button_id, - text_modes=text_modes[button_id], - compare_to=get_table_data()[question_id][models[0]][base_file_id], - ) - return table_data, dummy_data + '1' if button_id == 0 else dummy_data - - -@app.callback( - [ - Output({"type": "new_label_input", "id": ALL}, "value"), - Output({"type": "label_selector", "id": ALL}, "options"), - Output({"type": "label_selector", "id": ALL}, 'value'), - ], - Input({"type": "add_new_label_button", "id": ALL}, "n_clicks"), - [ - State({"type": "new_label_input", "id": ALL}, "value"), - State({"type": "label_selector", "id": ALL}, "options"), - State({"type": "label_selector", "id": ALL}, 'value'), - State({"type": "label_selector", "id": ALL}, "id"), - ], -) -def add_new_label( - n_click: int, - new_labels: List[str], - options: List[List[str]], - values: List[str], - label_ids: List[int], -) -> Tuple[List[List[str]], List[str]]: - ctx = callback_context - no_updates = [no_update] * len(new_labels) - if not ctx.triggered: - return no_updates, no_updates, no_updates - - button_id = label_ids.index( - json.loads(LABEL_SELECTOR_ID.format(json.loads(ctx.triggered[-1]['prop_id'].split('.')[0])['id'])) - ) - - if not ctx.triggered[0]['value']: - return no_updates, no_updates, no_updates - - if new_labels[button_id] and new_labels[button_id] not in options[button_id]: - for i in range(len(options)): - new_label = {'label': new_labels[button_id], 'value': new_labels[button_id]} - if new_label not in options[i]: - options[i].append({'label': new_labels[button_id], 'value': new_labels[button_id]}) - values[button_id] = new_labels[button_id] - else: - return no_updates, no_updates, no_updates - - get_labels().append(new_labels[button_id]) - new_labels[button_id] = "" - - return new_labels, options, values - - -@app.callback( - Output({"type": "chosen_label", "id": ALL}, "children"), - Input({"type": "label_selector", "id": ALL}, "value"), - [ - State({"type": "label_selector", "id": ALL}, "id"), - State({"type": "chosen_label", "id": ALL}, "children"), - ], -) -def choose_label( - label: List[str], label_ids: List[int], chosen_labels: List[str] -) -> Tuple[List[List[str]], List[str]]: - ctx = callback_context - if not ctx.triggered: - return [no_update] * len(chosen_labels) - - for trigger in ctx.triggered: - button_id = label_ids.index( - json.loads(LABEL_SELECTOR_ID.format(json.loads(trigger['prop_id'].split('.')[0])['id'])) - ) - - if not ctx.triggered[0]['value'] or label[button_id] == CHOOSE_LABEL: - chosen_labels[button_id] = "" - else: - chosen_labels[button_id] = f"chosen label: {label[button_id]}" - - return chosen_labels - - -@app.callback( - [ - Output( - "detailed_answers_header", - "children", - allow_duplicate=True, - ), - Output( - {"type": "detailed_answers_row", "id": ALL}, - "children", - allow_duplicate=True, - ), - ], - Input("add_model", "n_clicks"), - [ - State("detailed_answers_header", "children"), - State({"type": "detailed_answers_row", "id": ALL}, "children"), - State({"type": "model_selector", "id": ALL}, "id"), - State("datatable", "selected_rows"), - ], - prevent_initial_call=True, -) -def add_model( - n_clicks: int, - header: List, - rows: List, - selectors_ids: List[int], - idx: List[int], -) -> Tuple[List, List]: - if not n_clicks: - return no_update, [no_update] * len(rows) - available_models = list(get_available_models().keys()) - last_header_id = selectors_ids[-1]['id'] if selectors_ids != [] else -1 - header.append(get_models_selector_table_cell(available_models, available_models[0], last_header_id + 1, True)) - last_cell_id = rows[-1][-1]["props"]["children"]["props"]['id']['id'] - for i, row in enumerate(rows): - row.append( - get_detailed_answer_column( - last_cell_id + i + 1, - file_id=last_header_id + 1 if i == 0 and idx else None, - ) - ) - - return header, rows - - -@app.callback( - [ - Output("detailed_answers_header", "children"), - Output({"type": "detailed_answers_row", "id": ALL}, "children"), - ], - Input({"type": "del_model", "id": ALL}, "n_clicks"), - [ - State("detailed_answers_header", "children"), - State({"type": "detailed_answers_row", "id": ALL}, "children"), - State({"type": "del_model", "id": ALL}, "id"), - ], - prevent_initial_call=True, -) -def del_model( - n_clicks: List[int], - header: List, - rows: List, - id_del: List[int], -) -> Tuple[List, List]: - ctx = callback_context - if not ctx.triggered: - return no_update, [no_update] * len(rows) - - button_id = json.loads(ctx.triggered[0]['prop_id'].split('.')[0])['id'] - - if not ctx.triggered[0]['value']: - return no_update, [no_update] * len(rows) - - for i, id in enumerate(id_del): - if id['id'] == button_id: - index = i + 2 - - header.pop(index) - for i, row in enumerate(rows): - row.pop(index) - - return header, rows - - -@app.callback( - [ - Output({"type": 'file_selector', "id": ALL}, 'options'), - Output({"type": 'file_selector', "id": ALL}, 'value'), - ], - [ - Input({"type": "apply_filter_button", "id": ALL}, "n_clicks"), - Input({"type": "apply_sorting_button", "id": ALL}, "n_clicks"), - Input({"type": "model_selector", "id": ALL}, "value"), - ], - [ - State({"type": "model_selector", "id": ALL}, "id"), - State({"type": "sorting_function_input", "id": ALL}, "value"), - State({"type": "filter_function_input", "id": ALL}, "value"), - State({"type": "apply_on_filtered_data", "id": ALL}, "value"), - State('datatable', "page_current"), - State('datatable', "page_size"), - State('datatable', 'selected_rows'), - State({"type": 'file_selector', "id": ALL}, 'options'), - State({"type": 'file_selector', "id": ALL}, 'value'), - ], - prevent_initial_call=True, -) -def change_files_order( - filter_n_click: int, - sorting_n_click: int, - models: List[str], - model_ids: List[int], - sorting_functions: List[str], - filter_functions: List[str], - apply_on_filtered_data: List[int], - current_page: int, - page_size: int, - idx: List[int], - file_selector_options: List[str], - file_selector_values: List[str], -) -> Tuple[List[List[str]], List[str]]: - no_updates = [no_update] * len(file_selector_options) - if not filter_n_click and not sorting_n_click: - return no_updates, no_updates - if not idx: - raise PreventUpdate - ctx = callback_context - if not ctx.triggered: - return no_updates, no_updates - try: - button_id = model_ids.index( - json.loads(MODEL_SELECTOR_ID.format(json.loads(ctx.triggered[-1]['prop_id'].split('.')[0])['id'])) - ) - except ValueError: - return no_updates, no_updates - - if not ctx.triggered[0]['value'] or button_id == -1: - return no_updates, no_updates - model = models[button_id] - question_id = current_page * page_size + idx[0] - array_to_filter = ( - get_table_data()[question_id][model] - if not apply_on_filtered_data or not apply_on_filtered_data[button_id] - else list( - filter( - lambda data: data[FILE_NAME] in [file_name['label'] for file_name in file_selector_options], - get_table_data()[question_id][model], - ) - ) - ) - file_selector_options[button_id] = [ - {'label': data[FILE_NAME], 'value': data[FILE_NAME]} - for data in get_filtered_files( - filter_functions[button_id + 1], - sorting_functions[button_id + 1], - array_to_filter, - ) - ] - file_selector_values[button_id] = file_selector_options[button_id][0] - - return file_selector_options, file_selector_values diff --git a/nemo_inspector/callbacks/assets/ansi_styles.css b/nemo_inspector/callbacks/assets/ansi_styles.css deleted file mode 100644 index 84c5e700d..000000000 --- a/nemo_inspector/callbacks/assets/ansi_styles.css +++ /dev/null @@ -1,44 +0,0 @@ -/* Foreground Colors */ -.ansi30 { color: black; } /* Black */ -.ansi31 { color: red ; } /* Red */ -.ansi32 { color: green; } /* Green */ -.ansi33 { color: yellow; } /* Yellow */ -.ansi34 { color: blue; } /* Blue */ -.ansi35 { color: magenta; } /* Magenta */ -.ansi36 { color: cyan; } /* Cyan */ -.ansi37 { color: white; } /* White */ -.ansi90 { color: grey; } /* Bright Black (grey) */ -.ansi91 { color: #FFCCCB; } /* Bright Red */ -.ansi92 { color: lightgreen; } /* Bright Green */ -.ansi93 { color: lightyellow; } /* Bright Yellow */ -.ansi94 { color: lightblue; } /* Bright Blue */ -.ansi95 { color: #ff80ff;} /* Bright Magenta */ -.ansi96 { color: lightcyan; } /* Bright Cyan */ -.ansi97 { color: #FFFFF7; } /* Bright White */ - -/* Background Colors */ -.ansi40 { background-color: black; } /* Black */ -.ansi41 { background-color: red; } /* Red */ -.ansi42 { background-color: green; } /* Green */ -.ansi43 { background-color: yellow; } /* Yellow */ -.ansi44 { background-color: blue; } /* Blue */ -.ansi45 { background-color: magenta; } /* Magenta */ -.ansi46 { background-color: cyan; } /* Cyan */ -.ansi47 { background-color: white; } /* White */ -.ansi100 { background-color: grey; } /* Bright Black (grey) */ -.ansi101 { background-color: #FFCCCB; } /* Bright Red */ -.ansi102 { background-color: lightgreen; } /* Bright Green */ -.ansi103 { background-color: lightyellow;} /* Bright Yellow */ -.ansi104 { background-color: lightblue; } /* Bright Blue */ -.ansi105 { background-color: #ff80ff;}/* Bright Magenta */ -.ansi106 { background-color: lightcyan; } /* Bright Cyan */ -.ansi107 { background-color: #FFFFF7; } /* Bright White */ - -/* Styles */ -.ansi1 { font-weight: bold; } /* Bold */ -.ansi3 { font-style: italic; } /* Italic */ -.ansi4 { text-decoration: underline; } /* Underline */ -.ansi9 { text-decoration: line-through; } /* Strikethrough */ -.ansi24 { text-decoration: none; } /* No underline */ -.ansi39 { color: initial; } /* Default foreground color */ -.ansi49 { background-color: initial; } /* Default background color */ diff --git a/nemo_inspector/callbacks/assets/change_element_height.js b/nemo_inspector/callbacks/assets/change_element_height.js deleted file mode 100644 index b789a4f2a..000000000 --- a/nemo_inspector/callbacks/assets/change_element_height.js +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -window.addEventListener('message', function(event) { - if (event.data && event.data.frameHeight && event.data.frameId) { - var iframe = document.getElementById(event.data.frameId); - if (iframe) { - iframe.style.height = event.data.frameHeight + 'px'; - } - } -}, false); diff --git a/nemo_inspector/callbacks/assets/images/compare_icon.png b/nemo_inspector/callbacks/assets/images/compare_icon.png deleted file mode 100644 index 582c048f9..000000000 Binary files a/nemo_inspector/callbacks/assets/images/compare_icon.png and /dev/null differ diff --git a/nemo_inspector/callbacks/assets/images/edit_icon.png b/nemo_inspector/callbacks/assets/images/edit_icon.png deleted file mode 100644 index 3b3a25e37..000000000 Binary files a/nemo_inspector/callbacks/assets/images/edit_icon.png and /dev/null differ diff --git a/nemo_inspector/callbacks/assets/images/save_icon.png b/nemo_inspector/callbacks/assets/images/save_icon.png deleted file mode 100644 index c81377782..000000000 Binary files a/nemo_inspector/callbacks/assets/images/save_icon.png and /dev/null differ diff --git a/nemo_inspector/callbacks/assets/register_textarea.js b/nemo_inspector/callbacks/assets/register_textarea.js deleted file mode 100644 index c124e8b24..000000000 --- a/nemo_inspector/callbacks/assets/register_textarea.js +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -function registerTextarea() { - var textareas = document.querySelectorAll("textarea"); - textareas.forEach(function(textarea) { - function updateHeight() { - textarea.style.height = 0 + 'px'; - - var height = Math.max(textarea.scrollHeight, textarea.offsetHeight, - textarea.clientHeight); - - textarea.style.height = height + 'px'; - }; - textarea.onload = updateHeight; - textarea.onresize = updateHeight; - textarea.addEventListener('input', updateHeight); - updateHeight() - }); -}; diff --git a/nemo_inspector/callbacks/assets/styles.css b/nemo_inspector/callbacks/assets/styles.css deleted file mode 100644 index 452ea05c8..000000000 --- a/nemo_inspector/callbacks/assets/styles.css +++ /dev/null @@ -1,9 +0,0 @@ -.button-class { - line-height: 20px; - font-size: 14px; - height: 40px; - text-overflow: ellipsis; - white-space: nowrap; - overflow: hidden; - margin-left: 2px; -} \ No newline at end of file diff --git a/nemo_inspector/callbacks/base_callback.py b/nemo_inspector/callbacks/base_callback.py deleted file mode 100644 index 0507653c7..000000000 --- a/nemo_inspector/callbacks/base_callback.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from datetime import datetime -from typing import Tuple - -from callbacks import app -from dash import html -from dash.dependencies import Input, Output -from flask import current_app -from layouts import get_compare_test_layout, get_run_test_layout -from settings.constants import CODE_BEGIN, CODE_END, CODE_OUTPUT_BEGIN, CODE_OUTPUT_END -from utils.common import get_available_models, get_data_from_files, get_height_adjustment - - -@app.callback( - [ - Output("page_content", "children"), - Output("run_mode_link", "active"), - Output("analyze_link", "active"), - ], - Input("url", "pathname"), -) -def nav_click(url: str) -> Tuple[html.Div, bool, bool]: - if url == "/": - return get_run_test_layout(), True, False - elif url == "/analyze": - config = current_app.config['nemo_inspector'] - config['inspector_params']['code_separators'] = ( - config['prompt']['template'][CODE_BEGIN], - config['prompt']['template'][CODE_END], - ) - config['inspector_params']['code_output_separators'] = ( - config['prompt']['template'][CODE_OUTPUT_BEGIN], - config['prompt']['template'][CODE_OUTPUT_END], - ) - get_data_from_files(datetime.now()) - get_available_models(datetime.now()) - return get_compare_test_layout(), False, True - - -@app.callback( - Output("js_container", "children", allow_duplicate=True), - [ - Input("page_content", "children"), - Input("js_trigger", "children"), - ], - prevent_initial_call=True, -) -def adjust_text_area_height(content: html.Div, trigger: str) -> html.Iframe: - return get_height_adjustment() diff --git a/nemo_inspector/callbacks/run_prompt_callbacks.py b/nemo_inspector/callbacks/run_prompt_callbacks.py deleted file mode 100644 index 3f81e7503..000000000 --- a/nemo_inspector/callbacks/run_prompt_callbacks.py +++ /dev/null @@ -1,487 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -from dataclasses import asdict -from typing import Dict, List, Tuple, Union - -import dash_bootstrap_components as dbc -from callbacks import app -from dash import ALL, html, no_update -from dash._callback import NoUpdate -from dash.dependencies import Input, Output, State -from flask import current_app -from layouts import ( - get_few_shots_by_id_layout, - get_query_params_layout, - get_results_content_layout, - get_single_prompt_output_layout, - get_utils_field_representation, -) -from settings.constants import ( - CODE_BEGIN, - CODE_END, - CODE_OUTPUT_BEGIN, - CODE_OUTPUT_END, - CONFIGS_FOLDER, - FEW_SHOTS_INPUT, - QUERY_INPUT_TYPE, - RETRIEVAL, - RETRIEVAL_FIELDS, - SEPARATOR_DISPLAY, - SEPARATOR_ID, - TEMPLATES_FOLDER, - UNDEFINED, -) -from utils.common import ( - extract_query_params, - get_test_data, - get_utils_dict, - get_utils_from_config, - get_values_from_input_group, - initialize_default, -) -from utils.strategies.strategy_maker import RunPromptStrategyMaker - -from nemo_skills.prompt.few_shot_examples import examples_map -from nemo_skills.prompt.utils import PromptConfig, get_prompt, load_config - - -@app.callback( - [ - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - ], - Input("prompt_params_input", "active_item"), - State("js_trigger", "children"), - prevent_initial_call=True, -) -def trigger_js(active_item: str, js_trigger: str) -> Tuple[str, str]: - return "", js_trigger + " " - - -@app.callback( - [ - Output("utils_group", "children", allow_duplicate=True), - Output("few_shots_div", "children"), - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - ], - [ - Input("examples_type", "value"), - Input('input_file', 'value'), - Input({"type": RETRIEVAL, "id": ALL}, "value"), - ], - [ - State("js_trigger", "children"), - State('utils_group', 'children'), - ], - prevent_initial_call=True, -) -def update_examples_type( - examples_type: str, - input_file: str, - retrieval_fields: List, - js_trigger: str, - raw_utils: List[Dict], -) -> Union[NoUpdate, dbc.AccordionItem]: - if not examples_type: - examples_type = "" - input_file_index = 0 - retrieval_field_index = -1 - - for retrieval_index, util in enumerate(raw_utils): - name = util['props']['children'][0]['props']['children'] - if name == 'input_file': - input_file_index = retrieval_index - if name == 'retrieval_field': - retrieval_field_index = retrieval_index - - if examples_type == RETRIEVAL: - utils = {key.split(SEPARATOR_ID)[-1]: value for key, value in get_values_from_input_group(raw_utils).items()} - utils.pop('examples_type', None) - - if ( - 'retrieval_file' in utils - and utils['retrieval_file'] - and os.path.isfile(utils['retrieval_file']) - and os.path.isfile(input_file) - ): - with open(utils['retrieval_file'], 'r') as retrieval_file, open(input_file, 'r') as input_file: - types = current_app.config['nemo_inspector']['types'] - sample = { - key: value - for key, value in json.loads(retrieval_file.readline()).items() - if key in json.loads(input_file.readline()) - } - types['retrieval_field'] = list(filter(lambda key: isinstance(sample[key], str), sample.keys())) - if retrieval_field_index != -1: - retrieval_field = raw_utils[retrieval_field_index]['props']['children'][1]['props'] - retrieval_field_value = raw_utils[retrieval_field_index]['props']['children'][1]['props']['value'] - retrieval_field['options'] = types['retrieval_field'] - if retrieval_field_value in types['retrieval_field']: - retrieval_field['value'] = retrieval_field_value - else: - retrieval_field['value'] = types['retrieval_field'][0] - utils["retrieval_field"] = retrieval_field['value'] - - if raw_utils[input_file_index + 1]['props']['children'][0]['props']['children'] not in RETRIEVAL_FIELDS: - for retrieval_field in RETRIEVAL_FIELDS: - raw_utils.insert( - input_file_index + 1, - get_utils_dict( - retrieval_field, - current_app.config['nemo_inspector']['retrieval_fields'][retrieval_field], - {"type": RETRIEVAL, "id": retrieval_field}, - ), - ) - - else: - while ( - input_file_index + 1 < len(raw_utils) - and raw_utils[input_file_index + 1]['props']['children'][0]['props']['children'] in RETRIEVAL_FIELDS - ): - raw_utils.pop(input_file_index + 1) - - size = len(examples_map.get(examples_type, [])) - return ( - raw_utils, - RunPromptStrategyMaker().get_strategy().get_few_shots_div_layout(size), - "", - js_trigger + " ", - ) - - -@app.callback( - [ - Output("few_shots_pagination_content", "children"), - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - ], - [ - Input("few_shots_pagination", "active_page"), - Input( - { - "type": "text_modes", - "id": FEW_SHOTS_INPUT, - }, - "value", - ), - Input("dummy_output", "children"), - ], - State('examples_type', "value"), - State("js_trigger", "children"), - prevent_initial_call=True, -) -def change_examples_page( - page: int, - text_modes: List[str], - dummy_output: str, - examples_type: str, - js_trigger: str, -) -> Tuple[Tuple[html.Div], int]: - if not examples_type: - examples_type = "" - return ( - get_few_shots_by_id_layout(page, examples_type, text_modes), - '', - js_trigger + '', - ) - - -@app.callback( - [ - Output( - SEPARATOR_ID.join(field.split(SEPARATOR_DISPLAY)), - "value", - allow_duplicate=True, - ) - for field in get_utils_from_config({"prompt": asdict(initialize_default(PromptConfig))}).keys() - ] - + [ - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - ], - [ - Input("prompt_config", "value"), - Input("prompt_template", "value"), - ], - State("js_trigger", "children"), - prevent_initial_call=True, -) -def update_prompt_type( - prompt_config: str, prompt_template: str, js_trigger: str -) -> Union[NoUpdate, dbc.AccordionItem]: - config_path = os.path.join(CONFIGS_FOLDER, prompt_config) - template_path = os.path.join(TEMPLATES_FOLDER, prompt_template) - if ( - "used_prompt" in current_app.config['nemo_inspector']['prompt'] - and (config_path, template_path) == current_app.config['nemo_inspector']['prompt']['used_prompt'] - ): - output_len = len(get_utils_from_config(asdict(initialize_default(PromptConfig))).keys()) - return [no_update] * (output_len + 2) - - current_app.config['nemo_inspector']['prompt']['used_prompt'] = (config_path, template_path) - if not os.path.isfile(config_path) and not os.path.isfile(template_path): - output_len = len(get_utils_from_config(asdict(initialize_default(PromptConfig))).keys()) - return [no_update] * (output_len + 2) - elif not os.path.isfile(config_path): - prompt_config = initialize_default(PromptConfig, load_config(template_path)) - elif not os.path.isfile(template_path): - prompt_config = initialize_default(PromptConfig, asdict(get_prompt(config_path).config)) - else: - prompt_config = initialize_default(PromptConfig, asdict(get_prompt(config_path, template_path).config)) - - current_app.config['nemo_inspector']['prompt']['stop_phrases'] = prompt_config.template.stop_phrases - - return [ - get_utils_field_representation(value, key) - for key, value in get_utils_from_config(asdict(prompt_config)).items() - ] + ['', js_trigger + " "] - - -@app.callback( - Output("dummy_output", "children", allow_duplicate=True), - [ - Input(CODE_BEGIN, "value"), - Input(CODE_END, "value"), - Input(CODE_OUTPUT_BEGIN, "value"), - Input(CODE_OUTPUT_END, "value"), - ], - State( - "dummy_output", - 'children', - ), - prevent_initial_call=True, -) -def update_code_separators( - code_begin: str, code_end: str, code_output_begin: str, code_output_end: str, dummy_data: str -) -> str: - current_app.config['nemo_inspector']['inspector_params']['code_separators'] = (code_begin, code_end) - current_app.config['nemo_inspector']['inspector_params']['code_output_separators'] = ( - code_output_begin, - code_output_end, - ) - - return dummy_data + "1" - - -@app.callback( - [ - Output("results_content", "children"), - Output("loading_container", "children", allow_duplicate=True), - ], - Input("run_button", "n_clicks"), - [ - State("utils_group", "children"), - State("run_mode_options", "value"), - State({"type": QUERY_INPUT_TYPE, "id": ALL}, "value"), - State({"type": QUERY_INPUT_TYPE, "id": ALL}, "id"), - State({"type": "query_store", "id": ALL}, "data"), - State("loading_container", "children"), - ], - prevent_initial_call=True, -) -def get_run_test_results( - n_clicks: int, - utils: List[Dict], - run_mode: str, - query_params: List[str], - query_params_ids: List[Dict], - query_store: List[Dict[str, str]], - loading_container: str, -) -> Union[Tuple[html.Div, str], Tuple[NoUpdate, NoUpdate]]: - if n_clicks is None: - return no_update, no_update - - utils = get_values_from_input_group(utils) - if "examples_type" in utils and utils["examples_type"] is None: - utils["examples_type"] = "" - - if None not in query_params: - query_store = [extract_query_params(query_params_ids, query_params)] - - return ( - RunPromptStrategyMaker(run_mode) - .get_strategy() - .run( - utils, - query_store[0], - ), - loading_container + " ", - ) - - -@app.callback( - [ - Output("prompt_params_input", "children", allow_duplicate=True), - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - Output("results_content", "children", allow_duplicate=True), - ], - Input("run_mode_options", "value"), - [ - State("utils_group", "children"), - State("js_trigger", "children"), - ], - prevent_initial_call=True, -) -def change_mode(run_mode: str, utils: List[Dict], js_trigger: str) -> Tuple[List[dbc.AccordionItem], None]: - utils = get_values_from_input_group(utils) - return ( - get_query_params_layout(run_mode, utils.get('input_file', UNDEFINED)), - "", - js_trigger + ' ', - None, - ) - - -@app.callback( - [ - Output("query_input_children", "children", allow_duplicate=True), - Output({"type": "query_store", "id": ALL}, "data", allow_duplicate=True), - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - ], - [ - Input("query_search_button", "n_clicks"), - Input("input_file", "value"), - Input("run_mode_options", "value"), - ], - [ - State("query_search_input", "value"), - State( - { - "type": "text_modes", - "id": QUERY_INPUT_TYPE, - }, - "value", - ), - State("js_trigger", "children"), - ], - prevent_initial_call=True, -) -def prompt_search( - n_clicks: int, - input_file: str, - run_mode: str, - index: int, - text_modes: List[str], - js_trigger: str, -) -> Tuple[Union[List[str], NoUpdate]]: - query_data = get_test_data(index, input_file)[0] - return ( - RunPromptStrategyMaker() - .get_strategy() - .get_query_input_children_layout( - query_data, - text_modes, - ), - [query_data], - "", - js_trigger + " ", - ) - - -@app.callback( - [ - Output("query_input_children", "children", allow_duplicate=True), - Output({"type": "query_store", "id": ALL}, "data", allow_duplicate=True), - Output("js_container", "children", allow_duplicate=True), - Output("js_trigger", "children", allow_duplicate=True), - ], - [ - Input( - { - "type": "text_modes", - "id": QUERY_INPUT_TYPE, - }, - "value", - ), - ], - [ - State({"type": "query_store", "id": ALL}, "data"), - State({"type": QUERY_INPUT_TYPE, "id": ALL}, "value"), - State({"type": QUERY_INPUT_TYPE, "id": ALL}, "id"), - State("js_trigger", "children"), - ], - prevent_initial_call=True, -) -def change_prompt_search_mode( - text_modes: List[str], - query_store: List[Dict[str, str]], - query_params: List[str], - query_params_ids: List[int], - js_trigger: str, -) -> Tuple[Union[List[str], NoUpdate]]: - if None not in query_params: - query_store = [extract_query_params(query_params_ids, query_params)] - - return ( - RunPromptStrategyMaker() - .get_strategy() - .get_query_input_children_layout( - query_store[0], - text_modes, - ), - query_store, - "", - js_trigger + " ", - ) - - -@app.callback( - Output("results_content", "children", allow_duplicate=True), - Input("preview_button", "n_clicks"), - [ - State("run_mode_options", "value"), - State("utils_group", "children"), - State({"type": QUERY_INPUT_TYPE, "id": ALL}, "value"), - State({"type": QUERY_INPUT_TYPE, "id": ALL}, "id"), - State({"type": "query_store", "id": ALL}, "data"), - ], - prevent_initial_call=True, -) -def preview( - n_clicks: int, - run_mode: str, - utils: List[Dict], - query_params: List[str], - query_params_ids: List[int], - query_store: List[Dict[str, str]], -) -> html.Pre: - if None not in query_params: - query_store = [extract_query_params(query_params_ids, query_params)] - - utils = get_values_from_input_group(utils) - - prompt = RunPromptStrategyMaker(run_mode).get_strategy().get_prompt(utils, query_store[0]) - return get_results_content_layout(str(prompt)) - - -@app.callback( - Output("results_content_text", "children", allow_duplicate=True), - Input( - { - "type": "text_modes", - "id": "results_content", - }, - "value", - ), - State("text_store", "data"), - prevent_initial_call=True, -) -def change_results_content_mode(text_modes: List[str], text: str) -> html.Pre: - return get_single_prompt_output_layout(text, text_modes) if text_modes and len(text_modes) else text diff --git a/nemo_inspector/images/analyze_page.png b/nemo_inspector/images/analyze_page.png deleted file mode 100644 index fc3e18fe3..000000000 Binary files a/nemo_inspector/images/analyze_page.png and /dev/null differ diff --git a/nemo_inspector/images/demo.png b/nemo_inspector/images/demo.png deleted file mode 100644 index 82cd0f458..000000000 Binary files a/nemo_inspector/images/demo.png and /dev/null differ diff --git a/nemo_inspector/images/inference_page.png b/nemo_inspector/images/inference_page.png deleted file mode 100644 index 351628a72..000000000 Binary files a/nemo_inspector/images/inference_page.png and /dev/null differ diff --git a/nemo_inspector/images/stats.png b/nemo_inspector/images/stats.png deleted file mode 100644 index da9c1fea0..000000000 Binary files a/nemo_inspector/images/stats.png and /dev/null differ diff --git a/nemo_inspector/layouts/__init__.py b/nemo_inspector/layouts/__init__.py deleted file mode 100644 index e419ede71..000000000 --- a/nemo_inspector/layouts/__init__.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from layouts.analyze_page_layouts import ( - get_compare_test_layout, - get_few_shots_layout, - get_models_options_layout, - get_stats_input, - get_stats_text, - get_utils_layout, -) -from layouts.base_layouts import ( - get_input_group_layout, - get_main_page_layout, - get_results_content_layout, - get_selector_layout, - get_single_prompt_output_layout, - get_switch_layout, - get_text_area_layout, - get_text_modes_layout, - get_utils_field_representation, -) -from layouts.run_prompt_page_layouts import ( - get_few_shots_by_id_layout, - get_query_params_layout, - get_run_mode_layout, - get_run_test_layout, -) -from layouts.table_layouts import ( - get_detailed_answer_column, - get_filter_answers_layout, - get_filter_layout, - get_filter_text, - get_labels, - get_model_answers_table_layout, - get_models_selector_table_cell, - get_row_detailed_inner_data, - get_single_prompt_output_layout, - get_sorting_answers_layout, - get_sorting_layout, - get_stats_layout, - get_table_data, - get_table_detailed_inner_data, - get_update_dataset_layout, -) diff --git a/nemo_inspector/layouts/analyze_page_layouts.py b/nemo_inspector/layouts/analyze_page_layouts.py deleted file mode 100644 index 39dee421a..000000000 --- a/nemo_inspector/layouts/analyze_page_layouts.py +++ /dev/null @@ -1,361 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from typing import Dict, List - -import dash_bootstrap_components as dbc -from dash import dcc, html -from flask import current_app -from layouts.base_layouts import get_code_text_area_layout, get_selector_layout, get_switch_layout -from layouts.table_layouts import get_change_label_layout, get_filter_layout, get_sorting_layout -from settings.constants import CHOOSE_MODEL, CUSTOM, DELETE, GENERAL_STATS, INLINE_STATS -from utils.common import get_available_models, get_custom_stats, get_general_custom_stats, get_stats_raw - - -def get_models_options_layout() -> dbc.Accordion: - runs_storage = get_available_models() - items = [ - dbc.AccordionItem( - dbc.Accordion( - [ - get_utils_layout(values["utils"]), - get_few_shots_layout(values["examples"]), - ], - start_collapsed=True, - always_open=True, - ), - title=model, - ) - for model, values in runs_storage.items() - ] - models_options = dbc.Accordion( - items, - start_collapsed=True, - always_open=True, - ) - return dbc.Accordion( - dbc.AccordionItem( - models_options, - title="Generation parameters", - ), - start_collapsed=True, - always_open=True, - ) - - -def get_utils_layout(utils: Dict) -> dbc.AccordionItem: - input_groups = [ - dbc.InputGroup( - [ - html.Pre(f"{name}: ", className="mr-2"), - html.Pre( - (value if value == "" or str(value).strip() != "" else repr(value)[1:-1]), - className="mr-2", - style={"overflow-x": "scroll"}, - ), - ], - className="mb-3", - ) - for name, value in utils.items() - ] - return dbc.AccordionItem( - html.Div(input_groups), - title="Utils", - ) - - -def get_few_shots_layout(examples: List[Dict]) -> dbc.AccordionItem: - example_layout = lambda example: [ - html.Div( - [ - dcc.Markdown(f'**{name}**'), - html.Pre(value), - ] - ) - for name, value in example.items() - ] - examples_layout = [ - dbc.Accordion( - dbc.AccordionItem( - example_layout(example), - title=f"example {id}", - ), - start_collapsed=True, - always_open=True, - ) - for id, example in enumerate(examples) - ] - return dbc.AccordionItem( - html.Div(examples_layout), - title="Few shots", - ) - - -def get_update_dataset_modal_layout() -> html.Div: - text = ( - "Write an expression to modify the data\n\n" - "For example: {**data, 'generation': data['generation'].strip()}\n\n" - "The function has to return a new dict" - ) - header = dbc.ModalHeader( - dbc.ModalTitle("Update Dataset"), - close_button=True, - ) - body = dbc.ModalBody( - html.Div( - [ - html.Pre(text), - get_code_text_area_layout( - id="update_dataset_input", - ), - ], - ) - ) - footer = dbc.ModalFooter( - dbc.Button( - "Apply", - id="apply_update_dataset_button", - className="ms-auto", - n_clicks=0, - ) - ) - return html.Div( - [ - dbc.Button( - "Update dataset", - id="update_dataset_button", - class_name='button-class', - ), - dbc.Modal( - [ - header, - body, - footer, - ], - size="lg", - id="update_dataset_modal", - centered=True, - is_open=False, - ), - ], - style={'display': 'inline-block'}, - ) - - -def get_save_dataset_layout() -> html.Div: - return html.Div( - [ - dbc.Button("Save dataset", id="save_dataset", class_name='button-class'), - dbc.Modal( - [ - dbc.ModalBody( - [ - dbc.InputGroup( - [ - dbc.InputGroupText('save_path'), - dbc.Input( - value=os.path.join( - current_app.config['nemo_inspector']['inspector_params'][ - 'save_generations_path' - ], - 'default_name', - ), - id='save_path', - type='text', - ), - ], - className="mb-3", - ), - dbc.Container(id="error_message"), - ] - ), - dbc.ModalFooter( - dbc.Button( - "Save", - id="save_dataset_button", - className="ms-auto", - n_clicks=0, - ) - ), - ], - id="save_dataset_modal", - is_open=False, - style={ - 'text-align': 'center', - "margin-top": "10px", - "margin-bottom": "10px", - }, - ), - ], - ) - - -def get_compare_test_layout() -> html.Div: - return html.Div( - [ - get_models_options_layout(), - dbc.InputGroup( - [ - get_sorting_layout(), - get_filter_layout(), - get_add_stats_layout(), - get_change_label_layout(apply_for_all_files=False), - get_update_dataset_modal_layout(), - get_save_dataset_layout(), - dbc.Button( - "+", - id="add_model", - outline=True, - color="primary", - className="me-1", - class_name='button-class', - style={'margin-left': '1px'}, - ), - get_selector_layout( - get_available_models().keys(), - 'base_model_answers_selector', - value=CHOOSE_MODEL, - ), - ] - ), - html.Pre(id="filtering_container"), - html.Pre(id="sorting_container"), - dcc.Loading( - children=dbc.Container(id="loading_container", style={'display': 'none'}, children=""), - type='circle', - style={'margin-top': '50px'}, - ), - html.Div( - children=[], - id="compare_models_rows", - ), - ], - ) - - -def get_stats_text(general_stats: bool = False, delete: bool = False): - if delete: - return "Choose the name of the statistic you want to delete" - else: - if general_stats: - return ( - "Creating General Custom Statistics:\n\n" - "To introduce new general custom statistics:\n" - "1. Create a dictionary where keys are the names of your custom stats.\n" - "2. Assign functions as values. These functions should accept arrays where first dimension\n" - "is a question index and second is a file number (both sorted and filtered).\n\n" - "Example:\n\n" - "Define a custom function to integrate into your stats:\n\n" - "def my_func(datas):\n" - " correct_responses = 0\n" - " for question_data in datas:\n" - " for file_data in question_data:\n" - " correct_responses += file_data['is_correct']\n" - " return correct_responses\n" - "{'correct_responses': my_func}" - ) - else: - return ( - "Creating Custom Statistics:\n\n" - "To introduce new custom statistics:\n" - "1. Create a dictionary where keys are the names of your custom stats.\n" - "2. Assign functions as values. These functions should accept arrays containing data\n" - "from all relevant files.\n\n" - "Note: Do not use names that already exist in the current stats or JSON fields\n" - "to avoid conflicts.\n\n" - "Example:\n\n" - "Define a custom function to integrate into your stats:\n\n" - "def unique_error_counter(datas):\n" - " unique_errors = set()\n" - " for data in datas:\n" - " unique_errors.add(data.get('error_message'))\n" - " return len(unique_errors)\n\n" - "{'unique_error_count': unique_error_counter}" - ) - - -def get_stats_input(modes: List[str] = []) -> List: - body = [] - if DELETE in modes: - delete_options = list( - get_general_custom_stats().keys() if GENERAL_STATS in modes else get_custom_stats().keys() - ) - body += [ - get_selector_layout( - delete_options, - "stats_input", - delete_options[0] if delete_options else "", - ) - ] - else: - mode = GENERAL_STATS if GENERAL_STATS in modes else INLINE_STATS - extractor_options = list(get_stats_raw()[mode].keys()) - body += [ - get_selector_layout(extractor_options, "stats_extractor", CUSTOM), - get_code_text_area_layout(id="stats_input"), - ] - return [ - html.Pre(get_stats_text(GENERAL_STATS in modes, DELETE in modes), id="stats_text"), - ] + body - - -def get_add_stats_layout() -> html.Div: - modal_header = dbc.ModalHeader( - [ - dbc.ModalTitle("Set Up Your Stats"), - get_switch_layout( - id="stats_modes", - labels=["general stats", "delete mode"], - values=[GENERAL_STATS, DELETE], - additional_params={"inline": True, "style": {"margin-left": "10px"}}, - ), - ], - close_button=True, - ) - modal_body = dbc.ModalBody( - html.Div( - get_stats_input(), - id="stats_input_container", - ) - ) - modal_footer = dbc.ModalFooter( - dbc.Button( - "Apply", - id="apply_new_stats", - className="ms-auto", - n_clicks=0, - ) - ) - return html.Div( - [ - dbc.Button( - "Stats", - id="set_new_stats_button", - class_name='button-class', - ), - dbc.Modal( - [ - modal_header, - modal_body, - modal_footer, - ], - size="lg", - id="new_stats", - centered=True, - is_open=False, - ), - ], - style={'display': 'inline-block'}, - ) diff --git a/nemo_inspector/layouts/base_layouts.py b/nemo_inspector/layouts/base_layouts.py deleted file mode 100644 index bc1ad99e3..000000000 --- a/nemo_inspector/layouts/base_layouts.py +++ /dev/null @@ -1,284 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -from typing import Dict, Iterable, List, Optional, Union - -import dash_ace -import dash_bootstrap_components as dbc -from dash import dcc, html -from flask import current_app -from settings.constants import ANSI, CODE, COMPARE, LATEX, MARKDOWN, SEPARATOR_DISPLAY, SEPARATOR_ID, UNDEFINED -from utils.common import parse_model_answer -from utils.decoration import color_text_diff, design_text_output, highlight_code - - -def get_main_page_layout() -> html.Div: - nav_items = [ - dbc.NavItem( - dbc.NavLink( - "Inference", - id="run_mode_link", - href="/", - active=True, - ) - ), - dbc.NavItem(dbc.NavLink("Analyze", id="analyze_link", href="/analyze")), - ] - return html.Div( - [ - dcc.Location(id="url", refresh=False), - dbc.NavbarSimple( - children=nav_items, - brand="NeMo Inspector", - sticky="top", - color="blue", - dark=True, - class_name="mb-2", - ), - dbc.Container(id="page_content"), - dbc.Container(id="js_trigger", style={'display': 'none'}, children=""), - dbc.Container(id="js_container"), - dbc.Container(id='dummy_output', style={'display': 'none'}, children=""), - ] - ) - - -def get_switch_layout( - id: Union[Dict, str], - labels: List[str], - values: Optional[List[str]] = None, - disabled: List[bool] = [False], - is_active: bool = False, - chosen_values: Optional[List[str]] = None, - additional_params: Dict = {}, -) -> dbc.Checklist: - if values is None: - values = labels - return dbc.Checklist( - id=id, - options=[ - { - "label": label, - "value": value, - "disabled": is_disabled, - } - for label, value, is_disabled in itertools.zip_longest(labels, values, disabled, fillvalue=False) - ], - value=(chosen_values if chosen_values else [values[0]] if is_active else []), - **additional_params, - ) - - -def get_selector_layout(options: Iterable, id: str, value: str = "") -> dbc.Select: - if value not in options: - options = [value] + list(options) - return dbc.Select( - id=id, - options=[ - { - "label": str(value), - "value": value, - } - for value in options - ], - value=str(value), - ) - - -def get_text_area_layout( - id: str, value: str, text_modes: List[str] = [], editable: bool = False -) -> Union[dbc.Textarea, html.Pre]: - if editable and text_modes == []: - component = dbc.Textarea - children = {"value": value} - else: - component = html.Pre - children = {"children": get_single_prompt_output_layout(value, text_modes)} - return component( - **children, - id=id, - style={ - 'width': '100%', - 'border': "1px solid #dee2e6", - }, - ) - - -def get_single_prompt_output_layout( - answer: str, text_modes: List[str] = [CODE, LATEX, ANSI], compare_to: str = "" -) -> List[html.Div]: - parsed_answers = ( - parse_model_answer(answer) if CODE in text_modes else [{"explanation": answer, "code": None, "output": None}] - ) - parsed_compared_answers = ( - ( - parse_model_answer(compare_to) - if CODE in text_modes - else [{"explanation": compare_to, "code": None, "output": None}] - ) - if COMPARE in text_modes - else parsed_answers - ) - - items = [] - styles = { - "explanation": {'default': {}, 'wrong': {}}, - "code": {'default': {}, 'wrong': {}}, - "output": { - "default": { - "border": "1px solid black", - "background-color": "#cdd4f1c8", - "marginBottom": "10px", - "marginTop": "-6px", - }, - 'wrong': { - "border": "1px solid red", - "marginBottom": "10px", - "marginTop": "-6px", - }, - }, - } - - functions = {"explanation": design_text_output, "code": highlight_code, "output": design_text_output} - - def check_existence(array: List[Dict[str, str]], i: int, key: str): - return i < len(array) and key in array[i] and array[i][key] - - for i in range(max(len(parsed_answers), len(parsed_compared_answers))): - for key in ["explanation", "code", "output"]: - if check_existence(parsed_answers, i, key) or check_existence(parsed_compared_answers, i, key): - diff = color_text_diff( - parsed_answers[i][key] if check_existence(parsed_answers, i, key) else "", - parsed_compared_answers[i][key] if check_existence(parsed_compared_answers, i, key) else "", - ) - style_type = ( - 'default' - if not check_existence(parsed_answers, i, key) or 'wrong_code_block' not in parsed_answers[i][key] - else 'wrong' - ) - style = styles[key][style_type] - item = functions[key](diff, style=style, text_modes=text_modes) - items.append(item) - return items - - -def get_text_modes_layout(id: str, is_formatted: bool = True): - return get_switch_layout( - id={ - "type": "text_modes", - "id": id, - }, - labels=[CODE, LATEX, MARKDOWN, ANSI], - chosen_values=[CODE, LATEX, ANSI] if is_formatted else [], - additional_params={ - "style": { - "display": "inline-flex", - "flex-wrap": "wrap", - }, - "inputStyle": {'margin-left': '-10px'}, - "labelStyle": {'margin-left': '3px'}, - }, - ) - - -def get_results_content_layout( - text: str, content: str = None, style: Dict = {}, is_formatted: bool = False -) -> html.Div: - return html.Div( - [ - get_text_modes_layout("results_content", is_formatted), - html.Pre( - content if content else text, - id="results_content_text", - style={'margin-bottom': '10px'}, - ), - dcc.Store(data=text, id="text_store"), - ], - style=style, - ) - - -def validation_parameters(name: str, value: Union[str, int, float]) -> Dict[str, str]: - parameters = {"type": "text"} - if str(value).replace(".", "", 1).replace("-", "", 1).isdigit(): - parameters["type"] = "number" - - if str(value).isdigit(): - parameters["min"] = 0 - - if "." in str(value) and str(value).replace(".", "", 1).isdigit(): - parameters["min"] = 0 - parameters["max"] = 1 if name != "temperature" else 100 - parameters["step"] = 0.1 - - return parameters - - -def get_input_group_layout(name: str, value: Union[str, int, float, bool]) -> dbc.InputGroup: - input_function = dbc.Textarea - additional_params = { - "style": { - 'width': '100%', - }, - "debounce": True, - } - if name.split(SEPARATOR_DISPLAY)[-1] in current_app.config['nemo_inspector']['types'].keys(): - input_function = get_selector_layout - additional_params = { - "options": current_app.config['nemo_inspector']['types'][name.split(SEPARATOR_DISPLAY)[-1]], - } - if value is None: - value = UNDEFINED - elif isinstance(value, bool): - input_function = get_selector_layout - additional_params = {"options": [True, False]} - elif isinstance(value, (float, int)): - input_function = dbc.Input - additional_params = validation_parameters(name, value) - additional_params["debounce"] = True - - return dbc.InputGroup( - [ - dbc.InputGroupText(name), - input_function( - value=get_utils_field_representation(value), - id=name.replace(SEPARATOR_DISPLAY, SEPARATOR_ID), - **additional_params, - ), - ], - className="mb-3", - ) - - -def get_utils_field_representation(value: Union[str, int, float, bool], key: str = "") -> str: - return ( - UNDEFINED - if value is None and key.split(SEPARATOR_ID)[-1] in current_app.config['nemo_inspector']['types'] - else value if value == "" or str(value).strip() != "" else repr(value)[1:-1] - ) - - -def get_code_text_area_layout(id): - return dash_ace.DashAceEditor( - id=id, - theme='tomorrow_night', - mode='python', - tabSize=4, - value="", - enableBasicAutocompletion=True, - enableLiveAutocompletion=True, - placeholder='Write your code here...', - style={'width': '100%', 'height': '300px'}, - ) diff --git a/nemo_inspector/layouts/run_prompt_page_layouts.py b/nemo_inspector/layouts/run_prompt_page_layouts.py deleted file mode 100644 index 718602c85..000000000 --- a/nemo_inspector/layouts/run_prompt_page_layouts.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Tuple - -import dash_bootstrap_components as dbc -from dash import dcc, html -from flask import current_app -from layouts.base_layouts import get_text_area_layout -from settings.constants import CHAT_MODE, FEW_SHOTS_INPUT, ONE_SAMPLE_MODE -from utils.strategies.strategy_maker import RunPromptStrategyMaker - -from nemo_skills.prompt.few_shot_examples import examples_map - - -def get_few_shots_by_id_layout(page: int, examples_type: str, text_modes: List[str]) -> Tuple[html.Div]: - examples_list = examples_map.get(examples_type, [{}]) - if not page or len(examples_list) < page: - return html.Div() - return ( - html.Div( - [ - dbc.InputGroup( - [ - dbc.InputGroupText(key), - get_text_area_layout({"type": FEW_SHOTS_INPUT, "id": key}, str(value), text_modes), - ], - className="mb-3", - ) - for key, value in (examples_list[page - 1].items()) - ], - ), - ) - - -def get_query_params_layout(mode: str = ONE_SAMPLE_MODE, dataset: str = None) -> List[dbc.AccordionItem]: - strategy = RunPromptStrategyMaker(mode).get_strategy() - return ( - strategy.get_utils_input_layout() - + strategy.get_few_shots_input_layout() - + strategy.get_query_input_layout(dataset) - ) - - -def get_run_mode_layout() -> html.Div: - return html.Div( - [ - dbc.RadioItems( - id="run_mode_options", - className="btn-group", - inputClassName="btn-check", - labelClassName="btn btn-outline-primary", - labelCheckedClassName="active", - options=[ - {"label": "Chat", "value": CHAT_MODE}, - {"label": "Run one sample", "value": ONE_SAMPLE_MODE}, - ], - value=ONE_SAMPLE_MODE, - ), - ], - className="radio-group", - ) - - -def get_run_test_layout() -> html.Div: - return html.Div( - [ - get_run_mode_layout(), - dbc.Accordion( - get_query_params_layout(dataset=current_app.config['nemo_inspector']['input_file']), - start_collapsed=True, - always_open=True, - id="prompt_params_input", - ), - dbc.Button( - "preview", - id="preview_button", - outline=True, - color="primary", - className="me-1 mb-2", - ), - dbc.Button( - "run", - id="run_button", - outline=True, - color="primary", - className="me-1 mb-2", - ), - dcc.Loading( - children=dbc.Container(id="loading_container", style={'display': 'none'}, children=""), - type='circle', - style={'margin-top': '50px'}, - ), - dbc.Container(id="results_content"), - ] - ) diff --git a/nemo_inspector/layouts/table_layouts.py b/nemo_inspector/layouts/table_layouts.py deleted file mode 100644 index 74b0b8efa..000000000 --- a/nemo_inspector/layouts/table_layouts.py +++ /dev/null @@ -1,931 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import logging -import math -from typing import Dict, List - -import dash_bootstrap_components as dbc -from dash import dash_table, html -from layouts.base_layouts import ( - get_code_text_area_layout, - get_selector_layout, - get_single_prompt_output_layout, - get_switch_layout, - get_text_modes_layout, -) -from settings.constants import ( - ANSI, - CODE, - COMPARE, - COMPARE_ICON_PATH, - DATA_PAGE_SIZE, - EDIT_ICON_PATH, - ERROR_MESSAGE_TEMPLATE, - FILE_NAME, - FILES_FILTERING, - FILES_ONLY, - LABEL, - LATEX, - MODEL_SELECTOR_ID, - NAME_FOR_BASE_MODEL, - QUESTIONS_FILTERING, - STATS_KEYS, -) -from utils.common import ( - catch_eval_exception, - custom_deepcopy, - get_available_models, - get_compared_rows, - get_data_from_files, - get_editable_rows, - get_eval_function, - get_excluded_row, - get_filtered_files, - get_general_custom_stats, - get_metrics, - is_detailed_answers_rows_key, -) - -table_data = [] -labels = [] - - -def get_table_data() -> List: - return table_data - - -def get_labels() -> List: - return labels - - -def get_filter_text(available_filters: List[str] = [], mode: str = FILES_FILTERING) -> str: - available_filters = list( - get_table_data()[0][list(get_table_data()[0].keys())[0]][0].keys() - if len(get_table_data()) and not available_filters - else STATS_KEYS + list(get_metrics([]).keys()) + ["+ all fields in json"] - ) - if mode == FILES_ONLY: - return ( - "Write an expression to filter the data\n\n" - + "For example:\ndata['is_correct'] and not data['error_message']\n\n" - + "The expression has to return bool.\n\n" - + "Available parameters to filter data:\n" - + '\n'.join( - [', '.join(available_filters[start : start + 5]) for start in range(0, len(available_filters), 5)] - ), - ) - elif mode == FILES_FILTERING: - return ( - "Write an expression to filter the data\n" - + "Separate expressions for different generations with &&\n" - + "You can use base_generation variable to access data from the current generation\n\n" - + "For example:\ndata['generation1']['correct_responses'] > 0.5 && data[base_generation]['no_response'] < 0.2\n\n" - + "The expression has to return bool.\n\n" - + "Available parameters to filter data:\n" - + '\n'.join( - [', '.join(available_filters[start : start + 5]) for start in range(0, len(available_filters), 5)] - ), - ) - elif mode == QUESTIONS_FILTERING: - return ( - "Write an expression to filter the data\n" - + "You can operate with a dictionary containing keys representing generation names\n" - + "and a list of values as JSON data from your generation from each file.\n" - + "You can use base_generation variable to access data from the current generation\n\n" - + "For example:\ndata['generation1'][0]['is_correct'] != data[base_generation][0]['is_correct']\n\n" - + "The expression has to return bool.\n\n" - + "Available parameters to filter data:\n" - + '\n'.join( - [', '.join(available_filters[start : start + 5]) for start in range(0, len(available_filters), 5)] - ), - ) - - -def get_filter_layout(id: int = -1, available_filters: List[str] = [], mode: str = FILES_FILTERING) -> html.Div: - text = get_filter_text(available_filters, mode) - - filter_mode = ( - [ - get_switch_layout( - id={"type": "filter_mode", "id": id}, - labels=["filter files"], - is_active=True, - additional_params={ - "inline": True, - "style": {"margin-left": "10px"}, - }, - ) - ] - if mode != FILES_ONLY - else [] - ) - - header = dbc.ModalHeader( - ( - [ - dbc.ModalTitle( - "Set Up Your Filter", - ), - ] - + filter_mode - ), - close_button=True, - ) - body = dbc.ModalBody( - html.Div( - [ - html.Pre(text, id={"type": "filter_text", "id": id}), - get_code_text_area_layout( - id={ - "type": "filter_function_input", - "id": id, - }, - ), - ] - ) - ) - switch = get_switch_layout( - { - "type": "apply_on_filtered_data", - "id": id, - }, - ["Apply for filtered data"], - additional_params={"style": {"margin-left": "10px"}}, - ) - footer = dbc.ModalFooter( - dbc.Button( - "Apply", - id={"type": "apply_filter_button", "id": id}, - className="ms-auto", - n_clicks=0, - ) - ) - return html.Div( - [ - dbc.Button( - "Filters", - id={"type": "set_filter_button", "id": id}, - class_name='button-class', - ), - dbc.Modal( - [ - header, - body, - switch, - footer, - ], - size="lg", - id={"type": "filter", "id": id}, - centered=True, - is_open=False, - ), - ], - style={'display': 'inline-block'}, - ) - - -def get_sorting_layout(id: int = -1, available_params: List[str] = []) -> html.Div: - available_params = list( - get_table_data()[0][list(get_table_data()[0].keys())[0]][0].keys() - if len(get_table_data()) and not available_params - else STATS_KEYS + list(get_metrics([]).keys()) + ["+ all fields in json"] - ) - text = ( - "Write an expression to sort the data\n\n" - "For example: len(data['question'])\n\n" - "The function has to return sortable type\n\n" - "Available parameters to sort data:\n" - + '\n'.join([', '.join(available_params[start : start + 5]) for start in range(0, len(available_params), 5)]) - ) - header = dbc.ModalHeader( - dbc.ModalTitle("Set Up Your Sorting Parameters"), - close_button=True, - ) - body = dbc.ModalBody( - html.Div( - [ - html.Pre(text), - get_code_text_area_layout( - id={ - "type": "sorting_function_input", - "id": id, - }, - ), - ], - ) - ) - footer = dbc.ModalFooter( - dbc.Button( - "Apply", - id={"type": "apply_sorting_button", "id": id}, - className="ms-auto", - n_clicks=0, - ) - ) - return html.Div( - [ - dbc.Button( - "Sort", - id={"type": "set_sorting_button", "id": id}, - class_name='button-class', - ), - dbc.Modal( - [ - header, - body, - footer, - ], - size="lg", - id={"type": "sorting", "id": id}, - centered=True, - is_open=False, - ), - ], - style={'display': 'inline-block'}, - ) - - -def get_change_label_layout(id: int = -1, apply_for_all_files: bool = True) -> html.Div: - header = dbc.ModalHeader( - dbc.ModalTitle("Manage labels"), - close_button=True, - ) - switch_layout = ( - [ - get_switch_layout( - { - "type": "aplly_for_all_files", - "id": id, - }, - ["Apply for all files"], - additional_params={"style": {"margin-left": "10px"}}, - ) - ] - if apply_for_all_files - else [] - ) - body = dbc.ModalBody( - html.Div( - [ - get_selector_layout( - options=labels, - id={"type": "label_selector", "id": id}, - value="choose label", - ), - dbc.InputGroup( - [ - dbc.Input( - id={ - "type": "new_label_input", - "id": id, - }, - placeholder="Enter new label", - type="text", - ), - dbc.Button( - "Add", - id={ - "type": "add_new_label_button", - "id": id, - }, - ), - ] - ), - *switch_layout, - html.Pre("", id={"type": "chosen_label", "id": id}), - ], - ) - ) - footer = dbc.ModalFooter( - html.Div( - [ - dbc.Button( - children="Delete", - id={ - "type": "delete_label_button", - "id": id, - }, - className="ms-auto", - n_clicks=0, - ), - html.Pre( - " ", - style={'display': 'inline-block', 'font-size': '5px'}, - ), - dbc.Button( - children="Apply", - id={"type": "apply_label_button", "id": id}, - className="ms-auto", - n_clicks=0, - ), - ], - ), - style={'display': 'inline-block'}, - ) - return html.Div( - [ - dbc.Button( - "Labels", - id={"type": "set_file_label_button", "id": id}, - class_name='button-class', - ), - dbc.Modal( - [header, body, footer], - size="lg", - id={"type": "label", "id": id}, - centered=True, - is_open=False, - ), - ], - style={'display': 'inline-block'}, - ) - - -def get_stats_layout() -> List[dbc.Row]: - return [ - dbc.Row( - dbc.Col( - dash_table.DataTable( - id='datatable', - columns=[ - { - 'name': name, - 'id': name, - 'hideable': True, - } - for name in STATS_KEYS + list(get_metrics([]).keys()) - ], - row_selectable='single', - cell_selectable=False, - page_action='custom', - page_current=0, - page_size=DATA_PAGE_SIZE, - page_count=math.ceil(len(table_data) / DATA_PAGE_SIZE), - style_cell={ - 'overflow': 'hidden', - 'textOverflow': 'ellipsis', - 'maxWidth': 0, - 'textAlign': 'center', - }, - style_header={ - 'color': 'text-primary', - 'text_align': 'center', - 'height': 'auto', - 'whiteSpace': 'normal', - }, - css=[ - { - 'selector': '.dash-spreadsheet-menu', - 'rule': 'position:absolute; bottom: 8px', - }, - { - 'selector': '.dash-filter--case', - 'rule': 'display: none', - }, - { - 'selector': '.column-header--hide', - 'rule': 'display: none', - }, - ], - ), - ) - ), - ] - - -def get_models_selector_table_cell(models: List[str], name: str, id: int, add_del_button: bool = False) -> dbc.Col: - del_model_layout = ( - [ - dbc.Button( - "-", - id={"type": "del_model", "id": id}, - outline=True, - color="primary", - className="me-1", - style={"height": "40px"}, - ), - ] - if add_del_button - else [] - ) - return dbc.Col( - html.Div( - [ - html.Div( - get_selector_layout( - models, - json.loads(MODEL_SELECTOR_ID.format(id)), - name, - ), - ), - get_sorting_layout(id), - get_filter_layout(id, mode=FILES_ONLY), - get_change_label_layout(id), - ] - + del_model_layout - + [get_text_modes_layout(id)], - style={'display': 'inline-flex'}, - ), - class_name='mt-1 bg-light font-monospace text-break small rounded border', - id={"type": "column_header", "id": id}, - ) - - -def get_models_selector_table_header(models: List[str]) -> List[dbc.Row]: - return [ - dbc.Row( - [ - dbc.Col( - html.Div( - "", - ), - width=2, - class_name='mt-1 bg-light font-monospace text-break small rounded border', - id='first_column', - ) - ] - + [ - get_models_selector_table_cell(get_available_models(), name, i, i != 0) - for i, name in enumerate(models) - ], - id='detailed_answers_header', - ) - ] - - -def get_detailed_answer_column(id: int, file_id=None) -> dbc.Col: - return dbc.Col( - html.Div( - children=( - get_selector_layout([], {"type": "file_selector", "id": file_id}, "") if file_id is not None else "" - ), - id={ - 'type': 'detailed_models_answers', - 'id': id, - }, - ), - class_name='mt-1 bg-light font-monospace text-break small rounded border', - ) - - -def get_detailed_answers_rows(keys: List[str], colums_number: int) -> List[dbc.Row]: - return [ - dbc.Row( - [ - dbc.Col( - html.Div( - html.Div( - [ - html.Div( - key, - id={"type": "row_name", "id": i}, - style={"display": "inline-block"}, - ), - dbc.Button( - html.Img( - src=EDIT_ICON_PATH, - id={"type": "edit_row_image", "id": i}, - style={ - "height": "15px", - "display": "inline-block", - }, - ), - id={"type": "edit_row_button", "id": i}, - outline=True, - color="primary", - className="me-1", - style={ - "border": "none", - "line-height": "1.2", - "display": "inline-block", - "margin-left": "1px", - "display": "none" if key in (FILE_NAME, LABEL) else "inline-block", - }, - ), - dbc.Button( - html.Img( - src=COMPARE_ICON_PATH, - id={"type": "compare_texts", "id": i}, - style={ - "height": "15px", - "display": "inline-block", - }, - ), - id={"type": "compare_texts_button", "id": i}, - outline=True, - color="primary", - className="me-1", - style={ - "border": "none", - "line-height": "1.2", - "display": "inline-block", - "margin-left": "-10px" if key != LABEL else "1px", - "display": "none" if key == FILE_NAME else "inline-block", - }, - ), - dbc.Button( - "-", - id={"type": "del_row", "id": i}, - outline=True, - color="primary", - className="me-1", - style={ - "border": "none", - "display": "inline-block", - "margin-left": "-9px" if key != FILE_NAME else "1px", - }, - ), - ], - style={"display": "inline-block"}, - ), - ), - width=2, - class_name='mt-1 bg-light font-monospace text-break small rounded border', - ) - ] - + [get_detailed_answer_column(j * len(keys) + i) for j in range(colums_number)], - id={"type": "detailed_answers_row", "id": i}, - ) - for i, key in enumerate(keys) - ] - - -def get_table_answers_detailed_data_layout( - models: List[str], - keys: List[str], -) -> List[dbc.Row]: - return get_models_selector_table_header(models) + get_detailed_answers_rows(keys, len(models)) - - -def get_row_detailed_inner_data( - question_id: int, - model: str, - rows_names: List[str], - files_names: List[str], - file_id: int, - col_id: int, - compare_to: Dict = {}, - text_modes: List[str] = [CODE, LATEX, ANSI], -) -> List: - table_data = get_table_data()[question_id].get(model, []) - row_data = [] - empty_list = False - if table_data[file_id].get(FILE_NAME, None) not in files_names: - empty_list = True - for key in filter( - lambda key: is_detailed_answers_rows_key(key), - rows_names, - ): - if file_id < 0 or len(table_data) <= file_id or key in get_excluded_row(): - value = "" - elif key == FILE_NAME: - value = get_selector_layout( - files_names, - {"type": "file_selector", "id": col_id}, - (table_data[file_id].get(key, None) if not empty_list else ""), - ) - elif empty_list: - value = "" - elif key in get_editable_rows(): - value = str(table_data[file_id].get(key, None)) - else: - value = get_single_prompt_output_layout( - str(table_data[file_id].get(key, None)), - text_modes + ([COMPARE] if key in get_compared_rows() else []), - str(compare_to.get(key, "")), - ) - row_data.append( - value - if key not in get_editable_rows() - else dbc.Textarea(id={"type": "editable_row", "id": key, "model_name": model}, value=value) - ) - return row_data - - -def get_table_detailed_inner_data( - question_id: int, - rows_names: List[str], - models: List[str], - files_id: List[int], - filter_functions: List[str], - sorting_functions: List[str], - text_modes: List[List[str]], -) -> List: - table_data = [] - for col_id, (model, file_id, filter_function, sorting_function, modes) in enumerate( - zip(models, files_id, filter_functions, sorting_functions, text_modes) - ): - row_data = get_row_detailed_inner_data( - question_id=question_id, - model=model, - rows_names=rows_names, - files_names=[ - file[FILE_NAME] - for file in get_filtered_files( - filter_function, - sorting_function, - get_table_data()[question_id][model] if len(get_table_data()) else [], - ) - ], - file_id=file_id, - col_id=col_id, - text_modes=modes, - compare_to=get_table_data()[question_id][models[0]][files_id[0]], - ) - table_data.extend(row_data) - return table_data - - -def get_general_stats_layout( - base_model: str, -) -> html.Div: - data_for_base_model = [data.get(base_model, []) for data in get_table_data()] - custom_stats = {} - for name, func in get_general_custom_stats().items(): - errors_dict = {} - custom_stats[name] = catch_eval_exception( - [], - func, - data_for_base_model, - "Got error when applying function", - errors_dict, - ) - if len(errors_dict): - logging.error(ERROR_MESSAGE_TEMPLATE.format(name, errors_dict)) - - overall_samples = sum(len(question_data) for question_data in data_for_base_model) - dataset_size = len(list(filter(lambda x: bool(x), data_for_base_model))) - stats = { - "dataset size": dataset_size, - "overall number of samples": overall_samples, - "generations per sample": (overall_samples / dataset_size if dataset_size else 0), - **custom_stats, - } - return [html.Div([html.Pre(f'{name}: {value}') for name, value in stats.items()])] - - -def get_update_dataset_layout(base_model: str, update_function: str, models: List[str]) -> List[html.Tr]: - errors_dict = {} - global table_data - if update_function: - update_eval_function = get_eval_function(update_function.strip()) - available_models = { - model_name: model_info["file_paths"] for model_name, model_info in get_available_models().items() - } - - for question_id in range(len(table_data)): - new_dicts = list( - map( - lambda data: catch_eval_exception( - available_models, - update_eval_function, - data, - data, - errors_dict, - ), - table_data[question_id][base_model], - ) - ) - for i, new_dict in enumerate(new_dicts): - for key, value in new_dict.items(): - table_data[question_id][base_model][i][key] = value - - keys = list(table_data[question_id][base_model][i].keys()) - for key in keys: - if key not in new_dict: - table_data[question_id][base_model][i].pop(key) - - if len(errors_dict): - logging.error(ERROR_MESSAGE_TEMPLATE.format("update_dataset", errors_dict)) - - return ( - get_stats_layout() - + get_general_stats_layout(base_model) - + get_table_answers_detailed_data_layout( - models, - list( - filter( - is_detailed_answers_rows_key, - ( - table_data[0][base_model][0].keys() - if len(table_data) and len(table_data[0][base_model]) - else [] - ), - ) - ), - ) - ) - - -def get_sorting_answers_layout(base_model: str, sorting_function: str, models: List[str]) -> List[html.Tr]: - errors_dict = {} - global table_data - if sorting_function: - sortting_eval_function = get_eval_function(sorting_function.strip()) - available_models = { - model_name: model_info["file_paths"] for model_name, model_info in get_available_models().items() - } - - for question_id in range(len(table_data)): - for model in table_data[question_id].keys(): - table_data[question_id][model].sort( - key=lambda data: catch_eval_exception( - available_models, - sortting_eval_function, - data, - 0, - errors_dict, - ) - ) - - table_data.sort( - key=lambda single_question_data: tuple( - map( - lambda data: catch_eval_exception( - available_models, - sortting_eval_function, - data, - 0, - errors_dict, - ), - single_question_data[base_model], - ) - ) - ) - if len(errors_dict): - logging.error(ERROR_MESSAGE_TEMPLATE.format("sorting", errors_dict)) - - return ( - get_stats_layout() - + get_general_stats_layout(base_model) - + get_table_answers_detailed_data_layout( - models, - list( - filter( - is_detailed_answers_rows_key, - ( - table_data[0][base_model][0].keys() - if len(table_data) and len(table_data[0][base_model]) - else [] - ), - ) - ), - ) - ) - - -def get_filter_answers_layout( - base_model: str, - filtering_function: str, - apply_on_filtered_data: bool, - models: List[str], - filter_mode: str, -) -> List[html.Tr]: - global table_data - clean_table_data = [] - if not apply_on_filtered_data: - table_data = custom_deepcopy(get_data_from_files()) - for question_id in range(len(table_data)): - for model_id, files_data in table_data[question_id].items(): - stats = get_metrics(files_data) - table_data[question_id][model_id] = list( - map( - lambda data: {**data, **stats}, - table_data[question_id][model_id], - ) - ) - - errors_dict = {} - if filtering_function: - available_models = { - model_name: model_info["file_paths"] for model_name, model_info in get_available_models().items() - } - filter_lines = filtering_function.strip().split('\n') - common_expressions, splitted_filters = ( - "\n".join(filter_lines[:-1]), - filter_lines[-1], - ) - full_splitted_filters = [ - common_expressions + "\n" + single_filter for single_filter in splitted_filters.split('&&') - ] - filtering_functions = ( - list( - [ - get_eval_function(f"{NAME_FOR_BASE_MODEL} = '{base_model}'\n" + func) - for func in full_splitted_filters - ] - ) - if filtering_function - else [] - ) - - if filter_mode == FILES_FILTERING: - for question_id in range(len(table_data)): - good_data = True - for model_id in table_data[question_id].keys(): - - def filtering_key_function(file_dict): - data = {model_id: file_dict} - return all( - [ - catch_eval_exception( - available_models, - filter_function, - data, - True, - errors_dict, - ) - for filter_function in filtering_functions - ], - ) - - table_data[question_id][model_id] = list( - filter( - filtering_key_function, - table_data[question_id][model_id], - ) - ) - stats = get_metrics(table_data[question_id][model_id]) - table_data[question_id][model_id] = list( - map( - lambda data: {**data, **stats}, - table_data[question_id][model_id], - ) - ) - - if table_data[question_id][model_id] == []: - good_data = False - if good_data: - clean_table_data.append(table_data[question_id]) - else: - func = get_eval_function(f"{NAME_FOR_BASE_MODEL} = '{base_model}'\n" + filtering_function.strip()) - clean_table_data = list( - filter( - lambda data: catch_eval_exception( - available_models=[], - eval_func=func, - data=data, - default_answer=True, - errors_dict=errors_dict, - ), - table_data, - ) - ) - table_data = clean_table_data - if len(errors_dict): - logging.error(ERROR_MESSAGE_TEMPLATE.format("filtering", errors_dict)) - - return ( - get_stats_layout() - + get_general_stats_layout(base_model) - + get_table_answers_detailed_data_layout( - models, - list( - filter( - is_detailed_answers_rows_key, - ( - table_data[0][base_model][0].keys() - if len(table_data) and len(table_data[0][base_model]) - else [] - ), - ) - ), - ) - ) - - -def get_model_answers_table_layout(base_model: str, use_current: bool = False) -> List: - global table_data - if not use_current: - table_data = custom_deepcopy(get_data_from_files()) - - return ( - get_stats_layout() - + get_general_stats_layout(base_model) - + get_table_answers_detailed_data_layout( - [base_model], - list( - filter( - is_detailed_answers_rows_key, - ( - table_data[0][base_model][0].keys() - if len(table_data) and len(table_data[0][base_model]) - else [] - ), - ) - ), - ) - ) diff --git a/nemo_inspector/nemo_inspector.py b/nemo_inspector/nemo_inspector.py deleted file mode 100644 index 4efa8a6b9..000000000 --- a/nemo_inspector/nemo_inspector.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -from pathlib import Path - -sys.path.append(str(Path(__file__).parents[1])) - -import signal - -from callbacks import app -from layouts import get_main_page_layout - -signal.signal(signal.SIGALRM, signal.SIG_IGN) - - -if __name__ == "__main__": - app.title = "NeMo Inspector" - app.layout = get_main_page_layout() - app.run( - host='localhost', - port='8080', - ) diff --git a/nemo_inspector/settings/__init__.py b/nemo_inspector/settings/__init__.py deleted file mode 100644 index d9155f923..000000000 --- a/nemo_inspector/settings/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/nemo_inspector/settings/constants.py b/nemo_inspector/settings/constants.py deleted file mode 100644 index 716909c8f..000000000 --- a/nemo_inspector/settings/constants.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from pathlib import Path - -ANSWER_FIELD = "expected_answer" -ANSI = "ansi" -NAME_FOR_BASE_MODEL = "base_generation" -EXTRA_FIELDS = ["page_index", "file_name"] -INLINE_STATS = "inline_stats" -GENERAL_STATS = "general_stats" -CHAT_MODE = "chat_mode" -CHOOSE_MODEL = "choose generation" -CHOOSE_LABEL = "choose label" -COMPARE = 'compare' -CODE = "code" -COMPARE_ICON_PATH = "assets/images/compare_icon.png" -CODE_SEPARATORS = { - "code_begin": '', - "code_end": '', - "code_output_begin": '', - "code_output_end": '', -} -CUSTOM = 'custom' -DATA_PAGE_SIZE = 10 -DELETE = "delete" -EDIT_ICON_PATH = "assets/images/edit_icon.png" -ERROR_MESSAGE_TEMPLATE = "When applying {} function\ngot errors\n{}" -FEW_SHOTS_INPUT = "few_shots_input" -FILE_NAME = 'file_name' -FILES_ONLY = "files_only" -FILES_FILTERING = "add_files_filtering" -GENERAL_STATS = "general_stats" -CODE_BEGIN = 'code_begin' -CODE_END = 'code_end' -CODE_OUTPUT_BEGIN = 'code_output_begin' -CODE_OUTPUT_END = 'code_output_end' -CONFIGS_FOLDER = os.path.join(Path(__file__).parents[2].absolute(), 'nemo_skills/prompt/config') -GREEDY = "greedy" -IGNORE_FIELDS = ['stop_phrases', 'used_prompt', 'server_type'] -QUESTIONS_FILTERING = "questions_filtering" -QUERY_INPUT_TYPE = "query_input" -QUERY_INPUT_ID = '{{"type": "{}", "id": "{}"}}' -QUESTION_FIELD = "problem" -ONE_SAMPLE_MODE = "one_sample" -METRICS = "metrics" -OUTPUT = "output" -OUTPUT_PATH = "{}-{}.jsonl" -PARAMS_TO_REMOVE = [ - 'output_file', - 'dataset', - 'split', - 'example_dicts', - 'retriever', - '_context_template', - 'save_generations_path', -] -PARAMETERS_FILE_NAME = "nemo_inspector/results/parameters.json" -TEMPLATES_FOLDER = os.path.join(Path(__file__).parents[2].absolute(), 'nemo_skills/prompt/template') -RETRIEVAL = 'retrieval' -RETRIEVAL_FIELDS = [ - 'max_retrieved_chars_field', - 'retrieved_entries', - 'retrieval_file', - 'retrieval_field', - 'retrieved_few_shots', - 'max_retrieved_chars', - 'randomize_retrieved_entries', -] -SAVE_ICON_PATH = "assets/images/save_icon.png" -STATS_KEYS = [ - 'question_index', - 'problem', -] -SEPARATOR_DISPLAY = '.' -SEPARATOR_ID = '->' -SETTING_PARAMS = [ - 'server', - 'sandbox', - 'output_file', - 'inspector_params', - 'types', - 'stop_phrases', -] -STATISTICS_FOR_WHOLE_DATASET = ["correct_answer", "wrong_answer", "no_answer"] -UNDEFINED = "undefined" -MARKDOWN = "markdown" -MODEL_SELECTOR_ID = '{{"type": "model_selector", "id": {}}}' -LABEL_SELECTOR_ID = '{{"type": "label_selector", "id": {}}}' -LABEL = "labels" -LATEX = "latex" diff --git a/nemo_inspector/settings/inspector_config.py b/nemo_inspector/settings/inspector_config.py deleted file mode 100644 index e5648b71d..000000000 --- a/nemo_inspector/settings/inspector_config.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import field -from typing import Dict, Tuple - -import hydra - -from nemo_skills.inference.generate import GenerateSolutionsConfig -from nemo_skills.utils import nested_dataclass, unroll_files - - -@nested_dataclass(kw_only=True) -class BaseInspectorConfig: - model_prediction: Dict[str, str] = field(default_factory=dict) - save_generations_path: str = "nemo_inspector/results/saved_generations" - - def __post_init__(self): - self.model_prediction = { - model_name: list(unroll_files(file_path.split(" "))) - for model_name, file_path in self.model_prediction.items() - } - - -@nested_dataclass(kw_only=True) -class InspectorConfig(GenerateSolutionsConfig): - inspector_params: BaseInspectorConfig = field(default_factory=BaseInspectorConfig) - - -cs = hydra.core.config_store.ConfigStore.instance() -cs.store(name="base_inspector_config", node=InspectorConfig) diff --git a/nemo_inspector/settings/inspector_config.yaml b/nemo_inspector/settings/inspector_config.yaml deleted file mode 100644 index 525cb24e8..000000000 --- a/nemo_inspector/settings/inspector_config.yaml +++ /dev/null @@ -1,4 +0,0 @@ -defaults: - # add custom config to set up model_prediction - - base_inspector_config - - _self_ diff --git a/nemo_inspector/settings/templates.py b/nemo_inspector/settings/templates.py deleted file mode 100644 index 1b07a208b..000000000 --- a/nemo_inspector/settings/templates.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys - -summarize_results_template = ( - sys.executable - + """ pipeline/summarize_results.py {results_path} \\ - --benchmarks {benchmarks}""" -) diff --git a/nemo_inspector/tests/README.md b/nemo_inspector/tests/README.md deleted file mode 100644 index 8c01933c1..000000000 --- a/nemo_inspector/tests/README.md +++ /dev/null @@ -1,11 +0,0 @@ -To launch tests firstly install all the requirements -``` -pip install -r requirements/main.txt -pip install -r requirements/inspector.txt -pip install -r requirements/common-tests.txt -pip install -r requirements/inspector-tests.txt -``` -Now it is possible to launch tests -``` -pytest inspector/tests -``` \ No newline at end of file diff --git a/nemo_inspector/tests/test_ping.py b/nemo_inspector/tests/test_ping.py deleted file mode 100644 index 37fd7b937..000000000 --- a/nemo_inspector/tests/test_ping.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import subprocess -import sys -from pathlib import Path - -import pytest -from selenium import webdriver -from selenium.webdriver.chrome.options import Options -from selenium.webdriver.chrome.service import Service -from selenium.webdriver.common.by import By -from selenium.webdriver.support import expected_conditions as EC -from selenium.webdriver.support.ui import WebDriverWait -from webdriver_manager.chrome import ChromeDriverManager -from webdriver_manager.core.os_manager import ChromeType - -project_root = str(Path(__file__).parents[2]) -sys.path.remove(str(Path(__file__).parents[0])) - - -@pytest.fixture(scope="module") -def nemo_inspector_process(): - # Start the NeMo Inspector as a subprocess - - process = subprocess.Popen( - ["python", "nemo_inspector/nemo_inspector.py"], - cwd=project_root, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - - yield process - - # Terminate the process after the tests - process.terminate() - process.wait() - - -@pytest.fixture -def chrome_driver(): - chrome_driver_path = ChromeDriverManager(chrome_type=ChromeType.GOOGLE).install() - options = Options() - options.page_load_strategy = 'normal' - options.add_argument("--headless") - options.add_argument("--disable-gpu") - options.add_argument("--no-sandbox") - options.add_argument("--disable-dev-shm-usage") - - service = Service(chrome_driver_path) - driver = webdriver.Chrome(service=service, options=options) - os.environ['PATH'] += os.pathsep + '/'.join(chrome_driver_path.split("/")[:-1]) - yield driver - driver.quit() - - -@pytest.mark.parametrize( - ("element_id", "url"), - [('run_button', "/"), ('add_model', "/analyze")], -) -def test_dash_app_launch(chrome_driver, nemo_inspector_process, element_id, url): - full_url = f"http://localhost:8080{url}" - - chrome_driver.get(full_url) - - element = WebDriverWait(chrome_driver, 10).until(EC.presence_of_element_located((By.ID, element_id))) - assert element.is_displayed() diff --git a/nemo_inspector/utils/__init__.py b/nemo_inspector/utils/__init__.py deleted file mode 100644 index d9155f923..000000000 --- a/nemo_inspector/utils/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/nemo_inspector/utils/common.py b/nemo_inspector/utils/common.py deleted file mode 100644 index 5b275890b..000000000 --- a/nemo_inspector/utils/common.py +++ /dev/null @@ -1,622 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datetime -import functools -import json -import logging -import os -import re -import subprocess -from collections import defaultdict -from dataclasses import fields, is_dataclass -from types import NoneType -from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union, get_args, get_origin, get_type_hints - -from dash import html -from flask import current_app -from joblib import Parallel, delayed -from settings.constants import ( - ANSWER_FIELD, - CUSTOM, - ERROR_MESSAGE_TEMPLATE, - FILE_NAME, - GENERAL_STATS, - IGNORE_FIELDS, - INLINE_STATS, - OUTPUT, - PARAMETERS_FILE_NAME, - PARAMS_TO_REMOVE, - QUESTION_FIELD, - RETRIEVAL_FIELDS, - SEPARATOR_DISPLAY, - SETTING_PARAMS, - STATS_KEYS, - UNDEFINED, -) - -from nemo_skills.inference.generate import GenerateSolutionsConfig, InferenceConfig -from nemo_skills.prompt.utils import FewShotExamplesConfig, PromptConfig, PromptTemplate -from nemo_skills.utils import unroll_files - -custom_stats = {} -general_custom_stats = {} -deleted_stats = set() -excluded_rows = set() -editable_rows = set() -compared_rows = set() -stats_raw = {INLINE_STATS: {CUSTOM: ""}, GENERAL_STATS: {CUSTOM: ""}} - - -def get_editable_rows() -> Set: - return editable_rows - - -def get_excluded_row() -> Set: - return excluded_rows - - -def get_deleted_stats() -> Set: - return deleted_stats - - -def get_custom_stats() -> Dict: - return custom_stats - - -def get_compared_rows() -> Dict: - return compared_rows - - -def get_general_custom_stats() -> Dict: - return general_custom_stats - - -def get_stats_raw() -> Dict: - return stats_raw - - -def parse_model_answer(answer: str) -> List[Dict]: - """ - Parses a model answer and extracts code blocks, explanations, and outputs preserving their sequence. - - Args: - answer (str): The model answer to parse. - - Returns: - List[Dict]: A list of dictionaries containing the parsed results. Each dictionary - contains the following keys: - - 'explanation': The explanation text before the code block. - - 'code': The code block. - - 'output': The output of the code block. - - """ - config = current_app.config["nemo_inspector"] - code_start, code_end = map( - re.escape, - config["inspector_params"]["code_separators"], - ) - output_start, output_end = map( - re.escape, - config["inspector_params"]["code_output_separators"], - ) - code_pattern = re.compile(rf"{code_start}(.*?){code_end}", re.DOTALL) - code_output_pattern = re.compile( - rf"{code_start}(.*?){code_end}\s*{output_start}(.*?){output_end}", - re.DOTALL, - ) - code_matches = list(code_pattern.finditer(answer)) - code_output_matches = list(code_output_pattern.finditer(answer)) - parsed_results = [] - last_index = 0 - for code_match in code_matches: - explanation = answer[last_index : code_match.start()].strip() - code_text = code_match.group(1).strip() - output_text = None - if code_output_matches and code_output_matches[0].start() == code_match.start(): - output_match = code_output_matches.pop(0) - output_text = output_match.group(2).strip() - parsed_results.append( - { - "explanation": explanation, - "code": code_text, - "output": output_text, - } - ) - last_index = code_match.end() - if output_text is not None: - last_index = output_match.end() - if last_index < len(answer): - trailing_text = answer[last_index:].strip() - if code_start.replace("\\", "") in trailing_text: - code_start_index = trailing_text.find(code_start.replace("\\", "")) - parsed_results.append( - { - "explanation": trailing_text[0:code_start_index].strip(), - "code": trailing_text[code_start_index + len(code_start.replace("\\", "")) :], - "output": "code_block was not finished", - "wrong_code_block": True, - } - ) - trailing_text = None - if trailing_text: - parsed_results.append({"explanation": trailing_text, "code": None, "output": None}) - return parsed_results - - -def get_height_adjustment() -> html.Iframe: - return html.Iframe( - id="query_params_iframe", - srcDoc=""" - - - - - - - - - """, - style={"visibility": "hidden"}, - ) - - -@functools.lru_cache() -def get_test_data(index: int, dataset: str) -> Tuple[Dict, int]: - if not dataset or dataset == UNDEFINED or os.path.isfile(dataset) is False: - return {QUESTION_FIELD: "", ANSWER_FIELD: ""}, 0 - with open(dataset) as file: - tests = file.readlines() - index = max(min(len(tests), index), 1) - test = json.loads(tests[index - 1]) - return test, index - - -def get_values_from_input_group(children: Iterable) -> Dict: - values = {} - for child in children: - for input_group_child in child["props"]["children"]: - if "id" in input_group_child["props"].keys() and "value" in input_group_child["props"].keys(): - type_function = str - value = input_group_child["props"]["value"] - id = ( - input_group_child["props"]["id"]["id"] - if isinstance(input_group_child["props"]["id"], Dict) - else input_group_child["props"]["id"] - ) - if value is None or value == UNDEFINED: - values[id] = None - continue - if str(value).isdigit() or str(value).replace("-", "", 1).isdigit(): - type_function = int - elif str(value).replace(".", "", 1).replace("-", "", 1).isdigit(): - type_function = float - - values[id] = type_function(str(value).replace('\\n', '\n')) - - return values - - -def extract_query_params(query_params_ids: List[Dict], query_params: List[Dict]) -> Dict: - default_answer = {QUESTION_FIELD: "", "expected_answer": ""} - try: - query_params_extracted = {param_id['id']: param for param_id, param in zip(query_params_ids, query_params)} - except ValueError: - query_params_extracted = default_answer - - return query_params_extracted or default_answer - - -def get_utils_from_config_helper(cfg: Dict, display_path: bool = True) -> Dict: - config = {} - for key, value in sorted(cfg.items()): - if key in PARAMS_TO_REMOVE or key in SETTING_PARAMS: - continue - elif isinstance(value, Dict): - config = { - **config, - **{ - (key + SEPARATOR_DISPLAY if display_path and 'template' in inner_key else "") + inner_key: value - for inner_key, value in get_utils_from_config_helper(value).items() - }, - } - elif not isinstance(value, List): - config[key] = value - return config - - -def get_utils_from_config(cfg: Dict, display_path: bool = True) -> Dict: - return { - SEPARATOR_DISPLAY.join(key.split(SEPARATOR_DISPLAY)[1:]) or key: value - for key, value in get_utils_from_config_helper(cfg, display_path).items() - if key not in RETRIEVAL_FIELDS + IGNORE_FIELDS - } - - -def get_stats(all_files_data: List[Dict]) -> Tuple[float, float, float]: - """Returns the percentage of correct, wrong, and no response answers in the given data. - - If not data is provided, returns -1 for all values. - """ - correct = 0 - wrong = 0 - no_response = 0 - for data in all_files_data: - if data.get("predicted_answer") is None: - no_response += 1 - elif data.get("is_correct", False): - correct += 1 - else: - wrong += 1 - - if len(all_files_data): - return ( - correct / len(all_files_data), - wrong / len(all_files_data), - no_response / len(all_files_data), - ) - return -1, -1, -1 - - -def get_metrics(all_files_data: List[Dict], errors_dict: Dict = {}) -> Dict: - correct_responses, wrong_responses, no_response = get_stats(all_files_data) - custom_stats = {} - for name, func in get_custom_stats().items(): - if name not in errors_dict: - errors_dict[name] = {} - custom_stats[name] = catch_eval_exception( - [], - func, - all_files_data, - "Got error when applying function", - errors_dict[name], - ) - - stats = { - 'correct_responses': round(correct_responses, 2), - "wrong_responses": round(wrong_responses, 2), - "no_response": round(no_response, 2), - **custom_stats, - } - return stats - - -def get_eval_function(text): - template = """ -def eval_function(data): -{} - return {} -""" - code_lines = [''] + text.strip().split('\n') - code = template.format( - '\n '.join(code_lines[:-1]), - code_lines[-1:][0], - ) - namespace = {} - exec(code, namespace) - return namespace['eval_function'] - - -def calculate_metrics_for_whole_data(table_data: List, model_id: str) -> Dict: - errors_dict = {} - for question_id in range(len(table_data)): - stats = get_metrics(table_data[question_id][model_id], errors_dict) - table_data[question_id][model_id] = list( - map( - lambda data: {**data, **stats}, - table_data[question_id][model_id], - ) - ) - if len(errors_dict): - for name, error_dict in errors_dict.items(): - logging.error(ERROR_MESSAGE_TEMPLATE.format(name, error_dict)) - - -def catch_eval_exception( - available_models: List[str], - eval_func: Callable[[Dict], bool], - data: Dict, - default_answer: Union[bool, str], - errors_dict: Optional[Dict] = {}, -) -> bool: - try: - if eval_func is None: - return default_answer - return eval_func(data) - except Exception as e: - if str(e).split(" ")[-1].replace("'", "") not in available_models: - if str(e) not in errors_dict: - errors_dict[str(e)] = 0 - errors_dict[str(e)] += 1 - return default_answer - - -def custom_deepcopy(data) -> List: - new_data = [] - for item in data: - new_item = {} - for key, value_list in item.items(): - new_item[key] = value_list - new_data.append(new_item) - return new_data - - -@functools.lru_cache(maxsize=1) -def get_data_from_files(cache_indicator=None) -> List: - if cache_indicator is not None: - return [] - base_config = current_app.config['nemo_inspector'] - dataset = None - if os.path.isfile(base_config['input_file']): - with open(base_config['input_file']) as f: - dataset = [json.loads(line) for line in f] - - available_models = { - model_name: model_info["file_paths"] for model_name, model_info in get_available_models().items() - } - - all_models_data_array = [] - - def process_model_files(model_id, results_files, dataset): - model_data = defaultdict(list) - file_names = {} - for file_id, path in enumerate(results_files): - file_name = path.split('/')[-1].split('.')[0] - if file_name in file_names: - file_names[file_name] += 1 - file_name += f"_{file_names[file_name]}" - else: - file_names[file_name] = 1 - with open(path) as f: - answers = map(json.loads, f) - for question_index, answer in enumerate(answers): - result = { - FILE_NAME: file_name, - **(dataset[question_index] if dataset and len(dataset) > question_index else {}), - "question_index": question_index + 1, - "page_index": file_id, - "labels": [], - **answer, - } - model_data[question_index].append(result) - return model_id, model_data - - num_cores = -1 - model_data_list = Parallel(n_jobs=num_cores)( - delayed(process_model_files)(model_id, results_files, dataset) - for model_id, results_files in available_models.items() - ) - - for model_id, model_data in model_data_list: - for question_index, results in model_data.items(): - if len(all_models_data_array) <= question_index: - all_models_data_array.append({}) - all_models_data_array[question_index][model_id] = results - stats = get_metrics(all_models_data_array[question_index][model_id]) - all_models_data_array[question_index][model_id] = list( - map( - lambda data: {**data, **stats}, - all_models_data_array[question_index][model_id], - ) - ) - - return all_models_data_array - - -def get_filtered_files( - filter_function: str, - sorting_function: str, - array_to_filter: List, -) -> List: - filter_lambda_functions = [ - get_eval_function(func.strip()) for func in (filter_function if filter_function else "True").split('&&') - ] - available_models = get_available_models() - filtered_data = [ - list( - filter( - lambda data: catch_eval_exception(available_models, function, data, False), - array_to_filter, - ) - ) - for function in filter_lambda_functions - ] - - filtered_data = list(filter(lambda data: data != [], filtered_data)) - filtered_data = filtered_data[0] if len(filtered_data) > 0 else [{FILE_NAME: ""}] - if sorting_function and filtered_data != [{FILE_NAME: ""}]: - sorting_lambda_function = get_eval_function(sorting_function.strip()) - filtered_data.sort(key=lambda data: catch_eval_exception(available_models, sorting_lambda_function, data, 0)) - - return filtered_data - - -def is_detailed_answers_rows_key(key: str) -> bool: - return ( - key not in get_deleted_stats() - and 'index' not in key - and key not in STATS_KEYS + list(get_metrics([]).keys()) - or key == QUESTION_FIELD - ) - - -@functools.lru_cache(maxsize=1) -def get_available_models(cache_indicator=None) -> Dict: - if cache_indicator is not None: - return {} - try: - with open(PARAMETERS_FILE_NAME) as f: - runs_storage = json.load(f) - except FileNotFoundError: - runs_storage = {} - models = list(runs_storage.keys()) - config = current_app.config["nemo_inspector"]["inspector_params"] - for model_name in models: - runs_storage[model_name]["file_paths"] = list( - unroll_files([os.path.join(config["results_path"], model_name, f"{OUTPUT}*.jsonl")]) - ) - for model_name, files in config["model_prediction"].items(): - runs_storage[model_name] = { - "utils": {}, - "examples": {}, - "file_paths": files, - } - - return runs_storage - - -def run_subprocess(command: str) -> Tuple[str, bool]: - result = subprocess.run(command, shell=True, capture_output=True, text=True) - success = True - - delta = datetime.timedelta(minutes=1) - start_time = datetime.datetime.now() - while result.returncode != 0 and datetime.datetime.now() - start_time <= delta: - result = subprocess.run(command, shell=True, capture_output=True, text=True) - - if result.returncode != 0: - logging.info(f"Error while running command: {command}") - logging.info(f"Return code: {result.returncode}") - logging.info(f"Output (stderr): {result.stderr.strip()}") - success = False - - return result.stdout.strip(), result.stderr.strip(), success - - -def get_config( - config_class: Union[GenerateSolutionsConfig, PromptConfig, InferenceConfig, FewShotExamplesConfig], - utils: Dict[str, str], - config: Dict, -) -> Union[GenerateSolutionsConfig, PromptConfig, InferenceConfig, FewShotExamplesConfig]: - return config_class( - **{ - key: value - for key, value in { - **config, - **utils, - }.items() - if key in {field.name for field in fields(config_class)} - }, - ) - - -@functools.lru_cache(maxsize=1) -def get_settings(): - def get_settings_helper(config: Dict): - settings = {} - for key, value in config.items(): - if key in SETTING_PARAMS: - settings[key] = value - if isinstance(value, dict): - settings = {**settings, **get_settings_helper(value)} - return settings - - return get_settings_helper(current_app.config['nemo_inspector']) - - -def get_utils_dict(name: Union[str, Dict], value: Union[str, int], id: Union[str, Dict] = None): - if id is None: - id = name - if name in current_app.config['nemo_inspector']['types'].keys(): - template = { - 'props': { - 'id': id, - 'options': [ - {"label": value, "value": value} for value in current_app.config['nemo_inspector']['types'][name] - ], - 'value': current_app.config['nemo_inspector']['types'][name][0], - }, - 'type': 'Select', - 'namespace': 'dash_bootstrap_components', - } - elif isinstance(value, (int, float)): - float_params = {"step": 0.1} if isinstance(value, float) else {} - template = { - 'props': { - 'id': id, - 'debounce': True, - 'min': 0, - 'type': 'number', - 'value': value, - **float_params, - }, - 'type': 'Input', - 'namespace': 'dash_bootstrap_components', - } - else: - template = { - 'props': { - 'id': id, - 'debounce': True, - 'style': {'width': '100%'}, - 'value': value, - }, - 'type': 'Textarea', - 'namespace': 'dash_bootstrap_components', - } - return { - 'props': { - 'children': [ - { - 'props': {'children': name}, - 'type': 'InputGroupText', - 'namespace': 'dash_bootstrap_components', - }, - template, - ], - 'className': 'mb-3', - }, - 'type': 'InputGroup', - 'namespace': 'dash_bootstrap_components', - } - - -def initialize_default( - cls: Union[PromptTemplate, PromptConfig], specification: Dict = {} -) -> Union[PromptTemplate, PromptConfig]: - if not specification: - specification = {} - - def get_default(field, specification: Dict = None): - if not specification: - specification = {} - _type = get_type_hints(cls)[field.name] - if is_dataclass(_type): - return initialize_default( - _type, - { - **specification, - **( - specification.get(field.name, {}) - if isinstance(specification.get(field.name, {}), Dict) - else {} - ), - }, - ) - if isinstance(specification, Dict) and field.name in specification: - return specification[field.name] - else: - args = get_args(_type) - if len(args): - if NoneType in args: - return None - else: - return args[0]() - return (get_origin(_type) or _type)() - - return cls(**{field.name: get_default(field, specification) for field in fields(cls)}) diff --git a/nemo_inspector/utils/decoration.py b/nemo_inspector/utils/decoration.py deleted file mode 100644 index 09c2056ed..000000000 --- a/nemo_inspector/utils/decoration.py +++ /dev/null @@ -1,362 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -import re -import string -from difflib import SequenceMatcher -from html import escape -from io import StringIO -from typing import Callable, Dict, List, Optional, Tuple, Union - -from ansi2html import Ansi2HTMLConverter -from dash import dcc, html -from pygments.formatters import HtmlFormatter -from pygments.lexers import PythonLexer -from settings.constants import ANSI, COMPARE, LATEX, MARKDOWN - - -def color_text_diff(text1: str, text2: str) -> str: - if text1 == text2: - return [(text1, {})] - matcher = SequenceMatcher(None, text1, text2) - result = [] - for tag, i1, i2, j1, j2 in matcher.get_opcodes(): - if tag == 'equal': - result.append((text2[j1:j2], {})) - elif tag == 'replace': - result.append((text1[i1:i2], {"background-color": "#c8e6c9"})) - result.append((text2[j1:j2], {"background-color": "#ffcdd2", "text-decoration": "line-through"})) - elif tag == 'insert': - result.append((text2[j1:j2], {"background-color": "#ffcdd2", "text-decoration": "line-through"})) - elif tag == 'delete': - result.append((text1[i1:i2], {"background-color": "#c8e6c9"})) - return result - - -def get_starts_with_tag_function(tag: str, default_index_move: int) -> Callable[[str, int], Tuple[bool, int]]: - def starts_with_tag_func_templ(text: str, index: int): - is_starts_with_tag = text.startswith(tag, index) - if not is_starts_with_tag: - returning_index = index + default_index_move - elif '{' not in tag: - returning_index = index + len(tag) - else: - returning_index = text.find('}', index) % (len(text) + 1) - - return is_starts_with_tag, returning_index - - return starts_with_tag_func_templ - - -def proccess_tag( - text: str, - start_index: int, - detect_start_token: Callable[[str, int], Tuple[bool, int]], - detect_end_token: Callable[[str, int], Tuple[bool, int]], - end_sign: Optional[str], - last_block_only: bool = False, -) -> int: - count = 0 - index = start_index - while index < len(text): - if end_sign and text[index] == end_sign: - return start_index, start_index + 1 - is_start_token, new_index = detect_start_token(text, index) - count += is_start_token - if last_block_only and is_start_token: - start_index = index - count = min(1, count) - index = new_index - is_end_token, index = detect_end_token(text, index) - count -= is_end_token - if count == 0: - break - return start_index, index + 1 - - -def get_single_dollar_functions(direction: int, default_index_move: int) -> Callable[[str, int], Tuple[bool, int]]: - return lambda text, index: ( - text[index] == '$' and not text[index + direction].isspace(), - index + default_index_move, - ) - - -def get_detection_functions(text, index) -> tuple[ - Callable[[str, int], Tuple[bool, int]], - Callable[[str, int], Tuple[bool, int]], - Optional[str], - bool, - bool, -]: - multiline_tags = [('\\begin{', '\\end{', True), ('$$', '$$', False)] - for start_tag, end_tag, add_dollars in multiline_tags: - if text.startswith(start_tag, index): - return ( - get_starts_with_tag_function(start_tag, 1), - get_starts_with_tag_function(end_tag, 0), - None, - add_dollars, - False, - ) - - starts_with_dollar_func = get_single_dollar_functions(1, 1) - ends_with_dollar_func = get_single_dollar_functions(-1, 0) - if starts_with_dollar_func(text, index)[0]: - return starts_with_dollar_func, ends_with_dollar_func, '\n', False, True - - return None, None, None, None, None - - -def proccess_plain_text(text: str) -> str: - special_chars = r'*_{}[]()#+-.!`' - for character in special_chars: - text = text.replace(character, '\\' + character) - return text - - -def preprocess_latex(text: str, escape: bool = True) -> str: - text = '\n' + text.replace('\\[', '\n$$\n').replace('\\]', '\n$$\n').replace('\\(', ' $').replace('\\)', '$ ') - - right_side_operations = ['-', '=', '+', '*', '/'] - left_side_operations = ['=', '+', '*', '/'] - for op in right_side_operations: - text = text.replace(op + '$', op + ' $') - - for op in left_side_operations: - text = text.replace('$' + op, '$ ' + op) - - text += '\n' - index = 1 - texts = [] - start_plain_text_index = -1 - while index < len(text) - 1: - ( - detect_start_token, - detect_end_token, - end_sign, - add_dollars, - use_last_block_only, - ) = get_detection_functions(text, index) - if detect_start_token is not None: - if start_plain_text_index != -1: - texts.append( - proccess_plain_text(text[start_plain_text_index:index]) - if escape - else text[start_plain_text_index:index] - ) - start_plain_text_index = -1 - - start_index, new_index = proccess_tag( - text, - index, - detect_start_token, - detect_end_token, - end_sign, - use_last_block_only, - ) - texts.append(proccess_plain_text(text[index:start_index]) if escape else text[index:start_index]) - if add_dollars: - texts.append('\n$$\n') - texts.append(text[start_index:new_index].strip()) - texts.append('\n$$\n') - else: - texts.append(text[start_index:new_index]) - index = new_index - elif start_plain_text_index == -1: - start_plain_text_index = index - index += 1 - else: - index += 1 - if start_plain_text_index != -1: - texts.append(proccess_plain_text(text[start_plain_text_index:]) if escape else text[start_plain_text_index:]) - return ''.join(texts).replace('\n', '\n\n').strip() - - -def design_text_output(texts: List[Union[str, str]], style={}, text_modes: List[str] = [LATEX, ANSI]) -> html.Div: - conv = Ansi2HTMLConverter() - ansi_escape = re.compile(r'\x1b\[[0-9;]*m') - full_text = ''.join(map(lambda x: x[0], texts)) - if ANSI in text_modes: - if bool(ansi_escape.search(full_text)) or 'ipython-input' in full_text or 'Traceback' in full_text: - if bool(ansi_escape.search(full_text)): - full_text = conv.convert(full_text, full=False) - else: - full_text = conv.convert(full_text.replace('[', '\u001b['), full=False) - return html.Div( - iframe_template( - '', - f'
{full_text}
', - ), - style=style, - ) - return html.Div( - ( - dcc.Markdown( - preprocess_latex(full_text, escape=MARKDOWN not in text_modes), - mathjax=True, - dangerously_allow_html=True, - ) - if LATEX in text_modes and COMPARE not in text_modes - else ( - dcc.Markdown(full_text) - if MARKDOWN in text_modes and COMPARE not in text_modes - else [html.Span(text, style={**inner_style, "whiteSpace": "pre-wrap"}) for text, inner_style in texts] - ) - ), - style=style, - ) - - -def update_height_js(iframe_id: str) -> str: - return f""" - function updateHeight() {{ - var body = document.body, - html = document.documentElement; - - var height = Math.max(body.scrollHeight, body.offsetHeight, - html.clientHeight, html.scrollHeight, html.offsetHeight); - - parent.postMessage({{ frameHeight: height, frameId: '{iframe_id}' }}, '*'); - }} - window.onload = updateHeight; - window.onresize = updateHeight; - """ - - -def iframe_template(header: str, content: str, style: Dict = {}, iframe_id: str = None) -> html.Iframe: - if not iframe_id: - iframe_id = get_random_id() - - iframe_style = { - "width": "100%", - "border": "none", - "overflow": "hidden", - } - - iframe_style.update(style) - - return html.Iframe( - id=iframe_id, - srcDoc=f""" - - - - {header} - - - {content} - - - """, - style=iframe_style, - ) - - -def get_random_id() -> str: - return ''.join(random.choices(string.ascii_letters + string.digits, k=20)) - - -def highlight_code(codes: List[Tuple[str, Dict[str, str]]], **kwargs) -> html.Iframe: - - full_code = ''.join([code for code, style in codes]) - - # Track positions and styles - positions = [] - current_pos = 0 - for code, style in codes: - start_pos = current_pos - end_pos = current_pos + len(code) - if style: - positions.append((start_pos, end_pos, style)) - current_pos = end_pos - - # Custom formatter to apply styles at correct positions - class CustomHtmlFormatter(HtmlFormatter): - def __init__(self, positions, **options): - super().__init__(**options) - self.positions = positions - self.current_pos = 0 - - def format(self, tokensource, outfile): - style_starts = {start: style for start, _, style in self.positions} - style_ends = {end: style for _, end, style in self.positions} - active_styles = [] - - for ttype, value in tokensource: - token_length = len(value) - token_start = self.current_pos - - # Apply styles character by character - result = '' - for i, char in enumerate(value): - char_pos = token_start + i - - # Check if a style starts or ends here - if char_pos in style_starts: - style = style_starts[char_pos] - active_styles.append(style) - if char_pos in style_ends: - style = style_ends[char_pos] - if style in active_styles: - active_styles.remove(style) - - # Get CSS class for syntax highlighting - css_class = self._get_css_class(ttype) - char_html = escape(char) - if css_class: - char_html = f'{char_html}' - - # Apply active styles - if active_styles: - combined_style = {} - for style_dict in active_styles: - combined_style.update(style_dict) - style_str = '; '.join(f'{k}: {v}' for k, v in combined_style.items()) - char_html = f'{char_html}' - - result += char_html - - outfile.write(result) - self.current_pos += token_length - - # Use the custom formatter to highlight the code - lexer = PythonLexer() - formatter = CustomHtmlFormatter(positions, nowrap=True) - style_defs = formatter.get_style_defs('.highlight') - style_defs += """ -.highlight { - font-family: monospace; -} -""" - - output = StringIO() - formatter.format(lexer.get_tokens(full_code), output) - highlighted_code = output.getvalue() - - # Build the iframe content - iframe_id = get_random_id() - content = f""" -
{highlighted_code}
- -""" - - return html.Div( - iframe_template( - header=f"", - content=content, - iframe_id=iframe_id, - style={"border": "black 1px solid", "background-color": "#ebecf0d8"}, - ) - ) diff --git a/nemo_inspector/utils/strategies/base_strategy.py b/nemo_inspector/utils/strategies/base_strategy.py deleted file mode 100644 index eebe90b3c..000000000 --- a/nemo_inspector/utils/strategies/base_strategy.py +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import logging -from typing import Callable, Dict, List, Union - -import dash_bootstrap_components as dbc -import requests -from dash import dcc, html -from flask import current_app -from layouts import ( - get_input_group_layout, - get_results_content_layout, - get_single_prompt_output_layout, - get_switch_layout, - get_text_area_layout, - get_text_modes_layout, -) -from settings.constants import ( - FEW_SHOTS_INPUT, - QUERY_INPUT_TYPE, - RETRIEVAL, - RETRIEVAL_FIELDS, - SEPARATOR_DISPLAY, - SEPARATOR_ID, -) -from utils.common import get_config, get_settings, get_utils_from_config, initialize_default - -from nemo_skills.code_execution.math_grader import extract_answer -from nemo_skills.code_execution.sandbox import get_sandbox -from nemo_skills.inference.server.code_execution_model import get_code_execution_model -from nemo_skills.inference.server.model import get_model -from nemo_skills.prompt.few_shot_examples import examples_map -from nemo_skills.prompt.utils import Prompt, PromptConfig - - -class ModeStrategies: - def __init__(self): - self.sandbox = None - - def sandbox_init(self): - if self.sandbox is None and 'sandbox' in current_app.config['nemo_inspector']: - self.sandbox = get_sandbox( - **current_app.config['nemo_inspector']['sandbox'], - ) - - def get_utils_input_layout( - self, - condition: Callable[[str, Union[str, int, float, bool]], bool] = lambda key, value: True, - disabled: bool = False, - ) -> List[dbc.AccordionItem]: - utils = get_utils_from_config(current_app.config['nemo_inspector']).items() - input_group_layout = html.Div( - ( - [ - get_input_group_layout( - name, - value, - ) - for name, value in sorted( - utils, - key=lambda item: ( - 1 - if item[0].split(SEPARATOR_DISPLAY)[-1] in current_app.config['nemo_inspector']['types'] - else 0 if not isinstance(item[1], str) else 2 - ), - ) - if condition(name, value) - ] - ), - id="utils_group", - ) - utils_group_layout = [ - dbc.AccordionItem( - html.Div( - [ - input_group_layout, - ] - ), - title="Utils", - ) - ] - return utils_group_layout - - def get_few_shots_input_layout(self) -> List[dbc.AccordionItem]: - config = current_app.config['nemo_inspector'] - size = len(examples_map.get(config["examples_type"], [])) - return [ - dbc.AccordionItem( - self.get_few_shots_div_layout(size), - title="Few shots", - id="few_shots_group", - ) - ] - - def get_query_input_layout( - self, query_data: Dict[str, str], is_prompt_search: bool = True - ) -> List[dbc.AccordionItem]: - switch_layout = [ - get_text_modes_layout( - QUERY_INPUT_TYPE, - False, - ) - ] - search_layout = [self._get_search_prompt_layout()] if is_prompt_search else [] - query_input = [ - html.Div( - self.get_query_input_children_layout(query_data), - id="query_input_children", - ) - ] - query_store = [dcc.Store(id={"type": "query_store", "id": 1}, data=query_data)] - return [ - dbc.AccordionItem( - html.Div( - switch_layout + search_layout + query_input + query_store, - ), - title="Input", - id="query_input_content", - ) - ] - - def get_query_input_children_layout( - self, query_data: Dict[str, str], text_modes: List[str] = [] - ) -> List[dbc.InputGroup]: - return [ - dbc.InputGroup( - [ - dbc.InputGroupText(key), - get_text_area_layout( - id={ - "type": QUERY_INPUT_TYPE, - "id": key, - }, - value=str(value), - text_modes=text_modes, - editable=True, - ), - ], - className="mb-3", - ) - for key, value in query_data.items() - ] - - def get_few_shots_div_layout(self, size: int) -> html.Div: - return html.Div( - [ - html.Div( - [ - dbc.Pagination( - id="few_shots_pagination", - max_value=size, - active_page=1, - ), - get_text_modes_layout(FEW_SHOTS_INPUT, True), - ] - ), - dbc.Container(id="few_shots_pagination_content"), - ], - id="few_shots_div", - ) - - def run(self, utils: Dict, params: Dict) -> html.Div: - utils = {key.split(SEPARATOR_ID)[-1]: value for key, value in utils.items()} - if utils['code_execution'] and str(utils['code_execution']) == 'True': - self.sandbox_init() - llm = get_code_execution_model( - **current_app.config['nemo_inspector']['server'], - sandbox=self.sandbox, - ) - else: - llm = get_model(**current_app.config['nemo_inspector']['server']) - - generate_params = { - key: value for key, value in utils.items() if key in inspect.signature(llm.generate).parameters - } - logging.info(f"query to process: {params['prompts'][0]}") - - try: - outputs = llm.generate( - prompts=params['prompts'], - stop_phrases=current_app.config['nemo_inspector']['prompt']['stop_phrases'], - **generate_params, - ) - except requests.exceptions.ConnectionError as e: - return self._get_connection_error_message() - except Exception as e: - logging.error(f"error during run prompt: {e}") - logging.error(f"error type: {type(e)}") - return html.Div(f"Got error\n{e}") - - logging.info(f"query's answer: {outputs[0]}") - - try: - predicted_answer = extract_answer(outputs[0]['generation']) - color, background, is_correct = ( - ('#d4edda', '#d4edda', "correct") - if self.sandbox.is_output_correct(predicted_answer, params["expected_answer"]) - else ("#fecccb", "#fecccb", "incorrect") - ) - except Exception as e: - color, background, is_correct = 'black', 'white', "unknown" - return html.Div( - [ - get_results_content_layout( - outputs[0]['generation'], - get_single_prompt_output_layout( - outputs[0]['generation'], - ), - style={"border": f"2px solid {color}"}, - is_formatted=True, - ), - html.Div( - ( - f"Answer {predicted_answer} is {is_correct}" - if is_correct != "unknown" - else "Could not evaluate the answer" - ), - style={"background-color": background}, - ), - ] - ) - - def get_prompt(self, utils: Dict, input_dict: Dict[str, str]) -> str: - utils = { - key.split(SEPARATOR_ID)[-1]: value - for key, value in utils.items() - if key != RETRIEVAL and key not in RETRIEVAL_FIELDS - } - prompt_config = initialize_default(PromptConfig, {**utils}) - prompt = Prompt(config=prompt_config) - return prompt.fill(input_dict) - - def _get_search_prompt_layout(self) -> dbc.InputGroup: - return dbc.InputGroup( - [ - dbc.InputGroupText("Index of test"), - dbc.Input( - value=1, - id="query_search_input", - type="number", - size="sm", - ), - dbc.Button( - "Search", - id="query_search_button", - outline=True, - size="sm", - color="primary", - className="me-1", - ), - ], - className="mb-3", - ) - - def _get_connection_error_message(self): - return html.Div( - html.P( - [ - "Could not connect to the server. Please check that the server is running (look at ", - html.A( - "inference.md", - href="https://github.com/NVIDIA/NeMo-Skills/blob/main/docs/inference.md", - ), - " for more information). ", - "Also check that you have provided correct host, ssh_key_path and ssh_server parameters", - ] - ) - ) diff --git a/nemo_inspector/utils/strategies/chat_mode.py b/nemo_inspector/utils/strategies/chat_mode.py deleted file mode 100644 index 2bfbb5e74..000000000 --- a/nemo_inspector/utils/strategies/chat_mode.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, List - -import dash_bootstrap_components as dbc -from flask import current_app -from settings.constants import ANSWER_FIELD, ONE_SAMPLE_MODE, QUESTION_FIELD, SEPARATOR_ID -from utils.strategies.base_strategy import ModeStrategies - - -class ChatModeStrategy(ModeStrategies): - mode = ONE_SAMPLE_MODE - - def __init__(self): - super().__init__() - - def get_utils_input_layout(self) -> List[dbc.AccordionItem]: - config = current_app.config['nemo_inspector'] - return super().get_utils_input_layout( - lambda key, value: key in config['inference'].keys(), - True, - ) - - def get_few_shots_input_layout(self) -> List[dbc.AccordionItem]: - return [] - - def get_query_input_layout(self, dataset) -> List[dbc.AccordionItem]: - return super().get_query_input_layout( - { - QUESTION_FIELD: "", - ANSWER_FIELD: "", - }, - False, - ) - - def run(self, utils: Dict, params: Dict): - utils = {key.split(SEPARATOR_ID)[-1]: value for key, value in utils.items()} - params['prompts'] = [self.get_prompt(utils, params)] - return super().run(utils, params) - - def get_prompt(self, utils: Dict, params: Dict) -> str: - utils = {key.split(SEPARATOR_ID)[-1]: value for key, value in utils.items()} - utils['user'] = '{question}' - utils['prompt_template'] = '{user}\n{generation}' - return super().get_prompt( - utils, - params, - ) diff --git a/nemo_inspector/utils/strategies/one_sample_mode.py b/nemo_inspector/utils/strategies/one_sample_mode.py deleted file mode 100644 index ce274122c..000000000 --- a/nemo_inspector/utils/strategies/one_sample_mode.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, List - -import dash_bootstrap_components as dbc -from dash import html -from settings.constants import ONE_SAMPLE_MODE, SEPARATOR_ID -from utils.common import get_test_data -from utils.strategies.base_strategy import ModeStrategies - - -class OneTestModeStrategy(ModeStrategies): - mode = ONE_SAMPLE_MODE - - def __init__(self): - super().__init__() - - def get_utils_input_layout(self) -> List[dbc.AccordionItem]: - return super().get_utils_input_layout(disabled=True) - - def get_query_input_layout(self, dataset: str) -> List[dbc.AccordionItem]: - return super().get_query_input_layout(get_test_data(0, dataset)[0]) - - def run(self, utils: Dict, params: Dict) -> html.Div: - utils = {key.split(SEPARATOR_ID)[-1]: value for key, value in utils.items()} - params['prompts'] = [self.get_prompt(utils, params)] - return super().run(utils, params) diff --git a/nemo_inspector/utils/strategies/strategy_maker.py b/nemo_inspector/utils/strategies/strategy_maker.py deleted file mode 100644 index 459f40e58..000000000 --- a/nemo_inspector/utils/strategies/strategy_maker.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -from settings.constants import CHAT_MODE, ONE_SAMPLE_MODE -from utils.strategies.base_strategy import ModeStrategies -from utils.strategies.chat_mode import ChatModeStrategy -from utils.strategies.one_sample_mode import OneTestModeStrategy - - -class RunPromptStrategyMaker: - strategies = { - ONE_SAMPLE_MODE: OneTestModeStrategy, - CHAT_MODE: ChatModeStrategy, - } - - def __init__(self, mode: Optional[str] = None): - self.mode = mode - - def get_strategy(self) -> ModeStrategies: - return self.strategies.get(self.mode, ModeStrategies)() diff --git a/requirements/inspector-tests.txt b/requirements/inspector-tests.txt deleted file mode 100644 index 131d94982..000000000 --- a/requirements/inspector-tests.txt +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -dash[testing] -pytest-rerunfailures -webdriver-manager==4.0.2 diff --git a/requirements/inspector.txt b/requirements/inspector.txt deleted file mode 100644 index 7b5ac4b2a..000000000 --- a/requirements/inspector.txt +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -ansi2html -dash -dash-ace -dash_bootstrap_components -hydra-core -joblib -pandas -pygments -sshtunnel_requests