diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index b853d16c4..4cba33fe1 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -8,7 +8,7 @@ namespace Generators { DefaultInputIDs::DefaultInputIDs(State& state) : state_{state} { name_ = model_.config_->model.decoder.inputs.input_ids.c_str(); - shape_ = {state_.params_->BatchBeamSize(), 0}; + shape_ = {state_.params_->search.batch_size, 0}; type_ = model_.session_info_->GetInputDataType(name_); if (model_.session_info_->HasInput(model_.config_->model.decoder.inputs.current_sequence_length) && @@ -47,7 +47,7 @@ void DefaultInputIDs::Add() { void DefaultInputIDs::Update(DeviceSpan& new_tokens) { if (!value_) { - shape_[1] = static_cast(new_tokens.size()); + shape_[1] = static_cast(new_tokens.size()) / shape_[0]; // If 64-bit, convert from 32-bit to 64-bit auto input_ids = new_tokens.CopyDeviceToCpu();