From d17bc1f87736b6a7f058b2f246e651d34d648b47 Mon Sep 17 00:00:00 2001 From: uga-rosa Date: Mon, 8 Jan 2024 20:34:49 +0900 Subject: [PATCH] Use trie tree --- lua/cmp_dictionary/caches.lua | 118 ++++++++++++-------------------- lua/cmp_dictionary/db.lua | 25 +++---- lua/cmp_dictionary/lib/trie.lua | 66 ++++++++++++++++++ lua/cmp_dictionary/util.lua | 19 +---- 4 files changed, 122 insertions(+), 106 deletions(-) create mode 100644 lua/cmp_dictionary/lib/trie.lua diff --git a/lua/cmp_dictionary/caches.lua b/lua/cmp_dictionary/caches.lua index b8e7547..3bdcce9 100644 --- a/lua/cmp_dictionary/caches.lua +++ b/lua/cmp_dictionary/caches.lua @@ -1,14 +1,13 @@ local util = require("cmp_dictionary.util") local lfu = require("cmp_dictionary.lfu") local config = require("cmp_dictionary.config") -local utf8 = require("cmp_dictionary.lib.utf8") -local Async = require("cmp_dictionary.kit.Async") -local Worker = require("cmp_dictionary.kit.Thread.Worker") +local Trie = require("cmp_dictionary.lib.trie") ---@class DictionaryData ----@field items lsp.CompletionItem[] +---@field trie Trie ---@field mtime integer ---@field path string +---@field detail string local Caches = { ---@type DictionaryData[] @@ -23,71 +22,55 @@ local dictCache = lfu.init(config.get("capacity")) local function need_to_load() local dictionaries = util.get_dictionaries() local updated_or_new = {} - for _, dict in ipairs(dictionaries) do - local path = vim.fn.expand(dict) - if util.bool_fn.filereadable(path) then - local mtime = vim.fn.getftime(path) - local cache = dictCache:get(path) - if cache and cache.mtime == mtime then - table.insert(Caches.valid, cache) - else - table.insert(updated_or_new, { path = path, mtime = mtime }) - end + for _, path in ipairs(dictionaries) do + local mtime = vim.fn.getftime(path) + local cache = dictCache:get(path) + if cache and cache.mtime == mtime then + table.insert(Caches.valid, cache) + else + table.insert(updated_or_new, { path = path, mtime = mtime }) end end return updated_or_new end ----Create dictionary data from buffers ---@param path string ----@param name string ----@return lsp.CompletionItem[] items -local read_items = Worker.new(function(path, name) - local buffer = require("cmp_dictionary.util").read_file_sync(path) - - local items = {} - local detail = ("belong to `%s`"):format(name) +---@param mtime integer +local function cache_update(path, mtime) + local buffer = util.read_file_sync(path) + local trie = Trie.new() for w in vim.gsplit(buffer, "%s+") do if w ~= "" then - table.insert(items, { label = w, detail = detail }) + trie:insert(w) end end - table.sort(items, function(item1, item2) - return item1.label < item2.label - end) - - return items -end) ----@param path string ----@param mtime integer ----@return cmp_dictionary.kit.Async.AsyncTask -local function cache_update(path, mtime) local name = vim.fn.fnamemodify(path, ":t") - return read_items(path, name):next(function(items) - local cache = { - items = items, - mtime = mtime, - path = path, - } - dictCache:set(path, cache) - table.insert(Caches.valid, cache) - end) + local cache = { + trie = trie, + mtime = mtime, + path = path, + detail = ("belong to `%s`"):format(name), + } + + dictCache:set(path, cache) + table.insert(Caches.valid, cache) end +local update_on_going = false local function update() - local buftype = vim.api.nvim_buf_get_option(0, "buftype") - if buftype ~= "" then + local buftype = vim.api.nvim_get_option_value("buftype", { buf = 0 }) + if buftype ~= "" or update_on_going then return end + update_on_going = true Caches.valid = {} - - Async.all(vim.tbl_map(function(n) - return cache_update(n.path, n.mtime) - end, need_to_load())):next(function() - just_updated = true - end) + for _, n in ipairs(need_to_load()) do + cache_update(n.path, n.mtime) + end + just_updated = true + update_on_going = false end function Caches.update() @@ -102,37 +85,20 @@ function Caches.request(req, isIncomplete) local items = {} isIncomplete = isIncomplete or false - local ok, offset, codepoint - ok, offset = pcall(utf8.offset, req, -1) - if not ok then - return items, isIncomplete - end - ok, codepoint = pcall(utf8.codepoint, req, offset) - if not ok then - return items, isIncomplete - end - - local req_next = req:sub(1, offset - 1) .. utf8.char(codepoint + 1) - - local max_items = config.get("max_items") + local max_items = config.get("max_items") --[[@as integer]] for _, cache in pairs(Caches.valid) do - local start = util.binary_search(cache.items, req, function(vector, index, key) - return vector[index].label >= key - end) - local last = util.binary_search(cache.items, req_next, function(vector, index, key) - return vector[index].label >= key - end) - 1 - if start > 0 and last > 0 and start <= last then - if max_items > 0 and last >= start + max_items then - last = start + max_items + local words = cache.trie:search(req, max_items) + for i = 1, #words do + if max_items >= 0 and #items >= max_items then isIncomplete = true + goto done end - for i = start, last do - local item = cache.items[i] - table.insert(items, item) - end + local item = { label = words[i], detail = cache.detail } + table.insert(items, item) end end + ::done:: + return items, isIncomplete end diff --git a/lua/cmp_dictionary/db.lua b/lua/cmp_dictionary/db.lua index 3321d65..9840020 100644 --- a/lua/cmp_dictionary/db.lua +++ b/lua/cmp_dictionary/db.lua @@ -69,19 +69,16 @@ end local function need_to_load(db) local dictionaries = util.get_dictionaries() local updated_or_new = {} - for _, dictionary in ipairs(dictionaries) do - local path = vim.fn.expand(dictionary) - if util.bool_fn.filereadable(path) then - local mtime = vim.fn.getftime(path) - local mtime_cache = db:select("dictionary", { select = "mtime", where = { filepath = path } }) - if mtime_cache[1] and mtime_cache[1].mtime == mtime then - db:update("dictionary", { - set = { valid = 1 }, - where = { filepath = path }, - }) - else - table.insert(updated_or_new, { path = path, mtime = mtime }) - end + for _, path in ipairs(dictionaries) do + local mtime = vim.fn.getftime(path) + local mtime_cache = db:select("dictionary", { select = "mtime", where = { filepath = path } }) + if mtime_cache[1] and mtime_cache[1].mtime == mtime then + db:update("dictionary", { + set = { valid = 1 }, + where = { filepath = path }, + }) + else + table.insert(updated_or_new, { path = path, mtime = mtime }) end end return updated_or_new @@ -101,7 +98,7 @@ local read_items = Worker.new(function(path, name) end) local function update(db) - local buftype = vim.api.nvim_buf_get_option(0, "buftype") + local buftype = vim.api.nvim_get_option_value("buftype", { buf = 0 }) if buftype ~= "" then return end diff --git a/lua/cmp_dictionary/lib/trie.lua b/lua/cmp_dictionary/lib/trie.lua new file mode 100644 index 0000000..067be16 --- /dev/null +++ b/lua/cmp_dictionary/lib/trie.lua @@ -0,0 +1,66 @@ +---@class TrieNode +---@field children table +---@field end_of_word boolean +local TrieNode = {} + +---@return TrieNode +function TrieNode.new() + return { children = {}, end_of_word = false } +end + +---@class Trie +---@field root TrieNode +local Trie = {} + +---@return Trie +function Trie.new() + return setmetatable({ + root = TrieNode.new(), + }, { __index = Trie }) +end + +---@param word string +function Trie:insert(word) + local current = self.root + for char in vim.gsplit(word, "") do + local node = current.children[char] or TrieNode.new() + current.children[char] = node + current = node + end + current.end_of_word = true +end + +---@private +---@param node TrieNode +---@param prefix string +---@param word_list string[] +---@param limit integer +function Trie:search_prefix(node, prefix, word_list, limit) + if limit >= 0 and #word_list >= limit then + return + end + if node.end_of_word then + table.insert(word_list, prefix) + end + for char, child in pairs(node.children) do + self:search_prefix(child, prefix .. char, word_list, limit) + end +end + +---@param prefix string +---@param limit integer +---@return string[] +function Trie:search(prefix, limit) + local node = self.root + for char in vim.gsplit(prefix, "") do + node = node.children[char] + if node == nil then + return {} + end + end + local word_list = {} + self:search_prefix(node, prefix, word_list, limit) + return word_list +end + +return Trie diff --git a/lua/cmp_dictionary/util.lua b/lua/cmp_dictionary/util.lua index 91a21dd..99ea228 100644 --- a/lua/cmp_dictionary/util.lua +++ b/lua/cmp_dictionary/util.lua @@ -39,8 +39,9 @@ function M.get_dictionaries() local dict = {} for _, al in ipairs({ global, local_ }) do for _, d in ipairs(al) do - if vim.fn.filereadable(vim.fn.expand(d)) == 1 then - table.insert(dict, d) + local path = vim.fn.expand(d) --[[@as string]] + if vim.fn.filereadable(path) == 1 then + table.insert(dict, path) end end end @@ -101,18 +102,4 @@ function M.debounce(name, callback, timeout) ) end -M.bool_fn = setmetatable({}, { - __index = function(_, key) - return function(...) - local v = vim.fn[key](...) - if not v or v == 0 or v == "" then - return false - elseif type(v) == "table" and next(v) == nil then - return false - end - return true - end - end, -}) - return M