Skip to content

Commit

Permalink
Allow some known extra inputs in the model (#1167)
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Jan 9, 2025
1 parent 9143cfd commit 43fa6ab
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 2 deletions.
42 changes: 42 additions & 0 deletions src/models/extra_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,46 @@

namespace Generators {

PresetExtraInputs::PresetExtraInputs(State& state)
: state_(state),
registry_{
{"num_logits_to_keep", [&state = state_]() -> std::unique_ptr<OrtValue> {
std::vector<int64_t> shape{1};
auto num_logits_to_keep = OrtValue::CreateTensor<int64_t>(state.model_.allocator_cpu_, shape);
*num_logits_to_keep->GetTensorMutableData<int64_t>() = 0;
return num_logits_to_keep;
}}} {}

void PresetExtraInputs::Add() {
const auto input_names_vector = state_.model_.session_info_->GetInputNames();
const std::unordered_set<std::string> input_names(state_.input_names_.begin(), state_.input_names_.end());
std::vector<std::string> unclaimed_input_names;
// Add any model input for which we don't have a corresponding input in the state to the unclaimed_input_names
for (const auto& input_name : input_names_vector) {
if (input_names.find(input_name) == input_names.end()) {
unclaimed_input_names.push_back(input_name);
}
}

// Try to claim the unclaimed inputs from the registry
for (const auto& input_name : unclaimed_input_names) {
auto it = registry_.find(input_name);
if (it != registry_.end()) {
extra_input_names_.push_back(input_name);
extra_inputs_.push_back(it->second());
state_.input_names_.push_back(extra_input_names_.back().c_str());
state_.inputs_.push_back(extra_inputs_.back().get());
} else if (input_name.rfind("onnx::Neg_", 0) == 0) {
// The unclaimed input has a prefix of onnx::Neg_, which is a special case
// We treat this as an alias to num_logits_to_keep
extra_input_names_.push_back(input_name);
extra_inputs_.push_back(registry_.at("num_logits_to_keep")());
state_.input_names_.push_back(extra_input_names_.back().c_str());
state_.inputs_.push_back(extra_inputs_.back().get());
}
}
}

ExtraInputs::ExtraInputs(State& state)
: state_{state} {
extra_inputs_.reserve(state_.params_->extra_inputs.size());
Expand Down Expand Up @@ -78,6 +118,8 @@ void ExtraInputs::Add() {
throw std::runtime_error("Unsupported device for graph capture");
}
}

registrar_.Add();
}

#pragma warning(pop)
Expand Down
13 changes: 13 additions & 0 deletions src/models/extra_inputs.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@

namespace Generators {

struct PresetExtraInputs {
PresetExtraInputs(State& state);
void Add();

private:
using FuncType = std::function<std::unique_ptr<OrtValue>()>;
State& state_;
std::unordered_map<std::string, FuncType> registry_;
std::vector<std::unique_ptr<OrtValue>> extra_inputs_;
std::vector<std::string> extra_input_names_;
};

struct ExtraInputs {
ExtraInputs(State& state);
void Add();
Expand All @@ -14,6 +26,7 @@ struct ExtraInputs {
std::vector<OrtValue*> extra_inputs_;
std::vector<std::unique_ptr<OrtValue>> owned_extra_inputs_;
std::unordered_map<std::string, StaticBuffer*> sb_extra_inputs_;
PresetExtraInputs registrar_{state_};
};

} // namespace Generators
8 changes: 8 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,14 @@ ONNXTensorElementDataType SessionInfo::GetOutputDataType(const std::string& name
return result->second;
}

std::vector<std::string> SessionInfo::GetInputNames() const {
std::vector<std::string> names;
names.reserve(inputs_.size());
for (const auto& input : inputs_)
names.push_back(input.first);
return names;
}

Model::Model(std::unique_ptr<Config> config) : config_{std::move(config)} {
CreateSessionOptions();
}
Expand Down
2 changes: 2 additions & 0 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ struct SessionInfo {
ONNXTensorElementDataType GetInputDataType(const std::string& name) const;
ONNXTensorElementDataType GetOutputDataType(const std::string& name) const;

std::vector<std::string> GetInputNames() const;

private:
std::unordered_map<std::string, ONNXTensorElementDataType> inputs_, outputs_;
};
Expand Down
75 changes: 73 additions & 2 deletions test/python/test_onnxruntime_genai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,12 +628,83 @@ def _export_adapter(adapter, adapter_file_name):
params = og.GeneratorParams(model)
params.set_search_options(max_length=20, batch_size=len(prompts))

print(len(adapter_paths))

generator = og.Generator(model, params)
for i in range(len(adapter_paths)):
generator.set_active_adapter(adapters, f"adapter_{i}")

generator.append_tokens(tokenizer.encode_batch(prompts))
while not generator.is_done():
generator.generate_next_token()


@pytest.mark.parametrize("device", devices)
@pytest.mark.skipif(
sysconfig.get_platform().endswith("arm64"),
reason="ONNX is not available on ARM64",
)
@pytest.mark.parametrize("extra_inputs", [("num_logits_to_keep", True), ("onnx::Neg_67", True), ("abcde", False)])
def test_preset_extra_inputs(test_data_path, device, phi2_for, extra_inputs):
def _prepare_model(test_data_path):
phi2_model_path = phi2_for(device)
relative_model_path = "preset_extra_inputs"
extra_inputs_model_path = os.fspath(Path(test_data_path) / relative_model_path)

shutil.copytree(phi2_model_path, extra_inputs_model_path, dirs_exist_ok=True)

# Create the model with the extra inputs
model = onnx.load(Path(extra_inputs_model_path) / "model.onnx")

for node in model.graph.node:
if node.name == "/lm_head/Add":
node.output[0] = "logits_0"
break

extra_input_name, valid = extra_inputs
extra_input = onnx.helper.make_tensor_value_info(
extra_input_name,
onnx.TensorProto.INT64,
[],
)

model.graph.input.append(extra_input)

cast_node = onnx.helper.make_node(
"Cast", [extra_input_name], [f"{extra_input_name}_cast"], to=onnx.TensorProto.FLOAT if device == "cpu" else onnx.TensorProto.FLOAT16
)
add_node = onnx.helper.make_node(
"Add", [f"{extra_input_name}_cast", "logits_0"], ["logits"], name="add_to_logits"
)
model.graph.node.extend([cast_node, add_node])

onnx.save(
model,
Path(extra_inputs_model_path) / "model.onnx",
save_as_external_data=True,
location="model.data",
)

return extra_inputs_model_path, valid

model_path, valid_model = _prepare_model(test_data_path)
model = og.Model(model_path)
tokenizer = og.Tokenizer(model)
prompts = [
"This is a test.",
"Rats are awesome pets!",
"The quick brown fox jumps over the lazy dog.",
]

params = og.GeneratorParams(model)
params.set_search_options(max_length=20, batch_size=len(prompts))

generator = og.Generator(model, params)
if not valid_model:
with pytest.raises(og.OrtException) as exc_info:
generator.append_tokens(tokenizer.encode_batch(prompts))

assert f"Missing Input: {extra_inputs[0]}" in str(exc_info.value)
else:
generator.append_tokens(tokenizer.encode_batch(prompts))

while not generator.is_done():
generator.generate_next_token()

0 comments on commit 43fa6ab

Please sign in to comment.