Skip to content

Commit 3261d08

Browse files
committed
add: reduce by key strategies
1 parent 3fc1b44 commit 3261d08

File tree

5 files changed

+303
-111
lines changed

5 files changed

+303
-111
lines changed

src/GraphBLAS-sharp.Backend/Common/Sum.fs

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,8 @@ module Reduce =
247247
let gid = ndRange.GlobalID0
248248

249249
if gid = 0 then
250-
let mutable currentKey = keys.[gid]
251-
let mutable segmentResult = values.[gid]
250+
let mutable currentKey = keys.[0]
251+
let mutable segmentResult = values.[0]
252252
let mutable segmentCount = 0
253253

254254
for i in 1 .. length - 1 do
@@ -277,51 +277,39 @@ module Reduce =
277277

278278
let kernel = kernel.GetKernel()
279279

280-
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange resultLength keys values reducedValues reducedKeys))
280+
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange keys.Length keys values reducedValues reducedKeys))
281281

282282
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
283283

284+
reducedKeys, reducedValues
285+
284286
let segmentSequential (clContext: ClContext) workGroupSize (reduceOp: Expr<'a -> 'a -> 'a>) =
285287

286288
let kernel =
287-
<@ fun (ndRange: Range1D) uniqueKeyCount (offsets: ClArray<int>) (keys: ClArray<int>) (values: ClArray<'a>) (reducedValues: ClArray<'a>) (reducedKeys: ClArray<int>) ->
289+
<@ fun (ndRange: Range1D) uniqueKeyCount keysLength (offsets: ClArray<int>) (keys: ClArray<int>) (values: ClArray<'a>) (reducedValues: ClArray<'a>) (reducedKeys: ClArray<int>) ->
288290

289291
let gid = ndRange.GlobalID0
290292

291293
if gid < uniqueKeyCount then
292294
let startPosition = offsets.[gid]
293-
let sourceKey = keys.[startPosition]
294295

295-
let mutable nextPosition = startPosition + 1 // TODO()
296-
let mutable nextKey = keys.[nextPosition]
296+
let sourceKey = keys.[startPosition]
297297
let mutable sum = values.[startPosition]
298298

299-
while nextKey = sourceKey do
300-
sum <- (%reduceOp) sum values.[nextPosition]
299+
let mutable currentPosition = startPosition + 1
300+
301+
while currentPosition < keysLength
302+
&& sourceKey = keys.[currentPosition] do
301303

302-
nextPosition <- nextPosition + 1
303-
nextKey <- keys.[nextPosition]
304+
sum <- (%reduceOp) sum values.[currentPosition]
305+
currentPosition <- currentPosition + 1
304306

305307
reducedValues.[gid] <- sum
306308
reducedKeys.[gid] <- sourceKey @>
307309

308310
let kernel = clContext.Compile kernel
309311

310-
let getUniqueBitmap = ClArray.getUniqueBitmap clContext workGroupSize
311-
312-
let prefixSum = PrefixSum.runExcludeInplace <@ (+) @> clContext workGroupSize
313-
314-
let removeDuplicates = ClArray.removeDuplications clContext workGroupSize
315-
316-
fun (processor: MailboxProcessor<_>) allocationMode (keys: ClArray<int>) (values: ClArray<'a>) ->
317-
318-
let bitmap = getUniqueBitmap processor DeviceOnly keys
319-
320-
let resultLength = (prefixSum processor bitmap 0).ToHostAndFree processor
321-
322-
let offsets = removeDuplicates processor bitmap
323-
324-
bitmap.Free processor
312+
fun (processor: MailboxProcessor<_>) allocationMode (resultLength: int) (offsets: ClArray<int>) (keys: ClArray<int>) (values: ClArray<'a>) ->
325313

326314
let reducedValues = clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
327315

@@ -331,10 +319,12 @@ module Reduce =
331319

332320
let kernel = kernel.GetKernel()
333321

334-
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange resultLength offsets keys values reducedValues reducedKeys))
322+
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange resultLength keys.Length offsets keys values reducedValues reducedKeys))
335323

336324
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
337325

326+
reducedKeys, reducedValues
327+
338328
let oneWorkGroupSegments (clContext: ClContext) workGroupSize (reduceOp: Expr<'a -> 'a -> 'a>) =
339329

340330
let kernel =
@@ -343,40 +333,39 @@ module Reduce =
343333
let lid = ndRange.GlobalID0
344334

345335
// load values to local memory (may be without it)
346-
let localValues = localArray<'a> length
336+
let localValues = localArray<'a> workGroupSize
347337
if lid < length then localValues.[lid] <- values.[lid]
348338

349339
// load keys to local memory (mb without it)
350-
let localKeys = localArray<int> length
340+
let localKeys = localArray<int> workGroupSize
351341
if lid < length then localKeys.[lid] <- keys.[lid]
352342

353343
// get unique keys bitmap
354-
let localBitmap = localArray<int> length
355-
(%PreparePositions.getUniqueBitmapLocal<int>) localKeys length lid localBitmap
344+
let localBitmap = localArray<int> workGroupSize
345+
localBitmap.[lid] <- 0
346+
(%PreparePositions.getUniqueBitmapLocal<int>) localKeys workGroupSize lid localBitmap
356347

357348
// get positions from bitmap by prefix sum
358349
// ??? get bitmap by prefix sum in another kernel ???
350+
// ??? we can restrict prefix sum for 0 .. length ???
359351
(%SubSum.localIntPrefixSum) lid workGroupSize localBitmap
360-
let localPositions = localBitmap
361352

362-
let uniqueKeysCount = localPositions.[length - 1]
353+
let uniqueKeysCount = localBitmap.[length - 1]
363354

364355
if lid < uniqueKeysCount then
365356
let itemKeyId = lid + 1
366-
// we can count start position by itemKeyId
367-
// but loose coalesced memory read pattern
368357

369358
let startKeyIndex =
370-
(%Search.Bin.lowerPosition) length itemKeyId localPositions
359+
(%Search.Bin.lowerPosition) length itemKeyId localBitmap
371360

372361
match startKeyIndex with
373362
| Some startPosition ->
374-
let sourcePosition = localPositions.[startPosition]
363+
let sourceKeyPosition = localBitmap.[startPosition]
375364
let mutable currentSum = localValues.[startPosition]
376365
let mutable currentIndex = startPosition + 1
377366

378367
while currentIndex < length
379-
&& localPositions.[currentIndex] = sourcePosition do
368+
&& localBitmap.[currentIndex] = sourceKeyPosition do
380369

381370
currentSum <- (%reduceOp) currentSum localValues.[currentIndex]
382371
currentIndex <- currentIndex + 1
@@ -388,6 +377,7 @@ module Reduce =
388377
let kernel = clContext.Compile kernel
389378

390379
fun (processor: MailboxProcessor<_>) allocationMode (resultLength: int) (keys: ClArray<int>) (values: ClArray<'a>) ->
380+
if keys.Length > workGroupSize then failwith "The length of the value should not exceed the size of the workgroup"
391381

392382
let reducedValues = clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
393383

@@ -397,6 +387,9 @@ module Reduce =
397387

398388
let kernel = kernel.GetKernel()
399389

400-
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange resultLength keys values reducedValues reducedKeys))
390+
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange keys.Length keys values reducedValues reducedKeys))
401391

402392
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
393+
394+
reducedKeys, reducedValues
395+

src/GraphBLAS-sharp.Backend/Quotes/PreparePositions.fs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ module PreparePositions =
3232
<@ fun (array: 'a []) length lid (result: int []) ->
3333
if lid < length then
3434
let isFirst = lid = 0
35-
let isUnique = lid > 0 && array.[lid] <> array.[lid - 1]
35+
36+
let isNotEqualToPrev = array.[lid] <> array.[lid - 1]
37+
let isUnique = lid > 0 && isNotEqualToPrev
3638

3739
if isFirst || isUnique then result.[lid] <- 1 else result.[lid] <- 0 @>
3840

tests/GraphBLAS-sharp.Tests/Common/Reduce/ReduceByKey.fs

Lines changed: 166 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,29 @@
1-
module GraphBLAS.FSharp.Tests.Backend.Common.Reduce.ReduceByKey
1+
module GraphBLAS.FSharp.Tests.Backend.Common.ReduceByKey
22

3+
open Expecto
4+
open GraphBLAS.FSharp.Backend.Common
35
open GraphBLAS.FSharp.Tests
4-
open Brahma.FSharp
56
open GraphBLAS.FSharp.Backend.Objects.ClContext
7+
open Brahma.FSharp
8+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
69

710
let context = Context.defaultContext.ClContext
811

912
let processor = Context.defaultContext.Queue
1013

11-
let checkResult (arrayAndKeys: (int * 'a) []) =
12-
let keys, values =
13-
Array.sortBy fst arrayAndKeys
14-
|> Array.unzip
14+
let config = Utils.defaultConfig
15+
16+
let checkResult isEqual actualKeys actualValues keys values reduceOp =
1517

18+
let expectedKeys, expectedValues = HostPrimitives.reduceByKey keys values reduceOp
1619

17-
()
20+
"Keys must be the same"
21+
|> Utils.compareArrays (=) actualKeys expectedKeys
1822

19-
let makeTest reduce (arrayAndKeys: (int * 'a) []) =
23+
"Values must the same"
24+
|> Utils.compareArrays isEqual actualValues expectedValues
25+
26+
let makeTest isEqual reduce reduceOp (arrayAndKeys: (int * 'a) []) =
2027
let keys, values =
2128
Array.sortBy fst arrayAndKeys
2229
|> Array.unzip
@@ -28,8 +35,157 @@ let makeTest reduce (arrayAndKeys: (int * 'a) []) =
2835
let clValues =
2936
context.CreateClArrayWithSpecificAllocationMode(DeviceOnly, values)
3037

38+
let resultLength = Array.length <| Array.distinct keys
39+
40+
let clActualKeys, clActualValues: ClArray<int> * ClArray<'a>
41+
= reduce processor HostInterop resultLength clKeys clValues
42+
43+
clValues.Free processor
44+
clKeys.Free processor
45+
46+
let actualValues = clActualValues.ToHostAndFree processor
47+
let actualKeys = clActualKeys.ToHostAndFree processor
48+
49+
checkResult isEqual actualKeys actualValues keys values reduceOp
50+
51+
let createTestSequential<'a> (isEqual: 'a -> 'a -> bool) reduceOp reduceOpQ =
52+
53+
let reduce =
54+
Reduce.ByKey.sequential context Utils.defaultWorkGroupSize reduceOpQ
55+
56+
makeTest isEqual reduce reduceOp
57+
|> testPropertyWithConfig config $"test on {typeof<'a>}"
58+
59+
let sequentialTest =
60+
let addTests =
61+
testList
62+
"add tests"
63+
[ createTestSequential<int> (=) (+) <@ (+) @>
64+
createTestSequential<byte> (=) (+) <@ (+) @>
65+
66+
if Utils.isFloat64Available context.ClDevice then
67+
createTestSequential<float> Utils.floatIsEqual (+) <@ (+) @>
68+
69+
createTestSequential<float32> Utils.float32IsEqual (+) <@ (+) @>
70+
createTestSequential<bool> (=) (||) <@ (||) @> ]
71+
72+
let mulTests =
73+
testList
74+
"mul tests"
75+
[ createTestSequential<int> (=) (*) <@ (*) @>
76+
createTestSequential<byte> (=) (*) <@ (*) @>
77+
78+
if Utils.isFloat64Available context.ClDevice then
79+
createTestSequential<float> Utils.floatIsEqual (*) <@ (*) @>
80+
81+
createTestSequential<float32> Utils.float32IsEqual (*) <@ (*) @>
82+
createTestSequential<bool> (=) (&&) <@ (&&) @> ]
83+
84+
testList "Sequential" [addTests; mulTests]
85+
86+
let createTestOneWorkGroup<'a> (isEqual: 'a -> 'a -> bool) reduceOp reduceOpQ =
87+
let reduce =
88+
Reduce.ByKey.oneWorkGroupSegments context Utils.defaultWorkGroupSize reduceOpQ
89+
90+
makeTest isEqual reduce reduceOp
91+
|> testPropertyWithConfig { config with endSize = Utils.defaultWorkGroupSize } $"test on {typeof<'a>}"
92+
93+
let oneWorkGroupTest =
94+
let addTests =
95+
testList
96+
"add tests"
97+
[ createTestOneWorkGroup<int> (=) (+) <@ (+) @>
98+
createTestOneWorkGroup<byte> (=) (+) <@ (+) @>
99+
100+
if Utils.isFloat64Available context.ClDevice then
101+
createTestOneWorkGroup<float> Utils.floatIsEqual (+) <@ (+) @>
102+
103+
createTestOneWorkGroup<float32> Utils.float32IsEqual (+) <@ (+) @>
104+
createTestOneWorkGroup<bool> (=) (||) <@ (||) @> ]
105+
106+
let mulTests =
107+
testList
108+
"mul tests"
109+
[ createTestOneWorkGroup<int> (=) (*) <@ (*) @>
110+
createTestOneWorkGroup<byte> (=) (*) <@ (*) @>
111+
112+
if Utils.isFloat64Available context.ClDevice then
113+
createTestOneWorkGroup<float> Utils.floatIsEqual (*) <@ (*) @>
114+
115+
createTestOneWorkGroup<float32> Utils.float32IsEqual (*) <@ (*) @>
116+
createTestOneWorkGroup<bool> (=) (&&) <@ (&&) @> ]
117+
118+
testList "One work group" [addTests; mulTests]
119+
120+
let makeTestSequentialSegments isEqual reduce reduceOp (valuesAndKeys: (int * 'a) []) =
121+
122+
let valuesAndKeys = Array.sortBy fst valuesAndKeys
123+
124+
if valuesAndKeys.Length > 0 then
125+
let offsets =
126+
Array.map fst valuesAndKeys
127+
|> HostPrimitives.getUniqueBitmapFirstOccurrence
128+
|> HostPrimitives.getBitPositions
129+
130+
let resultLength = offsets.Length
131+
132+
let keys, values = Array.unzip valuesAndKeys
133+
134+
let clOffsets = context.CreateClArrayWithSpecificAllocationMode(HostInterop, offsets)
135+
136+
let clKeys =
137+
context.CreateClArrayWithSpecificAllocationMode(DeviceOnly, keys)
138+
139+
let clValues =
140+
context.CreateClArrayWithSpecificAllocationMode(DeviceOnly, values)
141+
142+
let clReducedKeys, clReducedValues: ClArray<int> * ClArray<'a> =
143+
reduce processor DeviceOnly resultLength clOffsets clKeys clValues
144+
145+
let reducedKeys = clReducedKeys.ToHostAndFree processor
146+
let reducedValues = clReducedValues.ToHostAndFree processor
147+
148+
checkResult isEqual reducedKeys reducedValues keys values reduceOp
149+
150+
151+
let createTestSequentialSegments<'a> (isEqual: 'a -> 'a -> bool) reduceOp reduceOpQ =
152+
let reduce =
153+
Reduce.ByKey.segmentSequential context Utils.defaultWorkGroupSize reduceOpQ
154+
155+
makeTestSequentialSegments isEqual reduce reduceOp
156+
|> testPropertyWithConfig { config with startSize = 1000 } $"test on {typeof<'a>}"
157+
158+
let sequentialSegmentTests =
159+
let addTests =
160+
testList
161+
"add tests"
162+
[ createTestSequentialSegments<int> (=) (+) <@ (+) @>
163+
createTestSequentialSegments<byte> (=) (+) <@ (+) @>
164+
165+
if Utils.isFloat64Available context.ClDevice then
166+
createTestSequentialSegments<float> Utils.floatIsEqual (+) <@ (+) @>
167+
168+
createTestSequentialSegments<float32> Utils.float32IsEqual (+) <@ (+) @>
169+
createTestSequentialSegments<bool> (=) (||) <@ (||) @> ]
170+
171+
let mulTests =
172+
testList
173+
"mul tests"
174+
[ createTestSequentialSegments<int> (=) (*) <@ (*) @>
175+
createTestSequentialSegments<byte> (=) (*) <@ (*) @>
176+
177+
if Utils.isFloat64Available context.ClDevice then
178+
createTestSequentialSegments<float> Utils.floatIsEqual (*) <@ (*) @>
179+
180+
createTestSequentialSegments<float32> Utils.float32IsEqual (*) <@ (*) @>
181+
createTestSequentialSegments<bool> (=) (&&) <@ (&&) @> ]
182+
183+
testList "Sequential segments" [addTests; mulTests]
184+
185+
186+
187+
188+
31189

32-
reduce processor clKeys
33190

34191

35-
()

0 commit comments

Comments
 (0)