Skip to content

Commit 3a8cda8

Browse files
authored
Merge pull request #56 from yuehhua/scatter
Improve CPU scatter performance
2 parents 2733443 + 62d47c9 commit 3a8cda8

File tree

7 files changed

+155
-91
lines changed

7 files changed

+155
-91
lines changed

benchmark/pics/cpu_scatter.png

1.42 KB
Loading

benchmark/pics/cpu_scatter.svg

Lines changed: 74 additions & 74 deletions
Loading

benchmark/plot.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using CSV, DataFrames
2+
using Gadfly
3+
import Cairo
4+
5+
julia_bmk_file = joinpath("benchmark", "scatter_julia.tsv")
6+
python_bmk_file = joinpath("benchmark", "scatter_pytorch.tsv")
7+
8+
bmk_jl = CSV.read(julia_bmk_file; delim='\t')
9+
bmk_py = CSV.read(python_bmk_file; delim='\t')
10+
11+
bmk_jl[!, :framework] .= "geometricflux"
12+
bmk_py[!, :framework] .= "pytorch-scatter"
13+
14+
bmk = vcat(bmk_jl, bmk_py)
15+
16+
bmk[!, :min_time] .= bmk[!, :min_time]/1000
17+
bmk[!, :mean_time] .= bmk[!, :mean_time]/1000
18+
bmk[!, :max_time] .= bmk[!, :max_time]/1000
19+
20+
function plot_benchmark(device)
21+
DEVICE = uppercase(device)
22+
p = plot(bmk[bmk[!,:device] .== device, :], x="sample", y="mean_time", color="framework",
23+
Geom.point, Geom.line, Scale.x_log2, Scale.y_log10,
24+
Guide.title("Scatter add performance on $(DEVICE)"),
25+
Guide.xlabel("Matrix Size"), Guide.ylabel("Time (μs)"),
26+
Coord.cartesian(xmin=4, xmax=21, ymin=1, ymax=7))
27+
28+
draw(SVG(joinpath("benchmark", "pics", "$(device)_scatter.svg"), 9inch, 6inch), p)
29+
draw(PNG(joinpath("benchmark", "pics", "$(device)_scatter.png"), 9inch, 6inch), p)
30+
end
31+
32+
plot_benchmark("cpu")
33+
plot_benchmark("gpu")

benchmark/profile.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using CUDAdrv
2+
using GeometricFlux
3+
using Profile
4+
using ProfileView
5+
6+
ENV["JULIA_NUM_THREADS"] = 1
7+
8+
d = 50
9+
nbins = 20
10+
l = 2^20
11+
12+
hist = zeros(Float32, d, nbins)
13+
δ = rand(Float32, d, l)
14+
idx = rand(1:nbins, l)
15+
scatter_add!(hist, δ, idx)
16+
17+
@profile scatter_add!(hist, δ, idx)
18+
Profile.print()
19+
20+
@profview scatter_add!(hist, δ, idx)
21+
22+
# sudo nvprof --profile-from-start off julia benchmark/scatter.jl
23+
# sudo nvprof --profile-from-start off --print-gpu-trace julia --proj benchmark/scatter.jl
24+
# sudo chown $USER -R $HOME/.julia/
25+
26+
# @profview scatter_add!(hist, δ, idx)
27+
# CUDAdrv.@profile scatter_add!(hist_gpu, δ_gpu, idx_gpu)

benchmark/scatter.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ using DataFrames
66
using CSV
77
using BenchmarkTools
88
using BenchmarkTools: Trial, TrialEstimate, median, mean
9-
# using ProfileView
9+
10+
ENV["JULIA_NUM_THREADS"] = 1
1011

1112
d = 50
1213
nbins = 20
@@ -53,11 +54,3 @@ CSV.write("benchmark/scatter_julia.tsv", data; delim="\t")
5354
## Benchmark
5455
# @benchmark scatter_add!($hist, $δ, $idx)
5556
# CuArrays.@time scatter_add!(hist_gpu, δ_gpu, idx_gpu)
56-
57-
## Profiling
58-
# sudo nvprof --profile-from-start off julia benchmark/scatter.jl
59-
# sudo nvprof --profile-from-start off --print-gpu-trace julia --proj benchmark/scatter.jl
60-
# sudo chown yuehhua -R /home/yuehhua/.julia/
61-
62-
# @profview scatter_add!(hist, δ, idx)
63-
# CUDAdrv.@profile scatter_add!(hist_gpu, δ_gpu, idx_gpu)

benchmark/scatter_py.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ using BenchmarkTools: Trial, TrialEstimate, median, mean
77
py"""
88
import torch
99
import torch_scatter as sc
10-
torch.set_num_threads(12)
1110
cuda = torch.device("cuda:0")
1211
d = 50
1312
nbins = 20

src/operations/scatter.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ for op = [:add, :sub, :mul, :div]
88
@eval function $fn(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
99
@simd for k = 1:length(xs)
1010
k = CartesianIndices(xs)[k]
11-
@inbounds ys[:, xs[k]...] .= $(name2op[op]).(view(ys, :, xs[k]...), view(us, :, k))
11+
ys_v = view(ys, :, xs[k]...)
12+
us_v = view(us, :, k)
13+
@inbounds ys_v .= $(name2op[op]).(ys_v, us_v)
1214
end
1315
ys
1416
end
@@ -17,15 +19,19 @@ end
1719
function scatter_max!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
1820
@simd for k = 1:length(xs)
1921
k = CartesianIndices(xs)[k]
20-
@inbounds ys[:, xs[k]...] .= max.(view(ys, :, xs[k]...), view(us, :, k))
22+
ys_v = view(ys, :, xs[k]...)
23+
us_v = view(us, :, k)
24+
@inbounds view(ys, :, xs[k]...) .= max.(ys_v, us_v)
2125
end
2226
ys
2327
end
2428

2529
function scatter_min!(ys::Array{T}, us::Array{T}, xs::Array{<:IntOrTuple}) where {T<:Real}
2630
@simd for k = 1:length(xs)
2731
k = CartesianIndices(xs)[k]
28-
@inbounds ys[:, xs[k]...] .= min.(view(ys, :, xs[k]...), view(us, :, k))
32+
ys_v = view(ys, :, xs[k]...)
33+
us_v = view(us, :, k)
34+
@inbounds ys_v .= min.(ys_v, us_v)
2935
end
3036
ys
3137
end
@@ -48,7 +54,9 @@ for op = [:add, :sub, :mul, :div]
4854
xs::StaticArray{<:Tuple,<:IntOrTuple}) where {T<:Real}
4955
@simd for k = 1:length(xs)
5056
k = CartesianIndices(xs)[k]
51-
@inbounds ys[:, xs[k]...] .= $(name2op[op]).(view(ys, :, xs[k]...), view(us, :, k))
57+
ys_v = view(ys, :, xs[k]...)
58+
us_v = view(us, :, k)
59+
@inbounds ys_v .= $(name2op[op]).(ys_v, us_v)
5260
end
5361
ys
5462
end
@@ -58,7 +66,9 @@ function scatter_max!(ys::StaticArray{<:Tuple,T}, us::StaticArray{<:Tuple,T},
5866
xs::StaticArray{<:Tuple,<:IntOrTuple}) where {T<:Real}
5967
@simd for k = 1:length(xs)
6068
k = CartesianIndices(xs)[k]
61-
@inbounds ys[:, xs[k]...] .= max.(view(ys, :, xs[k]...), view(us, :, k))
69+
ys_v = view(ys, :, xs[k]...)
70+
us_v = view(us, :, k)
71+
@inbounds ys_v .= max.(ys_v, us_v)
6272
end
6373
ys
6474
end
@@ -67,7 +77,9 @@ function scatter_min!(ys::StaticArray{<:Tuple,T}, us::StaticArray{<:Tuple,T},
6777
xs::StaticArray{<:Tuple,<:IntOrTuple}) where {T<:Real}
6878
@simd for k = 1:length(xs)
6979
k = CartesianIndices(xs)[k]
70-
@inbounds ys[:, xs[k]...] .= min.(view(ys, :, xs[k]...), view(us, :, k))
80+
ys_v = view(ys, :, xs[k]...)
81+
us_v = view(us, :, k)
82+
@inbounds ys_v .= min.(ys_v, us_v)
7183
end
7284
ys
7385
end
@@ -163,7 +175,7 @@ end
163175
counts += sum(xs.==i) * (xs.==i)
164176
end
165177
@inbounds for ind = CartesianIndices(counts)
166-
Δu[:, ind] ./= counts[ind]
178+
view(Δu, :, ind) ./= counts[ind]
167179
end
168180
(Δ, Δu, nothing)
169181
end

0 commit comments

Comments
 (0)