-
Notifications
You must be signed in to change notification settings - Fork 5
Mooncake testsuite refactor #175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c69b57b
734f390
023cfed
159da4d
deda605
98c9fc0
64a4246
8b209c2
f8d7cc6
552559f
9161bde
6550d8b
b3329de
61e8061
0de37f0
2238bf8
7dd64c8
81d9d42
bb4076d
c7dadcd
a140eda
4afc992
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -239,7 +239,7 @@ for f in (:eig, :eigh) | |
| _warn_pullback_truncerror(dϵ) | ||
|
|
||
| # compute pullbacks | ||
| $f_pullback!(dA, Ac, DVc, dDVtrunc, ind) | ||
| $f_pullback!(dA, Ac, DV, dDVtrunc, ind) | ||
| zero!.(dDVtrunc) # since this is allocated in this function this is probably not required | ||
|
|
||
| # restore state | ||
|
|
@@ -351,8 +351,8 @@ for f in (:eig, :eigh) | |
| dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc))) | ||
| function $f_adjoint!(::NoRData) | ||
| # compute pullbacks | ||
| $f_pullback!(dA, Ac, DVc, dDVtrunc, ind) | ||
| zero!.(dDVtrunc) # since this is allocated in this function this is probably not required | ||
| $f_pullback!(dA, Ac, DV, dDVtrunc, ind) | ||
| zero!.(dDV) | ||
|
|
||
| # restore state | ||
| copy!(A, Ac) | ||
|
|
@@ -425,7 +425,7 @@ for (f!, f) in ( | |
| S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) | ||
| Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) | ||
| USVᴴc = copy.(USVᴴ) | ||
| output = $f!(A, Mooncake.primal(alg_dalg)) | ||
| output = $f!(A, USVᴴ, Mooncake.primal(alg_dalg)) | ||
| function svd_adjoint(::NoRData) | ||
| copy!(A, Ac) | ||
| if $(f! == svd_compact!) | ||
|
|
@@ -590,7 +590,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS | |
| _warn_pullback_truncerror(dϵ) | ||
|
|
||
| # compute pullbacks | ||
| svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind) | ||
| svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind) | ||
| zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required | ||
| zero!.(dUSVᴴ) | ||
|
|
||
|
|
@@ -717,8 +717,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U | |
| dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))) | ||
| function svd_trunc_adjoint(::NoRData) | ||
| # compute pullbacks | ||
| svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind) | ||
| zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required | ||
| svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question here; was this not tested before? This could not have worked.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same reply here, because the in-place functions weren't really tested very well, this slipped through the cracks |
||
| zero!.(dUSVᴴ) | ||
|
|
||
| # restore state | ||
|
|
||
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| using MatrixAlgebraKit | ||
| using Test | ||
| using LinearAlgebra: Diagonal | ||
| using CUDA, AMDGPU | ||
|
|
||
| BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI | ||
| GenericFloats = () | ||
| @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") | ||
| using .TestSuite | ||
|
|
||
| is_buildkite = get(ENV, "BUILDKITE", "false") == "true" | ||
|
|
||
| m = 19 | ||
| for T in (BLASFloats..., GenericFloats...) | ||
| TestSuite.seed_rng!(123) | ||
| if !is_buildkite | ||
| TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) | ||
| end | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| using MatrixAlgebraKit | ||
| using Test | ||
| using LinearAlgebra: Diagonal | ||
| using CUDA, AMDGPU | ||
|
|
||
| BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI | ||
| GenericFloats = () | ||
| @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") | ||
| using .TestSuite | ||
|
|
||
| is_buildkite = get(ENV, "BUILDKITE", "false") == "true" | ||
|
|
||
| m = 19 | ||
| for T in (BLASFloats..., GenericFloats...) | ||
| TestSuite.seed_rng!(123) | ||
| if !is_buildkite | ||
| TestSuite.test_mooncake_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) | ||
| end | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| using MatrixAlgebraKit | ||
| using Test | ||
| using LinearAlgebra: Diagonal | ||
| using CUDA, AMDGPU | ||
|
|
||
| BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI | ||
| GenericFloats = () | ||
| @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") | ||
| using .TestSuite | ||
|
|
||
| is_buildkite = get(ENV, "BUILDKITE", "false") == "true" | ||
|
|
||
| m = 19 | ||
| for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) | ||
| TestSuite.seed_rng!(123) | ||
| if !is_buildkite | ||
| TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) | ||
| end | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| using MatrixAlgebraKit | ||
| using Test | ||
| using LinearAlgebra: Diagonal | ||
| using CUDA, AMDGPU | ||
|
|
||
| BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI | ||
| GenericFloats = () | ||
| @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") | ||
| using .TestSuite | ||
|
|
||
| is_buildkite = get(ENV, "BUILDKITE", "false") == "true" | ||
|
|
||
| m = 19 | ||
| for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) | ||
| TestSuite.seed_rng!(123) | ||
| if !is_buildkite | ||
| TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) | ||
| end | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| using MatrixAlgebraKit | ||
| using Test | ||
| using LinearAlgebra: Diagonal | ||
| using CUDA, AMDGPU | ||
|
|
||
| BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI | ||
| GenericFloats = () | ||
| @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") | ||
| using .TestSuite | ||
|
|
||
| is_buildkite = get(ENV, "BUILDKITE", "false") == "true" | ||
|
|
||
| m = 19 | ||
| for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) | ||
| TestSuite.seed_rng!(123) | ||
| if !is_buildkite | ||
| atol = rtol = m * n * TestSuite.precision(T) | ||
| m >= n && TestSuite.test_mooncake_left_polar(T, (m, n); atol, rtol) | ||
| n >= m && TestSuite.test_mooncake_right_polar(T, (m, n); atol, rtol) | ||
| end | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| using MatrixAlgebraKit | ||
| using Test | ||
| using LinearAlgebra: Diagonal | ||
| using CUDA, AMDGPU | ||
|
|
||
| BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI | ||
| GenericFloats = () | ||
| @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") | ||
| using .TestSuite | ||
|
|
||
| is_buildkite = get(ENV, "BUILDKITE", "false") == "true" | ||
|
|
||
| m = 19 | ||
| for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) | ||
| TestSuite.seed_rng!(123) | ||
| if !is_buildkite | ||
| TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) | ||
| end | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| using MatrixAlgebraKit | ||
| using Test | ||
| using LinearAlgebra: Diagonal | ||
| using CUDA, AMDGPU | ||
|
|
||
| BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI | ||
| GenericFloats = () | ||
| @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") | ||
| using .TestSuite | ||
|
|
||
| is_buildkite = get(ENV, "BUILDKITE", "false") == "true" | ||
|
|
||
| m = 19 | ||
| for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) | ||
| TestSuite.seed_rng!(123) | ||
| if !is_buildkite | ||
| TestSuite.test_mooncake_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) | ||
| end | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How did this escape previous testing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not entirely sure, but it seems like the
test_pullback_matchfunctions had some holes in the testing functionality