Skip to content

Commit

Permalink
Let directory live
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Jan 9, 2025
1 parent 5993ded commit d0ec019
Showing 1 changed file with 25 additions and 23 deletions.
48 changes: 25 additions & 23 deletions test/python/test_onnxruntime_genai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,9 +644,12 @@ def _export_adapter(adapter, adapter_file_name):
reason="ONNX is not available on ARM64",
)
@pytest.mark.parametrize("extra_inputs", [("num_logits_to_keep", True), ("onnx::Neg_67", True), ("abcde", False)])
def test_preset_extra_inputs(device, phi2_for, extra_inputs):
def _prepare_model(extra_inputs_model_path):
def test_preset_extra_inputs(test_data_path, device, phi2_for, extra_inputs):
def _prepare_model(test_data_path):
phi2_model_path = phi2_for(device)
relative_model_path = "preset_extra_inputs"
extra_inputs_model_path = os.fspath(Path(test_data_path) / relative_model_path)

shutil.copytree(phi2_model_path, extra_inputs_model_path, dirs_exist_ok=True)

# Create the model with the extra inputs
Expand Down Expand Up @@ -681,29 +684,28 @@ def _prepare_model(extra_inputs_model_path):
location="model.data",
)

return valid

with tempfile.TemporaryDirectory() as model_path:
valid_model = _prepare_model(model_path)
model = og.Model(model_path)
tokenizer = og.Tokenizer(model)
prompts = [
"This is a test.",
"Rats are awesome pets!",
"The quick brown fox jumps over the lazy dog.",
]
return extra_inputs_model_path, valid

params = og.GeneratorParams(model)
params.set_search_options(max_length=20, batch_size=len(prompts))
model_path, valid_model = _prepare_model(test_data_path)
model = og.Model(model_path)
tokenizer = og.Tokenizer(model)
prompts = [
"This is a test.",
"Rats are awesome pets!",
"The quick brown fox jumps over the lazy dog.",
]

generator = og.Generator(model, params)
if not valid_model:
with pytest.raises(og.OrtException) as exc_info:
generator.append_tokens(tokenizer.encode_batch(prompts))
params = og.GeneratorParams(model)
params.set_search_options(max_length=20, batch_size=len(prompts))

assert f"Missing Input: {extra_inputs[0]}" in str(exc_info.value)
else:
generator = og.Generator(model, params)
if not valid_model:
with pytest.raises(og.OrtException) as exc_info:
generator.append_tokens(tokenizer.encode_batch(prompts))

while not generator.is_done():
generator.generate_next_token()
assert f"Missing Input: {extra_inputs[0]}" in str(exc_info.value)
else:
generator.append_tokens(tokenizer.encode_batch(prompts))

while not generator.is_done():
generator.generate_next_token()

0 comments on commit d0ec019

Please sign in to comment.