Skip to content

Commit b6e13fd

Browse files
committed
SpMSpVMasked
1 parent a4da73d commit b6e13fd

File tree

5 files changed

+274
-81
lines changed

5 files changed

+274
-81
lines changed

src/GraphBLAS-sharp.Backend/Algorithms/BFS.fs

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ module internal BFS =
135135
Operations.SpMVInPlace add mul clContext workGroupSize
136136

137137
let spMSpV =
138-
Operations.SpMSpVBool add mul clContext workGroupSize
138+
Operations.SpMSpVMaskedBool add mul clContext workGroupSize
139139

140140
let zeroCreate =
141141
Vector.zeroCreate clContext workGroupSize
@@ -145,9 +145,6 @@ module internal BFS =
145145
let maskComplementedInPlace =
146146
Vector.map2InPlace Mask.complementedOp clContext workGroupSize
147147

148-
let maskComplemented =
149-
Vector.map2Sparse Mask.complementedOp clContext workGroupSize
150-
151148
let fillSubVectorInPlace =
152149
Vector.assignByMaskInPlace (Mask.assign) clContext workGroupSize
153150

@@ -190,28 +187,21 @@ module internal BFS =
190187
match frontier with
191188
| ClVector.Sparse _ ->
192189
//Getting new frontier
193-
match spMSpV queue matrix frontier with
190+
match spMSpV queue matrix frontier levels with
194191
| None ->
195192
frontier.Dispose()
196193
stop <- true
197-
| Some newFrontier ->
194+
| Some newMaskedFrontier ->
198195
frontier.Dispose()
199-
//Filtering visited vertices
200-
match maskComplemented queue DeviceOnly newFrontier levels with
201-
| None ->
202-
stop <- true
203-
newFrontier.Dispose()
204-
| Some newMaskedFrontier ->
205-
newFrontier.Dispose()
206-
207-
//Push/pull
208-
let NNZ = getNNZ queue newMaskedFrontier
209-
210-
if (push NNZ newMaskedFrontier.Size) then
211-
frontier <- newMaskedFrontier
212-
else
213-
frontier <- toDense queue DeviceOnly newMaskedFrontier
214-
newMaskedFrontier.Dispose()
196+
197+
//Push/pull
198+
let NNZ = getNNZ queue newMaskedFrontier
199+
200+
if (push NNZ newMaskedFrontier.Size) then
201+
frontier <- newMaskedFrontier
202+
else
203+
frontier <- toDense queue DeviceOnly newMaskedFrontier
204+
newMaskedFrontier.Dispose()
215205
| ClVector.Dense oldFrontier ->
216206
//Getting new frontier
217207
spMVInPlace queue matrix frontier frontier

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -906,14 +906,14 @@ module ClArray =
906906
<@ fun (ndRange: Range1D) (length: int) (array: ClArray<'a>) (count: ClCell<int>) ->
907907
let gid = ndRange.GlobalID0
908908
let mutable countLocal = 0
909-
let gSize = ndRange.GlobalWorkSize
909+
let step = ndRange.GlobalWorkSize
910910

911911
let mutable i = gid
912912

913913
while i < length do
914914
let res = (%predicate) array.[i]
915915
if res then countLocal <- countLocal + 1
916-
i <- i + gSize
916+
i <- i + step
917917

918918
atomic (+) count.Value countLocal |> ignore @>
919919

src/GraphBLAS-sharp.Backend/Operations/Operations.fs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ module Operations =
331331
| _ -> failwith "Not implemented yet"
332332

333333
/// <summary>
334-
/// CSR Matrix - sparse vector multiplication. Optimized for bool OR and AND operations.
334+
/// CSR Matrix - sparse vector multiplication. Optimized for bool OR and AND operations by skipping reduction stage.
335335
/// </summary>
336336
/// <param name="add">Type of binary function to reduce entries.</param>
337337
/// <param name="mul">Type of binary function to combine entries.</param>
@@ -352,6 +352,50 @@ module Operations =
352352
| ClMatrix.CSR m, ClVector.Sparse v -> Option.map ClVector.Sparse (run queue m v)
353353
| _ -> failwith "Not implemented yet"
354354

355+
/// <summary>
356+
/// CSR Matrix - sparse vector multiplication with mask. Mask is complemented.
357+
/// </summary>
358+
/// <param name="add">Type of binary function to reduce entries.</param>
359+
/// <param name="mul">Type of binary function to combine entries.</param>
360+
/// <param name="clContext">OpenCL context.</param>
361+
/// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
362+
let SpMSpVMasked
363+
(add: Expr<'c option -> 'c option -> 'c option>)
364+
(mul: Expr<'a option -> 'b option -> 'c option>)
365+
(clContext: ClContext)
366+
workGroupSize
367+
=
368+
369+
let run =
370+
SpMSpV.Masked.runMasked add mul clContext workGroupSize
371+
372+
fun (queue: RawCommandQueue) (matrix: ClMatrix<'a>) (vector: ClVector<'b>) (mask: ClVector<'d>) ->
373+
match matrix, vector, mask with
374+
| ClMatrix.CSR m, ClVector.Sparse v, ClVector.Dense mask -> Option.map ClVector.Sparse (run queue m v mask)
375+
| _ -> failwith "Not implemented yet"
376+
377+
/// <summary>
378+
/// CSR Matrix - sparse vector multiplication with mask. Mask is complemented. Optimized for bool OR and AND operations by skipping reduction stage.
379+
/// </summary>
380+
/// <param name="add">Type of binary function to reduce entries.</param>
381+
/// <param name="mul">Type of binary function to combine entries.</param>
382+
/// <param name="clContext">OpenCL context.</param>
383+
/// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
384+
let SpMSpVMaskedBool
385+
(add: Expr<bool option -> bool option -> bool option>)
386+
(mul: Expr<bool option -> bool option -> bool option>)
387+
(clContext: ClContext)
388+
workGroupSize
389+
=
390+
391+
let run =
392+
SpMSpV.Masked.runMaskedBoolStandard add mul clContext workGroupSize
393+
394+
fun (queue: RawCommandQueue) (matrix: ClMatrix<'a>) (vector: ClVector<'b>) (mask: ClVector<'d>) ->
395+
match matrix, vector, mask with
396+
| ClMatrix.CSR m, ClVector.Sparse v, ClVector.Dense mask -> Option.map ClVector.Sparse (run queue m v mask)
397+
| _ -> failwith "Not implemented yet"
398+
355399
/// <summary>
356400
/// CSR Matrix - sparse vector multiplication.
357401
/// </summary>

src/GraphBLAS-sharp.Backend/Operations/SpMSpV.fs

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,218 @@ module SpMSpV =
290290
Indices = resultIndices
291291
Values = create queue DeviceOnly resultIndices.Length true
292292
Size = matrix.ColumnCount })
293+
294+
module Masked =
295+
296+
let private count (clContext: ClContext) workGroupSize =
297+
298+
let count =
299+
<@ fun (ndRange: Range1D) vectorLength (vectorIndices: ClArray<int>) (vectorMask: ClArray<'d option>) (matrixRowPointers: ClArray<int>) (matrixColumns: ClArray<int>) (result: ClCell<int>) ->
300+
let gid = ndRange.GlobalID0
301+
let step = ndRange.GlobalWorkSize
302+
303+
let mutable idx = gid
304+
305+
while idx < vectorLength do
306+
let vectorIndex = vectorIndices.[idx]
307+
308+
let rowStart = matrixRowPointers.[vectorIndex]
309+
let rowEnd = matrixRowPointers.[vectorIndex + 1]
310+
311+
let mutable count = 0
312+
313+
for i in rowStart .. rowEnd - 1 do
314+
match vectorMask.[matrixColumns.[i]] with
315+
| None -> count <- count + 1
316+
| Some _ -> ()
317+
318+
atomic (+) result.Value count |> ignore
319+
320+
idx <- idx + step @>
321+
322+
let count = clContext.Compile count
323+
324+
fun (queue: RawCommandQueue) (matrix: ClMatrix.CSR<'a>) (vector: ClVector.Sparse<'b>) (vectorMask: ClArray<'d option>) ->
325+
326+
let length = vector.NNZ
327+
328+
let numberOfGroups =
329+
Utils.divUpClamp length workGroupSize 1 1024
330+
331+
let result = clContext.CreateClCell(0)
332+
333+
let ndRange =
334+
Range1D.CreateValid(numberOfGroups * workGroupSize, workGroupSize)
335+
336+
let count = count.GetKernel()
337+
338+
count.KernelFunc ndRange length vector.Indices vectorMask matrix.RowPointers matrix.Columns result
339+
340+
queue.RunKernel count
341+
342+
result.ToHostAndFree queue
343+
344+
let private multiplyValues
345+
(clContext: ClContext)
346+
(mul: Expr<'a option -> 'b option -> 'c option>)
347+
workGroupSize
348+
=
349+
350+
let multiply =
351+
<@ fun (ndRange: Range1D) resultLength (vectorIndices: ClArray<int>) (vectorValues: ClArray<'b>) (vectorMask: ClArray<'d option>) (matrixRowPointers: ClArray<int>) (matrixColumns: ClArray<int>) (matrixValues: ClArray<'a>) (resultOffset: ClCell<int>) (resultIndices: ClArray<int>) (resultValues: ClArray<'c option>) ->
352+
let gid = ndRange.GlobalID0
353+
let step = ndRange.GlobalWorkSize
354+
355+
let mutable i = gid
356+
357+
while i < resultLength do
358+
let vectorIndex = vectorIndices.[i]
359+
let vectorValue = vectorValues.[i]
360+
361+
let rowStart = matrixRowPointers.[vectorIndex]
362+
let rowEnd = matrixRowPointers.[vectorIndex + 1]
363+
364+
let mutable count = 0
365+
366+
for i in rowStart .. rowEnd - 1 do
367+
match vectorMask.[matrixColumns.[i]] with
368+
| None -> count <- count + 1
369+
| Some _ -> ()
370+
371+
let mutable offset = atomic (+) resultOffset.Value count
372+
373+
for i in rowStart .. rowEnd - 1 do
374+
let columnIndex = matrixColumns.[i]
375+
376+
// TODO: Pass mask operation
377+
match vectorMask.[columnIndex] with
378+
| None ->
379+
resultIndices.[offset] <- columnIndex
380+
resultValues.[offset] <- (%mul) (Some matrixValues.[i]) (Some vectorValue)
381+
offset <- offset + 1
382+
| Some _ -> ()
383+
384+
i <- i + step @>
385+
386+
let kernel = clContext.Compile multiply
387+
388+
fun (queue: RawCommandQueue) (matrix: ClMatrix.CSR<'a>) (vector: ClVector.Sparse<'b>) (vectorMask: ClArray<'d option>) (resultSize: int) ->
389+
390+
let multipliedIndices =
391+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, resultSize)
392+
393+
let multipliedValues =
394+
clContext.CreateClArrayWithSpecificAllocationMode<'c option>(DeviceOnly, resultSize)
395+
396+
let offset = clContext.CreateClCell 0
397+
398+
let numberOfGroups =
399+
Utils.divUpClamp vector.NNZ workGroupSize 1 1024
400+
401+
let ndRange =
402+
Range1D.CreateValid(numberOfGroups * workGroupSize, workGroupSize)
403+
404+
let kernel = kernel.GetKernel()
405+
406+
kernel.KernelFunc
407+
ndRange
408+
vector.NNZ
409+
vector.Indices
410+
vector.Values
411+
vectorMask
412+
matrix.RowPointers
413+
matrix.Columns
414+
matrix.Values
415+
offset
416+
multipliedIndices
417+
multipliedValues
418+
419+
queue.RunKernel kernel
420+
421+
offset.Free()
422+
423+
multipliedIndices, multipliedValues
424+
425+
let runMasked
426+
(add: Expr<'c option -> 'c option -> 'c option>)
427+
(mul: Expr<'a option -> 'b option -> 'c option>)
428+
(clContext: ClContext)
429+
workGroupSize
430+
=
431+
432+
let count = count clContext workGroupSize
433+
434+
let multiplyValues =
435+
multiplyValues clContext mul workGroupSize
436+
437+
let sort =
438+
Sort.Bitonic.sortKeyValuesInplace clContext workGroupSize
439+
440+
let segReduce =
441+
Reduce.ByKey.Option.segmentSequential add clContext workGroupSize
442+
443+
fun (queue: RawCommandQueue) (matrix: ClMatrix.CSR<'a>) (vector: ClVector.Sparse<'b>) (mask: ClArray<'d option>) ->
444+
445+
match count queue matrix vector mask with
446+
| 0 -> None
447+
| resultSize ->
448+
let multipliedIndices, multipliedValues =
449+
multiplyValues queue matrix vector mask resultSize
450+
451+
sort queue multipliedIndices multipliedValues
452+
453+
let result =
454+
segReduce queue DeviceOnly multipliedIndices multipliedValues
455+
|> Option.map
456+
(fun (reducedValues, reducedKeys) ->
457+
{ Context = clContext
458+
Indices = reducedKeys
459+
Values = reducedValues
460+
Size = matrix.ColumnCount })
461+
462+
multipliedIndices.Free()
463+
multipliedValues.Free()
464+
465+
result
466+
467+
let runMaskedBoolStandard
468+
(add: Expr<'c option -> 'c option -> 'c option>)
469+
(mul: Expr<'a option -> 'b option -> 'c option>)
470+
(clContext: ClContext)
471+
workGroupSize
472+
=
473+
474+
let count = count clContext workGroupSize
475+
476+
let multiplyValues =
477+
multiplyValues clContext mul workGroupSize
478+
479+
let sort =
480+
Sort.Bitonic.sortKeyValuesInplace clContext workGroupSize
481+
482+
let removeDuplicates =
483+
GraphBLAS.FSharp.ClArray.removeDuplications clContext workGroupSize
484+
485+
let create =
486+
GraphBLAS.FSharp.ClArray.create clContext workGroupSize
487+
488+
fun (queue: RawCommandQueue) (matrix: ClMatrix.CSR<'a>) (vector: ClVector.Sparse<'b>) (mask: ClArray<'d option>) ->
489+
490+
match count queue matrix vector mask with
491+
| 0 -> None
492+
| resultSize ->
493+
let multipliedIndices, multipliedValues =
494+
multiplyValues queue matrix vector mask resultSize
495+
496+
sort queue multipliedIndices multipliedValues
497+
498+
let resultIndices = removeDuplicates queue multipliedIndices
499+
500+
multipliedIndices.Free()
501+
multipliedValues.Free()
502+
503+
Some
504+
<| { Context = clContext
505+
Indices = resultIndices
506+
Values = create queue DeviceOnly resultIndices.Length true
507+
Size = matrix.ColumnCount }

0 commit comments

Comments
 (0)