Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Time-based split in evaluation #2094

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,62 @@ def leave_one_out(self, group_by, leave_one_mode):
next_ds = [self.copy(_) for _ in next_df]
return next_ds

def time_based_split(self, ratios, group_by):
"""Split interaction records by time-based strategy that combines global temporal and leave-one-out constraints.

Args:
ratios (list): List of split ratios.
group_by (str): Field name that interaction records should be grouped by before splitting.

Returns:
list: List of :class:`~Dataset`, whose interaction features have been split.
"""
self.logger.debug(f"time based split, group_by=[{group_by}]")
if group_by is None:
raise ValueError("Time-based split strategy requires a group field")

if self.time_field not in self.inter_feat:
raise ValueError(f"Field [{self.time_field}] is not in inter_feat.")

self.logger.debug(f"time-based split with ratios [{ratios}], group_by=[{group_by}]")
tot_ratio = sum(ratios)
ratios = [_ / tot_ratio for _ in ratios]

# Determine the global temporal boundary (e.g., 90th percentile)
all_times = self.inter_feat[self.time_field].numpy()
global_temporal_boundary = np.percentile(all_times, 100 * (1 - ratios[-1]))

train_index, valid_index, test_index = [], [], []
grouped_inter_feat_index = self._grouped_index(self.inter_feat[group_by].numpy())

for grouped_index in grouped_inter_feat_index:
grouped_index = np.array(grouped_index)
grouped_inter_feat = self.inter_feat[grouped_index]
grouped_inter_feat.sort(by=self.time_field)

# Split into training/validation and test sets based on the global temporal boundary
times = grouped_inter_feat[self.time_field].numpy()
train_valid_mask = times <= global_temporal_boundary
test_mask = ~train_valid_mask

train_valid_index = grouped_index[train_valid_mask]
test_user_indices = grouped_index[test_mask]

split_point = int(len(train_valid_index) * (ratios[0] / (1 - ratios[2])))
train_index.extend(train_valid_index[:split_point])
valid_index.extend(train_valid_index[split_point:])

test_index.extend(test_user_indices)

self._drop_unused_col()
next_df = [
self.inter_feat[train_index],
self.inter_feat[valid_index],
self.inter_feat[test_index],
]
next_ds = [self.copy(_) for _ in next_df]
return next_ds

def shuffle(self):
"""Shuffle the interaction records inplace."""
self.inter_feat.shuffle()
Expand Down Expand Up @@ -1799,6 +1855,13 @@ def build(self):
datasets = self.leave_one_out(
group_by=self.uid_field, leave_one_mode=split_args["LS"]
)
elif split_mode == "TS":
if not isinstance(split_args["TS"], list):
raise ValueError(f'The value of "TS" [{split_args}] should be a list.')
datasets = self.time_based_split(
ratios=split_args["TS"],
group_by=self.uid_field
)
else:
raise NotImplementedError(
f"The splitting_method [{split_mode}] has not been implemented."
Expand Down