Skip to content

Commit

Permalink
fix: fix a bug in competition metric evaluation (#407)
Browse files Browse the repository at this point in the history
* fix a bug in competition metric evaluation

* fix a bug

* fix a bug in rag loading
  • Loading branch information
WinstonLiyt authored Sep 30, 2024
1 parent 14f7d97 commit 94c47d6
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
10 changes: 6 additions & 4 deletions rdagent/core/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ def __init__(self, path: str | Path | None = None) -> None:
def load(self) -> None:
if self.path is not None and self.path.exists():
with self.path.open("rb") as f:
self.__dict__.update(
pickle.load(f).__dict__,
) # TODO: because we need to align with init function, we need a less hacky way to do this
loaded = pickle.load(f)
if isinstance(loaded, dict):
self.__dict__.update(loaded)
else:
self.__dict__.update(loaded.__dict__)

def dump(self) -> None:
if self.path is not None:
self.path.parent.mkdir(parents=True, exist_ok=True)
pickle.dump(self, self.path.open("wb"))
pickle.dump(self.__dict__, self.path.open("wb"))
else:
logger.warning("KnowledgeBase path is not set, dump failed.")
5 changes: 3 additions & 2 deletions rdagent/scenarios/kaggle/experiment/prompts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ kg_description_template:
"Competition Features": "Two-line description of the overall features involved within the competition as background."
"Submission Specifications": "The submission specification & sample submission csv descriptions for the model to output."
"Submission channel number to each sample": "The number of channels in the output for each sample, e.g., 1 for regression, N for N class classification with probabilities, etc. A Integer. If not specified, it is 1."
"Evaluation Description": "A brief description for what metrics are used in evaluation. An explanation of whether a higher score is better or lower is better in terms of performance."
"Evaluation Boolean": "True" or "False" (True means the higher score the better (like accuracy); False means the lower value the better (like loss).)
"Evaluation Description": "A brief description of the metrics used in the evaluation. Please note that if `evaluation_metric_direction` is True, it indicates that higher values are better; if False, lower values are preferred."
}
Since these might be very similar column names in data like one_hot_encoded columns, you can use some regex to group them together.
Expand All @@ -22,6 +21,8 @@ kg_description_template:
{{ competition_descriptions }}
The raw data information:
{{ raw_data_information }}
Evaluation_metric_direction:
{{ evaluation_metric_direction }}
kg_background: |-
You are solving a data science tasks and the type of the competition is {{ competition_type }}.
Expand Down
9 changes: 3 additions & 6 deletions rdagent/scenarios/kaggle/experiment/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def __init__(self, competition: str) -> None:
self.submission_specifications = None
self.model_output_channel = None
self.evaluation_desc = None
self.evaluation_metric_direction = None
self.leaderboard = leaderboard_scores(competition)
self.evaluation_metric_direction = float(self.leaderboard[0]) > float(self.leaderboard[-1])
self.vector_base = None
self.mini_case = KAGGLE_IMPLEMENT_SETTING.mini_case
self._analysis_competition_description()
Expand All @@ -75,8 +76,6 @@ def __init__(self, competition: str) -> None:
self.confidence_parameter = 1.0
self.initial_performance = 0.0

self.leaderboard = leaderboard_scores(competition)

def _analysis_competition_description(self):
sys_prompt = (
Environment(undefined=StrictUndefined)
Expand All @@ -90,6 +89,7 @@ def _analysis_competition_description(self):
.render(
competition_descriptions=self.competition_descriptions,
raw_data_information=self._source_data,
evaluation_metric_direction=self.evaluation_metric_direction,
)
)

Expand All @@ -111,9 +111,6 @@ def _analysis_competition_description(self):
self.evaluation_desc = response_json_analysis.get(
"Evaluation Description", "No evaluation specification provided."
)
self.evaluation_metric_direction = response_json_analysis.get(
"Evaluation Boolean", "No evaluation specification provided."
)

def get_competition_full_desc(self) -> str:
evaluation_direction = "higher the better" if self.evaluation_metric_direction else "lower the better"
Expand Down

0 comments on commit 94c47d6

Please sign in to comment.