diff --git a/README.md b/README.md index 5547e21..bb59780 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ use { requires = { "nvim-lua/plenary.nvim", "hrsh7th/nvim-cmp", + "rohanorton/ws.nvim", }, config = function() require("codeium").setup({ @@ -52,6 +53,7 @@ use { dependencies = { "nvim-lua/plenary.nvim", "hrsh7th/nvim-cmp", + "rohanorton/ws.nvim", }, config = function() require("codeium").setup({ diff --git a/lua/codeium/api.lua b/lua/codeium/api.lua index 6217752..1232ee6 100644 --- a/lua/codeium/api.lua +++ b/lua/codeium/api.lua @@ -1,12 +1,29 @@ local versions = require("codeium.versions") +local chat = require("codeium.chat") local config = require("codeium.config") local io = require("codeium.io") local log = require("codeium.log") local update = require("codeium.update") local notify = require("codeium.notify") -local util = require("codeium.util") +local utils = require("codeium.util") +local wsclient = require('ws.websocket_client') local api_key = nil +local function noop(...) end + +---@return string +local function get_nonce() + local possible = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + local nonce = "" + + for _ = 1, 32 do + local randomIndex = math.random(1, #possible) + nonce = nonce .. string.sub(possible, randomIndex, randomIndex) + end + + return nonce +end + local function find_port(manager_dir, start_time) local files = io.readdir(manager_dir) @@ -36,7 +53,42 @@ local function get_request_metadata(request_id) } end -local Server = {} +local codeium_workspace_root_hints = { '.bzr', '.git', '.hg', '.svn', '_FOSSIL_', 'package.json' } +local function get_project_root() + local last_dir = '' + local dir = vim.fn.getcwd() + while dir ~= last_dir do + for root_hint in ipairs(codeium_workspace_root_hints) do + local hint = dir .. '/' .. root_hint + if vim.fn.isdirectory(hint) or vim.fn.filereadable(hint) then + return dir + end + end + last_dir = dir + dir = vim.fn.fnamemodify(dir, ':h') + end + return vim.fn.getcwd() +end + +---@class codeium.Server +---@field port? number +---@field job? plenary.Job +---@field chat_ports? table +---@field current_cookie? number +---@field workspaces table +---@field healthy boolean +---@field pending_request table +---@field ws? WebSocketClient +local Server = { + _port = nil, + job = nil, + chat_ports = { chatClientPort = nil, chatWebServerPort = nil }, + current_cookie = nil, + workspaces = {}, + healthy = false, + pending_request = { 0, noop }, + ws = nil, +} Server.__index = Server function Server.load_api_key() @@ -56,7 +108,7 @@ function Server.load_api_key() api_key = json.api_key end -function Server.save_api_key() +local function save_api_key() local _, result = io.write_json(config.options.config_path, { api_key = api_key, }) @@ -110,7 +162,7 @@ function Server.authenticate() end if json and json.api_key and json.api_key ~= "" then api_key = json.api_key - Server.save_api_key() + save_api_key() notify.info("api key saved") return end @@ -133,322 +185,444 @@ function Server.authenticate() prompt() end +---@return codeium.Server function Server:new() local m = {} setmetatable(m, self) - local o = {} - setmetatable(o, m) + m.__index = m + return m +end - local port = nil - local job = nil - local current_cookie = nil - local workspaces = {} - local healthy = false +function Server:start() + self:shutdown() - local function request(fn, payload, callback) - local url = "http://127.0.0.1:" .. port .. "/exa.language_server_pb.LanguageServerService/" .. fn - io.post(url, { - body = payload, - callback = callback, - }) - end + self.current_cookie = next_cookie() - local function do_heartbeat() - request("Heartbeat", { - metadata = get_request_metadata(), - }, function(_, err) - if err then - notify.warn("heartbeat failed", err) - else - healthy = true - end - end) + if not api_key then + io.timer(1000, 0, self.start) + return end - function m.is_healthy() - return healthy + local manager_dir = config.options.manager_path + if not manager_dir then + manager_dir = io.tempdir("codeium/manager") + vim.fn.mkdir(manager_dir, "p") end - function m.start() - m.shutdown() + local start_time = io.touch(manager_dir .. "/start") - current_cookie = next_cookie() - - if not api_key then - io.timer(1000, 0, m.start) + local function on_exit(_, err) + if not self.current_cookie then return end - local manager_dir = config.manager_path - if not manager_dir then - manager_dir = io.tempdir("codeium/manager") - vim.fn.mkdir(manager_dir, "p") + self.healthy = false + if err then + self.job = nil + self.current_cookie = nil + + notify.error("codeium server crashed", err) + io.timer(1000, 0, function() + log.debug("restarting server after crash") + self:start() + end) end + end - local start_time = io.touch(manager_dir .. "/start") + local function on_output(_, v, j) + log.debug(j.pid .. ": " .. v) + end - local function on_exit(_, err) - if not current_cookie then - return - end + local api_server_url = "https://" + .. config.options.api.host + .. ":" + .. config.options.api.port + .. (config.options.api.path and "/" .. config.options.api.path:gsub("^/", "") or "") + + local job_args = { + update.get_bin_info().bin, + "--api_server_url", + api_server_url, + "--manager_dir", + manager_dir, + enable_handlers = true, + enable_recording = false, + on_exit = on_exit, + on_stdout = on_output, + on_stderr = on_output, + } - healthy = false - if err then - job = nil - current_cookie = nil - - notify.error("codeium server crashed", err) - io.timer(1000, 0, function() - log.debug("restarting server after crash") - m.start() - end) - end - end + if config.options.enable_chat then + table.insert(job_args, "--enable_chat_web_server") + table.insert(job_args, "--enable_chat_client") + end - local function on_output(_, v, j) - log.debug(j.pid .. ": " .. v) - end + if config.options.enable_local_search then + table.insert(job_args, "--enable_local_search") + end - local api_server_url = "https://" - .. config.options.api.host - .. ":" - .. config.options.api.port - .. (config.options.api.path and "/" .. config.options.api.path:gsub("^/", "") or "") - - local job_args = { - update.get_bin_info().bin, - "--api_server_url", - api_server_url, - "--manager_dir", - manager_dir, - enable_handlers = true, - enable_recording = false, - on_exit = on_exit, - on_stdout = on_output, - on_stderr = on_output, - } - - if config.options.enable_chat then - table.insert(job_args, "--enable_chat_web_server") - table.insert(job_args, "--enable_chat_client") - end + if config.options.enable_index_service then + table.insert(job_args, "--enable_index_service") + table.insert(job_args, "--search_max_workspace_file_count") + table.insert(job_args, config.options.search_max_workspace_file_count) + end - if config.options.enable_local_search then - table.insert(job_args, "--enable_local_search") - end + if config.options.api.portal_url then + table.insert(job_args, "--portal_url") + table.insert(job_args, "https://" .. config.options.api.portal_url) + end - if config.options.enable_index_service then - table.insert(job_args, "--enable_index_service") - table.insert(job_args, "--search_max_workspace_file_count") - table.insert(job_args, config.options.search_max_workspace_file_count) - end + if config.options.enterprise_mode then + table.insert(job_args, "--enterprise_mode") + end - if config.options.api.portal_url then - table.insert(job_args, "--portal_url") - table.insert(job_args, "https://" .. config.options.api.portal_url) - end + if config.options.detect_proxy ~= nil then + table.insert(job_args, "--detect_proxy=" .. tostring(config.options.detect_proxy)) + end - if config.options.enterprise_mode then - table.insert(job_args, "--enterprise_mode") + local job = io.job(job_args) + job:start() + + local function start_heartbeat() + io.timer(100, 5000, function(cancel_heartbeat) + if not self.current_cookie then + cancel_heartbeat() + else + self:do_heartbeat() + end + end) + end + + io.timer(100, 500, function(cancel) + if not self.current_cookie then + cancel() + return end - if config.options.detect_proxy ~= nil then - table.insert(job_args, "--detect_proxy=" .. tostring(config.options.detect_proxy)) + self.port = find_port(manager_dir, start_time) + if self.port then + notify.info("Codeium server started on port " .. self.port) + cancel() + start_heartbeat() end + end) +end - local job = io.job(job_args) - job:start() +---@param fn string +---@param payload table +---@param callback function +function Server:request(fn, payload, callback) + if not self.port then + notify.info("Server not started yet") + return + end + local url = "http://127.0.0.1:" .. self.port .. "/exa.language_server_pb.LanguageServerService/" .. fn + io.post(url, { + body = payload, + callback = callback, + }) +end - local function start_heartbeat() - io.timer(100, 5000, function(cancel_heartbeat) - if not current_cookie then - cancel_heartbeat() - else - do_heartbeat() - end - end) +function Server:init_chat() + io.timer(200, 500, function(cancel) + if not self.port then + return end - - io.timer(100, 500, function(cancel) - if not current_cookie then + self:request("GetProcesses", { + metadata = get_request_metadata(), + }, function(body, err) + if err then + notify.error("failed to get chat ports", err) cancel() return end - - port = find_port(manager_dir, start_time) - if port then - cancel() - start_heartbeat() - end + self.chat_ports = vim.fn.json_decode(body) + notify.info("Codeium chat ready to use on server ports: client port " .. + self.chat_ports.chatClientPort .. " and server port " .. self.chat_ports.chatWebServerPort) + cancel() end) - end + end) +end - local function noop(...) end +function Server:request_completion(document, editor_options, other_documents, callback) + self.pending_request[2](true) - local pending_request = { 0, noop } - function m.request_completion(document, editor_options, other_documents, callback) - pending_request[2](true) + local metadata = get_request_metadata() + local this_pending_request - local metadata = get_request_metadata() - local this_pending_request + local complete + complete = function(...) + complete = noop + this_pending_request(false) + callback(...) + end - local complete - complete = function(...) - complete = noop - this_pending_request(false) - callback(...) + this_pending_request = function(is_complete) + if self.pending_request[1] == metadata.request_id then + self.pending_request = { 0, noop } end + this_pending_request = noop - this_pending_request = function(is_complete) - if pending_request[1] == metadata.request_id then - pending_request = { 0, noop } + self:request("CancelRequest", { + metadata = get_request_metadata(), + request_id = metadata.request_id, + }, function(_, err) + if err then + log.warn("failed to cancel in-flight request", err) end - this_pending_request = noop - - request("CancelRequest", { - metadata = get_request_metadata(), - request_id = metadata.request_id, - }, function(_, err) - if err then - log.warn("failed to cancel in-flight request", err) - end - end) + end) - if is_complete then - complete(false, nil) - end + if is_complete then + complete(false, nil) end - pending_request = { metadata.request_id, this_pending_request } + end + self.pending_request = { metadata.request_id, this_pending_request } + + self:request("GetCompletions", { + metadata = metadata, + editor_options = editor_options, + document = document, + other_documents = other_documents, + }, function(body, err) + if err then + if err.status == 503 or err.status == 408 then + -- Service Unavailable or Timeout error + return complete(false, nil) + end - request("GetCompletions", { - metadata = metadata, - editor_options = editor_options, - document = document, - other_documents = other_documents, - }, function(body, err) - if err then - if err.status == 503 or err.status == 408 then - -- Service Unavailable or Timeout error + local ok, json = pcall(vim.fn.json_decode, err.response.body) + if ok and json then + if json.state and json.state.state == "CODEIUM_STATE_INACTIVE" then + if json.state.message then + log.debug("completion request failed", json.state.message) + end return complete(false, nil) end - - local ok, json = pcall(vim.fn.json_decode, err.response.body) - if ok and json then - if json.state and json.state.state == "CODEIUM_STATE_INACTIVE" then - if json.state.message then - log.debug("completion request failed", json.state.message) - end - return complete(false, nil) - end - if json.code == "canceled" then - log.debug("completion request cancelled at the server", json.message) - return complete(false, nil) - end + if json.code == "canceled" then + log.debug("completion request cancelled at the server", json.message) + return complete(false, nil) end - - notify.error("completion request failed", err) - complete(false, nil) - return end - local ok, json = pcall(vim.fn.json_decode, body) - if not ok then - notify.error("completion request failed", "invalid JSON:", json) - return - end + notify.error("completion request failed", err) + complete(false, nil) + return + end - log.trace("completion: ", json) - complete(true, json) - end) + local ok, json = pcall(vim.fn.json_decode, body) + if not ok then + notify.error("completion request failed", "invalid JSON:", json) + return + end - return function() - this_pending_request(true) + log.trace("completion: ", json) + complete(true, json) + end) + + return function() + this_pending_request(true) + end +end + +function Server:accept_completion(completion_id) + self:request("AcceptCompletion", { + metadata = get_request_metadata(), + completion_id = completion_id, + }, noop) +end + +function Server:do_heartbeat() + self:request("Heartbeat", { + metadata = get_request_metadata(), + }, function(_, err) + if err then + notify.warn("heartbeat failed", err) + else + self.healthy = true end + end) +end + +function Server:is_healthy() + return self.healthy +end + +function Server:open_chat() + if self.chat_ports == nil then + notify.error("chat ports not found") + return + end + local url = "http://127.0.0.1:" + .. self.chat_ports.chatClientPort + .. "?api_key=" + .. api_key + .. "&has_enterprise_extension=" + .. (config.options.enterprise_mode and "true" or "false") + .. "&web_server_url=ws://127.0.0.1:" + .. self.chat_ports.chatWebServerPort + .. "&ide_name=neovim" + .. "&ide_version=" + .. versions.nvim + .. "&app_name=codeium.nvim" + .. "&extension_name=codeium.nvim" + .. "&extension_version=" + .. versions.extension + .. "&ide_telemetry_enabled=true" + .. "&has_index_service=" + .. (config.options.enable_index_service and "true" or "false") + .. "&locale=en_US" + + -- cross-platform solution to open the web app + local os_info = io.get_system_info() + if os_info.os == "linux" then + os.execute("xdg-open '" .. url .. "'") + elseif os_info.os == "macos" then + os.execute("open '" .. url .. "'") + elseif os_info.os == "windows" then + os.execute("start " .. url) + else + notify.error("Unsupported operating system") end +end - function m.accept_completion(completion_id) - request("AcceptCompletion", { - metadata = get_request_metadata(), - completion_id = completion_id, - }, noop) +function Server:add_workspace() + local project_root = get_project_root() + -- workspace already tracked by server + if self.workspaces[project_root] then + return end - function m.add_workspace() - local project_root = vim.fn.getcwd() - -- workspace already tracked by server - if workspaces[project_root] then + io.timer(300, 500, function(cancel) + if not self.port then return end - -- unable to track hidden path - for entry in project_root:gmatch("[^/]+") do - if entry:sub(1, 1) == "." then - return - end - end + self:request("AddTrackedWorkspace", { workspace = project_root, metadata = get_request_metadata() }, + function(_, err) + if err then + notify.error("failed to add workspace: " .. err.out) + return + end + self.workspaces[project_root] = true + notify.info("Workspace " .. project_root .. " added") + end) + cancel() + end) +end - request("AddTrackedWorkspace", { workspace = project_root, metadata = get_request_metadata() }, function(_, err) - if err then - notify.error("failed to add workspace: " .. err.out) - return - end - workspaces[project_root] = true - end) +function Server:shutdown() + self.current_cookie = nil + if self.ws then + self.ws.close() + end + if self.job then + self.job.on_exit = nil + self.job:shutdown() end +end - function m.get_chat_ports() - request("GetProcesses", { - metadata = get_request_metadata(), - }, function(body, err) +---@param payload table +function Server:chat_server_request(payload) + local body = { get_chat_message_request = { metadata = get_request_metadata(), chat_messages = {payload} } } + local input_string = vim.fn.json_encode(body) + print("request: " .. input_string) + + self.ws.send(input_string) + print("request sent") +end + +---@param intent table +function Server:request_chat_action(intent) + local current_timestamp = { + seconds = os.time(), + nanos = os.clock() * 1000000000 -- Assuming you want nanoseconds precision + } + local message_id = "user-" .. tostring(current_timestamp.nanos) + local chat_message = { + message_id = message_id, + source = 'CHAT_MESSAGE_SOURCE_USER', + timestamp = current_timestamp, + conversation_id = get_nonce(), + intent = intent, + in_progress = false + } + self:chat_server_request(chat_message) +end + +function Server:request_generate_code() + self:request_chat_action(chat.intent_generate_code()) +end + +function Server:request_explain_code() + self:request_function_action(chat.intent_function_explain) +end + +function Server:request_docstring() + self:request_function_action(chat.intent_function_docstring) +end + +function Server:request_refactor() + self:request_function_action(chat.intent_function_refactor) +end + +function Server:connect_ide() + local url = "ws://127.0.0.1:" .. self.chat_ports.chatWebServerPort .. "/connect/ide" + print("Connecting to " .. url) + local ws = wsclient(url) + + ws.on_close(function() + print("Websocket closed ") + end) + + ws.on_open(function() + print("Websocket open") + end) + + ws.on_error(function(err) + print("Websocket error " .. err) + end) + + ws.on_message(function(msg, is_binary) + if is_binary then + print("Binary message received") + else + local _msg = msg:to_string() + print("Message received " .. _msg) + end + end) + + -- Connect to server. + ws.connect() + + self.ws = ws +end + +function Server:close() + self.ws.close() +end + +---Request action for a function under cursor. +---@param intent function +function Server:request_function_action(intent) + local row, _ = unpack(vim.api.nvim_win_get_cursor(0)) + self:request("GetFunctions", { document = utils.buf_to_codeium(0) }, + function(body, err) if err then - notify.error("failed to get chat ports", err) + notify.error("failed to get functions: " .. err.out) return end - local ports = vim.fn.json_decode(body) - local url = "http://127.0.0.1:" - .. ports.chatClientPort - .. "?api_key=" - .. api_key - .. "&has_enterprise_extension=" - .. (config.options.enterprise_mode and "true" or "false") - .. "&web_server_url=ws://127.0.0.1:" - .. ports.chatWebServerPort - .. "&ide_name=neovim" - .. "&ide_version=" - .. versions.nvim - .. "&app_name=codeium.nvim" - .. "&extension_name=codeium.nvim" - .. "&extension_version=" - .. versions.extension - .. "&ide_telemetry_enabled=true" - .. "&has_index_service=" - .. (config.options.enable_index_service and "true" or "false") - .. "&locale=en_US" - - -- cross-platform solution to open the web app - local os_info = io.get_system_info() - if os_info.os == "linux" then - os.execute("xdg-open '" .. url .. "'") - elseif os_info.os == "macos" then - os.execute("open '" .. url .. "'") - elseif os_info.os == "windows" then - os.execute("start " .. url) - else - notify.error("Unsupported operating system") + + local ok, json = pcall(vim.fn.json_decode, body) + if ok and json then + for _, item in ipairs(json.functionCaptures) do + -- print("item: " .. item.nodeName) + if item.startLine <= row and item.endLine >= row then + self:request_chat_action(intent(item)) + return + end + end end end) - end - - function m.shutdown() - current_cookie = nil - if job then - job.on_exit = nil - job:shutdown() - end - end - - m.__index = m - return o end return Server diff --git a/lua/codeium/chat.lua b/lua/codeium/chat.lua new file mode 100644 index 0000000..fd301a2 --- /dev/null +++ b/lua/codeium/chat.lua @@ -0,0 +1,130 @@ +local enums = require("codeium.enums") +local util = require("codeium.util") +---@class codeium.Chat +local chat = {} + +---@return number +local function language() + local bufnr = vim.api.nvim_get_current_buf() + local filetype = enums.filetype_aliases[vim.bo[bufnr].filetype] or vim.bo[bufnr].filetype or "text" + return enums.languages[filetype] or enums.languages.unspecified +end + +-- string raw_source = 1; +-- +-- // Start position of the code block. +-- int32 start_line = 2; +-- int32 start_col = 3; +-- +-- // End position of the code block. +-- int32 end_line = 4; +-- int32 end_col = 5; +local function code_block_info() + -- Get the current buffer + local current_buffer = vim.api.nvim_get_current_buf() + + -- Get the start and end positions of the visual selection + local start_line, start_col = unpack(vim.api.nvim_buf_get_mark(current_buffer, "<")) + local end_line, end_col = unpack(vim.api.nvim_buf_get_mark(current_buffer, ">")) + + local lines = vim.api.nvim_buf_get_lines(current_buffer, start_line, end_line, true) + local line_ending = util.get_newline(current_buffer) + table.insert(lines, "") + local text = table.concat(lines, line_ending) + + return { raw_source = text, start_line = start_line, start_col = start_col, end_line = end_line, end_col = end_col } +end + + +-- codeium_common_pb.FunctionInfo function_info = 1; +-- codeium_common_pb.Language language = 2; +-- string file_path = 3; +function chat.intent_function_explain(func_info) + return { explain_function = { function_info = func_info, language = language(), file_path = vim.api.nvim_buf_get_name(0) } } +end + +-- codeium_common_pb.FunctionInfo function_info = 1; +-- codeium_common_pb.Language language = 2; +-- string file_path = 3; +-- string refactor_description = 4; +function chat.intent_function_refactor(func_info) + local file_path = vim.api.nvim_buf_get_name(0) + local prompt = vim.fn.input("Refactor description: ") + return { function_refactor = { function_info = func_info, language = language(), file_path = file_path, refactor_description = prompt } } +end + +-- codeium_common_pb.FunctionInfo function_info = 1; +-- codeium_common_pb.Language language = 2; +-- string file_path = 3; +-- +-- --Optional additional instructions to inform what tests to generate. +-- string instructions = 4; +function chat.intent_function_unit_tests(func_info) + local prompt = vim.fn.input("Unit test instructions: ") + local file_path = vim.api.nvim_buf_get_name(0) + return { function_unit_tests = { function_info = func_info, language = language(), file_path = file_path, instructions = prompt } } +end + +--Ask for a docstring for a function. +-- codeium_common_pb.FunctionInfo function_info = 1; +-- codeium_common_pb.Language language = 2; +-- string file_path = 3; +function chat.intent_function_docstring(func_info) + local file_path = vim.api.nvim_buf_get_name(0) + return { function_docstring = { function_info = func_info, language = language(), file_path = file_path } } +end + +--Ask to explain a generic piece of code. +-- CodeBlockInfo code_block_info = 1; +-- codeium_common_pb.Language language = 2; +-- string file_path = 3; +function chat.intent_code_block_explain() + local file_path = vim.api.nvim_buf_get_name(0) + return { code_block_explain = { code_block_info = code_block_info(), language = language(), file_path = file_path } } +end + +--Ask to refactor a generic piece of code. +-- CodeBlockInfo code_block_info = 1; +-- codeium_common_pb.Language language = 2; +-- string file_path = 3; +-- string refactor_description = 4; +function chat.intent_code_block_refactor() + local file_path = vim.api.nvim_buf_get_name(0) + return { code_block_refactor = { code_block_info = code_block_info(), language = language(), file_path = file_path, refactor_description = "" } } +end + +--Ask to explain a problem. +-- string diagnostic_function chat.= 1; +-- CodeBlockInfo problematic_code = 2; //entire code block with error +-- string surrounding_code_snippet = 3; +-- codeium_common_pb.Language language = 4; +-- string file_path = 5; +-- int32 line_number = 6; +function chat.intent_problem_explain() + local file_path = vim.api.nvim_buf_get_name(0) + return { + problem_explain = { + diagnostic_function = "", + problematic_code = code_block_info(), + surrounding_code_snippet = "", + language = language(), + file_path = file_path, + line_number = 0 + } + } +end + +--Ask to generate a piece of code. +-- string instruction = 1; +-- codeium_common_pb.Language language = 2; +-- string file_path = 3; +-- --Line to insert the generated code into. +-- int32 line_number = 4; +function chat.intent_generate_code() + local prompt = vim.fn.input("Please add prompt:") + local file_path = vim.api.nvim_buf_get_name(0) + local line_number = vim.api.nvim_win_get_cursor(0)[1] + return { generate_code = { instruction = prompt, language = language(), file_path = file_path, line_number = line_number } } +end + +return chat diff --git a/lua/codeium/config.lua b/lua/codeium/config.lua index 9bcf1de..221329b 100644 --- a/lua/codeium/config.lua +++ b/lua/codeium/config.lua @@ -1,7 +1,10 @@ local notify = require("codeium.notify") +---@class codeium.config +---@field options codeium.options local M = {} +---@return codeium.options function M.defaults() return { manager_path = nil, @@ -56,8 +59,23 @@ function M.apply_conditional_defaults(options) return options end +---@class codeium.options +---@field manager_path string +---@field bin_path string +---@field config_path string +---@field language_server_download_url string +---@field api table +---@field enterprise_mode boolean +---@field detect_proxy boolean +---@field tools table +---@field wrapper function +---@field enable_chat boolean +---@field enable_local_search boolean +---@field enable_index_service boolean +---@field search_max_workspace_file_count number M.options = {} +---@param options? codeium.options function M.setup(options) options = options or {} diff --git a/lua/codeium/init.lua b/lua/codeium/init.lua index f57046d..fb68e42 100644 --- a/lua/codeium/init.lua +++ b/lua/codeium/init.lua @@ -1,41 +1,90 @@ -local M = {} +local notify = require("codeium.notify") + +---@class codeium +---@field Server codeium.Server|nil +---@field Config codeium.options|nil +local M = { + Server = nil, + Config = nil +} function M.setup(options) - local Source = require("codeium.source") - local Server = require("codeium.api") + local source = require("codeium.source") + local server = require("codeium.api") local update = require("codeium.update") - require("codeium.config").setup(options) + local config = require("codeium.config") + config.setup(options) + M.Config = config.options - local s = Server:new() + M.Server = server:new() update.download(function(err) if not err then - Server.load_api_key() - s.start() + M.Server.load_api_key() + M.Server:start() + if config.options.enable_chat then + M.Server:init_chat() + end + M.Server:add_workspace() end end) vim.api.nvim_create_user_command("Codeium", function(opts) local args = opts.fargs if args[1] == "Auth" then - Server.authenticate() + M.Server.authenticate() end if args[1] == "Chat" then - s.get_chat_ports() - s.add_workspace() + M.Server:open_chat() + M.Server:add_workspace() end end, { nargs = 1, complete = function() local commands = {"Auth"} - if require("codeium.config").options.enable_chat then + if config.options.enable_chat then commands = vim.list_extend(commands, {"Chat"}) end return commands end, }) - local source = Source:new(s) - require("cmp").register_source("codeium", source) + require("cmp").register_source("codeium", source:new(M.Server)) +end + +function M.open_chat() + if not M.Config.enable_chat then + notify.info("Codeium Chat disabled") + return + end + M.Server:open_chat() +end + +function M.add_workspace() + M.Server:add_workspace() +end + +function M.generate_code() + M.Server:request_generate_code() +end + +function M.explain() + M.Server:request_explain_code() +end + +function M.refactor() + M.Server:request_refactor() +end + +function M.add_docstring() + M.Server:request_docstring() +end + +function M.connect_ide() + M.Server:connect_ide() +end + +function M.stop() + M.Server:close() end return M diff --git a/lua/codeium/io.lua b/lua/codeium/io.lua index 7cf009b..efe4265 100644 --- a/lua/codeium/io.lua +++ b/lua/codeium/io.lua @@ -6,6 +6,7 @@ local curl = require("plenary.curl") local config = require("codeium.config") local default_mod = 438 -- 666 +---@class codeium.io local M = {} local function check_job(job, status) diff --git a/lua/codeium/notify.lua b/lua/codeium/notify.lua index b22dda2..6eeb22f 100644 --- a/lua/codeium/notify.lua +++ b/lua/codeium/notify.lua @@ -1,5 +1,6 @@ local log = require("codeium.log") +---@class codeium.notify local M = {} local opts = { title = "Codeium", diff --git a/lua/codeium/source.lua b/lua/codeium/source.lua index 2c0eb4f..38ba142 100644 --- a/lua/codeium/source.lua +++ b/lua/codeium/source.lua @@ -72,33 +72,20 @@ local function codeium_to_cmp(comp, offset, right) } end -local function buf_to_codeium(bufnr) - local filetype = enums.filetype_aliases[vim.bo[bufnr].filetype] or vim.bo[bufnr].filetype or "text" - local language = enums.languages[filetype] or enums.languages.unspecified - local line_ending = util.get_newline(bufnr) - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, true) - table.insert(lines, "") - local text = table.concat(lines, line_ending) - return { - editor_language = filetype, - language = language, - text = text, - line_ending = line_ending, - absolute_path = vim.api.nvim_buf_get_name(bufnr), - } -end local function get_other_documents(bufnr) local other_documents = {} for _, buf in ipairs(vim.api.nvim_list_bufs()) do if vim.api.nvim_buf_is_loaded(buf) and vim.bo[buf].filetype ~= '' and buf ~= bufnr then - table.insert(other_documents, buf_to_codeium(buf)) + table.insert(other_documents, util.buf_to_codeium(buf)) end end return other_documents end +---@class codeium.Source +---@field server codeium.Server|nil local Source = { server = nil, } @@ -113,7 +100,7 @@ function Source:new(server) end function Source:is_available() - return self.server.is_healthy() + return self.server:is_healthy() end function Source:get_position_encoding_kind() @@ -130,7 +117,7 @@ require("cmp").event:on("confirm_done", function(event) and event.entry.source.source and event.entry.source.source.server then - event.entry.source.source.server.accept_completion(event.entry.completion_item.codeium_completion_id) + event.entry.source.source.server:accept_completion(event.entry.completion_item.codeium_completion_id) end end) @@ -180,7 +167,7 @@ function Source:complete(params, callback) local other_documents = get_other_documents(bufnr) - self.server.request_completion( + self.server:request_completion( { editor_language = filetype, language = language, diff --git a/lua/codeium/update.lua b/lua/codeium/update.lua index ea2e54b..e65ff67 100644 --- a/lua/codeium/update.lua +++ b/lua/codeium/update.lua @@ -2,6 +2,8 @@ local config = require("codeium.config") local versions = require("codeium.versions") local io = require("codeium.io") local notify = require("codeium.notify") + +---@class codeium.update local M = {} local cached = nil diff --git a/lua/codeium/util.lua b/lua/codeium/util.lua index e401799..57f7a53 100644 --- a/lua/codeium/util.lua +++ b/lua/codeium/util.lua @@ -1,4 +1,5 @@ local enums = require("codeium.enums") +---@class codeium.util local M = {} function M.fallback_call(calls, with_filter, fallback_value) @@ -34,4 +35,20 @@ function M.get_newline(bufnr) return enums.line_endings[vim.bo[bufnr].fileformat] or "\n" end +function M.buf_to_codeium(bufnr) + local filetype = enums.filetype_aliases[vim.bo[bufnr].filetype] or vim.bo[bufnr].filetype or "text" + local language = enums.languages[filetype] or enums.languages.unspecified + local line_ending = M.get_newline(bufnr) + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, true) + table.insert(lines, "") + local text = table.concat(lines, line_ending) + return { + editor_language = filetype, + language = language, + text = text, + line_ending = line_ending, + absolute_path = vim.api.nvim_buf_get_name(bufnr), + } +end + return M