From c0eb49c38e0a5627b09082e36eecfeacb7a55a60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aliaksandr=20Tru=C5=A1?= Date: Fri, 19 Apr 2024 10:35:15 +0200 Subject: [PATCH 1/9] chore: add type annotations --- lua/codeium/api.lua | 3 ++- lua/codeium/config.lua | 18 ++++++++++++++++++ lua/codeium/init.lua | 5 +++-- lua/codeium/io.lua | 1 + lua/codeium/notify.lua | 1 + lua/codeium/source.lua | 1 + lua/codeium/update.lua | 2 ++ lua/codeium/util.lua | 1 + 8 files changed, 29 insertions(+), 3 deletions(-) diff --git a/lua/codeium/api.lua b/lua/codeium/api.lua index 6217752..2d30160 100644 --- a/lua/codeium/api.lua +++ b/lua/codeium/api.lua @@ -4,7 +4,6 @@ local io = require("codeium.io") local log = require("codeium.log") local update = require("codeium.update") local notify = require("codeium.notify") -local util = require("codeium.util") local api_key = nil local function find_port(manager_dir, start_time) @@ -36,6 +35,7 @@ local function get_request_metadata(request_id) } end +---@class codeium.Server local Server = {} Server.__index = Server @@ -133,6 +133,7 @@ function Server.authenticate() prompt() end +---@return codeium.Server function Server:new() local m = {} setmetatable(m, self) diff --git a/lua/codeium/config.lua b/lua/codeium/config.lua index 9bcf1de..c8a8877 100644 --- a/lua/codeium/config.lua +++ b/lua/codeium/config.lua @@ -1,7 +1,10 @@ local notify = require("codeium.notify") +---@class codeium.config +---@field options codeium.options local M = {} +---@return codeium.options function M.defaults() return { manager_path = nil, @@ -56,8 +59,23 @@ function M.apply_conditional_defaults(options) return options end +---@class codeium.options +---@field manager_path string +---@field bin_path string +---@field config_path string +---@field language_server_download_url string +---@field api table +---@field enterprise_mode boolean +---@field detect_proxy boolean +---@field tools table +---@field wrapper function +---@field enable_chat boolean +---@field enable_local_search boolean +---@field enable_index_service boolean +---@field search_max_workspace_file_count number M.options = {} +---@param options codeium.options|nil function M.setup(options) options = options or {} diff --git a/lua/codeium/init.lua b/lua/codeium/init.lua index f57046d..769d5e1 100644 --- a/lua/codeium/init.lua +++ b/lua/codeium/init.lua @@ -4,7 +4,8 @@ function M.setup(options) local Source = require("codeium.source") local Server = require("codeium.api") local update = require("codeium.update") - require("codeium.config").setup(options) + local config = require("codeium.config") + config.setup(options) local s = Server:new() update.download(function(err) @@ -27,7 +28,7 @@ function M.setup(options) nargs = 1, complete = function() local commands = {"Auth"} - if require("codeium.config").options.enable_chat then + if config.options.enable_chat then commands = vim.list_extend(commands, {"Chat"}) end return commands diff --git a/lua/codeium/io.lua b/lua/codeium/io.lua index 7cf009b..efe4265 100644 --- a/lua/codeium/io.lua +++ b/lua/codeium/io.lua @@ -6,6 +6,7 @@ local curl = require("plenary.curl") local config = require("codeium.config") local default_mod = 438 -- 666 +---@class codeium.io local M = {} local function check_job(job, status) diff --git a/lua/codeium/notify.lua b/lua/codeium/notify.lua index b22dda2..6eeb22f 100644 --- a/lua/codeium/notify.lua +++ b/lua/codeium/notify.lua @@ -1,5 +1,6 @@ local log = require("codeium.log") +---@class codeium.notify local M = {} local opts = { title = "Codeium", diff --git a/lua/codeium/source.lua b/lua/codeium/source.lua index 2c0eb4f..bc0ce6a 100644 --- a/lua/codeium/source.lua +++ b/lua/codeium/source.lua @@ -99,6 +99,7 @@ local function get_other_documents(bufnr) return other_documents end +---@class codeium.Source local Source = { server = nil, } diff --git a/lua/codeium/update.lua b/lua/codeium/update.lua index ea2e54b..e65ff67 100644 --- a/lua/codeium/update.lua +++ b/lua/codeium/update.lua @@ -2,6 +2,8 @@ local config = require("codeium.config") local versions = require("codeium.versions") local io = require("codeium.io") local notify = require("codeium.notify") + +---@class codeium.update local M = {} local cached = nil diff --git a/lua/codeium/util.lua b/lua/codeium/util.lua index e401799..769e2b1 100644 --- a/lua/codeium/util.lua +++ b/lua/codeium/util.lua @@ -1,4 +1,5 @@ local enums = require("codeium.enums") +---@class codeium.util local M = {} function M.fallback_call(calls, with_filter, fallback_value) From d633dfd4eac4c133e16798b17c9570cd70d4d543 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aliaksandr=20Tru=C5=A1?= Date: Sun, 21 Apr 2024 09:57:37 +0200 Subject: [PATCH 2/9] Start chat functionality refactoring - separate functions for get_chat_ports and open_chat --- lua/codeium/api.lua | 113 ++++++++++++++++++++++++++++--------------- lua/codeium/init.lua | 38 +++++++++++---- 2 files changed, 103 insertions(+), 48 deletions(-) diff --git a/lua/codeium/api.lua b/lua/codeium/api.lua index 2d30160..9ade176 100644 --- a/lua/codeium/api.lua +++ b/lua/codeium/api.lua @@ -142,6 +142,7 @@ function Server:new() setmetatable(o, m) local port = nil + local chat_ports = nil local job = nil local current_cookie = nil local workspaces = {} @@ -155,6 +156,14 @@ function Server:new() }) end + local function chat_request(fn, payload, callback) + local url = "http://127.0.0.1:" .. chat_ports.chatClientPort .. "/exa.language_server_pb.LanguageServerService/" .. fn + io.post(url, { + body = payload, + callback = callback, + }) + end + local function do_heartbeat() request("Heartbeat", { metadata = get_request_metadata(), @@ -279,6 +288,7 @@ function Server:new() port = find_port(manager_dir, start_time) if port then + notify.info("Codeium server started") cancel() start_heartbeat() end @@ -397,49 +407,76 @@ function Server:new() end) end - function m.get_chat_ports() - request("GetProcesses", { - metadata = get_request_metadata(), - }, function(body, err) - if err then - notify.error("failed to get chat ports", err) + function m.init_chat() + io.timer(100, 500, function(cancel) + if not port then return end - local ports = vim.fn.json_decode(body) - local url = "http://127.0.0.1:" - .. ports.chatClientPort - .. "?api_key=" - .. api_key - .. "&has_enterprise_extension=" - .. (config.options.enterprise_mode and "true" or "false") - .. "&web_server_url=ws://127.0.0.1:" - .. ports.chatWebServerPort - .. "&ide_name=neovim" - .. "&ide_version=" - .. versions.nvim - .. "&app_name=codeium.nvim" - .. "&extension_name=codeium.nvim" - .. "&extension_version=" - .. versions.extension - .. "&ide_telemetry_enabled=true" - .. "&has_index_service=" - .. (config.options.enable_index_service and "true" or "false") - .. "&locale=en_US" - - -- cross-platform solution to open the web app - local os_info = io.get_system_info() - if os_info.os == "linux" then - os.execute("xdg-open '" .. url .. "'") - elseif os_info.os == "macos" then - os.execute("open '" .. url .. "'") - elseif os_info.os == "windows" then - os.execute("start " .. url) - else - notify.error("Unsupported operating system") - end + request("GetProcesses", { + metadata = get_request_metadata(), + }, function(body, err) + if err then + notify.error("failed to get chat ports", err) + cancel() + return + end + chat_ports = vim.fn.json_decode(body) + notify.info("Codeium chat ready to use") + cancel() + end) end) end + function m.open_chat() + if chat_ports == nil then + notify.error("chat ports not found") + return + end + local url = "http://127.0.0.1:" + .. chat_ports.chatClientPort + .. "?api_key=" + .. api_key + .. "&has_enterprise_extension=" + .. (config.options.enterprise_mode and "true" or "false") + .. "&web_server_url=ws://127.0.0.1:" + .. chat_ports.chatWebServerPort + .. "&ide_name=neovim" + .. "&ide_version=" + .. versions.nvim + .. "&app_name=codeium.nvim" + .. "&extension_name=codeium.nvim" + .. "&extension_version=" + .. versions.extension + .. "&ide_telemetry_enabled=true" + .. "&has_index_service=" + .. (config.options.enable_index_service and "true" or "false") + .. "&locale=en_US" + + -- cross-platform solution to open the web app + local os_info = io.get_system_info() + if os_info.os == "linux" then + os.execute("xdg-open '" .. url .. "'") + elseif os_info.os == "macos" then + os.execute("open '" .. url .. "'") + elseif os_info.os == "windows" then + os.execute("start " .. url) + else + notify.error("Unsupported operating system") + end + end + + function m.request_chat_action(document, editor_options, prompt, callback) + local body = { + message_id = 1, + source = 'User', + timestamp = timestamp(), + conversation_id = 1, + content = { indents = { generic = { text = prompt } } }, + in_progress = false + } + m.chat_request("GetAction", body, "prompt", callback) + end + function m.shutdown() current_cookie = nil if job then diff --git a/lua/codeium/init.lua b/lua/codeium/init.lua index 769d5e1..62ae040 100644 --- a/lua/codeium/init.lua +++ b/lua/codeium/init.lua @@ -1,28 +1,35 @@ -local M = {} +local M = { + Server = nil, + Config = nil +} function M.setup(options) - local Source = require("codeium.source") - local Server = require("codeium.api") + local source = require("codeium.source") + local server = require("codeium.api") local update = require("codeium.update") local config = require("codeium.config") config.setup(options) + Config = config.options - local s = Server:new() + Server = server:new() update.download(function(err) if not err then - Server.load_api_key() - s.start() + server.load_api_key() + Server.start() + if config.options.enable_chat then + Server.init_chat() + end end end) vim.api.nvim_create_user_command("Codeium", function(opts) local args = opts.fargs if args[1] == "Auth" then - Server.authenticate() + server.authenticate() end if args[1] == "Chat" then - s.get_chat_ports() - s.add_workspace() + Server.open_chat() + Server.add_workspace() end end, { nargs = 1, @@ -35,8 +42,19 @@ function M.setup(options) end, }) - local source = Source:new(s) + local source = source:new(Server) require("cmp").register_source("codeium", source) end +function M.open_chat() + if not Config.enable_chat then + return + end + Server.open_chat() +end + +function M.add_workspace() + Server.add_workspace() +end + return M From ffa60ecc1f415769bdfbe9e6e8a42d311c4bc1a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aliaksandr=20Tru=C5=A1?= Date: Sun, 21 Apr 2024 22:03:58 +0200 Subject: [PATCH 3/9] continue work on chat functionality --- lua/codeium/api.lua | 125 ++++++++++++++++++++++++++++++------- lua/codeium/chat.lua | 145 +++++++++++++++++++++++++++++++++++++++++++ lua/codeium/init.lua | 33 ++++++---- 3 files changed, 268 insertions(+), 35 deletions(-) create mode 100644 lua/codeium/chat.lua diff --git a/lua/codeium/api.lua b/lua/codeium/api.lua index 9ade176..045f885 100644 --- a/lua/codeium/api.lua +++ b/lua/codeium/api.lua @@ -1,4 +1,5 @@ local versions = require("codeium.versions") +local chat = require("codeium.chat") local config = require("codeium.config") local io = require("codeium.io") local log = require("codeium.log") @@ -149,6 +150,10 @@ function Server:new() local healthy = false local function request(fn, payload, callback) + if not port then + notify.info("Server not started yet") + return + end local url = "http://127.0.0.1:" .. port .. "/exa.language_server_pb.LanguageServerService/" .. fn io.post(url, { body = payload, @@ -156,10 +161,12 @@ function Server:new() }) end - local function chat_request(fn, payload, callback) - local url = "http://127.0.0.1:" .. chat_ports.chatClientPort .. "/exa.language_server_pb.LanguageServerService/" .. fn + local function chat_server_request(fn, payload, callback) + local url = "http://127.0.0.1:" .. + chat_ports.chatWebServerPort .. "/exa.language_server_pb.LanguageServerService/" .. fn + local body = { metadata = get_request_metadata(), chat_message = payload } io.post(url, { - body = payload, + body = body, callback = callback, }) end @@ -287,8 +294,9 @@ function Server:new() end port = find_port(manager_dir, start_time) + -- port = 42100 if port then - notify.info("Codeium server started") + notify.info("Codeium server started on port " .. port) cancel() start_heartbeat() end @@ -385,30 +393,49 @@ function Server:new() }, noop) end + local codeium_workspace_root_hints = { '.bzr', '.git', '.hg', '.svn', '_FOSSIL_', 'package.json' } + function GetProjectRoot() + local last_dir = '' + local dir = vim.fn.getcwd() + while dir ~= last_dir do + for root_hint in ipairs(codeium_workspace_root_hints) do + local hint = dir .. '/' .. root_hint + if vim.fn.isdirectory(hint) or vim.fn.filereadable(hint) then + return dir + end + end + last_dir = dir + dir = vim.fn.fnamemodify(dir, ':h') + end + return vim.fn.getcwd() + end + function m.add_workspace() - local project_root = vim.fn.getcwd() + local project_root = GetProjectRoot() -- workspace already tracked by server if workspaces[project_root] then return end - -- unable to track hidden path - for entry in project_root:gmatch("[^/]+") do - if entry:sub(1, 1) == "." then - return - end - end - request("AddTrackedWorkspace", { workspace = project_root, metadata = get_request_metadata() }, function(_, err) - if err then - notify.error("failed to add workspace: " .. err.out) + io.timer(300, 500, function(cancel) + if not port then return end - workspaces[project_root] = true + request("AddTrackedWorkspace", { workspace = project_root, metadata = get_request_metadata() }, + function(_, err) + if err then + notify.error("failed to add workspace: " .. err.out) + return + end + workspaces[project_root] = true + notify.info("Workspace " .. project_root .. " added") + end) + cancel() end) end function m.init_chat() - io.timer(100, 500, function(cancel) + io.timer(200, 500, function(cancel) if not port then return end @@ -421,7 +448,7 @@ function Server:new() return end chat_ports = vim.fn.json_decode(body) - notify.info("Codeium chat ready to use") + notify.info("Codeium chat ready to use on server ports: client port " .. chat_ports.chatClientPort .. " and server port" .. chat_ports.chatWebServerPort) cancel() end) end) @@ -465,16 +492,66 @@ function Server:new() end end - function m.request_chat_action(document, editor_options, prompt, callback) + local function getNonce() + local possible = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + local nonce = "" + + for _ = 1, 32 do + local randomIndex = math.random(1, #possible) + nonce = nonce .. string.sub(possible, randomIndex, randomIndex) + end + + return nonce + end + + ---@param indent table + ---@param callback function + local function request_chat_action(indent, callback) + local current_timestamp = os.time() + local message_id = "user-" .. tostring(current_timestamp) local body = { - message_id = 1, - source = 'User', - timestamp = timestamp(), - conversation_id = 1, - content = { indents = { generic = { text = prompt } } }, + message_id = message_id, + source = 'CHAT_MESSAGE_SOURCE_USER', + timestamp = current_timestamp, + conversation_id = getNonce(), + content = { indent = indent }, in_progress = false } - m.chat_request("GetAction", body, "prompt", callback) + chat_server_request("GetChatMessage", body, callback) + end + + function m.request_generate_code() + request_chat_action(chat.intent_generate_code(), function(body, err) + if err then + notify.error("Error code: " .. err.code) + notify.error("Error message: " .. err.out) + notify.error("Error status: " .. err.status) + notify.error("Error response: " .. err.response) + else + notify.info("Explain: " .. body) + end + end) + end + + function m.open_connection() + io.timer(200, 500, function(cancel) + if not port then + return + end + local url = "http://127.0.0.1:" .. chat_ports.chatClientPort .. "/api/chat_enabled" + callback = function(body, err) + if err then + notify.error("chat response", err) + cancel() + return + end + notify.info("Response: " .. body) + end + io.post(url, { + body = { metadata = get_request_metadata() }, + callback = callback, + }) + end) end function m.shutdown() diff --git a/lua/codeium/chat.lua b/lua/codeium/chat.lua new file mode 100644 index 0000000..d1ac2e6 --- /dev/null +++ b/lua/codeium/chat.lua @@ -0,0 +1,145 @@ +local enums = require("codeium.enums") +---@class codeium.Chat +local chat = {} + +function chat.intent_generic(text) + return { text = text } +end + +-- string raw_source = 1; +-- string clean_function = 2; +-- string docstring = 3; +-- string node_name = 4; +-- string params = 5; +-- int32 definition_line = 6; +-- int32 start_line = 7; +-- int32 end_line = 8; +-- int32 start_col = 9; +-- int32 end_col = 10; +-- string leading_whitespace = 11; +-- Language language = 12; +local function function_info() + local bufnr = vim.api.nvim_get_current_buf() + local filetype = enums.filetype_aliases[vim.bo[bufnr].filetype] or vim.bo[bufnr].filetype or "text" + local language = enums.languages[filetype] or enums.languages.unspecified + return { + raw_source = "", + clean_function = "", + docstring = "", + node_name = "", + params = "", + definition_line = 6, + start_line = 7, + end_line = 8, + start_col = 9, + end_col = 10, + leading_whitespace = "", + language = language + } +end + +---@return number +local function language() + local bufnr = vim.api.nvim_get_current_buf() + local filetype = enums.filetype_aliases[vim.bo[bufnr].filetype] or vim.bo[bufnr].filetype or "text" + return enums.languages[filetype] or enums.languages.unspecified +end + +-- string raw_source = 1; +-- +-- // Start position of the code block. +-- int32 start_line = 2; +-- int32 start_col = 3; +-- +-- // End position of the code block. +-- int32 end_line = 4; +-- int32 end_col = 5; +local function code_block_info() + return { raw_source = "", start_line = 0, start_col = 0, end_line = 1, end_col = 1 } +end + + +-- codeium_common_pb.FunctionInfo function_info = 1; +-- codeium_common_pb.Language language = 2; +-- string file_path = 3; +function chat.intent_function_explain() + return { explain_function = { function_info = function_info(), language = language(), file_path = vim.api.nvim_buf_get_name(0) } } +end + +-- codeium_common_pb.FunctionInfo function_info = 1; +-- codeium_common_pb.Language language = 2; +-- string file_path = 3; +-- string refactor_description = 4; +function chat.intent_function_refactor() + return { function_refactor = { function_info = function_info(), language = language(), file_path = "", refactor_description = "" } } +end + +-- codeium_common_pb.FunctionInfo function_info = 1; +-- codeium_common_pb.Language language = 2; +-- string file_path = 3; +-- +-- --Optional additional instructions to inform what tests to generate. +-- string instructions = 4; +function chat.intent_function_unit_tests() + return { function_unit_tests = { function_info = function_info(), language = language(), file_path = "", instructions = "" } } +end + +--Ask for a docstring for a function. +-- codeium_common_pb.FunctionInfo function_info = 1; +-- codeium_common_pb.Language language = 2; +-- string file_path = 3; +function chat.intent_function_docstring() + return { function_docstring = { function_info = function_info(), language = language(), file_path = "" } } +end + +--Ask to explain a generic piece of code. +-- CodeBlockInfo code_block_info = 1; +-- codeium_common_pb.Language language = 2; +-- string file_path = 3; +function chat.intent_code_block_explain() + return { code_block_explain = { code_block_info = code_block_info(), language = language(), file_path = "" } } +end + +--Ask to refactor a generic piece of code. +-- CodeBlockInfo code_block_info = 1; +-- codeium_common_pb.Language language = 2; +-- string file_path = 3; +-- string refactor_description = 4; +function chat.intent_code_block_refactor() + return { code_block_refactor = { code_block_info = code_block_info(), language = language(), file_path = "", refactor_description = "" } } +end + +--Ask to explain a problem. +-- string diagnostic_function chat.= 1; +-- CodeBlockInfo problematic_code = 2; //entire code block with error +-- string surrounding_code_snippet = 3; +-- codeium_common_pb.Language language = 4; +-- string file_path = 5; +-- int32 line_number = 6; +function chat.intent_problem_explain() + return { + problem_explain = { + diagnostic_function = "", + problematic_code = code_block_info(), + surrounding_code_snippet = "", + language = language(), + file_path = "", + line_number = 0 + } + } +end + +--Ask to generate a piece of code. +-- string instruction = 1; +-- codeium_common_pb.Language language = 2; +-- string file_path = 3; +-- --Line to insert the generated code into. +-- int32 line_number = 4; +function chat.intent_generate_code() + local prompt = vim.fn.input("Please add prompt:") + local file_path = vim.api.nvim_buf_get_name(0) + local line_number = vim.api.nvim_win_get_cursor(0)[1] + return { generate_code = { instruction = prompt, language = language(), file_path = file_path, line_number = line_number } } +end + +return chat diff --git a/lua/codeium/init.lua b/lua/codeium/init.lua index 62ae040..8ca1621 100644 --- a/lua/codeium/init.lua +++ b/lua/codeium/init.lua @@ -1,3 +1,8 @@ +local notify = require("codeium.notify") + +---@class codeium +---@field Server codeium.Server|nil +---@field Config codeium.options|nil local M = { Server = nil, Config = nil @@ -9,16 +14,17 @@ function M.setup(options) local update = require("codeium.update") local config = require("codeium.config") config.setup(options) - Config = config.options + M.Config = config.options - Server = server:new() + M.Server = server:new() update.download(function(err) if not err then server.load_api_key() - Server.start() + M.Server.start() if config.options.enable_chat then - Server.init_chat() + M.Server.init_chat() end + M.Server.add_workspace() end end) @@ -28,8 +34,8 @@ function M.setup(options) server.authenticate() end if args[1] == "Chat" then - Server.open_chat() - Server.add_workspace() + M.Server.open_chat() + M.Server.add_workspace() end end, { nargs = 1, @@ -42,19 +48,24 @@ function M.setup(options) end, }) - local source = source:new(Server) - require("cmp").register_source("codeium", source) + require("cmp").register_source("codeium", source:new(M.Server)) end function M.open_chat() - if not Config.enable_chat then + if not M.Config.enable_chat then + notify.info("Codeium Chat disabled") return end - Server.open_chat() + M.Server.open_chat() end function M.add_workspace() - Server.add_workspace() + M.Server.add_workspace() +end + +function M.generate_code() + M.Server.open_connection() + M.Server.request_generate_code() end return M From 156337e7d25f9db0f981b82106eced190a7deb98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aliaksandr=20Tru=C5=A1?= Date: Mon, 29 Apr 2024 09:53:36 +0200 Subject: [PATCH 4/9] impl code_block_info --- lua/codeium/api.lua | 26 ++++++++++++++++++++++++++ lua/codeium/chat.lua | 34 +++++++++++++++++++++++++++------- lua/codeium/init.lua | 10 ++++++++++ 3 files changed, 63 insertions(+), 7 deletions(-) diff --git a/lua/codeium/api.lua b/lua/codeium/api.lua index 045f885..d49222c 100644 --- a/lua/codeium/api.lua +++ b/lua/codeium/api.lua @@ -533,6 +533,32 @@ function Server:new() end) end + function m.request_explain_code() + request_chat_action(chat.intent_code_block_explain(), function(body, err) + if err then + notify.error("Error code: " .. err.code) + notify.error("Error message: " .. err.out) + notify.error("Error status: " .. err.status) + notify.error("Error response: " .. err.response) + else + notify.info("Explain: " .. body) + end + end) + end + + function m.request_docstring() + request_chat_action(chat.intent_function_docstring(), function(body, err) + if err then + notify.error("Error code: " .. err.code) + notify.error("Error message: " .. err.out) + notify.error("Error status: " .. err.status) + notify.error("Error response: " .. err.response) + else + notify.info("Explain: " .. body) + end + end) + end + function m.open_connection() io.timer(200, 500, function(cancel) if not port then diff --git a/lua/codeium/chat.lua b/lua/codeium/chat.lua index d1ac2e6..e3b7c28 100644 --- a/lua/codeium/chat.lua +++ b/lua/codeium/chat.lua @@ -1,4 +1,5 @@ local enums = require("codeium.enums") +local util = require("codeium.util") ---@class codeium.Chat local chat = {} @@ -55,7 +56,19 @@ end -- int32 end_line = 4; -- int32 end_col = 5; local function code_block_info() - return { raw_source = "", start_line = 0, start_col = 0, end_line = 1, end_col = 1 } + -- Get the current buffer + local current_buffer = vim.api.nvim_get_current_buf() + + -- Get the start and end positions of the visual selection + local start_line, start_col = unpack(vim.api.nvim_buf_get_mark(current_buffer, "<")) + local end_line, end_col = unpack(vim.api.nvim_buf_get_mark(current_buffer, ">")) + + local lines = vim.api.nvim_buf_get_lines(current_buffer, start_line, end_line, true) + local line_ending = util.get_newline(current_buffer) + table.insert(lines, "") + local text = table.concat(lines, line_ending) + + return { raw_source = text, start_line = start_line, start_col = start_col, end_line = end_line, end_col = end_col } end @@ -71,7 +84,8 @@ end -- string file_path = 3; -- string refactor_description = 4; function chat.intent_function_refactor() - return { function_refactor = { function_info = function_info(), language = language(), file_path = "", refactor_description = "" } } + local file_path = vim.api.nvim_buf_get_name(0) + return { function_refactor = { function_info = function_info(), language = language(), file_path = file_path, refactor_description = "" } } end -- codeium_common_pb.FunctionInfo function_info = 1; @@ -81,7 +95,9 @@ end -- --Optional additional instructions to inform what tests to generate. -- string instructions = 4; function chat.intent_function_unit_tests() - return { function_unit_tests = { function_info = function_info(), language = language(), file_path = "", instructions = "" } } + local prompt = vim.fn.input("Unit test instructions: ") + local file_path = vim.api.nvim_buf_get_name(0) + return { function_unit_tests = { function_info = function_info(), language = language(), file_path = file_path, instructions = prompt } } end --Ask for a docstring for a function. @@ -89,7 +105,8 @@ end -- codeium_common_pb.Language language = 2; -- string file_path = 3; function chat.intent_function_docstring() - return { function_docstring = { function_info = function_info(), language = language(), file_path = "" } } + local file_path = vim.api.nvim_buf_get_name(0) + return { function_docstring = { function_info = function_info(), language = language(), file_path = file_path } } end --Ask to explain a generic piece of code. @@ -97,7 +114,8 @@ end -- codeium_common_pb.Language language = 2; -- string file_path = 3; function chat.intent_code_block_explain() - return { code_block_explain = { code_block_info = code_block_info(), language = language(), file_path = "" } } + local file_path = vim.api.nvim_buf_get_name(0) + return { code_block_explain = { code_block_info = code_block_info(), language = language(), file_path = file_path } } end --Ask to refactor a generic piece of code. @@ -106,7 +124,8 @@ end -- string file_path = 3; -- string refactor_description = 4; function chat.intent_code_block_refactor() - return { code_block_refactor = { code_block_info = code_block_info(), language = language(), file_path = "", refactor_description = "" } } + local file_path = vim.api.nvim_buf_get_name(0) + return { code_block_refactor = { code_block_info = code_block_info(), language = language(), file_path = file_path, refactor_description = "" } } end --Ask to explain a problem. @@ -117,13 +136,14 @@ end -- string file_path = 5; -- int32 line_number = 6; function chat.intent_problem_explain() + local file_path = vim.api.nvim_buf_get_name(0) return { problem_explain = { diagnostic_function = "", problematic_code = code_block_info(), surrounding_code_snippet = "", language = language(), - file_path = "", + file_path = file_path, line_number = 0 } } diff --git a/lua/codeium/init.lua b/lua/codeium/init.lua index 8ca1621..a5f3c89 100644 --- a/lua/codeium/init.lua +++ b/lua/codeium/init.lua @@ -68,4 +68,14 @@ function M.generate_code() M.Server.request_generate_code() end +function M.explain() + M.Server.open_connection() + M.Server.request_explain_code() +end + +function M.add_docstring() + M.Server.open_connection() + M.Server.request_docstring() +end + return M From 21748bf8c8ba737d3e471651fb04704ee0b0cd12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aliaksandr=20Tru=C5=A1?= Date: Mon, 29 Apr 2024 15:02:53 +0200 Subject: [PATCH 5/9] Refactor api module to be more OOP --- lua/codeium/api.lua | 750 ++++++++++++++++++++--------------------- lua/codeium/init.lua | 27 +- lua/codeium/source.lua | 7 +- 3 files changed, 383 insertions(+), 401 deletions(-) diff --git a/lua/codeium/api.lua b/lua/codeium/api.lua index d49222c..2cc9d60 100644 --- a/lua/codeium/api.lua +++ b/lua/codeium/api.lua @@ -7,6 +7,20 @@ local update = require("codeium.update") local notify = require("codeium.notify") local api_key = nil +local function noop(...) end + +local function get_nonce() + local possible = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + local nonce = "" + + for _ = 1, 32 do + local randomIndex = math.random(1, #possible) + nonce = nonce .. string.sub(possible, randomIndex, randomIndex) + end + + return nonce +end + local function find_port(manager_dir, start_time) local files = io.readdir(manager_dir) @@ -36,8 +50,40 @@ local function get_request_metadata(request_id) } end +local codeium_workspace_root_hints = { '.bzr', '.git', '.hg', '.svn', '_FOSSIL_', 'package.json' } +local function get_project_root() + local last_dir = '' + local dir = vim.fn.getcwd() + while dir ~= last_dir do + for root_hint in ipairs(codeium_workspace_root_hints) do + local hint = dir .. '/' .. root_hint + if vim.fn.isdirectory(hint) or vim.fn.filereadable(hint) then + return dir + end + end + last_dir = dir + dir = vim.fn.fnamemodify(dir, ':h') + end + return vim.fn.getcwd() +end + ---@class codeium.Server -local Server = {} +---@field port? number +---@field job? plenary.Job +---@field chat_ports? table +---@field current_cookie? number +---@field workspaces table +---@field healthy boolean +---@field pending_request table +local Server = { + _port = nil, + job = nil, + chat_ports = { chatClientPort = nil, chatWebServerPort = nil }, + current_cookie = nil, + workspaces = {}, + healthy = false, + pending_request = { 0, noop }, +} Server.__index = Server function Server.load_api_key() @@ -57,7 +103,7 @@ function Server.load_api_key() api_key = json.api_key end -function Server.save_api_key() +local function save_api_key() local _, result = io.write_json(config.options.config_path, { api_key = api_key, }) @@ -111,7 +157,7 @@ function Server.authenticate() end if json and json.api_key and json.api_key ~= "" then api_key = json.api_key - Server.save_api_key() + save_api_key() notify.info("api key saved") return end @@ -139,457 +185,395 @@ function Server:new() local m = {} setmetatable(m, self) - local o = {} - setmetatable(o, m) - - local port = nil - local chat_ports = nil - local job = nil - local current_cookie = nil - local workspaces = {} - local healthy = false + m.__index = m + return m +end - local function request(fn, payload, callback) - if not port then - notify.info("Server not started yet") - return - end - local url = "http://127.0.0.1:" .. port .. "/exa.language_server_pb.LanguageServerService/" .. fn - io.post(url, { - body = payload, - callback = callback, - }) - end +function Server:start() + self:shutdown() - local function chat_server_request(fn, payload, callback) - local url = "http://127.0.0.1:" .. - chat_ports.chatWebServerPort .. "/exa.language_server_pb.LanguageServerService/" .. fn - local body = { metadata = get_request_metadata(), chat_message = payload } - io.post(url, { - body = body, - callback = callback, - }) - end + self.current_cookie = next_cookie() - local function do_heartbeat() - request("Heartbeat", { - metadata = get_request_metadata(), - }, function(_, err) - if err then - notify.warn("heartbeat failed", err) - else - healthy = true - end - end) + if not api_key then + io.timer(1000, 0, self.start) + return end - function m.is_healthy() - return healthy + local manager_dir = config.options.manager_path + if not manager_dir then + manager_dir = io.tempdir("codeium/manager") + vim.fn.mkdir(manager_dir, "p") end - function m.start() - m.shutdown() - - current_cookie = next_cookie() + local start_time = io.touch(manager_dir .. "/start") - if not api_key then - io.timer(1000, 0, m.start) + local function on_exit(_, err) + if not self.current_cookie then return end - local manager_dir = config.manager_path - if not manager_dir then - manager_dir = io.tempdir("codeium/manager") - vim.fn.mkdir(manager_dir, "p") - end - - local start_time = io.touch(manager_dir .. "/start") - - local function on_exit(_, err) - if not current_cookie then - return - end + self.healthy = false + if err then + self.job = nil + self.current_cookie = nil - healthy = false - if err then - job = nil - current_cookie = nil - - notify.error("codeium server crashed", err) - io.timer(1000, 0, function() - log.debug("restarting server after crash") - m.start() - end) - end - end - - local function on_output(_, v, j) - log.debug(j.pid .. ": " .. v) + notify.error("codeium server crashed", err) + io.timer(1000, 0, function() + log.debug("restarting server after crash") + self:start() + end) end + end - local api_server_url = "https://" - .. config.options.api.host - .. ":" - .. config.options.api.port - .. (config.options.api.path and "/" .. config.options.api.path:gsub("^/", "") or "") - - local job_args = { - update.get_bin_info().bin, - "--api_server_url", - api_server_url, - "--manager_dir", - manager_dir, - enable_handlers = true, - enable_recording = false, - on_exit = on_exit, - on_stdout = on_output, - on_stderr = on_output, - } - - if config.options.enable_chat then - table.insert(job_args, "--enable_chat_web_server") - table.insert(job_args, "--enable_chat_client") - end + local function on_output(_, v, j) + log.debug(j.pid .. ": " .. v) + end - if config.options.enable_local_search then - table.insert(job_args, "--enable_local_search") - end + local api_server_url = "https://" + .. config.options.api.host + .. ":" + .. config.options.api.port + .. (config.options.api.path and "/" .. config.options.api.path:gsub("^/", "") or "") + + local job_args = { + update.get_bin_info().bin, + "--api_server_url", + api_server_url, + "--manager_dir", + manager_dir, + enable_handlers = true, + enable_recording = false, + on_exit = on_exit, + on_stdout = on_output, + on_stderr = on_output, + } - if config.options.enable_index_service then - table.insert(job_args, "--enable_index_service") - table.insert(job_args, "--search_max_workspace_file_count") - table.insert(job_args, config.options.search_max_workspace_file_count) - end + if config.options.enable_chat then + table.insert(job_args, "--enable_chat_web_server") + table.insert(job_args, "--enable_chat_client") + end - if config.options.api.portal_url then - table.insert(job_args, "--portal_url") - table.insert(job_args, "https://" .. config.options.api.portal_url) - end + if config.options.enable_local_search then + table.insert(job_args, "--enable_local_search") + end - if config.options.enterprise_mode then - table.insert(job_args, "--enterprise_mode") - end + if config.options.enable_index_service then + table.insert(job_args, "--enable_index_service") + table.insert(job_args, "--search_max_workspace_file_count") + table.insert(job_args, config.options.search_max_workspace_file_count) + end - if config.options.detect_proxy ~= nil then - table.insert(job_args, "--detect_proxy=" .. tostring(config.options.detect_proxy)) - end + if config.options.api.portal_url then + table.insert(job_args, "--portal_url") + table.insert(job_args, "https://" .. config.options.api.portal_url) + end - local job = io.job(job_args) - job:start() + if config.options.enterprise_mode then + table.insert(job_args, "--enterprise_mode") + end - local function start_heartbeat() - io.timer(100, 5000, function(cancel_heartbeat) - if not current_cookie then - cancel_heartbeat() - else - do_heartbeat() - end - end) - end + if config.options.detect_proxy ~= nil then + table.insert(job_args, "--detect_proxy=" .. tostring(config.options.detect_proxy)) + end - io.timer(100, 500, function(cancel) - if not current_cookie then - cancel() - return - end + local job = io.job(job_args) + job:start() - port = find_port(manager_dir, start_time) - -- port = 42100 - if port then - notify.info("Codeium server started on port " .. port) - cancel() - start_heartbeat() + local function start_heartbeat() + io.timer(100, 5000, function(cancel_heartbeat) + if not self.current_cookie then + cancel_heartbeat() + else + self:do_heartbeat() end end) end - local function noop(...) end - - local pending_request = { 0, noop } - function m.request_completion(document, editor_options, other_documents, callback) - pending_request[2](true) - - local metadata = get_request_metadata() - local this_pending_request - - local complete - complete = function(...) - complete = noop - this_pending_request(false) - callback(...) + io.timer(100, 500, function(cancel) + if not self.current_cookie then + cancel() + return end - this_pending_request = function(is_complete) - if pending_request[1] == metadata.request_id then - pending_request = { 0, noop } - end - this_pending_request = noop + self.port = find_port(manager_dir, start_time) + -- port = 42100 + if self.port then + notify.info("Codeium server started on port " .. self.port) + cancel() + start_heartbeat() + end + end) +end - request("CancelRequest", { - metadata = get_request_metadata(), - request_id = metadata.request_id, - }, function(_, err) - if err then - log.warn("failed to cancel in-flight request", err) - end - end) +function Server:request(fn, payload, callback) + if not self.port then + notify.info("Server not started yet") + return + end + local url = "http://127.0.0.1:" .. self.port .. "/exa.language_server_pb.LanguageServerService/" .. fn + io.post(url, { + body = payload, + callback = callback, + }) +end - if is_complete then - complete(false, nil) - end +function Server:init_chat() + io.timer(200, 500, function(cancel) + if not self.port then + return end - pending_request = { metadata.request_id, this_pending_request } - - request("GetCompletions", { - metadata = metadata, - editor_options = editor_options, - document = document, - other_documents = other_documents, + self:request("GetProcesses", { + metadata = get_request_metadata(), }, function(body, err) if err then - if err.status == 503 or err.status == 408 then - -- Service Unavailable or Timeout error - return complete(false, nil) - end - - local ok, json = pcall(vim.fn.json_decode, err.response.body) - if ok and json then - if json.state and json.state.state == "CODEIUM_STATE_INACTIVE" then - if json.state.message then - log.debug("completion request failed", json.state.message) - end - return complete(false, nil) - end - if json.code == "canceled" then - log.debug("completion request cancelled at the server", json.message) - return complete(false, nil) - end - end - - notify.error("completion request failed", err) - complete(false, nil) - return - end - - local ok, json = pcall(vim.fn.json_decode, body) - if not ok then - notify.error("completion request failed", "invalid JSON:", json) + notify.error("failed to get chat ports", err) + cancel() return end - - log.trace("completion: ", json) - complete(true, json) + self.chat_ports = vim.fn.json_decode(body) + notify.info("Codeium chat ready to use on server ports: client port " .. + self.chat_ports.chatClientPort .. " and server port " .. self.chat_ports.chatWebServerPort) + cancel() end) + end) +end - return function() - this_pending_request(true) - end - end +function Server:request_completion(document, editor_options, other_documents, callback) + self.pending_request[2](true) - function m.accept_completion(completion_id) - request("AcceptCompletion", { - metadata = get_request_metadata(), - completion_id = completion_id, - }, noop) - end + local metadata = get_request_metadata() + local this_pending_request - local codeium_workspace_root_hints = { '.bzr', '.git', '.hg', '.svn', '_FOSSIL_', 'package.json' } - function GetProjectRoot() - local last_dir = '' - local dir = vim.fn.getcwd() - while dir ~= last_dir do - for root_hint in ipairs(codeium_workspace_root_hints) do - local hint = dir .. '/' .. root_hint - if vim.fn.isdirectory(hint) or vim.fn.filereadable(hint) then - return dir - end - end - last_dir = dir - dir = vim.fn.fnamemodify(dir, ':h') - end - return vim.fn.getcwd() + local complete + complete = function(...) + complete = noop + this_pending_request(false) + callback(...) end - function m.add_workspace() - local project_root = GetProjectRoot() - -- workspace already tracked by server - if workspaces[project_root] then - return + this_pending_request = function(is_complete) + if self.pending_request[1] == metadata.request_id then + self.pending_request = { 0, noop } end + this_pending_request = noop - io.timer(300, 500, function(cancel) - if not port then - return + self:request("CancelRequest", { + metadata = get_request_metadata(), + request_id = metadata.request_id, + }, function(_, err) + if err then + log.warn("failed to cancel in-flight request", err) end - request("AddTrackedWorkspace", { workspace = project_root, metadata = get_request_metadata() }, - function(_, err) - if err then - notify.error("failed to add workspace: " .. err.out) - return - end - workspaces[project_root] = true - notify.info("Workspace " .. project_root .. " added") - end) - cancel() end) - end - function m.init_chat() - io.timer(200, 500, function(cancel) - if not port then - return + if is_complete then + complete(false, nil) + end + end + self.pending_request = { metadata.request_id, this_pending_request } + + self:request("GetCompletions", { + metadata = metadata, + editor_options = editor_options, + document = document, + other_documents = other_documents, + }, function(body, err) + if err then + if err.status == 503 or err.status == 408 then + -- Service Unavailable or Timeout error + return complete(false, nil) end - request("GetProcesses", { - metadata = get_request_metadata(), - }, function(body, err) - if err then - notify.error("failed to get chat ports", err) - cancel() - return + + local ok, json = pcall(vim.fn.json_decode, err.response.body) + if ok and json then + if json.state and json.state.state == "CODEIUM_STATE_INACTIVE" then + if json.state.message then + log.debug("completion request failed", json.state.message) + end + return complete(false, nil) end - chat_ports = vim.fn.json_decode(body) - notify.info("Codeium chat ready to use on server ports: client port " .. chat_ports.chatClientPort .. " and server port" .. chat_ports.chatWebServerPort) - cancel() - end) - end) - end + if json.code == "canceled" then + log.debug("completion request cancelled at the server", json.message) + return complete(false, nil) + end + end - function m.open_chat() - if chat_ports == nil then - notify.error("chat ports not found") + notify.error("completion request failed", err) + complete(false, nil) return end - local url = "http://127.0.0.1:" - .. chat_ports.chatClientPort - .. "?api_key=" - .. api_key - .. "&has_enterprise_extension=" - .. (config.options.enterprise_mode and "true" or "false") - .. "&web_server_url=ws://127.0.0.1:" - .. chat_ports.chatWebServerPort - .. "&ide_name=neovim" - .. "&ide_version=" - .. versions.nvim - .. "&app_name=codeium.nvim" - .. "&extension_name=codeium.nvim" - .. "&extension_version=" - .. versions.extension - .. "&ide_telemetry_enabled=true" - .. "&has_index_service=" - .. (config.options.enable_index_service and "true" or "false") - .. "&locale=en_US" - - -- cross-platform solution to open the web app - local os_info = io.get_system_info() - if os_info.os == "linux" then - os.execute("xdg-open '" .. url .. "'") - elseif os_info.os == "macos" then - os.execute("open '" .. url .. "'") - elseif os_info.os == "windows" then - os.execute("start " .. url) - else - notify.error("Unsupported operating system") + + local ok, json = pcall(vim.fn.json_decode, body) + if not ok then + notify.error("completion request failed", "invalid JSON:", json) + return end + + log.trace("completion: ", json) + complete(true, json) + end) + + return function() + this_pending_request(true) end +end - local function getNonce() - local possible = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" - local nonce = "" +function Server:accept_completion(completion_id) + self:request("AcceptCompletion", { + metadata = get_request_metadata(), + completion_id = completion_id, + }, noop) +end - for _ = 1, 32 do - local randomIndex = math.random(1, #possible) - nonce = nonce .. string.sub(possible, randomIndex, randomIndex) +function Server:do_heartbeat() + self:request("Heartbeat", { + metadata = get_request_metadata(), + }, function(_, err) + if err then + notify.warn("heartbeat failed", err) + else + self.healthy = true end + end) +end - return nonce - end - - ---@param indent table - ---@param callback function - local function request_chat_action(indent, callback) - local current_timestamp = os.time() - local message_id = "user-" .. tostring(current_timestamp) - local body = { - message_id = message_id, - source = 'CHAT_MESSAGE_SOURCE_USER', - timestamp = current_timestamp, - conversation_id = getNonce(), - content = { indent = indent }, - in_progress = false - } - chat_server_request("GetChatMessage", body, callback) - end +function Server:is_healthy() + return self.healthy +end - function m.request_generate_code() - request_chat_action(chat.intent_generate_code(), function(body, err) - if err then - notify.error("Error code: " .. err.code) - notify.error("Error message: " .. err.out) - notify.error("Error status: " .. err.status) - notify.error("Error response: " .. err.response) - else - notify.info("Explain: " .. body) - end - end) +function Server:open_chat() + if self.chat_ports == nil then + notify.error("chat ports not found") + return end - - function m.request_explain_code() - request_chat_action(chat.intent_code_block_explain(), function(body, err) - if err then - notify.error("Error code: " .. err.code) - notify.error("Error message: " .. err.out) - notify.error("Error status: " .. err.status) - notify.error("Error response: " .. err.response) - else - notify.info("Explain: " .. body) - end - end) + local url = "http://127.0.0.1:" + .. self.chat_ports.chatClientPort + .. "?api_key=" + .. api_key + .. "&has_enterprise_extension=" + .. (config.options.enterprise_mode and "true" or "false") + .. "&web_server_url=ws://127.0.0.1:" + .. self.chat_ports.chatWebServerPort + .. "&ide_name=neovim" + .. "&ide_version=" + .. versions.nvim + .. "&app_name=codeium.nvim" + .. "&extension_name=codeium.nvim" + .. "&extension_version=" + .. versions.extension + .. "&ide_telemetry_enabled=true" + .. "&has_index_service=" + .. (config.options.enable_index_service and "true" or "false") + .. "&locale=en_US" + + -- cross-platform solution to open the web app + local os_info = io.get_system_info() + if os_info.os == "linux" then + os.execute("xdg-open '" .. url .. "'") + elseif os_info.os == "macos" then + os.execute("open '" .. url .. "'") + elseif os_info.os == "windows" then + os.execute("start " .. url) + else + notify.error("Unsupported operating system") end +end - function m.request_docstring() - request_chat_action(chat.intent_function_docstring(), function(body, err) - if err then - notify.error("Error code: " .. err.code) - notify.error("Error message: " .. err.out) - notify.error("Error status: " .. err.status) - notify.error("Error response: " .. err.response) - else - notify.info("Explain: " .. body) - end - end) +function Server:add_workspace() + local project_root = get_project_root() + -- workspace already tracked by server + if self.workspaces[project_root] then + return end - function m.open_connection() - io.timer(200, 500, function(cancel) - if not port then - return - end - local url = "http://127.0.0.1:" .. chat_ports.chatClientPort .. "/api/chat_enabled" - callback = function(body, err) + io.timer(300, 500, function(cancel) + if not self.port then + return + end + self:request("AddTrackedWorkspace", { workspace = project_root, metadata = get_request_metadata() }, + function(_, err) if err then - notify.error("chat response", err) - cancel() + notify.error("failed to add workspace: " .. err.out) return end - notify.info("Response: " .. body) - end - io.post(url, { - body = { metadata = get_request_metadata() }, - callback = callback, - }) - end) + self.workspaces[project_root] = true + notify.info("Workspace " .. project_root .. " added") + end) + cancel() + end) +end + +function Server:shutdown() + self.current_cookie = nil + if self.job then + self.job.on_exit = nil + self.job:shutdown() end +end + +function Server:chat_server_request(fn, payload, callback) + local url = "http://127.0.0.1:" .. + self.chat_ports.chatWebServerPort .. "/exa.language_server_pb.LanguageServerService/" .. fn + local body = { metadata = get_request_metadata(), chat_message = payload } + io.post(url, { + body = body, + callback = callback, + }) +end + +---@param indent table +---@param callback function +function Server:request_chat_action(indent, callback) + local current_timestamp = os.time() + local message_id = "user-" .. tostring(current_timestamp) + local body = { + message_id = message_id, + source = 'CHAT_MESSAGE_SOURCE_USER', + timestamp = current_timestamp, + conversation_id = get_nonce(), + content = { indent = indent }, + in_progress = false + } + self:chat_server_request("GetChatMessage", body, callback) +end - function m.shutdown() - current_cookie = nil - if job then - job.on_exit = nil - job:shutdown() +function Server:request_generate_code() + self:request_chat_action(chat.intent_generate_code(), function(body, err) + if err then + notify.error("Error code: " .. err.code) + notify.error("Error message: " .. err.out) + notify.error("Error status: " .. err.status) + notify.error("Error response: " .. err.response) + else + notify.info("Explain: " .. body) end - end + end) +end - m.__index = m - return o +function Server:request_explain_code() + self:request_chat_action(chat.intent_code_block_explain(), function(body, err) + if err then + notify.error("Error code: " .. err.code) + notify.error("Error message: " .. err.out) + notify.error("Error status: " .. err.status) + notify.error("Error response: " .. err.response) + else + notify.info("Explain: " .. body) + end + end) +end + +function Server:request_docstring() + self:request_chat_action(chat.intent_function_docstring(), function(body, err) + if err then + notify.error("Error code: " .. err.code) + notify.error("Error message: " .. err.out) + notify.error("Error status: " .. err.status) + notify.error("Error response: " .. err.response) + else + notify.info("Explain: " .. body) + end + end) end return Server diff --git a/lua/codeium/init.lua b/lua/codeium/init.lua index a5f3c89..80ae065 100644 --- a/lua/codeium/init.lua +++ b/lua/codeium/init.lua @@ -19,23 +19,23 @@ function M.setup(options) M.Server = server:new() update.download(function(err) if not err then - server.load_api_key() - M.Server.start() + M.Server.load_api_key() + M.Server:start() if config.options.enable_chat then - M.Server.init_chat() + M.Server:init_chat() end - M.Server.add_workspace() + M.Server:add_workspace() end end) vim.api.nvim_create_user_command("Codeium", function(opts) local args = opts.fargs if args[1] == "Auth" then - server.authenticate() + M.Server.authenticate() end if args[1] == "Chat" then - M.Server.open_chat() - M.Server.add_workspace() + M.Server:open_chat() + M.Server:add_workspace() end end, { nargs = 1, @@ -56,26 +56,23 @@ function M.open_chat() notify.info("Codeium Chat disabled") return end - M.Server.open_chat() + M.Server:open_chat() end function M.add_workspace() - M.Server.add_workspace() + M.Server:add_workspace() end function M.generate_code() - M.Server.open_connection() - M.Server.request_generate_code() + M.Server:request_generate_code() end function M.explain() - M.Server.open_connection() - M.Server.request_explain_code() + M.Server:request_explain_code() end function M.add_docstring() - M.Server.open_connection() - M.Server.request_docstring() + M.Server:request_docstring() end return M diff --git a/lua/codeium/source.lua b/lua/codeium/source.lua index bc0ce6a..b91bcf2 100644 --- a/lua/codeium/source.lua +++ b/lua/codeium/source.lua @@ -100,6 +100,7 @@ local function get_other_documents(bufnr) end ---@class codeium.Source +---@field server codeium.Server|nil local Source = { server = nil, } @@ -114,7 +115,7 @@ function Source:new(server) end function Source:is_available() - return self.server.is_healthy() + return self.server:is_healthy() end function Source:get_position_encoding_kind() @@ -131,7 +132,7 @@ require("cmp").event:on("confirm_done", function(event) and event.entry.source.source and event.entry.source.source.server then - event.entry.source.source.server.accept_completion(event.entry.completion_item.codeium_completion_id) + event.entry.source.source.server:accept_completion(event.entry.completion_item.codeium_completion_id) end end) @@ -181,7 +182,7 @@ function Source:complete(params, callback) local other_documents = get_other_documents(bufnr) - self.server.request_completion( + self.server:request_completion( { editor_language = filetype, language = language, From 0ff4c31fd9ef66585c1c119fd89bc4ee120b406d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aliaksandr=20Tru=C5=A1?= Date: Tue, 30 Apr 2024 16:04:19 +0200 Subject: [PATCH 6/9] feat: add posibility to call actions for function under cursor --- lua/codeium/api.lua | 42 +++++++++++++++++++++++++++++++++++++++++- lua/codeium/chat.lua | 9 +++++---- lua/codeium/config.lua | 2 +- lua/codeium/init.lua | 4 ++++ lua/codeium/source.lua | 17 +---------------- lua/codeium/util.lua | 16 ++++++++++++++++ 6 files changed, 68 insertions(+), 22 deletions(-) diff --git a/lua/codeium/api.lua b/lua/codeium/api.lua index 2cc9d60..efdf211 100644 --- a/lua/codeium/api.lua +++ b/lua/codeium/api.lua @@ -5,6 +5,7 @@ local io = require("codeium.io") local log = require("codeium.log") local update = require("codeium.update") local notify = require("codeium.notify") +local utils = require("codeium.util") local api_key = nil local function noop(...) end @@ -564,7 +565,7 @@ function Server:request_explain_code() end function Server:request_docstring() - self:request_chat_action(chat.intent_function_docstring(), function(body, err) + self:request_function_action(chat.intent_function_docstring, function(body, err) if err then notify.error("Error code: " .. err.code) notify.error("Error message: " .. err.out) @@ -576,4 +577,43 @@ function Server:request_docstring() end) end + +function Server:request_refactor() + self:request_function_action(chat.intent_function_refactor, function(body, err) + if err then + notify.error("Error code: " .. err.code) + notify.error("Error message: " .. err.out) + notify.error("Error status: " .. err.status) + notify.error("Error response: " .. err.response) + else + notify.info("Explain: " .. body) + end + end) +end + +---Request action for a function under cursor. +---@param indent function +---@param callback function +function Server:request_function_action(indent, callback) + local row, _ = unpack(vim.api.nvim_win_get_cursor(0)) + self:request("GetFunctions", { document = utils.buf_to_codeium(0) }, + function(body, err) + if err then + notify.error("failed to get functions: " .. err.out) + return + end + + local ok, json = pcall(vim.fn.json_decode, body) + if ok and json then + for _, item in ipairs(json.functionCaptures) do + print("item: " .. item.nodeName) + if item.startLine <= row and item.endLine >= row then + self:request_chat_action(indent(item), callback) + return + end + end + end + end) +end + return Server diff --git a/lua/codeium/chat.lua b/lua/codeium/chat.lua index e3b7c28..57a7244 100644 --- a/lua/codeium/chat.lua +++ b/lua/codeium/chat.lua @@ -83,9 +83,10 @@ end -- codeium_common_pb.Language language = 2; -- string file_path = 3; -- string refactor_description = 4; -function chat.intent_function_refactor() +function chat.intent_function_refactor(func_info) local file_path = vim.api.nvim_buf_get_name(0) - return { function_refactor = { function_info = function_info(), language = language(), file_path = file_path, refactor_description = "" } } + local prompt = vim.fn.input("Refactor description: ") + return { function_refactor = { function_info = func_info, language = language(), file_path = file_path, refactor_description = prompt } } end -- codeium_common_pb.FunctionInfo function_info = 1; @@ -104,9 +105,9 @@ end -- codeium_common_pb.FunctionInfo function_info = 1; -- codeium_common_pb.Language language = 2; -- string file_path = 3; -function chat.intent_function_docstring() +function chat.intent_function_docstring(func_info) local file_path = vim.api.nvim_buf_get_name(0) - return { function_docstring = { function_info = function_info(), language = language(), file_path = file_path } } + return { function_docstring = { function_info = func_info, language = language(), file_path = file_path } } end --Ask to explain a generic piece of code. diff --git a/lua/codeium/config.lua b/lua/codeium/config.lua index c8a8877..221329b 100644 --- a/lua/codeium/config.lua +++ b/lua/codeium/config.lua @@ -75,7 +75,7 @@ end ---@field search_max_workspace_file_count number M.options = {} ----@param options codeium.options|nil +---@param options? codeium.options function M.setup(options) options = options or {} diff --git a/lua/codeium/init.lua b/lua/codeium/init.lua index 80ae065..dc1f54e 100644 --- a/lua/codeium/init.lua +++ b/lua/codeium/init.lua @@ -71,6 +71,10 @@ function M.explain() M.Server:request_explain_code() end +function M.refactor() + M.Server:request_refactor() +end + function M.add_docstring() M.Server:request_docstring() end diff --git a/lua/codeium/source.lua b/lua/codeium/source.lua index b91bcf2..38ba142 100644 --- a/lua/codeium/source.lua +++ b/lua/codeium/source.lua @@ -72,28 +72,13 @@ local function codeium_to_cmp(comp, offset, right) } end -local function buf_to_codeium(bufnr) - local filetype = enums.filetype_aliases[vim.bo[bufnr].filetype] or vim.bo[bufnr].filetype or "text" - local language = enums.languages[filetype] or enums.languages.unspecified - local line_ending = util.get_newline(bufnr) - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, true) - table.insert(lines, "") - local text = table.concat(lines, line_ending) - return { - editor_language = filetype, - language = language, - text = text, - line_ending = line_ending, - absolute_path = vim.api.nvim_buf_get_name(bufnr), - } -end local function get_other_documents(bufnr) local other_documents = {} for _, buf in ipairs(vim.api.nvim_list_bufs()) do if vim.api.nvim_buf_is_loaded(buf) and vim.bo[buf].filetype ~= '' and buf ~= bufnr then - table.insert(other_documents, buf_to_codeium(buf)) + table.insert(other_documents, util.buf_to_codeium(buf)) end end return other_documents diff --git a/lua/codeium/util.lua b/lua/codeium/util.lua index 769e2b1..57f7a53 100644 --- a/lua/codeium/util.lua +++ b/lua/codeium/util.lua @@ -35,4 +35,20 @@ function M.get_newline(bufnr) return enums.line_endings[vim.bo[bufnr].fileformat] or "\n" end +function M.buf_to_codeium(bufnr) + local filetype = enums.filetype_aliases[vim.bo[bufnr].filetype] or vim.bo[bufnr].filetype or "text" + local language = enums.languages[filetype] or enums.languages.unspecified + local line_ending = M.get_newline(bufnr) + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, true) + table.insert(lines, "") + local text = table.concat(lines, line_ending) + return { + editor_language = filetype, + language = language, + text = text, + line_ending = line_ending, + absolute_path = vim.api.nvim_buf_get_name(bufnr), + } +end + return M From 17e9cb7ebde317c1d049836258bcf1f160236d5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aliaksandr=20Tru=C5=A1?= Date: Sat, 4 May 2024 16:40:03 +0200 Subject: [PATCH 7/9] send GetChatMessage via websockets --- README.md | 2 + lua/codeium/api.lua | 117 +++++++++++++++++++++++-------------------- lua/codeium/chat.lua | 44 ++-------------- lua/codeium/init.lua | 8 +++ 4 files changed, 78 insertions(+), 93 deletions(-) diff --git a/README.md b/README.md index 5547e21..bb59780 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ use { requires = { "nvim-lua/plenary.nvim", "hrsh7th/nvim-cmp", + "rohanorton/ws.nvim", }, config = function() require("codeium").setup({ @@ -52,6 +53,7 @@ use { dependencies = { "nvim-lua/plenary.nvim", "hrsh7th/nvim-cmp", + "rohanorton/ws.nvim", }, config = function() require("codeium").setup({ diff --git a/lua/codeium/api.lua b/lua/codeium/api.lua index efdf211..3d958a3 100644 --- a/lua/codeium/api.lua +++ b/lua/codeium/api.lua @@ -6,10 +6,12 @@ local log = require("codeium.log") local update = require("codeium.update") local notify = require("codeium.notify") local utils = require("codeium.util") +local wsclient = require('ws.websocket_client') local api_key = nil local function noop(...) end +---@return string local function get_nonce() local possible = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" local nonce = "" @@ -76,6 +78,7 @@ end ---@field workspaces table ---@field healthy boolean ---@field pending_request table +---@field ws? WebSocketClient local Server = { _port = nil, job = nil, @@ -84,6 +87,7 @@ local Server = { workspaces = {}, healthy = false, pending_request = { 0, noop }, + ws = nil, } Server.__index = Server @@ -306,6 +310,9 @@ function Server:start() end) end +---@param fn string +---@param payload table +---@param callback function function Server:request(fn, payload, callback) if not self.port then notify.info("Server not started yet") @@ -506,95 +513,99 @@ end function Server:shutdown() self.current_cookie = nil + if self.ws then + self.ws.close() + end if self.job then self.job.on_exit = nil self.job:shutdown() end end -function Server:chat_server_request(fn, payload, callback) - local url = "http://127.0.0.1:" .. - self.chat_ports.chatWebServerPort .. "/exa.language_server_pb.LanguageServerService/" .. fn - local body = { metadata = get_request_metadata(), chat_message = payload } - io.post(url, { - body = body, - callback = callback, - }) +---@param payload table +function Server:chat_server_request(payload) + local body = { get_chat_message_request = { metadata = get_request_metadata(), chat_message = payload } } + local input_string = vim.fn.json_encode(body) + print("request: " .. input_string) + + -- self.ws.send(input_string, { is_binary = true }) + self.ws.send(input_string) + print("request sent") end ---@param indent table ---@param callback function -function Server:request_chat_action(indent, callback) +function Server:request_chat_action(indent) local current_timestamp = os.time() local message_id = "user-" .. tostring(current_timestamp) local body = { message_id = message_id, - source = 'CHAT_MESSAGE_SOURCE_USER', + -- source = 'CHAT_MESSAGE_SOURCE_USER', + source = 1, timestamp = current_timestamp, conversation_id = get_nonce(), content = { indent = indent }, in_progress = false } - self:chat_server_request("GetChatMessage", body, callback) + self:chat_server_request(body) end function Server:request_generate_code() - self:request_chat_action(chat.intent_generate_code(), function(body, err) - if err then - notify.error("Error code: " .. err.code) - notify.error("Error message: " .. err.out) - notify.error("Error status: " .. err.status) - notify.error("Error response: " .. err.response) - else - notify.info("Explain: " .. body) - end - end) + self:request_chat_action(chat.intent_generate_code()) end function Server:request_explain_code() - self:request_chat_action(chat.intent_code_block_explain(), function(body, err) - if err then - notify.error("Error code: " .. err.code) - notify.error("Error message: " .. err.out) - notify.error("Error status: " .. err.status) - notify.error("Error response: " .. err.response) - else - notify.info("Explain: " .. body) - end - end) + self:request_chat_action(chat.intent_code_block_explain()) end function Server:request_docstring() - self:request_function_action(chat.intent_function_docstring, function(body, err) - if err then - notify.error("Error code: " .. err.code) - notify.error("Error message: " .. err.out) - notify.error("Error status: " .. err.status) - notify.error("Error response: " .. err.response) - else - notify.info("Explain: " .. body) - end - end) + self:request_function_action(chat.intent_function_docstring) end - function Server:request_refactor() - self:request_function_action(chat.intent_function_refactor, function(body, err) - if err then - notify.error("Error code: " .. err.code) - notify.error("Error message: " .. err.out) - notify.error("Error status: " .. err.status) - notify.error("Error response: " .. err.response) + self:request_function_action(chat.intent_function_refactor) +end + +function Server:connect_ide() + local url = "ws://127.0.0.1:" .. self.chat_ports.chatWebServerPort .. "/connect/ide" + -- local url = "ws://echo.websocket.in/" + print("Connecting to " .. url) + local ws = wsclient(url) + + ws.on_close(function() + print("Websocket closed ") + end) + + ws.on_open(function() + print("Websocket open") + end) + + ws.on_error(function(err) + print("Websocket error " .. err) + end) + + ws.on_message(function(msg, is_binary) + if is_binary then + print("Binary message received") else - notify.info("Explain: " .. body) + local _msg = msg:to_string() + print("Message received " .. _msg) end end) + + -- Connect to server. + ws.connect() + + self.ws = ws +end + +function Server:close() + self.ws.close() end ---Request action for a function under cursor. ---@param indent function ----@param callback function -function Server:request_function_action(indent, callback) +function Server:request_function_action(indent) local row, _ = unpack(vim.api.nvim_win_get_cursor(0)) self:request("GetFunctions", { document = utils.buf_to_codeium(0) }, function(body, err) @@ -606,9 +617,9 @@ function Server:request_function_action(indent, callback) local ok, json = pcall(vim.fn.json_decode, body) if ok and json then for _, item in ipairs(json.functionCaptures) do - print("item: " .. item.nodeName) + -- print("item: " .. item.nodeName) if item.startLine <= row and item.endLine >= row then - self:request_chat_action(indent(item), callback) + self:request_chat_action(indent(item)) return end end diff --git a/lua/codeium/chat.lua b/lua/codeium/chat.lua index 57a7244..fd301a2 100644 --- a/lua/codeium/chat.lua +++ b/lua/codeium/chat.lua @@ -3,42 +3,6 @@ local util = require("codeium.util") ---@class codeium.Chat local chat = {} -function chat.intent_generic(text) - return { text = text } -end - --- string raw_source = 1; --- string clean_function = 2; --- string docstring = 3; --- string node_name = 4; --- string params = 5; --- int32 definition_line = 6; --- int32 start_line = 7; --- int32 end_line = 8; --- int32 start_col = 9; --- int32 end_col = 10; --- string leading_whitespace = 11; --- Language language = 12; -local function function_info() - local bufnr = vim.api.nvim_get_current_buf() - local filetype = enums.filetype_aliases[vim.bo[bufnr].filetype] or vim.bo[bufnr].filetype or "text" - local language = enums.languages[filetype] or enums.languages.unspecified - return { - raw_source = "", - clean_function = "", - docstring = "", - node_name = "", - params = "", - definition_line = 6, - start_line = 7, - end_line = 8, - start_col = 9, - end_col = 10, - leading_whitespace = "", - language = language - } -end - ---@return number local function language() local bufnr = vim.api.nvim_get_current_buf() @@ -75,8 +39,8 @@ end -- codeium_common_pb.FunctionInfo function_info = 1; -- codeium_common_pb.Language language = 2; -- string file_path = 3; -function chat.intent_function_explain() - return { explain_function = { function_info = function_info(), language = language(), file_path = vim.api.nvim_buf_get_name(0) } } +function chat.intent_function_explain(func_info) + return { explain_function = { function_info = func_info, language = language(), file_path = vim.api.nvim_buf_get_name(0) } } end -- codeium_common_pb.FunctionInfo function_info = 1; @@ -95,10 +59,10 @@ end -- -- --Optional additional instructions to inform what tests to generate. -- string instructions = 4; -function chat.intent_function_unit_tests() +function chat.intent_function_unit_tests(func_info) local prompt = vim.fn.input("Unit test instructions: ") local file_path = vim.api.nvim_buf_get_name(0) - return { function_unit_tests = { function_info = function_info(), language = language(), file_path = file_path, instructions = prompt } } + return { function_unit_tests = { function_info = func_info, language = language(), file_path = file_path, instructions = prompt } } end --Ask for a docstring for a function. diff --git a/lua/codeium/init.lua b/lua/codeium/init.lua index dc1f54e..fb68e42 100644 --- a/lua/codeium/init.lua +++ b/lua/codeium/init.lua @@ -79,4 +79,12 @@ function M.add_docstring() M.Server:request_docstring() end +function M.connect_ide() + M.Server:connect_ide() +end + +function M.stop() + M.Server:close() +end + return M From 40acd216f6042f427352f558b6d5cfd6435118f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aliaksandr=20Tru=C5=A1?= Date: Tue, 7 May 2024 22:23:44 +0200 Subject: [PATCH 8/9] minor updates to sending chat request --- lua/codeium/api.lua | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/lua/codeium/api.lua b/lua/codeium/api.lua index 3d958a3..ef6780f 100644 --- a/lua/codeium/api.lua +++ b/lua/codeium/api.lua @@ -301,7 +301,6 @@ function Server:start() end self.port = find_port(manager_dir, start_time) - -- port = 42100 if self.port then notify.info("Codeium server started on port " .. self.port) cancel() @@ -528,26 +527,26 @@ function Server:chat_server_request(payload) local input_string = vim.fn.json_encode(body) print("request: " .. input_string) - -- self.ws.send(input_string, { is_binary = true }) self.ws.send(input_string) print("request sent") end ----@param indent table ----@param callback function -function Server:request_chat_action(indent) - local current_timestamp = os.time() - local message_id = "user-" .. tostring(current_timestamp) - local body = { +---@param intent table +function Server:request_chat_action(intent) + local current_timestamp = { + seconds = os.time(), + nanos = os.clock() * 1000000000 -- Assuming you want nanoseconds precision + } + local message_id = "user-" .. tostring(current_timestamp.nanos) + local chat_message = { message_id = message_id, - -- source = 'CHAT_MESSAGE_SOURCE_USER', - source = 1, + source = 'CHAT_MESSAGE_SOURCE_USER', timestamp = current_timestamp, conversation_id = get_nonce(), - content = { indent = indent }, + intent = intent, in_progress = false } - self:chat_server_request(body) + self:chat_server_request(chat_message) end function Server:request_generate_code() @@ -555,7 +554,7 @@ function Server:request_generate_code() end function Server:request_explain_code() - self:request_chat_action(chat.intent_code_block_explain()) + self:request_function_action(chat.intent_function_explain) end function Server:request_docstring() @@ -568,7 +567,6 @@ end function Server:connect_ide() local url = "ws://127.0.0.1:" .. self.chat_ports.chatWebServerPort .. "/connect/ide" - -- local url = "ws://echo.websocket.in/" print("Connecting to " .. url) local ws = wsclient(url) @@ -604,8 +602,8 @@ function Server:close() end ---Request action for a function under cursor. ----@param indent function -function Server:request_function_action(indent) +---@param intent function +function Server:request_function_action(intent) local row, _ = unpack(vim.api.nvim_win_get_cursor(0)) self:request("GetFunctions", { document = utils.buf_to_codeium(0) }, function(body, err) @@ -619,7 +617,7 @@ function Server:request_function_action(indent) for _, item in ipairs(json.functionCaptures) do -- print("item: " .. item.nodeName) if item.startLine <= row and item.endLine >= row then - self:request_chat_action(indent(item)) + self:request_chat_action(intent(item)) return end end From 1c752057436a4787ccff1c6ff5e18a4f01d291be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aliaksandr=20Tru=C5=A1?= Date: Fri, 10 May 2024 11:27:00 +0200 Subject: [PATCH 9/9] fix chat_messages request attribute --- lua/codeium/api.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lua/codeium/api.lua b/lua/codeium/api.lua index ef6780f..1232ee6 100644 --- a/lua/codeium/api.lua +++ b/lua/codeium/api.lua @@ -523,7 +523,7 @@ end ---@param payload table function Server:chat_server_request(payload) - local body = { get_chat_message_request = { metadata = get_request_metadata(), chat_message = payload } } + local body = { get_chat_message_request = { metadata = get_request_metadata(), chat_messages = {payload} } } local input_string = vim.fn.json_encode(body) print("request: " .. input_string)