@@ -104,11 +104,10 @@ to_vec(x::AbstractArray) = vec(x)
104
104
105
105
# TODO : This may need to be defined in `sparsearraydok.jl`, after `SparseArrayDOK`
106
106
# is defined. And/or define `default_type(::SparseArrayStyle, T::Type) = SparseArrayDOK{T}`.
107
- @interface :: AbstractSparseArrayInterface function Base. similar (
108
- a:: AbstractArray , T:: Type , size:: Tuple{Vararg{Int}}
109
- )
110
- # TODO : Define `default_similartype` or something like that?
111
- return SparseArrayDOK {T} (size... )
107
+ @interface I:: AbstractSparseArrayInterface function Base. similar (
108
+ :: AbstractArray , :: Type{T} , ax
109
+ ) where {T}
110
+ return similar (I, T, ax)
112
111
end
113
112
114
113
using ArrayLayouts: ArrayLayouts, zero!
@@ -117,13 +116,11 @@ using ArrayLayouts: ArrayLayouts, zero!
117
116
# and is useful for sparse array logic, since it can be used to empty
118
117
# the sparse array storage.
119
118
# We use a single function definition to minimize method ambiguities.
120
- @interface interface:: AbstractSparseArrayInterface function ArrayLayouts . zero! (
121
- a :: AbstractArray
119
+ @interface interface:: AbstractSparseArrayInterface function DerivableInterfaces . zero! (
120
+ A :: AbstractArray
122
121
)
123
- # More generally, this codepath could be taking if `zero(eltype(a))`
124
- # is defined and the elements are immutable.
125
- f = eltype (a) <: Number ? Returns (zero (eltype (a))) : zero!
126
- return @interface interface map_stored! (f, a, a)
122
+ storedvalues (A) .= zero! (storedvalues (A))
123
+ return A
127
124
end
128
125
129
126
# `f::typeof(norm)`, `op::typeof(max)` used by `norm`.
150
147
return output
151
148
end
152
149
153
- abstract type AbstractSparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
154
-
155
- @derive (T= AbstractSparseArrayStyle,) begin
156
- Base. similar (:: Broadcast.Broadcasted{<:T} , :: Type , :: Tuple )
157
- Base. copyto! (:: AbstractArray , :: Broadcast.Broadcasted{<:T} )
158
- end
159
-
160
- struct SparseArrayStyle{N} <: AbstractSparseArrayStyle{N} end
161
-
162
- SparseArrayStyle {M} (:: Val{N} ) where {M,N} = SparseArrayStyle {N} ()
163
-
164
- DerivableInterfaces. interface (:: Type{<:AbstractSparseArrayStyle} ) = SparseArrayInterface ()
165
-
166
- @interface :: AbstractSparseArrayInterface function Broadcast. BroadcastStyle (type:: Type )
167
- return SparseArrayStyle {ndims(type)} ()
168
- end
169
-
170
150
using ArrayLayouts: ArrayLayouts, MatMulMatAdd
171
151
172
152
abstract type AbstractSparseLayout <: ArrayLayouts.MemoryLayout end
@@ -190,19 +170,20 @@ using LinearAlgebra: LinearAlgebra, mul!
190
170
@interface :: AbstractSparseArrayInterface function LinearAlgebra. mul! (
191
171
C:: AbstractArray , A:: AbstractArray , B:: AbstractArray , α:: Number , β:: Number
192
172
)
193
- a_dest .*= β
173
+ C .*= β
194
174
β′ = one (Bool)
195
- for I1 in eachstoredindex (a1)
196
- for I2 in eachstoredindex (a2)
197
- I_dest = mul_indices (I1, I2)
198
- if ! isnothing (I_dest)
199
- a_dest[I_dest] = mul! (a_dest[I_dest], a1[I1], a2[I2], α, β′)
200
- end
175
+ for iA in eachstoredindex (A), iB in eachstoredindex (B)
176
+ iC = mul_indices (iA, iB)
177
+ if ! isnothing (iC)
178
+ C[iC] = mul!! (C[iC], A[iA], B[iB], α, β′)
201
179
end
202
180
end
203
- return a_dest
181
+ return C
204
182
end
205
183
184
+ mul!! (C, A, B, α, β) = mul! (C, A, B, α, β)
185
+ mul!! (C:: Number , A:: Number , B:: Number , α:: Number , β:: Number ) = β * C + α * A * B
186
+
206
187
function ArrayLayouts. materialize! (
207
188
m:: MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout}
208
189
)
0 commit comments