Skip to content

Commit fa392b0

Browse files
committed
refactor: prefix sum
1 parent 7885e8a commit fa392b0

File tree

2 files changed

+213
-40
lines changed

2 files changed

+213
-40
lines changed

src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs

Lines changed: 142 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ namespace GraphBLAS.FSharp.Backend.Common
33
open Brahma.FSharp
44
open FSharp.Quotations
55
open GraphBLAS.FSharp.Backend.Quotes
6+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
7+
open GraphBLAS.FSharp.Backend.Objects.ClCell
8+
open GraphBLAS.FSharp.Backend.Objects.ClContext
69

710
module PrefixSum =
811
let private update (opAdd: Expr<'a -> 'a -> 'a>) (clContext: ClContext) workGroupSize =
@@ -38,7 +41,7 @@ module PrefixSum =
3841
)
3942

4043
processor.Post(Msg.CreateRunMsg<_, _> kernel)
41-
processor.Post(Msg.CreateFreeMsg(mirror))
44+
mirror.Free processor
4245

4346
let private scanGeneral
4447
beforeLocalSumClear
@@ -48,10 +51,8 @@ module PrefixSum =
4851
workGroupSize
4952
=
5053

51-
let subSum = SubSum.treeSum opAdd
52-
5354
let scan =
54-
<@ fun (ndRange: Range1D) inputArrayLength verticesLength (resultBuffer: ClArray<'a>) (verticesBuffer: ClArray<'a>) (totalSumBuffer: ClCell<'a>) (zero: ClCell<'a>) (mirror: ClCell<bool>) ->
55+
<@ fun (ndRange: Range1D) inputArrayLength verticesLength (inputArray: ClArray<'a>) (verticesBuffer: ClArray<'a>) (totalSumBuffer: ClCell<'a>) (zero: ClCell<'a>) (mirror: ClCell<bool>) ->
5556

5657
let mirror = mirror.Value
5758

@@ -62,46 +63,34 @@ module PrefixSum =
6263
if mirror then
6364
i <- inputArrayLength - 1 - i
6465

65-
let localID = ndRange.LocalID0
66+
let lid = ndRange.LocalID0
6667

6768
let zero = zero.Value
6869

6970
if gid < inputArrayLength then
70-
resultLocalBuffer.[localID] <- resultBuffer.[i]
71+
resultLocalBuffer.[lid] <- inputArray.[i]
7172
else
72-
resultLocalBuffer.[localID] <- zero
73+
resultLocalBuffer.[lid] <- zero
7374

7475
barrierLocal ()
7576

76-
(%subSum) workGroupSize localID resultLocalBuffer
77-
78-
if localID = workGroupSize - 1 then
79-
if verticesLength <= 1 && localID = gid then
80-
totalSumBuffer.Value <- resultLocalBuffer.[localID]
81-
82-
verticesBuffer.[gid / workGroupSize] <- resultLocalBuffer.[localID]
83-
(%beforeLocalSumClear) resultBuffer resultLocalBuffer.[localID] inputArrayLength gid i
84-
resultLocalBuffer.[localID] <- zero
85-
86-
let mutable step = workGroupSize
87-
88-
while step > 1 do
89-
barrierLocal ()
77+
// Local tree reduce
78+
(%SubSum.upSweep opAdd) workGroupSize lid resultLocalBuffer
9079

91-
if localID < workGroupSize / step then
92-
let i = step * (localID + 1) - 1
93-
let j = i - (step >>> 1)
80+
if lid = workGroupSize - 1 then
81+
// if last iteration
82+
if verticesLength <= 1 && lid = gid then
83+
totalSumBuffer.Value <- resultLocalBuffer.[lid]
9484

95-
let tmp = resultLocalBuffer.[i]
96-
let buff = (%opAdd) tmp resultLocalBuffer.[j]
97-
resultLocalBuffer.[i] <- buff
98-
resultLocalBuffer.[j] <- tmp
85+
verticesBuffer.[gid / workGroupSize] <- resultLocalBuffer.[lid]
86+
(%beforeLocalSumClear) inputArray resultLocalBuffer.[lid] inputArrayLength gid i
87+
resultLocalBuffer.[lid] <- zero
9988

100-
step <- step >>> 1
89+
(%SubSum.downSweep opAdd) workGroupSize lid resultLocalBuffer
10190

10291
barrierLocal ()
10392

104-
(%writeData) resultBuffer resultLocalBuffer inputArrayLength workGroupSize gid i localID @>
93+
(%writeData) inputArray resultLocalBuffer inputArrayLength workGroupSize gid i lid @>
10594

10695
let program = clContext.Compile(scan)
10796

@@ -132,13 +121,14 @@ module PrefixSum =
132121
)
133122

134123
processor.Post(Msg.CreateRunMsg<_, _> kernel)
135-
processor.Post(Msg.CreateFreeMsg(zero))
136-
processor.Post(Msg.CreateFreeMsg(mirror))
124+
125+
zero.Free processor
126+
mirror.Free processor
137127

138128
let private scanExclusive<'a when 'a: struct> =
139129
scanGeneral
140130
<@ fun (_: ClArray<'a>) (_: 'a) (_: int) (_: int) (_: int) -> () @>
141-
<@ fun (resultBuffer: ClArray<'a>) (resultLocalBuffer: 'a []) (inputArrayLength: int) (smth: int) (gid: int) (i: int) (localID: int) ->
131+
<@ fun (resultBuffer: ClArray<'a>) (resultLocalBuffer: 'a []) (inputArrayLength: int) (_: int) (gid: int) (i: int) (localID: int) ->
142132

143133
if gid < inputArrayLength then
144134
resultBuffer.[i] <- resultLocalBuffer.[localID] @>
@@ -147,8 +137,7 @@ module PrefixSum =
147137
scanGeneral
148138
<@ fun (resultBuffer: ClArray<'a>) (value: 'a) (inputArrayLength: int) (gid: int) (i: int) ->
149139

150-
if gid < inputArrayLength then
151-
resultBuffer.[i] <- value @>
140+
if gid < inputArrayLength then resultBuffer.[i] <- value @>
152141
<@ fun (resultBuffer: ClArray<'a>) (resultLocalBuffer: 'a []) (inputArrayLength: int) (workGroupSize: int) (gid: int) (i: int) (localID: int) ->
153142

154143
if gid < inputArrayLength
@@ -206,8 +195,8 @@ module PrefixSum =
206195
verticesArrays <- swap verticesArrays
207196
verticesLength <- (verticesLength - 1) / workGroupSize + 1
208197

209-
processor.Post(Msg.CreateFreeMsg(firstVertices))
210-
processor.Post(Msg.CreateFreeMsg(secondVertices))
198+
firstVertices.Free processor
199+
secondVertices.Free processor
211200

212201
totalSum
213202

@@ -270,3 +259,119 @@ module PrefixSum =
270259
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<int>) ->
271260

272261
scan processor inputArray 0
262+
263+
264+
module ByKey =
265+
let private oneWorkGroup
266+
writeZero
267+
zero
268+
uniqueKey
269+
(opAdd: Expr<'a -> 'a -> 'a>)
270+
(clContext: ClContext)
271+
workGroupSize
272+
=
273+
274+
let scan =
275+
<@ fun (ndRange: Range1D) length (values: ClArray<'a>) (keys: ClArray<int>) ->
276+
277+
let localValues = localArray<'a> workGroupSize
278+
let localKeys = localArray<int> workGroupSize
279+
280+
let gid = ndRange.GlobalID0
281+
let lid = ndRange.LocalID0
282+
283+
if gid < length then
284+
// only one workgroup
285+
localValues.[lid] <- values.[lid]
286+
localKeys.[lid] <- keys.[gid]
287+
else
288+
localValues.[lid] <- zero
289+
localKeys.[lid] <- uniqueKey
290+
291+
barrierLocal ()
292+
293+
// Local tree reduce
294+
(%SubSum.upSweepByKey opAdd) workGroupSize lid localValues localKeys
295+
296+
// if root item
297+
if lid = workGroupSize - 1
298+
|| localValues.[lid] <> localValues.[lid + 1] then
299+
300+
(%writeZero) localValues lid zero
301+
302+
(%SubSum.downSweepByKey opAdd) workGroupSize lid localValues localKeys
303+
304+
barrierLocal ()
305+
306+
values.[lid] <- localValues.[lid] @>
307+
308+
let program = clContext.Compile(scan)
309+
310+
fun (processor: MailboxProcessor<_>) (keys: ClArray<int>) (values: ClArray<'a>) ->
311+
312+
let kernel = program.GetKernel()
313+
314+
let ndRange =
315+
Range1D.CreateValid(values.Length, workGroupSize)
316+
317+
processor.Post(
318+
Msg.MsgSetArguments
319+
(fun () ->
320+
kernel.KernelFunc
321+
ndRange
322+
values.Length
323+
values
324+
keys)
325+
)
326+
327+
processor.Post(Msg.CreateRunMsg<_, _> kernel)
328+
329+
let sequentialSegments opWrite (clContext: ClContext) workGroupSize opAdd zero =
330+
331+
let kernel =
332+
<@ fun (ndRange: Range1D) lenght uniqueKeysCount (values: ClArray<'a>) (keys: ClArray<int>) (offsets: ClArray<int>) ->
333+
let gid = ndRange.GlobalID0
334+
335+
if gid < uniqueKeysCount then
336+
let sourcePosition = offsets.[gid]
337+
let sourceKey = keys.[sourcePosition]
338+
339+
let mutable currentSum = values.[sourcePosition]
340+
let mutable previousSum = zero
341+
342+
values.[gid] <- (%opWrite) previousSum currentSum
343+
344+
let mutable currentPosition = sourcePosition + 1
345+
346+
while currentPosition < lenght
347+
&& keys.[currentPosition] = sourceKey do
348+
349+
previousSum <- currentSum
350+
currentSum <- (%opAdd) currentSum values.[currentPosition]
351+
352+
values.[gid] <- (%opWrite) previousSum currentSum
353+
354+
currentPosition <- currentPosition + 1 @>
355+
356+
let kernel = clContext.Compile kernel
357+
358+
fun (processor: MailboxProcessor<_>) uniqueKeysCount (values: ClArray<'a>) (keys: ClArray<int>) (offsets: ClArray<int>) ->
359+
360+
let kernel = kernel.GetKernel()
361+
362+
let ndRange =
363+
Range1D.CreateValid(values.Length, workGroupSize)
364+
365+
processor.Post(
366+
Msg.MsgSetArguments
367+
(fun () ->
368+
kernel.KernelFunc
369+
ndRange
370+
values.Length
371+
uniqueKeysCount
372+
values
373+
keys
374+
offsets)
375+
)
376+
377+
processor.Post(Msg.CreateRunMsg<_, _> kernel)

src/GraphBLAS-sharp.Backend/Quotes/SubSum.fs

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,76 @@ module SubSum =
3131

3232
barrierLocal () @>
3333

34-
let sequentialSum<'a> opAdd =
35-
sumGeneral<'a> <| sequentialAccess<'a> opAdd
34+
let sequentialSum<'a> = sumGeneral<'a> << sequentialAccess<'a>
3635

37-
let treeSum<'a> opAdd = sumGeneral<'a> <| treeAccess<'a> opAdd
36+
let upSweep<'a> = sumGeneral<'a> << treeAccess<'a>
37+
38+
let downSweep opAdd =
39+
<@ fun wgSize lid (localBuffer: 'a []) ->
40+
let mutable step = wgSize
41+
42+
while step > 1 do
43+
barrierLocal ()
44+
45+
if lid < wgSize / step then
46+
let i = step * (lid + 1) - 1
47+
let j = i - (step >>> 1)
48+
49+
let tmp = localBuffer.[i]
50+
let buff = (%opAdd) tmp localBuffer.[j]
51+
localBuffer.[i] <- buff
52+
localBuffer.[j] <- tmp
53+
54+
step <- step >>> 1 @>
55+
56+
let upSweepByKey opAdd =
57+
<@ fun wgSize lid (localBuffer: 'a []) (localKeys: 'b [])->
58+
let mutable step = 2
59+
60+
while step <= wgSize do
61+
let i = step * (lid + 1) - 1
62+
63+
let firstIndex = i - (step >>> 1) // TODO()
64+
let secondIndex = i
65+
66+
let firstKey = localKeys.[firstIndex]
67+
let secondKey = localKeys.[secondIndex]
68+
69+
if lid < wgSize / step
70+
&& firstKey = secondKey then
71+
72+
let firstValue = localBuffer.[firstIndex]
73+
let secondValue = localBuffer.[secondIndex]
74+
75+
localBuffer.[secondIndex] <- (%opAdd) firstValue secondValue
76+
77+
step <- step <<< 1
78+
79+
barrierLocal () @>
80+
81+
let downSweepByKey opAdd =
82+
<@ fun wgSize lid (localBuffer: 'a []) (localKeys: int []) ->
83+
let mutable step = wgSize
84+
85+
while step > 1 do
86+
barrierLocal ()
87+
88+
let rightIndex = step * (lid + 1) - 1
89+
let leftIndex = rightIndex - (step >>> 1)
90+
91+
let rightKey = localKeys.[rightIndex]
92+
let leftKey = localKeys.[leftIndex]
93+
94+
if lid < wgSize / step
95+
&& rightKey = leftKey then
96+
97+
let tmp = localBuffer.[rightIndex]
98+
let buff = (%opAdd) tmp localBuffer.[leftIndex]
99+
100+
localBuffer.[rightIndex] <- buff
101+
localBuffer.[leftIndex] <- tmp
102+
103+
step <- step >>> 1 @>
38104

39105
let localPrefixSum opAdd =
40106
<@ fun (lid: int) (workGroupSize: int) (array: 'a []) ->
@@ -52,4 +118,6 @@ module SubSum =
52118
barrierLocal ()
53119
array.[lid] <- value @>
54120

121+
122+
55123
let localIntPrefixSum = localPrefixSum <@ (+) @>

0 commit comments

Comments
 (0)