Skip to content

Commit 35e5e29

Browse files
committed
add: ClArray.set, ClArray.item
1 parent 2ddbe37 commit 35e5e29

File tree

6 files changed

+266
-1
lines changed

6 files changed

+266
-1
lines changed

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,3 +732,59 @@ module ClArray =
732732

733733
let upperBound<'a when 'a: comparison> clContext =
734734
bound<'a, int> Search.Bin.lowerBound clContext
735+
736+
let item<'a> (clContext: ClContext) workGroupSize =
737+
738+
let kernel =
739+
<@ fun (ndRange: Range1D) index (array: ClArray<'a>) (result: ClCell<'a>) ->
740+
741+
let gid = ndRange.GlobalID0
742+
743+
if gid = 0 then
744+
result.Value <- array.[index] @>
745+
746+
let program = clContext.Compile kernel
747+
748+
fun (processor: MailboxProcessor<_>) (index: int) (array: ClArray<'a>) ->
749+
750+
if index < 0 || index >= array.Length then
751+
failwith "Index out of range"
752+
753+
let result =
754+
clContext.CreateClCell Unchecked.defaultof<'a>
755+
756+
let kernel = program.GetKernel()
757+
758+
let ndRange = Range1D.CreateValid(1, workGroupSize)
759+
760+
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange index array result))
761+
processor.Post(Msg.CreateRunMsg<_, _> kernel)
762+
763+
result
764+
765+
let set<'a> (clContext: ClContext) workGroupSize =
766+
767+
let kernel =
768+
<@ fun (ndRange: Range1D) index (array: ClArray<'a>) (value: ClCell<'a>) ->
769+
770+
let gid = ndRange.GlobalID0
771+
772+
if gid = 0 then
773+
array.[index] <- value.Value @>
774+
775+
let program = clContext.Compile kernel
776+
777+
fun (processor: MailboxProcessor<_>) (array: ClArray<'a>) (index: int) (value: 'a) ->
778+
779+
if index < 0 || index >= array.Length then
780+
failwith "Index out of range"
781+
782+
let value =
783+
clContext.CreateClCell value
784+
785+
let kernel = program.GetKernel()
786+
787+
let ndRange = Range1D.CreateValid(1, workGroupSize)
788+
789+
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange index array value))
790+
processor.Post(Msg.CreateRunMsg<_, _> kernel)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
module GraphBLAS.FSharp.Tests.Backend.Common.ClArray.Item
2+
3+
open Expecto
4+
open Brahma.FSharp
5+
open GraphBLAS.FSharp.Backend.Common
6+
open GraphBLAS.FSharp.Test
7+
open GraphBLAS.FSharp.Tests
8+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
9+
open GraphBLAS.FSharp.Backend.Objects.ClCell
10+
11+
let context = Context.defaultContext.ClContext
12+
13+
let processor = Context.defaultContext.Queue
14+
15+
let config = { Utils.defaultConfig with arbitrary = [ typeof<Generators.ClArray.Item> ] }
16+
17+
let makeTest<'a when 'a: equality> testFun (array: 'a [], position) =
18+
19+
if array.Length > 0 then
20+
21+
let clArray = context.CreateClArray array
22+
23+
let result: ClCell<'a> = testFun processor position clArray
24+
25+
clArray.Free processor
26+
let actual = result.ToHost processor
27+
28+
let expected = Array.item position array
29+
30+
"Results must be the same"
31+
|> Expect.equal actual expected
32+
33+
let createTest<'a when 'a: equality> =
34+
ClArray.item context Utils.defaultWorkGroupSize
35+
|> makeTest<'a>
36+
|> testPropertyWithConfig config $"test on %A{typeof<'a>}"
37+
38+
let tests =
39+
[ createTest<int>
40+
41+
if Utils.isFloat64Available context.ClDevice then
42+
createTest<float>
43+
44+
createTest<float32>
45+
createTest<bool> ]
46+
|> testList "Item"
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
module GraphBLAS.FSharp.Tests.Backend.Common.ClArray.Set
2+
3+
open Expecto
4+
open Brahma.FSharp
5+
open GraphBLAS.FSharp.Backend.Common
6+
open GraphBLAS.FSharp.Test
7+
open GraphBLAS.FSharp.Tests
8+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
9+
10+
let context = Context.defaultContext.ClContext
11+
12+
let processor = Context.defaultContext.Queue
13+
14+
let config = { Utils.defaultConfig with arbitrary = [typeof<Generators.ClArray.Set>]}
15+
16+
let makeTest<'a when 'a : equality> testFun (array: 'a [], position, value: 'a) =
17+
18+
if array.Length > 0 then
19+
20+
let clArray = context.CreateClArray array
21+
22+
testFun processor clArray position value
23+
24+
let actual = clArray.ToHostAndFree processor
25+
Array.set array position value
26+
27+
"Results must be the same"
28+
|> Utils.compareArrays (=) actual array
29+
30+
let createTest<'a when 'a : equality> =
31+
ClArray.set context Utils.defaultWorkGroupSize
32+
|> makeTest<'a>
33+
|> testPropertyWithConfig config $"test on %A{typeof<'a>}"
34+
35+
let tests =
36+
[ createTest<int>
37+
38+
if Utils.isFloat64Available context.ClDevice then
39+
createTest<float>
40+
41+
createTest<float32>
42+
createTest<bool> ]
43+
|> testList "Set"
44+

tests/GraphBLAS-sharp.Tests/Generators.fs

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,121 @@ module Generators =
12781278
arrayAndChunkPosition <| Arb.generate<bool>
12791279
|> Arb.fromGen
12801280

1281+
module ClArray =
1282+
type Set() =
1283+
static let arrayAndChunkPosition (valuesGenerator: Gen<'a>) =
1284+
gen {
1285+
let! size = Gen.sized <| fun size -> Gen.choose (1, size + 1)
1286+
1287+
let! array = Gen.arrayOfLength size valuesGenerator
1288+
1289+
let! position = Gen.choose (0, array.Length - 1)
1290+
1291+
let! value = valuesGenerator
1292+
1293+
return (array, position, value)
1294+
}
1295+
1296+
static member IntType() =
1297+
arrayAndChunkPosition <| Arb.generate<int>
1298+
|> Arb.fromGen
1299+
1300+
static member FloatType() =
1301+
arrayAndChunkPosition
1302+
<| (Arb.Default.NormalFloat()
1303+
|> Arb.toGen
1304+
|> Gen.map float)
1305+
|> Arb.fromGen
1306+
1307+
static member Float32Type() =
1308+
arrayAndChunkPosition
1309+
<| (normalFloat32Generator <| System.Random())
1310+
|> Arb.fromGen
1311+
1312+
static member SByteType() =
1313+
arrayAndChunkPosition <| Arb.generate<sbyte>
1314+
|> Arb.fromGen
1315+
1316+
static member ByteType() =
1317+
arrayAndChunkPosition <| Arb.generate<byte>
1318+
|> Arb.fromGen
1319+
1320+
static member Int16Type() =
1321+
arrayAndChunkPosition <| Arb.generate<int16>
1322+
|> Arb.fromGen
1323+
1324+
static member UInt16Type() =
1325+
arrayAndChunkPosition <| Arb.generate<uint16>
1326+
|> Arb.fromGen
1327+
1328+
static member Int32Type() =
1329+
arrayAndChunkPosition <| Arb.generate<int32>
1330+
|> Arb.fromGen
1331+
1332+
static member UInt32Type() =
1333+
arrayAndChunkPosition <| Arb.generate<uint32>
1334+
|> Arb.fromGen
1335+
1336+
static member BoolType() =
1337+
arrayAndChunkPosition <| Arb.generate<bool>
1338+
|> Arb.fromGen
1339+
1340+
type Item() =
1341+
static let arrayAndChunkPosition (valuesGenerator: Gen<'a>) =
1342+
gen {
1343+
let! size = Gen.sized <| fun size -> Gen.choose (1, size + 1)
1344+
1345+
let! array = Gen.arrayOfLength size valuesGenerator
1346+
1347+
let! position = Gen.choose (0, array.Length - 1)
1348+
1349+
return (array, position)
1350+
}
1351+
1352+
static member IntType() =
1353+
arrayAndChunkPosition <| Arb.generate<int>
1354+
|> Arb.fromGen
1355+
1356+
static member FloatType() =
1357+
arrayAndChunkPosition
1358+
<| (Arb.Default.NormalFloat()
1359+
|> Arb.toGen
1360+
|> Gen.map float)
1361+
|> Arb.fromGen
1362+
1363+
static member Float32Type() =
1364+
arrayAndChunkPosition
1365+
<| (normalFloat32Generator <| System.Random())
1366+
|> Arb.fromGen
1367+
1368+
static member SByteType() =
1369+
arrayAndChunkPosition <| Arb.generate<sbyte>
1370+
|> Arb.fromGen
1371+
1372+
static member ByteType() =
1373+
arrayAndChunkPosition <| Arb.generate<byte>
1374+
|> Arb.fromGen
1375+
1376+
static member Int16Type() =
1377+
arrayAndChunkPosition <| Arb.generate<int16>
1378+
|> Arb.fromGen
1379+
1380+
static member UInt16Type() =
1381+
arrayAndChunkPosition <| Arb.generate<uint16>
1382+
|> Arb.fromGen
1383+
1384+
static member Int32Type() =
1385+
arrayAndChunkPosition <| Arb.generate<int32>
1386+
|> Arb.fromGen
1387+
1388+
static member UInt32Type() =
1389+
arrayAndChunkPosition <| Arb.generate<uint32>
1390+
|> Arb.fromGen
1391+
1392+
static member BoolType() =
1393+
arrayAndChunkPosition <| Arb.generate<bool>
1394+
|> Arb.fromGen
1395+
12811396
module Matrix =
12821397
type Sub() =
12831398
static let arrayAndChunkPosition (valuesGenerator: Gen<'a>) =

tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
<Compile Include="Backend/Common/ClArray/RemoveDuplicates.fs" />
3131
<Compile Include="Backend/Common/ClArray/Replicate.fs" />
3232
<Compile Include="Backend/Common/ClArray/UpperBound.fs" />
33+
<Compile Include="Backend/Common/ClArray/Set.fs" />
34+
<Compile Include="Backend/Common/ClArray/Item.fs" />
3335
<Compile Include="Backend/Common/Gather.fs" />
3436
<Compile Include="Backend/Common/Reduce/Reduce.fs" />
3537
<Compile Include="Backend/Common/Reduce/ReduceByKey.fs" />

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ let commonTests =
4949
Common.ClArray.Concat.tests
5050
Common.ClArray.Fill.tests
5151
Common.ClArray.Pairwise.tests
52-
Common.ClArray.UpperBound.tests ]
52+
Common.ClArray.UpperBound.tests
53+
Common.ClArray.Set.tests
54+
Common.ClArray.Item.tests ]
5355

5456
let sortTests =
5557
testList

0 commit comments

Comments
 (0)