diff --git a/lua/orgmode/utils/promise.lua b/lua/orgmode/utils/promise.lua index 06227b70d..9d3a659b9 100644 --- a/lua/orgmode/utils/promise.lua +++ b/lua/orgmode/utils/promise.lua @@ -1,471 +1,572 @@ ----@diagnostic disable: undefined-field --- Taken from https://github.com/notomo/promise.nvim - -local vim = vim +---@alias OrgPromiseState 'pending' | 'fulfilled' | 'rejected' + +---@class OrgPromisePackedValues +---@field n integer +---@field [integer] any + +---@class OrgPromise +---@field _state OrgPromiseState +---@field _value OrgPromisePackedValues|nil +---@field _handled boolean +---@field _unhandled_check_scheduled boolean +---@field _subscribers fun()[] +local Promise = {} +Promise.__index = Promise -local PackedValue = {} -PackedValue.__index = PackedValue +---@type OrgPromiseState +local PENDING = 'pending' +---@type OrgPromiseState +local FULFILLED = 'fulfilled' +---@type OrgPromiseState +local REJECTED = 'rejected' +local AWAIT = {} +local managed_threads = setmetatable({}, { __mode = 'k' }) + +---@param value any +---@return boolean +local function is_promise(value) + return type(value) == 'table' and getmetatable(value) == Promise +end -function PackedValue.new(...) - local values = vim.F.pack_len(...) - local tbl = { _values = values } - return setmetatable(tbl, PackedValue) +---@return OrgPromisePackedValues +local function pack_values(...) + return vim.F.pack_len(...) end -function PackedValue.pcall(self, f) - local ok_and_value = function(ok, ...) - return ok, PackedValue.new(...) +---@param values OrgPromisePackedValues +---@param from? integer +local function unpack_values(values, from) + if (from or 1) == 1 then + return vim.F.unpack_len(values) end - return ok_and_value(pcall(f, self:unpack())) + return unpack(values, from, values.n) end -function PackedValue.unpack(self) - return vim.F.unpack_len(self._values) +---@param values OrgPromisePackedValues +---@return any +local function first_value(values) + return values[1] end -function PackedValue.first(self) - local first = self:unpack() - return first +---@param fn fun(...: any) +---@param ... any +local function schedule(fn, ...) + local args = pack_values(...) + + vim.schedule(function() + fn(unpack_values(args)) + end) end ---- @generic T : any ---- @generic V : any ---- @class OrgPromise: { next: (fun(self: OrgPromise, resolve:fun(result:T):V):OrgPromise), catch: (fun(self: OrgPromise, reject:fun(err:any)):OrgPromise), finally: (fun(self: OrgPromise, reject:fun(err:any)):OrgPromise), wait: fun(self: OrgPromise, timeout?: number):V } -local Promise = {} -Promise.__index = Promise +---@generic T +---@param fn fun(...: any): T|OrgPromise +---@param on_success fun(...: any) +---@param on_error fun(...: any) +---@param ... any +local function run_coroutine(fn, on_success, on_error, ...) + local thread = coroutine.create(fn) + managed_threads[thread] = true + + local function step(...) + local result = pack_values(coroutine.resume(thread, ...)) + if not result[1] then + managed_threads[thread] = nil + on_error(unpack_values(result, 2)) + return + end + + if coroutine.status(thread) == 'dead' then + managed_threads[thread] = nil + on_success(unpack_values(result, 2)) + return + end + + if result[2] == AWAIT then + local promise = result[3] + Promise.resolve(promise):next(function(...) + step(true, pack_values(...)) + end, function(...) + step(false, pack_values(...)) + end) + return + end -local PromiseStatus = { Pending = 'pending', Fulfilled = 'fulfilled', Rejected = 'rejected' } + Promise.resolve(unpack_values(result, 2)):next(step, on_error) + end -local is_promise = function(v) - return getmetatable(v) == Promise + step(...) end -local new_empty_userdata = function() - return newproxy(true) +---@generic T +---@param promise OrgPromise +local function flush(promise) + if promise._state == PENDING then + return + end + + local subscribers = promise._subscribers + promise._subscribers = {} + + for _, subscriber in ipairs(subscribers) do + subscriber() + end end -local new_pending = function(on_fullfilled, on_rejected) - vim.validate('on_fullfilled', on_fullfilled, 'function', true) - vim.validate('on_rejected', on_rejected, 'function', true) - local tbl = { - _status = PromiseStatus.Pending, - _queued = {}, - _value = nil, - _on_fullfilled = on_fullfilled, - _on_rejected = on_rejected, - _handled = false, - } - local self = setmetatable(tbl, Promise) +---@generic T +---@param promise OrgPromise +local function schedule_unhandled_check(promise) + if promise._unhandled_check_scheduled then + return + end + + promise._unhandled_check_scheduled = true + + vim.schedule(function() + promise._unhandled_check_scheduled = false - local userdata = new_empty_userdata() - self._unhandled_detector = setmetatable({ [self] = userdata }, { __mode = 'k' }) - getmetatable(userdata).__gc = function() - if self._status ~= PromiseStatus.Rejected or self._handled then + if promise._state ~= REJECTED or promise._handled then return end - self._handled = true - vim.schedule(function() - local value = self._value:unpack() - -- Do not report keyboard interrupt errors as unhandled. - -- There is no way to handle pressed "" while waiting for a promise. - if value == 'Keyboard interrupt' then - return - end - local values = vim.inspect({ value }, { newline = '', indent = '' }) - error('unhandled promise rejection: ' .. values, 0) + + local reason = first_value(promise._value) + if reason == 'Keyboard interrupt' then + return + end + + error('unhandled promise rejection: ' .. vim.inspect({ reason }, { newline = '', indent = '' }), 0) + end) +end + +---@generic T +---@param promise OrgPromise +---@param ... any +local function reject_promise(promise, ...) + if promise._state ~= PENDING then + return + end + + promise._state = REJECTED + promise._value = pack_values(...) + promise._handled = promise._handled or #promise._subscribers > 0 + schedule_unhandled_check(promise) + flush(promise) +end + +---@generic T +---@param promise OrgPromise +---@param ... T +local function fulfill_promise(promise, ...) + if promise._state ~= PENDING then + return + end + + promise._state = FULFILLED + promise._value = pack_values(...) + flush(promise) +end + +---@generic T +---@param promise OrgPromise +---@param ... T|OrgPromise +local function resolve_promise(promise, ...) + local argc = select('#', ...) + local first = ... + + if argc == 1 and promise == first then + reject_promise(promise, 'Cannot resolve a promise with itself') + return + end + + if argc == 1 and is_promise(first) then + first:next(function(...) + fulfill_promise(promise, ...) + end, function(...) + reject_promise(promise, ...) end) + return end - return self + fulfill_promise(promise, ...) end ---- Equivalents to JavaScript's Promise.new. ---- @param executor fun(resolve:fun(...:any),reject:fun(...:any)) ---- @return OrgPromise +---@generic T +---@param executor fun(resolve: fun(...: T|OrgPromise), reject: fun(...: any)) +---@return OrgPromise function Promise.new(executor) - vim.validate('executor', executor, 'function') + assert(type(executor) == 'function', 'Promise.new expects an executor function') + + ---@type OrgPromise + local promise = setmetatable({ + _state = PENDING, + _value = nil, + _handled = false, + _unhandled_check_scheduled = false, + _subscribers = {}, + }, Promise) - local self = new_pending() + local settled = false - local resolve = function(...) - local first = ... - if is_promise(first) then - first - :next(function(...) - self:_resolve(...) - end) - :catch(function(...) - self:_reject(...) - end) + local function resolve(...) + if settled then return end - self:_resolve(...) + + settled = true + resolve_promise(promise, ...) end - local reject = function(...) - self:_reject(...) + + local function reject(...) + if settled then + return + end + + settled = true + reject_promise(promise, ...) end - executor(resolve, reject) - return self + run_coroutine(function() + executor(resolve, reject) + end, function() end, reject) + + return promise end ---- Returns a fulfilled promise. ---- But if the first argument is promise, returns the promise. ---- @param ... any: one promise or non-promises ---- @return OrgPromise +---@generic T +---@param ... T|OrgPromise +---@return OrgPromise function Promise.resolve(...) + local argc = select('#', ...) local first = ... - if is_promise(first) then + + if argc == 1 and is_promise(first) then return first end - local value = PackedValue.new(...) - return Promise.new(function(resolve, _) - resolve(value:unpack()) - end) + + ---@type OrgPromise + local promise = setmetatable({ + _state = PENDING, + _value = nil, + _handled = false, + _unhandled_check_scheduled = false, + _subscribers = {}, + }, Promise) + + resolve_promise(promise, ...) + + return promise end ---- Returns a rejected promise. ---- But if the first argument is promise, returns the promise. ---- @param ... any: one promise or non-promises ---- @return OrgPromise +---@param ... any +---@return OrgPromise function Promise.reject(...) + local argc = select('#', ...) local first = ... - if is_promise(first) then + + if argc == 1 and is_promise(first) then return first end - local value = PackedValue.new(...) - return Promise.new(function(_, reject) - reject(value:unpack()) - end) -end -function Promise._resolve(self, ...) - if self._status == PromiseStatus.Rejected then - return - end - self._status = PromiseStatus.Fulfilled - self._value = PackedValue.new(...) - for _ = 1, #self._queued do - local promise = table.remove(self._queued, 1) - promise:_start_resolve(self._value) - end -end + ---@type OrgPromise + local promise = setmetatable({ + _state = PENDING, + _value = nil, + _handled = false, + _unhandled_check_scheduled = false, + _subscribers = {}, + }, Promise) -function Promise._start_resolve(self, value) - if not self._on_fullfilled then - return vim.schedule(function() - self:_resolve(value:unpack()) - end) - end - local ok, result = value:pcall(self._on_fullfilled) - if not ok then - return vim.schedule(function() - self:_reject(result:unpack()) - end) - end - local first = result:first() - if not is_promise(first) then - return vim.schedule(function() - self:_resolve(result:unpack()) - end) - end - first - :next(function(...) - self:_resolve(...) - end) - :catch(function(...) - self:_reject(...) - end) -end + reject_promise(promise, ...) -function Promise._reject(self, ...) - if self._status == PromiseStatus.Fulfilled then - return - end - self._status = PromiseStatus.Rejected - self._value = PackedValue.new(...) - self._handled = self._handled or #self._queued > 0 - for _ = 1, #self._queued do - local promise = table.remove(self._queued, 1) - promise:_start_reject(self._value) - end + return promise end -function Promise._start_reject(self, value) - if not self._on_rejected then - return vim.schedule(function() - self:_reject(value:unpack()) - end) - end - local ok, result = value:pcall(self._on_rejected) - local first = result:first() - if ok and not is_promise(first) then - return vim.schedule(function() - self:_resolve(result:unpack()) +---@generic T, U +---@param self OrgPromise +---@param on_fulfilled? fun(...: any): U|OrgPromise +---@param on_rejected? fun(...: any): U|OrgPromise +---@return OrgPromise +function Promise:next(on_fulfilled, on_rejected) + ---@type OrgPromise + local child = setmetatable({ + _state = PENDING, + _value = nil, + _handled = false, + _unhandled_check_scheduled = false, + _subscribers = {}, + }, Promise) + + local function run_callback() + local callback = nil + if self._state == FULFILLED then + callback = on_fulfilled + else + callback = on_rejected + end + + if callback == nil then + if self._state == FULFILLED then + fulfill_promise(child, unpack_values(self._value)) + else + reject_promise(child, unpack_values(self._value)) + end + return + end + + run_coroutine(function() + return callback(unpack_values(self._value)) + end, function(...) + resolve_promise(child, ...) + end, function(...) + reject_promise(child, ...) end) end - if not is_promise(first) then - return vim.schedule(function() - self:_reject(result:unpack()) - end) + + self._handled = true + table.insert(self._subscribers, run_callback) + + if self._state ~= PENDING then + schedule(run_callback) end - first - :next(function(...) - self:_resolve(...) - end) - :catch(function(...) - self:_reject(...) - end) -end ---- Equivalents to JavaScript's Promise.then. ---- @param on_fullfilled (fun(...:any):any)?: A callback on fullfilled. ---- @param on_rejected (fun(...:any):any)?: A callback on rejected. ---- @return OrgPromise -function Promise.next(self, on_fullfilled, on_rejected) - vim.validate('on_fullfilled', on_fullfilled, 'function', true) - vim.validate('on_rejected', on_rejected, 'function', true) - local promise = new_pending(on_fullfilled, on_rejected) - table.insert(self._queued, promise) - vim.schedule(function() - if self._status == PromiseStatus.Fulfilled then - return self:_resolve(self._value:unpack()) - end - if self._status == PromiseStatus.Rejected then - return self:_reject(self._value:unpack()) - end - end) - return promise + return child end ---- Equivalents to JavaScript's Promise.catch. ---- @param on_rejected (fun(...:any):any)?: A callback on rejected. ---- @return OrgPromise -function Promise.catch(self, on_rejected) +---@generic T, U +---@param self OrgPromise +---@param on_rejected fun(reason: any): U|OrgPromise +---@return OrgPromise +function Promise:catch(on_rejected) return self:next(nil, on_rejected) end ---- Equivalents to JavaScript's Promise.finally. ---- @param on_finally fun() ---- @return OrgPromise -function Promise.finally(self, on_finally) - vim.validate('on_finally', on_finally, 'function', true) - return self - :next(function(...) - on_finally() - return ... - end) - :catch(function(...) - on_finally() - return Promise.reject(...) - end) +---@generic T +---@param self OrgPromise +---@param on_finally fun() +---@return OrgPromise +function Promise:finally(on_finally) + assert(type(on_finally) == 'function', 'Promise.finally expects a callback function') + + return self:next(function(...) + on_finally() + return ... + end, function(...) + on_finally() + return Promise.reject(...) + end) end ---- Equivalents to JavaScript's Promise.then. ---- @param timeout? number ---- @return any -function Promise.wait(self, timeout) - timeout = timeout or 5000 - local is_done = false - local has_error = false - local result = nil - - self - :next(function(...) - result = PackedValue.new(...) - is_done = true - end) - :catch(function(...) - has_error = true - result = PackedValue.new(...) - is_done = true - end) +---@generic T +---@param self OrgPromise +---@param timeout? integer +---@param interval? integer +---@return T +function Promise:wait(timeout, interval) + self._handled = true + + if self._state == PENDING then + local timeout_ms = timeout or 30000 + local ok = vim.wait(timeout_ms, function() + return self._state ~= PENDING + end, interval or 10) + + if not ok then + error(('Promise timed out after %dms'):format(timeout_ms), 0) + end + end + + if self._state == REJECTED then + error(first_value(self._value), 0) + end - local success, code = vim.wait(timeout, function() - return is_done - end, 1) + return unpack_values(self._value) +end - local value = result and result:unpack() +---Run a function in a managed async coroutine context. +---@generic T +---@param callback fun(...: any): T|OrgPromise +---@param ... any +---@return OrgPromise +function Promise.async(callback, ...) + assert(type(callback) == 'function', 'Promise.async expects a callback function') + local args = pack_values(...) - if has_error then - return error(value) + return Promise.new(function(resolve, reject) + run_coroutine(callback, resolve, reject, unpack_values(args)) + end) +end + +---Await promise resolution without blocking the event loop. +---Must be called from a yieldable managed async coroutine context. +---@generic T +---@return T +function Promise:await() + local thread, is_main = coroutine.running() + if not thread or is_main then + error('promise:await() must be called from a yieldable async context', 0) end - if success then - return value + if self._state == FULFILLED then + return unpack_values(self._value) end - if code == -1 then - return error('promise timeout of ' .. tostring(timeout) .. 'ms reached') - elseif code == -2 then - return error('promise interrupted') + if self._state == REJECTED then + error(first_value(self._value), 0) end - return error('promise failed with unknown reason') -end + if not managed_threads[thread] then + self:next(function(...) + coroutine.resume(thread, true, pack_values(...)) + end, function(...) + coroutine.resume(thread, false, pack_values(...)) + end) ---- Equivalents to JavaScript's Promise.all. ---- Even if multiple value are resolved, results include only the first value. ---- @param list any[]: promise or non-promise values ---- @return OrgPromise -function Promise.all(list) - vim.validate('list', list, 'table') - return Promise.new(function(resolve, reject) - local remain = #list - if remain == 0 then - return resolve({}) + local ok, values = coroutine.yield() + if not ok then + error(first_value(values), 0) end - local results = {} - for i, e in ipairs(list) do - Promise.resolve(e) - :next(function(...) - local first = ... - results[i] = first - if remain == 1 then - return resolve(results) - end - remain = remain - 1 - end) - :catch(function(...) - reject(...) - end) - end - end) + return unpack_values(values) + end + + local ok, values = coroutine.yield(AWAIT, self) + if not ok then + error(first_value(values), 0) + end + + return unpack_values(values) end ---- Equivalents to JavaScript's Promise.map with concurrency limit. ---- @param callback fun(value: any, index: number, array: any[]): any ---- @param list any[]: promise or non-promise values ---- @param concurrency? number: limit number of concurrent items processing ---- @return OrgPromise -function Promise.map(callback, list, concurrency) - vim.validate('list', list, 'table') - vim.validate('callback', callback, 'function') - vim.validate('concurrency', concurrency, 'number', true) - - local results = {} - local processing = 0 - local index = 1 - concurrency = concurrency or #list +---@generic T +---@param items (T|OrgPromise)[] +---@return OrgPromise +function Promise.all(items) + assert(type(items) == 'table', 'Promise.all expects a list-like table') + + local total = #items + + if total == 0 then + return Promise.resolve({}) + end return Promise.new(function(resolve, reject) - local function processNext() - if index > #list then - if processing == 0 then - resolve(results) + ---@type T[] + local results = {} + local completed = 0 + local settled = false + + for index, item in ipairs(items) do + Promise.resolve(item):next(function(value) + if settled then + return end - return - end - local i = index - index = index + 1 - processing = processing + 1 + results[index] = value + completed = completed + 1 - Promise.resolve(callback(list[i], i, list)) - :next(function(...) - results[i] = ... - processing = processing - 1 - processNext() - end) - :catch(function(...) - reject(...) - end) - end + if completed == total then + settled = true + resolve(results) + end + end, function(reason) + if settled then + return + end - for _ = 1, concurrency do - processNext() + settled = true + reject(reason) + end) end end) end ---- Equivalents to JavaScript's Promise.mapSeries ---- @param callback fun(value: any, index: number): any ---- @param list any[]: promise or non-promise values ---- @return OrgPromise -function Promise.mapSeries(callback, list) - return Promise.map(callback, list, 1) -end +---@generic T, U +---@param mapper fun(item: T, index: integer, items: T[]): U|OrgPromise +---@param items T[] +---@param concurrency? integer +---@return OrgPromise +function Promise.map(mapper, items, concurrency) + assert(type(items) == 'table', 'Promise.map expects a list-like table') + assert(type(mapper) == 'function', 'Promise.map expects a mapper function') ---- Equivalents to JavaScript's Promise.race. ---- @param list any[]: promise or non-promise values ---- @return OrgPromise -function Promise.race(list) - vim.validate('list', list, 'table') - return Promise.new(function(resolve, reject) - for _, e in ipairs(list) do - Promise.resolve(e) - :next(function(...) - resolve(...) - end) - :catch(function(...) - reject(...) - end) - end - end) -end + local total = #items + concurrency = concurrency or total + + if total == 0 then + return Promise.resolve({}) + end + + assert(type(concurrency) == 'number' and concurrency >= 1, 'Promise.map concurrency must be >= 1') + concurrency = math.floor(concurrency) + concurrency = math.min(concurrency, total) ---- Equivalents to JavaScript's Promise.any. ---- Even if multiple value are rejected, errors include only the first value. ---- @param list any[]: promise or non-promise values ---- @return OrgPromise -function Promise.any(list) - vim.validate('list', list, 'table') return Promise.new(function(resolve, reject) - local remain = #list - if remain == 0 then - return reject({}) - end + ---@type U[] + local results = {} + local next_index = 1 + local active = 0 + local completed = 0 + local settled = false - local errs = {} - for i, e in ipairs(list) do - Promise.resolve(e) - :next(function(...) - resolve(...) - end) - :catch(function(...) - local first = ... - errs[i] = first - if remain == 1 then - return reject(errs) + local function pump() + if settled then + return + end + + while active < concurrency and next_index <= total do + local index = next_index + next_index = next_index + 1 + active = active + 1 + + run_coroutine(function() + return mapper(items[index], index, items) + end, function(value) + Promise.resolve(value):next(function(resolved) + if settled then + return + end + + results[index] = resolved + active = active - 1 + completed = completed + 1 + + if completed == total then + settled = true + resolve(results) + return + end + + pump() + end, function(reason) + if settled then + return + end + + settled = true + active = active - 1 + reject(reason) + end) + end, function(reason) + if settled then + return end - remain = remain - 1 + + settled = true + active = active - 1 + reject(reason) end) + end end + + pump() end) end ---- Equivalents to JavaScript's Promise.allSettled. ---- Even if multiple value are resolved/rejected, value/reason is only the first value. ---- @param list any[]: promise or non-promise values ---- @return OrgPromise -function Promise.all_settled(list) - vim.validate('list', list, 'table') - return Promise.new(function(resolve) - local remain = #list - if remain == 0 then - return resolve({}) - end - - local results = {} - for i, e in ipairs(list) do - Promise.resolve(e) - :next(function(...) - local first = ... - results[i] = { status = PromiseStatus.Fulfilled, value = first } - end) - :catch(function(...) - local first = ... - results[i] = { status = PromiseStatus.Rejected, reason = first } - end) - :finally(function() - if remain == 1 then - return resolve(results) - end - remain = remain - 1 - end) - end - end) +---@generic T, U +---@param mapper fun(item: T, index: integer, items: T[]): U|OrgPromise +---@param items T[] +---@return OrgPromise +function Promise.mapSeries(mapper, items) + return Promise.map(mapper, items, 1) end return Promise