Skip to content

Commit 9dab665

Browse files
committed
Specialize indexing triangular matrices with BandIndex
1 parent b6d2155 commit 9dab665

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

stdlib/LinearAlgebra/src/dense.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ norm2(x::Union{Array{T},StridedVector{T}}) where {T<:BlasFloat} =
110110
# Conservative assessment of types that have zero(T) defined for themselves
111111
haszero(::Type) = false
112112
haszero(::Type{T}) where {T<:Number} = isconcretetype(T)
113-
@propagate_inbounds _zero(M::AbstractArray{T}, i, j) where {T} = haszero(T) ? zero(T) : zero(M[i,j])
113+
@propagate_inbounds _zero(M::AbstractArray{T}, inds...) where {T} = haszero(T) ? zero(T) : zero(M[inds...])
114114

115115
"""
116116
triu!(M, k::Integer)

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,20 @@ Base.isstored(A::UpperTriangular, i::Int, j::Int) =
236236
@propagate_inbounds getindex(A::UpperTriangular, i::Int, j::Int) =
237237
i <= j ? A.data[i,j] : _zero(A.data,j,i)
238238

239+
# these specialized getindex methods enable constant-propagation of the band
240+
Base.@constprop :aggressive @propagate_inbounds function getindex(A::UnitLowerTriangular{T}, b::BandIndex) where {T}
241+
b.band < 0 ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
242+
end
243+
Base.@constprop :aggressive @propagate_inbounds function getindex(A::LowerTriangular, b::BandIndex)
244+
b.band <= 0 ? A.data[b] : _zero(A.data, b)
245+
end
246+
Base.@constprop :aggressive @propagate_inbounds function getindex(A::UnitUpperTriangular{T}, b::BandIndex) where {T}
247+
b.band > 0 ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
248+
end
249+
Base.@constprop :aggressive @propagate_inbounds function getindex(A::UpperTriangular, b::BandIndex)
250+
b.band >= 0 ? A.data[b] : _zero(A.data, b)
251+
end
252+
239253
_zero_triangular_half_str(::Type{<:UpperOrUnitUpperTriangular}) = "lower"
240254
_zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"
241255

0 commit comments

Comments
 (0)