Sorter type safety

This commit is contained in:
Alexander Courtis 2024-11-09 11:03:36 +11:00
parent b70bb6792e
commit f06bd90d20

View File

@ -1,17 +1,19 @@
local Class = require("nvim-tree.classic")
local DirectoryNode = require("nvim-tree.node.directory")
---@alias SorterType "name" | "case_sensitive" | "modification_time" | "extension" | "suffix" | "filetype"
---@alias SorterComparator fun(a: Node, b: Node, cfg: SorterCfg): boolean
---@type table<SorterType, SorterComparator>
local C = {}
---@alias SorterType "name" | "case_sensitive" | "modification_time" | "extension" | "suffix" | "filetype"
---@class (exact) SorterState
---@class (exact) SorterCfg
---@field sorter SorterType|fun(nodes: Node[])
---@field folders_first boolean
---@field files_first boolean
---@class (exact) Sorter: Class
---@field private state SorterState
---@field private cfg SorterCfg
local Sorter = Class:extend()
---@class Sorter
@ -23,7 +25,7 @@ local Sorter = Class:extend()
---@protected
---@param args SorterArgs
function Sorter:new(args)
self.state = {
self.cfg = {
sorter = args.explorer.opts.sort.sorter,
folders_first = args.explorer.opts.sort.folders_first,
files_first = args.explorer.opts.sort.files_first,
@ -35,7 +37,7 @@ end
---@return fun(a: Node, b: Node): boolean
function Sorter:get_comparator(type)
return function(a, b)
return (C[type] or C.name)(a, b, self.state)
return (C[type] or C.name)(a, b, self.cfg)
end
end
@ -56,7 +58,7 @@ end
---Evaluate `sort.folders_first` and `sort.files_first`
---@param a Node
---@param b Node
---@param cfg SorterState
---@param cfg SorterCfg
---@return boolean|nil
local function folders_or_files_first(a, b, cfg)
if not (cfg.folders_first or cfg.files_first) then
@ -129,7 +131,7 @@ end
---Perform a merge sort using sorter option.
---@param t Node[]
function Sorter:sort(t)
if type(self.state.sorter) == "function" then
if type(self.cfg.sorter) == "function" then
local t_user = {}
local origin_index = {}
@ -146,7 +148,7 @@ function Sorter:sort(t)
table.insert(origin_index, n)
end
local predefined = self.state.sorter(t_user)
local predefined = self.cfg.sorter(t_user)
if predefined then
split_merge(t, 1, #t, self:get_comparator(predefined))
return
@ -172,15 +174,16 @@ function Sorter:sort(t)
end
split_merge(t, 1, #t, mini_comparator) -- sort by user order
elseif type(self.state.sorter) == "string" then
split_merge(t, 1, #t, self:get_comparator(self.pre))
elseif type(self.cfg.sorter) == "string" then
local sorter = self.cfg.sorter --[[@as string]]
split_merge(t, 1, #t, self:get_comparator(sorter))
end
end
---@param a Node
---@param b Node
---@param ignorecase boolean|nil
---@param cfg SorterState
---@param cfg SorterCfg
---@return boolean
local function node_comparator_name_ignorecase_or_not(a, b, ignorecase, cfg)
if not (a and b) then
@ -199,14 +202,17 @@ local function node_comparator_name_ignorecase_or_not(a, b, ignorecase, cfg)
end
end
---@type SorterComparator
function C.case_sensitive(a, b, cfg)
return node_comparator_name_ignorecase_or_not(a, b, false, cfg)
end
---@type SorterComparator
function C.name(a, b, cfg)
return node_comparator_name_ignorecase_or_not(a, b, true, cfg)
end
---@type SorterComparator
function C.modification_time(a, b, cfg)
if not (a and b) then
return true
@ -231,6 +237,7 @@ function C.modification_time(a, b, cfg)
return last_modified_b <= last_modified_a
end
---@type SorterComparator
function C.suffix(a, b, cfg)
if not (a and b) then
return true
@ -280,6 +287,7 @@ function C.suffix(a, b, cfg)
return a_suffix:lower() < b_suffix:lower()
end
---@type SorterComparator
function C.extension(a, b, cfg)
if not (a and b) then
return true
@ -305,6 +313,7 @@ function C.extension(a, b, cfg)
return a_ext < b_ext
end
---@type SorterComparator
function C.filetype(a, b, cfg)
local a_ft = vim.filetype.match({ filename = a.name })
local b_ft = vim.filetype.match({ filename = b.name })