diff --git a/relbench/tasks/avito.py b/relbench/tasks/avito.py index 873fb5d4..7f03a622 100644 --- a/relbench/tasks/avito.py +++ b/relbench/tasks/avito.py @@ -60,7 +60,7 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab return Table( df=df, - fkey_col_to_pkey_table={"AdID": "entity_table"}, + fkey_col_to_pkey_table={self.entity_col: self.entity_table}, pkey_col=None, time_col="timestamp", ) @@ -108,7 +108,7 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab return Table( df=df, - fkey_col_to_pkey_table={"UserID": "entity_table"}, + fkey_col_to_pkey_table={self.entity_col: self.entity_table}, pkey_col=None, time_col="timestamp", ) @@ -165,7 +165,7 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab return Table( df=df, - fkey_col_to_pkey_table={"UserID": "entity_table"}, + fkey_col_to_pkey_table={self.entity_col: self.entity_table}, pkey_col=None, time_col="timestamp", )