diff --git a/src/Spaces/PolynomialSpace.jl b/src/Spaces/PolynomialSpace.jl index e632fbc..d2d8f9b 100644 --- a/src/Spaces/PolynomialSpace.jl +++ b/src/Spaces/PolynomialSpace.jl @@ -151,20 +151,33 @@ See https://github.com/JuliaLinearAlgebra/BandedMatrices.jl/blob/master/LICENSE end _view(::Any, A, b) = view(A, b) -_view(::Val{true}, A::BandedMatrix, b) = dataview(view(A, b)) +function _view(::Val{true}, A::BandedMatrix, b::Band) + l, u = bandwidths(A) + -l <= b.i <= u || throw(ArgumentError("invalid band $b for bandwidths $((-l,u))")) + dataview(view(A, b)) +end -function _get_bands(B, C, bmk, f, ValBC) +function _get_bands(B, C, bmk, f, valB) Cbmk = _view(Val(true), C, band(bmk*f)) Bm = _view(Val(true), B, band(flipsign(bmk-1, f))) B0 = _view(Val(true), B, band(flipsign(bmk, f))) - Bp = _view(ValBC, B, band(flipsign(bmk+1, f))) + Bp = _view(valB, B, band(flipsign(bmk+1, f))) Cbmk, Bm, B0, Bp end -function _jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, ValJ, ValBC) - Jp = _view(ValJ, J, band(1)) - J0 = _view(ValJ, J, band(0)) - Jm = _view(ValJ, J, band(-1)) +# Fast implementation of C[:,:] = α*J*B+β*C where the bandediwth of B is +# specified by b, not by the parameters in B +function jac_gbmm!(α, J, B, β, C, b, valB) + if β ≠ 1 + lmul!(β,C) + end + + n = size(J,1) + Cn, Cm = size(C) + + Jp = _view(Val(true), J, band(1)) + J0 = _view(Val(true), J, band(0)) + Jm = _view(Val(true), J, band(-1)) kr = intersect(-1:b-1, b-Cm+1:b-1+Cn) @@ -172,7 +185,7 @@ function _jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, ValJ, ValBC) # this might also help with cache localization k = -1 if k in kr - Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, ValBC) + Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, valB) for i in 1:n-b+k Cbmk[i] += α * Bm[i+1] * Jp[i] end @@ -180,14 +193,14 @@ function _jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, ValJ, ValBC) k = 0 if k in kr - Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, Val(true)) + Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, valB) for i in 1:n-b+k Cbmk[i] += α * (Bm[i+1] * Jp[i] + B0[i] * J0[i]) end end for k in max(1, first(kr)):last(kr) - Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, Val(true)) + Cbmk, Bm, B0, Bp = _get_bands(B, C, b-k, 1, valB) Cbmk[1] += α * (Bm[2] * Jp[1] + B0[1] * J0[1]) for i in 2:n-b+k Cbmk[i] += α * (Bm[i+1] * Jp[i] + B0[i] * J0[i] + Bp[i-1] * Jm[i-1]) @@ -198,7 +211,7 @@ function _jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, ValJ, ValBC) k = -1 if k in kr - Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, ValBC) + Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, valB) for (i, Ji) in enumerate(b-k:n-1) Ckmb[i] += α * Bp[i] * Jm[Ji] end @@ -206,7 +219,7 @@ function _jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, ValJ, ValBC) k = 0 if k in kr - Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, Val(true)) + Ckmb, Bp, B0, Bm = _get_bands(B, C, b-k, -1, valB) Ckmb[1] += α * Bp[1] * Jm[b-k] for (i, Ji) in enumerate(b-k+1:n-1) Ckmb[i] += α * B0[i] * J0[Ji] @@ -238,21 +251,6 @@ function _jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, ValJ, ValBC) return C end -# Fast implementation of C[:,:] = α*J*B+β*C where the bandediwth of B is -# specified by b, not by the parameters in B -function jac_gbmm!(α, J, B, β, C, b, valJ, valBC) - if β ≠ 1 - lmul!(β,C) - end - - n = size(J,1) - Cn, Cm = size(C) - - _jac_gbmm!(α, J, B, β, C, b, (Cn, Cm), n, valJ, valBC) - - C -end - function BandedMatrix(S::SubOperator{T,ConcreteMultiplication{C,PS,T}, NTuple{2,UnitRange{Int}}}) where {PS<:PolynomialSpace,T,C<:PolynomialSpace} M=parent(S) @@ -285,7 +283,6 @@ function BandedMatrix(S::SubOperator{T,ConcreteMultiplication{C,PS,T}, #Multiplication is transpose J=Operator{T}(Recurrence(M.space))[jkr,jkr] - valJ = all(>=(1), bandwidths(J)) ? Val(true) : Val(false) B=n-1 # final bandwidth @@ -293,15 +290,16 @@ function BandedMatrix(S::SubOperator{T,ConcreteMultiplication{C,PS,T}, Bk2 = BandedMatrix(Zeros{T}(size(J)), (B,B)) dataview(view(Bk2, band(0))) .= a[n]/recβ(T,sp,n-1) α,β = recα(T,sp,n-1),recβ(T,sp,n-2) - Bk1 = (-α/β)*Bk2 + Bk1 = lmul!(-α/β, copy(Bk2)) dataview(view(Bk1, band(0))) .+= a[n-1]/β - jac_gbmm!(one(T)/β,J,Bk2,one(T),Bk1,0,valJ, Val(true)) + jac_gbmm!(one(T)/β,J,Bk2,one(T),Bk1,0, Val(true)) b=1 # we keep track of bandwidths manually to reuse memory for k=n-2:-1:2 + # b goes from 1: α,β,γ=recα(T,sp,k),recβ(T,sp,k-1),recγ(T,sp,k+1) lmul!(-γ/β,Bk2) dataview(view(Bk2, band(0))) .+= a[k]/β - jac_gbmm!(1/β,J,Bk1,one(T),Bk2,b,valJ,Val(true)) + jac_gbmm!(1/β,J,Bk1,one(T),Bk2,b,Val(true)) LinearAlgebra.axpy!(-α/β,Bk1,Bk2) Bk2,Bk1=Bk1,Bk2 b+=1 @@ -309,7 +307,7 @@ function BandedMatrix(S::SubOperator{T,ConcreteMultiplication{C,PS,T}, α,γ=recα(T,sp,1),recγ(T,sp,2) lmul!(-γ,Bk2) dataview(view(Bk2, band(0))) .+= a[1] - jac_gbmm!(one(T),J,Bk1,one(T),Bk2,b,valJ,Val(false)) + jac_gbmm!(one(T),J,Bk1,one(T),Bk2,b,Val(false)) LinearAlgebra.axpy!(-α,Bk1,Bk2) # relationship between jkr and kr, jr