-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SELA end2end #1621
Open
Trustccc
wants to merge
7
commits into
geekan:main
Choose a base branch
from
Trustccc:dev
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
SELA end2end #1621
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
e5140c9
add end2end
Trustccc 54c70d4
Refactored the end-to-end code with the entry point being `run_sela.p…
Trustccc c831303
Refactored the end-to-end code with the entry point being `run_sela.p…
Trustccc 535a2be
Merge remote-tracking branch 'origin/dev' into dev
Trustccc 3a47b4e
change dataset.py, delete end2end_demo.py
Trustccc db00fc4
Update sela README
Trustccc 7895adb
change the sela.py、run_sela.py
Trustccc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import fire | ||
from runner.sela import SELA | ||
|
||
requirement = """ | ||
Implement MCTS with a rollout count of 10 to improve my dataset. Focus on forecasting the RS column. | ||
Carry out data analysis, data preprocessing, feature engineering, and modeling for the forecast. | ||
Report the rmse on the evaluation dataset, omitting any visual or graphical outputs. | ||
""" | ||
|
||
|
||
async def main(): | ||
""" | ||
The main function serves as an entry point and supports direct running. | ||
garylin2099 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
# Example requirement and data path | ||
data_dir = "Path/to/dataset" | ||
|
||
# Initialize Sela and run | ||
sela = SELA() | ||
await sela.run(requirement, data_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import argparse | ||
import json | ||
import os | ||
from typing import Optional | ||
|
||
from metagpt.ext.sela.runner.custom import CustomRunner | ||
from metagpt.ext.sela.runner.mcts import MCTSRunner | ||
from metagpt.ext.sela.runner.random_search import RandomSearchRunner | ||
from metagpt.ext.sela.runner.runner import Runner | ||
from metagpt.llm import LLM | ||
from metagpt.utils.common import CodeParser | ||
|
||
SELA_INSTRUCTION = """ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. REQ_PARSING_PROMPT |
||
You are an assistant for configuring machine learning experiments. | ||
|
||
Given the requirement and data directory: | ||
{requirement} | ||
{data_dir} | ||
|
||
Your task: | ||
1. Extract **experiment configurations** from the requirement if explicitly mentioned, such as: | ||
- "rollouts: 10" | ||
- "exp_mode: mcts" | ||
- "max_depth: 4" | ||
|
||
2. Extract **experiment data information**, including: | ||
- **dataset**: Dataset name (if explicitly mentioned in the requirement, use that; otherwise, use the last folder name in the data directory path) | ||
- **metric**: Evaluation metric | ||
- **target_col**: Target column | ||
- **user_requirement**: Specific instructions or dataset handling requirements | ||
|
||
Output a JSON object with two parts: | ||
- "config": A dictionary of explicitly mentioned configurations, using keys: | ||
- "task": str (a noun based on the dataset name, customizable, e.g., "titanic") | ||
- "exp_mode": str (e.g., "mcts", "rs", "base", "custom", "greedy", "autogluon") | ||
- "rollouts": int | ||
- "max_depth": int | ||
- "rs_mode": str (e.g., "single", "set") | ||
- "special_instruction": str (e.g., "text", "image") | ||
- "data_info": A dictionary of experiment data information, with keys: | ||
- "dataset": str (e.g., "04_titanic") | ||
- "metric": str (e.g., "f1", "rmse") | ||
- "target_col": str (e.g., "Survived") | ||
- "user_requirement": str | ||
|
||
Example output: | ||
```json | ||
{{ | ||
"config": {{ | ||
"task": "titanic", | ||
"exp_mode": "mcts", | ||
"rollouts": 10 | ||
}}, | ||
"data_info": {{ | ||
"dataset": "04_titanic", | ||
"metric": "f1", | ||
"target_col": "Survived", | ||
"user_requirement": "Predict the target column `Survived`. Perform data analysis, preprocessing, feature engineering, and modeling. Report f1 on eval data. Do not include visualizations." | ||
}} | ||
}} | ||
``` | ||
|
||
Return only the JSON object. | ||
""" | ||
DEFAULT_CONFIG = { | ||
"name": "", | ||
"reflection": True, | ||
"no_reflection": False, | ||
"exp_mode": "mcts", | ||
"rollouts": 10, | ||
"load_tree": False, | ||
"role_timeout": 1000, | ||
"use_fixed_insights": False, | ||
"low_is_better": False, | ||
"start_task_id": 2, | ||
"from_scratch": True, | ||
"eval_func": "sela", | ||
"custom_dataset_dir": None, | ||
"max_depth": 4, | ||
"rs_mode": "single", | ||
"is_multimodal": True, | ||
"num_experiments": 1, | ||
"external_eval": True, | ||
"no_external_eval": False, | ||
"special_instruction": None, | ||
} | ||
|
||
|
||
class SELA: | ||
def __init__(self, use_llm: bool = True): | ||
""" | ||
Initialize the SELA class. | ||
Args: | ||
use_llm: Whether to use LLM (Language Model) to parse the requirement. | ||
""" | ||
self.llm = LLM() if use_llm else None | ||
|
||
async def _parse_requirement(self, requirement: str, data_dir: str) -> dict: | ||
""" | ||
Use LLM to analyze the experiment requirement and extract configurations. | ||
""" | ||
if not self.llm: | ||
raise ValueError("LLM is not initialized. Cannot parse the requirement.") | ||
response = await self.llm.aask( | ||
SELA_INSTRUCTION.format(requirement=json.dumps(requirement), data_dir=json.dumps(data_dir)) | ||
) | ||
print(f"LLM Response: {response}") | ||
parsed_response = self._parse_json(response) | ||
return { | ||
"config": {**DEFAULT_CONFIG, **parsed_response.get("config", {})}, | ||
"data_info": parsed_response.get("data_info", {}), | ||
} | ||
|
||
@staticmethod | ||
def _parse_json(json_string: str) -> dict: | ||
""" | ||
Extract and parse JSON content from the given string using CodeParser. | ||
""" | ||
try: | ||
json_code = CodeParser.parse_code("", json_string, "json") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you dont need try except, let the error be raise up |
||
import json | ||
|
||
return json.loads(json_code) | ||
except ValueError: | ||
raise ValueError(f"Invalid JSON format: {json_string}") | ||
|
||
def _select_runner(self, config: argparse.Namespace, data_config: dict): | ||
""" | ||
Select the appropriate experiment runner based on the experiment mode. | ||
""" | ||
runners = { | ||
"mcts": lambda: MCTSRunner(config, data_config), | ||
"greedy": lambda: MCTSRunner(tree_mode="greedy"), | ||
"random": lambda: MCTSRunner(tree_mode="random"), | ||
"rs": lambda: RandomSearchRunner(config), | ||
"base": lambda: Runner(config), | ||
"custom": lambda: CustomRunner(config), | ||
} | ||
if config.exp_mode not in runners: | ||
raise ValueError(f"Invalid exp_mode: {config.exp_mode}") | ||
return runners[config.exp_mode]() | ||
|
||
async def run(self, requirement: str, data_dir: Optional[str] = None): | ||
""" | ||
Run the experiment with the given requirement and data directory. | ||
""" | ||
if not os.path.exists(data_dir): | ||
raise FileNotFoundError(f"Dataset directory not found: {data_dir}") | ||
|
||
config_all = await self._parse_requirement(requirement, data_dir) | ||
config_exp, data_info = config_all["config"], config_all["data_info"] | ||
|
||
data_config = { | ||
"datasets_dir": data_dir, | ||
"work_dir": "../../workspace", | ||
"role_dir": "storage/SELA", | ||
"datasets": {config_exp.get("task"): data_info}, | ||
} | ||
|
||
await self._select_runner(argparse.Namespace(**config_exp), data_config).run_experiment() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use Path from pathlib instead of os.path