Skip to content

Commit 95941af

Browse files
committed
wip: tests and msbfs bug fix
1 parent e78e50e commit 95941af

File tree

13 files changed

+491
-135
lines changed

13 files changed

+491
-135
lines changed

src/GraphBLAS-sharp.Backend/Algorithms/MSBFS.fs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ open GraphBLAS.FSharp.Objects
88
open GraphBLAS.FSharp.Objects.ClMatrix
99
open GraphBLAS.FSharp.Objects.ArraysExtensions
1010
open GraphBLAS.FSharp.Objects.ClContextExtensions
11+
open GraphBLAS.FSharp.Objects.ClCellExtensions
1112
open GraphBLAS.FSharp.Backend.Matrix.LIL
1213
open GraphBLAS.FSharp.Backend.Matrix.COO
1314

@@ -50,10 +51,12 @@ module internal MSBFS =
5051
let mergeDisjoint =
5152
Matrix.mergeDisjoint clContext workGroupSize
5253

54+
let setLevel = ClArray.fill clContext workGroupSize
55+
5356
let findIntersection =
5457
Intersect.findKeysIntersection clContext workGroupSize
5558

56-
fun (queue: MailboxProcessor<_>) allocationMode (front: ClMatrix.COO<_>) (levels: ClMatrix.COO<_>) ->
59+
fun (queue: MailboxProcessor<_>) allocationMode (level: int) (front: ClMatrix.COO<_>) (levels: ClMatrix.COO<_>) ->
5760

5861
// Find intersection of levels and front indices.
5962
let intersection =
@@ -68,7 +71,14 @@ module internal MSBFS =
6871
match newFront with
6972
| Some f ->
7073
// Update levels
74+
let levelClCell = clContext.CreateClCell level
75+
76+
setLevel queue levelClCell 0 f.Values.Length f.Values
77+
78+
levelClCell.Free queue
79+
7180
let newLevels = mergeDisjoint queue levels f
81+
7282
newLevels, newFront
7383
| _ -> levels, None
7484

@@ -114,7 +124,7 @@ module internal MSBFS =
114124
| Some newFrontier ->
115125
front.Dispose queue
116126
//Filtering visited vertices
117-
match updateFrontAndLevels queue DeviceOnly newFrontier levels with
127+
match updateFrontAndLevels queue DeviceOnly level newFrontier levels with
118128
| l, Some f ->
119129
front <- f
120130
levels.Dispose queue

src/GraphBLAS-sharp.Backend/Common/Map.fs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
open Brahma.FSharp
44
open Microsoft.FSharp.Quotations
55
open GraphBLAS.FSharp.Objects.ClContextExtensions
6+
open GraphBLAS.FSharp.Objects.ClCellExtensions
67

78
module Map =
89
/// <summary>
@@ -15,11 +16,11 @@ module Map =
1516
let map<'a, 'b> (op: Expr<'a -> 'b>) (clContext: ClContext) workGroupSize =
1617

1718
let map =
18-
<@ fun (ndRange: Range1D) lenght (inputArray: ClArray<'a>) (result: ClArray<'b>) ->
19+
<@ fun (ndRange: Range1D) length (inputArray: ClArray<'a>) (result: ClArray<'b>) ->
1920

2021
let gid = ndRange.GlobalID0
2122

22-
if gid < lenght then
23+
if gid < length then
2324
result.[gid] <- (%op) inputArray.[gid] @>
2425

2526
let kernel = clContext.Compile map
@@ -50,11 +51,11 @@ module Map =
5051
let mapInPlace<'a> (op: Expr<'a -> 'a>) (clContext: ClContext) workGroupSize =
5152

5253
let map =
53-
<@ fun (ndRange: Range1D) lenght (inputArray: ClArray<'a>) ->
54+
<@ fun (ndRange: Range1D) length (inputArray: ClArray<'a>) ->
5455

5556
let gid = ndRange.GlobalID0
5657

57-
if gid < lenght then
58+
if gid < length then
5859
inputArray.[gid] <- (%op) inputArray.[gid] @>
5960

6061
let kernel = clContext.Compile map
@@ -81,11 +82,11 @@ module Map =
8182
let mapWithValue<'a, 'b, 'c> (clContext: ClContext) workGroupSize (op: Expr<'a -> 'b -> 'c>) =
8283

8384
let map =
84-
<@ fun (ndRange: Range1D) lenght (value: ClCell<'a>) (inputArray: ClArray<'b>) (result: ClArray<'c>) ->
85+
<@ fun (ndRange: Range1D) length (value: ClCell<'a>) (inputArray: ClArray<'b>) (result: ClArray<'c>) ->
8586

8687
let gid = ndRange.GlobalID0
8788

88-
if gid < lenght then
89+
if gid < length then
8990
result.[gid] <- (%op) value.Value inputArray.[gid] @>
9091

9192
let kernel = clContext.Compile map
@@ -108,6 +109,8 @@ module Map =
108109

109110
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
110111

112+
valueClCell.Free processor
113+
111114
result
112115

113116
/// <summary>

src/GraphBLAS-sharp.Backend/Matrix/COO/Intersect.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ module internal Intersect =
1212
<@ fun (ndRange: Range1D) (leftNNZ: int) (rightNNZ: int) (leftRows: ClArray<int>) (leftColumns: ClArray<int>) (rightRows: ClArray<int>) (rightColumns: ClArray<int>) (bitmap: ClArray<int>) ->
1313

1414
let gid = ndRange.GlobalID0
15-
let bitmapSize = min leftNNZ rightNNZ
15+
let bitmapSize = leftNNZ
1616

1717
if gid < bitmapSize then
1818

@@ -21,7 +21,7 @@ module internal Intersect =
2121
||| (uint64 leftColumns.[gid])
2222

2323
let intersect =
24-
(%Search.Bin.existsByKey2D) bitmapSize index rightRows rightColumns
24+
(%Search.Bin.existsByKey2D) rightNNZ index rightRows rightColumns
2525

2626
if intersect then
2727
bitmap.[gid] <- 1

src/GraphBLAS-sharp.Backend/Matrix/COO/Matrix.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ module Matrix =
243243
/// </summary>
244244
/// <param name="clContext">OpenCL context.</param>
245245
/// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
246-
let findIntersectionByKeys (clContext: ClContext) workGroupSize =
246+
let findKeysIntersection (clContext: ClContext) workGroupSize =
247247
Intersect.findKeysIntersection clContext workGroupSize
248248

249249
/// <summary>

src/GraphBLAS-sharp.Backend/Matrix/CSR/Matrix.fs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,7 @@ module Matrix =
393393
{ Context = clContext
394394
RowCount = matrix.RowCount
395395
ColumnCount = matrix.ColumnCount
396-
Rows = rows
397-
NNZ = matrix.NNZ }
396+
Rows = rows }
398397

399398
/// <summary>
400399
/// Gets the number of non-zero elements in each row.

src/GraphBLAS-sharp.Backend/Quotes/Search.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ module Search =
9898

9999
/// <summary>
100100
/// Searches value in array by two keys.
101-
/// In case there is a value at the given keys position, it is returned.
101+
/// In case there is a value at the given keys position, it returns true.
102102
/// </summary>
103103
let existsByKey2D<'a> =
104104
<@ fun length sourceIndex (rowIndices: ClArray<int>) (columnIndices: ClArray<int>) ->

tests/GraphBLAS-sharp.Tests/Backend/Algorithms/MSBFS.fs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ let makeLevelsTest context queue bfs (matrix: int [,]) =
3232
let matrixDevice = matrixHost.ToDevice context
3333

3434
let expectedArray2D: int [,] =
35-
Array2D.zeroCreate sourceVertexCount (Array2D.length2 matrix)
35+
Array2D.zeroCreate sourceVertexCount (Array2D.length1 matrix)
3636

3737
source
3838
|> Seq.iteri
@@ -109,7 +109,7 @@ let makeParentsTest context queue bfs (matrix: int [,]) =
109109

110110
let createParentsTest context queue testFun =
111111
testFun
112-
|> makeLevelsTest context queue
112+
|> makeParentsTest context queue
113113
|> testPropertyWithConfig config $"test on %A{typeof<'a>}"
114114

115115
let parentsTestFixtures (testContext: TestContext) =
@@ -119,7 +119,7 @@ let parentsTestFixtures (testContext: TestContext) =
119119
let bfsLevels =
120120
Algorithms.MSBFS.runParents context workGroupSize
121121

122-
createLevelsTest context queue bfsLevels ]
122+
createParentsTest context queue bfsLevels ]
123123

124124
let parentsTests =
125-
TestCases.gpuTests "MSBFS Levels tests" parentsTestFixtures
125+
TestCases.gpuTests "MSBFS Parents tests" parentsTestFixtures
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
module GraphBLAS.FSharp.Tests.Backend.Common.ClArray.ExcludeElements
2+
3+
open Expecto
4+
open Brahma.FSharp
5+
open GraphBLAS.FSharp
6+
open GraphBLAS.FSharp.Test
7+
open GraphBLAS.FSharp.Tests
8+
open GraphBLAS.FSharp.Objects.ArraysExtensions
9+
open GraphBLAS.FSharp.Objects.ClContextExtensions
10+
11+
let context = Context.defaultContext.ClContext
12+
13+
let processor = Context.defaultContext.Queue
14+
15+
let config =
16+
{ Utils.defaultConfig with
17+
arbitrary = [ typeof<Generators.ClArray.ExcludeElements> ] }
18+
19+
let makeTest<'a> isEqual (zero: 'a) testFun ((array, bitmap): 'a array * int array) =
20+
if array.Length > 0 && (Array.exists ((=) 1) bitmap) then
21+
22+
let arrayCl = context.CreateClArray array
23+
let bitmapCl = context.CreateClArray bitmap
24+
25+
let actual: ClArray<'a> option = testFun processor HostInterop bitmapCl arrayCl
26+
let actual =
27+
actual
28+
|> Option.map (fun a -> a.ToHostAndFree processor)
29+
30+
arrayCl.Free processor
31+
bitmapCl.Free processor
32+
33+
let expected =
34+
(bitmap, array)
35+
||> Array.zip
36+
|> Array.filter (fun (bit, _) -> bit <> 1)
37+
|> Array.unzip
38+
|> snd
39+
40+
match actual with
41+
| Some actual ->
42+
"Results must be the same"
43+
|> Utils.compareArrays isEqual actual expected
44+
| None ->
45+
"Expected should be empty"
46+
|> Expect.isEmpty expected
47+
48+
let createTest<'a> (zero: 'a) isEqual =
49+
ClArray.excludeElements context Utils.defaultWorkGroupSize
50+
|> makeTest<'a> isEqual zero
51+
|> testPropertyWithConfig config $"test on %A{typeof<'a>}"
52+
53+
let tests =
54+
[ createTest<int> 0 (=)
55+
56+
if Utils.isFloat64Available context.ClDevice then
57+
createTest<float> 0.0 (=)
58+
59+
createTest<float32> 0.0f (=)
60+
createTest<bool> false (=) ]
61+
|> testList "ExcludeElements tests"
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
module GraphBLAS.FSharp.Tests.Backend.Matrix.Intersect
2+
3+
open Expecto
4+
open Brahma.FSharp
5+
open GraphBLAS.FSharp
6+
open GraphBLAS.FSharp.Backend
7+
open GraphBLAS.FSharp.Test
8+
open GraphBLAS.FSharp.Tests
9+
open GraphBLAS.FSharp.Tests.Context
10+
open GraphBLAS.FSharp.Objects
11+
open GraphBLAS.FSharp.Objects.ArraysExtensions
12+
13+
let config =
14+
{ Utils.defaultConfig with
15+
arbitrary = [ typeof<Generators.PairOfSparseMatrices> ] }
16+
17+
let workGroupSize = Utils.defaultWorkGroupSize
18+
19+
let context = Context.defaultContext.ClContext
20+
let processor = Context.defaultContext.Queue
21+
22+
let makeTest isZero testFun (leftMatrix: 'a [,], rightMatrix: 'a [,]) =
23+
24+
let m1 = Matrix.COO.FromArray2D(leftMatrix, isZero)
25+
let m2 = Matrix.COO.FromArray2D(rightMatrix, isZero)
26+
27+
if m1.NNZ > 0 && m2.NNZ > 0 then
28+
let expected =
29+
let leftIndices =
30+
(m1.Rows, m1.Columns)
31+
||> Array.zip
32+
33+
let rightIndices =
34+
(m2.Rows, m2.Columns)
35+
||> Array.zip
36+
37+
Array.init
38+
<| m1.NNZ
39+
<| fun i ->
40+
let index = leftIndices.[i]
41+
if Array.exists ((=) index) rightIndices then 1 else 0
42+
43+
let m1 = m1.ToDevice context
44+
let m2 = m2.ToDevice context
45+
46+
let actual: ClArray<int> =
47+
testFun processor ClContextExtensions.HostInterop m1 m2
48+
49+
let actual = actual.ToHostAndFree processor
50+
51+
m1.Dispose processor
52+
m2.Dispose processor
53+
54+
// Check result
55+
"Matrices should be equal"
56+
|> Expect.equal actual expected
57+
58+
let createTest isZero =
59+
Matrix.COO.Matrix.findKeysIntersection context workGroupSize
60+
|> makeTest isZero
61+
|> testPropertyWithConfig config $"test on %A{typeof<'a>}"
62+
63+
let tests =
64+
[ createTest ((=) false)
65+
createTest ((=) 0)
66+
createTest ((=) 0uy)
67+
createTest (Utils.float32IsEqual 0.0f)
68+
69+
if Utils.isFloat64Available context.ClDevice then
70+
createTest (Utils.floatIsEqual 0.0) ]
71+
|> testList "Intersect tests"

tests/GraphBLAS-sharp.Tests/Backend/Matrix/Merge.fs

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ module GraphBLAS.FSharp.Tests.Backend.Matrix.Merge
22

33
open Brahma.FSharp
44
open Expecto
5+
open GraphBLAS.FSharp.Test
56
open Microsoft.FSharp.Collections
67
open GraphBLAS.FSharp.Backend
78
open GraphBLAS.FSharp.Tests
89
open GraphBLAS.FSharp.Tests.Backend
910
open GraphBLAS.FSharp.Objects
1011
open GraphBLAS.FSharp.Objects.ArraysExtensions
12+
open GraphBLAS.FSharp.Objects.MatrixExtensions
1113

1214
let context = Context.defaultContext.ClContext
1315

@@ -108,6 +110,52 @@ let testsCOO =
108110
createTestCOO (=) false ]
109111
|> testList "COO"
110112

113+
let makeTestCOODisjoint isEqual zero testFun (leftArray: 'a [,], rightArray: 'a [,]) =
114+
115+
let leftMatrix =
116+
Matrix.COO.FromArray2D(leftArray, isEqual zero)
117+
118+
let rightMatrix =
119+
Matrix.COO.FromArray2D(rightArray, isEqual zero)
120+
121+
if leftMatrix.NNZ > 0 && rightMatrix.NNZ > 0 then
122+
123+
let clLeftMatrix = leftMatrix.ToDevice context
124+
let clRightMatrix = rightMatrix.ToDevice context
125+
126+
let actual: ClMatrix.COO<'a> = testFun processor clLeftMatrix clRightMatrix
127+
let actual = actual.ToHostAndFree processor
128+
129+
clLeftMatrix.Dispose processor
130+
clRightMatrix.Dispose processor
131+
132+
rightArray
133+
|> Array2D.iteri (fun row column value -> leftArray.[row, column] <- value)
134+
135+
let expected = Matrix.COO.FromArray2D(leftArray, isEqual zero)
136+
137+
Utils.compareCOOMatrix isEqual actual expected
138+
139+
let createTestCOODisjoint isEqual (zero: 'a) =
140+
let configDisjoint =
141+
{Utils.defaultConfig with
142+
endSize = 10
143+
arbitrary = [ typeof<Generators.PairOfDisjointMatricesOfTheSameSize> ]}
144+
145+
Matrix.COO.Merge.runDisjoint context Utils.defaultWorkGroupSize
146+
|> makeTestCOODisjoint isEqual zero
147+
|> testPropertyWithConfig configDisjoint $"test on {typeof<'a>}"
148+
149+
let testsCOODisjoint =
150+
[ createTestCOODisjoint (=) 0
151+
152+
if Utils.isFloat64Available context.ClDevice then
153+
createTestCOODisjoint (=) 0.0
154+
155+
createTestCOODisjoint (=) 0.0f
156+
createTestCOODisjoint (=) false ]
157+
|> testList "COO Disjoint"
158+
111159
let makeTestCSR isEqual zero testFun (leftArray: 'a [,], rightArray: 'a [,]) =
112160
let leftMatrix =
113161
Matrix.CSR.FromArray2D(leftArray, isEqual zero)
@@ -173,4 +221,4 @@ let testsCSR =
173221
|> testList "CSR"
174222

175223
let allTests =
176-
[ testsCSR; testsCOO ] |> testList "Merge"
224+
[ testsCSR; testsCOO; testsCOODisjoint ] |> testList "Merge"

0 commit comments

Comments
 (0)