Skip to content

Commit

Permalink
[feat] optimize plan batch_size for better sharding plan when use sam…
Browse files Browse the repository at this point in the history
…pler (#8)
  • Loading branch information
tiankongdeguiji authored Oct 11, 2024
1 parent 23eeadc commit 913cd2b
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 2 deletions.
8 changes: 8 additions & 0 deletions tzrec/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,14 @@ def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch:
batch = self._data_parser.to_batch(output_data)
return batch

@property
def sampled_batch_size(self) -> int:
"""Batch size with sampler."""
if self._sampler:
return self._batch_size + self._sampler.estimated_sample_num
else:
return self._batch_size


class BaseReader(metaclass=_reader_meta_cls):
"""Reader base class.
Expand Down
37 changes: 37 additions & 0 deletions tzrec/datasets/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,11 @@ def _parse_sparse_nodes(
# pyre-ignore [16]
return features, nodes.indices

@property
def estimated_sample_num(self) -> int:
"""Max number of sampled num examples."""
raise NotImplementedError


class NegativeSampler(BaseSampler):
"""Negative Sampler.
Expand Down Expand Up @@ -363,6 +368,11 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
result_dict = dict(zip(self._attr_names, features))
return result_dict

@property
def estimated_sample_num(self) -> int:
"""Estimated number of sampled num examples."""
return self._num_sample


class NegativeSamplerV2(BaseSampler):
"""Negative Sampler V2.
Expand Down Expand Up @@ -453,6 +463,11 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
result_dict = dict(zip(self._attr_names, features))
return result_dict

@property
def estimated_sample_num(self) -> int:
"""Estimated number of sampled num examples."""
return self._num_sample


class HardNegativeSampler(BaseSampler):
"""HardNegativeSampler.
Expand Down Expand Up @@ -548,6 +563,11 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
result_dict["hard_neg_indices"] = pa.array(hard_neg_indices)
return result_dict

@property
def estimated_sample_num(self) -> int:
"""Estimated number of sampled num examples."""
return self._num_sample + min(self._num_hard_sample, 8) * self._batch_size


class HardNegativeSamplerV2(BaseSampler):
"""HardNegativeSampler.
Expand Down Expand Up @@ -649,6 +669,11 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
result_dict["hard_neg_indices"] = pa.array(hard_neg_indices)
return result_dict

@property
def estimated_sample_num(self) -> int:
"""Estimated number of sampled num examples."""
return self._num_sample + min(self._num_hard_sample, 8) * self._batch_size


class TDMSampler(BaseSampler):
"""TDM training sampler.
Expand Down Expand Up @@ -840,6 +865,13 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:

return pos_result_dict, neg_result_dict

@property
def estimated_sample_num(self) -> int:
"""Estimated number of sampled num examples."""
return (
sum(self._layer_num_sample) + len(self._layer_num_sample) - 2
) * self._batch_size


class TDMPredictSampler(BaseSampler):
"""TDM predict sampler.
Expand Down Expand Up @@ -913,3 +945,8 @@ def get(self, input_ids: pa.Array) -> Dict[str, pa.Array]:
pos_result_dict = dict(zip(self._attr_names[1:], pos_fea_result))

return pos_result_dict

@property
def estimated_sample_num(self) -> int:
"""Estimated number of sampled num examples."""
return min((2 ** (self._max_level - 1)), 800) * self._batch_size
6 changes: 6 additions & 0 deletions tzrec/datasets/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def _sampler_worker(res):
],
batch_size=4,
)
assert sampler.estimated_sample_num == 8
sampler.init_cluster()
sampler.launch_server()
sampler.init()
Expand Down Expand Up @@ -285,6 +286,7 @@ def _sampler_worker(res):
],
batch_size=4,
)
assert sampler.estimated_sample_num == 8
sampler.init_cluster()
sampler.launch_server()
sampler.init()
Expand Down Expand Up @@ -332,6 +334,7 @@ def _sampler_worker(res):
],
batch_size=4,
)
assert sampler.estimated_sample_num == 40
sampler.init_cluster()
sampler.launch_server()
sampler.init()
Expand Down Expand Up @@ -381,6 +384,7 @@ def _sampler_worker(res):
],
batch_size=4,
)
assert sampler.estimated_sample_num == 40
sampler.init_cluster()
sampler.launch_server()
sampler.init()
Expand Down Expand Up @@ -426,6 +430,7 @@ def _sampler_worker(pos_res, neg_res):
],
batch_size=4,
)
assert sampler.estimated_sample_num == 76
sampler.init_cluster()
sampler.launch_server()
sampler.init()
Expand Down Expand Up @@ -487,6 +492,7 @@ def _sampler_worker(res):
],
batch_size=4,
)
assert sampler.estimated_sample_num == 128
sampler.init_cluster()
sampler.launch_server()
sampler.init()
Expand Down
6 changes: 4 additions & 2 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,8 @@ def train_and_evaluate(

planner = create_planner(
device=device,
batch_size=data_config.batch_size,
# pyre-ignore [16]
batch_size=train_dataloader.dataset.sampled_batch_size,
)
plan = planner.collective_plan(
model, get_default_sharders(), dist.GroupMember.WORLD
Expand Down Expand Up @@ -662,7 +663,8 @@ def evaluate(

planner = create_planner(
device=device,
batch_size=data_config.batch_size,
# pyre-ignore [16]
batch_size=eval_dataloader.dataset.sampled_batch_size,
)
plan = planner.collective_plan(
model, get_default_sharders(), dist.GroupMember.WORLD
Expand Down

0 comments on commit 913cd2b

Please sign in to comment.