Skip to content

Commit 2251dd4

Browse files
committed
Add map initial implementation
1 parent 5ffa937 commit 2251dd4

File tree

2 files changed

+127
-1
lines changed

2 files changed

+127
-1
lines changed

src/SparseArraysBase.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@ export SparseArrayDOK,
1515

1616
using DerivableInterfaces
1717

18-
include("indexing.jl")
18+
# sparsearrayinterface
1919
include("abstractsparsearrayinterface.jl")
20+
include("indexing.jl")
21+
include("map.jl")
2022
include("sparsearrayinterface.jl")
23+
24+
# types
2125
include("wrappers.jl")
2226
include("abstractsparsearray.jl")
2327
include("sparsearraydok.jl")

src/map.jl

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# zero-preserving Traits
2+
# ----------------------
3+
"""
4+
abstract type ZeroPreserving end
5+
6+
Holy Trait to indicate how a function interacts with abstract zero values:
7+
8+
- `StrongPreserving` : output is guaranteed to be zero if **any** input is.
9+
- `WeakPreserving` : output is guaranteed to be zero if **all** inputs are.
10+
- `NonPreserving` : no guarantees on output.
11+
12+
To attempt to automatically determine this, either `ZeroPreserving(f, A::AbstractArray...)` or
13+
`ZeroPreserving(f, T::Type...)` can be used/overloaded.
14+
15+
!!! warning
16+
incorrectly registering a function to be zero-preserving will lead to silently wrong results.
17+
"""
18+
abstract type ZeroPreserving end
19+
struct StrongPreserving <: ZeroPreserving end
20+
struct WeakPreserving <: ZeroPreserving end
21+
struct NonPreserving <: ZeroPreserving end
22+
23+
# warning: cannot automatically detect WeakPreserving since this would mean checking all values
24+
function ZeroPreserving(f, A::AbstractArray, Bs::AbstractArray...)
25+
return ZeroPreserving(f, eltype(A), eltype.(Bs)...)
26+
end
27+
# TODO: the following might not properly specialize on the types
28+
# TODO: non-concrete element types
29+
function ZeroPreserving(f, T::Type, Ts::Type...)
30+
return iszero(f(zero(T), zero.(Ts)...)) ? WeakPreserving() : NonPreserving()
31+
end
32+
33+
const _WEAK_FUNCTIONS = (:+, :-)
34+
for f in _WEAK_FUNCTIONS
35+
@eval begin
36+
ZeroPreserving(::typeof($f), ::AbstractArray, ::AbstractArray...) = WeakPreserving()
37+
ZeroPreserving(::typeof($f), ::Type, ::Type...) = WeakPreserving()
38+
end
39+
end
40+
41+
const _STRONG_FUNCTIONS = (:*,)
42+
for f in _STRONG_FUNCTIONS
43+
@eval begin
44+
ZeroPreserving(::typeof($f), ::AbstractArray, ::AbstractArray...) = StrongPreserving()
45+
ZeroPreserving(::typeof($f), ::Type, ::Type...) = StrongPreserving()
46+
end
47+
end
48+
49+
# map(!)
50+
# ------
51+
@interface I::AbstractSparseArrayInterface function Base.map(
52+
f, A::AbstractArray, Bs::AbstractArray...
53+
)
54+
T = Base.Broadcast.combine_eltypes(f, (A, Bs...))
55+
C = similar(I, T, size(A))
56+
return @interface I map!(f, C, A, Bs...)
57+
end
58+
59+
@interface ::AbstractSparseArrayInterface function Base.map!(
60+
f, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
61+
)
62+
return _map!(f, ZeroPreserving(f, A, Bs...), C, A, Bs...)
63+
end
64+
65+
function _map!(
66+
f, ::StrongPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
67+
)
68+
checkshape(C, A, Bs...)
69+
zero!(C)
70+
style = IndexStyle(C, A, Bs...)
71+
unaliased = map(Base.Fix1(Base.unalias, C), (A, Bs...))
72+
for I in intersect(eachstoredindex.(Ref(style), unaliased)...)
73+
@inbounds C[I] = f(ith_all(I, unaliased)...)
74+
end
75+
return C
76+
end
77+
function _map!(
78+
f, ::WeakPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
79+
)
80+
checkshape(C, A, Bs...)
81+
zero!(C)
82+
style = IndexStyle(C, A, Bs...)
83+
unaliased = map(Base.Fix1(Base.unalias, C), (A, Bs...))
84+
for I in union(eachstoredindex.(Ref(style), (A, Bs...))...)
85+
@inbounds C[I] = f(ith_all(I, unaliased)...)
86+
end
87+
return C
88+
end
89+
function _map!(f, ::NonPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...)
90+
checkshape(C, A, Bs...)
91+
unaliased = map(Base.Fix1(Base.unalias, C), (A, Bs...))
92+
for I in eachindex(C, A, Bs...)
93+
@inbounds C[I] = f(ith_all(I, unaliased)...)
94+
end
95+
return C
96+
end
97+
98+
# Derived functions
99+
# -----------------
100+
@interface I::AbstractSparseArrayInterface Base.copyto!(
101+
C::AbstractArray, A::AbstractArray
102+
) = @interface I map!(identity, C, A)
103+
104+
# Utility functions
105+
# -----------------
106+
# shape check similar to checkbounds
107+
checkshape(::Type{Bool}, A::AbstractArray) = true
108+
checkshape(::Type{Bool}, A::AbstractArray, B::AbstractArray) = size(A) == size(B)
109+
function checkshape(::Type{Bool}, A::AbstractArray, Bs::AbstractArray...)
110+
return allequal(size, (A, Bs...))
111+
end
112+
113+
function checkshape(A::AbstractArray, Bs::AbstractArray...)
114+
return checkshape(Bool, A, Bs...) ||
115+
throw(DimensionMismatch("argument shapes must match"))
116+
end
117+
118+
@inline ith_all(i, ::Tuple{}) = ()
119+
function ith_all(i, as)
120+
@_propagate_inbounds_meta
121+
return (as[1][i], ith_all(i, Base.tail(as))...)
122+
end

0 commit comments

Comments
 (0)