Skip to content

Commit a4da73d

Browse files
committed
Count atomic
1 parent e5ad7ac commit a4da73d

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -902,22 +902,40 @@ module ClArray =
902902

903903
let count<'a> (predicate: Expr<'a -> bool>) (clContext: ClContext) workGroupSize =
904904

905-
let sum =
906-
Reduce.reduce <@ (+) @> clContext workGroupSize
905+
let count =
906+
<@ fun (ndRange: Range1D) (length: int) (array: ClArray<'a>) (count: ClCell<int>) ->
907+
let gid = ndRange.GlobalID0
908+
let mutable countLocal = 0
909+
let gSize = ndRange.GlobalWorkSize
907910

908-
let getBitmap =
909-
Map.map<'a, int> (Map.predicateBitmap predicate) clContext workGroupSize
911+
let mutable i = gid
912+
913+
while i < length do
914+
let res = (%predicate) array.[i]
915+
if res then countLocal <- countLocal + 1
916+
i <- i + gSize
917+
918+
atomic (+) count.Value countLocal |> ignore @>
919+
920+
let count = clContext.Compile count
910921

911922
fun (processor: RawCommandQueue) (array: ClArray<'a>) ->
912923

913-
let bitmap = getBitmap processor DeviceOnly array
924+
let result = clContext.CreateClCell<int>(0)
914925

915-
let result =
916-
(sum processor bitmap).ToHostAndFree processor
926+
let numberOfGroups =
927+
Utils.divUpClamp array.Length workGroupSize 1 1024
917928

918-
bitmap.Free()
929+
let ndRange =
930+
Range1D.CreateValid(workGroupSize * numberOfGroups, workGroupSize)
919931

920-
result
932+
let kernel = count.GetKernel()
933+
934+
kernel.KernelFunc ndRange array.Length array result
935+
936+
processor.RunKernel kernel
937+
938+
result.ToHostAndFree processor
921939

922940
/// <summary>
923941
/// Builds a new array whose elements are the results of applying the given function

src/GraphBLAS-sharp.Backend/Common/Utils.fs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ module internal Utils =
1919
>> fun x -> x ||| (x >>> 16)
2020
>> fun x -> x + 1
2121

22+
let divUp x y = x / y + (if x % y = 0 then 0 else 1)
23+
24+
let divUpClamp x y left right = min (max (divUp x y) left) right
25+
2226
let floorToMultiple multiple x = x / multiple * multiple
2327

2428
let ceilToMultiple multiple x = ((x - 1) / multiple + 1) * multiple

0 commit comments

Comments
 (0)