Sorter type safety

This commit is contained in:
Alexander Courtis 2024-11-09 11:49:40 +11:00
parent f06bd90d20
commit 06b63260b3

View File

@ -2,13 +2,16 @@ local Class = require("nvim-tree.classic")
local DirectoryNode = require("nvim-tree.node.directory")
---@alias SorterType "name" | "case_sensitive" | "modification_time" | "extension" | "suffix" | "filetype"
---@alias SorterComparator fun(a: Node, b: Node, cfg: SorterCfg): boolean
---@alias SorterUser fun(nodes: Node[]): SorterType?
---@alias SorterComparator fun(a: Node, b: Node, cfg: SorterCfg): boolean?
---@type table<SorterType, SorterComparator>
local C = {}
---@class (exact) SorterCfg
---@field sorter SorterType|fun(nodes: Node[])
---@field sorter SorterType|SorterUser
---@field folders_first boolean
---@field files_first boolean
@ -32,15 +35,6 @@ function Sorter:new(args)
}
end
---Predefined comparator
---@param type SorterType
---@return fun(a: Node, b: Node): boolean
function Sorter:get_comparator(type)
return function(a, b)
return (C[type] or C.name)(a, b, self.cfg)
end
end
---Create a shallow copy of a portion of a list.
---@param t table
---@param first integer First index, inclusive
@ -56,13 +50,10 @@ local function tbl_slice(t, first, last)
end
---Evaluate `sort.folders_first` and `sort.files_first`
---@param a Node
---@param b Node
---@param cfg SorterCfg
---@return boolean|nil
---@type SorterComparator
local function folders_or_files_first(a, b, cfg)
if not (cfg.folders_first or cfg.files_first) then
return
return nil
end
if not a:is(DirectoryNode) and b:is(DirectoryNode) then
@ -72,14 +63,17 @@ local function folders_or_files_first(a, b, cfg)
-- folder <> file
return not cfg.files_first
end
return nil
end
---@param t table
---@param t Node[]
---@param cfg SorterCfg
---@param first number
---@param mid number
---@param last number
---@param comparator fun(a: Node, b: Node): boolean
local function merge(t, first, mid, last, comparator)
---@param comparator SorterComparator
local function merge(t, cfg, first, mid, last, comparator)
local n1 = mid - first + 1
local n2 = last - mid
local ls = tbl_slice(t, first, mid)
@ -89,7 +83,7 @@ local function merge(t, first, mid, last, comparator)
local k = first
while i <= n1 and j <= n2 do
if comparator(ls[i], rs[j]) then
if comparator(ls[i], rs[j], cfg) then
t[k] = ls[i]
i = i + 1
else
@ -112,26 +106,29 @@ local function merge(t, first, mid, last, comparator)
end
end
---@param t table
---@param t Node[]
---@param cfg SorterCfg
---@param first number
---@param last number
---@param comparator fun(a: Node, b: Node): boolean
local function split_merge(t, first, last, comparator)
---@param comparator SorterComparator
local function split_merge(t, cfg, first, last, comparator)
if (last - first) < 1 then
return
end
local mid = math.floor((first + last) / 2)
split_merge(t, first, mid, comparator)
split_merge(t, mid + 1, last, comparator)
merge(t, first, mid, last, comparator)
split_merge(t, cfg, first, mid, comparator)
split_merge(t, cfg, mid + 1, last, comparator)
merge(t, cfg, first, mid, last, comparator)
end
---Perform a merge sort using sorter option.
---@param t Node[]
function Sorter:sort(t)
if type(self.cfg.sorter) == "function" then
if C[self.cfg.sorter] then
split_merge(t, self.cfg, 1, #t, C[self.cfg.sorter])
elseif type(self.cfg.sorter) == "function" then
local t_user = {}
local origin_index = {}
@ -148,9 +145,10 @@ function Sorter:sort(t)
table.insert(origin_index, n)
end
local predefined = self.cfg.sorter(t_user)
if predefined then
split_merge(t, 1, #t, self:get_comparator(predefined))
-- user may return a SorterType
local ret = self.cfg.sorter(t_user)
if C[ret] then
split_merge(t, self.cfg, 1, #t, C[ret])
return
end
@ -173,10 +171,7 @@ function Sorter:sort(t)
return (a_index or 0) <= (b_index or 0)
end
split_merge(t, 1, #t, mini_comparator) -- sort by user order
elseif type(self.cfg.sorter) == "string" then
local sorter = self.cfg.sorter --[[@as string]]
split_merge(t, 1, #t, self:get_comparator(sorter))
split_merge(t, self.cfg, 1, #t, mini_comparator) -- sort by user order
end
end