176 lines
4.0 KiB
Lua
176 lines
4.0 KiB
Lua
|
|
local M = {}
|
||
|
|
|
||
|
|
local tests_query = [[
|
||
|
|
(function_declaration
|
||
|
|
name: (identifier) @testname
|
||
|
|
parameters: (parameter_list
|
||
|
|
. (parameter_declaration
|
||
|
|
type: (pointer_type) @type) .)
|
||
|
|
(#match? @type "*testing.(T|M)")
|
||
|
|
(#match? @testname "^Test.+$")) @parent
|
||
|
|
]]
|
||
|
|
|
||
|
|
local subtests_query = [[
|
||
|
|
(call_expression
|
||
|
|
function: (selector_expression
|
||
|
|
operand: (identifier)
|
||
|
|
field: (field_identifier) @run)
|
||
|
|
arguments: (argument_list
|
||
|
|
(interpreted_string_literal) @testname
|
||
|
|
[
|
||
|
|
(func_literal)
|
||
|
|
(identifier)
|
||
|
|
])
|
||
|
|
(#eq? @run "Run")) @parent
|
||
|
|
]]
|
||
|
|
|
||
|
|
local function format_subtest(testcase, test_tree)
|
||
|
|
local parent
|
||
|
|
if testcase.parent then
|
||
|
|
for _, curr in pairs(test_tree) do
|
||
|
|
if curr.name == testcase.parent then
|
||
|
|
parent = curr
|
||
|
|
break
|
||
|
|
end
|
||
|
|
end
|
||
|
|
return string.format("%s/%s", format_subtest(parent, test_tree), testcase.name)
|
||
|
|
else
|
||
|
|
return testcase.name
|
||
|
|
end
|
||
|
|
end
|
||
|
|
|
||
|
|
local function get_closest_above_cursor(test_tree)
|
||
|
|
local result
|
||
|
|
for _, curr in pairs(test_tree) do
|
||
|
|
if not result then
|
||
|
|
result = curr
|
||
|
|
else
|
||
|
|
local node_row1, _, _, _ = curr.node:range()
|
||
|
|
local result_row1, _, _, _ = result.node:range()
|
||
|
|
if node_row1 > result_row1 then
|
||
|
|
result = curr
|
||
|
|
end
|
||
|
|
end
|
||
|
|
end
|
||
|
|
if result then
|
||
|
|
return format_subtest(result, test_tree)
|
||
|
|
end
|
||
|
|
return nil
|
||
|
|
end
|
||
|
|
|
||
|
|
local function is_parent(dest, source)
|
||
|
|
if not (dest and source) then
|
||
|
|
return false
|
||
|
|
end
|
||
|
|
if dest == source then
|
||
|
|
return false
|
||
|
|
end
|
||
|
|
|
||
|
|
local current = source
|
||
|
|
while current ~= nil do
|
||
|
|
if current == dest then
|
||
|
|
return true
|
||
|
|
end
|
||
|
|
|
||
|
|
current = current:parent()
|
||
|
|
end
|
||
|
|
|
||
|
|
return false
|
||
|
|
end
|
||
|
|
|
||
|
|
local function get_closest_test()
|
||
|
|
local stop_row = vim.api.nvim_win_get_cursor(0)[1]
|
||
|
|
local ft = vim.api.nvim_buf_get_option(0, "filetype")
|
||
|
|
assert(ft == "go", "can only find test in go files, not " .. ft)
|
||
|
|
local parser = vim.treesitter.get_parser(0)
|
||
|
|
local root = (parser:parse()[1]):root()
|
||
|
|
|
||
|
|
local test_tree = {}
|
||
|
|
|
||
|
|
local test_query = vim.treesitter.query.parse(ft, tests_query)
|
||
|
|
assert(test_query, "could not parse test query")
|
||
|
|
for _, match, _ in test_query:iter_matches(root, 0, 0, stop_row, { all = true }) do
|
||
|
|
local test_match = {}
|
||
|
|
for id, nodes in pairs(match) do
|
||
|
|
for _, node in ipairs(nodes) do
|
||
|
|
local capture = test_query.captures[id]
|
||
|
|
if capture == "testname" then
|
||
|
|
local name = vim.treesitter.get_node_text(node, 0)
|
||
|
|
test_match.name = name
|
||
|
|
end
|
||
|
|
if capture == "parent" then
|
||
|
|
test_match.node = node
|
||
|
|
end
|
||
|
|
end
|
||
|
|
end
|
||
|
|
table.insert(test_tree, test_match)
|
||
|
|
end
|
||
|
|
|
||
|
|
local subtest_query = vim.treesitter.query.parse(ft, subtests_query)
|
||
|
|
assert(subtest_query, "could not parse test query")
|
||
|
|
for _, match, _ in subtest_query:iter_matches(root, 0, 0, stop_row, { all = true }) do
|
||
|
|
local test_match = {}
|
||
|
|
for id, nodes in pairs(match) do
|
||
|
|
for _, node in ipairs(nodes) do
|
||
|
|
local capture = subtest_query.captures[id]
|
||
|
|
if capture == "testname" then
|
||
|
|
local name = vim.treesitter.get_node_text(node, 0)
|
||
|
|
test_match.name = string.gsub(string.gsub(name, " ", "_"), '"', "")
|
||
|
|
end
|
||
|
|
if capture == "parent" then
|
||
|
|
test_match.node = node
|
||
|
|
end
|
||
|
|
end
|
||
|
|
end
|
||
|
|
table.insert(test_tree, test_match)
|
||
|
|
end
|
||
|
|
|
||
|
|
table.sort(test_tree, function(a, b)
|
||
|
|
return is_parent(a.node, b.node)
|
||
|
|
end)
|
||
|
|
|
||
|
|
for _, parent in ipairs(test_tree) do
|
||
|
|
for _, child in ipairs(test_tree) do
|
||
|
|
if is_parent(parent.node, child.node) then
|
||
|
|
child.parent = parent.name
|
||
|
|
end
|
||
|
|
end
|
||
|
|
end
|
||
|
|
|
||
|
|
return get_closest_above_cursor(test_tree)
|
||
|
|
end
|
||
|
|
|
||
|
|
local function get_package_name()
|
||
|
|
local test_dir = vim.fn.fnamemodify(vim.fn.expand("%:.:h"), ":r")
|
||
|
|
return "./" .. test_dir
|
||
|
|
end
|
||
|
|
|
||
|
|
M.closest_test = function()
|
||
|
|
local package_name = get_package_name()
|
||
|
|
local test_case = get_closest_test()
|
||
|
|
local test_scope
|
||
|
|
if test_case then
|
||
|
|
test_scope = "testcase"
|
||
|
|
else
|
||
|
|
test_scope = "package"
|
||
|
|
end
|
||
|
|
return {
|
||
|
|
package = package_name,
|
||
|
|
name = test_case,
|
||
|
|
scope = test_scope,
|
||
|
|
}
|
||
|
|
end
|
||
|
|
|
||
|
|
M.get_root_dir = function()
|
||
|
|
local id, client = next(vim.lsp.buf_get_clients())
|
||
|
|
if id == nil then
|
||
|
|
error({ error_msg = "lsp client not attached" })
|
||
|
|
end
|
||
|
|
if not client.config.root_dir then
|
||
|
|
error({ error_msg = "lsp root_dir not defined" })
|
||
|
|
end
|
||
|
|
return client.config.root_dir
|
||
|
|
end
|
||
|
|
|
||
|
|
return M
|