@@ -3,45 +3,62 @@ namespace GraphBLAS.FSharp.Backend.Algorithms
33open Brahma.FSharp
44open FSharp.Quotations
55open GraphBLAS.FSharp
6- open GraphBLAS.FSharp .Backend .Quotes
76open GraphBLAS.FSharp .Objects
7+ open GraphBLAS.FSharp .Common
88open GraphBLAS.FSharp .Objects .ClMatrix
99open GraphBLAS.FSharp .Objects .ArraysExtensions
1010open GraphBLAS.FSharp .Objects .ClContextExtensions
1111open GraphBLAS.FSharp .Objects .ClCellExtensions
12+ open GraphBLAS.FSharp .Backend .Quotes
1213open GraphBLAS.FSharp .Backend .Matrix .LIL
1314open GraphBLAS.FSharp .Backend .Matrix .COO
1415
1516module internal MSBFS =
1617 let private frontExclude ( clContext : ClContext ) workGroupSize =
1718
18- let excludeValues =
19- ClArray.excludeElements clContext workGroupSize
19+ let invert =
20+ ClArray.mapInPlace ArithmeticOperations.intNotQ clContext workGroupSize
21+
22+ let prefixSum =
23+ PrefixSum.standardExcludeInPlace clContext workGroupSize
24+
25+ let scatterIndices =
26+ Scatter.lastOccurrence clContext workGroupSize
2027
21- let excludeIndices =
22- ClArray.excludeElements clContext workGroupSize
28+ let scatterValues =
29+ Scatter.lastOccurrence clContext workGroupSize
2330
2431 fun ( queue : MailboxProcessor < _ >) allocationMode ( front : ClMatrix.COO < _ >) ( intersection : ClArray < int >) ->
2532
26- let newRows =
27- excludeIndices queue allocationMode intersection front.Rows
33+ invert queue intersection
34+
35+ let length =
36+ ( prefixSum queue intersection)
37+ .ToHostAndFree queue
38+
39+ if length = 0 then
40+ None
41+ else
42+ let rows =
43+ clContext.CreateClArrayWithSpecificAllocationMode( allocationMode, length)
44+
45+ let columns =
46+ clContext.CreateClArrayWithSpecificAllocationMode( allocationMode, length)
2847
29- let newColumns =
30- excludeIndices queue allocationMode intersection front.Columns
48+ let values =
49+ clContext.CreateClArrayWithSpecificAllocationMode ( allocationMode , length )
3150
32- let newValues =
33- excludeValues queue allocationMode intersection front.Values
51+ scatterIndices queue intersection front.Rows rows
52+ scatterIndices queue intersection front.Columns columns
53+ scatterValues queue intersection front.Values values
3454
35- match newRows, newColumns, newValues with
36- | Some rows, Some columns, Some values ->
3755 { Context = clContext
3856 Rows = rows
3957 Columns = columns
4058 Values = values
4159 RowCount = front.RowCount
4260 ColumnCount = front.ColumnCount }
4361 |> Some
44- | _ -> None
4562
4663 module Levels =
4764 let private updateFrontAndLevels ( clContext : ClContext ) workGroupSize =
@@ -70,13 +87,14 @@ module internal MSBFS =
7087
7188 match newFront with
7289 | Some f ->
73- // Update levels
7490 let levelClCell = clContext.CreateClCell level
7591
92+ // Set current level value to all remaining front positions
7693 setLevel queue levelClCell 0 f.Values.Length f.Values
7794
7895 levelClCell.Free queue
7996
97+ // Update levels
8098 let newLevels = mergeDisjoint queue levels f
8199
82100 newLevels, newFront
@@ -110,7 +128,7 @@ module internal MSBFS =
110128
111129 let mutable front = copy queue DeviceOnly levels
112130
113- let mutable level = 0
131+ let mutable level = 1
114132 let mutable stop = false
115133
116134 while not stop do
@@ -121,15 +139,21 @@ module internal MSBFS =
121139 | None ->
122140 front.Dispose queue
123141 stop <- true
142+
124143 | Some newFrontier ->
125144 front.Dispose queue
145+
126146 //Filtering visited vertices
127147 match updateFrontAndLevels queue DeviceOnly level newFrontier levels with
128148 | l, Some f ->
129149 front <- f
150+
130151 levels.Dispose queue
152+
131153 levels <- l
154+
132155 newFrontier.Dispose queue
156+
133157 | _, None ->
134158 stop <- true
135159 newFrontier.Dispose queue
@@ -151,8 +175,6 @@ module internal MSBFS =
151175
152176 module Parents =
153177 let private updateFrontAndParents ( clContext : ClContext ) workGroupSize =
154- // update parents same as levels
155- // every front value should be equal to its column number
156178 let frontExclude = frontExclude clContext workGroupSize
157179
158180 let mergeDisjoint =
@@ -175,10 +197,15 @@ module internal MSBFS =
175197
176198 match newFront with
177199 | Some f ->
178- // Update levels
179200 let resultFront = { f with Values = f.Columns }
180- let newLevels = mergeDisjoint queue parents f
181- newLevels, Some resultFront
201+
202+ // Update parents
203+ let newParents = mergeDisjoint queue parents f
204+
205+ f.Values.Free queue
206+
207+ newParents, Some resultFront
208+
182209 | _ -> parents, None
183210
184211 let run < 'a when 'a : struct > ( clContext : ClContext ) workGroupSize =
@@ -190,7 +217,7 @@ module internal MSBFS =
190217 clContext
191218 workGroupSize
192219
193- let updateFrontAndLevels =
220+ let updateFrontAndParents =
194221 updateFrontAndParents clContext workGroupSize
195222
196223 fun ( queue : MailboxProcessor < Msg >) ( inputMatrix : ClMatrix < 'a >) ( source : int list ) ->
@@ -227,15 +254,20 @@ module internal MSBFS =
227254 | None ->
228255 front.Dispose queue
229256 stop <- true
257+
230258 | Some newFrontier ->
231259 front.Dispose queue
260+
232261 //Filtering visited vertices
233- match updateFrontAndLevels queue DeviceOnly newFrontier parents with
234- | l , Some f ->
262+ match updateFrontAndParents queue DeviceOnly newFrontier parents with
263+ | p , Some f ->
235264 front <- f
265+
236266 parents.Dispose queue
237- parents <- l
267+ parents <- p
268+
238269 newFrontier.Dispose queue
270+
239271 | _, None ->
240272 stop <- true
241273 newFrontier.Dispose queue
0 commit comments