Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ for f in (:eig, :eigh)
_warn_pullback_truncerror(dϵ)

# compute pullbacks
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did this escape previous testing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not entirely sure, but it seems like the test_pullback_match functions had some holes in the testing functionality

zero!.(dDVtrunc) # since this is allocated in this function this is probably not required

# restore state
Expand Down Expand Up @@ -351,8 +351,8 @@ for f in (:eig, :eigh)
dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc)))
function $f_adjoint!(::NoRData)
# compute pullbacks
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
zero!.(dDV)

# restore state
copy!(A, Ac)
Expand Down Expand Up @@ -425,7 +425,7 @@ for (f!, f) in (
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
USVᴴc = copy.(USVᴴ)
output = $f!(A, Mooncake.primal(alg_dalg))
output = $f!(A, USVᴴ, Mooncake.primal(alg_dalg))
function svd_adjoint(::NoRData)
copy!(A, Ac)
if $(f! == svd_compact!)
Expand Down Expand Up @@ -590,7 +590,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
_warn_pullback_truncerror(dϵ)

# compute pullbacks
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind)
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
zero!.(dUSVᴴ)

Expand Down Expand Up @@ -717,8 +717,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))
function svd_trunc_adjoint(::NoRData)
# compute pullbacks
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here; was this not tested before? This could not have worked.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reply here, because the in-place functions weren't really tested very well, this slipped through the cracks

zero!.(dUSVᴴ)

# restore state
Expand Down
23 changes: 13 additions & 10 deletions src/pullbacks/lq.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,33 @@
lq_rank(L; kwargs...) = qr_rank(L; kwargs...)

function check_lq_cotangents(
L, Q, ΔL, ΔQ, minmn::Int, p::Int;
L, Q, ΔL, ΔQ, p::Int;
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
)
minmn = min(size(L, 1), size(Q, 2))
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# rows of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Δgauge = max(Δgauge, norm(ΔQ2))
Δgauge_Q = norm(ΔQ2, Inf)
Δgauge = max(Δgauge, Δgauge_Q)
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):size(L, 1), (p + 1):minmn)
Δgauge = max(Δgauge, norm(ΔL22))
Δgauge_L = norm(ΔL22, Inf)
Δgauge = max(Δgauge, Δgauge_L)
end
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
return
end

function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(Q1))
function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2))
# in the case where A is full rank, but there are more columns in Q than in A
# (the case of `lq_full`), there is gauge-invariant information in the
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
Expand All @@ -32,7 +37,7 @@ function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = defaul
# Q2' * ΔQ2 as a gauge dependent quantity.
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
@warn "`lq_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

Expand Down Expand Up @@ -62,9 +67,7 @@ function lq_pullback!(
L, Q = LQ
m = size(L, 1)
n = size(Q, 2)
minmn = min(m, n)
Ld = diagview(L)
p = @something findlast(>=(rank_atol) ∘ abs, Ld) 0
p = lq_rank(L; rank_atol)

ΔL, ΔQ = ΔLQ

Expand All @@ -74,7 +77,7 @@ function lq_pullback!(
ΔA1 = view(ΔA, 1:p, :)
ΔA2 = view(ΔA, (p + 1):m, :)

check_lq_cotangents(L, Q, ΔL, ΔQ, minmn, p; gauge_atol)
check_lq_cotangents(L, Q, ΔL, ΔQ, p; gauge_atol)

ΔQ̃ = zero!(similar(Q, (p, n)))
if !iszerotangent(ΔQ)
Expand Down
23 changes: 15 additions & 8 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ))
qr_rank(R; rank_atol = default_pullback_rank_atol(R)) =
@something findlast(>=(rank_atol) ∘ abs, diagview(R)) 0

function check_qr_cotangents(
Q, R, ΔQ, ΔR, p::Int;
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
)
minmn = min(size(Q, 1), size(R, 2))
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
Expand All @@ -7,11 +14,13 @@ function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Rea
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
Δgauge_Q = norm(ΔQ2, Inf)
Δgauge = max(Δgauge, Δgauge_Q)
end
if !iszerotangent(ΔR)
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
Δgauge = max(Δgauge, norm(ΔR22, Inf))
Δgauge_R = norm(ΔR22, Inf)
Δgauge = max(Δgauge, Δgauge_R)
end
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
Expand All @@ -29,7 +38,7 @@ function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_
# Q2' * ΔQ2 as a gauge dependent quantity.
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
@warn "`qr_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

Expand Down Expand Up @@ -60,19 +69,17 @@ function qr_pullback!(
Q, R = QR
m = size(Q, 1)
n = size(R, 2)
minmn = min(m, n)
Rd = diagview(R)
p = @something findlast(>=(rank_atol) ∘ abs, Rd) 0
p = qr_rank(R; rank_atol)

ΔQ, ΔR = ΔQR

Q1 = view(Q, :, 1:p)
Q2 = view(Q, :, (p + 1):size(Q, 2))
R11 = view(R, 1:p, 1:p)
ΔA1 = view(ΔA, :, 1:p)
ΔA2 = view(ΔA, :, (p + 1):n)

check_qr_cotangents(Q, R, ΔQ, ΔR, minmn, p; gauge_atol)
check_qr_cotangents(Q, R, ΔQ, ΔR, p; gauge_atol)

ΔQ̃ = zero!(similar(Q, (m, p)))
if !iszerotangent(ΔQ)
Expand Down
29 changes: 0 additions & 29 deletions test/mooncake.jl

This file was deleted.

19 changes: 19 additions & 0 deletions test/mooncake/eig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/eigh.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/lq.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/orthnull.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
21 changes: 21 additions & 0 deletions test/mooncake/polar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
atol = rtol = m * n * TestSuite.precision(T)
m >= n && TestSuite.test_mooncake_left_polar(T, (m, n); atol, rtol)
n >= m && TestSuite.test_mooncake_right_polar(T, (m, n); atol, rtol)
end
end
19 changes: 19 additions & 0 deletions test/mooncake/qr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ if filter_tests!(testsuite, args)
is_apple_ci = Sys.isapple() && get(ENV, "CI", "false") == "true"
if is_apple_ci
delete!(testsuite, "enzyme")
delete!(testsuite, "mooncake")
filter!(p -> !startswith(first(p), "mooncake/"), testsuite)
delete!(testsuite, "chainrules")
end
Sys.iswindows() && delete!(testsuite, "enzyme")
Expand Down
Loading