From 58fedf7380022b2bb3322727330c3204df490c89 Mon Sep 17 00:00:00 2001
From: Baiju Meswani <bmeswani@microsoft.com>
Date: Thu, 19 Dec 2024 14:18:00 -0800
Subject: [PATCH] Address a DML regression caused by the continuous decoding
 changes

---
 src/models/debugging.cpp | 75 ++++++++++++++++++++--------------------
 src/models/input_ids.cpp | 46 ++++++++++++++++++------
 src/models/logits.cpp    | 21 ++++++-----
 src/ort_genai_c.cpp      | 19 ++++++----
 4 files changed, 97 insertions(+), 64 deletions(-)

diff --git a/src/models/debugging.cpp b/src/models/debugging.cpp
index 11056bfde..f986b8688 100644
--- a/src/models/debugging.cpp
+++ b/src/models/debugging.cpp
@@ -88,47 +88,48 @@ void DumpTensor(const Model& model, std::ostream& stream, OrtValue* value, bool
   stream << SGR::Fg_Green << " Location: " << SGR::Reset;
 
   const auto& memory_info = value->GetTensorMemoryInfo();
-  switch (memory_info.GetDeviceType()) {
-    case OrtMemoryInfoDeviceType_CPU:
-      stream << "CPU\r\n";
-      DumpValues(stream, type_info->GetElementType(), value->GetTensorRawData(), element_count);
-      break;
-    case OrtMemoryInfoDeviceType_GPU: {
-      stream << "GPU\r\n";
+  auto device_type = memory_info.GetDeviceType();
+  if (device_type == OrtMemoryInfoDeviceType_CPU) {
+    stream << "CPU\r\n";
+    DumpValues(stream, type_info->GetElementType(), value->GetTensorRawData(), element_count);
+  } else if (device_type == OrtMemoryInfoDeviceType_GPU) {
+    stream << "GPU\r\n";
 #if USE_CUDA
-      auto type = type_info->GetElementType();
-      size_t element_size = SizeOf(type);
-      auto cpu_copy = std::make_unique<uint8_t[]>(element_size * element_count);
-      CudaCheck() == cudaMemcpy(cpu_copy.get(), value->GetTensorRawData(), element_size * element_count, cudaMemcpyDeviceToHost);
-      DumpValues(stream, type, cpu_copy.get(), element_count);
-#elif USE_DML
-      auto type = type_info->GetElementType();
-      size_t element_size = SizeOf(type);
-      auto cpu_copy = std::make_unique<uint8_t[]>(element_size * element_count);
-
-      if (value->GetTensorMutableRawData()) {
-        ComPtr<ID3D12Resource> gpu_resource;
-        Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(
-            model.allocator_device_,
-            value->GetTensorMutableRawData(),
-            &gpu_resource));
-
-        model.GetDmlReadbackHeap()->ReadbackFromGpu(
-            std::span(cpu_copy.get(), element_size * element_count),
-            gpu_resource.Get(),
-            0,
-            D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
-      }
-
-      DumpValues(stream, type, cpu_copy.get(), element_count);
+    auto type = type_info->GetElementType();
+    size_t element_size = SizeOf(type);
+    auto cpu_copy = std::make_unique<uint8_t[]>(element_size * element_count);
+    CudaCheck() == cudaMemcpy(cpu_copy.get(), value->GetTensorRawData(), element_size * element_count, cudaMemcpyDeviceToHost);
+    DumpValues(stream, type, cpu_copy.get(), element_count);
 #else
-      stream << "Unexpected, using GPU memory but not compiled with CUDA or DML?";
+    throw std::runtime_error("Unexpected error. Trying to access GPU memory but the project is not compiled with CUDA.");
 #endif
-      break;
+  } else if (static_cast<int>(device_type) == 4) {
+    stream << "DML\r\n";
+#if USE_DML
+    auto type = type_info->GetElementType();
+    size_t element_size = SizeOf(type);
+    auto cpu_copy = std::make_unique<uint8_t[]>(element_size * element_count);
+
+    if (value->GetTensorMutableRawData()) {
+      ComPtr<ID3D12Resource> gpu_resource;
+      Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(
+          model.allocator_device_,
+          value->GetTensorMutableRawData(),
+          &gpu_resource));
+
+      model.GetDmlReadbackHeap()->ReadbackFromGpu(
+          std::span(cpu_copy.get(), element_size * element_count),
+          gpu_resource.Get(),
+          0,
+          D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
     }
-    default:
-      stream << "Unhandled device type";
-      break;
+
+    DumpValues(stream, type, cpu_copy.get(), element_count);
+#else
+    throw std::runtime_error("Unexpected error. Trying to access DML memory but the project is not compiled with DML.");
+#endif
+  } else {
+    stream << "Unhandled device type: " << static_cast<int>(device_type) << "\r\n";
   }
 }
 
diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp
index f99907b59..b853d16c4 100644
--- a/src/models/input_ids.cpp
+++ b/src/models/input_ids.cpp
@@ -11,16 +11,6 @@ DefaultInputIDs::DefaultInputIDs(State& state)
   shape_ = {state_.params_->BatchBeamSize(), 0};
   type_ = model_.session_info_->GetInputDataType(name_);
 
-  if (state_.GetCapturedGraphInfo()) {
-    sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get();
-
-#if USE_DML
-    if (model_.device_type_ == DeviceType::DML) {
-      sb_input_ids_int32_ = state_.GetCapturedGraphInfo()->sb_input_ids_int32_.get();
-    }
-#endif
-  }
-
   if (model_.session_info_->HasInput(model_.config_->model.decoder.inputs.current_sequence_length) &&
       model_.session_info_->HasInput(model_.config_->model.decoder.inputs.past_sequence_length)) {
     if (state_.params_->BatchBeamSize() != 1) {
@@ -36,7 +26,7 @@ DefaultInputIDs::DefaultInputIDs(State& state)
     current_sequence_length_ = OrtValue::CreateTensor(model_.allocator_cpu_, current_sequence_length_shape, model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.current_sequence_length));
     *current_sequence_length_->GetTensorMutableData<int32_t>() = 0;
 
-    past_sequence_length_ = OrtValue::CreateTensor(*model_.allocator_device_, past_sequence_length_shape, model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.past_sequence_length));
+    past_sequence_length_ = OrtValue::CreateTensor(model_.allocator_cpu_, past_sequence_length_shape, model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.past_sequence_length));
     *past_sequence_length_->GetTensorMutableData<int32_t>() = -1;
   }
 }
@@ -56,6 +46,40 @@ void DefaultInputIDs::Add() {
 }
 
 void DefaultInputIDs::Update(DeviceSpan<int32_t>& new_tokens) {
+  if (!value_) {
+    shape_[1] = static_cast<int64_t>(new_tokens.size());
+
+    // If 64-bit, convert from 32-bit to 64-bit
+    auto input_ids = new_tokens.CopyDeviceToCpu();
+    if (type_ == Ort::TypeToTensorType<int64_t>) {
+      value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_);
+      auto* p_data = value_->GetTensorMutableData<int64_t>();
+      for (auto v : input_ids) {
+        *p_data++ = v;
+      }
+    } else {
+      if (type_ != Ort::TypeToTensorType<int32_t>)
+        throw std::runtime_error("InputIDs must be int64 or int32");
+      value_ = OrtValue::CreateTensor<int32_t>(model_.allocator_cpu_.GetInfo(), input_ids, shape_);
+    }
+
+    value_ = model_.ExpandInputs(value_, state_.params_->search.num_beams);
+    shape_[0] *= state_.params_->search.num_beams;
+
+    if (state_.GetCapturedGraphInfo()) {
+      sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get();
+
+#if USE_DML
+      if (model_.device_type_ == DeviceType::DML) {
+        sb_input_ids_int32_ = state_.GetCapturedGraphInfo()->sb_input_ids_int32_.get();
+      }
+#endif
+    }
+
+    state_.inputs_[input_index_] = value_.get();
+    return;
+  }
+
   const auto get_unpadded_sequence_length = [](std::span<const int32_t> input_ids,
                                                int32_t pad_token_id) {
     int32_t seq_length = 0;
diff --git a/src/models/logits.cpp b/src/models/logits.cpp
index 9e5be5b57..edaf95d1a 100644
--- a/src/models/logits.cpp
+++ b/src/models/logits.cpp
@@ -16,15 +16,6 @@ Logits::Logits(State& state)
       type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} {
   output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
 
-  if (state_.GetCapturedGraphInfo()) {
-    if (type_ == Ort::TypeToTensorType<float>) {
-      sb_logits32_ = state_.GetCapturedGraphInfo()->sb_logits32_.get();
-    }
-    if (type_ == Ort::TypeToTensorType<Ort::Float16_t>) {
-      sb_logits16_ = state_.GetCapturedGraphInfo()->sb_logits16_.get();
-    }
-  }
-
 #if USE_CUDA
   if (model_.device_type_ == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) {
     auto& cpu_ids = model_.config_->model.eos_token_ids;
@@ -215,6 +206,18 @@ void Logits::Update(const DeviceSpan<int32_t>& next_tokens, size_t new_kv_length
   StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType<Ort::Float16_t> ? sb_logits16_ : sb_logits32_;
   output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_)
                            : sb_logits->CreateTensorOnStaticBuffer(shape_, type_);
+
+  if (state_.GetCapturedGraphInfo()) {
+    if (!sb_logits16_ && !sb_logits32_) {
+      if (type_ == Ort::TypeToTensorType<float>) {
+        sb_logits32_ = state_.GetCapturedGraphInfo()->sb_logits32_.get();
+      }
+      if (type_ == Ort::TypeToTensorType<Ort::Float16_t>) {
+        sb_logits16_ = state_.GetCapturedGraphInfo()->sb_logits16_.get();
+      }
+    }
+  }
+
   state_.outputs_[output_index_] = output_raw_.get();
 }
 
diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp
index 3f2c9d750..c4bcf3015 100644
--- a/src/ort_genai_c.cpp
+++ b/src/ort_genai_c.cpp
@@ -337,11 +337,18 @@ OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator
   // Copy data to ortvalue_clone
   auto element_size = Generators::SizeOf(type_info->GetElementType());
   auto data_size = type_info->GetElementCount() * element_size;
-  if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::CUDA) {
+  const auto device_type = ortvalue_output->GetTensorMemoryInfo().GetDeviceType();
+  if (device_type == OrtMemoryInfoDeviceType_CPU) {
+    std::copy(static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()),
+              static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()) + data_size,
+              static_cast<uint8_t*>(ortvalue_clone->GetTensorMutableRawData()));
+  } else if (device_type == OrtMemoryInfoDeviceType_GPU) {
 #if USE_CUDA
     cudaMemcpy(ortvalue_clone->GetTensorMutableRawData(), ortvalue_output->GetTensorMutableRawData(), data_size, cudaMemcpyDeviceToHost);
+#else
+    throw std::runtime_error("Unexpected error. Trying to access GPU memory but the project is not compiled with CUDA.");
 #endif
-  } else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::DML) {
+  } else if (static_cast<int>(device_type) == 4) {
 #if USE_DML
     ComPtr<ID3D12Resource> gpu_resource;
     Ort::ThrowOnError(generator.model_->GetOrtDmlApi()->GetD3D12ResourceFromAllocation(
@@ -354,13 +361,11 @@ OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator
         gpu_resource.Get(),
         0,
         D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
+#else
+    throw std::runtime_error("Unexpected error. Trying to access DML memory but the project is not compiled with DML.");
 #endif
-  } else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_CPU) {
-    std::copy(static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()),
-              static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()) + data_size,
-              static_cast<uint8_t*>(ortvalue_clone->GetTensorMutableRawData()));
   } else {
-    throw std::runtime_error("Unsupported Device type: " + std::to_string(ortvalue_output->GetTensorMemoryInfo().GetDeviceType()));
+    throw std::runtime_error("Unsupported device type: " + static_cast<int>(device_type));
   }
 
   auto tensor = std::make_shared<Generators::Tensor>(std::move(ortvalue_clone));