Skip to content

Commit

Permalink
implement fix for set inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
aciddelgado committed Dec 16, 2024
1 parent 7735e10 commit 3fc4ba1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_

// Temporary solution for multimodal and whisper models
if (!params.aux_input_ids.empty() && params.aux_input_ids.data() != nullptr) {
AppendTokens(params.aux_input_ids);
AuxAppendTokens(params.aux_input_ids);
}
}

Expand All @@ -274,6 +274,20 @@ DeviceSpan<int32_t> Generator::AllocateInputIdsOnDevice(const cpu_span<int32_t>
return input_ids_device;
}

// TODO(aciddelgado): Remove this function once SetInputs is moved to generator
void Generator::AuxAppendTokens(const cpu_span<int32_t> input_ids) {
ThrowErrorIfSessionTerminated(state_->session_terminated_);
if (input_ids.size() == 0)
throw std::runtime_error("input_ids is empty");
if (search_->GetSequenceLength() != 0 && state_->params_->search.batch_size > 1)
throw std::runtime_error("AppendTokens can only be called once for batch_size > 1. To call AppendTokens again, use RewindToLength(0)");

auto input_ids_device = AllocateInputIdsOnDevice(input_ids);
search_->AppendTokens(input_ids_device);
computed_logits_ = false;
ComputeLogits(input_ids_device);
}

void Generator::AppendTokens(const cpu_span<int32_t> input_ids) {
ThrowErrorIfSessionTerminated(state_->session_terminated_);
if (input_ids.size() == 0)
Expand Down
1 change: 1 addition & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ struct Generator : LeakChecked<Generator> {

private:
DeviceSpan<int32_t> AllocateInputIdsOnDevice(const cpu_span<int32_t> input_ids);
void AuxAppendTokens(const cpu_span<int32_t> input_ids);
void ComputeLogits(DeviceSpan<int32_t> next_tokens);
enum Action { standard, // Default, set in any other case
generated, // Set after GenerateNextToken
Expand Down

0 comments on commit 3fc4ba1

Please sign in to comment.