Skip to content

Commit

Permalink
add tests for sampling params
Browse files Browse the repository at this point in the history
  • Loading branch information
sky-2002 committed Dec 15, 2024
1 parent b5823a6 commit 79350f3
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def test_greedy():
assert ancestors.equal(torch.tensor([0, 1]))
assert weights.equal(torch.tensor([logprobs[0, 0], logprobs[1, 2]]))

params = sampler.sampling_params
assert params.sampler == "greedy"
assert params.num_samples == 1
assert params.top_p is None
assert params.top_k is None
assert params.temperature == 0.0


def test_multinomial():
rng = torch.Generator()
Expand All @@ -72,6 +79,14 @@ def test_multinomial():
assert ancestors.equal(torch.tensor([0, 1]))
assert weights.equal(torch.tensor([logprobs[0, 0], logprobs[1, 2]]))

sampler = MultinomialSampler(samples=5, top_k=10, top_p=0.9, temperature=0.8)
params = sampler.sampling_params
assert params.sampler == "multinomial"
assert params.num_samples == 5
assert params.top_p == 0.9
assert params.top_k == 10
assert params.temperature == 0.8


def test_multinomial_init():
sampler = MultinomialSampler()
Expand Down Expand Up @@ -252,3 +267,11 @@ def test_beam_search():
]
)
)

sampler = BeamSearchSampler(beams=3)
params = sampler.sampling_params
assert params.sampler == "beam_search"
assert params.num_samples == 3
assert params.top_p is None
assert params.top_k is None
assert params.temperature == 1.0

0 comments on commit 79350f3

Please sign in to comment.