-
-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update vital TODOs and comments in Kaggle scen
- Loading branch information
1 parent
d1e2d51
commit 3e0cf74
Showing
5 changed files
with
127 additions
and
13 deletions.
There are no files selected for viewing
97 changes: 90 additions & 7 deletions
97
rdagent/scenarios/feature_engineering/developer/data_runner.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,113 @@ | ||
from typing import List | ||
import pandas as pd | ||
|
||
from rdagent.components.coder.model_coder.model import ModelExperiment, ModelFBWorkspace | ||
from rdagent.components.runner import CachedRunner | ||
from rdagent.components.runner.conf import RUNNER_SETTINGS | ||
from rdagent.core.exception import ModelEmptyError | ||
from rdagent.core.exception import FactorEmptyError | ||
from rdagent.core.conf import RD_AGENT_SETTINGS | ||
from rdagent.core.utils import multiprocessing_wrapper | ||
from rdagent.log import rdagent_logger as logger | ||
from rdagent.scenarios.feature_engineering.experiment.feature_experiment import FEFeatureExperiment | ||
|
||
|
||
class FEFeatureRunner(CachedRunner[FEFeatureExperiment]): | ||
def develop(self, exp: FEFeatureExperiment) -> FEFeatureExperiment: | ||
import pickle | ||
with open('/home/v-yuanteli/RD-Agent/git_ignore_folder/test_featexp_data.pkl', 'wb') as f: | ||
pickle.dump(exp, f) | ||
print("Feature Experiment object saved to test_featexp_data.pkl") | ||
|
||
|
||
# 这里考虑把每次实验都加一次原始数据 | ||
if exp.based_experiments and exp.based_experiments[-1].result is None: | ||
exp.based_experiments[-1] = self.develop(exp.based_experiments[-1]) | ||
|
||
if RUNNER_SETTINGS.cache_result: | ||
cache_hit, result = self.get_cache_result(exp) | ||
if cache_hit: | ||
exp.result = result | ||
return exp | ||
|
||
#TODO 这里对应于SOTA因子库的概念 | ||
if exp.based_experiments: | ||
SOTA_factor = None | ||
if len(exp.based_experiments) > 1: | ||
SOTA_factor = self.process_factor_data(exp.based_experiments) | ||
|
||
# Process the new factors data | ||
new_factors = self.process_factor_data(exp) | ||
|
||
if new_factors.empty: | ||
raise FactorEmptyError("No valid factor data found to merge.") | ||
|
||
# Combine the SOTA factor and new factors if SOTA factor exists | ||
if SOTA_factor is not None and not SOTA_factor.empty: | ||
new_factors = self.deduplicate_new_factors(SOTA_factor, new_factors) | ||
if new_factors.empty: | ||
raise FactorEmptyError("No valid factor data found to merge.") | ||
combined_factors = pd.concat([SOTA_factor, new_factors], axis=1).dropna() | ||
else: | ||
combined_factors = new_factors | ||
|
||
if exp.sub_workspace_list[0].code_dict.get("model.py") is None: | ||
raise ModelEmptyError("model.py is empty") | ||
# to replace & inject code | ||
exp.experiment_workspace.inject_code(**{"model.py": exp.sub_workspace_list[0].code_dict["model.py"]}) | ||
# Sort and nest the combined factors under 'feature' | ||
# TODO 这里是去重吧,针对feature的格式处理 kaggle应该不需要 | ||
combined_factors = combined_factors.sort_index() | ||
combined_factors = combined_factors.loc[:, ~combined_factors.columns.duplicated(keep="last")] | ||
new_columns = pd.MultiIndex.from_product([["feature"], combined_factors.columns]) | ||
combined_factors.columns = new_columns | ||
|
||
env_to_use = {"PYTHONPATH": "./"} | ||
# Save the combined factors to the workspace | ||
with open(exp.experiment_workspace.workspace_path / "combined_factors_df.pkl", "wb") as f: | ||
pickle.dump(combined_factors, f) | ||
|
||
result = exp.experiment_workspace.execute(run_env=env_to_use) | ||
# TODO 这里还是execute,应该是连kaggle的dockers | ||
result = exp.experiment_workspace.execute( | ||
qlib_config_name=f"conf.yaml" if len(exp.based_experiments) == 0 else "conf_combined.yaml" | ||
) | ||
|
||
exp.result = result | ||
if RUNNER_SETTINGS.cache_result: | ||
self.dump_cache_result(exp, result) | ||
|
||
return exp | ||
|
||
|
||
return exp | ||
|
||
def process_factor_data(self, exp_or_list: List[FEFeatureExperiment] | FEFeatureExperiment) -> pd.DataFrame: | ||
""" | ||
Process and combine factor data from experiment implementations. | ||
Args: | ||
exp (ASpecificExp): The experiment containing factor data. | ||
Returns: | ||
pd.DataFrame: Combined factor data without NaN values. | ||
""" | ||
#TODO 这里需要把task的代码执行一遍,得到一个dataframe | ||
if isinstance(exp_or_list, FEFeatureExperiment): | ||
exp_or_list = [exp_or_list] | ||
factor_dfs = [] | ||
|
||
# Collect all exp's dataframes | ||
for exp in exp_or_list: | ||
# Iterate over sub-implementations and execute them to get each factor data | ||
#TODO 这里应当使用feature_execute函数实现 | ||
message_and_df_list = multiprocessing_wrapper( | ||
[(implementation.feature_execute) for implementation in exp.sub_workspace_list], | ||
n=RD_AGENT_SETTINGS.multi_proc_n, | ||
) | ||
#TODO datatime这些 这里应该不需要了 | ||
for message, df in message_and_df_list: | ||
# Check if factor generation was successful | ||
if df is not None and "datetime" in df.index.names: | ||
time_diff = df.index.get_level_values("datetime").to_series().diff().dropna().unique() | ||
if pd.Timedelta(minutes=1) not in time_diff: | ||
factor_dfs.append(df) | ||
|
||
# Combine all successful factor data | ||
if factor_dfs: | ||
return pd.concat(factor_dfs, axis=1) | ||
else: | ||
raise FactorEmptyError("No valid factor data found to merge.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
22 changes: 22 additions & 0 deletions
22
rdagent/scenarios/feature_engineering/experiment/feature_template/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class HybridFeatureInteractionModel(nn.Module): | ||
def __init__(self, num_features): | ||
super(HybridFeatureInteractionModel, self).__init__() | ||
self.fc1 = nn.Linear(num_features, 128) | ||
self.bn1 = nn.BatchNorm1d(128, momentum=0.1) | ||
self.fc2 = nn.Linear(128, 64) | ||
self.bn2 = nn.BatchNorm1d(64, momentum=0.1) | ||
self.fc3 = nn.Linear(64, 1) | ||
self.dropout = nn.Dropout(0.3) | ||
|
||
def forward(self, x): | ||
x = F.relu(self.bn1(self.fc1(x))) | ||
x = F.relu(self.bn2(self.fc2(x))) | ||
x = self.dropout(x) | ||
x = torch.sigmoid(self.fc3(x)) | ||
return x | ||
|
||
model_cls = HybridFeatureInteractionModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters