Skip to content

Commit

Permalink
implement fix for set inputs (#1152)
Browse files Browse the repository at this point in the history
### Fix SetInputs for Multi-Modal Models

#### Description
This PR addresses an issue where `SetInputs`, which is used with
multi-modal models like phi3v and whisper, was broken. The error
encountered was:

```
RuntimeError: Please use params.SetInputs for phi3v. AppendTokens is not supported for this model type.
```

#### Root Cause
The error was caused because `AppendTokens` is an API only meant to be
used by models that support it. However, in the case of multi modal
models, it was called internally during the construction of the
generator to process inputs, leading to the runtime error.

#### Fix
This PR avoids calling `AppendTokens` internally, thereby fixing the
issue with `SetInputs` on multi-modal models. It also fixes this issue:
#1151.
  • Loading branch information
aciddelgado authored Dec 18, 2024
1 parent c77e525 commit 12999f3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_

// Temporary solution for multimodal and whisper models
if (!params.aux_input_ids.empty() && params.aux_input_ids.data() != nullptr) {
AppendTokens(params.aux_input_ids);
AuxAppendTokens(params.aux_input_ids);
}
}

Expand All @@ -274,6 +274,20 @@ DeviceSpan<int32_t> Generator::AllocateInputIdsOnDevice(const cpu_span<int32_t>
return input_ids_device;
}

// TODO(aciddelgado): Remove this function once SetInputs is moved to generator
void Generator::AuxAppendTokens(const cpu_span<int32_t> input_ids) {
ThrowErrorIfSessionTerminated(state_->session_terminated_);
if (input_ids.size() == 0)
throw std::runtime_error("input_ids is empty");
if (search_->GetSequenceLength() != 0 && state_->params_->search.batch_size > 1)
throw std::runtime_error("AppendTokens can only be called once for batch_size > 1. To call AppendTokens again, use RewindToLength(0)");

auto input_ids_device = AllocateInputIdsOnDevice(input_ids);
search_->AppendTokens(input_ids_device);
computed_logits_ = false;
ComputeLogits(input_ids_device);
}

void Generator::AppendTokens(const cpu_span<int32_t> input_ids) {
ThrowErrorIfSessionTerminated(state_->session_terminated_);
if (input_ids.size() == 0)
Expand Down
1 change: 1 addition & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ struct Generator : LeakChecked<Generator> {

private:
DeviceSpan<int32_t> AllocateInputIdsOnDevice(const cpu_span<int32_t> input_ids);
void AuxAppendTokens(const cpu_span<int32_t> input_ids);
void ComputeLogits(DeviceSpan<int32_t> next_tokens);
enum Action { standard, // Default, set in any other case
generated, // Set after GenerateNextToken
Expand Down

0 comments on commit 12999f3

Please sign in to comment.