diff --git a/Project.toml b/Project.toml index d20a13e..d61d384 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseArraysBase" uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208" authors = ["ITensor developers and contributors"] -version = "0.5.4" +version = "0.5.5" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/README.md b/README.md index 956ac7b..dc22266 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,7 @@ using Dictionaries: IndexError @test isstored(a, 1, 2) @test setstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0] @test_throws IndexError setstoredindex!(copy(a), 21, 2, 1) -@test setunstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0] +@test_throws IndexError setunstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0] @test storedlength(a) == 1 @test issetequal(storedpairs(a), [CartesianIndex(1, 2) => 12]) @test issetequal(storedvalues(a), [12]) diff --git a/examples/README.jl b/examples/README.jl index 4198c08..bea4830 100644 --- a/examples/README.jl +++ b/examples/README.jl @@ -86,7 +86,7 @@ using Dictionaries: IndexError @test isstored(a, 1, 2) @test setstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0] @test_throws IndexError setstoredindex!(copy(a), 21, 2, 1) -@test setunstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0] +@test_throws IndexError setunstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0] @test storedlength(a) == 1 @test issetequal(storedpairs(a), [CartesianIndex(1, 2) => 12]) @test issetequal(storedvalues(a), [12]) diff --git a/src/SparseArraysBase.jl b/src/SparseArraysBase.jl index 9d0ea77..0e0db90 100644 --- a/src/SparseArraysBase.jl +++ b/src/SparseArraysBase.jl @@ -19,6 +19,7 @@ export SparseArrayDOK, include("abstractsparsearrayinterface.jl") include("sparsearrayinterface.jl") +include("indexing.jl") include("wrappers.jl") include("abstractsparsearray.jl") include("sparsearraydok.jl") diff --git a/src/abstractsparsearrayinterface.jl b/src/abstractsparsearrayinterface.jl index c28a3e9..5ad5d0b 100644 --- a/src/abstractsparsearrayinterface.jl +++ b/src/abstractsparsearrayinterface.jl @@ -32,87 +32,6 @@ end # Minimal interface for `SparseArrayInterface`. # Fallbacks for dense/non-sparse arrays. -function isstored(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N} - @_propagate_inbounds_meta - @boundscheck checkbounds(a, I...) - return true -end -function isstored(a::AbstractArray, I::Int) - @_propagate_inbounds_meta - return isstored(a, Tuple(CartesianIndices(a)[I])...) -end -function isstored(a::AbstractArray, I::Int...) - @_propagate_inbounds_meta - @boundscheck checkbounds(a, I...) - I′ = ntuple(i -> I[i], ndims(a)) - return isstored(a, I′...) -end - -@interface ::AbstractArrayInterface eachstoredindex(a::AbstractArray) = eachindex(a) -@interface ::AbstractArrayInterface getstoredindex(a::AbstractArray, I::Int...) = - getindex(a, I...) -@interface ::AbstractArrayInterface function setstoredindex!( - a::AbstractArray, value, I::Int... -) - setindex!(a, value, I...) - return a -end -# TODO: Should this error by default if the value at the index -# is stored? It could be disabled with something analogous -# to `checkbounds`, like `checkstored`/`checkunstored`. -@interface ::AbstractArrayInterface function setunstoredindex!( - a::AbstractArray, value, I::Int... -) - # TODO: Make this a `MethodError`? - return error("Not implemented.") -end - -# TODO: Use `Base.to_indices`? -isstored(a::AbstractArray, I::CartesianIndex) = isstored(a, Tuple(I)...) -# TODO: Use `Base.to_indices`? -getstoredindex(a::AbstractArray, I::CartesianIndex) = getstoredindex(a, Tuple(I)...) -# TODO: Use `Base.to_indices`? -getunstoredindex(a::AbstractArray, I::CartesianIndex) = getunstoredindex(a, Tuple(I)...) -# TODO: Use `Base.to_indices`? -function setstoredindex!(a::AbstractArray, value, I::CartesianIndex) - return setstoredindex!(a, value, Tuple(I)...) -end -# TODO: Use `Base.to_indices`? -function setunstoredindex!(a::AbstractArray, value, I::CartesianIndex) - return setunstoredindex!(a, value, Tuple(I)...) -end - -# Interface defaults. -# TODO: Have a fallback that handles element types -# that don't define `zero(::Type)`. -@interface ::AbstractArrayInterface getunstoredindex(a::AbstractArray, I::Int...) = - zero(eltype(a)) - -# DerivableInterfacesd interface. -@interface ::AbstractArrayInterface storedlength(a::AbstractArray) = length(storedvalues(a)) -@interface ::AbstractArrayInterface storedpairs(a::AbstractArray) = - map(I -> I => getstoredindex(a, I), eachstoredindex(a)) - -@interface ::AbstractArrayInterface function eachstoredindex(as::AbstractArray...) - return eachindex(as...) -end - -@interface ::AbstractArrayInterface storedvalues(a::AbstractArray) = a - -# Automatically derive the interface for all `AbstractArray` subtypes. -# TODO: Define `SparseArrayInterfaceOps` derivable trait and rewrite this -# as `@derive AbstractArray SparseArrayInterfaceOps`. -@derive (T=AbstractArray,) begin - SparseArraysBase.eachstoredindex(::T) - SparseArraysBase.eachstoredindex(::T...) - SparseArraysBase.getstoredindex(::T, ::Int...) - SparseArraysBase.getunstoredindex(::T, ::Int...) - SparseArraysBase.setstoredindex!(::T, ::Any, ::Int...) - SparseArraysBase.setunstoredindex!(::T, ::Any, ::Int...) - SparseArraysBase.storedlength(::T) - SparseArraysBase.storedpairs(::T) - SparseArraysBase.storedvalues(::T) -end # TODO: Add `ndims` type parameter, like `Base.Broadcast.AbstractArrayStyle`. # TODO: This isn't used to define interface functions right now. @@ -160,56 +79,9 @@ struct StoredValues{T,A<:AbstractArray{T},I} <: AbstractVector{T} end StoredValues(a::AbstractArray) = StoredValues(a, to_vec(eachstoredindex(a))) Base.size(a::StoredValues) = size(a.storedindices) -Base.getindex(a::StoredValues, I::Int) = getstoredindex(a.array, a.storedindices[I]) -function Base.setindex!(a::StoredValues, value, I::Int) - return setstoredindex!(a.array, value, a.storedindices[I]) -end - -@interface ::AbstractSparseArrayInterface function isstored( - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - return CartesianIndex(I) in eachstoredindex(a) -end - -@interface ::AbstractSparseArrayInterface storedvalues(a::AbstractArray) = StoredValues(a) - -@interface ::AbstractSparseArrayInterface function eachstoredindex( - a1::AbstractArray, a2::AbstractArray, a_rest::AbstractArray... -) - # TODO: Make this more customizable, say with a function - # `combine/promote_storedindices(a1, a2)`. - return union(eachstoredindex.((a1, a2, a_rest...))...) -end - -@interface ::AbstractSparseArrayInterface function eachstoredindex(a::AbstractArray) - # TODO: Use `MethodError`? - return error("Not implemented.") -end - -# We restrict to `I::Vararg{Int,N}` to allow more general functions to handle trailing -# indices and linear indices. -@interface ::AbstractSparseArrayInterface function Base.getindex( - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - !isstored(a, I...) && return getunstoredindex(a, I...) - return getstoredindex(a, I...) -end - -# We restrict to `I::Vararg{Int,N}` to allow more general functions to handle trailing -# indices and linear indices. -@interface ::AbstractSparseArrayInterface function Base.setindex!( - a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N} -) where {N} - if !isstored(a, I...) - # Don't set the value if it is zero, but only check - # if it is zero if the elements are numbers since otherwise - # it may be nontrivial to check. - eltype(a) <: Number && iszero(value) && return a - setunstoredindex!(a, value, I...) - return a - end - setstoredindex!(a, value, I...) - return a +@inline Base.getindex(a::StoredValues, I::Int) = getindex(a.array, a.storedindices[I]) +@inline function Base.setindex!(a::StoredValues, value, I::Int) + return setindex!(a.array, value, a.storedindices[I]) end # TODO: This may need to be defined in `sparsearraydok.jl`, after `SparseArrayDOK` diff --git a/src/indexing.jl b/src/indexing.jl new file mode 100644 index 0000000..06a6392 --- /dev/null +++ b/src/indexing.jl @@ -0,0 +1,408 @@ +using Base: @_propagate_inbounds_meta + +# Indexing interface +# ------------------ +# these definitions are not using @derive since we need the @inline annotation +# to correctly deal with boundschecks and @inbounds + +""" + getstoredindex(A::AbstractArray, I...) -> eltype(A) + +Obtain `getindex(A, I...)` with the guarantee that there is a stored entry at that location. + +Similar to `Base.getindex`, new definitions should be in line with `IndexStyle(A)`. +""" +@inline getstoredindex(A::AbstractArray, I...) = + @interface interface(A) getstoredindex(A, I...) + +""" + getunstoredindex(A::AbstractArray, I...) -> eltype(A) + +Obtain the value that would be returned by `getindex(A, I...)` when there is no stored entry +at that location. +By default, this takes an explicit copy of the `getindex` implementation to mimick a newly +instantiated object. + +Similar to `Base.getindex`, new definitions should be in line with `IndexStyle(A)`. +""" +@inline getunstoredindex(A::AbstractArray, I...) = + @interface interface(A) getunstoredindex(A, I...) + +""" + isstored(A::AbstractArray, I...) -> Bool + +Check if the array `A` has a stored entry at the location specified by indices `I...`. +For generic array types this defaults to `true` whenever the indices are inbounds, but +sparse array types might overload this function when appropriate. + +Similar to `Base.getindex`, new definitions should be in line with `IndexStyle(A)`. +""" +@inline isstored(A::AbstractArray, I...) = @interface interface(A) isstored(A, I...) + +""" + setstoredindex!(A::AbstractArray, v, I...) -> A + +`setindex!(A, v, I...)` with the guarantee that there is a stored entry at the given location. + +Similar to `Base.setindex!`, new definitions should be in line with `IndexStyle(A)`. +""" +@inline setstoredindex!(A::AbstractArray, v, I...) = + @interface interface(A) setstoredindex!(A, v, I...) + +""" + setunstoredindex!(A::AbstractArray, v, I...) -> A + +`setindex!(A, v, I...)` with the guarantee that there is no stored entry at the given location. + +Similar to `Base.setindex!`, new definitions should be in line with `IndexStyle(A)`. +""" +@inline setunstoredindex!(A::AbstractArray, v, I...) = + @interface interface(A) setunstoredindex!(A, v, I...) + +# Indices interface +# ----------------- +""" + eachstoredindex(A::AbstractArray...) + eachstoredindex(style::IndexStyle, A::AbstractArray...) + +An iterable over all indices of the stored values. +For multiple arrays, the iterable contains all indices where at least one input has a stored value. +The type of indices can be controlled through `style`, which will default to a compatible style for all +inputs. + +The order of the iterable is not guaranteed to be fixed or sorted, and should not be assumed +to be the same as [`storedvalues`](@ref). + +See also [`storedvalues`](@ref), [`storedpairs`](@ref) and [`storedlength`](@ref). +""" +function eachstoredindex end + +""" + storedlength(A::AbstractArray) -> Int + +The number of values that are currently being stored. +""" +function storedlength end + +""" + storedpairs(A::AbstractArray) -> (k, v)... + +An iterable over all stored indices and their corresponding stored values. +The indices are compatible with `IndexStyle(A)`. + +The order of the iterable is not guaranteed to be fixed or sorted. +See also [`eachstoredindex`](@ref) and [`storedvalues`](@ref). +""" +function storedpairs end + +""" + storedvalues(A::AbstractArray) -> v... + +An iterable over all stored values. + +The order of the iterable is not guaranteed to be fixed or sorted, and should not be assumed +to be the same as [`eachstoredindex`](@ref). +""" +function storedvalues end + +@derive (T=AbstractArray,) begin + SparseArraysBase.eachstoredindex(::T...) + SparseArraysBase.eachstoredindex(::IndexStyle, ::T...) + SparseArraysBase.storedlength(::T) + SparseArraysBase.storedpairs(::T) + SparseArraysBase.storedvalues(::T) +end + +# canonical indexing +# ------------------ +# ensure functions only have to be defined in terms of a single canonical f: +# f(::AbstractArray, I::Int) if IndexLinear +# f(::AbstractArray{<:Any,N}, I::Vararg{Int,N}) if IndexCartesian + +for f in (:isstored, :getunstoredindex, :getstoredindex) + _f = Symbol(:_, f) + error_if_canonical = Symbol(:error_if_canonical_, f) + @eval begin + @interface ::AbstractArrayInterface function $f(A::AbstractArray, I...) + @_propagate_inbounds_meta + style = IndexStyle(A) + $error_if_canonical(style, A, I...) + return $_f(style, A, Base.to_indices(A, I)...) + end + + # linear indexing + @inline $_f(::IndexLinear, A::AbstractVector, i::Int) = $f(A, i) + @inline $_f(::IndexLinear, A::AbstractArray, i::Int) = $f(A, i) + @inline function $_f(::IndexLinear, A::AbstractArray, I::Vararg{Int,M}) where {M} + @boundscheck checkbounds(A, I...) + return @inbounds $f(A, Base._to_linear_index(A, I...)) + end + + # cartesian indexing + @inline function $_f(::IndexCartesian, A::AbstractArray, I::Vararg{Int,M}) where {M} + @boundscheck checkbounds(A, I...) + return @inbounds $f(A, Base._to_subscript_indices(A, I...)...) + end + @inline function $_f( + ::IndexCartesian, A::AbstractArray{<:Any,N}, I::Vararg{Int,N} + ) where {N} + return $f(A, I...) + end + + # errors + $_f(::IndexStyle, A::AbstractArray, I...) = + error("`$f` for $("$(typeof(A))") with types $("$(typeof(I))") is not supported") + + $error_if_canonical(::IndexLinear, A::AbstractArray, ::Int) = + throw(Base.CanonicalIndexError("$($f)", typeof(A))) + $error_if_canonical( + ::IndexCartesian, A::AbstractArray{<:Any,N}, ::Vararg{Int,N} + ) where {N} = throw(Base.CanonicalIndexError("$($f)", typeof(A))) + $error_if_canonical(::IndexStyle, A::AbstractArray, ::Any...) = nothing + end +end + +for f! in (:setunstoredindex!, :setstoredindex!) + _f! = Symbol(:_, f!) + error_if_canonical = Symbol(:error_if_canonical_, f!) + @eval begin + @interface ::AbstractArrayInterface function $f!(A::AbstractArray, v, I...) + @_propagate_inbounds_meta + style = IndexStyle(A) + $error_if_canonical(style, A, I...) + return $_f!(style, A, v, Base.to_indices(A, I)...) + end + + # linear indexing + @inline $_f!(::IndexLinear, A::AbstractVector, v, i::Int) = $f!(A, v, i) + @inline $_f!(::IndexLinear, A::AbstractArray, v, i::Int) = $f!(A, v, i) + @inline function $_f!(::IndexLinear, A::AbstractArray, v, I::Vararg{Int,M}) where {M} + @boundscheck checkbounds(A, I...) + return @inbounds $f!(A, v, Base._to_linear_index(A, I...)) + end + + # cartesian indexing + @inline function $_f!(::IndexCartesian, A::AbstractArray, v, I::Vararg{Int,M}) where {M} + @boundscheck checkbounds(A, I...) + return @inbounds $f!(A, v, Base._to_subscript_indices(A, I...)...) + end + @inline function $_f!( + ::IndexCartesian, A::AbstractArray{<:Any,N}, v, I::Vararg{Int,N} + ) where {N} + return $f!(A, v, I...) + end + + # errors + $_f!(::IndexStyle, A::AbstractArray, I...) = + error("`$f!` for $("$(typeof(A))") with types $("$(typeof(I))") is not supported") + + $error_if_canonical(::IndexLinear, A::AbstractArray, ::Int) = + throw(Base.CanonicalIndexError("$($(string(f!)))", typeof(A))) + $error_if_canonical( + ::IndexCartesian, A::AbstractArray{<:Any,N}, ::Vararg{Int,N} + ) where {N} = throw(Base.CanonicalIndexError("$($f!)", typeof(A))) + $error_if_canonical(::IndexStyle, A::AbstractArray, ::Any...) = nothing + end +end + +# AbstractArrayInterface fallback definitions +# ------------------------------------------- +@interface ::AbstractArrayInterface function isstored(A::AbstractArray, i::Int, I::Int...) + @inline + @boundscheck checkbounds(A, i, I...) + return true +end + +@interface ::AbstractArrayInterface function getunstoredindex(A::AbstractArray, I::Int...) + @inline + @boundscheck checkbounds(A, I...) + return zero(eltype(A)) +end +@interface ::AbstractArrayInterface function getstoredindex(A::AbstractArray, I::Int...) + @inline + return getindex(A, I...) +end + +@interface ::AbstractArrayInterface function setstoredindex!(A::AbstractArray, v, I::Int...) + @inline + return setindex!(A, v, I...) +end +@interface ::AbstractArrayInterface setunstoredindex!(A::AbstractArray, v, I::Int...) = + error("setunstoredindex! for $(typeof(A)) is not supported") + +@interface ::AbstractArrayInterface eachstoredindex(A::AbstractArray, B::AbstractArray...) = + eachstoredindex(IndexStyle(A, B...), A, B...) +@interface ::AbstractArrayInterface eachstoredindex( + style::IndexStyle, A::AbstractArray, B::AbstractArray... +) = eachindex(style, A, B...) + +@interface ::AbstractArrayInterface storedvalues(A::AbstractArray) = values(A) +@interface ::AbstractArrayInterface storedpairs(A::AbstractArray) = pairs(A) +@interface ::AbstractArrayInterface storedlength(A::AbstractArray) = length(storedvalues(A)) + +# SparseArrayInterface implementations +# ------------------------------------ +# canonical errors are moved to `isstored`, `getstoredindex` and `getunstoredindex` +# so no errors at this level by defining both IndexLinear and IndexCartesian +@interface ::AbstractSparseArrayInterface function Base.getindex( + A::AbstractArray{<:Any,N}, I::Vararg{Int,N} +) where {N} + @_propagate_inbounds_meta + @boundscheck checkbounds(A, I...) # generally isstored requires bounds checking + return @inbounds isstored(A, I...) ? getstoredindex(A, I...) : getunstoredindex(A, I...) +end +@interface ::AbstractSparseArrayInterface function Base.getindex(A::AbstractArray, I::Int) + @_propagate_inbounds_meta + @boundscheck checkbounds(A, I) + return @inbounds isstored(A, I) ? getstoredindex(A, I) : getunstoredindex(A, I) +end +# disambiguate vectors +@interface ::AbstractSparseArrayInterface function Base.getindex(A::AbstractVector, I::Int) + @_propagate_inbounds_meta + @boundscheck checkbounds(A, I) + return @inbounds isstored(A, I) ? getstoredindex(A, I) : getunstoredindex(A, I) +end + +@interface ::AbstractSparseArrayInterface function Base.setindex!( + A::AbstractArray{<:Any,N}, v, I::Vararg{Int,N} +) where {N} + @_propagate_inbounds_meta + @boundscheck checkbounds(A, I...) + return @inbounds if isstored(A, I...) + setstoredindex!(A, v, I...) + else + setunstoredindex!(A, v, I...) + end +end +@interface ::AbstractSparseArrayInterface function Base.setindex!( + A::AbstractArray, v, I::Int +) + @_propagate_inbounds_meta + @boundscheck checkbounds(A, I) + return @inbounds if isstored(A, I) + setstoredindex!(A, v, I) + else + setunstoredindex!(A, v, I) + end +end +# disambiguate vectors +@interface ::AbstractSparseArrayInterface function Base.setindex!( + A::AbstractVector, v, I::Int +) + @_propagate_inbounds_meta + @boundscheck checkbounds(A, I) + return @inbounds if isstored(A, I) + setstoredindex!(A, v, I) + else + setunstoredindex!(A, v, I) + end +end + +# required: +@interface ::AbstractSparseArrayInterface eachstoredindex( + style::IndexStyle, A::AbstractArray +) = throw(MethodError(eachstoredindex, Tuple{typeof(style),typeof(A)})) + +# derived but may be specialized: +@interface ::AbstractSparseArrayInterface function eachstoredindex( + style::IndexStyle, A::AbstractArray, B::AbstractArray... +) + return union(map(Base.Fix1(eachstoredindex, style), (A, B...))...) +end + +@interface ::AbstractSparseArrayInterface storedvalues(A::AbstractArray) = StoredValues(A) + +# default implementation is a bit tricky here: we don't know if this is the "canonical" +# implementation, so we check this and otherwise map back to `_isstored` to canonicalize the +# indices +@interface ::AbstractSparseArrayInterface function isstored(A::AbstractArray, I::Int...) + @_propagate_inbounds_meta + style = IndexStyle(A) + # canonical linear indexing + if style == IndexLinear() && length(I) == 1 + @boundscheck checkbounds(A, I...) + return only(I) in eachstoredindex(style, A) + end + + # canonical cartesian indexing + if style == IndexCartesian() && length(I) == ndims(A) + @boundscheck checkbounds(A, I...) + return CartesianIndex(I...) in eachstoredindex(style, A) + end + + # non-canonical indexing + return _isstored(style, A, Base.to_indices(A, I)...) +end + +@interface ::AbstractSparseArrayInterface function getunstoredindex( + A::AbstractArray, I::Int... +) + @_propagate_inbounds_meta + style = IndexStyle(A) + + # canonical linear indexing + if style == IndexLinear() && length(I) == 1 + @boundscheck checkbounds(A, I...) + return zero(eltype(A)) + end + + # canonical cartesian indexing + if style == IndexCartesian() && length(I) == ndims(A) + @boundscheck checkbounds(A, I...) + return zero(eltype(A)) + end + + # non-canonical indexing + return _getunstoredindex(style, A, Base.to_indices(A, I)...) +end + +# make sure we don't call AbstractArrayInterface defaults +@interface ::AbstractSparseArrayInterface function getstoredindex( + A::AbstractArray, I::Int... +) + @_propagate_inbounds_meta + style = IndexStyle(A) + error_if_canonical_getstoredindex(style, A, I...) + return _getstoredindex(style, A, Base.to_indices(A, I)...) +end + +for f! in (:setstoredindex!, :setunstoredindex!) + _f! = Symbol(:_, f!) + error_if_canonical_setstoredindex = Symbol(:error_if_canonical_, f!) + @eval begin + @interface ::AbstractSparseArrayInterface function $f!(A::AbstractArray, v, I::Int...) + @_propagate_inbounds_meta + style = IndexStyle(A) + $error_if_canonical_setstoredindex(style, A, I...) + return $_f!(style, A, v, Base.to_indices(A, I)...) + end + end +end + +@interface ::AbstractSparseArrayInterface storedlength(A::AbstractArray) = + length(storedvalues(A)) +@interface ::AbstractSparseArrayInterface function storedpairs(A::AbstractArray) + return Iterators.map(I -> (I => A[I]), eachstoredindex(A)) +end + +#= +All sparse array interfaces are mapped through layout_getindex. (is this too opinionated?) + +using ArrayLayouts getindex: this is a bit cumbersome because there already is a way to make +that work focused on types but here we want to focus on interfaces. +eg: ArrayLayouts.@layoutgetindex ArrayType +TODO: decide if we need the interface approach at all here +=# +for (Tr, Tc) in Iterators.product( + Iterators.repeated((:Colon, :AbstractUnitRange, :AbstractVector, :Integer), 2)... +) + Tr === Tc === :Integer && continue + @eval begin + @interface ::AbstractSparseArrayInterface function Base.getindex( + A::AbstractMatrix, kr::$Tr, jr::$Tc + ) + Base.@inline # needed to make boundschecks work + return ArrayLayouts.layout_getindex(A, kr, jr) + end + end +end diff --git a/src/sparsearraydok.jl b/src/sparsearraydok.jl index 983949f..7365127 100644 --- a/src/sparsearraydok.jl +++ b/src/sparsearraydok.jl @@ -74,28 +74,36 @@ storage(a::SparseArrayDOK) = a.storage Base.size(a::SparseArrayDOK) = a.size storedvalues(a::SparseArrayDOK) = values(storage(a)) -function isstored(a::SparseArrayDOK{<:Any,N}, I::Vararg{Int,N}) where {N} - return @interface interface(a) isstored(a, I...) +@inline function isstored(a::SparseArrayDOK{<:Any,N}, I::Vararg{Int,N}) where {N} + @boundscheck checkbounds(a, I...) + return haskey(storage(a), CartesianIndex(I)) end -function eachstoredindex(a::SparseArrayDOK) +function eachstoredindex(::IndexCartesian, a::SparseArrayDOK) return keys(storage(a)) end -function getstoredindex(a::SparseArrayDOK, I::Int...) +@inline function getstoredindex(a::SparseArrayDOK{<:Any,N}, I::Vararg{Int,N}) where {N} + @boundscheck checkbounds(a, I...) return storage(a)[CartesianIndex(I)] end -function getunstoredindex(a::SparseArrayDOK, I::Int...) +@inline function getunstoredindex(a::SparseArrayDOK{<:Any,N}, I::Vararg{Int,N}) where {N} + @boundscheck checkbounds(a, I...) return a.getunstored(a, I...) end -function setstoredindex!(a::SparseArrayDOK, value, I::Int...) - # TODO: Have a way to disable this check, analogous to `checkbounds`, - # since this is already checked in `setindex!`. - isstored(a, I...) || throw(IndexError("key $(CartesianIndex(I)) not found")) +@inline function setstoredindex!( + a::SparseArrayDOK{<:Any,N}, value, I::Vararg{Int,N} +) where {N} + # `isstored` includes a boundscheck as well + @boundscheck isstored(a, I...) || + throw(IndexError(lazy"key $(CartesianIndex(I...)) not found")) # TODO: If `iszero(value)`, unstore the index. storage(a)[CartesianIndex(I)] = value return a end -function setunstoredindex!(a::SparseArrayDOK, value, I::Int...) - set!(storage(a), CartesianIndex(I), value) +@inline function setunstoredindex!( + a::SparseArrayDOK{<:Any,N}, value, I::Vararg{Int,N} +) where {N} + @boundscheck checkbounds(a, I...) + insert!(storage(a), CartesianIndex(I), value) return a end diff --git a/src/wrappers.jl b/src/wrappers.jl index 23b4286..bd9e481 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -1,6 +1,9 @@ parentvalue_to_value(a::AbstractArray, value) = value value_to_parentvalue(a::AbstractArray, value) = value eachstoredparentindex(a::AbstractArray) = eachstoredindex(parent(a)) +function eachstoredparentindex(style::IndexStyle, a::AbstractArray) + return eachstoredindex(style, parent(a)) +end storedparentvalues(a::AbstractArray) = storedvalues(parent(a)) function parentindex_to_index(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N} @@ -22,7 +25,7 @@ function index_to_parentindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where end # Handle linear indexing. function index_to_parentindex(a::AbstractArray, I::Int) - return index_to_parentindex(a, CartesianIndices(a)[I]) + return LinearIndices(parent(a))[index_to_parentindex(a, CartesianIndices(a)[I])] end function cartesianindex_reverse(I::CartesianIndex) @@ -76,6 +79,11 @@ function eachstoredparentindex(a::SubArray) return all(d -> I[d] ∈ parentindices(a)[d], 1:ndims(parent(a))) end end +function eachstoredparentindex(style::IndexStyle, a::SubArray) + return filter(eachstoredindex(style, parent(a))) do I + return all(d -> I[d] ∈ parentindices(a)[d], 1:ndims(parent(a))) + end +end # Don't constrain the number of dimensions of the array # and index since the parent array can have a different # number of dimensions than the `SubArray`. @@ -145,10 +153,13 @@ for type in (:Adjoint, :PermutedDimsArray, :ReshapedArray, :SubArray, :Transpose @eval begin @interface ::AbstractSparseArrayInterface storedvalues(a::$type) = storedparentvalues(a) @interface ::AbstractSparseArrayInterface function eachstoredindex(a::$type) + return map(Base.Fix1(parentindex_to_index, a), eachstoredparentindex(a)) + end + @interface ::AbstractSparseArrayInterface function eachstoredindex( + style::IndexStyle, a::$type + ) # TODO: Make lazy with `Iterators.map`. - return map(collect(eachstoredparentindex(a))) do I - return parentindex_to_index(a, I) - end + return map(Base.Fix1(parentindex_to_index, a), eachstoredparentindex(style, a)) end @interface ::AbstractSparseArrayInterface function getstoredindex(a::$type, I::Int...) return parentvalue_to_value( @@ -193,7 +204,7 @@ end @interface ::AbstractArrayInterface eachstoredindex(D::Diagonal) = _diagind(D, IndexCartesian()) -function isstored(D::Diagonal, i::Int, j::Int) +@interface ::AbstractArrayInterface function isstored(D::Diagonal, i::Int, j::Int) return i == j && checkbounds(Bool, D, i, j) end @interface ::AbstractArrayInterface function getstoredindex(D::Diagonal, i::Int, j::Int) diff --git a/test/test_basics.jl b/test/test_basics.jl index 3aed922..d9f7661 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -36,7 +36,7 @@ arrayts = (Array, JLArray) # probably we can have a trait for that. It could be based # on the `ArrayLayouts.MemoryLayout`. @allowscalar @test storedvalues(a) == a - @allowscalar @test storedpairs(a) == collect(pairs(vec(a))) + @allowscalar @test storedpairs(a) == pairs(a) @allowscalar for I in eachindex(a) @test getstoredindex(a, I) == a[I] @test iszero(getunstoredindex(a, I)) diff --git a/test/test_oneelementarray.jl b/test/test_oneelementarray.jl index 9fc5854..1cb9599 100644 --- a/test/test_oneelementarray.jl +++ b/test/test_oneelementarray.jl @@ -35,8 +35,8 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a[2, 2] === zero(Bool) @test storedlength(a) == 1 @test collect(eachstoredindex(a)) == [CartesianIndex(1, 2)] - @test storedpairs(a) == [CartesianIndex(1, 2) => 1] - @test storedvalues(a) == [1] + @test collect(storedpairs(a)) == [CartesianIndex(1, 2) => 1] + @test collect(storedvalues(a)) == [1] end for a in (OneElementArray(1, 2), OneElementVector(1, 2)) @@ -47,8 +47,8 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a[2] === zero(Bool) @test storedlength(a) == 1 @test collect(eachstoredindex(a)) == [CartesianIndex(1)] - @test storedpairs(a) == [CartesianIndex(1) => 1] - @test storedvalues(a) == [1] + @test collect(storedpairs(a)) == [CartesianIndex(1) => 1] + @test collect(storedvalues(a)) == [1] end a = OneElementArray() diff --git a/test/test_sparsearraydok.jl b/test/test_sparsearraydok.jl index 6616a4c..91f76d2 100644 --- a/test/test_sparsearraydok.jl +++ b/test/test_sparsearraydok.jl @@ -33,14 +33,18 @@ arrayts = (Array,) a[1, 2] = 12 @test a isa SparseArrayDOK{elt,2} @test size(a) == (2, 2) + @test a[1] == 0 @test a[1, 1] == 0 @test a[1, 1, 1] == 0 + @test a[3] == 12 @test a[1, 2] == 12 @test a[1, 2, 1] == 12 @test storedlength(a) == 1 + @test_throws BoundsError a[5] + @test_throws BoundsError a[1, 3] a = SparseArrayDOK{elt}(undef, 2, 2) - a[1, 2] = 12 + a[3] = 12 for b in (similar(a, Float32, (3, 3)), similar(a, Float32, Base.OneTo.((3, 3)))) @test b isa SparseArrayDOK{Float32,2} @test b == zeros(Float32, 3, 3) @@ -59,13 +63,15 @@ arrayts = (Array,) # isstored a = SparseArrayDOK{elt}(undef, 4, 4) a[2, 3] = 23 - for I in CartesianIndices(a) + for (I, i) in zip(CartesianIndices(a), LinearIndices(a)) if I == CartesianIndex(2, 3) @test isstored(a, I) @test isstored(a, Tuple(I)...) + @test isstored(a, i) else @test !isstored(a, I) @test !isstored(a, Tuple(I)...) + @test !isstored(a, i) end end @@ -83,12 +89,21 @@ arrayts = (Array,) end end + # vector + a = SparseArrayDOK{elt}(undef, 2) + a[2] = 12 + @test b[1] == 0 + @test a[2] == 12 + @test storedlength(a) == 1 + a = SparseArrayDOK{elt}(undef, 3, 3, 3) a[1, 2, 3] = 123 b = permutedims(a, (2, 3, 1)) @test b isa SparseArrayDOK{elt,3} @test b[2, 3, 1] == 123 @test storedlength(b) == 1 + @test b[1] == 0 + @test b[LinearIndices(b)[2, 3, 1]] == 123 a = SparseArrayDOK{elt}(undef, 2, 2) a[1, 2] = 12