diff --git a/src/Data/Traversable.lua b/src/Data/Traversable.lua index a846178..2592d76 100644 --- a/src/Data/Traversable.lua +++ b/src/Data/Traversable.lua @@ -1,31 +1,32 @@ local array1 = function(a) return {a} end local array2 = function(a) return function(b) return {a, b} end end local array3 = function(a) return function(b) return function(c) return {a, b, c} end end end -local concat2 = function(xs) return function(ys) return table.concat(xs, ys) end end return { traverseArrayImpl = (function(apply) return function(map) return function(pure) - return function(f) - return function(array) - local function go(bot, top) - if top - bot == 0 then - return pure({}) - elseif top - bot == 1 then - return map(array1)(f(array[bot])) - elseif top - bot == 2 then - return apply(map(array2)(f(array[bot])))(f(array[bot + 1])) - elseif top - bot == 3 then - return apply(apply(map(array3)(f(array[bot])))(f(array[bot + 1])))(f(array[bot + 2])) - else - -- This slightly tricky pivot selection aims to produce two - -- even-length partitions where possible. - local pivot = bot + math.floor((top - bot) / 4) * 2 - return apply(map(concat2)(go(bot, pivot)))(go(pivot, top)) + return function (appendArrays) + return function(f) + return function(array) + local function go(bot, top) + if top - bot == 0 then + return pure({}) + elseif top - bot == 1 then + return map(array1)(f(array[bot + 1])) + elseif top - bot == 2 then + return apply(map(array2)(f(array[bot + 1])))(f(array[bot + 2])) + elseif top - bot == 3 then + return apply(apply(map(array3)(f(array[bot + 1])))(f(array[bot + 2])))(f(array[bot + 3])) + else + -- This slightly tricky pivot selection aims to produce two + -- even-length partitions where possible. + local pivot = bot + math.floor((top - bot) / 4) * 2 + return apply(map(appendArrays)(go(bot, pivot)))(go(pivot, top)) + end end - end - return go(0, #array) + return go(0, #array) + end end end end diff --git a/src/Data/Traversable.purs b/src/Data/Traversable.purs index 180bf2f..c8342b0 100644 --- a/src/Data/Traversable.purs +++ b/src/Data/Traversable.purs @@ -100,7 +100,7 @@ sequenceDefault sequenceDefault = traverse identity instance traversableArray :: Traversable Array where - traverse = traverseArrayImpl apply map pure + traverse = traverseArrayImpl apply map pure (<>) sequence = sequenceDefault foreign import traverseArrayImpl @@ -108,6 +108,7 @@ foreign import traverseArrayImpl . (forall x y. m (x -> y) -> m x -> m y) -> (forall x y. (x -> y) -> m x -> m y) -> (forall x. x -> m x) + -> (forall x. Array x -> Array x -> Array x) -> (a -> m b) -> Array a -> m (Array b)