Skip to content

Commit aa6870e

Browse files
committed
fix rebase conflicts
1 parent 66482df commit aa6870e

3 files changed

Lines changed: 13 additions & 21 deletions

File tree

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using LinearAlgebra: BlasFloat
1515

1616
include("yacusolver.jl")
1717

18-
MatrixAlgebraKit.default_householder_driver(::StridedCuMatrix{<:BlasFloat}) = CUSOLVER()
18+
MatrixAlgebraKit.default_householder_driver(::StridedCuVecOrMat{<:BlasFloat}) = CUSOLVER()
1919
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
2020
return CUSOLVER_QRIteration(; kwargs...)
2121
end

src/interface/lq.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,14 @@ default_lq_algorithm(A; kwargs...) = default_lq_algorithm(typeof(A); kwargs...)
7272

7373
default_lq_algorithm(T::Type; kwargs...) =
7474
throw(MethodError(default_lq_algorithm, (T,)))
75-
end
7675
default_lq_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} =
7776
Householder(; kwargs...)
78-
function default_lq_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
79-
return DiagonalAlgorithm(; kwargs...)
80-
end
81-
function default_lq_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A}
82-
return default_lq_algorithm(A)
83-
end
84-
function default_lq_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A}
85-
return default_lq_algorithm(A)
86-
end
77+
default_lq_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} =
78+
DiagonalAlgorithm(; kwargs...)
79+
default_lq_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} =
80+
default_lq_algorithm(A)
81+
default_lq_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A} =
82+
default_lq_algorithm(A)
8783

8884
for f in (:lq_full!, :lq_compact!, :lq_null!)
8985
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}

src/interface/qr.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,14 @@ default_qr_algorithm(A; kwargs...) = default_qr_algorithm(typeof(A); kwargs...)
7272

7373
default_qr_algorithm(T::Type; kwargs...) =
7474
throw(MethodError(default_qr_algorithm, (T,)))
75-
end
7675
default_qr_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} =
7776
Householder(; kwargs...)
78-
function default_qr_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
79-
return DiagonalAlgorithm(; kwargs...)
80-
end
81-
function default_qr_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A}
82-
return default_qr_algorithm(A)
83-
end
84-
function default_qr_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A}
85-
return default_qr_algorithm(A)
86-
end
77+
default_qr_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} =
78+
DiagonalAlgorithm(; kwargs...)
79+
default_qr_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} =
80+
default_qr_algorithm(A)
81+
default_qr_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A} =
82+
default_qr_algorithm(A)
8783

8884
for f in (:qr_full!, :qr_compact!, :qr_null!)
8985
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}

0 commit comments

Comments
 (0)