diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index 35fce89c6..9dbbe23cf 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -1729,6 +1729,62 @@ def leave_one_out(self, group_by, leave_one_mode): next_ds = [self.copy(_) for _ in next_df] return next_ds + def time_based_split(self, ratios, group_by): + """Split interaction records by time-based strategy that combines global temporal and leave-one-out constraints. + + Args: + ratios (list): List of split ratios. + group_by (str): Field name that interaction records should be grouped by before splitting. + + Returns: + list: List of :class:`~Dataset`, whose interaction features have been split. + """ + self.logger.debug(f"time based split, group_by=[{group_by}]") + if group_by is None: + raise ValueError("Time-based split strategy requires a group field") + + if self.time_field not in self.inter_feat: + raise ValueError(f"Field [{self.time_field}] is not in inter_feat.") + + self.logger.debug(f"time-based split with ratios [{ratios}], group_by=[{group_by}]") + tot_ratio = sum(ratios) + ratios = [_ / tot_ratio for _ in ratios] + + # Determine the global temporal boundary (e.g., 90th percentile) + all_times = self.inter_feat[self.time_field].numpy() + global_temporal_boundary = np.percentile(all_times, 100 * (1 - ratios[-1])) + + train_index, valid_index, test_index = [], [], [] + grouped_inter_feat_index = self._grouped_index(self.inter_feat[group_by].numpy()) + + for grouped_index in grouped_inter_feat_index: + grouped_index = np.array(grouped_index) + grouped_inter_feat = self.inter_feat[grouped_index] + grouped_inter_feat.sort(by=self.time_field) + + # Split into training/validation and test sets based on the global temporal boundary + times = grouped_inter_feat[self.time_field].numpy() + train_valid_mask = times <= global_temporal_boundary + test_mask = ~train_valid_mask + + train_valid_index = grouped_index[train_valid_mask] + test_user_indices = grouped_index[test_mask] + + split_point = int(len(train_valid_index) * (ratios[0] / (1 - ratios[2]))) + train_index.extend(train_valid_index[:split_point]) + valid_index.extend(train_valid_index[split_point:]) + + test_index.extend(test_user_indices) + + self._drop_unused_col() + next_df = [ + self.inter_feat[train_index], + self.inter_feat[valid_index], + self.inter_feat[test_index], + ] + next_ds = [self.copy(_) for _ in next_df] + return next_ds + def shuffle(self): """Shuffle the interaction records inplace.""" self.inter_feat.shuffle() @@ -1799,6 +1855,13 @@ def build(self): datasets = self.leave_one_out( group_by=self.uid_field, leave_one_mode=split_args["LS"] ) + elif split_mode == "TS": + if not isinstance(split_args["TS"], list): + raise ValueError(f'The value of "TS" [{split_args}] should be a list.') + datasets = self.time_based_split( + ratios=split_args["TS"], + group_by=self.uid_field + ) else: raise NotImplementedError( f"The splitting_method [{split_mode}] has not been implemented."