Skip to content

Scatters support almost all Real numbers and add test coverage #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions src/GeometricFlux.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
module GeometricFlux

using Core.Intrinsics: llvmcall
using Base.Threads
using Statistics: mean
using SparseArrays: SparseMatrixCSC
using LinearAlgebra: I, issymmetric, diagm, eigmax
Expand All @@ -13,13 +11,6 @@ using Flux: glorot_uniform, leakyrelu, GRUCell
using Flux: @functor
using ZygoteRules

import Base: identity
import Base.Threads: atomictypes, llvmtypes, inttype, ArithmeticTypes, FloatTypes,
atomic_cas!, atomic_xchg!,
atomic_add!, atomic_sub!, atomic_max!, atomic_min!,
atomic_and!, atomic_nand!, atomic_or!, atomic_xor!
import Base.Sys: ARCH, WORD_SIZE

export

# layers/meta
Expand Down
7 changes: 7 additions & 0 deletions src/atomic.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
# from https://github.com/chengchingwen/Transformers.jl/tree/master/src/fix
using Core.Intrinsics: llvmcall
import Base: identity
import Base.Threads: atomictypes, llvmtypes, inttype, ArithmeticTypes, FloatTypes,
atomic_cas!, atomic_xchg!,
atomic_add!, atomic_sub!, atomic_max!, atomic_min!,
atomic_and!, atomic_nand!, atomic_or!, atomic_xor!
import Base.Sys: ARCH, WORD_SIZE

for typ ∈ atomictypes
lt = llvmtypes[typ]
Expand Down
12 changes: 0 additions & 12 deletions src/cuda/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,6 @@ end
end
end

function gather_indices(X::CuArray{T}) where T
Y = gather_indices(Array(X))
cuY = Dict{T,CuVector}(k => cu(Tuple.(v)) for (k, v) in Y)
cuY
end

@adjoint function scatter_max!(ys::CuArray{T}, us::CuArray{T}, xs::CuArray) where {T<:AbstractFloat}
max = copy(ys)
scatter_max!(max, us, xs)
Expand All @@ -180,9 +174,3 @@ end
(Δy, Δu, nothing)
end
end

function numerical_cmp(X::CuArray{T}, Y::CuArray) where T
Z = map((x,y) -> sign(x - y)^2, X, Y)
Z = map(x -> (one(T) - x)^2, Z)
Z
end
14 changes: 14 additions & 0 deletions src/cuda/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
## Inverse operation of scatter

gather(input::AbstractArray{T}, index::CuArray{Int}) where T = gather(cu(input), index)

function gather(input::CuMatrix{T}, index::CuArray{Int}) where T
Expand All @@ -7,3 +9,15 @@ function gather(input::CuMatrix{T}, index::CuArray{Int}) where T
end
return out
end

function gather_indices(X::CuArray{T}) where T
Y = gather_indices(Array(X))
cuY = Dict{T,CuVector}(k => cu(Tuple.(v)) for (k, v) in Y)
cuY
end

function numerical_cmp(X::CuArray{T}, Y::CuArray) where T
Z = map((x,y) -> sign(x - y)^2, X, Y)
Z = map(x -> (one(T) - x)^2, Z)
Z
end
2 changes: 2 additions & 0 deletions src/layers/msgpass.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Base.Threads

abstract type MessagePassing <: Meta end

adjlist(m::T) where {T<:MessagePassing} = m.adjlist
Expand Down
10 changes: 0 additions & 10 deletions src/layers/pool.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
floattype(::Type{T}) where {T<:AbstractFloat} = T
floattype(::Type{Int8}) = Float16
floattype(::Type{UInt8}) = Float16
floattype(::Type{Int16}) = Float16
floattype(::Type{UInt16}) = Float16
floattype(::Type{Int32}) = Float32
floattype(::Type{UInt32}) = Float32
floattype(::Type{Int64}) = Float64
floattype(::Type{UInt64}) = Float64

struct GlobalPool{A}
aggr::Symbol
cluster::A
Expand Down
53 changes: 37 additions & 16 deletions src/scatter.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
## Scatter operations

const ops = [:add, :sub, :mul, :div, :max, :min, :mean]
const name2op = Dict(:add => :+, :sub => :-, :mul => :*, :div => :/)

for op = [:add, :sub, :mul, :div]
@eval function $(Symbol("scatter_", op, "!"))(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
fn = Symbol("scatter_$(op)!")
@eval function $fn(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
@simd for k = 1:length(xs)
k = CartesianIndices(xs)[k]
@inbounds ys[:, xs[k]...] .= $(name2op[op]).(ys[:, xs[k]...], us[:, k])
Expand All @@ -10,23 +14,23 @@ for op = [:add, :sub, :mul, :div]
end
end

function scatter_max!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
function scatter_max!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
@simd for k = 1:length(xs)
k = CartesianIndices(xs)[k]
@inbounds ys[:, xs[k]...] .= max.(ys[:, xs[k]...], us[:, k])
end
ys
end

function scatter_min!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
function scatter_min!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
@simd for k = 1:length(xs)
k = CartesianIndices(xs)[k]
@inbounds ys[:, xs[k]...] .= min.(ys[:, xs[k]...], us[:, k])
end
ys
end

function scatter_mean!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
function scatter_mean!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
Ns = zero(ys)
ys_ = zero(ys)
scatter_add!(Ns, one.(us), xs)
Expand All @@ -35,6 +39,10 @@ function scatter_mean!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) wher
return ys
end



## Derivatives of scatter operations

@adjoint function scatter_add!(ys::AbstractArray, us::AbstractArray, xs::AbstractArray)
ys_ = copy(ys)
scatter_add!(ys_, us, xs)
Expand All @@ -47,7 +55,7 @@ end
ys_, Δ -> (Δ, -gather(zero(Δ)+Δ, xs), nothing)
end

@adjoint function scatter_mul!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
@adjoint function scatter_mul!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
ys_ = copy(ys)
scatter_mul!(ys_, us, xs)
ys_, function (Δ)
Expand All @@ -65,7 +73,7 @@ end
end
end

@adjoint function scatter_div!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
@adjoint function scatter_div!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
ys_ = copy(ys)
scatter_div!(ys_, us, xs)
ys_, function (Δ)
Expand All @@ -83,15 +91,7 @@ end
end
end

function gather_indices(X::Array{T}) where T
Y = DefaultDict{T,Vector{CartesianIndex}}(CartesianIndex[])
@inbounds for (ind, val) = pairs(X)
push!(Y[val], ind)
end
Y
end

@adjoint function scatter_max!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
@adjoint function scatter_max!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
max = copy(ys)
scatter_max!(max, us, xs)
max, function (Δ)
Expand All @@ -101,7 +101,7 @@ end
end
end

@adjoint function scatter_min!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where T
@adjoint function scatter_min!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
min = copy(ys)
scatter_min!(min, us, xs)
min, function (Δ)
Expand All @@ -127,6 +127,18 @@ end
end
end



## Bool

function scatter_add!(ys::Array{Bool}, us::Array{Bool}, xs::Array{<:IntOrTuple})
scatter_add!(Int8.(ys), Int8.(us), xs)
end



## API

function scatter!(op::Symbol, ys::AbstractArray, us::AbstractArray, xs::AbstractArray)
if op == :add
return scatter_add!(ys, us, xs)
Expand All @@ -144,3 +156,12 @@ function scatter!(op::Symbol, ys::AbstractArray, us::AbstractArray, xs::Abstract
return scatter_mean!(ys, us, xs)
end
end

# Support different types of array
for op = ops
fn = Symbol("scatter_$(op)!")
@eval function $fn(ys::Array{T}, us::Array{S}, xs::Array{<:IntOrTuple}) where {T<:Real,S<:Real}
PT = promote_type(T, S)
$fn(PT.(ys), PT.(us), xs)
end
end
31 changes: 30 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
## Type transformation

floattype(::Type{T}) where {T<:AbstractFloat} = T
floattype(::Type{Int8}) = Float16
floattype(::Type{UInt8}) = Float16
floattype(::Type{Int16}) = Float16
floattype(::Type{UInt16}) = Float16
floattype(::Type{Int32}) = Float32
floattype(::Type{UInt32}) = Float32
floattype(::Type{Int64}) = Float64
floattype(::Type{UInt64}) = Float64



## Inverse operation of scatter

function gather(input::AbstractArray{T,N}, index::AbstractArray{<:Integer,N}, dims::Integer;
out::AbstractArray{T,N}=similar(index, T)) where {T,N}
@assert dims <= N "Specified dimensions must lower or equal to the rank of input matrix."
Expand All @@ -10,7 +26,6 @@ function gather(input::AbstractArray{T,N}, index::AbstractArray{<:Integer,N}, di
return out
end


function gather(input::Matrix{T}, index::Array{Int}) where T
out = Array{T}(undef, size(input,1), size(index)...)
@inbounds for ind = CartesianIndices(index)
Expand All @@ -19,8 +34,20 @@ function gather(input::Matrix{T}, index::Array{Int}) where T
return out
end

function gather_indices(X::Array{T}) where T
Y = DefaultDict{T,Vector{CartesianIndex}}(CartesianIndex[])
@inbounds for (ind, val) = pairs(X)
push!(Y[val], ind)
end
Y
end

identity(; kwargs...) = kwargs.data



## Graph related utility functions

struct GraphInfo{A,T<:Integer}
adj::AbstractVector{A}
edge_idx::A
Expand All @@ -42,6 +69,8 @@ function edge_index_table(adj::AbstractVector{<:AbstractVector{<:Integer}},
y
end



## Indexing

function range_indecies(idx::Tuple)
Expand Down
28 changes: 22 additions & 6 deletions test/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,64 +4,80 @@ us = ones(Int, 2, 3, 4)
xs = [1 2 3 4;
4 2 1 3;
3 5 5 3]
types = [UInt8, UInt16, UInt32, UInt64,
Int8, Int16, Int32, Int64, Int128,
Float16, Float32, Float64]
types = [UInt8, UInt16, UInt32, UInt64, UInt128,
Int8, Int16, Int32, Int64, Int128, BigInt,
Float16, Float32, Float64, BigFloat, Rational]

@testset "scatter" begin
for T = types
@testset "$T" begin
PT = promote_type(T, Int)
@testset "scatter_add!" begin
ys_ = [5 5 8 6 7;
7 7 10 8 9]
@test scatter_add!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter_add!(T.(copy(ys)), us, xs) == PT.(ys_)
@test scatter_add!(copy(ys), T.(us), xs) == PT.(ys_)
@test scatter!(:add, T.(copy(ys)), T.(us), xs) == T.(ys_)
end

@testset "scatter_sub!" begin
ys_ = [1 1 0 2 3;
3 3 2 4 5]
@test scatter_sub!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter_sub!(T.(copy(ys)), us, xs) == PT.(ys_)
@test scatter_sub!(copy(ys), T.(us), xs) == PT.(ys_)
@test scatter!(:sub, T.(copy(ys)), T.(us), xs) == T.(ys_)
end

@testset "scatter_max!" begin
ys_ = [3 3 4 4 5;
5 5 6 6 7]
@test scatter_max!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter_max!(T.(copy(ys)), us, xs) == PT.(ys_)
@test scatter_max!(copy(ys), T.(us), xs) == PT.(ys_)
@test scatter!(:max, T.(copy(ys)), T.(us), xs) == T.(ys_)
end

@testset "scatter_min!" begin
ys_ = [1 1 1 1 1;
1 1 1 1 1]
@test scatter_min!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter_min!(T.(copy(ys)), us, xs) == PT.(ys_)
@test scatter_min!(copy(ys), T.(us), xs) == PT.(ys_)
@test scatter!(:min, T.(copy(ys)), T.(us), xs) == T.(ys_)
end

@testset "scatter_mul!" begin
ys_ = [3 3 4 4 5;
5 5 6 6 7]
@test scatter_mul!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter_mul!(T.(copy(ys)), us, xs) == PT.(ys_)
@test scatter_mul!(copy(ys), T.(us), xs) == PT.(ys_)
@test scatter!(:mul, T.(copy(ys)), T.(us), xs) == T.(ys_)
end
end
end

for T = [Float16, Float32, Float64]
for T = [Float16, Float32, Float64, BigFloat, Rational]
@testset "$T" begin
PT = promote_type(T, Float64)
@testset "scatter_div!" begin
us_div = us .* 2
ys_ = [0.75 0.75 0.25 1. 1.25;
1.25 1.25 0.375 1.5 1.75]
@test scatter_div!(T.(copy(ys)), T.(us_div), xs) == T.(ys_)
@test scatter_div!(T.(copy(ys)), us_div, xs) == PT.(ys_)
@test scatter_div!(copy(ys), T.(us_div), xs) == PT.(ys_)
@test scatter!(:div, T.(copy(ys)), T.(us_div), xs) == T.(ys_)
end

@testset "scatter_mean!" begin
ys_ = [4 4 5 5 6;
6 6 7 7 8]
ys_ = [4. 4. 5. 5. 6.;
6. 6. 7. 7. 8.]
@test scatter_mean!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter_mean!(T.(copy(ys)), us, xs) == PT.(ys_)
@test scatter_mean!(copy(ys), T.(us), xs) == PT.(ys_)
@test scatter!(:mean, T.(copy(ys)), T.(us), xs) == T.(ys_)
end
end
Expand Down