Skip to content

Commit

Permalink
check if gen_df is None (#216)
Browse files Browse the repository at this point in the history
  • Loading branch information
qew21 authored Aug 20, 2024
1 parent 538d4ef commit 48ff804
Showing 1 changed file with 30 additions and 6 deletions.
36 changes: 30 additions & 6 deletions rdagent/components/coder/factor_coder/CoSTEER/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ def evaluate(
gt_implementation: Workspace,
) -> Tuple[str, object]:
_, gen_df = self._get_df(gt_implementation, implementation)

if gen_df is None:
return (
"The source dataframe is None. Please check the implementation.",
False,
)
if len(gen_df.columns) == 1:
return "The source dataframe has only one column which is correct.", True
else:
Expand Down Expand Up @@ -241,7 +245,11 @@ def evaluate(
gt_implementation: Workspace,
) -> Tuple[str, object]:
gt_df, gen_df = self._get_df(gt_implementation, implementation)

if gen_df is None:
return (
"The source dataframe is None. Please check the implementation.",
False,
)
if gen_df.shape[0] == gt_df.shape[0]:
return "Both dataframes have the same rows count.", True
else:
Expand All @@ -258,7 +266,11 @@ def evaluate(
gt_implementation: Workspace,
) -> Tuple[str, object]:
gt_df, gen_df = self._get_df(gt_implementation, implementation)

if gen_df is None:
return (
"The source dataframe is None. Please check the implementation.",
False,
)
if gen_df.index.equals(gt_df.index):
return "Both dataframes have the same index.", True
else:
Expand All @@ -275,7 +287,11 @@ def evaluate(
gt_implementation: Workspace,
) -> Tuple[str, object]:
gt_df, gen_df = self._get_df(gt_implementation, implementation)

if gen_df is None:
return (
"The source dataframe is None. Please check the implementation.",
False,
)
if gen_df.isna().sum().sum() == gt_df.isna().sum().sum():
return "Both dataframes have the same missing values.", True
else:
Expand All @@ -292,7 +308,11 @@ def evaluate(
gt_implementation: Workspace,
) -> Tuple[str, object]:
gt_df, gen_df = self._get_df(gt_implementation, implementation)

if gen_df is None:
return (
"The source dataframe is None. Please check the implementation.",
-1,
)
try:
close_values = gen_df.sub(gt_df).abs().lt(1e-6)
result_int = close_values.astype(int)
Expand Down Expand Up @@ -323,7 +343,11 @@ def evaluate(
gt_implementation: Workspace,
) -> Tuple[str, object]:
gt_df, gen_df = self._get_df(gt_implementation, implementation)

if gen_df is None:
return (
"The source dataframe is None. Please check the implementation.",
False,
)
concat_df = pd.concat([gen_df, gt_df], axis=1)
concat_df.columns = ["source", "gt"]
ic = concat_df.groupby("datetime").apply(lambda df: df["source"].corr(df["gt"])).dropna().mean()
Expand Down

0 comments on commit 48ff804

Please sign in to comment.