@@ -3,6 +3,9 @@ namespace GraphBLAS.FSharp.Backend.Common
33open Brahma.FSharp
44open FSharp.Quotations
55open GraphBLAS.FSharp .Backend .Quotes
6+ open GraphBLAS.FSharp .Backend .Objects .ArraysExtensions
7+ open GraphBLAS.FSharp .Backend .Objects .ClCell
8+ open GraphBLAS.FSharp .Backend .Objects .ClContext
69
710module PrefixSum =
811 let private update ( opAdd : Expr < 'a -> 'a -> 'a >) ( clContext : ClContext ) workGroupSize =
@@ -38,7 +41,7 @@ module PrefixSum =
3841 )
3942
4043 processor.Post( Msg.CreateRunMsg<_, _> kernel)
41- processor.Post ( Msg.CreateFreeMsg ( mirror))
44+ mirror.Free processor
4245
4346 let private scanGeneral
4447 beforeLocalSumClear
@@ -48,10 +51,8 @@ module PrefixSum =
4851 workGroupSize
4952 =
5053
51- let subSum = SubSum.treeSum opAdd
52-
5354 let scan =
54- <@ fun ( ndRange : Range1D ) inputArrayLength verticesLength ( resultBuffer : ClArray < 'a >) ( verticesBuffer : ClArray < 'a >) ( totalSumBuffer : ClCell < 'a >) ( zero : ClCell < 'a >) ( mirror : ClCell < bool >) ->
55+ <@ fun ( ndRange : Range1D ) inputArrayLength verticesLength ( inputArray : ClArray < 'a >) ( verticesBuffer : ClArray < 'a >) ( totalSumBuffer : ClCell < 'a >) ( zero : ClCell < 'a >) ( mirror : ClCell < bool >) ->
5556
5657 let mirror = mirror.Value
5758
@@ -62,46 +63,34 @@ module PrefixSum =
6263 if mirror then
6364 i <- inputArrayLength - 1 - i
6465
65- let localID = ndRange.LocalID0
66+ let lid = ndRange.LocalID0
6667
6768 let zero = zero.Value
6869
6970 if gid < inputArrayLength then
70- resultLocalBuffer.[ localID ] <- resultBuffer .[ i]
71+ resultLocalBuffer.[ lid ] <- inputArray .[ i]
7172 else
72- resultLocalBuffer.[ localID ] <- zero
73+ resultLocalBuffer.[ lid ] <- zero
7374
7475 barrierLocal ()
7576
76- (% subSum) workGroupSize localID resultLocalBuffer
77-
78- if localID = workGroupSize - 1 then
79- if verticesLength <= 1 && localID = gid then
80- totalSumBuffer.Value <- resultLocalBuffer.[ localID]
81-
82- verticesBuffer.[ gid / workGroupSize] <- resultLocalBuffer.[ localID]
83- (% beforeLocalSumClear) resultBuffer resultLocalBuffer.[ localID] inputArrayLength gid i
84- resultLocalBuffer.[ localID] <- zero
85-
86- let mutable step = workGroupSize
87-
88- while step > 1 do
89- barrierLocal ()
77+ // Local tree reduce
78+ (% SubSum.upSweep opAdd) workGroupSize lid resultLocalBuffer
9079
91- if localID < workGroupSize / step then
92- let i = step * ( localID + 1 ) - 1
93- let j = i - ( step >>> 1 )
80+ if lid = workGroupSize - 1 then
81+ // if last iteration
82+ if verticesLength <= 1 && lid = gid then
83+ totalSumBuffer.Value <- resultLocalBuffer.[ lid]
9484
95- let tmp = resultLocalBuffer.[ i]
96- let buff = (% opAdd) tmp resultLocalBuffer.[ j]
97- resultLocalBuffer.[ i] <- buff
98- resultLocalBuffer.[ j] <- tmp
85+ verticesBuffer.[ gid / workGroupSize] <- resultLocalBuffer.[ lid]
86+ (% beforeLocalSumClear) inputArray resultLocalBuffer.[ lid] inputArrayLength gid i
87+ resultLocalBuffer.[ lid] <- zero
9988
100- step <- step >>> 1
89+ (% SubSum.downSweep opAdd ) workGroupSize lid resultLocalBuffer
10190
10291 barrierLocal ()
10392
104- (% writeData) resultBuffer resultLocalBuffer inputArrayLength workGroupSize gid i localID @>
93+ (% writeData) inputArray resultLocalBuffer inputArrayLength workGroupSize gid i lid @>
10594
10695 let program = clContext.Compile( scan)
10796
@@ -132,13 +121,14 @@ module PrefixSum =
132121 )
133122
134123 processor.Post( Msg.CreateRunMsg<_, _> kernel)
135- processor.Post( Msg.CreateFreeMsg( zero))
136- processor.Post( Msg.CreateFreeMsg( mirror))
124+
125+ zero.Free processor
126+ mirror.Free processor
137127
138128 let private scanExclusive < 'a when 'a : struct > =
139129 scanGeneral
140130 <@ fun ( _ : ClArray < 'a >) ( _ : 'a ) ( _ : int ) ( _ : int ) ( _ : int ) -> () @>
141- <@ fun ( resultBuffer : ClArray < 'a >) ( resultLocalBuffer : 'a []) ( inputArrayLength : int ) ( smth : int ) ( gid : int ) ( i : int ) ( localID : int ) ->
131+ <@ fun ( resultBuffer : ClArray < 'a >) ( resultLocalBuffer : 'a []) ( inputArrayLength : int ) ( _ : int ) ( gid : int ) ( i : int ) ( localID : int ) ->
142132
143133 if gid < inputArrayLength then
144134 resultBuffer.[ i] <- resultLocalBuffer.[ localID] @>
@@ -147,8 +137,7 @@ module PrefixSum =
147137 scanGeneral
148138 <@ fun ( resultBuffer : ClArray < 'a >) ( value : 'a ) ( inputArrayLength : int ) ( gid : int ) ( i : int ) ->
149139
150- if gid < inputArrayLength then
151- resultBuffer.[ i] <- value @>
140+ if gid < inputArrayLength then resultBuffer.[ i] <- value @>
152141 <@ fun ( resultBuffer : ClArray < 'a >) ( resultLocalBuffer : 'a []) ( inputArrayLength : int ) ( workGroupSize : int ) ( gid : int ) ( i : int ) ( localID : int ) ->
153142
154143 if gid < inputArrayLength
@@ -206,8 +195,8 @@ module PrefixSum =
206195 verticesArrays <- swap verticesArrays
207196 verticesLength <- ( verticesLength - 1 ) / workGroupSize + 1
208197
209- processor.Post ( Msg.CreateFreeMsg ( firstVertices))
210- processor.Post ( Msg.CreateFreeMsg ( secondVertices))
198+ firstVertices.Free processor
199+ secondVertices.Free processor
211200
212201 totalSum
213202
@@ -270,3 +259,119 @@ module PrefixSum =
270259 fun ( processor : MailboxProcessor < _ >) ( inputArray : ClArray < int >) ->
271260
272261 scan processor inputArray 0
262+
263+
264+ module ByKey =
265+ let private oneWorkGroup
266+ writeZero
267+ zero
268+ uniqueKey
269+ ( opAdd : Expr < 'a -> 'a -> 'a >)
270+ ( clContext : ClContext )
271+ workGroupSize
272+ =
273+
274+ let scan =
275+ <@ fun ( ndRange : Range1D ) length ( values : ClArray < 'a >) ( keys : ClArray < int >) ->
276+
277+ let localValues = localArray< 'a> workGroupSize
278+ let localKeys = localArray< int> workGroupSize
279+
280+ let gid = ndRange.GlobalID0
281+ let lid = ndRange.LocalID0
282+
283+ if gid < length then
284+ // only one workgroup
285+ localValues.[ lid] <- values.[ lid]
286+ localKeys.[ lid] <- keys.[ gid]
287+ else
288+ localValues.[ lid] <- zero
289+ localKeys.[ lid] <- uniqueKey
290+
291+ barrierLocal ()
292+
293+ // Local tree reduce
294+ (% SubSum.upSweepByKey opAdd) workGroupSize lid localValues localKeys
295+
296+ // if root item
297+ if lid = workGroupSize - 1
298+ || localValues.[ lid] <> localValues.[ lid + 1 ] then
299+
300+ (% writeZero) localValues lid zero
301+
302+ (% SubSum.downSweepByKey opAdd) workGroupSize lid localValues localKeys
303+
304+ barrierLocal ()
305+
306+ values.[ lid] <- localValues.[ lid] @>
307+
308+ let program = clContext.Compile( scan)
309+
310+ fun ( processor : MailboxProcessor < _ >) ( keys : ClArray < int >) ( values : ClArray < 'a >) ->
311+
312+ let kernel = program.GetKernel()
313+
314+ let ndRange =
315+ Range1D.CreateValid( values.Length, workGroupSize)
316+
317+ processor.Post(
318+ Msg.MsgSetArguments
319+ ( fun () ->
320+ kernel.KernelFunc
321+ ndRange
322+ values.Length
323+ values
324+ keys)
325+ )
326+
327+ processor.Post( Msg.CreateRunMsg<_, _> kernel)
328+
329+ let sequentialSegments opWrite ( clContext : ClContext ) workGroupSize opAdd zero =
330+
331+ let kernel =
332+ <@ fun ( ndRange : Range1D ) lenght uniqueKeysCount ( values : ClArray < 'a >) ( keys : ClArray < int >) ( offsets : ClArray < int >) ->
333+ let gid = ndRange.GlobalID0
334+
335+ if gid < uniqueKeysCount then
336+ let sourcePosition = offsets.[ gid]
337+ let sourceKey = keys.[ sourcePosition]
338+
339+ let mutable currentSum = values.[ sourcePosition]
340+ let mutable previousSum = zero
341+
342+ values.[ gid] <- (% opWrite) previousSum currentSum
343+
344+ let mutable currentPosition = sourcePosition + 1
345+
346+ while currentPosition < lenght
347+ && keys.[ currentPosition] = sourceKey do
348+
349+ previousSum <- currentSum
350+ currentSum <- (% opAdd) currentSum values.[ currentPosition]
351+
352+ values.[ gid] <- (% opWrite) previousSum currentSum
353+
354+ currentPosition <- currentPosition + 1 @>
355+
356+ let kernel = clContext.Compile kernel
357+
358+ fun ( processor : MailboxProcessor < _ >) uniqueKeysCount ( values : ClArray < 'a >) ( keys : ClArray < int >) ( offsets : ClArray < int >) ->
359+
360+ let kernel = kernel.GetKernel()
361+
362+ let ndRange =
363+ Range1D.CreateValid( values.Length, workGroupSize)
364+
365+ processor.Post(
366+ Msg.MsgSetArguments
367+ ( fun () ->
368+ kernel.KernelFunc
369+ ndRange
370+ values.Length
371+ uniqueKeysCount
372+ values
373+ keys
374+ offsets)
375+ )
376+
377+ processor.Post( Msg.CreateRunMsg<_, _> kernel)
0 commit comments