Skip to content

Commit be04ae0

Browse files
committed
some tweaks
1 parent 49eddc5 commit be04ae0

File tree

3 files changed

+30
-47
lines changed

3 files changed

+30
-47
lines changed

src/abstractsparsearray.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,19 @@ using LinearAlgebra: LinearAlgebra
1818
const MIMEtextplain = MIME"text/plain"
1919

2020
@derive (T=AnyAbstractSparseArray,) begin
21-
Base.getindex(::T, ::Any...)
2221
Base.getindex(::T, ::Int...)
23-
Base.setindex!(::T, ::Any, ::Any...)
2422
Base.setindex!(::T, ::Any, ::Int...)
2523
Base.similar(::T, ::Type, ::Tuple{Vararg{Int}})
2624
Base.similar(::T, ::Type, ::Tuple{Base.OneTo,Vararg{Base.OneTo}})
27-
Base.copy(::T)
28-
Base.copy!(::AbstractArray, ::T)
29-
Base.copyto!(::AbstractArray, ::T)
25+
# Base.copy(::T)
26+
# Base.copy!(::AbstractArray, ::T)
27+
# Base.copyto!(::AbstractArray, ::T)
3028
Base.map(::Any, ::T...)
3129
Base.map!(::Any, ::AbstractArray, ::T...)
32-
Base.mapreduce(::Any, ::Any, ::T...; kwargs...)
33-
Base.reduce(::Any, ::T...; kwargs...)
34-
Base.all(::Function, ::T)
35-
Base.all(::T)
30+
# Base.mapreduce(::Any, ::Any, ::T...; kwargs...)
31+
# Base.reduce(::Any, ::T...; kwargs...)
32+
# Base.all(::Function, ::T)
33+
# Base.all(::T)
3634
Base.iszero(::T)
3735
Base.real(::T)
3836
Base.fill!(::T, ::Any)
@@ -44,7 +42,7 @@ const MIMEtextplain = MIME"text/plain"
4442
Base.cat(::T...; kwargs...)
4543
ArrayLayouts.MemoryLayout(::Type{<:T})
4644
LinearAlgebra.mul!(::AbstractMatrix, ::T, ::T, ::Number, ::Number)
47-
Base.show(::IO, ::MIMEtextplain, ::T)
45+
# Base.show(::IO, ::MIMEtextplain, ::T)
4846
end
4947

5048
function Base.replace_in_print_matrix(

src/abstractsparsearrayinterface.jl

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,10 @@ to_vec(x::AbstractArray) = vec(x)
104104

105105
# TODO: This may need to be defined in `sparsearraydok.jl`, after `SparseArrayDOK`
106106
# is defined. And/or define `default_type(::SparseArrayStyle, T::Type) = SparseArrayDOK{T}`.
107-
@interface ::AbstractSparseArrayInterface function Base.similar(
108-
a::AbstractArray, T::Type, size::Tuple{Vararg{Int}}
109-
)
110-
# TODO: Define `default_similartype` or something like that?
111-
return SparseArrayDOK{T}(size...)
107+
@interface I::AbstractSparseArrayInterface function Base.similar(
108+
::AbstractArray, ::Type{T}, ax
109+
) where {T}
110+
return similar(I, T, ax)
112111
end
113112

114113
using ArrayLayouts: ArrayLayouts, zero!
@@ -117,13 +116,11 @@ using ArrayLayouts: ArrayLayouts, zero!
117116
# and is useful for sparse array logic, since it can be used to empty
118117
# the sparse array storage.
119118
# We use a single function definition to minimize method ambiguities.
120-
@interface interface::AbstractSparseArrayInterface function ArrayLayouts.zero!(
121-
a::AbstractArray
119+
@interface interface::AbstractSparseArrayInterface function DerivableInterfaces.zero!(
120+
A::AbstractArray
122121
)
123-
# More generally, this codepath could be taking if `zero(eltype(a))`
124-
# is defined and the elements are immutable.
125-
f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero!
126-
return @interface interface map_stored!(f, a, a)
122+
storedvalues(A) .= zero!(storedvalues(A))
123+
return A
127124
end
128125

129126
# `f::typeof(norm)`, `op::typeof(max)` used by `norm`.
@@ -150,23 +147,6 @@ end
150147
return output
151148
end
152149

153-
abstract type AbstractSparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
154-
155-
@derive (T=AbstractSparseArrayStyle,) begin
156-
Base.similar(::Broadcast.Broadcasted{<:T}, ::Type, ::Tuple)
157-
Base.copyto!(::AbstractArray, ::Broadcast.Broadcasted{<:T})
158-
end
159-
160-
struct SparseArrayStyle{N} <: AbstractSparseArrayStyle{N} end
161-
162-
SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}()
163-
164-
DerivableInterfaces.interface(::Type{<:AbstractSparseArrayStyle}) = SparseArrayInterface()
165-
166-
@interface ::AbstractSparseArrayInterface function Broadcast.BroadcastStyle(type::Type)
167-
return SparseArrayStyle{ndims(type)}()
168-
end
169-
170150
using ArrayLayouts: ArrayLayouts, MatMulMatAdd
171151

172152
abstract type AbstractSparseLayout <: ArrayLayouts.MemoryLayout end
@@ -190,19 +170,20 @@ using LinearAlgebra: LinearAlgebra, mul!
190170
@interface ::AbstractSparseArrayInterface function LinearAlgebra.mul!(
191171
C::AbstractArray, A::AbstractArray, B::AbstractArray, α::Number, β::Number
192172
)
193-
a_dest .*= β
173+
C .*= β
194174
β′ = one(Bool)
195-
for I1 in eachstoredindex(a1)
196-
for I2 in eachstoredindex(a2)
197-
I_dest = mul_indices(I1, I2)
198-
if !isnothing(I_dest)
199-
a_dest[I_dest] = mul!(a_dest[I_dest], a1[I1], a2[I2], α, β′)
200-
end
175+
for iA in eachstoredindex(A), iB in eachstoredindex(B)
176+
iC = mul_indices(iA, iB)
177+
if !isnothing(iC)
178+
C[iC] = mul!!(C[iC], A[iA], B[iB], α, β′)
201179
end
202180
end
203-
return a_dest
181+
return C
204182
end
205183

184+
mul!!(C, A, B, α, β) = mul!(C, A, B, α, β)
185+
mul!!(C::Number, A::Number, B::Number, α::Number, β::Number) = β * C + α * A * B
186+
206187
function ArrayLayouts.materialize!(
207188
m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout}
208189
)

src/sparsearraydok.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ function SparseArrayDOK{T}(size::Int...) where {T}
3535
return SparseArrayDOK{T,length(size)}(size...)
3636
end
3737

38+
function SparseArrayDOK{T}(::UndefInitializer, axes::Tuple) where {T}
39+
return SparseArrayDOK{T}(undef, Base.to_shape(axes))
40+
end
41+
3842
using DerivableInterfaces: @array_aliases
3943
# Define `SparseMatrixDOK`, `AnySparseArrayDOK`, etc.
4044
@array_aliases SparseArrayDOK
@@ -46,7 +50,7 @@ storedvalues(a::SparseArrayDOK) = values(storage(a))
4650
function isstored(a::SparseArrayDOK, I::Int...)
4751
return CartesianIndex(I) in keys(storage(a))
4852
end
49-
function eachstoredindex(a::SparseArrayDOK)
53+
function eachstoredindex(::IndexCartesian, a::SparseArrayDOK)
5054
return keys(storage(a))
5155
end
5256
function getstoredindex(a::SparseArrayDOK, I::Int...)

0 commit comments

Comments
 (0)