Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support more models in Kaggle scenario #218

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions rdagent/components/coder/model_coder/CoSTEER/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,24 @@

def shape_evaluator(prediction: torch.Tensor, target_shape: Tuple = None) -> Tuple[str, bool]:
if target_shape is None or prediction is None:
return "No output generated from the model. No shape evaluation conducted.", False
return (
"No output generated from the model. No shape evaluation conducted.",
False,
)
pre_shape = prediction.shape

return (
"The shape of the output is correct.",
True,
) # now test xgboost so no need for shape evaluator

if pre_shape == target_shape:
return "The shape of the output is correct.", True
else:
return f"The shape of the output is incorrect. Expected {target_shape}, but got {pre_shape}.", False
return (
f"The shape of the output is incorrect. Expected {target_shape}, but got {pre_shape}.",
False,
)


def reshape_tensor(original_tensor, target_shape):
Expand All @@ -50,7 +61,10 @@ def value_evaluator(
if prediction is None:
return "No output generated from the model. Skip value evaluation", False
elif target is None:
return "No ground truth output provided. Value evaluation not impractical", False
return (
"No ground truth output provided. Value evaluation not impractical",
False,
)
else:
# Calculate the mean absolute difference
diff = torch.mean(torch.abs(target - prediction)).item()
Expand Down
46 changes: 32 additions & 14 deletions rdagent/components/coder/model_coder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from pathlib import Path
from typing import Any, Dict, Optional

import torch
import numpy as np
import xgboost as xgb

from rdagent.components.coder.model_coder.conf import MODEL_IMPL_SETTINGS
from rdagent.core.exception import CodeFormatError
Expand Down Expand Up @@ -95,9 +96,17 @@ def execute(
if cache_file_path.exists():
return pickle.load(open(cache_file_path, "rb"))
mod = get_module_by_module_path(str(self.workspace_path / "model.py"))
model_cls = mod.model_cls

if self.target_task.model_type == "Tabular":
if self.target_task.model_type != "XGBoost":
model_cls = mod.model_cls

if self.target_task.model_type == "XGBoost":
X_simulated = np.random.rand(100, num_features) # 100 samples, `num_features` features each
y_simulated = np.random.randint(0, 2, 100) # Binary target for example
params = mod.get_params()
num_round = mod.get_num_round()
dtrain = xgb.DMatrix(X_simulated, label=y_simulated)
elif self.target_task.model_type == "Tabular":
input_shape = (batch_size, num_features)
m = model_cls(num_features=input_shape[1])
data = torch.full(input_shape, input_value)
Expand All @@ -113,21 +122,30 @@ def execute(
else:
raise ValueError(f"Unsupported model type: {self.target_task.model_type}")

# Initialize all parameters of `m` to `param_init_value`
for _, param in m.named_parameters():
param.data.fill_(param_init_value)

# Execute the model
if self.target_task.model_type == "Graph":
out = m(*data)
if self.target_task.model_type == "XGBoost":
bst = xgb.train(params, dtrain, num_round)
y_pred = bst.predict(dtrain)
execution_model_output = y_pred
execution_feedback_str = "Execution successful, model trained and predictions made."
else:
out = m(data)
# Initialize all parameters of `m` to `param_init_value`
for _, param in m.named_parameters():
param.data.fill_(param_init_value)

execution_model_output = out.cpu().detach()
execution_feedback_str = f"Execution successful, output tensor shape: {execution_model_output.shape}"
# Execute the model
if self.target_task.model_type == "Graph":
out = m(*data)
else:
out = m(data)

execution_model_output = out.cpu().detach()
execution_feedback_str = f"Execution successful, output tensor shape: {execution_model_output.shape}"

if MODEL_IMPL_SETTINGS.enable_execution_cache:
pickle.dump((execution_feedback_str, execution_model_output), open(cache_file_path, "wb"))
pickle.dump(
(execution_feedback_str, execution_model_output),
open(cache_file_path, "wb"),
)

except Exception as e:
execution_feedback_str = f"Execution error: {e}\nTraceback: {traceback.format_exc()}"
Expand Down
2 changes: 1 addition & 1 deletion rdagent/scenarios/kaggle/developer/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rdagent.oai.llm_utils import APIBackend
from rdagent.utils import convert2bool

feedback_prompts = Prompts(file_path=Path(__file__).parent.parent.parent / "qlib" / "prompts.yaml")
feedback_prompts = Prompts(file_path=Path(__file__).parent.parent / "prompts.yaml")
DIRNAME = Path(__file__).absolute().resolve().parent


Expand Down
72 changes: 20 additions & 52 deletions rdagent/scenarios/kaggle/experiment/model_template/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from model import model_cls
import xgboost as xgb
from model import get_num_round, get_params
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.metrics import accuracy_score
Expand All @@ -15,7 +14,6 @@
# Set random seed for reproducibility
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)


Expand All @@ -27,51 +25,22 @@ def compute_metrics_for_classification(y_true, y_pred):

def train_model(X_train, y_train, X_valid, y_valid):
"""Define and train the model."""
X_train_dense = X_train.toarray() if hasattr(X_train, "toarray") else X_train
X_valid_dense = X_valid.toarray() if hasattr(X_valid, "toarray") else X_valid
dtrain = xgb.DMatrix(X_train, label=y_train)
dvalid = xgb.DMatrix(X_valid, label=y_valid)

X_train_tensor = torch.tensor(X_train_dense, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
X_valid_tensor = torch.tensor(X_valid_dense, dtype=torch.float32)
y_valid_tensor = torch.tensor(y_valid, dtype=torch.float32).unsqueeze(1)
params = get_params()
num_round = get_num_round()

# Define the model
model = model_cls(num_features=X_train.shape[1])
evallist = [(dtrain, "train"), (dvalid, "eval")]
bst = xgb.train(params, dtrain, num_round, evallist)

# Define loss function and optimizer
criterion = nn.BCELoss() # Binary cross entropy loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train the model
num_epochs = 150 # Number of epochs
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
y_train_pred = model(X_train_tensor)
loss = criterion(y_train_pred, y_train_tensor)
loss.backward()
optimizer.step()

# Evaluate model on validation set after each epoch
model.eval()
with torch.no_grad():
y_valid_pred = model(X_valid_tensor)
valid_loss = criterion(y_valid_pred, y_valid_tensor)
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}, Validation Loss: {valid_loss.item()}")

return model
return bst


def predict(model, X):
"""Make predictions using the trained model."""
X_dense = X.toarray() if hasattr(X, "toarray") else X
X_tensor = torch.tensor(X_dense, dtype=torch.float32)
model.eval()

with torch.no_grad():
y_pred = model(X_tensor)
y_pred = y_pred.numpy().flatten()
return y_pred > 0.5 # Apply threshold to get boolean predictions
dtest = xgb.DMatrix(X)
y_pred_prob = model.predict(dtest)
return y_pred_prob > 0.5 # Apply threshold to get boolean predictions


if __name__ == "__main__":
Expand All @@ -88,7 +57,10 @@ def predict(model, X):

# Define preprocessors for numerical and categorical features
categorical_transformer = Pipeline(
steps=[("imputer", SimpleImputer(strategy="most_frequent")), ("onehot", OneHotEncoder(handle_unknown="ignore"))]
steps=[
("imputer", SimpleImputer(strategy="most_frequent")),
("onehot", OneHotEncoder(handle_unknown="ignore")),
]
)

numerical_transformer = Pipeline(steps=[("imputer", SimpleImputer(strategy="mean"))])
Expand Down Expand Up @@ -118,20 +90,16 @@ def predict(model, X):
print("Final Accuracy on validation set: ", accuracy)

# Save the validation accuracy
pd.Series(data=[accuracy], index=["ACC"]).to_csv("./submission.csv")
pd.Series(data=[accuracy], index=["ACC"]).to_csv("./submission_score.csv")

# Load and preprocess the test set
submission_df = pd.read_csv("/root/.data/test.csv")
passenger_ids = submission_df["PassengerId"]
submission_df = submission_df.drop(["PassengerId", "Name"], axis=1)
X_test = preprocessor.transform(submission_df)

# Make predictions on the test set and save them
y_test_pred = predict(model, X_test)
pd.Series(y_test_pred).to_csv("./submission_update.csv", index=False)

submission_result = pd.DataFrame({"PassengerId": passenger_ids, "Transported": y_test_pred})
# submit predictions for the test set
submission_df = pd.read_csv("/root/.data/test.csv")
submission_df = submission_df.drop(["PassengerId", "Name"], axis=1)
X_test = preprocessor.transform(submission_df)
y_test_pred = predict(model, X_test)
y_test_pred.to_csv("./submission_update.csv")
submission_result.to_csv("./submission.csv", index=False)
40 changes: 20 additions & 20 deletions rdagent/scenarios/kaggle/experiment/prompts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ kg_model_background: |-
You are solving this data science tasks of {{ competition_type }}:
{{competition_description}}

We provide an overall pipeline in train.py. Now fill in the provided train.py script to train a {{ competition_type }} model to get a good performance on this task.
We provide an overall pipeline in train.py. Now fill in the provided train.py script to train a {{ competition_type }} model to get a good performance on this task.

The model is a machine learning or deep learning structure designed to predict {{ target_description }}.
The data is extracted from the competition dataset, focusing on passenger attributes like {{ competition_features }}.
The data is extracted from the competition dataset, focusing on relevant attributes like {{ competition_features }}.
ModelType: The type of the model, "Tabular" for tabular model, "TimeSeries" for time series model and "XGBoost" for XGBoost model.

The model is defined in the following parts:
- Name: The name of the model.
Expand All @@ -31,31 +32,32 @@ kg_model_interface: |-
Your python code should follow the interface to better interact with the user's system.
You code should contain several parts:
1. The import part: import the necessary libraries.
2. A class that is a subclass of pytorch.nn.Module. This class should have an __init__ function and a forward function, which inputs a tensor and outputs a tensor.
2. A class that is a subclass of xgboost. This class should have an __init__ function and a forward function, which inputs a tensor and outputs a tensor.
3. Set a variable called "model_cls" to the class you defined.

The user will save your code into a python file called "model.py". Then the user imports model_cls in file "model.py" after setting the cwd into the directory:
```python
from model import model_cls
from model import get_params, get_num_round
```
So your python code should follow the pattern:
```python
class XXXModel(torch.nn.Module):
def get_params():
params = {
...
model_cls = XXXModel
}
return params
def get_num_round():
return xxx
```

The model has one types, "Tabular" for tabular model. The input shape to a tabular model is (batch_size, num_features).
The output shape of the model should be (batch_size, 1).
The "batch_size" is a dynamic value which is determined by the input of forward function.
The "num_features" is static which will be provided to the model through init function.
The model has one types, "XGBoost" for XGBoost model.
The XGBoost Model leverages two critical hyperparameters: "arams" and "num_round".
"params": This hyperparameter encapsulates various settings that dictate the model's behavior and learning process.
"num_round": This hyperparameter specifies the number of training iterations the model will undergo.
User will initialize the tabular model with the following code:
```python
model = model_cls(num_features=num_features)
```
User will initialize the tabular model with the following code:
```python
model = model_cls(num_features=num_features)
params = get_params()
num_round = get_num_round()
```
No other parameters will be passed to the model so give other parameters a default value or just make them static.

Expand All @@ -65,10 +67,8 @@ kg_model_interface: |-


kg_model_output_format: |-
Your output should be a tensor with shape (batch_size, 1).
The output tensor should be saved in a file named "output.pth" in the same directory as your python file.
The user will evaluate the shape of the output tensor so the tensor read from "output.pth" should be 8 numbers.

Your output should be an array with the appropriate number of predictions, each prediction being a single value. The output should be a 2D array with dimensions corresponding to the number of predictions and 1 column (e.g., (8, 1) if there are 8 predictions).

kg_model_simulator: |-
The models will be trained on the Spaceship Titanic dataset and evaluated on their ability to predict whether passengers were transported using metrics like accuracy and AUC-ROC.
The models will be trained on the competition dataset and evaluated on their ability to predict whether passengers were transported using metrics like accuracy and AUC-ROC.
Model performance will be iteratively improved based on feedback from evaluation results.
Loading