Sorter type safety
This commit is contained in:
parent
f06bd90d20
commit
06b63260b3
@ -2,13 +2,16 @@ local Class = require("nvim-tree.classic")
|
||||
local DirectoryNode = require("nvim-tree.node.directory")
|
||||
|
||||
---@alias SorterType "name" | "case_sensitive" | "modification_time" | "extension" | "suffix" | "filetype"
|
||||
---@alias SorterComparator fun(a: Node, b: Node, cfg: SorterCfg): boolean
|
||||
|
||||
---@alias SorterUser fun(nodes: Node[]): SorterType?
|
||||
|
||||
---@alias SorterComparator fun(a: Node, b: Node, cfg: SorterCfg): boolean?
|
||||
|
||||
---@type table<SorterType, SorterComparator>
|
||||
local C = {}
|
||||
|
||||
---@class (exact) SorterCfg
|
||||
---@field sorter SorterType|fun(nodes: Node[])
|
||||
---@field sorter SorterType|SorterUser
|
||||
---@field folders_first boolean
|
||||
---@field files_first boolean
|
||||
|
||||
@ -32,15 +35,6 @@ function Sorter:new(args)
|
||||
}
|
||||
end
|
||||
|
||||
---Predefined comparator
|
||||
---@param type SorterType
|
||||
---@return fun(a: Node, b: Node): boolean
|
||||
function Sorter:get_comparator(type)
|
||||
return function(a, b)
|
||||
return (C[type] or C.name)(a, b, self.cfg)
|
||||
end
|
||||
end
|
||||
|
||||
---Create a shallow copy of a portion of a list.
|
||||
---@param t table
|
||||
---@param first integer First index, inclusive
|
||||
@ -56,13 +50,10 @@ local function tbl_slice(t, first, last)
|
||||
end
|
||||
|
||||
---Evaluate `sort.folders_first` and `sort.files_first`
|
||||
---@param a Node
|
||||
---@param b Node
|
||||
---@param cfg SorterCfg
|
||||
---@return boolean|nil
|
||||
---@type SorterComparator
|
||||
local function folders_or_files_first(a, b, cfg)
|
||||
if not (cfg.folders_first or cfg.files_first) then
|
||||
return
|
||||
return nil
|
||||
end
|
||||
|
||||
if not a:is(DirectoryNode) and b:is(DirectoryNode) then
|
||||
@ -72,14 +63,17 @@ local function folders_or_files_first(a, b, cfg)
|
||||
-- folder <> file
|
||||
return not cfg.files_first
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
---@param t table
|
||||
---@param t Node[]
|
||||
---@param cfg SorterCfg
|
||||
---@param first number
|
||||
---@param mid number
|
||||
---@param last number
|
||||
---@param comparator fun(a: Node, b: Node): boolean
|
||||
local function merge(t, first, mid, last, comparator)
|
||||
---@param comparator SorterComparator
|
||||
local function merge(t, cfg, first, mid, last, comparator)
|
||||
local n1 = mid - first + 1
|
||||
local n2 = last - mid
|
||||
local ls = tbl_slice(t, first, mid)
|
||||
@ -89,7 +83,7 @@ local function merge(t, first, mid, last, comparator)
|
||||
local k = first
|
||||
|
||||
while i <= n1 and j <= n2 do
|
||||
if comparator(ls[i], rs[j]) then
|
||||
if comparator(ls[i], rs[j], cfg) then
|
||||
t[k] = ls[i]
|
||||
i = i + 1
|
||||
else
|
||||
@ -112,26 +106,29 @@ local function merge(t, first, mid, last, comparator)
|
||||
end
|
||||
end
|
||||
|
||||
---@param t table
|
||||
---@param t Node[]
|
||||
---@param cfg SorterCfg
|
||||
---@param first number
|
||||
---@param last number
|
||||
---@param comparator fun(a: Node, b: Node): boolean
|
||||
local function split_merge(t, first, last, comparator)
|
||||
---@param comparator SorterComparator
|
||||
local function split_merge(t, cfg, first, last, comparator)
|
||||
if (last - first) < 1 then
|
||||
return
|
||||
end
|
||||
|
||||
local mid = math.floor((first + last) / 2)
|
||||
|
||||
split_merge(t, first, mid, comparator)
|
||||
split_merge(t, mid + 1, last, comparator)
|
||||
merge(t, first, mid, last, comparator)
|
||||
split_merge(t, cfg, first, mid, comparator)
|
||||
split_merge(t, cfg, mid + 1, last, comparator)
|
||||
merge(t, cfg, first, mid, last, comparator)
|
||||
end
|
||||
|
||||
---Perform a merge sort using sorter option.
|
||||
---@param t Node[]
|
||||
function Sorter:sort(t)
|
||||
if type(self.cfg.sorter) == "function" then
|
||||
if C[self.cfg.sorter] then
|
||||
split_merge(t, self.cfg, 1, #t, C[self.cfg.sorter])
|
||||
elseif type(self.cfg.sorter) == "function" then
|
||||
local t_user = {}
|
||||
local origin_index = {}
|
||||
|
||||
@ -148,9 +145,10 @@ function Sorter:sort(t)
|
||||
table.insert(origin_index, n)
|
||||
end
|
||||
|
||||
local predefined = self.cfg.sorter(t_user)
|
||||
if predefined then
|
||||
split_merge(t, 1, #t, self:get_comparator(predefined))
|
||||
-- user may return a SorterType
|
||||
local ret = self.cfg.sorter(t_user)
|
||||
if C[ret] then
|
||||
split_merge(t, self.cfg, 1, #t, C[ret])
|
||||
return
|
||||
end
|
||||
|
||||
@ -173,10 +171,7 @@ function Sorter:sort(t)
|
||||
return (a_index or 0) <= (b_index or 0)
|
||||
end
|
||||
|
||||
split_merge(t, 1, #t, mini_comparator) -- sort by user order
|
||||
elseif type(self.cfg.sorter) == "string" then
|
||||
local sorter = self.cfg.sorter --[[@as string]]
|
||||
split_merge(t, 1, #t, self:get_comparator(sorter))
|
||||
split_merge(t, self.cfg, 1, #t, mini_comparator) -- sort by user order
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user