Skip to content

Commit eb22285

Browse files
committed
possibly fix implementation
1 parent 3e73f99 commit eb22285

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -793,11 +793,15 @@ for (f!, f, adj) in (
793793
arg = $f!(A, arg, Mooncake.primal(alg_dalg))
794794

795795
function $adj(::NoRData)
796-
dA .+= $f(darg)
797-
dA === darg || zero!(darg)
796+
$f!(darg)
797+
if dA !== darg
798+
dA .+= darg
799+
zero!(darg)
800+
end
798801
copy!(arg, argc)
799802
return ntuple(Returns(NoRData()), 4)
800803
end
804+
801805
return arg_darg, $adj
802806
end
803807

0 commit comments

Comments
 (0)