Skip to content

Commit

Permalink
add end2end
Browse files Browse the repository at this point in the history
  • Loading branch information
Trustccc committed Dec 4, 2024
1 parent 6fbc2a9 commit e5140c9
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 4 deletions.
227 changes: 227 additions & 0 deletions metagpt/ext/sela/end2end_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import os
import fire
import asyncio
import json
import argparse
import shutil
from typing import Optional
from metagpt.llm import LLM

from metagpt.ext.sela.data.custom_task import get_mle_is_lower_better, get_mle_task_id
from metagpt.ext.sela.runner.autogluon import GluonRunner
from metagpt.ext.sela.runner.autosklearn import AutoSklearnRunner
from metagpt.ext.sela.runner.custom import CustomRunner
from metagpt.ext.sela.runner.mcts import MCTSRunner
from metagpt.ext.sela.runner.random_search import RandomSearchRunner
from metagpt.ext.sela.runner.runner import Runner

from metagpt.ext.sela.evaluation.evaluation import (
node_evaluate_score_mlebench,
node_evaluate_score_sela,
)
from metagpt.ext.sela.evaluation.visualize_mcts import get_tree_text
from metagpt.ext.sela.runner.runner import Runner
from metagpt.ext.sela.search.search_algorithm import MCTS, Greedy, Random


class Sela:
DEFAULT_CONFIG = {
"name": "",
"reflection": True,
"no_reflection": False,
"exp_mode": "mcts",
"rollouts": 10,
"load_tree": False,
"role_timeout": 1000,
"use_fixed_insights": False,
"low_is_better": False,
"start_task_id": 2,
"from_scratch": False,
"eval_func": "sela",
"custom_dataset_dir": None,
"max_depth": 4,
"rs_mode": "single",
"is_multimodal": True,
"num_experiments": 1,
"external_eval": True,
"no_external_eval": False,
"special_instruction": None,
}

def __init__(self, use_llm: bool = True):
"""
初始化 Sela 类。
Args:
use_llm: 是否使用 LLM 来解析 requirement。
"""
self.llm = LLM() if use_llm else None

async def _parse_requirement(self, requirement: str) -> dict:
"""
使用 LLM 分析实验需求,提取实验配置和实验数据信息。
Args:
requirement: 用户输入的实验需求描述。
Returns:
dict: 包含实验配置和实验数据信息的字典。
"""
if not self.llm:
raise ValueError("LLM is not initialized. Cannot parse the requirement.")

# 确保 `requirement` 是安全的字符串
sanitized_requirement = json.dumps(requirement) # 将字符串转为 JSON 安全字符串

prompt = f"""
You are an assistant that helps configure machine learning experiments.
Given the following requirement:
{sanitized_requirement}
Your task:
1. Extract **experiment configurations** from the requirement if they are explicitly mentioned.
For example, "rollouts: 10", "exp_mode: mcts", or "max_depth: 4". These should override default values.
2. Extract **experiment data information** from the requirement. This includes:
- **dataset**: The name of the dataset being used (e.g., "04_titanic").
- **metric**: The evaluation metric or scoring method mentioned (e.g., "f1", "rmse", "f1 weighted").
- **target_col**: Predict the target column `Survived` (e.g., "Survived").
- **user_requirement**: Any specific instructions or requirements for handling the dataset (e.g.,"Your goal is to predict the target column `Survived`."
"Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. "
"Report f1 on the eval data. Do not plot or make any visualizations.")
Output a JSON object containing two parts:
- "config": This is a dictionary containing the experiment configuration. Include only explicitly mentioned configurations. Use keys like:
- "task": str (e.g., "titanic")
- "exp_mode": str (e.g., "mcts", "rs", "base", "custom", "greedy", "autogluon", "random", "autosklearn")
- "rollouts": int
- "max_depth": int
- "rs_mode": str (e.g., "single", "set")
- "special_instruction": str (e.g., "text", "image")
- "data_info": This is a dictionary containing the experiment data information with keys:
- "dataset": str (e.g., "04_titanic")
- "metric": str (e.g., "f1", "rmse", "f1 weighted")
- "target_col": str (e.g., "Survived")
- "user_requirement": str (e.g., Your goal is to predict the target column `Survived`."
"Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. "
"Report f1 on the eval data. Do not plot or make any visualizations.")
Example output:
{{
"config": {{
"task": "titanic",
"exp_mode": "mcts",
"rollouts": 10
}},
"data_info": {{
"dataset": "04_titanic",
"metric": "f1",
"target_col": "Predict the target column Survived",
"user_requirement": Your goal is to predict the target column `Survived`. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. "
"Report f1 on the eval data. Do not plot or make any visualizations."
}}
}}
Return only the JSON object. Do not include any comments or extra text.
"""
response = await self.llm.aask(prompt)
print(f"LLM Response: {response}")

parsed_response = self._parse_json(response)
config_from_user = parsed_response.get("config", {})
data_info = parsed_response.get("data_info", {})

# 合并默认配置和用户提供的配置
config = {**self.DEFAULT_CONFIG, **config_from_user}
return {"config": config, "data_info": data_info}

@staticmethod
def _parse_json(json_string: str) -> dict:
"""
解析 JSON 字符串,去除可能的 Markdown 标记。
"""
json_string = json_string.strip()
if json_string.startswith("```json"):
json_string = json_string[7:].strip()
if json_string.endswith("```"):
json_string = json_string[:-3].strip()

try:
return json.loads(json_string)
except json.JSONDecodeError:
raise ValueError(f"Invalid JSON format: {json_string}")

def _select_runner(self, config: argparse.Namespace, data_config: dict):
"""
根据配置选择适当的实验执行器。
Args:
config: 从 LLM 解析出的实验配置。
Returns:
实验执行器实例。
"""
exp_mode = config.exp_mode
if exp_mode == "mcts":
return MCTSRunner(config, data_config)
elif exp_mode == "greedy":
return MCTSRunner(tree_mode="greedy")
elif exp_mode == "random":
return MCTSRunner(tree_mode="random")
elif exp_mode == "rs":
return RandomSearchRunner(config)
elif exp_mode == "base":
return Runner(config)
elif exp_mode == "custom":
return CustomRunner(config)
else:
raise ValueError(f"Invalid exp_mode: {exp_mode}")

async def run(self, requirement: str, data_dir: Optional[str] = None):
"""
Args:
requirement: 实验需求,描述目标任务。
data_dir: 数据目录。
"""
if not os.path.exists(data_dir):
raise FileNotFoundError(f"Dataset directory not found: {data_dir}")

# 使用 LLM 解析需求
config_all = await self._parse_requirement(requirement)
config_exp = config_all["config"]
data_info = config_all["data_info"]

# 构造默认的 data_config 文件
data_config = {
"datasets_dir": data_dir, # 用户输入的数据目录路径
"work_dir": "../../workspace", # 默认的工作目录
"role_dir": "storage/SELA", # 存储角色路径
"datasets": {config_exp.get("task"): data_info}, # 数据集信息
}

# 根据需求选择适当的实验执行器
runner = self._select_runner(argparse.Namespace(**config_exp), data_config)

# 运行实验
await runner.run_experiment()


async def main():
"""
Main 函数作为入口,支持直接运行。
"""
# 示例需求和数据路径
requirement = ("Optimize 04_titanic dataset using MCTS with 10 rollouts. "
"Your goal is to predict the target column `Survived`."
"Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. "
"Report f1 on the eval data. Do not plot or make any visualizations.")

data_dir = "/home/coder/project/chenxin/MetaGPT/metagpt/ext/sela/data/SELA-datasets"

# 初始化 Sela 并运行
sela = Sela()
await sela.run(requirement, data_dir)


if __name__ == "__main__":
fire.Fire(main)
4 changes: 2 additions & 2 deletions metagpt/ext/sela/runner/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class MCTSRunner(Runner):
result_path: str = "results/mcts"

def __init__(self, args, tree_mode=None, **kwargs):
def __init__(self, args, data_config=None, tree_mode=None, **kwargs):
if args.special_instruction == "image":
self.start_task_id = 1 # start from datapreprocessing if it is image task
else:
Expand All @@ -23,7 +23,7 @@ def __init__(self, args, tree_mode=None, **kwargs):
elif args.eval_func == "mlebench":
self.eval_func = node_evaluate_score_mlebench

super().__init__(args, **kwargs)
super().__init__(args, data_config=None, **kwargs)
self.tree_mode = tree_mode

async def run_experiment(self):
Expand Down
4 changes: 2 additions & 2 deletions metagpt/ext/sela/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ class Runner:
data_config = DATA_CONFIG
start_task_id = 1

def __init__(self, args, **kwargs):
def __init__(self, args, data_config=None,**kwargs):
self.args = args
self.start_time_raw = datetime.datetime.now()
self.start_time = self.start_time_raw.strftime("%Y%m%d%H%M")
self.state = create_initial_state(
self.args.task,
start_task_id=self.start_task_id,
data_config=self.data_config,
data_config=data_config if data_config is not None else self.data_config,
args=self.args,
)

Expand Down

0 comments on commit e5140c9

Please sign in to comment.