Skip to content

Commit

Permalink
Fix batch size issues
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Dec 20, 2024
1 parent 58fedf7 commit b0624a2
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Generators {
DefaultInputIDs::DefaultInputIDs(State& state)
: state_{state} {
name_ = model_.config_->model.decoder.inputs.input_ids.c_str();
shape_ = {state_.params_->BatchBeamSize(), 0};
shape_ = {state_.params_->search.batch_size, 0};
type_ = model_.session_info_->GetInputDataType(name_);

if (model_.session_info_->HasInput(model_.config_->model.decoder.inputs.current_sequence_length) &&
Expand Down Expand Up @@ -47,7 +47,7 @@ void DefaultInputIDs::Add() {

void DefaultInputIDs::Update(DeviceSpan<int32_t>& new_tokens) {
if (!value_) {
shape_[1] = static_cast<int64_t>(new_tokens.size());
shape_[1] = static_cast<int64_t>(new_tokens.size()) / shape_[0];

// If 64-bit, convert from 32-bit to 64-bit
auto input_ids = new_tokens.CopyDeviceToCpu();
Expand Down

0 comments on commit b0624a2

Please sign in to comment.