From 06b63260b3e64ec4b47dbfcc7e40343b21443419 Mon Sep 17 00:00:00 2001 From: Alexander Courtis Date: Sat, 9 Nov 2024 11:49:40 +1100 Subject: [PATCH] Sorter type safety --- lua/nvim-tree/explorer/sorter.lua | 63 ++++++++++++++----------------- 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/lua/nvim-tree/explorer/sorter.lua b/lua/nvim-tree/explorer/sorter.lua index 2c112ba0..37b33fa7 100644 --- a/lua/nvim-tree/explorer/sorter.lua +++ b/lua/nvim-tree/explorer/sorter.lua @@ -2,13 +2,16 @@ local Class = require("nvim-tree.classic") local DirectoryNode = require("nvim-tree.node.directory") ---@alias SorterType "name" | "case_sensitive" | "modification_time" | "extension" | "suffix" | "filetype" ----@alias SorterComparator fun(a: Node, b: Node, cfg: SorterCfg): boolean + +---@alias SorterUser fun(nodes: Node[]): SorterType? + +---@alias SorterComparator fun(a: Node, b: Node, cfg: SorterCfg): boolean? ---@type table local C = {} ---@class (exact) SorterCfg ----@field sorter SorterType|fun(nodes: Node[]) +---@field sorter SorterType|SorterUser ---@field folders_first boolean ---@field files_first boolean @@ -32,15 +35,6 @@ function Sorter:new(args) } end ----Predefined comparator ----@param type SorterType ----@return fun(a: Node, b: Node): boolean -function Sorter:get_comparator(type) - return function(a, b) - return (C[type] or C.name)(a, b, self.cfg) - end -end - ---Create a shallow copy of a portion of a list. ---@param t table ---@param first integer First index, inclusive @@ -56,13 +50,10 @@ local function tbl_slice(t, first, last) end ---Evaluate `sort.folders_first` and `sort.files_first` ----@param a Node ----@param b Node ----@param cfg SorterCfg ----@return boolean|nil +---@type SorterComparator local function folders_or_files_first(a, b, cfg) if not (cfg.folders_first or cfg.files_first) then - return + return nil end if not a:is(DirectoryNode) and b:is(DirectoryNode) then @@ -72,14 +63,17 @@ local function folders_or_files_first(a, b, cfg) -- folder <> file return not cfg.files_first end + + return nil end ----@param t table +---@param t Node[] +---@param cfg SorterCfg ---@param first number ---@param mid number ---@param last number ----@param comparator fun(a: Node, b: Node): boolean -local function merge(t, first, mid, last, comparator) +---@param comparator SorterComparator +local function merge(t, cfg, first, mid, last, comparator) local n1 = mid - first + 1 local n2 = last - mid local ls = tbl_slice(t, first, mid) @@ -89,7 +83,7 @@ local function merge(t, first, mid, last, comparator) local k = first while i <= n1 and j <= n2 do - if comparator(ls[i], rs[j]) then + if comparator(ls[i], rs[j], cfg) then t[k] = ls[i] i = i + 1 else @@ -112,26 +106,29 @@ local function merge(t, first, mid, last, comparator) end end ----@param t table +---@param t Node[] +---@param cfg SorterCfg ---@param first number ---@param last number ----@param comparator fun(a: Node, b: Node): boolean -local function split_merge(t, first, last, comparator) +---@param comparator SorterComparator +local function split_merge(t, cfg, first, last, comparator) if (last - first) < 1 then return end local mid = math.floor((first + last) / 2) - split_merge(t, first, mid, comparator) - split_merge(t, mid + 1, last, comparator) - merge(t, first, mid, last, comparator) + split_merge(t, cfg, first, mid, comparator) + split_merge(t, cfg, mid + 1, last, comparator) + merge(t, cfg, first, mid, last, comparator) end ---Perform a merge sort using sorter option. ---@param t Node[] function Sorter:sort(t) - if type(self.cfg.sorter) == "function" then + if C[self.cfg.sorter] then + split_merge(t, self.cfg, 1, #t, C[self.cfg.sorter]) + elseif type(self.cfg.sorter) == "function" then local t_user = {} local origin_index = {} @@ -148,9 +145,10 @@ function Sorter:sort(t) table.insert(origin_index, n) end - local predefined = self.cfg.sorter(t_user) - if predefined then - split_merge(t, 1, #t, self:get_comparator(predefined)) + -- user may return a SorterType + local ret = self.cfg.sorter(t_user) + if C[ret] then + split_merge(t, self.cfg, 1, #t, C[ret]) return end @@ -173,10 +171,7 @@ function Sorter:sort(t) return (a_index or 0) <= (b_index or 0) end - split_merge(t, 1, #t, mini_comparator) -- sort by user order - elseif type(self.cfg.sorter) == "string" then - local sorter = self.cfg.sorter --[[@as string]] - split_merge(t, 1, #t, self:get_comparator(sorter)) + split_merge(t, self.cfg, 1, #t, mini_comparator) -- sort by user order end end