Skip to content

Commit

Permalink
do join id mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
cthorrez committed Oct 28, 2024
1 parent b3eb278 commit 98952c2
Showing 1 changed file with 34 additions and 14 deletions.
48 changes: 34 additions & 14 deletions riix/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
elif df.schema[datetime_col] == pl.Date:
datetime = df[datetime_col].cast(pl.Datetime)
elif df.schema[datetime_col] == pl.Utf8:
datetime = df[datetime_col].str.strptime(pl.Datetime, '%Y-%m-%d')
datetime = df[datetime_col].str.strptime(pl.Datetime, '%Y-%m-%dT%H:%M:%S%.f')
else:
raise ValueError('datetime_col must be one of Date, Datetime, or Utf8')
seconds_since_epoch = (datetime.dt.timestamp() // 1_000_000).to_numpy()
Expand All @@ -44,20 +44,40 @@ def __init__(
self.time_steps = self.time_steps.astype(np.int32)
self.process_time_steps()

self.num_matchups = len(df)
str_competitors = pl.concat([
df[competitor_cols[0]].cast(pl.Utf8).alias('competitor'),
df[competitor_cols[1]].cast(pl.Utf8).alias('competitor')
])
self.competitors = str_competitors.unique().sort().to_list()
# Create a single competitors reference dataframe
competitors_df = pl.DataFrame(
{'competitor': pl.concat([
df[competitor_cols[0]].cast(pl.Utf8),
df[competitor_cols[1]].cast(pl.Utf8)
]).unique().sort()
}).lazy().select(
pl.all(),
pl.int_range(pl.len(), dtype=pl.Int32).alias('index')
)
self.competitors = sorted(competitors_df.collect()['competitor'].to_list())
self.num_competitors = len(self.competitors)
self.competitor_to_idx = {comp: idx for idx, comp in enumerate(self.competitors)}
comp_idxs_1 = df[competitor_cols[0]].cast(pl.Utf8).map_elements(lambda x: self.competitor_to_idx[x], return_dtype=pl.Int32)
comp_idxs_2 = df[competitor_cols[1]].cast(pl.Utf8).map_elements(lambda x: self.competitor_to_idx[x], return_dtype=pl.Int32)
self.matchups = np.hstack([
comp_idxs_1.to_numpy()[:,None],
comp_idxs_2.to_numpy()[:,None],
])
self.competitor_to_idx = dict(zip(self.competitors, range(self.num_competitors)))
matchups_df = (df.lazy()
.select([
pl.col(competitor_cols[0]).cast(pl.Utf8).alias('comp1'),
pl.col(competitor_cols[1]).cast(pl.Utf8).alias('comp2')
])
.join(
competitors_df,
left_on='comp1',
right_on='competitor'
).rename({'index': 'index1'})
.join(
competitors_df,
left_on='comp2',
right_on='competitor'
).rename({'index': 'index2'})
.select(['index1', 'index2'])
)
self.matchups = np.ascontiguousarray(matchups_df.collect().to_numpy())



self.outcomes = df[outcome_col].to_numpy()

if verbose:
Expand Down

0 comments on commit 98952c2

Please sign in to comment.