diff --git a/src/models/debugging.cpp b/src/models/debugging.cpp index 11056bfde..f986b8688 100644 --- a/src/models/debugging.cpp +++ b/src/models/debugging.cpp @@ -88,47 +88,48 @@ void DumpTensor(const Model& model, std::ostream& stream, OrtValue* value, bool stream << SGR::Fg_Green << " Location: " << SGR::Reset; const auto& memory_info = value->GetTensorMemoryInfo(); - switch (memory_info.GetDeviceType()) { - case OrtMemoryInfoDeviceType_CPU: - stream << "CPU\r\n"; - DumpValues(stream, type_info->GetElementType(), value->GetTensorRawData(), element_count); - break; - case OrtMemoryInfoDeviceType_GPU: { - stream << "GPU\r\n"; + auto device_type = memory_info.GetDeviceType(); + if (device_type == OrtMemoryInfoDeviceType_CPU) { + stream << "CPU\r\n"; + DumpValues(stream, type_info->GetElementType(), value->GetTensorRawData(), element_count); + } else if (device_type == OrtMemoryInfoDeviceType_GPU) { + stream << "GPU\r\n"; #if USE_CUDA - auto type = type_info->GetElementType(); - size_t element_size = SizeOf(type); - auto cpu_copy = std::make_unique(element_size * element_count); - CudaCheck() == cudaMemcpy(cpu_copy.get(), value->GetTensorRawData(), element_size * element_count, cudaMemcpyDeviceToHost); - DumpValues(stream, type, cpu_copy.get(), element_count); -#elif USE_DML - auto type = type_info->GetElementType(); - size_t element_size = SizeOf(type); - auto cpu_copy = std::make_unique(element_size * element_count); - - if (value->GetTensorMutableRawData()) { - ComPtr gpu_resource; - Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation( - model.allocator_device_, - value->GetTensorMutableRawData(), - &gpu_resource)); - - model.GetDmlReadbackHeap()->ReadbackFromGpu( - std::span(cpu_copy.get(), element_size * element_count), - gpu_resource.Get(), - 0, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS); - } - - DumpValues(stream, type, cpu_copy.get(), element_count); + auto type = type_info->GetElementType(); + size_t element_size = SizeOf(type); + auto cpu_copy = std::make_unique(element_size * element_count); + CudaCheck() == cudaMemcpy(cpu_copy.get(), value->GetTensorRawData(), element_size * element_count, cudaMemcpyDeviceToHost); + DumpValues(stream, type, cpu_copy.get(), element_count); #else - stream << "Unexpected, using GPU memory but not compiled with CUDA or DML?"; + throw std::runtime_error("Unexpected error. Trying to access GPU memory but the project is not compiled with CUDA."); #endif - break; + } else if (static_cast(device_type) == 4) { + stream << "DML\r\n"; +#if USE_DML + auto type = type_info->GetElementType(); + size_t element_size = SizeOf(type); + auto cpu_copy = std::make_unique(element_size * element_count); + + if (value->GetTensorMutableRawData()) { + ComPtr gpu_resource; + Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation( + model.allocator_device_, + value->GetTensorMutableRawData(), + &gpu_resource)); + + model.GetDmlReadbackHeap()->ReadbackFromGpu( + std::span(cpu_copy.get(), element_size * element_count), + gpu_resource.Get(), + 0, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS); } - default: - stream << "Unhandled device type"; - break; + + DumpValues(stream, type, cpu_copy.get(), element_count); +#else + throw std::runtime_error("Unexpected error. Trying to access DML memory but the project is not compiled with DML."); +#endif + } else { + stream << "Unhandled device type: " << static_cast(device_type) << "\r\n"; } } diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index f99907b59..b853d16c4 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -11,16 +11,6 @@ DefaultInputIDs::DefaultInputIDs(State& state) shape_ = {state_.params_->BatchBeamSize(), 0}; type_ = model_.session_info_->GetInputDataType(name_); - if (state_.GetCapturedGraphInfo()) { - sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get(); - -#if USE_DML - if (model_.device_type_ == DeviceType::DML) { - sb_input_ids_int32_ = state_.GetCapturedGraphInfo()->sb_input_ids_int32_.get(); - } -#endif - } - if (model_.session_info_->HasInput(model_.config_->model.decoder.inputs.current_sequence_length) && model_.session_info_->HasInput(model_.config_->model.decoder.inputs.past_sequence_length)) { if (state_.params_->BatchBeamSize() != 1) { @@ -36,7 +26,7 @@ DefaultInputIDs::DefaultInputIDs(State& state) current_sequence_length_ = OrtValue::CreateTensor(model_.allocator_cpu_, current_sequence_length_shape, model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.current_sequence_length)); *current_sequence_length_->GetTensorMutableData() = 0; - past_sequence_length_ = OrtValue::CreateTensor(*model_.allocator_device_, past_sequence_length_shape, model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.past_sequence_length)); + past_sequence_length_ = OrtValue::CreateTensor(model_.allocator_cpu_, past_sequence_length_shape, model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.past_sequence_length)); *past_sequence_length_->GetTensorMutableData() = -1; } } @@ -56,6 +46,40 @@ void DefaultInputIDs::Add() { } void DefaultInputIDs::Update(DeviceSpan& new_tokens) { + if (!value_) { + shape_[1] = static_cast(new_tokens.size()); + + // If 64-bit, convert from 32-bit to 64-bit + auto input_ids = new_tokens.CopyDeviceToCpu(); + if (type_ == Ort::TypeToTensorType) { + value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_); + auto* p_data = value_->GetTensorMutableData(); + for (auto v : input_ids) { + *p_data++ = v; + } + } else { + if (type_ != Ort::TypeToTensorType) + throw std::runtime_error("InputIDs must be int64 or int32"); + value_ = OrtValue::CreateTensor(model_.allocator_cpu_.GetInfo(), input_ids, shape_); + } + + value_ = model_.ExpandInputs(value_, state_.params_->search.num_beams); + shape_[0] *= state_.params_->search.num_beams; + + if (state_.GetCapturedGraphInfo()) { + sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get(); + +#if USE_DML + if (model_.device_type_ == DeviceType::DML) { + sb_input_ids_int32_ = state_.GetCapturedGraphInfo()->sb_input_ids_int32_.get(); + } +#endif + } + + state_.inputs_[input_index_] = value_.get(); + return; + } + const auto get_unpadded_sequence_length = [](std::span input_ids, int32_t pad_token_id) { int32_t seq_length = 0; diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 9e5be5b57..edaf95d1a 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -16,15 +16,6 @@ Logits::Logits(State& state) type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} { output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); - if (state_.GetCapturedGraphInfo()) { - if (type_ == Ort::TypeToTensorType) { - sb_logits32_ = state_.GetCapturedGraphInfo()->sb_logits32_.get(); - } - if (type_ == Ort::TypeToTensorType) { - sb_logits16_ = state_.GetCapturedGraphInfo()->sb_logits16_.get(); - } - } - #if USE_CUDA if (model_.device_type_ == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) { auto& cpu_ids = model_.config_->model.eos_token_ids; @@ -215,6 +206,18 @@ void Logits::Update(const DeviceSpan& next_tokens, size_t new_kv_length StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType ? sb_logits16_ : sb_logits32_; output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) : sb_logits->CreateTensorOnStaticBuffer(shape_, type_); + + if (state_.GetCapturedGraphInfo()) { + if (!sb_logits16_ && !sb_logits32_) { + if (type_ == Ort::TypeToTensorType) { + sb_logits32_ = state_.GetCapturedGraphInfo()->sb_logits32_.get(); + } + if (type_ == Ort::TypeToTensorType) { + sb_logits16_ = state_.GetCapturedGraphInfo()->sb_logits16_.get(); + } + } + } + state_.outputs_[output_index_] = output_raw_.get(); } diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 3f2c9d750..c4bcf3015 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -337,11 +337,18 @@ OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator // Copy data to ortvalue_clone auto element_size = Generators::SizeOf(type_info->GetElementType()); auto data_size = type_info->GetElementCount() * element_size; - if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::CUDA) { + const auto device_type = ortvalue_output->GetTensorMemoryInfo().GetDeviceType(); + if (device_type == OrtMemoryInfoDeviceType_CPU) { + std::copy(static_cast(ortvalue_output->GetTensorMutableRawData()), + static_cast(ortvalue_output->GetTensorMutableRawData()) + data_size, + static_cast(ortvalue_clone->GetTensorMutableRawData())); + } else if (device_type == OrtMemoryInfoDeviceType_GPU) { #if USE_CUDA cudaMemcpy(ortvalue_clone->GetTensorMutableRawData(), ortvalue_output->GetTensorMutableRawData(), data_size, cudaMemcpyDeviceToHost); +#else + throw std::runtime_error("Unexpected error. Trying to access GPU memory but the project is not compiled with CUDA."); #endif - } else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::DML) { + } else if (static_cast(device_type) == 4) { #if USE_DML ComPtr gpu_resource; Ort::ThrowOnError(generator.model_->GetOrtDmlApi()->GetD3D12ResourceFromAllocation( @@ -354,13 +361,11 @@ OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator gpu_resource.Get(), 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS); +#else + throw std::runtime_error("Unexpected error. Trying to access DML memory but the project is not compiled with DML."); #endif - } else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_CPU) { - std::copy(static_cast(ortvalue_output->GetTensorMutableRawData()), - static_cast(ortvalue_output->GetTensorMutableRawData()) + data_size, - static_cast(ortvalue_clone->GetTensorMutableRawData())); } else { - throw std::runtime_error("Unsupported Device type: " + std::to_string(ortvalue_output->GetTensorMemoryInfo().GetDeviceType())); + throw std::runtime_error("Unsupported device type: " + static_cast(device_type)); } auto tensor = std::make_shared(std::move(ortvalue_clone));