diff options
Diffstat (limited to 'start/cmp/lua/cmp/utils')
24 files changed, 2233 insertions, 0 deletions
diff --git a/start/cmp/lua/cmp/utils/api.lua b/start/cmp/lua/cmp/utils/api.lua new file mode 100644 index 0000000..d053409 --- /dev/null +++ b/start/cmp/lua/cmp/utils/api.lua @@ -0,0 +1,69 @@ +local api = {} + +local CTRL_V = vim.api.nvim_replace_termcodes('<C-v>', true, true, true) +local CTRL_S = vim.api.nvim_replace_termcodes('<C-s>', true, true, true) + +api.get_mode = function() + local mode = vim.api.nvim_get_mode().mode:sub(1, 1) + if mode == 'i' then + return 'i' -- insert + elseif mode == 'v' or mode == 'V' or mode == CTRL_V then + return 'x' -- visual + elseif mode == 's' or mode == 'S' or mode == CTRL_S then + return 's' -- select + elseif mode == 'c' and vim.fn.getcmdtype() ~= '=' then + return 'c' -- cmdline + end +end + +api.is_insert_mode = function() + return api.get_mode() == 'i' +end + +api.is_cmdline_mode = function() + return api.get_mode() == 'c' +end + +api.is_select_mode = function() + return api.get_mode() == 's' +end + +api.is_visual_mode = function() + return api.get_mode() == 'x' +end + +api.is_suitable_mode = function() + local mode = api.get_mode() + return mode == 'i' or mode == 'c' +end + +api.get_current_line = function() + if api.is_cmdline_mode() then + return vim.fn.getcmdline() + end + return vim.api.nvim_get_current_line() +end + +api.get_cursor = function() + if api.is_cmdline_mode() then + return { vim.o.lines - (vim.api.nvim_get_option('cmdheight') or 1) + 1, vim.fn.getcmdpos() - 1 } + end + return vim.api.nvim_win_get_cursor(0) +end + +api.get_screen_cursor = function() + if api.is_cmdline_mode() then + local cursor = api.get_cursor() + return { cursor[1], cursor[2] + 1 } + end + local cursor = api.get_cursor() + local pos = vim.fn.screenpos(0, cursor[1], cursor[2] + 1) + return { pos.row, pos.col - 1 } +end + +api.get_cursor_before_line = function() + local cursor = api.get_cursor() + return string.sub(api.get_current_line(), 1, cursor[2]) +end + +return api diff --git a/start/cmp/lua/cmp/utils/api_spec.lua b/start/cmp/lua/cmp/utils/api_spec.lua new file mode 100644 index 0000000..5363b48 --- /dev/null +++ b/start/cmp/lua/cmp/utils/api_spec.lua @@ -0,0 +1,46 @@ +local spec = require('cmp.utils.spec') +local keymap = require('cmp.utils.keymap') +local feedkeys = require('cmp.utils.feedkeys') +local api = require('cmp.utils.api') + +describe('api', function() + describe('get_cursor', function() + before_each(spec.before) + it('insert-mode', function() + local cursor + feedkeys.call(keymap.t('i\t1234567890'), 'nx', function() + cursor = api.get_cursor() + end) + assert.are.equal(cursor[2], 11) + end) + it('cmdline-mode', function() + local cursor + keymap.set_map(0, 'c', '<Plug>(cmp-spec-spy)', function() + cursor = api.get_cursor() + end, { expr = true, noremap = true }) + feedkeys.call(keymap.t(':\t1234567890'), 'n') + feedkeys.call(keymap.t('<Plug>(cmp-spec-spy)'), 'x') + assert.are.equal(cursor[2], 11) + end) + end) + + describe('get_cursor_before_line', function() + before_each(spec.before) + it('insert-mode', function() + local cursor_before_line + feedkeys.call(keymap.t('i\t1234567890<Left><Left>'), 'nx', function() + cursor_before_line = api.get_cursor_before_line() + end) + assert.are.same(cursor_before_line, '\t12345678') + end) + it('cmdline-mode', function() + local cursor_before_line + keymap.set_map(0, 'c', '<Plug>(cmp-spec-spy)', function() + cursor_before_line = api.get_cursor_before_line() + end, { expr = true, noremap = true }) + feedkeys.call(keymap.t(':\t1234567890<Left><Left>'), 'n') + feedkeys.call(keymap.t('<Plug>(cmp-spec-spy)'), 'x') + assert.are.same(cursor_before_line, '\t12345678') + end) + end) +end) diff --git a/start/cmp/lua/cmp/utils/async.lua b/start/cmp/lua/cmp/utils/async.lua new file mode 100644 index 0000000..13f126b --- /dev/null +++ b/start/cmp/lua/cmp/utils/async.lua @@ -0,0 +1,127 @@ +local async = {} + +---@class cmp.AsyncThrottle +---@field public running boolean +---@field public timeout number +---@field public sync function(self: cmp.AsyncThrottle, timeout: number|nil) +---@field public stop function +---@field public __call function + +---@param fn function +---@param timeout number +---@return cmp.AsyncThrottle +async.throttle = function(fn, timeout) + local time = nil + local timer = vim.loop.new_timer() + return setmetatable({ + running = false, + timeout = timeout, + sync = function(self, timeout_) + vim.wait(timeout_ or 1000, function() + return not self.running + end) + end, + stop = function() + time = nil + timer:stop() + end, + }, { + __call = function(self, ...) + local args = { ... } + + if time == nil then + time = vim.loop.now() + end + + self.running = true + timer:stop() + timer:start(math.max(1, self.timeout - (vim.loop.now() - time)), 0, function() + vim.schedule(function() + time = nil + fn(unpack(args)) + self.running = false + end) + end) + end, + }) +end + +---Control async tasks. +async.step = function(...) + local tasks = { ... } + local next + next = function(...) + if #tasks > 0 then + table.remove(tasks, 1)(next, ...) + end + end + table.remove(tasks, 1)(next) +end + +---Timeout callback function +---@param fn function +---@param timeout number +---@return function +async.timeout = function(fn, timeout) + local timer + local done = false + local callback = function(...) + if not done then + done = true + timer:stop() + timer:close() + fn(...) + end + end + timer = vim.loop.new_timer() + timer:start(timeout, 0, function() + callback() + end) + return callback +end + +---@alias cmp.AsyncDedup fun(callback: function): function + +---Create deduplicated callback +---@return function +async.dedup = function() + local id = 0 + return function(callback) + id = id + 1 + + local current = id + return function(...) + if current == id then + callback(...) + end + end + end +end + +---Convert async process as sync +async.sync = function(runner, timeout) + local done = false + runner(function() + done = true + end) + vim.wait(timeout, function() + return done + end, 10, false) +end + +---Wait and callback for next safe state. +async.debounce_next_tick = function(callback) + local running = false + return function() + if running then + return + end + running = true + vim.schedule(function() + running = false + callback() + end) + end +end + +return async diff --git a/start/cmp/lua/cmp/utils/async_spec.lua b/start/cmp/lua/cmp/utils/async_spec.lua new file mode 100644 index 0000000..62f5379 --- /dev/null +++ b/start/cmp/lua/cmp/utils/async_spec.lua @@ -0,0 +1,69 @@ +local async = require('cmp.utils.async') + +describe('utils.async', function() + it('throttle', function() + local count = 0 + local now + local f = async.throttle(function() + count = count + 1 + end, 100) + + -- 1. delay for 100ms + now = vim.loop.now() + f.timeout = 100 + f() + vim.wait(1000, function() + return count == 1 + end) + assert.is.truthy(math.abs(f.timeout - (vim.loop.now() - now)) < 10) + + -- 2. delay for 500ms + now = vim.loop.now() + f.timeout = 500 + f() + vim.wait(1000, function() + return count == 2 + end) + assert.is.truthy(math.abs(f.timeout - (vim.loop.now() - now)) < 10) + + -- 4. delay for 500ms and wait 100ms (remain 400ms) + f.timeout = 500 + f() + vim.wait(100) -- remain 400ms + + -- 5. call immediately (100ms already elapsed from No.4) + now = vim.loop.now() + f.timeout = 100 + f() + vim.wait(1000, function() + return count == 3 + end) + assert.is.truthy(math.abs(vim.loop.now() - now) < 10) + end) + it('step', function() + local done = false + local step = {} + async.step(function(next) + vim.defer_fn(function() + table.insert(step, 1) + next() + end, 10) + end, function(next) + vim.defer_fn(function() + table.insert(step, 2) + next() + end, 10) + end, function(next) + vim.defer_fn(function() + table.insert(step, 3) + next() + end, 10) + end, function() + done = true + end) + vim.wait(1000, function() + return done + end) + assert.are.same(step, { 1, 2, 3 }) + end) +end) diff --git a/start/cmp/lua/cmp/utils/autocmd.lua b/start/cmp/lua/cmp/utils/autocmd.lua new file mode 100644 index 0000000..438e231 --- /dev/null +++ b/start/cmp/lua/cmp/utils/autocmd.lua @@ -0,0 +1,53 @@ +local debug = require('cmp.utils.debug') + +local autocmd = {} + +autocmd.group = vim.api.nvim_create_augroup('___cmp___', { clear = true }) + +autocmd.events = {} + +---Subscribe autocmd +---@param events string|string[] +---@param callback function +---@return function +autocmd.subscribe = function(events, callback) + events = type(events) == 'string' and { events } or events + + for _, event in ipairs(events) do + if not autocmd.events[event] then + autocmd.events[event] = {} + vim.api.nvim_create_autocmd(event, { + desc = ('nvim-cmp: autocmd: %s'):format(event), + group = autocmd.group, + callback = function() + autocmd.emit(event) + end, + }) + end + table.insert(autocmd.events[event], callback) + end + + return function() + for _, event in ipairs(events) do + for i, callback_ in ipairs(autocmd.events[event]) do + if callback_ == callback then + table.remove(autocmd.events[event], i) + break + end + end + end + end +end + +---Emit autocmd +---@param event string +autocmd.emit = function(event) + debug.log(' ') + debug.log(string.format('>>> %s', event)) + autocmd.events[event] = autocmd.events[event] or {} + for _, callback in ipairs(autocmd.events[event]) do + callback() + end +end + +return autocmd diff --git a/start/cmp/lua/cmp/utils/binary.lua b/start/cmp/lua/cmp/utils/binary.lua new file mode 100644 index 0000000..c6a7088 --- /dev/null +++ b/start/cmp/lua/cmp/utils/binary.lua @@ -0,0 +1,33 @@ +local binary = {} + +---Insert item to list to ordered index +---@param list any[] +---@param item any +---@param func fun(a: any, b: any): 1|-1|0 +binary.insort = function(list, item, func) + table.insert(list, binary.search(list, item, func), item) +end + +---Search suitable index from list +---@param list any[] +---@param item any +---@param func fun(a: any, b: any): 1|-1|0 +---@return number +binary.search = function(list, item, func) + local s = 1 + local e = #list + while s <= e do + local idx = math.floor((e + s) / 2) + local diff = func(item, list[idx]) + if diff > 0 then + s = idx + 1 + elseif diff < 0 then + e = idx - 1 + else + return idx + 1 + end + end + return s +end + +return binary diff --git a/start/cmp/lua/cmp/utils/binary_spec.lua b/start/cmp/lua/cmp/utils/binary_spec.lua new file mode 100644 index 0000000..92fe129 --- /dev/null +++ b/start/cmp/lua/cmp/utils/binary_spec.lua @@ -0,0 +1,28 @@ +local binary = require('cmp.utils.binary') + +describe('utils.binary', function() + it('insort', function() + local func = function(a, b) + return a.score - b.score + end + local list = {} + binary.insort(list, { id = 'a', score = 1 }, func) + binary.insort(list, { id = 'b', score = 5 }, func) + binary.insort(list, { id = 'c', score = 2.5 }, func) + binary.insort(list, { id = 'd', score = 2 }, func) + binary.insort(list, { id = 'e', score = 8 }, func) + binary.insort(list, { id = 'g', score = 8 }, func) + binary.insort(list, { id = 'h', score = 7 }, func) + binary.insort(list, { id = 'i', score = 6 }, func) + binary.insort(list, { id = 'j', score = 4 }, func) + assert.are.equal(list[1].id, 'a') + assert.are.equal(list[2].id, 'd') + assert.are.equal(list[3].id, 'c') + assert.are.equal(list[4].id, 'j') + assert.are.equal(list[5].id, 'b') + assert.are.equal(list[6].id, 'i') + assert.are.equal(list[7].id, 'h') + assert.are.equal(list[8].id, 'e') + assert.are.equal(list[9].id, 'g') + end) +end) diff --git a/start/cmp/lua/cmp/utils/buffer.lua b/start/cmp/lua/cmp/utils/buffer.lua new file mode 100644 index 0000000..63171c9 --- /dev/null +++ b/start/cmp/lua/cmp/utils/buffer.lua @@ -0,0 +1,28 @@ +local buffer = {} + +buffer.cache = {} + +---@return number buf +buffer.get = function(name) + local buf = buffer.cache[name] + if buf and vim.api.nvim_buf_is_valid(buf) then + return buf + else + return nil + end +end + +---@return number buf +---@return boolean created_new +buffer.ensure = function(name) + local created_new = false + local buf = buffer.get(name) + if not buf then + created_new = true + buf = vim.api.nvim_create_buf(false, true) + buffer.cache[name] = buf + end + return buf, created_new +end + +return buffer diff --git a/start/cmp/lua/cmp/utils/cache.lua b/start/cmp/lua/cmp/utils/cache.lua new file mode 100644 index 0000000..8607b2a --- /dev/null +++ b/start/cmp/lua/cmp/utils/cache.lua @@ -0,0 +1,58 @@ +---@class cmp.Cache +---@field public entries any +local cache = {} + +cache.new = function() + local self = setmetatable({}, { __index = cache }) + self.entries = {} + return self +end + +---Get cache value +---@param key string +---@return any|nil +cache.get = function(self, key) + key = self:key(key) + if self.entries[key] ~= nil then + return self.entries[key] + end + return nil +end + +---Set cache value explicitly +---@param key string +---@vararg any +cache.set = function(self, key, value) + key = self:key(key) + self.entries[key] = value +end + +---Ensure value by callback +---@param key string +---@param callback fun(): any +cache.ensure = function(self, key, callback) + local value = self:get(key) + if value == nil then + local v = callback() + self:set(key, v) + return v + end + return value +end + +---Clear all cache entries +cache.clear = function(self) + self.entries = {} +end + +---Create key +---@param key string|table +---@return string +cache.key = function(_, key) + if type(key) == 'table' then + return table.concat(key, ':') + end + return key +end + +return cache diff --git a/start/cmp/lua/cmp/utils/char.lua b/start/cmp/lua/cmp/utils/char.lua new file mode 100644 index 0000000..6e18994 --- /dev/null +++ b/start/cmp/lua/cmp/utils/char.lua @@ -0,0 +1,117 @@ +local _ + +local alpha = {} +_ = string.gsub('abcdefghijklmnopqrstuvwxyz', '.', function(char) + alpha[string.byte(char)] = true +end) + +local ALPHA = {} +_ = string.gsub('ABCDEFGHIJKLMNOPQRSTUVWXYZ', '.', function(char) + ALPHA[string.byte(char)] = true +end) + +local digit = {} +_ = string.gsub('1234567890', '.', function(char) + digit[string.byte(char)] = true +end) + +local white = {} +_ = string.gsub(' \t\n', '.', function(char) + white[string.byte(char)] = true +end) + +local char = {} + +---@param byte number +---@return boolean +char.is_upper = function(byte) + return ALPHA[byte] +end + +---@param byte number +---@return boolean +char.is_alpha = function(byte) + return alpha[byte] or ALPHA[byte] +end + +---@param byte number +---@return boolean +char.is_digit = function(byte) + return digit[byte] +end + +---@param byte number +---@return boolean +char.is_white = function(byte) + return white[byte] +end + +---@param byte number +---@return boolean +char.is_symbol = function(byte) + return not (char.is_alnum(byte) or char.is_white(byte)) +end + +---@param byte number +---@return boolean +char.is_printable = function(byte) + return string.match(string.char(byte), '^%c$') == nil +end + +---@param byte number +---@return boolean +char.is_alnum = function(byte) + return char.is_alpha(byte) or char.is_digit(byte) +end + +---@param text string +---@param index number +---@return boolean +char.is_semantic_index = function(text, index) + if index <= 1 then + return true + end + + local prev = string.byte(text, index - 1) + local curr = string.byte(text, index) + + if not char.is_upper(prev) and char.is_upper(curr) then + return true + end + if char.is_symbol(curr) or char.is_white(curr) then + return true + end + if not char.is_alpha(prev) and char.is_alpha(curr) then + return true + end + if not char.is_digit(prev) and char.is_digit(curr) then + return true + end + return false +end + +---@param text string +---@param current_index number +---@return boolean +char.get_next_semantic_index = function(text, current_index) + for i = current_index + 1, #text do + if char.is_semantic_index(text, i) then + return i + end + end + return #text + 1 +end + +---Ignore case match +---@param byte1 number +---@param byte2 number +---@return boolean +char.match = function(byte1, byte2) + if not char.is_alpha(byte1) or not char.is_alpha(byte2) then + return byte1 == byte2 + end + local diff = byte1 - byte2 + return diff == 0 or diff == 32 or diff == -32 +end + +return char diff --git a/start/cmp/lua/cmp/utils/debug.lua b/start/cmp/lua/cmp/utils/debug.lua new file mode 100644 index 0000000..c8b0dba --- /dev/null +++ b/start/cmp/lua/cmp/utils/debug.lua @@ -0,0 +1,20 @@ +local debug = {} + +debug.flag = false + +---Print log +---@vararg any +debug.log = function(...) + if debug.flag then + local data = {} + for _, v in ipairs({ ... }) do + if not vim.tbl_contains({ 'string', 'number', 'boolean' }, type(v)) then + v = vim.inspect(v) + end + table.insert(data, v) + end + print(table.concat(data, '\t')) + end +end + +return debug diff --git a/start/cmp/lua/cmp/utils/event.lua b/start/cmp/lua/cmp/utils/event.lua new file mode 100644 index 0000000..662d573 --- /dev/null +++ b/start/cmp/lua/cmp/utils/event.lua @@ -0,0 +1,51 @@ +---@class cmp.Event +---@field private events table<string, function[]> +local event = {} + +---Create vents +event.new = function() + local self = setmetatable({}, { __index = event }) + self.events = {} + return self +end + +---Add event listener +---@param name string +---@param callback function +---@return function +event.on = function(self, name, callback) + if not self.events[name] then + self.events[name] = {} + end + table.insert(self.events[name], callback) + return function() + self:off(name, callback) + end +end + +---Remove event listener +---@param name string +---@param callback function +event.off = function(self, name, callback) + for i, callback_ in ipairs(self.events[name] or {}) do + if callback_ == callback then + table.remove(self.events[name], i) + break + end + end +end + +---Remove all events +event.clear = function(self) + self.events = {} +end + +---Emit event +---@param name string +event.emit = function(self, name, ...) + for _, callback in ipairs(self.events[name] or {}) do + callback(...) + end +end + +return event diff --git a/start/cmp/lua/cmp/utils/feedkeys.lua b/start/cmp/lua/cmp/utils/feedkeys.lua new file mode 100644 index 0000000..cd20f60 --- /dev/null +++ b/start/cmp/lua/cmp/utils/feedkeys.lua @@ -0,0 +1,53 @@ +local keymap = require('cmp.utils.keymap') +local misc = require('cmp.utils.misc') + +local feedkeys = {} + +feedkeys.call = setmetatable({ + callbacks = {}, +}, { + __call = function(self, keys, mode, callback) + local is_insert = string.match(mode, 'i') ~= nil + local is_immediate = string.match(mode, 'x') ~= nil + + local queue = {} + if #keys > 0 then + table.insert(queue, { keymap.t('<Cmd>setlocal lazyredraw<CR>'), 'n' }) + table.insert(queue, { keymap.t('<Cmd>setlocal textwidth=0<CR>'), 'n' }) + table.insert(queue, { keymap.t('<Cmd>setlocal backspace=2<CR>'), 'n' }) + table.insert(queue, { keys, string.gsub(mode, '[itx]', ''), true }) + table.insert(queue, { keymap.t('<Cmd>setlocal %slazyredraw<CR>'):format(vim.o.lazyredraw and '' or 'no'), 'n' }) + table.insert(queue, { keymap.t('<Cmd>setlocal textwidth=%s<CR>'):format(vim.bo.textwidth or 0), 'n' }) + table.insert(queue, { keymap.t('<Cmd>setlocal backspace=%s<CR>'):format(vim.go.backspace or 2), 'n' }) + end + + if callback then + local id = misc.id('cmp.utils.feedkeys.call') + self.callbacks[id] = callback + table.insert(queue, { keymap.t('<Cmd>call v:lua.cmp.utils.feedkeys.call.run(%s)<CR>'):format(id), 'n', true }) + end + + if is_insert then + for i = #queue, 1, -1 do + vim.api.nvim_feedkeys(queue[i][1], queue[i][2] .. 'i', queue[i][3]) + end + else + for i = 1, #queue do + vim.api.nvim_feedkeys(queue[i][1], queue[i][2], queue[i][3]) + end + end + + if is_immediate then + vim.api.nvim_feedkeys('', 'x', true) + end + end, +}) +misc.set(_G, { 'cmp', 'utils', 'feedkeys', 'call', 'run' }, function(id) + if feedkeys.call.callbacks[id] then + feedkeys.call.callbacks[id]() + feedkeys.call.callbacks[id] = nil + end + return '' +end) + +return feedkeys diff --git a/start/cmp/lua/cmp/utils/feedkeys_spec.lua b/start/cmp/lua/cmp/utils/feedkeys_spec.lua new file mode 100644 index 0000000..24fba71 --- /dev/null +++ b/start/cmp/lua/cmp/utils/feedkeys_spec.lua @@ -0,0 +1,56 @@ +local spec = require('cmp.utils.spec') +local keymap = require('cmp.utils.keymap') + +local feedkeys = require('cmp.utils.feedkeys') + +describe('feedkeys', function() + before_each(spec.before) + + it('dot-repeat', function() + local reg + feedkeys.call(keymap.t('iaiueo<Esc>'), 'nx', function() + reg = vim.fn.getreg('.') + end) + assert.are.equal(reg, keymap.t('aiueo')) + end) + + it('textwidth', function() + vim.cmd([[setlocal textwidth=6]]) + feedkeys.call(keymap.t('iaiueo '), 'nx') + feedkeys.call(keymap.t('aaiueoaiueo'), 'nx') + assert.are.same(vim.api.nvim_buf_get_lines(0, 0, -1, false), { + 'aiueo aiueoaiueo', + }) + end) + + it('bacckspace', function() + vim.cmd([[setlocal backspace=0]]) + feedkeys.call(keymap.t('iaiueo'), 'nx') + feedkeys.call(keymap.t('a<BS><BS>'), 'nx') + assert.are.same(vim.api.nvim_buf_get_lines(0, 0, -1, false), { + 'aiu', + }) + end) + + it('testability', function() + feedkeys.call('i', 'n', function() + feedkeys.call('', 'n', function() + feedkeys.call('aiueo', 'in') + end) + feedkeys.call('', 'n', function() + feedkeys.call(keymap.t('<BS><BS><BS><BS><BS>'), 'in') + end) + feedkeys.call('', 'n', function() + feedkeys.call(keymap.t('abcde'), 'in') + end) + feedkeys.call('', 'n', function() + feedkeys.call(keymap.t('<BS><BS><BS><BS><BS>'), 'in') + end) + feedkeys.call('', 'n', function() + feedkeys.call(keymap.t('12345'), 'in') + end) + end) + feedkeys.call('', 'x') + assert.are.same(vim.api.nvim_buf_get_lines(0, 0, -1, false), { '12345' }) + end) +end) diff --git a/start/cmp/lua/cmp/utils/highlight.lua b/start/cmp/lua/cmp/utils/highlight.lua new file mode 100644 index 0000000..867632a --- /dev/null +++ b/start/cmp/lua/cmp/utils/highlight.lua @@ -0,0 +1,31 @@ +local highlight = {} + +highlight.keys = { + 'fg', + 'bg', + 'bold', + 'italic', + 'reverse', + 'standout', + 'underline', + 'undercurl', + 'strikethrough', +} + +highlight.inherit = function(name, source, settings) + for _, key in ipairs(highlight.keys) do + if not settings[key] then + local v = vim.fn.synIDattr(vim.fn.hlID(source), key) + if key == 'fg' or key == 'bg' then + local n = tonumber(v, 10) + v = type(n) == 'number' and n or v + else + v = v == 1 + end + settings[key] = v == '' and 'NONE' or v + end + end + vim.api.nvim_set_hl(0, name, settings) +end + +return highlight diff --git a/start/cmp/lua/cmp/utils/keymap.lua b/start/cmp/lua/cmp/utils/keymap.lua new file mode 100644 index 0000000..aea5c1d --- /dev/null +++ b/start/cmp/lua/cmp/utils/keymap.lua @@ -0,0 +1,251 @@ +local misc = require('cmp.utils.misc') +local buffer = require('cmp.utils.buffer') +local api = require('cmp.utils.api') + +local keymap = {} + +---Shortcut for nvim_replace_termcodes +---@param keys string +---@return string +keymap.t = function(keys) + return (string.gsub(keys, '(<[A-Za-z0-9\\%-%[%]%^@]->)', function(match) + return vim.api.nvim_eval(string.format([["\%s"]], match)) + end)) +end + +---Normalize key sequence. +---@param keys string +---@return string +keymap.normalize = function(keys) + local normalize_buf = buffer.ensure('cmp.util.keymap.normalize') + vim.api.nvim_buf_set_keymap(normalize_buf, 't', keys, '<Plug>(cmp.utils.keymap.normalize)', {}) + for _, map in ipairs(vim.api.nvim_buf_get_keymap(normalize_buf, 't')) do + if keymap.equals(map.rhs, '<Plug>(cmp.utils.keymap.normalize)') then + vim.api.nvim_buf_del_keymap(normalize_buf, 't', keys) + return map.lhs + end + end + vim.api.nvim_buf_del_keymap(normalize_buf, 't', keys) + return keys +end + +---Return vim notation keymapping (simple conversion). +---@param s string +---@return string +keymap.to_keymap = setmetatable({ + ['<CR>'] = { '\n', '\r', '\r\n' }, + ['<Tab>'] = { '\t' }, + ['<BSlash>'] = { '\\' }, + ['<Bar>'] = { '|' }, + ['<Space>'] = { ' ' }, +}, { + __call = function(self, s) + return string.gsub(s, '.', function(c) + for key, chars in pairs(self) do + if vim.tbl_contains(chars, c) then + return key + end + end + return c + end) + end, +}) + +---Mode safe break undo +keymap.undobreak = function() + if not api.is_insert_mode() then + return '' + end + return keymap.t('<C-g>u') +end + +---Mode safe join undo +keymap.undojoin = function() + if not api.is_insert_mode() then + return '' + end + return keymap.t('<C-g>U') +end + +---Create backspace keys. +---@param count number +---@return string +keymap.backspace = function(count) + if type(count) == 'string' then + count = vim.fn.strchars(count, true) + end + if count <= 0 then + return '' + end + local keys = {} + table.insert(keys, keymap.t(string.rep('<BS>', count))) + return table.concat(keys, '') +end + +---Update indentkeys. +---@param expr string +---@return string +keymap.indentkeys = function(expr) + return string.format(keymap.t('<Cmd>set indentkeys=%s<CR>'), expr and vim.fn.escape(expr, '| \t\\') or '') +end + +---Return two key sequence are equal or not. +---@param a string +---@param b string +---@return boolean +keymap.equals = function(a, b) + return keymap.t(a) == keymap.t(b) +end + +---Register keypress handler. +keymap.listen = function(mode, lhs, callback) + lhs = keymap.normalize(keymap.to_keymap(lhs)) + + local existing = keymap.get_map(mode, lhs) + local id = string.match(existing.rhs, 'v:lua%.cmp%.utils%.keymap%.set_map%((%d+)%)') + if id and keymap.set_map.callbacks[tonumber(id, 10)] then + return + end + + local bufnr = existing.buffer and vim.api.nvim_get_current_buf() or -1 + local fallback = keymap.fallback(bufnr, mode, existing) + keymap.set_map(bufnr, mode, lhs, function() + local ignore = false + ignore = ignore or (mode == 'c' and vim.fn.getcmdtype() == '=') + if ignore then + fallback() + else + callback(lhs, misc.once(fallback)) + end + end, { + expr = false, + noremap = true, + silent = true, + }) +end + +---Fallback +keymap.fallback = function(bufnr, mode, map) + return function() + if map.expr then + local fallback_expr = string.format('<Plug>(cmp.u.k.fallback_expr:%s)', map.lhs) + keymap.set_map(bufnr, mode, fallback_expr, function() + return keymap.solve(bufnr, mode, map).keys + end, { + expr = true, + noremap = map.noremap, + script = map.script, + nowait = map.nowait, + silent = map.silent and mode ~= 'c', + }) + vim.api.nvim_feedkeys(keymap.t(fallback_expr), 'im', true) + elseif not map.callback then + local solved = keymap.solve(bufnr, mode, map) + vim.api.nvim_feedkeys(solved.keys, solved.mode, true) + else + map.callback() + end + end +end + +---Solve +keymap.solve = function(bufnr, mode, map) + local lhs = keymap.t(map.lhs) + local rhs = map.expr and (map.callback and map.callback() or vim.api.nvim_eval(keymap.t(map.rhs))) or keymap.t(map.rhs) + + if map.noremap then + return { keys = rhs, mode = 'in' } + end + + if string.find(rhs, lhs, 1, true) == 1 then + local recursive = string.format('<SNR>0_(cmp.u.k.recursive:%s)', lhs) + keymap.set_map(bufnr, mode, recursive, lhs, { + noremap = true, + script = map.script, + nowait = map.nowait, + silent = map.silent and mode ~= 'c', + }) + return { keys = keymap.t(recursive) .. string.gsub(rhs, '^' .. vim.pesc(lhs), ''), mode = 'im' } + end + return { keys = rhs, mode = 'im' } +end + +---Get map +---@param mode string +---@param lhs string +---@return table +keymap.get_map = function(mode, lhs) + lhs = keymap.normalize(lhs) + + for _, map in ipairs(vim.api.nvim_buf_get_keymap(0, mode)) do + if keymap.equals(map.lhs, lhs) then + return { + lhs = map.lhs, + rhs = map.rhs or '', + expr = map.expr == 1, + callback = map.callback, + noremap = map.noremap == 1, + script = map.script == 1, + silent = map.silent == 1, + nowait = map.nowait == 1, + buffer = true, + } + end + end + + for _, map in ipairs(vim.api.nvim_get_keymap(mode)) do + if keymap.equals(map.lhs, lhs) then + return { + lhs = map.lhs, + rhs = map.rhs or '', + expr = map.expr == 1, + callback = map.callback, + noremap = map.noremap == 1, + script = map.script == 1, + silent = map.silent == 1, + nowait = map.nowait == 1, + buffer = false, + } + end + end + + return { + lhs = lhs, + rhs = lhs, + expr = false, + callback = nil, + noremap = true, + script = false, + silent = true, + nowait = false, + buffer = false, + } +end + +---Set keymapping +keymap.set_map = setmetatable({ + callbacks = {}, +}, { + __call = function(self, bufnr, mode, lhs, rhs, opts) + if type(rhs) == 'function' then + local id = misc.id('cmp.utils.keymap.set_map') + self.callbacks[id] = rhs + if opts.expr then + rhs = ('v:lua.cmp.utils.keymap.set_map(%s)'):format(id) + else + rhs = ('<Cmd>call v:lua.cmp.utils.keymap.set_map(%s)<CR>'):format(id) + end + end + + if bufnr == -1 then + vim.api.nvim_set_keymap(mode, lhs, rhs, opts) + else + vim.api.nvim_buf_set_keymap(bufnr, mode, lhs, rhs, opts) + end + end, +}) +misc.set(_G, { 'cmp', 'utils', 'keymap', 'set_map' }, function(id) + return keymap.set_map.callbacks[id]() or '' +end) + +return keymap diff --git a/start/cmp/lua/cmp/utils/keymap_spec.lua b/start/cmp/lua/cmp/utils/keymap_spec.lua new file mode 100644 index 0000000..959783f --- /dev/null +++ b/start/cmp/lua/cmp/utils/keymap_spec.lua @@ -0,0 +1,187 @@ +local spec = require('cmp.utils.spec') +local api = require('cmp.utils.api') +local feedkeys = require('cmp.utils.feedkeys') + +local keymap = require('cmp.utils.keymap') + +describe('keymap', function() + before_each(spec.before) + + it('t', function() + for _, key in ipairs({ + '<F1>', + '<C-a>', + '<C-]>', + '<C-[>', + '<C-^>', + '<C-@>', + '<C-\\>', + '<Tab>', + '<S-Tab>', + '<Plug>(example)', + '<C-r>="abc"<CR>', + '<Cmd>normal! ==<CR>', + }) do + assert.are.equal(keymap.t(key), vim.api.nvim_replace_termcodes(key, true, true, true)) + assert.are.equal(keymap.t(key .. key), vim.api.nvim_replace_termcodes(key .. key, true, true, true)) + assert.are.equal(keymap.t(key .. key .. key), vim.api.nvim_replace_termcodes(key .. key .. key, true, true, true)) + end + end) + + it('to_keymap', function() + assert.are.equal(keymap.to_keymap('\n'), '<CR>') + assert.are.equal(keymap.to_keymap('<CR>'), '<CR>') + assert.are.equal(keymap.to_keymap('|'), '<Bar>') + end) + + describe('fallback', function() + before_each(spec.before) + + local run_fallback = function(keys, fallback) + local state = {} + feedkeys.call(keys, '', function() + fallback() + end) + feedkeys.call('', '', function() + if api.is_cmdline_mode() then + state.buffer = { api.get_current_line() } + else + state.buffer = vim.api.nvim_buf_get_lines(0, 0, -1, false) + end + state.cursor = api.get_cursor() + end) + feedkeys.call('', 'x') + return state + end + + describe('basic', function() + it('<Plug>', function() + vim.api.nvim_buf_set_keymap(0, 'i', '<Plug>(pairs)', '()<Left>', { noremap = true }) + vim.api.nvim_buf_set_keymap(0, 'i', '(', '<Plug>(pairs)', { noremap = false }) + local fallback = keymap.fallback(0, 'i', keymap.get_map('i', '(')) + local state = run_fallback('i', fallback) + assert.are.same({ '()' }, state.buffer) + assert.are.same({ 1, 1 }, state.cursor) + end) + + it('<C-r>=', function() + vim.api.nvim_buf_set_keymap(0, 'i', '(', '<C-r>="()"<CR><Left>', {}) + local fallback = keymap.fallback(0, 'i', keymap.get_map('i', '(')) + local state = run_fallback('i', fallback) + assert.are.same({ '()' }, state.buffer) + assert.are.same({ 1, 1 }, state.cursor) + end) + + it('callback', function() + vim.api.nvim_buf_set_keymap(0, 'i', '(', '', { + callback = function() + vim.api.nvim_feedkeys('()' .. keymap.t('<Left>'), 'int', true) + end, + }) + local fallback = keymap.fallback(0, 'i', keymap.get_map('i', '(')) + local state = run_fallback('i', fallback) + assert.are.same({ '()' }, state.buffer) + assert.are.same({ 1, 1 }, state.cursor) + end) + + it('expr-callback', function() + vim.api.nvim_buf_set_keymap(0, 'i', '(', '', { + expr = true, + noremap = false, + silent = true, + callback = function() + return '()' .. keymap.t('<Left>') + end, + }) + local fallback = keymap.fallback(0, 'i', keymap.get_map('i', '(')) + local state = run_fallback('i', fallback) + assert.are.same({ '()' }, state.buffer) + assert.are.same({ 1, 1 }, state.cursor) + end) + + -- it('cmdline default <Tab>', function() + -- local fallback = keymap.fallback(0, 'c', keymap.get_map('c', '<Tab>')) + -- local state = run_fallback(':', fallback) + -- assert.are.same({ '' }, state.buffer) + -- assert.are.same({ 1, 0 }, state.cursor) + -- end) + end) + + describe('recursive', function() + it('non-expr', function() + vim.api.nvim_buf_set_keymap(0, 'i', '(', '()<Left>', { + expr = false, + noremap = false, + silent = true, + }) + local fallback = keymap.fallback(0, 'i', keymap.get_map('i', '(')) + local state = run_fallback('i', fallback) + assert.are.same({ '()' }, state.buffer) + assert.are.same({ 1, 1 }, state.cursor) + end) + + it('expr', function() + vim.api.nvim_buf_set_keymap(0, 'i', '(', '"()<Left>"', { + expr = true, + noremap = false, + silent = true, + }) + local fallback = keymap.fallback(0, 'i', keymap.get_map('i', '(')) + local state = run_fallback('i', fallback) + assert.are.same({ '()' }, state.buffer) + assert.are.same({ 1, 1 }, state.cursor) + end) + + it('expr-callback', function() + pcall(function() + vim.api.nvim_buf_set_keymap(0, 'i', '(', '', { + expr = true, + noremap = false, + silent = true, + callback = function() + return keymap.t('()<Left>') + end, + }) + local fallback = keymap.fallback(0, 'i', keymap.get_map('i', '(')) + local state = run_fallback('i', fallback) + assert.are.same({ '()' }, state.buffer) + assert.are.same({ 1, 1 }, state.cursor) + end) + end) + end) + end) + + describe('realworld', function() + before_each(spec.before) + + it('#226', function() + keymap.listen('i', '<c-n>', function(_, fallback) + fallback() + end) + vim.api.nvim_feedkeys(keymap.t('iaiueo<CR>a<C-n><C-n>'), 'tx', true) + assert.are.same({ 'aiueo', 'aiueo' }, vim.api.nvim_buf_get_lines(0, 0, -1, true)) + end) + + it('#414', function() + keymap.listen('i', '<M-j>', function() + vim.api.nvim_feedkeys(keymap.t('<C-n>'), 'int', true) + end) + vim.api.nvim_feedkeys(keymap.t('iaiueo<CR>a<M-j><M-j>'), 'tx', true) + assert.are.same({ 'aiueo', 'aiueo' }, vim.api.nvim_buf_get_lines(0, 0, -1, true)) + end) + + it('#744', function() + vim.api.nvim_buf_set_keymap(0, 'i', '<C-r>', 'recursive', { + noremap = true, + }) + vim.api.nvim_buf_set_keymap(0, 'i', '<CR>', '<CR>recursive', { + noremap = false, + }) + keymap.listen('i', '<CR>', function(_, fallback) + fallback() + end) + feedkeys.call(keymap.t('i<CR>'), 'tx') + assert.are.same({ '', 'recursive' }, vim.api.nvim_buf_get_lines(0, 0, -1, true)) + end) + end) +end) diff --git a/start/cmp/lua/cmp/utils/misc.lua b/start/cmp/lua/cmp/utils/misc.lua new file mode 100644 index 0000000..7c6d0e7 --- /dev/null +++ b/start/cmp/lua/cmp/utils/misc.lua @@ -0,0 +1,253 @@ +local misc = {} + +---Create once callback +---@param callback function +---@return function +misc.once = function(callback) + local done = false + return function(...) + if done then + return + end + done = true + callback(...) + end +end + +---Return concatenated list +---@param list1 any[] +---@param list2 any[] +---@return any[] +misc.concat = function(list1, list2) + local new_list = {} + for _, v in ipairs(list1) do + table.insert(new_list, v) + end + for _, v in ipairs(list2) do + table.insert(new_list, v) + end + return new_list +end + +---Repeat values +---@generic T +---@param str_or_tbl T +---@param count number +---@return T +misc.rep = function(str_or_tbl, count) + if type(str_or_tbl) == 'string' then + return string.rep(str_or_tbl, count) + end + local rep = {} + for _ = 1, count do + for _, v in ipairs(str_or_tbl) do + table.insert(rep, v) + end + end + return rep +end + +---Return the valu is empty or not. +---@param v any +---@return boolean +misc.empty = function(v) + if not v then + return true + end + if v == vim.NIL then + return true + end + if type(v) == 'string' and v == '' then + return true + end + if type(v) == 'table' and vim.tbl_isempty(v) then + return true + end + if type(v) == 'number' and v == 0 then + return true + end + return false +end + +---The symbol to remove key in misc.merge. +misc.none = vim.NIL + +---Merge two tables recursively +---@generic T +---@param v1 T +---@param v2 T +---@return T +misc.merge = function(v1, v2) + local merge1 = type(v1) == 'table' and (not vim.tbl_islist(v1) or vim.tbl_isempty(v1)) + local merge2 = type(v2) == 'table' and (not vim.tbl_islist(v2) or vim.tbl_isempty(v2)) + if merge1 and merge2 then + local new_tbl = {} + for k, v in pairs(v2) do + new_tbl[k] = misc.merge(v1[k], v) + end + for k, v in pairs(v1) do + if v2[k] == nil and v ~= misc.none then + new_tbl[k] = v + end + end + return new_tbl + end + if v1 == misc.none then + return nil + end + if v1 == nil then + if v2 == misc.none then + return nil + else + return v2 + end + end + if v1 == true then + if merge2 then + return v2 + end + return {} + end + + return v1 +end + +---Generate id for group name +misc.id = setmetatable({ + group = {}, +}, { + __call = function(_, group) + misc.id.group[group] = misc.id.group[group] or 0 + misc.id.group[group] = misc.id.group[group] + 1 + return misc.id.group[group] + end, +}) + +---Check the value is nil or not. +---@param v boolean +---@return boolean +misc.safe = function(v) + if v == nil or v == vim.NIL then + return nil + end + return v +end + +---Treat 1/0 as bool value +---@param v boolean|1|0 +---@param def boolean +---@return boolean +misc.bool = function(v, def) + if misc.safe(v) == nil then + return def + end + return v == true or v == 1 +end + +---Set value to deep object +---@param t table +---@param keys string[] +---@param v any +misc.set = function(t, keys, v) + local c = t + for i = 1, #keys - 1 do + local key = keys[i] + c[key] = misc.safe(c[key]) or {} + c = c[key] + end + c[keys[#keys]] = v +end + +---Copy table +---@generic T +---@param tbl T +---@return T +misc.copy = function(tbl) + if type(tbl) ~= 'table' then + return tbl + end + + if vim.tbl_islist(tbl) then + local copy = {} + for i, value in ipairs(tbl) do + copy[i] = misc.copy(value) + end + return copy + end + + local copy = {} + for key, value in pairs(tbl) do + copy[key] = misc.copy(value) + end + return copy +end + +---Safe version of vim.str_utfindex +---@param text string +---@param vimindex number|nil +---@return number +misc.to_utfindex = function(text, vimindex) + vimindex = vimindex or #text + 1 + return vim.str_utfindex(text, math.max(0, math.min(vimindex - 1, #text))) +end + +---Safe version of vim.str_byteindex +---@param text string +---@param utfindex number +---@return number +misc.to_vimindex = function(text, utfindex) + utfindex = utfindex or #text + for i = utfindex, 1, -1 do + local s, v = pcall(function() + return vim.str_byteindex(text, i) + 1 + end) + if s then + return v + end + end + return utfindex + 1 +end + +---Mark the function as deprecated +misc.deprecated = function(fn, msg) + local printed = false + return function(...) + if not printed then + print(msg) + printed = true + end + return fn(...) + end +end + +--Redraw +misc.redraw = setmetatable({ + doing = false, + force = false, + termcode = vim.api.nvim_replace_termcodes('<C-r><Esc>', true, true, true), +}, { + __call = function(self, force) + if vim.tbl_contains({ '/', '?' }, vim.fn.getcmdtype()) then + if vim.o.incsearch then + return vim.api.nvim_feedkeys(self.termcode, 'in', true) + end + end + + if self.doing then + return + end + self.doing = true + self.force = not not force + vim.schedule(function() + if self.force then + vim.cmd([[redraw!]]) + else + vim.cmd([[redraw]]) + end + self.doing = false + self.force = false + end) + end, +}) + +return misc diff --git a/start/cmp/lua/cmp/utils/misc_spec.lua b/start/cmp/lua/cmp/utils/misc_spec.lua new file mode 100644 index 0000000..f687155 --- /dev/null +++ b/start/cmp/lua/cmp/utils/misc_spec.lua @@ -0,0 +1,63 @@ +local spec = require('cmp.utils.spec') + +local misc = require('cmp.utils.misc') + +describe('misc', function() + before_each(spec.before) + + it('merge', function() + local merged + merged = misc.merge({ + a = {}, + }, { + a = { + b = 1, + }, + }) + assert.are.equal(merged.a.b, 1) + + merged = misc.merge({ + a = { + i = 1, + }, + }, { + a = { + c = 2, + }, + }) + assert.are.equal(merged.a.i, 1) + assert.are.equal(merged.a.c, 2) + + merged = misc.merge({ + a = false, + }, { + a = { + b = 1, + }, + }) + assert.are.equal(merged.a, false) + + merged = misc.merge({ + a = misc.none, + }, { + a = { + b = 1, + }, + }) + assert.are.equal(merged.a, nil) + + merged = misc.merge({ + a = misc.none, + }, { + a = nil, + }) + assert.are.equal(merged.a, nil) + + merged = misc.merge({ + a = nil, + }, { + a = misc.none, + }) + assert.are.equal(merged.a, nil) + end) +end) diff --git a/start/cmp/lua/cmp/utils/pattern.lua b/start/cmp/lua/cmp/utils/pattern.lua new file mode 100644 index 0000000..1481e84 --- /dev/null +++ b/start/cmp/lua/cmp/utils/pattern.lua @@ -0,0 +1,28 @@ +local pattern = {} + +pattern._regexes = {} + +pattern.regex = function(p) + if not pattern._regexes[p] then + pattern._regexes[p] = vim.regex(p) + end + return pattern._regexes[p] +end + +pattern.offset = function(p, text) + local s, e = pattern.regex(p):match_str(text) + if s then + return s + 1, e + 1 + end + return nil, nil +end + +pattern.matchstr = function(p, text) + local s, e = pattern.offset(p, text) + if s then + return string.sub(text, s, e) + end + return nil +end + +return pattern diff --git a/start/cmp/lua/cmp/utils/spec.lua b/start/cmp/lua/cmp/utils/spec.lua new file mode 100644 index 0000000..a4b2c83 --- /dev/null +++ b/start/cmp/lua/cmp/utils/spec.lua @@ -0,0 +1,92 @@ +local context = require('cmp.context') +local source = require('cmp.source') +local types = require('cmp.types') +local config = require('cmp.config') + +local spec = {} + +spec.before = function() + vim.cmd([[ + bdelete! + enew! + imapclear + imapclear <buffer> + cmapclear + cmapclear <buffer> + smapclear + smapclear <buffer> + xmapclear + xmapclear <buffer> + tmapclear + tmapclear <buffer> + setlocal noswapfile + setlocal virtualedit=all + setlocal completeopt=menu,menuone,noselect + ]]) + config.set_global({ + sources = { + { name = 'spec' }, + }, + snippet = { + expand = function(args) + local ctx = context.new() + vim.api.nvim_buf_set_text(ctx.bufnr, ctx.cursor.row - 1, ctx.cursor.col - 1, ctx.cursor.row - 1, ctx.cursor.col - 1, vim.split(string.gsub(args.body, '%$0', ''), '\n')) + for i, t in ipairs(vim.split(args.body, '\n')) do + local s = string.find(t, '$0', 1, true) + if s then + if i == 1 then + vim.api.nvim_win_set_cursor(0, { ctx.cursor.row, ctx.cursor.col + s - 2 }) + else + vim.api.nvim_win_set_cursor(0, { ctx.cursor.row + i - 1, s - 1 }) + end + break + end + end + end, + }, + }) + config.set_cmdline({ + sources = { + { name = 'spec' }, + }, + }, ':') +end + +spec.state = function(text, row, col) + vim.fn.setline(1, text) + vim.fn.cursor(row, col) + local ctx = context.empty() + local s = source.new('spec', { + complete = function() end, + }) + return { + context = function() + return ctx + end, + source = function() + return s + end, + backspace = function() + vim.fn.feedkeys('x', 'nx') + vim.fn.feedkeys('h', 'nx') + ctx = context.new(ctx, { reason = types.cmp.ContextReason.Auto }) + s:complete(ctx, function() end) + return ctx + end, + input = function(char) + vim.fn.feedkeys(('i%s'):format(char), 'nx') + vim.fn.feedkeys(string.rep('l', #char), 'nx') + ctx.prev_context = nil + ctx = context.new(ctx, { reason = types.cmp.ContextReason.Auto }) + s:complete(ctx, function() end) + return ctx + end, + manual = function() + ctx = context.new(ctx, { reason = types.cmp.ContextReason.Manual }) + s:complete(ctx, function() end) + return ctx + end, + } +end + +return spec diff --git a/start/cmp/lua/cmp/utils/str.lua b/start/cmp/lua/cmp/utils/str.lua new file mode 100644 index 0000000..bca210c --- /dev/null +++ b/start/cmp/lua/cmp/utils/str.lua @@ -0,0 +1,178 @@ +local char = require('cmp.utils.char') + +local str = {} + +local INVALIDS = {} +INVALIDS[string.byte("'")] = true +INVALIDS[string.byte('"')] = true +INVALIDS[string.byte('=')] = true +INVALIDS[string.byte('$')] = true +INVALIDS[string.byte('(')] = true +INVALIDS[string.byte('[')] = true +INVALIDS[string.byte('<')] = true +INVALIDS[string.byte('{')] = true +INVALIDS[string.byte(' ')] = true +INVALIDS[string.byte('\t')] = true +INVALIDS[string.byte('\n')] = true +INVALIDS[string.byte('\r')] = true + +local NR_BYTE = string.byte('\n') + +local PAIRS = {} +PAIRS[string.byte('<')] = string.byte('>') +PAIRS[string.byte('[')] = string.byte(']') +PAIRS[string.byte('(')] = string.byte(')') +PAIRS[string.byte('{')] = string.byte('}') +PAIRS[string.byte('"')] = string.byte('"') +PAIRS[string.byte("'")] = string.byte("'") + +---Return if specified text has prefix or not +---@param text string +---@param prefix string +---@return boolean +str.has_prefix = function(text, prefix) + if #text < #prefix then + return false + end + for i = 1, #prefix do + if not char.match(string.byte(text, i), string.byte(prefix, i)) then + return false + end + end + return true +end + +---get_common_string +str.get_common_string = function(text1, text2) + local min = math.min(#text1, #text2) + for i = 1, min do + if not char.match(string.byte(text1, i), string.byte(text2, i)) then + return string.sub(text1, 1, i - 1) + end + end + return string.sub(text1, 1, min) +end + +---Remove suffix +---@param text string +---@param suffix string +---@return string +str.remove_suffix = function(text, suffix) + if #text < #suffix then + return text + end + + local i = 0 + while i < #suffix do + if string.byte(text, #text - i) ~= string.byte(suffix, #suffix - i) then + return text + end + i = i + 1 + end + return string.sub(text, 1, -#suffix - 1) +end + +---trim +---@param text string +---@return string +str.trim = function(text) + local s = 1 + for i = 1, #text do + if not char.is_white(string.byte(text, i)) then + s = i + break + end + end + + local e = #text + for i = #text, 1, -1 do + if not char.is_white(string.byte(text, i)) then + e = i + break + end + end + if s == 1 and e == #text then + return text + end + return string.sub(text, s, e) +end + +---get_word +---@param text string +---@param stop_char number +---@param min_length number +---@return string +str.get_word = function(text, stop_char, min_length) + min_length = min_length or 0 + + local has_alnum = false + local stack = {} + local word = {} + local add = function(c) + table.insert(word, string.char(c)) + if stack[#stack] == c then + table.remove(stack, #stack) + else + if PAIRS[c] then + table.insert(stack, c) + end + end + end + for i = 1, #text do + local c = string.byte(text, i, i) + if #word < min_length then + table.insert(word, string.char(c)) + elseif not INVALIDS[c] then + add(c) + has_alnum = has_alnum or char.is_alnum(c) + elseif not has_alnum then + add(c) + elseif #stack ~= 0 then + add(c) + if has_alnum and #stack == 0 then + break + end + else + break + end + end + if stop_char and word[#word] == string.char(stop_char) then + table.remove(word, #word) + end + return table.concat(word, '') +end + +---Oneline +---@param text string +---@return string +str.oneline = function(text) + for i = 1, #text do + if string.byte(text, i) == NR_BYTE then + return string.sub(text, 1, i - 1) + end + end + return text +end + +---Escape special chars +---@param text string +---@param chars string[] +---@return string +str.escape = function(text, chars) + table.insert(chars, '\\') + local escaped = {} + local i = 1 + while i <= #text do + local c = string.sub(text, i, i) + if vim.tbl_contains(chars, c) then + table.insert(escaped, '\\') + table.insert(escaped, c) + else + table.insert(escaped, c) + end + i = i + 1 + end + return table.concat(escaped, '') +end + +return str diff --git a/start/cmp/lua/cmp/utils/str_spec.lua b/start/cmp/lua/cmp/utils/str_spec.lua new file mode 100644 index 0000000..1a21855 --- /dev/null +++ b/start/cmp/lua/cmp/utils/str_spec.lua @@ -0,0 +1,29 @@ +local str = require('cmp.utils.str') + +describe('utils.str', function() + it('get_word', function() + assert.are.equal(str.get_word('print'), 'print') + assert.are.equal(str.get_word('$variable'), '$variable') + assert.are.equal(str.get_word('print()'), 'print') + assert.are.equal(str.get_word('["cmp#confirm"]'), '["cmp#confirm"]') + assert.are.equal(str.get_word('"devDependencies":', string.byte('"')), '"devDependencies') + assert.are.equal(str.get_word('"devDependencies": ${1},', string.byte('"')), '"devDependencies') + assert.are.equal(str.get_word('#[cfg(test)]'), '#[cfg(test)]') + assert.are.equal(str.get_word('import { GetStaticProps$1 } from "next";', nil, 9), 'import { GetStaticProps') + end) + + it('remove_suffix', function() + assert.are.equal(str.remove_suffix('log()', '$0'), 'log()') + assert.are.equal(str.remove_suffix('log()$0', '$0'), 'log()') + assert.are.equal(str.remove_suffix('log()${0}', '${0}'), 'log()') + assert.are.equal(str.remove_suffix('log()${0:placeholder}', '${0}'), 'log()${0:placeholder}') + end) + + it('escape', function() + assert.are.equal(str.escape('plain', {}), 'plain') + assert.are.equal(str.escape('plain\\', {}), 'plain\\\\') + assert.are.equal(str.escape('plain\\"', {}), 'plain\\\\"') + assert.are.equal(str.escape('pla"in', { '"' }), 'pla\\"in') + assert.are.equal(str.escape('call("")', { '"' }), 'call(\\"\\")') + end) +end) diff --git a/start/cmp/lua/cmp/utils/window.lua b/start/cmp/lua/cmp/utils/window.lua new file mode 100644 index 0000000..a8a271e --- /dev/null +++ b/start/cmp/lua/cmp/utils/window.lua @@ -0,0 +1,313 @@ +local cache = require('cmp.utils.cache') +local misc = require('cmp.utils.misc') +local buffer = require('cmp.utils.buffer') +local api = require('cmp.utils.api') + +---@class cmp.WindowStyle +---@field public relative string +---@field public row number +---@field public col number +---@field public width number +---@field public height number +---@field public border string|string[]|nil +---@field public zindex number|nil + +---@class cmp.Window +---@field public name string +---@field public win number|nil +---@field public thumb_win number|nil +---@field public sbar_win number|nil +---@field public style cmp.WindowStyle +---@field public opt table<string, any> +---@field public buffer_opt table<string, any> +---@field public cache cmp.Cache +local window = {} + +---new +---@return cmp.Window +window.new = function() + local self = setmetatable({}, { __index = window }) + self.name = misc.id('cmp.utils.window.new') + self.win = nil + self.sbar_win = nil + self.thumb_win = nil + self.style = {} + self.cache = cache.new() + self.opt = {} + self.buffer_opt = {} + return self +end + +---Set window option. +---NOTE: If the window already visible, immediately applied to it. +---@param key string +---@param value any +window.option = function(self, key, value) + if vim.fn.exists('+' .. key) == 0 then + return + end + + if value == nil then + return self.opt[key] + end + + self.opt[key] = value + if self:visible() then + vim.api.nvim_win_set_option(self.win, key, value) + end +end + +---Set buffer option. +---NOTE: If the buffer already visible, immediately applied to it. +---@param key string +---@param value any +window.buffer_option = function(self, key, value) + if vim.fn.exists('+' .. key) == 0 then + return + end + + if value == nil then + return self.buffer_opt[key] + end + + self.buffer_opt[key] = value + local existing_buf = buffer.get(self.name) + if existing_buf then + vim.api.nvim_buf_set_option(existing_buf, key, value) + end +end + +---Set style. +---@param style cmp.WindowStyle +window.set_style = function(self, style) + self.style = style + local info = self:info() + + if vim.o.lines and vim.o.lines <= info.row + info.height + 1 then + self.style.height = vim.o.lines - info.row - info.border_info.vert - 1 + end + + self.style.zindex = self.style.zindex or 1 +end + +---Return buffer id. +---@return number +window.get_buffer = function(self) + local buf, created_new = buffer.ensure(self.name) + if created_new then + for k, v in pairs(self.buffer_opt) do + vim.api.nvim_buf_set_option(buf, k, v) + end + end + return buf +end + +---Open window +---@param style cmp.WindowStyle +window.open = function(self, style) + if style then + self:set_style(style) + end + + if self.style.width < 1 or self.style.height < 1 then + return + end + + if self.win and vim.api.nvim_win_is_valid(self.win) then + vim.api.nvim_win_set_config(self.win, self.style) + else + local s = misc.copy(self.style) + s.noautocmd = true + self.win = vim.api.nvim_open_win(self:get_buffer(), false, s) + for k, v in pairs(self.opt) do + vim.api.nvim_win_set_option(self.win, k, v) + end + end + self:update() +end + +---Update +window.update = function(self) + local info = self:info() + if info.scrollable then + -- Draw the background of the scrollbar + + if not info.border_info.visible then + local style = { + relative = 'editor', + style = 'minimal', + width = 1, + height = self.style.height, + row = info.row, + col = info.col + info.width - info.scrollbar_offset, -- info.col was already contained the scrollbar offset. + zindex = (self.style.zindex and (self.style.zindex + 1) or 1), + } + if self.sbar_win and vim.api.nvim_win_is_valid(self.sbar_win) then + vim.api.nvim_win_set_config(self.sbar_win, style) + else + style.noautocmd = true + self.sbar_win = vim.api.nvim_open_win(buffer.ensure(self.name .. 'sbar_buf'), false, style) + vim.api.nvim_win_set_option(self.sbar_win, 'winhighlight', 'EndOfBuffer:PmenuSbar,NormalFloat:PmenuSbar') + end + end + + -- Draw the scrollbar thumb + local thumb_height = math.floor(info.inner_height * (info.inner_height / self:get_content_height()) + 0.5) + local thumb_offset = math.floor(info.inner_height * (vim.fn.getwininfo(self.win)[1].topline / self:get_content_height())) + + local style = { + relative = 'editor', + style = 'minimal', + width = 1, + height = math.max(1, thumb_height), + row = info.row + thumb_offset + (info.border_info.visible and info.border_info.top or 0), + col = info.col + info.width - 1, -- info.col was already added scrollbar offset. + zindex = (self.style.zindex and (self.style.zindex + 2) or 2), + } + if self.thumb_win and vim.api.nvim_win_is_valid(self.thumb_win) then + vim.api.nvim_win_set_config(self.thumb_win, style) + else + style.noautocmd = true + self.thumb_win = vim.api.nvim_open_win(buffer.ensure(self.name .. 'thumb_buf'), false, style) + vim.api.nvim_win_set_option(self.thumb_win, 'winhighlight', 'EndOfBuffer:PmenuThumb,NormalFloat:PmenuThumb') + end + else + if self.sbar_win and vim.api.nvim_win_is_valid(self.sbar_win) then + vim.api.nvim_win_hide(self.sbar_win) + self.sbar_win = nil + end + if self.thumb_win and vim.api.nvim_win_is_valid(self.thumb_win) then + vim.api.nvim_win_hide(self.thumb_win) + self.thumb_win = nil + end + end + + -- In cmdline, vim does not redraw automatically. + if api.is_cmdline_mode() then + vim.api.nvim_win_call(self.win, function() + misc.redraw() + end) + end +end + +---Close window +window.close = function(self) + if self.win and vim.api.nvim_win_is_valid(self.win) then + if self.win and vim.api.nvim_win_is_valid(self.win) then + vim.api.nvim_win_hide(self.win) + self.win = nil + end + if self.sbar_win and vim.api.nvim_win_is_valid(self.sbar_win) then + vim.api.nvim_win_hide(self.sbar_win) + self.sbar_win = nil + end + if self.thumb_win and vim.api.nvim_win_is_valid(self.thumb_win) then + vim.api.nvim_win_hide(self.thumb_win) + self.thumb_win = nil + end + end +end + +---Return the window is visible or not. +window.visible = function(self) + return self.win and vim.api.nvim_win_is_valid(self.win) +end + +---Return win info. +window.info = function(self) + local border_info = self:get_border_info() + local info = { + row = self.style.row, + col = self.style.col, + width = self.style.width + border_info.left + border_info.right, + height = self.style.height + border_info.top + border_info.bottom, + inner_width = self.style.width, + inner_height = self.style.height, + border_info = border_info, + scrollable = false, + scrollbar_offset = 0, + } + + if self:get_content_height() > info.inner_height then + info.scrollable = true + if not border_info.visible then + info.scrollbar_offset = 1 + info.width = info.width + 1 + end + end + + return info +end + +---Return border information. +---@return { top: number, left: number, right: number, bottom: number, vert: number, horiz: number, visible: boolean } +window.get_border_info = function(self) + local border = self.style.border + if not border or border == 'none' then + return { + top = 0, + left = 0, + right = 0, + bottom = 0, + vert = 0, + horiz = 0, + visible = false, + } + end + if type(border) == 'string' then + if border == 'shadow' then + return { + top = 0, + left = 0, + right = 1, + bottom = 1, + vert = 1, + horiz = 1, + visible = false, + } + end + return { + top = 1, + left = 1, + right = 1, + bottom = 1, + vert = 2, + horiz = 2, + visible = true, + } + end + + local new_border = {} + while #new_border <= 8 do + for _, b in ipairs(border) do + table.insert(new_border, type(b) == 'string' and b or b[1]) + end + end + local info = {} + info.top = new_border[2] == '' and 0 or 1 + info.right = new_border[4] == '' and 0 or 1 + info.bottom = new_border[6] == '' and 0 or 1 + info.left = new_border[8] == '' and 0 or 1 + info.vert = info.top + info.bottom + info.horiz = info.left + info.right + info.visible = not (vim.tbl_contains({ '', ' ' }, new_border[2]) and vim.tbl_contains({ '', ' ' }, new_border[4]) and vim.tbl_contains({ '', ' ' }, new_border[6]) and vim.tbl_contains({ '', ' ' }, new_border[8])) + return info +end + +---Get scroll height. +---NOTE: The result of vim.fn.strdisplaywidth depends on the buffer it was called in (see comment in cmp.Entry.get_view). +---@return number +window.get_content_height = function(self) + if not self:option('wrap') then + return vim.api.nvim_buf_line_count(self:get_buffer()) + end + local height = 0 + vim.api.nvim_buf_call(self:get_buffer(), function() + for _, text in ipairs(vim.api.nvim_buf_get_lines(self:get_buffer(), 0, -1, false)) do + height = height + math.max(1, math.ceil(vim.fn.strdisplaywidth(text) / self.style.width)) + end + end) + return height +end + +return window |