diff --git a/metagpt/ext/sela/README.md b/metagpt/ext/sela/README.md index 6fb47b42c..e74183e7c 100644 --- a/metagpt/ext/sela/README.md +++ b/metagpt/ext/sela/README.md @@ -19,11 +19,6 @@ You can either download the datasets from the link or prepare the datasets from ## 2. Configurations -### Data Config - -- **`datasets.yaml`:** Provide base prompts, metrics, and target columns for respective datasets. -- **`data.yaml`:** Modify `datasets_dir` to the base directory of all prepared datasets. - ### LLM Config ```yaml @@ -34,13 +29,7 @@ llm: api_key: sk-xxx temperature: 0.5 ``` - - -## 3. SELA - -### Run SELA - -#### Setup +### Setup ```bash pip install -e . @@ -50,13 +39,43 @@ cd metagpt/ext/sela pip install -r requirements.txt ``` -#### Running Experiments +## 3. Quick Start + +### Example : Running SELA on the House Price Prediction Task + +- **To run the project, simply execute the following command** + ```bash + python run_sela.py + ``` -- **Examples:** +- **Explanation of `run_sela.py`** ```bash + requirement = (''' + Optimize dataset using MCTS with 10 rollouts. + This is a 05_house-prices-advanced-regression-techniques dataset. + Your goal is to predict the target column `SalePrice`. + Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. + Report rmse on the eval data. Do not plot or make any visualizations.''') + data_dir = "Path/to/dataset" + + sela = SELA() + await sela.run(requirement, data_dir) + ``` + +## 4. SELA Reproduction Details + +### Data Config +- **`datasets.yaml`:** Provide base prompts, metrics, and target columns for respective datasets. +- **`data.yaml`:** Modify `datasets_dir` to the base directory of all prepared datasets. + +### Run SELA + +#### Examples + +```bash python run_experiment.py --exp_mode mcts --task titanic --rollouts 10 python run_experiment.py --exp_mode mcts --task house-prices --rollouts 10 --low_is_better - ``` + ``` #### Parameters @@ -78,7 +97,7 @@ pip install -r requirements.txt ### Ablation Study -**RandomSearch** +#### RandomSearch - **Use a single insight:** ```bash @@ -90,7 +109,7 @@ pip install -r requirements.txt python run_experiment.py --exp_mode rs --task titanic --rs_mode set ``` -## 4. Citation +## 5. Citation Please cite our paper if you use SELA or find it cool or useful! ```bibtex diff --git a/metagpt/ext/sela/data/dataset.py b/metagpt/ext/sela/data/dataset.py index ef4179011..d80b78d1a 100644 --- a/metagpt/ext/sela/data/dataset.py +++ b/metagpt/ext/sela/data/dataset.py @@ -113,7 +113,11 @@ def get_split_dataset_path(dataset_name, config): datasets_dir = config["datasets_dir"] if dataset_name in config["datasets"]: dataset = config["datasets"][dataset_name] - data_path = os.path.join(datasets_dir, dataset["dataset"]) + # Check whether `dataset["dataset"]` is already the suffix of `datasets_dir`. If it isn't, perform path concatenation. + if datasets_dir.rpartition("/")[-1] == dataset["dataset"]: + data_path = datasets_dir + else: + data_path = Path(datasets_dir) / dataset["dataset"] split_datasets = { "train": os.path.join(data_path, "split_train.csv"), "dev": os.path.join(data_path, "split_dev.csv"), diff --git a/metagpt/ext/sela/insights/instruction_generator.py b/metagpt/ext/sela/insights/instruction_generator.py index d5d24c74d..5600efe8d 100644 --- a/metagpt/ext/sela/insights/instruction_generator.py +++ b/metagpt/ext/sela/insights/instruction_generator.py @@ -34,9 +34,8 @@ class InstructionGenerator: - data_config = DATA_CONFIG - - def __init__(self, state, use_fixed_insights, from_scratch): + def __init__(self, state, use_fixed_insights, from_scratch, data_config=None): + self.data_config = data_config if data_config is not None else DATA_CONFIG self.state = state self.file_path = state["exp_pool_path"] if state["custom_dataset_dir"]: @@ -44,8 +43,11 @@ def __init__(self, state, use_fixed_insights, from_scratch): self.dataset_info = file.read() else: dataset_info_path = ( - f"{self.data_config['datasets_dir']}/{state['dataset_config']['dataset']}/dataset_info.json" + f"{self.data_config['datasets_dir']}/dataset_info.json" + if self.data_config["datasets_dir"].rpartition("/")[-1] == state["dataset_config"]["dataset"] + else f"{self.data_config['datasets_dir']}/{state['dataset_config']['dataset']}/dataset_info.json" ) + with open(dataset_info_path, "r") as file: self.dataset_info = json.load(file) self.use_fixed_insights = use_fixed_insights diff --git a/metagpt/ext/sela/run_sela.py b/metagpt/ext/sela/run_sela.py new file mode 100644 index 000000000..404336259 --- /dev/null +++ b/metagpt/ext/sela/run_sela.py @@ -0,0 +1,24 @@ +import fire +from runner.sela import SELA + +requirement = """ +Implement MCTS with a rollout count of 10 to improve my dataset. Focus on forecasting the RS column. +Carry out data analysis, data preprocessing, feature engineering, and modeling for the forecast. +Report the rmse on the evaluation dataset, omitting any visual or graphical outputs. +""" + + +async def main(): + """ + The main function serves as an entry point and supports direct running. + """ + # Example requirement and data path + data_dir = "Path/to/dataset" + + # Initialize Sela and run + sela = SELA() + await sela.run(requirement, data_dir) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/metagpt/ext/sela/runner/mcts.py b/metagpt/ext/sela/runner/mcts.py index 8b6c14100..90ef1c57b 100644 --- a/metagpt/ext/sela/runner/mcts.py +++ b/metagpt/ext/sela/runner/mcts.py @@ -12,7 +12,7 @@ class MCTSRunner(Runner): result_path: str = "results/mcts" - def __init__(self, args, tree_mode=None, **kwargs): + def __init__(self, args, data_config=None, tree_mode=None, **kwargs): if args.special_instruction == "image": self.start_task_id = 1 # start from datapreprocessing if it is image task else: @@ -23,7 +23,7 @@ def __init__(self, args, tree_mode=None, **kwargs): elif args.eval_func == "mlebench": self.eval_func = node_evaluate_score_mlebench - super().__init__(args, **kwargs) + super().__init__(args, data_config=data_config, **kwargs) self.tree_mode = tree_mode async def run_experiment(self): @@ -35,7 +35,7 @@ async def run_experiment(self): mcts = Random(root_node=None, max_depth=depth, use_fixed_insights=use_fixed_insights) else: mcts = MCTS(root_node=None, max_depth=depth, use_fixed_insights=use_fixed_insights) - best_nodes = await mcts.search(state=self.state, args=self.args) + best_nodes = await mcts.search(state=self.state, args=self.args, data_config=self.data_config) best_node = best_nodes["global_best"] dev_best_node = best_nodes["dev_best"] score_dict = best_nodes["scores"] diff --git a/metagpt/ext/sela/runner/runner.py b/metagpt/ext/sela/runner/runner.py index 4b5504e09..26d6d7620 100644 --- a/metagpt/ext/sela/runner/runner.py +++ b/metagpt/ext/sela/runner/runner.py @@ -13,13 +13,13 @@ class Runner: result_path: str = "results/base" - data_config = DATA_CONFIG start_task_id = 1 - def __init__(self, args, **kwargs): + def __init__(self, args, data_config=None, **kwargs): self.args = args self.start_time_raw = datetime.datetime.now() self.start_time = self.start_time_raw.strftime("%Y%m%d%H%M") + self.data_config = data_config if data_config is not None else DATA_CONFIG self.state = create_initial_state( self.args.task, start_task_id=self.start_task_id, diff --git a/metagpt/ext/sela/runner/sela.py b/metagpt/ext/sela/runner/sela.py new file mode 100644 index 000000000..0a02f2fa3 --- /dev/null +++ b/metagpt/ext/sela/runner/sela.py @@ -0,0 +1,160 @@ +import argparse +import json +import os +from typing import Optional + +from metagpt.ext.sela.runner.custom import CustomRunner +from metagpt.ext.sela.runner.mcts import MCTSRunner +from metagpt.ext.sela.runner.random_search import RandomSearchRunner +from metagpt.ext.sela.runner.runner import Runner +from metagpt.llm import LLM +from metagpt.utils.common import CodeParser + +SELA_INSTRUCTION = """ +You are an assistant for configuring machine learning experiments. + +Given the requirement and data directory: +{requirement} +{data_dir} + +Your task: +1. Extract **experiment configurations** from the requirement if explicitly mentioned, such as: + - "rollouts: 10" + - "exp_mode: mcts" + - "max_depth: 4" + +2. Extract **experiment data information**, including: + - **dataset**: Dataset name (if explicitly mentioned in the requirement, use that; otherwise, use the last folder name in the data directory path) + - **metric**: Evaluation metric + - **target_col**: Target column + - **user_requirement**: Specific instructions or dataset handling requirements + +Output a JSON object with two parts: +- "config": A dictionary of explicitly mentioned configurations, using keys: + - "task": str (a noun based on the dataset name, customizable, e.g., "titanic") + - "exp_mode": str (e.g., "mcts", "rs", "base", "custom", "greedy", "autogluon") + - "rollouts": int + - "max_depth": int + - "rs_mode": str (e.g., "single", "set") + - "special_instruction": str (e.g., "text", "image") +- "data_info": A dictionary of experiment data information, with keys: + - "dataset": str (e.g., "04_titanic") + - "metric": str (e.g., "f1", "rmse") + - "target_col": str (e.g., "Survived") + - "user_requirement": str + +Example output: +```json +{{ + "config": {{ + "task": "titanic", + "exp_mode": "mcts", + "rollouts": 10 + }}, + "data_info": {{ + "dataset": "04_titanic", + "metric": "f1", + "target_col": "Survived", + "user_requirement": "Predict the target column `Survived`. Perform data analysis, preprocessing, feature engineering, and modeling. Report f1 on eval data. Do not include visualizations." + }} +}} +``` + +Return only the JSON object. +""" +DEFAULT_CONFIG = { + "name": "", + "reflection": True, + "no_reflection": False, + "exp_mode": "mcts", + "rollouts": 10, + "load_tree": False, + "role_timeout": 1000, + "use_fixed_insights": False, + "low_is_better": False, + "start_task_id": 2, + "from_scratch": True, + "eval_func": "sela", + "custom_dataset_dir": None, + "max_depth": 4, + "rs_mode": "single", + "is_multimodal": True, + "num_experiments": 1, + "external_eval": True, + "no_external_eval": False, + "special_instruction": None, +} + + +class SELA: + def __init__(self, use_llm: bool = True): + """ + Initialize the SELA class. + Args: + use_llm: Whether to use LLM (Language Model) to parse the requirement. + """ + self.llm = LLM() if use_llm else None + + async def _parse_requirement(self, requirement: str, data_dir: str) -> dict: + """ + Use LLM to analyze the experiment requirement and extract configurations. + """ + if not self.llm: + raise ValueError("LLM is not initialized. Cannot parse the requirement.") + response = await self.llm.aask( + SELA_INSTRUCTION.format(requirement=json.dumps(requirement), data_dir=json.dumps(data_dir)) + ) + print(f"LLM Response: {response}") + parsed_response = self._parse_json(response) + return { + "config": {**DEFAULT_CONFIG, **parsed_response.get("config", {})}, + "data_info": parsed_response.get("data_info", {}), + } + + @staticmethod + def _parse_json(json_string: str) -> dict: + """ + Extract and parse JSON content from the given string using CodeParser. + """ + try: + json_code = CodeParser.parse_code("", json_string, "json") + import json + + return json.loads(json_code) + except ValueError: + raise ValueError(f"Invalid JSON format: {json_string}") + + def _select_runner(self, config: argparse.Namespace, data_config: dict): + """ + Select the appropriate experiment runner based on the experiment mode. + """ + runners = { + "mcts": lambda: MCTSRunner(config, data_config), + "greedy": lambda: MCTSRunner(tree_mode="greedy"), + "random": lambda: MCTSRunner(tree_mode="random"), + "rs": lambda: RandomSearchRunner(config), + "base": lambda: Runner(config), + "custom": lambda: CustomRunner(config), + } + if config.exp_mode not in runners: + raise ValueError(f"Invalid exp_mode: {config.exp_mode}") + return runners[config.exp_mode]() + + async def run(self, requirement: str, data_dir: Optional[str] = None): + """ + Run the experiment with the given requirement and data directory. + """ + if not os.path.exists(data_dir): + raise FileNotFoundError(f"Dataset directory not found: {data_dir}") + + config_all = await self._parse_requirement(requirement, data_dir) + config_exp, data_info = config_all["config"], config_all["data_info"] + + data_config = { + "datasets_dir": data_dir, + "work_dir": "../../workspace", + "role_dir": "storage/SELA", + "datasets": {config_exp.get("task"): data_info}, + } + + await self._select_runner(argparse.Namespace(**config_exp), data_config).run_experiment() diff --git a/metagpt/ext/sela/search/tree_search.py b/metagpt/ext/sela/search/tree_search.py index eac26c86c..d1d4283b3 100644 --- a/metagpt/ext/sela/search/tree_search.py +++ b/metagpt/ext/sela/search/tree_search.py @@ -410,7 +410,7 @@ def get_score_order_dict(self): scores["test_raw"].append(node.raw_reward["test_score"]) return scores - async def search(self, state: dict, args): + async def search(self, state: dict, args, data_config): reflection = args.reflection load_tree = args.load_tree rollouts = args.rollouts @@ -418,7 +418,7 @@ async def search(self, state: dict, args): role, root = initialize_di_root_node(state, reflection=reflection) self.root_node = root self.instruction_generator = InstructionGenerator( - state=state, use_fixed_insights=self.use_fixed_insights, from_scratch=from_scratch + state=state, use_fixed_insights=self.use_fixed_insights, from_scratch=from_scratch, data_config=data_config ) await self.instruction_generator.initialize() diff --git a/metagpt/rag/benchmark/__init__.py b/metagpt/rag/benchmark/__init__.py deleted file mode 100644 index 7f143b9f2..000000000 --- a/metagpt/rag/benchmark/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from metagpt.rag.benchmark.base import RAGBenchmark - -__all__ = ["RAGBenchmark"] diff --git a/metagpt/rag/benchmark/base.py b/metagpt/rag/benchmark/base.py deleted file mode 100644 index b5d265b35..000000000 --- a/metagpt/rag/benchmark/base.py +++ /dev/null @@ -1,201 +0,0 @@ -import asyncio -from typing import List, Tuple, Union - -import evaluate -import jieba -from llama_index.core.embeddings import BaseEmbedding -from llama_index.core.evaluation import SemanticSimilarityEvaluator -from llama_index.core.schema import NodeWithScore -from pydantic import BaseModel - -from metagpt.const import EXAMPLE_BENCHMARK_PATH -from metagpt.logs import logger -from metagpt.rag.factories import get_rag_embedding -from metagpt.utils.common import read_json_file - - -class DatasetInfo(BaseModel): - name: str - document_files: List[str] - gt_info: List[dict] - - -class DatasetConfig(BaseModel): - datasets: List[DatasetInfo] - - -class RAGBenchmark: - def __init__( - self, - embed_model: BaseEmbedding = None, - ): - self.evaluator = SemanticSimilarityEvaluator( - embed_model=embed_model or get_rag_embedding(), - ) - - def set_metrics( - self, - bleu_avg: float = 0.0, - bleu_1: float = 0.0, - bleu_2: float = 0.0, - bleu_3: float = 0.0, - bleu_4: float = 0.0, - rouge_l: float = 0.0, - semantic_similarity: float = 0.0, - recall: float = 0.0, - hit_rate: float = 0.0, - mrr: float = 0.0, - length: float = 0.0, - generated_text: str = None, - ground_truth_text: str = None, - question: str = None, - ): - metrics = { - "bleu-avg": bleu_avg, - "bleu-1": bleu_1, - "bleu-2": bleu_2, - "bleu-3": bleu_3, - "bleu-4": bleu_4, - "rouge-L": rouge_l, - "semantic similarity": semantic_similarity, - "recall": recall, - "hit_rate": hit_rate, - "mrr": mrr, - "length": length, - } - - log = { - "generated_text": generated_text, - "ground_truth_text": ground_truth_text, - "question": question, - } - - return {"metrics": metrics, "log": log} - - def bleu_score(self, response: str, reference: str, with_penalty=False) -> Union[float, Tuple[float]]: - f = lambda text: list(jieba.cut(text)) - bleu = evaluate.load(path="bleu") - results = bleu.compute(predictions=[response], references=[[reference]], tokenizer=f) - - bleu_avg = results["bleu"] - bleu1 = results["precisions"][0] - bleu2 = results["precisions"][1] - bleu3 = results["precisions"][2] - bleu4 = results["precisions"][3] - brevity_penalty = results["brevity_penalty"] - - if with_penalty: - return bleu_avg, bleu1, bleu2, bleu3, bleu4 - else: - return 0.0 if brevity_penalty == 0 else bleu_avg / brevity_penalty, bleu1, bleu2, bleu3, bleu4 - - def rougel_score(self, response: str, reference: str) -> float: - # pip install rouge_score - f = lambda text: list(jieba.cut(text)) - rouge = evaluate.load(path="rouge") - - results = rouge.compute(predictions=[response], references=[[reference]], tokenizer=f, rouge_types=["rougeL"]) - score = results["rougeL"] - return score - - def recall(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float: - if nodes: - total_recall = sum(any(node.text in doc for node in nodes) for doc in reference_docs) - return total_recall / len(reference_docs) - else: - return 0.0 - - def hit_rate(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float: - if nodes: - return 1.0 if any(node.text in doc for doc in reference_docs for node in nodes) else 0.0 - else: - return 0.0 - - def mean_reciprocal_rank(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float: - mrr_sum = 0.0 - - for i, node in enumerate(nodes, start=1): - for doc in reference_docs: - if text in doc: - mrr_sum += 1.0 / i - return mrr_sum - - return mrr_sum - - async def semantic_similarity(self, response: str, reference: str) -> float: - result = await self.evaluator.aevaluate( - response=response, - reference=reference, - ) - - return result.score - - async def compute_metric( - self, - response: str = None, - reference: str = None, - nodes: list[NodeWithScore] = None, - reference_doc: list[str] = None, - question: str = None, - ): - recall = self.recall(nodes, reference_doc) - bleu_avg, bleu1, bleu2, bleu3, bleu4 = self.bleu_score(response, reference) - rouge_l = self.rougel_score(response, reference) - hit_rate = self.hit_rate(nodes, reference_doc) - mrr = self.mean_reciprocal_rank(nodes, reference_doc) - - similarity = await self.semantic_similarity(response, reference) - - result = self.set_metrics( - bleu_avg, - bleu1, - bleu2, - bleu3, - bleu4, - rouge_l, - similarity, - recall, - hit_rate, - mrr, - len(response), - response, - reference, - question, - ) - - return result - - @staticmethod - def load_dataset(ds_names: list[str] = ["all"]): - infos = read_json_file((EXAMPLE_BENCHMARK_PATH / "dataset_info.json").as_posix()) - dataset_config = DatasetConfig( - datasets=[ - DatasetInfo( - name=name, - document_files=[ - (EXAMPLE_BENCHMARK_PATH / name / file).as_posix() for file in info["document_file"] - ], - gt_info=read_json_file((EXAMPLE_BENCHMARK_PATH / name / info["gt_file"]).as_posix()), - ) - for dataset_info in infos - for name, info in dataset_info.items() - if name in ds_names or "all" in ds_names - ] - ) - - return dataset_config - - -if __name__ == "__main__": - benchmark = RAGBenchmark() - answer = "是的,根据提供的信息,2023年7月20日,应急管理部和财政部确实联合发布了《因灾倒塌、损坏住房恢复重建救助工作规范》的通知。这份《规范》旨在进一步规范因灾倒塌、损坏住房的恢复重建救助相关工作。它明确了地方各级政府负责实施救助工作,应急管理部和财政部则负责统筹指导。地方财政应安排足够的资金,中央财政也会提供适当的补助。救助资金将通过专账管理,并采取特定的管理方式。救助对象是那些因自然灾害导致住房倒塌或损坏,并向政府提出申请且符合条件的受灾家庭。相关部门将组织调查统计救助对象信息,并建立档案。此外,《规范》还强调了资金发放的具体方式和公开透明的要求。" - ground_truth = "“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。" - bleu_avg, bleu1, bleu2, bleu3, bleu4 = benchmark.bleu_score(answer, ground_truth) - rougeL_score = benchmark.rougel_score(answer, ground_truth) - similarity = asyncio.run(benchmark.SemanticSimilarity(answer, ground_truth)) - - logger.info( - f"BLEU Scores: bleu_avg = {bleu_avg}, bleu1 = {bleu1}, bleu2 = {bleu2}, bleu3 = {bleu3}, bleu4 = {bleu4}, " - f"RougeL Score: {rougeL_score}, " - f"Semantic Similarity: {similarity}" - ) diff --git a/metagpt/rag/parsers/__init__.py b/metagpt/rag/parsers/__init__.py deleted file mode 100644 index 03ac0de3a..000000000 --- a/metagpt/rag/parsers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from metagpt.rag.parsers.omniparse import OmniParse - -__all__ = ["OmniParse"] diff --git a/metagpt/rag/parsers/omniparse.py b/metagpt/rag/parsers/omniparse.py deleted file mode 100644 index ec08e38f1..000000000 --- a/metagpt/rag/parsers/omniparse.py +++ /dev/null @@ -1,139 +0,0 @@ -import asyncio -from fileinput import FileInput -from pathlib import Path -from typing import List, Optional, Union - -from llama_index.core import Document -from llama_index.core.async_utils import run_jobs -from llama_index.core.readers.base import BaseReader - -from metagpt.logs import logger -from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType -from metagpt.utils.async_helper import NestAsyncio -from metagpt.utils.omniparse_client import OmniParseClient - - -class OmniParse(BaseReader): - """OmniParse""" - - def __init__( - self, api_key: str = None, base_url: str = "http://localhost:8000", parse_options: OmniParseOptions = None - ): - """ - Args: - api_key: Default None, can be used for authentication later. - base_url: OmniParse Base URL for the API. - parse_options: Optional settings for OmniParse. Default is OmniParseOptions with default values. - """ - self.parse_options = parse_options or OmniParseOptions() - self.omniparse_client = OmniParseClient(api_key, base_url, max_timeout=self.parse_options.max_timeout) - - @property - def parse_type(self): - return self.parse_options.parse_type - - @property - def result_type(self): - return self.parse_options.result_type - - @parse_type.setter - def parse_type(self, parse_type: Union[str, OmniParseType]): - if isinstance(parse_type, str): - parse_type = OmniParseType(parse_type) - self.parse_options.parse_type = parse_type - - @result_type.setter - def result_type(self, result_type: Union[str, ParseResultType]): - if isinstance(result_type, str): - result_type = ParseResultType(result_type) - self.parse_options.result_type = result_type - - async def _aload_data( - self, - file_path: Union[str, bytes, Path], - extra_info: Optional[dict] = None, - ) -> List[Document]: - """ - Load data from the input file_path. - - Args: - file_path: File path or file byte data. - extra_info: Optional dictionary containing additional information. - - Returns: - List[Document] - """ - try: - if self.parse_type == OmniParseType.PDF: - # pdf parse - parsed_result = await self.omniparse_client.parse_pdf(file_path) - else: - # other parse use omniparse_client.parse_document - # For compatible byte data, additional filename is required - extra_info = extra_info or {} - filename = extra_info.get("filename") - parsed_result = await self.omniparse_client.parse_document(file_path, bytes_filename=filename) - - # Get the specified structured data based on result_type - content = getattr(parsed_result, self.result_type) - docs = [ - Document( - text=content, - metadata=extra_info or {}, - ) - ] - except Exception as e: - logger.error(f"OMNI Parse Error: {e}") - docs = [] - - return docs - - async def aload_data( - self, - file_path: Union[List[FileInput], FileInput], - extra_info: Optional[dict] = None, - ) -> List[Document]: - """ - Load data from the input file_path. - - Args: - file_path: File path or file byte data. - extra_info: Optional dictionary containing additional information. - - Notes: - This method ultimately calls _aload_data for processing. - - Returns: - List[Document] - """ - docs = [] - if isinstance(file_path, (str, bytes, Path)): - # Processing single file - docs = await self._aload_data(file_path, extra_info) - elif isinstance(file_path, list): - # Concurrently process multiple files - parse_jobs = [self._aload_data(file_item, extra_info) for file_item in file_path] - doc_ret_list = await run_jobs(jobs=parse_jobs, workers=self.parse_options.num_workers) - docs = [doc for docs in doc_ret_list for doc in docs] - return docs - - def load_data( - self, - file_path: Union[List[FileInput], FileInput], - extra_info: Optional[dict] = None, - ) -> List[Document]: - """ - Load data from the input file_path. - - Args: - file_path: File path or file byte data. - extra_info: Optional dictionary containing additional information. - - Notes: - This method ultimately calls aload_data for processing. - - Returns: - List[Document] - """ - NestAsyncio.apply_once() # Ensure compatibility with nested async calls - return asyncio.run(self.aload_data(file_path, extra_info)) diff --git a/metagpt/rag/retrievers/milvus_retriever.py b/metagpt/rag/retrievers/milvus_retriever.py deleted file mode 100644 index bcc66330b..000000000 --- a/metagpt/rag/retrievers/milvus_retriever.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Milvus retriever.""" - -from llama_index.core.retrievers import VectorIndexRetriever -from llama_index.core.schema import BaseNode - - -class MilvusRetriever(VectorIndexRetriever): - """Milvus retriever.""" - - def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: - """Support add nodes.""" - self._index.insert_nodes(nodes, **kwargs) - - def persist(self, persist_dir: str, **kwargs) -> None: - """Support persist. - - Milvus automatically saves, so there is no need to implement."""