Skip to content

Commit

Permalink
Add model freedom and ollama support
Browse files Browse the repository at this point in the history
  • Loading branch information
tharun571 committed Sep 27, 2024
1 parent 8dc85b4 commit 5205380
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 12 deletions.
Binary file modified docs/source/gemini.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 19 additions & 4 deletions docs/source/magics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,29 @@ Here are the magics available in xeus-cpp.
%%xassist
========================

Leverage the large language models to assist in your development process. Currently supported models are Gemini - gemini-1.5-flash, OpenAI - gpt-3.5-turbo-16k.
Leverage the large language models to assist in your development process. Currently supported models are Gemini, OpenAI, Ollama.

- Save the api key
- Save the api key (for OpenAI and Gemini)

.. code::
%%xassist model --save-key
key
- Save the model

- Set the response url (for Ollama)

.. code::
%%xassist model --set-url
key
.. code::
%%xassist model --save-model
key
- Use the model

.. code::
Expand All @@ -33,9 +47,10 @@ Leverage the large language models to assist in your development process. Curren
.. code::
%%xassist model --refresh
- Example
- Examples

.. image:: gemini.png

A new prompt is sent to the model everytime and the functionality to use previous context will be added soon.
.. image:: ollama.png
Binary file added docs/source/ollama.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
169 changes: 161 additions & 8 deletions src/xmagics/xassist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,80 @@ namespace xcpp
}
};

class model_manager
{
public:

static void save_model(const std::string& model, const std::string& model_name)
{
std::string model_file_path = model + "_model.txt";
std::ofstream out(model_file_path);
if (out)
{
out << model_name;
out.close();
std::cout << "Model saved for model " << model << std::endl;
}
else
{
std::cerr << "Failed to open file for writing model for model " << model << std::endl;
}
}

static std::string load_model(const std::string& model)
{
std::string model_file_path = model + "_model.txt";
std::ifstream in(model_file_path);
std::string model_name;
if (in)
{
std::getline(in, model_name);
in.close();
return model_name;
}

std::cerr << "Failed to open file for reading model for model " << model << std::endl;
return "";
}
};

class url_manager
{
public:

static void save_url(const std::string& model, const std::string& url)
{
std::string url_file_path = model + "_url.txt";
std::ofstream out(url_file_path);
if (out)
{
out << url;
out.close();
std::cout << "URL saved for model " << model << std::endl;
}
else
{
std::cerr << "Failed to open file for writing URL for model " << model << std::endl;
}
}

static std::string load_url(const std::string& model)
{
std::string url_file_path = model + "_url.txt";
std::ifstream in(url_file_path);
std::string url;
if (in)
{
std::getline(in, url);
in.close();
return url;
}

std::cerr << "Failed to open file for reading URL for model " << model << std::endl;
return "";
}
};

class chat_history
{
public:
Expand Down Expand Up @@ -209,8 +283,16 @@ namespace xcpp
{
curl_helper curl_helper;
const std::string chat_message = xcpp::chat_history::chat("gemini", "user", cell);
const std::string url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key="
+ key;
const std::string model = xcpp::model_manager::load_model("gemini");

if (model.empty())
{
std::cerr << "Model not found." << std::endl;
return "";
}

const std::string url = "https://generativelanguage.googleapis.com/v1beta/models/" + model
+ ":generateContent?key=" + key;
const std::string post_data = R"({"contents": [ )" + chat_message + R"(]})";

std::string response = curl_helper.perform_request(url, post_data);
Expand All @@ -231,13 +313,64 @@ namespace xcpp
return j["candidates"][0]["content"]["parts"][0]["text"];
}

std::string ollama(const std::string& cell)
{
curl_helper curl_helper;
const std::string url = xcpp::url_manager::load_url("ollama");
const std::string chat_message = xcpp::chat_history::chat("ollama", "user", cell);
const std::string model = xcpp::model_manager::load_model("ollama");

if (model.empty())
{
std::cerr << "Model not found." << std::endl;
return "";
}

if (url.empty())
{
std::cerr << "URL not found." << std::endl;
return "";
}

const std::string post_data = R"({
"model": ")" + model
+ R"(",
"messages": [)" + chat_message
+ R"(],
"stream": false
})";

std::string response = curl_helper.perform_request(url, post_data);

json j = json::parse(response);

if (j.find("error") != j.end())
{
std::cerr << "Error: " << j["error"]["message"] << std::endl;
return "";
}

const std::string chat = xcpp::chat_history::chat("ollama", "assistant", j["message"]["content"]);

return j["message"]["content"];
}

std::string openai(const std::string& cell, const std::string& key)
{
curl_helper curl_helper;
const std::string url = "https://api.openai.com/v1/chat/completions";
const std::string chat_message = xcpp::chat_history::chat("openai", "user", cell);
const std::string model = xcpp::model_manager::load_model("openai");

if (model.empty())
{
std::cerr << "Model not found." << std::endl;
return "";
}

const std::string post_data = R"({
"model": "gpt-3.5-turbo-16k",
"model": [)" + model
+ R"(],
"messages": [)" + chat_message
+ R"(],
"temperature": 0.7
Expand Down Expand Up @@ -273,7 +406,7 @@ namespace xcpp
std::istream_iterator<std::string>()
);

std::vector<std::string> models = {"gemini", "openai"};
std::vector<std::string> models = {"gemini", "openai", "ollama"};
std::string model = tokens[1];

if (std::find(models.begin(), models.end(), model) == models.end())
Expand All @@ -295,13 +428,29 @@ namespace xcpp
xcpp::chat_history::refresh(model);
return;
}

if (tokens[2] == "--save-model")
{
xcpp::model_manager::save_model(model, cell);
return;
}

if (tokens[2] == "--set-url" && model == "ollama")
{
xcpp::url_manager::save_url(model, cell);
return;
}
}

std::string key = xcpp::api_key_manager::load_api_key(model);
if (key.empty())
std::string key;
if (model != "ollama")
{
std::cerr << "API key for model " << model << " is not available." << std::endl;
return;
key = xcpp::api_key_manager::load_api_key(model);
if (key.empty())
{
std::cerr << "API key for model " << model << " is not available." << std::endl;
return;
}
}

std::string response;
Expand All @@ -313,6 +462,10 @@ namespace xcpp
{
response = openai(cell, key);
}
else if (model == "ollama")
{
response = ollama(cell);
}

std::cout << response;
}
Expand Down
35 changes: 35 additions & 0 deletions test/test_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -962,4 +962,39 @@ TEST_SUITE("xassist"){
std::remove("openai_api_key.txt");
}

TEST_CASE("ollama"){
xcpp::xassist assist;
std::string line = "%%xassist ollama --set-url";
std::string cell = "1234";

assist(line, cell);

std::ifstream infile("ollama_url.txt");
std::string content;
std::getline(infile, content);

REQUIRE(content == "1234");
infile.close();

line = "%%xassist ollama --save-model";
cell = "1234";

assist(line, cell);

std::ifstream infile_model("ollama_model.txt");
std::string content_model;
std::getline(infile_model, content_model);

REQUIRE(content_model == "1234");
infile_model.close();

StreamRedirectRAII redirect(std::cerr);

assist("%%xassist openai", "hello");

REQUIRE(!redirect.getCaptured().empty());

std::remove("openai_api_key.txt");
}

}

0 comments on commit 5205380

Please sign in to comment.