@@ -290,3 +290,218 @@ module SpMSpV =
290290 Indices = resultIndices
291291 Values = create queue DeviceOnly resultIndices.Length true
292292 Size = matrix.ColumnCount })
293+
294+ module Masked =
295+
296+ let private count ( clContext : ClContext ) workGroupSize =
297+
298+ let count =
299+ <@ fun ( ndRange : Range1D ) vectorLength ( vectorIndices : ClArray < int >) ( vectorMask : ClArray < 'd option >) ( matrixRowPointers : ClArray < int >) ( matrixColumns : ClArray < int >) ( result : ClCell < int >) ->
300+ let gid = ndRange.GlobalID0
301+ let step = ndRange.GlobalWorkSize
302+
303+ let mutable idx = gid
304+
305+ while idx < vectorLength do
306+ let vectorIndex = vectorIndices.[ idx]
307+
308+ let rowStart = matrixRowPointers.[ vectorIndex]
309+ let rowEnd = matrixRowPointers.[ vectorIndex + 1 ]
310+
311+ let mutable count = 0
312+
313+ for i in rowStart .. rowEnd - 1 do
314+ match vectorMask.[ matrixColumns.[ i]] with
315+ | None -> count <- count + 1
316+ | Some _ -> ()
317+
318+ atomic (+) result.Value count |> ignore
319+
320+ idx <- idx + step @>
321+
322+ let count = clContext.Compile count
323+
324+ fun ( queue : RawCommandQueue ) ( matrix : ClMatrix.CSR < 'a >) ( vector : ClVector.Sparse < 'b >) ( vectorMask : ClArray < 'd option >) ->
325+
326+ let length = vector.NNZ
327+
328+ let numberOfGroups =
329+ Utils.divUpClamp length workGroupSize 1 1024
330+
331+ let result = clContext.CreateClCell( 0 )
332+
333+ let ndRange =
334+ Range1D.CreateValid( numberOfGroups * workGroupSize, workGroupSize)
335+
336+ let count = count.GetKernel()
337+
338+ count.KernelFunc ndRange length vector.Indices vectorMask matrix.RowPointers matrix.Columns result
339+
340+ queue.RunKernel count
341+
342+ result.ToHostAndFree queue
343+
344+ let private multiplyValues
345+ ( clContext : ClContext )
346+ ( mul : Expr < 'a option -> 'b option -> 'c option >)
347+ workGroupSize
348+ =
349+
350+ let multiply =
351+ <@ fun ( ndRange : Range1D ) resultLength ( vectorIndices : ClArray < int >) ( vectorValues : ClArray < 'b >) ( vectorMask : ClArray < 'd option >) ( matrixRowPointers : ClArray < int >) ( matrixColumns : ClArray < int >) ( matrixValues : ClArray < 'a >) ( resultOffset : ClCell < int >) ( resultIndices : ClArray < int >) ( resultValues : ClArray < 'c option >) ->
352+ let gid = ndRange.GlobalID0
353+ let step = ndRange.GlobalWorkSize
354+
355+ let mutable i = gid
356+
357+ while i < resultLength do
358+ let vectorIndex = vectorIndices.[ i]
359+ let vectorValue = vectorValues.[ i]
360+
361+ let rowStart = matrixRowPointers.[ vectorIndex]
362+ let rowEnd = matrixRowPointers.[ vectorIndex + 1 ]
363+
364+ let mutable count = 0
365+
366+ for i in rowStart .. rowEnd - 1 do
367+ match vectorMask.[ matrixColumns.[ i]] with
368+ | None -> count <- count + 1
369+ | Some _ -> ()
370+
371+ let mutable offset = atomic (+) resultOffset.Value count
372+
373+ for i in rowStart .. rowEnd - 1 do
374+ let columnIndex = matrixColumns.[ i]
375+
376+ // TODO: Pass mask operation
377+ match vectorMask.[ columnIndex] with
378+ | None ->
379+ resultIndices.[ offset] <- columnIndex
380+ resultValues.[ offset] <- (% mul) ( Some matrixValues.[ i]) ( Some vectorValue)
381+ offset <- offset + 1
382+ | Some _ -> ()
383+
384+ i <- i + step @>
385+
386+ let kernel = clContext.Compile multiply
387+
388+ fun ( queue : RawCommandQueue ) ( matrix : ClMatrix.CSR < 'a >) ( vector : ClVector.Sparse < 'b >) ( vectorMask : ClArray < 'd option >) ( resultSize : int ) ->
389+
390+ let multipliedIndices =
391+ clContext.CreateClArrayWithSpecificAllocationMode< int>( DeviceOnly, resultSize)
392+
393+ let multipliedValues =
394+ clContext.CreateClArrayWithSpecificAllocationMode< 'c option>( DeviceOnly, resultSize)
395+
396+ let offset = clContext.CreateClCell 0
397+
398+ let numberOfGroups =
399+ Utils.divUpClamp vector.NNZ workGroupSize 1 1024
400+
401+ let ndRange =
402+ Range1D.CreateValid( numberOfGroups * workGroupSize, workGroupSize)
403+
404+ let kernel = kernel.GetKernel()
405+
406+ kernel.KernelFunc
407+ ndRange
408+ vector.NNZ
409+ vector.Indices
410+ vector.Values
411+ vectorMask
412+ matrix.RowPointers
413+ matrix.Columns
414+ matrix.Values
415+ offset
416+ multipliedIndices
417+ multipliedValues
418+
419+ queue.RunKernel kernel
420+
421+ offset.Free()
422+
423+ multipliedIndices, multipliedValues
424+
425+ let runMasked
426+ ( add : Expr < 'c option -> 'c option -> 'c option >)
427+ ( mul : Expr < 'a option -> 'b option -> 'c option >)
428+ ( clContext : ClContext )
429+ workGroupSize
430+ =
431+
432+ let count = count clContext workGroupSize
433+
434+ let multiplyValues =
435+ multiplyValues clContext mul workGroupSize
436+
437+ let sort =
438+ Sort.Bitonic.sortKeyValuesInplace clContext workGroupSize
439+
440+ let segReduce =
441+ Reduce.ByKey.Option.segmentSequential add clContext workGroupSize
442+
443+ fun ( queue : RawCommandQueue ) ( matrix : ClMatrix.CSR < 'a >) ( vector : ClVector.Sparse < 'b >) ( mask : ClArray < 'd option >) ->
444+
445+ match count queue matrix vector mask with
446+ | 0 -> None
447+ | resultSize ->
448+ let multipliedIndices , multipliedValues =
449+ multiplyValues queue matrix vector mask resultSize
450+
451+ sort queue multipliedIndices multipliedValues
452+
453+ let result =
454+ segReduce queue DeviceOnly multipliedIndices multipliedValues
455+ |> Option.map
456+ ( fun ( reducedValues , reducedKeys ) ->
457+ { Context = clContext
458+ Indices = reducedKeys
459+ Values = reducedValues
460+ Size = matrix.ColumnCount })
461+
462+ multipliedIndices.Free()
463+ multipliedValues.Free()
464+
465+ result
466+
467+ let runMaskedBoolStandard
468+ ( add : Expr < 'c option -> 'c option -> 'c option >)
469+ ( mul : Expr < 'a option -> 'b option -> 'c option >)
470+ ( clContext : ClContext )
471+ workGroupSize
472+ =
473+
474+ let count = count clContext workGroupSize
475+
476+ let multiplyValues =
477+ multiplyValues clContext mul workGroupSize
478+
479+ let sort =
480+ Sort.Bitonic.sortKeyValuesInplace clContext workGroupSize
481+
482+ let removeDuplicates =
483+ GraphBLAS.FSharp.ClArray.removeDuplications clContext workGroupSize
484+
485+ let create =
486+ GraphBLAS.FSharp.ClArray.create clContext workGroupSize
487+
488+ fun ( queue : RawCommandQueue ) ( matrix : ClMatrix.CSR < 'a >) ( vector : ClVector.Sparse < 'b >) ( mask : ClArray < 'd option >) ->
489+
490+ match count queue matrix vector mask with
491+ | 0 -> None
492+ | resultSize ->
493+ let multipliedIndices , multipliedValues =
494+ multiplyValues queue matrix vector mask resultSize
495+
496+ sort queue multipliedIndices multipliedValues
497+
498+ let resultIndices = removeDuplicates queue multipliedIndices
499+
500+ multipliedIndices.Free()
501+ multipliedValues.Free()
502+
503+ Some
504+ <| { Context = clContext
505+ Indices = resultIndices
506+ Values = create queue DeviceOnly resultIndices.Length true
507+ Size = matrix.ColumnCount }
0 commit comments