Skip to content

Commit f525dc5

Browse files
authored
Merge pull request #21 from yuehhua/develop
Scatters support almost all Real numbers and add test coverage
2 parents 97ac0bc + 4cbec22 commit f525dc5

File tree

9 files changed

+112
-54
lines changed

9 files changed

+112
-54
lines changed

src/GeometricFlux.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
module GeometricFlux
22

3-
using Core.Intrinsics: llvmcall
4-
using Base.Threads
53
using Statistics: mean
64
using SparseArrays: SparseMatrixCSC
75
using LinearAlgebra: I, issymmetric, diagm, eigmax
@@ -13,13 +11,6 @@ using Flux: glorot_uniform, leakyrelu, GRUCell
1311
using Flux: @functor
1412
using ZygoteRules
1513

16-
import Base: identity
17-
import Base.Threads: atomictypes, llvmtypes, inttype, ArithmeticTypes, FloatTypes,
18-
atomic_cas!, atomic_xchg!,
19-
atomic_add!, atomic_sub!, atomic_max!, atomic_min!,
20-
atomic_and!, atomic_nand!, atomic_or!, atomic_xor!
21-
import Base.Sys: ARCH, WORD_SIZE
22-
2314
export
2415

2516
# layers/meta

src/atomic.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
# from https://github.com/chengchingwen/Transformers.jl/tree/master/src/fix
2+
using Core.Intrinsics: llvmcall
3+
import Base: identity
4+
import Base.Threads: atomictypes, llvmtypes, inttype, ArithmeticTypes, FloatTypes,
5+
atomic_cas!, atomic_xchg!,
6+
atomic_add!, atomic_sub!, atomic_max!, atomic_min!,
7+
atomic_and!, atomic_nand!, atomic_or!, atomic_xor!
8+
import Base.Sys: ARCH, WORD_SIZE
29

310
for typ atomictypes
411
lt = llvmtypes[typ]

src/cuda/scatter.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,6 @@ end
151151
end
152152
end
153153

154-
function gather_indices(X::CuArray{T}) where T
155-
Y = gather_indices(Array(X))
156-
cuY = Dict{T,CuVector}(k => cu(Tuple.(v)) for (k, v) in Y)
157-
cuY
158-
end
159-
160154
@adjoint function scatter_max!(ys::CuArray{T}, us::CuArray{T}, xs::CuArray) where {T<:AbstractFloat}
161155
max = copy(ys)
162156
scatter_max!(max, us, xs)
@@ -180,9 +174,3 @@ end
180174
(Δy, Δu, nothing)
181175
end
182176
end
183-
184-
function numerical_cmp(X::CuArray{T}, Y::CuArray) where T
185-
Z = map((x,y) -> sign(x - y)^2, X, Y)
186-
Z = map(x -> (one(T) - x)^2, Z)
187-
Z
188-
end

src/cuda/utils.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
## Inverse operation of scatter
2+
13
gather(input::AbstractArray{T}, index::CuArray{Int}) where T = gather(cu(input), index)
24

35
function gather(input::CuMatrix{T}, index::CuArray{Int}) where T
@@ -7,3 +9,15 @@ function gather(input::CuMatrix{T}, index::CuArray{Int}) where T
79
end
810
return out
911
end
12+
13+
function gather_indices(X::CuArray{T}) where T
14+
Y = gather_indices(Array(X))
15+
cuY = Dict{T,CuVector}(k => cu(Tuple.(v)) for (k, v) in Y)
16+
cuY
17+
end
18+
19+
function numerical_cmp(X::CuArray{T}, Y::CuArray) where T
20+
Z = map((x,y) -> sign(x - y)^2, X, Y)
21+
Z = map(x -> (one(T) - x)^2, Z)
22+
Z
23+
end

src/layers/msgpass.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using Base.Threads
2+
13
abstract type MessagePassing <: Meta end
24

35
adjlist(m::T) where {T<:MessagePassing} = m.adjlist

src/layers/pool.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
1-
floattype(::Type{T}) where {T<:AbstractFloat} = T
2-
floattype(::Type{Int8}) = Float16
3-
floattype(::Type{UInt8}) = Float16
4-
floattype(::Type{Int16}) = Float16
5-
floattype(::Type{UInt16}) = Float16
6-
floattype(::Type{Int32}) = Float32
7-
floattype(::Type{UInt32}) = Float32
8-
floattype(::Type{Int64}) = Float64
9-
floattype(::Type{UInt64}) = Float64
10-
111
struct GlobalPool{A}
122
aggr::Symbol
133
cluster::A

src/scatter.jl

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
## Scatter operations
2+
3+
const ops = [:add, :sub, :mul, :div, :max, :min, :mean]
14
const name2op = Dict(:add => :+, :sub => :-, :mul => :*, :div => :/)
25

36
for op = [:add, :sub, :mul, :div]
4-
@eval function $(Symbol("scatter_", op, "!"))(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
7+
fn = Symbol("scatter_$(op)!")
8+
@eval function $fn(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
59
@simd for k = 1:length(xs)
610
k = CartesianIndices(xs)[k]
711
@inbounds ys[:, xs[k]...] .= $(name2op[op]).(ys[:, xs[k]...], us[:, k])
@@ -10,23 +14,23 @@ for op = [:add, :sub, :mul, :div]
1014
end
1115
end
1216

13-
function scatter_max!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
17+
function scatter_max!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
1418
@simd for k = 1:length(xs)
1519
k = CartesianIndices(xs)[k]
1620
@inbounds ys[:, xs[k]...] .= max.(ys[:, xs[k]...], us[:, k])
1721
end
1822
ys
1923
end
2024

21-
function scatter_min!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
25+
function scatter_min!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
2226
@simd for k = 1:length(xs)
2327
k = CartesianIndices(xs)[k]
2428
@inbounds ys[:, xs[k]...] .= min.(ys[:, xs[k]...], us[:, k])
2529
end
2630
ys
2731
end
2832

29-
function scatter_mean!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
33+
function scatter_mean!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
3034
Ns = zero(ys)
3135
ys_ = zero(ys)
3236
scatter_add!(Ns, one.(us), xs)
@@ -35,6 +39,10 @@ function scatter_mean!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) wher
3539
return ys
3640
end
3741

42+
43+
44+
## Derivatives of scatter operations
45+
3846
@adjoint function scatter_add!(ys::AbstractArray, us::AbstractArray, xs::AbstractArray)
3947
ys_ = copy(ys)
4048
scatter_add!(ys_, us, xs)
@@ -47,7 +55,7 @@ end
4755
ys_, Δ -> (Δ, -gather(zero(Δ)+Δ, xs), nothing)
4856
end
4957

50-
@adjoint function scatter_mul!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
58+
@adjoint function scatter_mul!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
5159
ys_ = copy(ys)
5260
scatter_mul!(ys_, us, xs)
5361
ys_, function (Δ)
@@ -65,7 +73,7 @@ end
6573
end
6674
end
6775

68-
@adjoint function scatter_div!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
76+
@adjoint function scatter_div!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
6977
ys_ = copy(ys)
7078
scatter_div!(ys_, us, xs)
7179
ys_, function (Δ)
@@ -83,15 +91,7 @@ end
8391
end
8492
end
8593

86-
function gather_indices(X::Array{T}) where T
87-
Y = DefaultDict{T,Vector{CartesianIndex}}(CartesianIndex[])
88-
@inbounds for (ind, val) = pairs(X)
89-
push!(Y[val], ind)
90-
end
91-
Y
92-
end
93-
94-
@adjoint function scatter_max!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
94+
@adjoint function scatter_max!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
9595
max = copy(ys)
9696
scatter_max!(max, us, xs)
9797
max, function (Δ)
@@ -101,7 +101,7 @@ end
101101
end
102102
end
103103

104-
@adjoint function scatter_min!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
104+
@adjoint function scatter_min!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
105105
min = copy(ys)
106106
scatter_min!(min, us, xs)
107107
min, function (Δ)
@@ -127,6 +127,18 @@ end
127127
end
128128
end
129129

130+
131+
132+
## Bool
133+
134+
function scatter_add!(ys::Array{Bool}, us::Array{Bool}, xs::Array{<:IntOrTuple})
135+
scatter_add!(Int8.(ys), Int8.(us), xs)
136+
end
137+
138+
139+
140+
## API
141+
130142
function scatter!(op::Symbol, ys::AbstractArray, us::AbstractArray, xs::AbstractArray)
131143
if op == :add
132144
return scatter_add!(ys, us, xs)
@@ -144,3 +156,12 @@ function scatter!(op::Symbol, ys::AbstractArray, us::AbstractArray, xs::Abstract
144156
return scatter_mean!(ys, us, xs)
145157
end
146158
end
159+
160+
# Support different types of array
161+
for op = ops
162+
fn = Symbol("scatter_$(op)!")
163+
@eval function $fn(ys::Array{T}, us::Array{S}, xs::Array{<:IntOrTuple}) where {T<:Real,S<:Real}
164+
PT = promote_type(T, S)
165+
$fn(PT.(ys), PT.(us), xs)
166+
end
167+
end

src/utils.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
## Type transformation
2+
3+
floattype(::Type{T}) where {T<:AbstractFloat} = T
4+
floattype(::Type{Int8}) = Float16
5+
floattype(::Type{UInt8}) = Float16
6+
floattype(::Type{Int16}) = Float16
7+
floattype(::Type{UInt16}) = Float16
8+
floattype(::Type{Int32}) = Float32
9+
floattype(::Type{UInt32}) = Float32
10+
floattype(::Type{Int64}) = Float64
11+
floattype(::Type{UInt64}) = Float64
12+
13+
14+
15+
## Inverse operation of scatter
16+
117
function gather(input::AbstractArray{T,N}, index::AbstractArray{<:Integer,N}, dims::Integer;
218
out::AbstractArray{T,N}=similar(index, T)) where {T,N}
319
@assert dims <= N "Specified dimensions must lower or equal to the rank of input matrix."
@@ -10,7 +26,6 @@ function gather(input::AbstractArray{T,N}, index::AbstractArray{<:Integer,N}, di
1026
return out
1127
end
1228

13-
1429
function gather(input::Matrix{T}, index::Array{Int}) where T
1530
out = Array{T}(undef, size(input,1), size(index)...)
1631
@inbounds for ind = CartesianIndices(index)
@@ -19,8 +34,20 @@ function gather(input::Matrix{T}, index::Array{Int}) where T
1934
return out
2035
end
2136

37+
function gather_indices(X::Array{T}) where T
38+
Y = DefaultDict{T,Vector{CartesianIndex}}(CartesianIndex[])
39+
@inbounds for (ind, val) = pairs(X)
40+
push!(Y[val], ind)
41+
end
42+
Y
43+
end
44+
2245
identity(; kwargs...) = kwargs.data
2346

47+
48+
49+
## Graph related utility functions
50+
2451
struct GraphInfo{A,T<:Integer}
2552
adj::AbstractVector{A}
2653
edge_idx::A
@@ -42,6 +69,8 @@ function edge_index_table(adj::AbstractVector{<:AbstractVector{<:Integer}},
4269
y
4370
end
4471

72+
73+
4574
## Indexing
4675

4776
function range_indecies(idx::Tuple)

test/scatter.jl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,64 +4,80 @@ us = ones(Int, 2, 3, 4)
44
xs = [1 2 3 4;
55
4 2 1 3;
66
3 5 5 3]
7-
types = [UInt8, UInt16, UInt32, UInt64,
8-
Int8, Int16, Int32, Int64, Int128,
9-
Float16, Float32, Float64]
7+
types = [UInt8, UInt16, UInt32, UInt64, UInt128,
8+
Int8, Int16, Int32, Int64, Int128, BigInt,
9+
Float16, Float32, Float64, BigFloat, Rational]
1010

1111
@testset "scatter" begin
1212
for T = types
1313
@testset "$T" begin
14+
PT = promote_type(T, Int)
1415
@testset "scatter_add!" begin
1516
ys_ = [5 5 8 6 7;
1617
7 7 10 8 9]
1718
@test scatter_add!(T.(copy(ys)), T.(us), xs) == T.(ys_)
19+
@test scatter_add!(T.(copy(ys)), us, xs) == PT.(ys_)
20+
@test scatter_add!(copy(ys), T.(us), xs) == PT.(ys_)
1821
@test scatter!(:add, T.(copy(ys)), T.(us), xs) == T.(ys_)
1922
end
2023

2124
@testset "scatter_sub!" begin
2225
ys_ = [1 1 0 2 3;
2326
3 3 2 4 5]
2427
@test scatter_sub!(T.(copy(ys)), T.(us), xs) == T.(ys_)
28+
@test scatter_sub!(T.(copy(ys)), us, xs) == PT.(ys_)
29+
@test scatter_sub!(copy(ys), T.(us), xs) == PT.(ys_)
2530
@test scatter!(:sub, T.(copy(ys)), T.(us), xs) == T.(ys_)
2631
end
2732

2833
@testset "scatter_max!" begin
2934
ys_ = [3 3 4 4 5;
3035
5 5 6 6 7]
3136
@test scatter_max!(T.(copy(ys)), T.(us), xs) == T.(ys_)
37+
@test scatter_max!(T.(copy(ys)), us, xs) == PT.(ys_)
38+
@test scatter_max!(copy(ys), T.(us), xs) == PT.(ys_)
3239
@test scatter!(:max, T.(copy(ys)), T.(us), xs) == T.(ys_)
3340
end
3441

3542
@testset "scatter_min!" begin
3643
ys_ = [1 1 1 1 1;
3744
1 1 1 1 1]
3845
@test scatter_min!(T.(copy(ys)), T.(us), xs) == T.(ys_)
46+
@test scatter_min!(T.(copy(ys)), us, xs) == PT.(ys_)
47+
@test scatter_min!(copy(ys), T.(us), xs) == PT.(ys_)
3948
@test scatter!(:min, T.(copy(ys)), T.(us), xs) == T.(ys_)
4049
end
4150

4251
@testset "scatter_mul!" begin
4352
ys_ = [3 3 4 4 5;
4453
5 5 6 6 7]
4554
@test scatter_mul!(T.(copy(ys)), T.(us), xs) == T.(ys_)
55+
@test scatter_mul!(T.(copy(ys)), us, xs) == PT.(ys_)
56+
@test scatter_mul!(copy(ys), T.(us), xs) == PT.(ys_)
4657
@test scatter!(:mul, T.(copy(ys)), T.(us), xs) == T.(ys_)
4758
end
4859
end
4960
end
5061

51-
for T = [Float16, Float32, Float64]
62+
for T = [Float16, Float32, Float64, BigFloat, Rational]
5263
@testset "$T" begin
64+
PT = promote_type(T, Float64)
5365
@testset "scatter_div!" begin
5466
us_div = us .* 2
5567
ys_ = [0.75 0.75 0.25 1. 1.25;
5668
1.25 1.25 0.375 1.5 1.75]
5769
@test scatter_div!(T.(copy(ys)), T.(us_div), xs) == T.(ys_)
70+
@test scatter_div!(T.(copy(ys)), us_div, xs) == PT.(ys_)
71+
@test scatter_div!(copy(ys), T.(us_div), xs) == PT.(ys_)
5872
@test scatter!(:div, T.(copy(ys)), T.(us_div), xs) == T.(ys_)
5973
end
6074

6175
@testset "scatter_mean!" begin
62-
ys_ = [4 4 5 5 6;
63-
6 6 7 7 8]
76+
ys_ = [4. 4. 5. 5. 6.;
77+
6. 6. 7. 7. 8.]
6478
@test scatter_mean!(T.(copy(ys)), T.(us), xs) == T.(ys_)
79+
@test scatter_mean!(T.(copy(ys)), us, xs) == PT.(ys_)
80+
@test scatter_mean!(copy(ys), T.(us), xs) == PT.(ys_)
6581
@test scatter!(:mean, T.(copy(ys)), T.(us), xs) == T.(ys_)
6682
end
6783
end

0 commit comments

Comments
 (0)