Skip to content

Commit bc7ebdf

Browse files
committed
add: segment sequential scan
1 parent fa392b0 commit bc7ebdf

File tree

8 files changed

+296
-104
lines changed

8 files changed

+296
-104
lines changed

src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ open FSharp.Quotations
55
open GraphBLAS.FSharp.Backend.Quotes
66
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
77
open GraphBLAS.FSharp.Backend.Objects.ClCell
8-
open GraphBLAS.FSharp.Backend.Objects.ClContext
98

109
module PrefixSum =
1110
let private update (opAdd: Expr<'a -> 'a -> 'a>) (clContext: ClContext) workGroupSize =
@@ -326,7 +325,11 @@ module PrefixSum =
326325

327326
processor.Post(Msg.CreateRunMsg<_, _> kernel)
328327

329-
let sequentialSegments opWrite (clContext: ClContext) workGroupSize opAdd zero =
328+
let oneWorkGroupExclude zero = oneWorkGroup <@ (fun _ _ _ -> ()) @> zero
329+
330+
let onwWorkGroupInclude zero = oneWorkGroup <@ (fun localValues lid zero -> localValues.[lid] <- zero) @> zero
331+
332+
let private sequentialSegments opWrite (clContext: ClContext) workGroupSize opAdd zero =
330333

331334
let kernel =
332335
<@ fun (ndRange: Range1D) lenght uniqueKeysCount (values: ClArray<'a>) (keys: ClArray<int>) (offsets: ClArray<int>) ->
@@ -336,20 +339,18 @@ module PrefixSum =
336339
let sourcePosition = offsets.[gid]
337340
let sourceKey = keys.[sourcePosition]
338341

339-
let mutable currentSum = values.[sourcePosition]
342+
let mutable currentSum = zero
340343
let mutable previousSum = zero
341344

342-
values.[gid] <- (%opWrite) previousSum currentSum
343-
344-
let mutable currentPosition = sourcePosition + 1
345+
let mutable currentPosition = sourcePosition
345346

346347
while currentPosition < lenght
347348
&& keys.[currentPosition] = sourceKey do
348349

349350
previousSum <- currentSum
350351
currentSum <- (%opAdd) currentSum values.[currentPosition]
351352

352-
values.[gid] <- (%opWrite) previousSum currentSum
353+
values.[currentPosition] <- (%opWrite) previousSum currentSum
353354

354355
currentPosition <- currentPosition + 1 @>
355356

@@ -375,3 +376,10 @@ module PrefixSum =
375376
)
376377

377378
processor.Post(Msg.CreateRunMsg<_, _> kernel)
379+
380+
381+
let sequentialExclude clContext = sequentialSegments (Map.fst ()) clContext
382+
383+
let sequentialInclude clContext = sequentialSegments (Map.snd ()) clContext
384+
385+

src/GraphBLAS-sharp.Backend/Quotes/Map.fs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@ module Map =
2121
match (%map) item with
2222
| Some _ -> 1
2323
| None -> 0 @>
24+
25+
let fst () = <@ fun fst _ -> fst @>
26+
27+
let snd () = <@ fun _ snd -> snd @>

src/GraphBLAS-sharp.Backend/Quotes/SubSum.fs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ module SubSum =
4747
let j = i - (step >>> 1)
4848

4949
let tmp = localBuffer.[i]
50-
let buff = (%opAdd) tmp localBuffer.[j]
50+
51+
let operand = localBuffer.[j] // brahma error
52+
let buff = (%opAdd) tmp operand
53+
5154
localBuffer.[i] <- buff
5255
localBuffer.[j] <- tmp
5356

@@ -60,7 +63,7 @@ module SubSum =
6063
while step <= wgSize do
6164
let i = step * (lid + 1) - 1
6265

63-
let firstIndex = i - (step >>> 1) // TODO()
66+
let firstIndex = i - (step >>> 1) // TODO(work ?)
6467
let secondIndex = i
6568

6669
let firstKey = localKeys.[firstIndex]
@@ -95,7 +98,9 @@ module SubSum =
9598
&& rightKey = leftKey then
9699

97100
let tmp = localBuffer.[rightIndex]
98-
let buff = (%opAdd) tmp localBuffer.[leftIndex]
101+
102+
let rightOperand = localBuffer.[leftIndex] // Brahma error
103+
let buff = (%opAdd) tmp rightOperand
99104

100105
localBuffer.[rightIndex] <- buff
101106
localBuffer.[leftIndex] <- tmp
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
module GraphBLAS.FSharp.Tests.Backend.Common.Scan.ByKey
2+
3+
open GraphBLAS.FSharp.Backend.Common
4+
open GraphBLAS.FSharp.Backend.Objects.ClContext
5+
open Expecto
6+
open GraphBLAS.FSharp.Tests
7+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
8+
9+
let context = Context.defaultContext.ClContext
10+
11+
let processor = Context.defaultContext.Queue
12+
13+
let scanByKey scan keysAndValues =
14+
// select keys
15+
Array.map fst keysAndValues
16+
// get unique keys
17+
|> Array.distinct
18+
|> Array.map (fun key ->
19+
// select with certain key
20+
Array.filter (fst >> ((=) key)) keysAndValues
21+
// get values
22+
|> Array.map snd
23+
// scan values and get only values without sum
24+
|> (fst << scan))
25+
|> Array.concat
26+
27+
let checkResult isEqual keysAndValues actual hostScan =
28+
29+
let expected = scanByKey hostScan keysAndValues
30+
31+
let keys, values = Array.unzip keysAndValues
32+
printfn "---------------"
33+
34+
printfn "keys: %A" keys
35+
printfn "values: %A" values
36+
printfn $"expected: %A{expected}"
37+
38+
printfn "-----------"
39+
40+
"Results must be the same"
41+
|> Utils.compareArrays isEqual actual expected
42+
43+
let makeTestSequentialSegments isEqual scanHost scanDevice (keysAndValues: (int * 'a) []) =
44+
if keysAndValues.Length > 0 then
45+
let keys, values =
46+
Array.sortBy fst keysAndValues
47+
|> Array.unzip
48+
49+
let offsets =
50+
HostPrimitives.getUniqueBitmapFirstOccurrence keys
51+
|> HostPrimitives.getBitPositions
52+
53+
let uniqueKeysCount = Array.distinct keys |> Array.length
54+
55+
let clKeys = context.CreateClArrayWithSpecificAllocationMode(HostInterop, keys)
56+
57+
let clValues = context.CreateClArrayWithSpecificAllocationMode(HostInterop, values)
58+
59+
let clOffsets = context.CreateClArrayWithSpecificAllocationMode(HostInterop, offsets)
60+
61+
scanDevice processor uniqueKeysCount clValues clKeys clOffsets
62+
63+
let actual = clValues.ToHostAndFree processor
64+
clKeys.Free processor
65+
clOffsets.Free processor
66+
67+
let keysAndValues = Array.zip keys values
68+
69+
checkResult isEqual keysAndValues actual scanHost
70+
71+
let createTest (zero: 'a) opAddQ opAdd isEqual deviceScan hostScan =
72+
73+
let hostScan = hostScan zero opAdd
74+
75+
let deviceScan =
76+
deviceScan context Utils.defaultWorkGroupSize opAddQ zero
77+
78+
makeTestSequentialSegments isEqual hostScan deviceScan
79+
|> testPropertyWithConfig Utils.defaultConfig $"test on {typeof<'a>}"
80+
81+
let sequentialSegmentsTests =
82+
let excludeTests =
83+
[ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
84+
85+
if Utils.isFloat64Available context.ClDevice then
86+
createTest 0.0 <@ (+) @> (+) Utils.floatIsEqual PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
87+
88+
createTest 0.0f <@ (+) @> (+) Utils.float32IsEqual PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
89+
90+
createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
91+
createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude ]
92+
|> testList "exclude"
93+
94+
let includeTests =
95+
[ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
96+
97+
if Utils.isFloat64Available context.ClDevice then
98+
createTest 0.0 <@ (+) @> (+) Utils.floatIsEqual PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
99+
100+
createTest 0.0f <@ (+) @> (+) Utils.float32IsEqual PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
101+
102+
createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
103+
createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude ]
104+
105+
|> testList "include"
106+
107+
testList "Sequential segments" [ excludeTests; includeTests ]
108+
109+
let makeTestOneWorkGroup isEqual scanHost scanDevice (keysAndValues: (int * 'a) []) =
110+
if keysAndValues.Length > 0 then
111+
let keys, values =
112+
Array.sortBy fst keysAndValues
113+
|> Array.unzip
114+
115+
let uniqueKeysCount = Array.distinct keys |> Array.length
116+
117+
let clKeys = context.CreateClArrayWithSpecificAllocationMode(HostInterop, keys)
118+
119+
let clValues = context.CreateClArrayWithSpecificAllocationMode(HostInterop, values)
120+
121+
scanDevice processor uniqueKeysCount clValues clKeys
122+
123+
let actual = clValues.ToHostAndFree processor
124+
clKeys.Free processor
125+
126+
let keysAndValues = Array.zip keys values
127+
128+
checkResult isEqual keysAndValues actual scanHost
129+
130+
let oneWorkGroupCreateTest (zero: 'a) opAddQ opAdd isEqual deviceScan hostScan =
131+
132+
let workGroupSize = 256
133+
134+
let hostScan = hostScan zero opAdd
135+
136+
let deviceScan =
137+
deviceScan context workGroupSize opAddQ zero
138+
139+
makeTestSequentialSegments isEqual hostScan deviceScan
140+
|> testPropertyWithConfig { Utils.defaultConfig with endSize = workGroupSize } $"test on {typeof<'a>}"
141+
142+
let oneWorkGroupTests =
143+
let excludeTests =
144+
[ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
145+
146+
if Utils.isFloat64Available context.ClDevice then
147+
createTest 0.0 <@ (+) @> (+) Utils.floatIsEqual PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
148+
149+
createTest 0.0f <@ (+) @> (+) Utils.float32IsEqual PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
150+
151+
createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
152+
createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude ]
153+
|> testList "exclude"
154+
155+
let includeTests =
156+
[ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
157+
158+
if Utils.isFloat64Available context.ClDevice then
159+
createTest 0.0 <@ (+) @> (+) Utils.floatIsEqual PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
160+
161+
createTest 0.0f <@ (+) @> (+) Utils.float32IsEqual PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
162+
163+
createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
164+
createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude ]
165+
166+
|> testList "include"
167+
168+
testList "Sequential segments" [ excludeTests; includeTests ]
169+
170+
171+
172+
173+
174+

tests/GraphBLAS-sharp.Tests/Common/ClArray/PrefixSum.fs renamed to tests/GraphBLAS-sharp.Tests/Common/Scan/PrefixSum.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module GraphBLAS.FSharp.Tests.Backend.Common.ClArray.PrefixSum
1+
module GraphBLAS.FSharp.Tests.Backend.Common.Scan.PrefixSum
22

33
open Expecto
44
open Expecto.Logging
@@ -62,7 +62,7 @@ let makeTest plus zero isEqual scan (array: 'a []) =
6262
let testFixtures plus plusQ zero isEqual name =
6363
PrefixSum.runIncludeInplace plusQ context wgSize
6464
|> makeTest plus zero isEqual
65-
|> testPropertyWithConfig config (sprintf "Correctness on %s" name)
65+
|> testPropertyWithConfig config $"Correctness on %s{name}"
6666

6767
let tests =
6868
q.Error.Add(fun e -> failwithf "%A" e)

tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
<Compile Include="Common/ClArray/RemoveDuplicates.fs" />
2525
<Compile Include="Common/ClArray/Copy.fs" />
2626
<Compile Include="Common/ClArray/Replicate.fs" />
27-
<Compile Include="Common/ClArray/PrefixSum.fs" />
2827
<Compile Include="Common/Sort/Bitonic.fs" />
2928
<Compile Include="Common/Sort/Radix.fs" />
3029
<Compile Include="Common/Reduce/Sum.fs" />
3130
<Compile Include="Common/Reduce/Reduce.fs" />
3231
<Compile Include="Common/Reduce/ReduceByKey.fs" />
32+
<Compile Include="Common\Scan\PrefixSum.fs" />
33+
<Compile Include="Common\Scan\ByKey.fs" />
3334
<!--Compile Include="MatrixOperationsTests/GetTuplesTests.fs" /-->
3435
<!--Compile Include="MatrixOperationsTests/MxvTests.fs" /-->
3536
<!--Compile Include="MatrixOperationsTests/VxmTests.fs" /-->

tests/GraphBLAS-sharp.Tests/Helpers.fs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,13 @@ module Utils =
141141
result
142142

143143
module HostPrimitives =
144-
let prefixSumInclude array =
145-
Array.scan (+) 0 array
146-
|> fun scanned -> scanned.[1..]
144+
let prefixSumInclude zero add array =
145+
Array.scan add zero array
146+
|> fun scanned -> scanned.[1..], Array.last scanned
147147

148-
let prefixSumExclude sourceArray =
149-
prefixSumInclude sourceArray
150-
|> Array.insertAt 0 0
148+
let prefixSumExclude zero add sourceArray =
149+
prefixSumInclude zero add sourceArray
150+
|> (fst >> Array.insertAt 0 zero)
151151
|> fun array -> Array.take sourceArray.Length array, Array.last array
152152

153153
let getUniqueBitmapLastOccurrence array =

0 commit comments

Comments
 (0)