-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
See https://github.com/quic/ai-hub-apps/releases/v0.3.0 for changelog. Signed-off-by: QAIHM Team <[email protected]>
- Loading branch information
Showing
80 changed files
with
97,673 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
// --------------------------------------------------------------------- | ||
// Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. | ||
// SPDX-License-Identifier: BSD-3-Clause | ||
// --------------------------------------------------------------------- | ||
#include "ChatApp.hpp" | ||
#include "PromptHandler.hpp" | ||
#include <fstream> | ||
#include <iostream> | ||
#include <regex> | ||
|
||
using namespace App; | ||
|
||
namespace | ||
{ | ||
|
||
constexpr const int c_chat_separater_length = 80; | ||
|
||
// | ||
// ChatSplit - Line to split during Chat for UX | ||
// Adds split line to separate out sections in output. | ||
// | ||
void ChatSplit(bool end_line = true) | ||
{ | ||
std::string split_line(c_chat_separater_length, '-'); | ||
std::cout << "\n" << split_line; | ||
if (end_line) | ||
{ | ||
std::cout << "\n"; | ||
} | ||
} | ||
|
||
// | ||
// GenieCallBack - Callback to handle response from Genie | ||
// - Captures response from Genie into user_data | ||
// - Print response to stdout | ||
// - Add ChatSplit upon sentence completion | ||
// | ||
void GenieCallBack(const char* response_back, const GenieDialog_SentenceCode_t sentence_code, const void* user_data) | ||
{ | ||
std::string* user_data_str = static_cast<std::string*>(const_cast<void*>(user_data)); | ||
user_data_str->append(response_back); | ||
|
||
// Write user response to output. | ||
std::cout << response_back; | ||
if (sentence_code == GenieDialog_SentenceCode_t::GENIE_DIALOG_SENTENCE_END) | ||
{ | ||
ChatSplit(false); | ||
} | ||
} | ||
|
||
// | ||
// LoadModelConfig - Loads model config file | ||
// - Loads config file in memory | ||
// - Replaces place-holders with user provided values | ||
// | ||
std::string LoadModelConfig(const std::string& model_config_path, | ||
const std::string& models_path, | ||
const std::string& htp_model_config_path, | ||
const std::string& tokenizer_path) | ||
{ | ||
std::string config; | ||
// Read config file into memory | ||
std::getline(std::ifstream(model_config_path), config, '\0'); | ||
|
||
// Replace place-holders in config file with user provided paths | ||
config = std::regex_replace(config, std::regex("<models_path>"), models_path); | ||
config = std::regex_replace(config, std::regex("<htp_backend_ext_path>"), htp_model_config_path); | ||
config = std::regex_replace(config, std::regex("<tokenizer_path>"), tokenizer_path); | ||
return config; | ||
} | ||
|
||
} // namespace | ||
|
||
ChatApp::ChatApp(const std::string& model_config_path, | ||
const std::string& models_path, | ||
const std::string& htp_config_path, | ||
const std::string& tokenizer_path) | ||
{ | ||
|
||
// Load model config in-memory | ||
std::string config = LoadModelConfig(model_config_path, models_path, htp_config_path, tokenizer_path); | ||
|
||
// Create Genie config | ||
if (GENIE_STATUS_SUCCESS != GenieDialogConfig_createFromJson(config.c_str(), &m_config_handle)) | ||
{ | ||
throw std::runtime_error("Failed to create the Genie Dialog config. Please check config file."); | ||
} | ||
|
||
// Create Genie dialog handle | ||
if (GENIE_STATUS_SUCCESS != GenieDialog_create(m_config_handle, &m_dialog_handle)) | ||
{ | ||
throw std::runtime_error("Failed to create the Genie Dialog."); | ||
} | ||
} | ||
|
||
ChatApp::~ChatApp() | ||
{ | ||
if (m_config_handle != nullptr) | ||
{ | ||
if (GENIE_STATUS_SUCCESS != GenieDialogConfig_free(m_config_handle)) | ||
{ | ||
std::cerr << "Failed to free the Genie Dialog config."; | ||
} | ||
} | ||
|
||
if (m_dialog_handle != nullptr) | ||
{ | ||
if (GENIE_STATUS_SUCCESS != GenieDialog_free(m_dialog_handle)) | ||
{ | ||
std::cerr << "Failed to free the Genie Dialog."; | ||
} | ||
} | ||
} | ||
|
||
void ChatApp::ChatWithUser(const std::string& user_name) | ||
{ | ||
AppUtils::Llama2PromptHandler prompt_handler; | ||
|
||
// Initiate Chat with infinite loop. | ||
// User to provide `exit` as a prompt to exit. | ||
while (true) | ||
{ | ||
std::string user_prompt; | ||
std::string model_response; | ||
|
||
// Input user prompt | ||
ChatSplit(); | ||
std::cout << user_name << ": "; | ||
std::getline(std::cin, user_prompt); | ||
|
||
// Exit prompt | ||
if (user_prompt.compare(c_exit_prompt) == 0) | ||
{ | ||
std::cout << "Exiting chat as per " << user_name << "'s request."; | ||
return; | ||
} | ||
// User provides an empty prompt | ||
if (user_prompt.empty()) | ||
{ | ||
std::cout << "\nPlease enter prompt.\n"; | ||
continue; | ||
} | ||
|
||
std::string tagged_prompt = prompt_handler.GetPromptWithTag(user_prompt); | ||
|
||
// Bot's response | ||
std::cout << c_bot_name << ":"; | ||
if (GENIE_STATUS_SUCCESS != GenieDialog_query(m_dialog_handle, tagged_prompt.c_str(), | ||
GenieDialog_SentenceCode_t::GENIE_DIALOG_SENTENCE_COMPLETE, | ||
GenieCallBack, &model_response)) | ||
{ | ||
throw std::runtime_error("Failed to get response from GenieDialog. Please restart the ChatApp."); | ||
} | ||
|
||
if (model_response.empty()) | ||
{ | ||
// If model response is empty, reset dialog to re-initiate dialog. | ||
// During local testing, we found that in certain cases, | ||
// model response bails out after few iterations during chat. | ||
// If that happens, just reset Dialog handle to continue the chat. | ||
if (GENIE_STATUS_SUCCESS != GenieDialog_reset(m_dialog_handle)) | ||
{ | ||
throw std::runtime_error("Failed to reset Genie Dialog."); | ||
} | ||
if (GENIE_STATUS_SUCCESS != GenieDialog_query(m_dialog_handle, tagged_prompt.c_str(), | ||
GenieDialog_SentenceCode_t::GENIE_DIALOG_SENTENCE_COMPLETE, | ||
GenieCallBack, &model_response)) | ||
{ | ||
throw std::runtime_error("Failed to get response from GenieDialog. Please restart the ChatApp."); | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.