Skip to content

Commit d839957

Browse files
committed
Refactor Enzyme testsuite
1 parent a349aef commit d839957

21 files changed

Lines changed: 843 additions & 584 deletions

test/enzyme.jl

Lines changed: 0 additions & 30 deletions
This file was deleted.

test/enzyme/eig.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/eigh.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/lq.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/orthnull.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/polar.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/qr.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/svd.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64,) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(1234)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@ if filter_tests!(testsuite, args)
2626
else
2727
is_apple_ci = Sys.isapple() && get(ENV, "CI", "false") == "true"
2828
if is_apple_ci
29-
delete!(testsuite, "enzyme")
3029
delete!(testsuite, "mooncake")
3130
delete!(testsuite, "chainrules")
3231
end
33-
Sys.iswindows() && delete!(testsuite, "enzyme")
32+
(Sys.iswindows() || is_apple_ci) && filter!(p -> !startswith(first(p), "enzyme/"), testsuite)
3433
end
3534
end
3635

test/testsuite/TestSuite.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ using MatrixAlgebraKit: diagview
1414
using LinearAlgebra: Diagonal, norm, istriu, istril, I
1515
using Random, StableRNGs
1616
using AMDGPU, CUDA
17+
using Enzyme, EnzymeTestUtils
1718

1819
const tests = Dict()
1920

@@ -95,8 +96,15 @@ include("eig.jl")
9596
include("eigh.jl")
9697
include("orthnull.jl")
9798
include("svd.jl")
98-
include("mooncake.jl")
99-
include("enzyme.jl")
100-
include("chainrules.jl")
99+
#include("mooncake.jl")
100+
#include("chainrules.jl")
101+
102+
include("enzyme/eig.jl")
103+
include("enzyme/eigh.jl")
104+
include("enzyme/qr.jl")
105+
include("enzyme/lq.jl")
106+
include("enzyme/svd.jl")
107+
include("enzyme/polar.jl")
108+
include("enzyme/orthnull.jl")
101109

102110
end

0 commit comments

Comments
 (0)