diff --git a/lua/codeium/api.lua b/lua/codeium/api.lua index 233d519..b0e9d75 100644 --- a/lua/codeium/api.lua +++ b/lua/codeium/api.lua @@ -12,6 +12,8 @@ local status = { api_key_error = nil, } +local function noop(...) end + local function find_port(manager_dir, start_time) local files = io.readdir(manager_dir) @@ -41,7 +43,30 @@ local function get_request_metadata(request_id) } end -local Server = {} +--- +--- Codeium Server API +--- @class Server +--- @field port? number +--- @field job? plenary.job +--- @field current_cookie? number +--- @field workspaces table +--- @field healthy boolean +--- @field last_heartbeat? number +--- @field last_heartbeat_error? string +--- @field enabled boolean +--- @field pending_request table +local Server = { + port = nil, + job = nil, + current_cookie = nil, + workspaces = {}, + healthy = false, + last_heartbeat = nil, + last_heartbeat_error = nil, + enabled = true, + pending_request = { 0, noop }, +} + Server.__index = Server function Server.check_status() @@ -152,407 +177,394 @@ function Server.authenticate() prompt() end -function Server:new() - local m = {} - setmetatable(m, self) - - local o = {} - setmetatable(o, m) - - local port = nil - local job = nil - local current_cookie = nil - local workspaces = {} - local healthy = false - local last_heartbeat = nil - local last_heartbeat_error = nil - local enabled = true - - local function request(fn, payload, callback) - local url = "http://127.0.0.1:" .. port .. "/exa.language_server_pb.LanguageServerService/" .. fn - io.post(url, { - body = payload, - callback = callback, - }) - end +---@return Server +function Server.new() + local m = setmetatable({}, Server) + m.__index = m + return m +end - local function do_heartbeat() - request("Heartbeat", { - metadata = get_request_metadata(), - }, function(_, err) - last_heartbeat = os.time() - last_heartbeat_error = nil - if err then - notify.warn("heartbeat failed", err) - last_heartbeat_error = err - else - healthy = true - end - end) +---@param fn string +---@param payload table +---@param callback function +function Server:request(fn, payload, callback) + local url = "http://127.0.0.1:" .. self.port .. "/exa.language_server_pb.LanguageServerService/" .. fn + io.post(url, { + body = payload, + callback = callback, + }) +end + +function Server:start() + self:shutdown() + + self.current_cookie = next_cookie() + + if not api_key then + io.timer(1000, 0, self:start()) + return end - function m.is_healthy() - return healthy + local manager_dir = config.manager_path + if not manager_dir then + manager_dir = io.tempdir("codeium/manager") + vim.fn.mkdir(manager_dir, "p") end - function m.checkhealth(logger) - logger.info("Checking server status") - if m.is_healthy() then - logger.ok("Server is healthy on port: " .. port) - else - logger.warn("Server is unhealthy") + local start_time = io.touch(manager_dir .. "/start") + + local function on_exit(_, err) + if not self.current_cookie then + return end - logger.info("Language Server binary: " .. update.get_bin_info().bin) + self.healthy = false + if err then + self.job = nil + self.current_cookie = nil - if last_heartbeat == nil then - logger.warn("No heartbeat executed") - else - logger.info("Last heartbeat: " .. os.date("%D %H:%M:%S", last_heartbeat)) - if last_heartbeat_error ~= nil then - logger.error(last_heartbeat_error) - else - logger.ok("Heartbeat ok") - end + notify.error("codeium server crashed", err) + io.timer(1000, 0, function() + log.debug("restarting server after crash") + self:start() + end) end end - function m.start() - m.shutdown() + local function on_output(_, v, j) + log.debug(j.pid .. ": " .. v) + end - current_cookie = next_cookie() + local api_server_url = "https://" + .. config.options.api.host + .. ":" + .. config.options.api.port + .. (config.options.api.path and "/" .. config.options.api.path:gsub("^/", "") or "") + + local job_args = { + update.get_bin_info().bin, + "--api_server_url", + api_server_url, + "--manager_dir", + manager_dir, + "--file_watch_max_dir_count", + config.options.file_watch_max_dir_count, + enable_handlers = true, + enable_recording = false, + on_exit = on_exit, + on_stdout = on_output, + on_stderr = on_output, + } - if not api_key then - io.timer(1000, 0, m.start) - return - end + if config.options.enable_chat then + table.insert(job_args, "--enable_chat_web_server") + table.insert(job_args, "--enable_chat_client") + end - local manager_dir = config.manager_path - if not manager_dir then - manager_dir = io.tempdir("codeium/manager") - vim.fn.mkdir(manager_dir, "p") - end + if config.options.enable_local_search then + table.insert(job_args, "--enable_local_search") + end - local start_time = io.touch(manager_dir .. "/start") + if config.options.enable_index_service then + table.insert(job_args, "--enable_index_service") + table.insert(job_args, "--search_max_workspace_file_count") + table.insert(job_args, config.options.search_max_workspace_file_count) + end - local function on_exit(_, err) - if not current_cookie then - return - end + if config.options.api.portal_url then + table.insert(job_args, "--portal_url") + table.insert(job_args, "https://" .. config.options.api.portal_url) + end - healthy = false - if err then - job = nil - current_cookie = nil - - notify.error("codeium server crashed", err) - io.timer(1000, 0, function() - log.debug("restarting server after crash") - m.start() - end) - end - end + if config.options.enterprise_mode then + table.insert(job_args, "--enterprise_mode") + end - local function on_output(_, v, j) - log.debug(j.pid .. ": " .. v) - end + if config.options.detect_proxy ~= nil then + table.insert(job_args, "--detect_proxy=" .. tostring(config.options.detect_proxy)) + end - local api_server_url = "https://" - .. config.options.api.host - .. ":" - .. config.options.api.port - .. (config.options.api.path and "/" .. config.options.api.path:gsub("^/", "") or "") - - local job_args = { - update.get_bin_info().bin, - "--api_server_url", - api_server_url, - "--manager_dir", - manager_dir, - "--file_watch_max_dir_count", - config.options.file_watch_max_dir_count, - enable_handlers = true, - enable_recording = false, - on_exit = on_exit, - on_stdout = on_output, - on_stderr = on_output, - } - - if config.options.enable_chat then - table.insert(job_args, "--enable_chat_web_server") - table.insert(job_args, "--enable_chat_client") - end + self.job = io.job(job_args) + self.job:start() - if config.options.enable_local_search then - table.insert(job_args, "--enable_local_search") - end + local function start_heartbeat() + io.timer(100, 5000, function(cancel_heartbeat) + if not self.current_cookie then + cancel_heartbeat() + else + self:do_heartbeat() + end + end) + end - if config.options.enable_index_service then - table.insert(job_args, "--enable_index_service") - table.insert(job_args, "--search_max_workspace_file_count") - table.insert(job_args, config.options.search_max_workspace_file_count) + io.timer(100, 500, function(cancel) + if not self.current_cookie then + cancel() + return end - if config.options.api.portal_url then - table.insert(job_args, "--portal_url") - table.insert(job_args, "https://" .. config.options.api.portal_url) + self.port = find_port(manager_dir, start_time) + if self.port then + cancel() + start_heartbeat() end + end) +end - if config.options.enterprise_mode then - table.insert(job_args, "--enterprise_mode") - end +function Server:is_healthy() + return self.healthy +end - if config.options.detect_proxy ~= nil then - table.insert(job_args, "--detect_proxy=" .. tostring(config.options.detect_proxy)) - end +function Server:checkhealth(logger) + logger.info("Checking server status") + if self:is_healthy() then + logger.ok("Server is healthy on port: " .. self.port) + else + logger.warn("Server is unhealthy") + end - local job = io.job(job_args) - job:start() + logger.info("Language Server binary: " .. update.get_bin_info().bin) - local function start_heartbeat() - io.timer(100, 5000, function(cancel_heartbeat) - if not current_cookie then - cancel_heartbeat() - else - do_heartbeat() - end - end) + if self.last_heartbeat == nil then + logger.warn("No heartbeat executed") + else + logger.info("Last heartbeat: " .. os.date("%D %H:%M:%S", self.last_heartbeat)) + if self.last_heartbeat_error ~= nil then + logger.error(self.last_heartbeat_error) + else + logger.ok("Heartbeat ok") end + end +end - io.timer(100, 500, function(cancel) - if not current_cookie then - cancel() - return - end +function Server:do_heartbeat() + self:request("Heartbeat", { + metadata = get_request_metadata(), + }, function(_, err) + self.last_heartbeat = os.time() + self.last_heartbeat_error = nil + if err then + notify.warn("heartbeat failed", err) + self.last_heartbeat_error = err + else + self.healthy = true + end + end) +end - port = find_port(manager_dir, start_time) - if port then - cancel() - start_heartbeat() - end - end) +function Server:request_completion(document, editor_options, other_documents, callback) + if not self.enabled then + return end + self.pending_request[2](true) - local function noop(...) end - - local pending_request = { 0, noop } - function m.request_completion(document, editor_options, other_documents, callback) - if enabled == false then - return - end - pending_request[2](true) + local metadata = get_request_metadata() + local this_pending_request - local metadata = get_request_metadata() - local this_pending_request + local complete + complete = function(...) + complete = noop + this_pending_request(false) + callback(...) + end - local complete - complete = function(...) - complete = noop - this_pending_request(false) - callback(...) + this_pending_request = function(is_complete) + if self.pending_request[1] == metadata.request_id then + self.pending_request = { 0, noop } end + this_pending_request = noop - this_pending_request = function(is_complete) - if pending_request[1] == metadata.request_id then - pending_request = { 0, noop } + self:request("CancelRequest", { + metadata = get_request_metadata(), + request_id = metadata.request_id, + }, function(_, err) + if err then + log.warn("failed to cancel in-flight request", err) end - this_pending_request = noop - - request("CancelRequest", { - metadata = get_request_metadata(), - request_id = metadata.request_id, - }, function(_, err) - if err then - log.warn("failed to cancel in-flight request", err) - end - end) + end) - if is_complete then - complete(false, nil) - end + if is_complete then + complete(false, nil) end - pending_request = { metadata.request_id, this_pending_request } - - request("GetCompletions", { - metadata = metadata, - editor_options = editor_options, - document = document, - other_documents = other_documents, - }, function(body, err) - if err then - if err.status == 503 or err.status == 408 then - -- Service Unavailable or Timeout error - return complete(false, nil) - end + end + self.pending_request = { metadata.request_id, this_pending_request } + + self:request("GetCompletions", { + metadata = metadata, + editor_options = editor_options, + document = document, + other_documents = other_documents, + }, function(body, err) + if err then + if err.status == 503 or err.status == 408 then + -- Service Unavailable or Timeout error + return complete(false, nil) + end - local ok, json = pcall(vim.fn.json_decode, err.response.body) - if ok and json then - if json.state and json.state.state == "CODEIUM_STATE_INACTIVE" then - if json.state.message then - log.debug("completion request failed", json.state.message) - end - return complete(false, nil) - end - if json.code == "canceled" then - log.debug("completion request cancelled at the server", json.message) - return complete(false, nil) + local ok, json = pcall(vim.fn.json_decode, err.response.body) + if ok and json then + if json.state and json.state.state == "CODEIUM_STATE_INACTIVE" then + if json.state.message then + log.debug("completion request failed", json.state.message) end + return complete(false, nil) + end + if json.code == "canceled" then + log.debug("completion request cancelled at the server", json.message) + return complete(false, nil) end - - notify.error("completion request failed", err) - complete(false, nil) - return - end - - local ok, json = pcall(vim.fn.json_decode, body) - if not ok then - notify.error("completion request failed", "invalid JSON:", json) - return end - log.trace("completion: ", json) - complete(true, json) - end) + notify.error("completion request failed", err) + complete(false, nil) + return + end - return function() - this_pending_request(true) + local ok, json = pcall(vim.fn.json_decode, body) + if not ok then + notify.error("completion request failed", "invalid JSON:", json) + return end - end - function m.accept_completion(completion_id) - request("AcceptCompletion", { - metadata = get_request_metadata(), - completion_id = completion_id, - }, noop) + log.trace("completion: ", json) + complete(true, json) + end) + + return function() + this_pending_request(true) end +end - function m.refresh_context() - -- bufnr for current buffer is 0 - local bufnr = 0 +function Server:accept_completion(completion_id) + self:request("AcceptCompletion", { + metadata = get_request_metadata(), + completion_id = completion_id, + }, noop) +end - local line_ending = util.get_newline(bufnr) - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, true) +function Server:refresh_context() + -- bufnr for current buffer is 0 + local bufnr = 0 - -- Ensure that there is always a newline at the end of the file - table.insert(lines, "") - local text = table.concat(lines, line_ending) + local line_ending = util.get_newline(bufnr) + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, true) - local filetype = vim.bo.filetype - local language = enums.languages[filetype] or enums.languages.unspecified + -- Ensure that there is always a newline at the end of the file + table.insert(lines, "") + local text = table.concat(lines, line_ending) - local doc = { - editor_language = filetype, - language = language, - cursor_offset = 0, - text = text, - line_ending = line_ending, - absolute_uri = util.get_uri(vim.api.nvim_buf_get_name(bufnr)), - workspace_uri = util.get_uri(util.get_project_root()), - } + local filetype = vim.bo.filetype + local language = enums.languages[filetype] or enums.languages.unspecified - request("RefreshContextForIdeAction", { - active_document = doc, - }, function(_, err) - if err then - notify.error("failed refresh context: " .. err.out) - return - end - end) - end + local doc = { + editor_language = filetype, + language = language, + cursor_offset = 0, + text = text, + line_ending = line_ending, + absolute_uri = util.get_uri(vim.api.nvim_buf_get_name(bufnr)), + workspace_uri = util.get_uri(util.get_project_root()), + } - function m.add_workspace() - local project_root = util.get_project_root() - -- workspace already tracked by server - if workspaces[project_root] then + self:request("RefreshContextForIdeAction", { + active_document = doc, + }, function(_, err) + if err then + notify.error("failed refresh context: " .. err.out) return end - -- unable to track hidden path - for entry in project_root:gmatch("[^/]+") do - if entry:sub(1, 1) == "." then - return - end - end - - request("AddTrackedWorkspace", { workspace = project_root }, function(_, err) - if err then - notify.error("failed to add workspace: " .. err.out) - return - end - workspaces[project_root] = true - end) - end + end) +end - function m.get_chat_ports() - request("GetProcesses", { - metadata = get_request_metadata(), - }, function(body, err) - if err then - notify.error("failed to get chat ports", err) - return - end - local ports = vim.fn.json_decode(body) - local url = "http://127.0.0.1:" - .. ports.chatClientPort - .. "?api_key=" - .. api_key - .. "&has_enterprise_extension=" - .. (config.options.enterprise_mode and "true" or "false") - .. "&web_server_url=ws://127.0.0.1:" - .. ports.chatWebServerPort - .. "&ide_name=neovim" - .. "&ide_version=" - .. versions.nvim - .. "&app_name=codeium.nvim" - .. "&extension_name=codeium.nvim" - .. "&extension_version=" - .. versions.extension - .. "&ide_telemetry_enabled=true" - .. "&has_index_service=" - .. (config.options.enable_index_service and "true" or "false") - .. "&locale=en_US" - - -- cross-platform solution to open the web app - local os_info = io.get_system_info() - if os_info.os == "linux" then - os.execute("xdg-open '" .. url .. "'") - elseif os_info.os == "macos" then - os.execute("open '" .. url .. "'") - elseif os_info.os == "windows" then - os.execute(string.format('start "" "%s"', url)) - else - notify.error("Unsupported operating system") - end - end) +function Server:add_workspace() + local project_root = util.get_project_root() + -- workspace already tracked by server + if self.workspaces[project_root] then + return end - - function m.shutdown() - current_cookie = nil - if job then - job.on_exit = nil - job:shutdown() + -- unable to track hidden path + for entry in project_root:gmatch("[^/]+") do + if entry:sub(1, 1) == "." then + return end end - function m.enable() - enabled = true - notify.info("Codeium enabled") - end - - function m.disable() - enabled = false - notify.info("Codeium disabled") - end + self:request("AddTrackedWorkspace", { workspace = project_root }, function(_, err) + if err then + notify.error("failed to add workspace: " .. err.out) + return + end + self.workspaces[project_root] = true + end) +end - function m.toggle() - if enabled then - m.disable() +function Server:get_chat_ports() + self:request("GetProcesses", { + metadata = get_request_metadata(), + }, function(body, err) + if err then + notify.error("failed to get chat ports", err) + return + end + local ports = vim.fn.json_decode(body) + local url = "http://127.0.0.1:" + .. ports.chatClientPort + .. "?api_key=" + .. api_key + .. "&has_enterprise_extension=" + .. (config.options.enterprise_mode and "true" or "false") + .. "&web_server_url=ws://127.0.0.1:" + .. ports.chatWebServerPort + .. "&ide_name=neovim" + .. "&ide_version=" + .. versions.nvim + .. "&app_name=codeium.nvim" + .. "&extension_name=codeium.nvim" + .. "&extension_version=" + .. versions.extension + .. "&ide_telemetry_enabled=true" + .. "&has_index_service=" + .. (config.options.enable_index_service and "true" or "false") + .. "&locale=en_US" + + -- cross-platform solution to open the web app + local os_info = io.get_system_info() + if os_info.os == "linux" then + os.execute("xdg-open '" .. url .. "'") + elseif os_info.os == "macos" then + os.execute("open '" .. url .. "'") + elseif os_info.os == "windows" then + os.execute(string.format('start "" "%s"', url)) else - m.enable() + notify.error("Unsupported operating system") end + end) +end + +function Server:shutdown() + self.current_cookie = nil + if self.job then + self.job.on_exit = nil + self.job:shutdown() end +end - m.__index = m - return o +function Server:enable() + self.enabled = true + notify.info("Codeium enabled") +end + +function Server:disable() + self.enabled = false + notify.info("Codeium disabled") +end + +function Server:toggle() + if self.enabled then + self:disable() + else + self:enable() + end end return Server diff --git a/lua/codeium/health.lua b/lua/codeium/health.lua index 2025cb4..711f925 100644 --- a/lua/codeium/health.lua +++ b/lua/codeium/health.lua @@ -13,7 +13,7 @@ local error = vim.health.error or vim.health.report_error local info = vim.health.info or vim.health.report_info local health_logger = { ok = ok, info = info, warn = warn, error = error } -local checkhealth = nil +local instance = nil function M.check() start("Codeium: checking Codeium server status") @@ -24,15 +24,16 @@ function M.check() ok("API key properly loaded") end - if checkhealth == nil then + if instance == nil then warn("Codeium: checkhealth is not set") return end - checkhealth(health_logger) + instance:checkhealth(health_logger) end -function M.register(callback) - checkhealth = callback +---@param server Server +function M.register(server) + instance = server end return M diff --git a/lua/codeium/init.lua b/lua/codeium/init.lua index b2f5ab0..775c7dc 100644 --- a/lua/codeium/init.lua +++ b/lua/codeium/init.lua @@ -7,14 +7,14 @@ function M.setup(options) local health = require("codeium.health") require("codeium.config").setup(options) - M.s = Server:new() + M.s = Server.new() update.download(function(err) if not err then Server.load_api_key() - M.s.start() + M.s:start() end end) - health.register(M.s.checkhealth) + health.register(M.s) vim.api.nvim_create_user_command("Codeium", function(opts) local args = opts.fargs @@ -22,12 +22,10 @@ function M.setup(options) Server.authenticate() end if args[1] == "Chat" then - M.s.refresh_context() - M.s.get_chat_ports() - M.s.add_workspace() + M.chat() end if args[1] == "Toggle" then - M.s.toggle() + M.toggle() end end, { nargs = 1, @@ -40,32 +38,32 @@ function M.setup(options) end, }) - local source = Source:new(s) + local source = Source:new(M.s) if require("codeium.config").options.enable_cmp_source then require("cmp").register_source("codeium", source) end - require("codeium.virtual_text").setup(s) + require("codeium.virtual_text").setup(M.s) end --- Open Codeium Chat function M.chat() - M.s.refresh_context() - M.s.get_chat_ports() - M.s.add_workspace() + M.s:refresh_context() + M.s:get_chat_ports() + M.s:add_workspace() end --- Toggle the Codeium plugin function M.toggle() - M.s.toggle() + M.s:toggle() end function M.enable() - M.s.enable() + M.s:enable() end function M.disable() - M.s.disable() + M.s:disable() end return M diff --git a/lua/codeium/source.lua b/lua/codeium/source.lua index dcbee60..9b42d4e 100644 --- a/lua/codeium/source.lua +++ b/lua/codeium/source.lua @@ -86,7 +86,7 @@ function Source:new(server) end function Source:is_available() - return self.server.is_healthy() + return self.server:is_healthy() end function Source:get_position_encoding_kind() @@ -106,7 +106,7 @@ if imported_cmp then and event.entry.source.source and event.entry.source.source.server then - event.entry.source.source.server.accept_completion(event.entry.completion_item.codeium_completion_id) + event.entry.source.source.server:accept_completion(event.entry.completion_item.codeium_completion_id) end end) end @@ -157,7 +157,7 @@ function Source:complete(params, callback) local other_documents = util.get_other_documents(bufnr) - self.server.request_completion( + self.server:request_completion( { text = text, editor_language = filetype,