diff --git a/megatron/text_generation_utils.py b/megatron/text_generation_utils.py index 7b7a390ab..bec96284a 100644 --- a/megatron/text_generation_utils.py +++ b/megatron/text_generation_utils.py @@ -82,6 +82,8 @@ def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): This function has been mostly taken from huggingface conversational ai code at https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 + When both top_k and top_p are specified, tokens are first filtered according to top_k, renormalized, and then filtered according to top_p. + logits: torch.Tensor -> logits of megatron model. top_k: integer -> integer between 0 and the models vocab size. Filters out any logits with a probability less than that of the top_kth token. top_p: float -> Top-p (nucleus) sampling chooses from the smallest possible set of tokens whose cumulative probability exceeds the probability top_p.