From 79350f33a5d5900d31ede2d0740312508e479a2c Mon Sep 17 00:00:00 2001 From: Aakash Thatte Date: Sun, 15 Dec 2024 22:19:00 +0530 Subject: [PATCH] add tests for sampling params --- tests/test_samplers.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 88cdb0fbc..10a7be26f 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -47,6 +47,13 @@ def test_greedy(): assert ancestors.equal(torch.tensor([0, 1])) assert weights.equal(torch.tensor([logprobs[0, 0], logprobs[1, 2]])) + params = sampler.sampling_params + assert params.sampler == "greedy" + assert params.num_samples == 1 + assert params.top_p is None + assert params.top_k is None + assert params.temperature == 0.0 + def test_multinomial(): rng = torch.Generator() @@ -72,6 +79,14 @@ def test_multinomial(): assert ancestors.equal(torch.tensor([0, 1])) assert weights.equal(torch.tensor([logprobs[0, 0], logprobs[1, 2]])) + sampler = MultinomialSampler(samples=5, top_k=10, top_p=0.9, temperature=0.8) + params = sampler.sampling_params + assert params.sampler == "multinomial" + assert params.num_samples == 5 + assert params.top_p == 0.9 + assert params.top_k == 10 + assert params.temperature == 0.8 + def test_multinomial_init(): sampler = MultinomialSampler() @@ -252,3 +267,11 @@ def test_beam_search(): ] ) ) + + sampler = BeamSearchSampler(beams=3) + params = sampler.sampling_params + assert params.sampler == "beam_search" + assert params.num_samples == 3 + assert params.top_p is None + assert params.top_k is None + assert params.temperature == 1.0