Skip to content

Commit

Permalink
Use trie tree
Browse files Browse the repository at this point in the history
  • Loading branch information
uga-rosa committed Jan 8, 2024
1 parent 363ce91 commit d17bc1f
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 106 deletions.
118 changes: 42 additions & 76 deletions lua/cmp_dictionary/caches.lua
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
local util = require("cmp_dictionary.util")
local lfu = require("cmp_dictionary.lfu")
local config = require("cmp_dictionary.config")
local utf8 = require("cmp_dictionary.lib.utf8")
local Async = require("cmp_dictionary.kit.Async")
local Worker = require("cmp_dictionary.kit.Thread.Worker")
local Trie = require("cmp_dictionary.lib.trie")

---@class DictionaryData
---@field items lsp.CompletionItem[]
---@field trie Trie
---@field mtime integer
---@field path string
---@field detail string

local Caches = {
---@type DictionaryData[]
Expand All @@ -23,71 +22,55 @@ local dictCache = lfu.init(config.get("capacity"))
local function need_to_load()
local dictionaries = util.get_dictionaries()
local updated_or_new = {}
for _, dict in ipairs(dictionaries) do
local path = vim.fn.expand(dict)
if util.bool_fn.filereadable(path) then
local mtime = vim.fn.getftime(path)
local cache = dictCache:get(path)
if cache and cache.mtime == mtime then
table.insert(Caches.valid, cache)
else
table.insert(updated_or_new, { path = path, mtime = mtime })
end
for _, path in ipairs(dictionaries) do
local mtime = vim.fn.getftime(path)
local cache = dictCache:get(path)
if cache and cache.mtime == mtime then
table.insert(Caches.valid, cache)
else
table.insert(updated_or_new, { path = path, mtime = mtime })
end
end
return updated_or_new
end

---Create dictionary data from buffers
---@param path string
---@param name string
---@return lsp.CompletionItem[] items
local read_items = Worker.new(function(path, name)
local buffer = require("cmp_dictionary.util").read_file_sync(path)

local items = {}
local detail = ("belong to `%s`"):format(name)
---@param mtime integer
local function cache_update(path, mtime)
local buffer = util.read_file_sync(path)
local trie = Trie.new()
for w in vim.gsplit(buffer, "%s+") do
if w ~= "" then
table.insert(items, { label = w, detail = detail })
trie:insert(w)
end
end
table.sort(items, function(item1, item2)
return item1.label < item2.label
end)

return items
end)

---@param path string
---@param mtime integer
---@return cmp_dictionary.kit.Async.AsyncTask
local function cache_update(path, mtime)
local name = vim.fn.fnamemodify(path, ":t")
return read_items(path, name):next(function(items)
local cache = {
items = items,
mtime = mtime,
path = path,
}
dictCache:set(path, cache)
table.insert(Caches.valid, cache)
end)
local cache = {
trie = trie,
mtime = mtime,
path = path,
detail = ("belong to `%s`"):format(name),
}

dictCache:set(path, cache)
table.insert(Caches.valid, cache)
end

local update_on_going = false
local function update()
local buftype = vim.api.nvim_buf_get_option(0, "buftype")
if buftype ~= "" then
local buftype = vim.api.nvim_get_option_value("buftype", { buf = 0 })
if buftype ~= "" or update_on_going then
return
end
update_on_going = true

Caches.valid = {}

Async.all(vim.tbl_map(function(n)
return cache_update(n.path, n.mtime)
end, need_to_load())):next(function()
just_updated = true
end)
for _, n in ipairs(need_to_load()) do
cache_update(n.path, n.mtime)
end
just_updated = true
update_on_going = false
end

function Caches.update()
Expand All @@ -102,37 +85,20 @@ function Caches.request(req, isIncomplete)
local items = {}
isIncomplete = isIncomplete or false

local ok, offset, codepoint
ok, offset = pcall(utf8.offset, req, -1)
if not ok then
return items, isIncomplete
end
ok, codepoint = pcall(utf8.codepoint, req, offset)
if not ok then
return items, isIncomplete
end

local req_next = req:sub(1, offset - 1) .. utf8.char(codepoint + 1)

local max_items = config.get("max_items")
local max_items = config.get("max_items") --[[@as integer]]
for _, cache in pairs(Caches.valid) do
local start = util.binary_search(cache.items, req, function(vector, index, key)
return vector[index].label >= key
end)
local last = util.binary_search(cache.items, req_next, function(vector, index, key)
return vector[index].label >= key
end) - 1
if start > 0 and last > 0 and start <= last then
if max_items > 0 and last >= start + max_items then
last = start + max_items
local words = cache.trie:search(req, max_items)
for i = 1, #words do
if max_items >= 0 and #items >= max_items then
isIncomplete = true
goto done
end
for i = start, last do
local item = cache.items[i]
table.insert(items, item)
end
local item = { label = words[i], detail = cache.detail }
table.insert(items, item)
end
end
::done::

return items, isIncomplete
end

Expand Down
25 changes: 11 additions & 14 deletions lua/cmp_dictionary/db.lua
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,16 @@ end
local function need_to_load(db)
local dictionaries = util.get_dictionaries()
local updated_or_new = {}
for _, dictionary in ipairs(dictionaries) do
local path = vim.fn.expand(dictionary)
if util.bool_fn.filereadable(path) then
local mtime = vim.fn.getftime(path)
local mtime_cache = db:select("dictionary", { select = "mtime", where = { filepath = path } })
if mtime_cache[1] and mtime_cache[1].mtime == mtime then
db:update("dictionary", {
set = { valid = 1 },
where = { filepath = path },
})
else
table.insert(updated_or_new, { path = path, mtime = mtime })
end
for _, path in ipairs(dictionaries) do
local mtime = vim.fn.getftime(path)
local mtime_cache = db:select("dictionary", { select = "mtime", where = { filepath = path } })
if mtime_cache[1] and mtime_cache[1].mtime == mtime then
db:update("dictionary", {
set = { valid = 1 },
where = { filepath = path },
})
else
table.insert(updated_or_new, { path = path, mtime = mtime })
end
end
return updated_or_new
Expand All @@ -101,7 +98,7 @@ local read_items = Worker.new(function(path, name)
end)

local function update(db)
local buftype = vim.api.nvim_buf_get_option(0, "buftype")
local buftype = vim.api.nvim_get_option_value("buftype", { buf = 0 })
if buftype ~= "" then
return
end
Expand Down
66 changes: 66 additions & 0 deletions lua/cmp_dictionary/lib/trie.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
---@class TrieNode
---@field children table<string, TrieNode>
---@field end_of_word boolean
local TrieNode = {}

---@return TrieNode
function TrieNode.new()
return { children = {}, end_of_word = false }
end

---@class Trie
---@field root TrieNode
local Trie = {}

---@return Trie
function Trie.new()
return setmetatable({
root = TrieNode.new(),
}, { __index = Trie })
end

---@param word string
function Trie:insert(word)
local current = self.root
for char in vim.gsplit(word, "") do
local node = current.children[char] or TrieNode.new()
current.children[char] = node
current = node
end
current.end_of_word = true
end

---@private
---@param node TrieNode
---@param prefix string
---@param word_list string[]
---@param limit integer
function Trie:search_prefix(node, prefix, word_list, limit)
if limit >= 0 and #word_list >= limit then
return
end
if node.end_of_word then
table.insert(word_list, prefix)
end
for char, child in pairs(node.children) do
self:search_prefix(child, prefix .. char, word_list, limit)
end
end

---@param prefix string
---@param limit integer
---@return string[]
function Trie:search(prefix, limit)
local node = self.root
for char in vim.gsplit(prefix, "") do
node = node.children[char]
if node == nil then
return {}
end
end
local word_list = {}
self:search_prefix(node, prefix, word_list, limit)
return word_list
end

return Trie
19 changes: 3 additions & 16 deletions lua/cmp_dictionary/util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ function M.get_dictionaries()
local dict = {}
for _, al in ipairs({ global, local_ }) do
for _, d in ipairs(al) do
if vim.fn.filereadable(vim.fn.expand(d)) == 1 then
table.insert(dict, d)
local path = vim.fn.expand(d) --[[@as string]]
if vim.fn.filereadable(path) == 1 then
table.insert(dict, path)
end
end
end
Expand Down Expand Up @@ -101,18 +102,4 @@ function M.debounce(name, callback, timeout)
)
end

M.bool_fn = setmetatable({}, {
__index = function(_, key)
return function(...)
local v = vim.fn[key](...)
if not v or v == 0 or v == "" then
return false
elseif type(v) == "table" and next(v) == nil then
return false
end
return true
end
end,
})

return M

0 comments on commit d17bc1f

Please sign in to comment.