-
Notifications
You must be signed in to change notification settings - Fork 30
Faster arraydist
with LazyArrays.jl
#231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
0aad616
cd9c845
3cfa86e
896bce3
fcf9bf4
6d0855b
74d4e38
ae52b81
fcdd588
59423de
bcfdecf
90d3bbc
9a1b201
80b3d51
a792e8b
c9324a3
b3b2786
a43d6e5
8604902
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -48,3 +48,112 @@ parameterless_type(x) = parameterless_type(typeof(x)) | |||||||||
parameterless_type(x::Type) = __parameterless_type(x) | ||||||||||
|
||||||||||
@non_differentiable adapt_randn(::Any...) | ||||||||||
|
||||||||||
""" | ||||||||||
make_closure(f, g) | ||||||||||
|
||||||||||
Return a closure of the form `(x, args...) -> f(g(args...), x)`. | ||||||||||
|
||||||||||
# Examples | ||||||||||
|
||||||||||
This is particularly useful when one wants to avoid broadcasting over constructors | ||||||||||
which can sometimes cause issues with type-inference, in particular when combined | ||||||||||
with reverse-mode AD frameworks. | ||||||||||
|
||||||||||
```juliarepl | ||||||||||
julia> using DistributionsAD, Distributions, ReverseDiff, BenchmarkTools | ||||||||||
|
||||||||||
julia> const data = randn(1000); | ||||||||||
|
||||||||||
julia> x = randn(length(data)); | ||||||||||
|
||||||||||
julia> f(x) = sum(logpdf.(Normal.(x), data)) | ||||||||||
f (generic function with 2 methods) | ||||||||||
|
||||||||||
julia> @btime ReverseDiff.gradient(\$f, \$x); | ||||||||||
848.759 μs (14605 allocations: 521.84 KiB) | ||||||||||
|
||||||||||
julia> # Much faster with ReverseDiff.jl. | ||||||||||
g(x) = let g_inner = DistributionsAD.make_closure(logpdf, Normal) | ||||||||||
sum(g_inner.(data, x)) | ||||||||||
end | ||||||||||
g (generic function with 1 method) | ||||||||||
|
||||||||||
julia> @btime ReverseDiff.gradient(\$g, \$x); | ||||||||||
17.460 μs (17 allocations: 71.52 KiB) | ||||||||||
``` | ||||||||||
|
||||||||||
See https://github.com/TuringLang/Turing.jl/issues/1934 more further discussion. | ||||||||||
|
||||||||||
# Notes | ||||||||||
To really go "vrooom!\" one needs to specialize on the arguments, e.g. if one | ||||||||||
has a function `myfunc` then we need to define | ||||||||||
|
||||||||||
```julia | ||||||||||
make_closure(::typeof(myfunc), ::Type{D}) where {D} = myfunc(D(args...), x) | ||||||||||
``` | ||||||||||
|
||||||||||
This can also be done using `DistributionsAD.@specialize_make_closure`: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe just missing type parameters in the definition below? |
||||||||||
|
||||||||||
```julia | ||||||||||
julia> mylogpdf(d, x) = logpdf(d, x) | ||||||||||
mylogpdf (generic function with 1 method) | ||||||||||
|
||||||||||
julia> h(x) = let inner = DistributionsAD.make_closure(mylogpdf, Normal) | ||||||||||
sum(inner.(data, x)) | ||||||||||
end | ||||||||||
h (generic function with 1 method) | ||||||||||
|
||||||||||
julia> @btime ReverseDiff.gradient(\$h, \$x); | ||||||||||
1.220 ms (37011 allocations: 1.42 MiB) | ||||||||||
|
||||||||||
julia> DistributionsAD.@specialize_make_closure mylogpdf | ||||||||||
|
||||||||||
julia> @btime ReverseDiff.gradient(\$h, \$x); | ||||||||||
17.038 μs (17 allocations: 71.52 KiB) | ||||||||||
``` | ||||||||||
""" | ||||||||||
make_closure(f, g) = (x, args...) -> f(g(args...), x) | ||||||||||
make_closure(f, ::Type{D}) where {D} = (x, args...) -> f(D(args...), x) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if there are any possible performance/compiler benefits by not closing over the variables but to use (more-Julian) callable structs that capture
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I unfortuantely tried that but AFAIK it ends up with this issue of closing over a I might have not done it correctly though. But you suggestion I have tried, and it unfortunately doesn't have an affect. If you just look at the returned closures, they're all the same one 😕 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe something like struct Closure{F,G} end
Closure(::F, ::G) where {F,G} = Closure{F,G}()
Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::G) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{F,G}()
for f in [pdf, logpdf, cdf, logcdf]
@eval (::$(Closure){typeof($f),G})(x, args...) where {G} = $f(G(args...), x)
end ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}() and others avoid the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have a look at what I've done now 👍 |
||||||||||
|
||||||||||
|
||||||||||
""" | ||||||||||
has_specialized_make_closure(f, g) | ||||||||||
|
||||||||||
Return `true` if there exists a specialized `make_closure(f, g)` implementation. | ||||||||||
""" | ||||||||||
has_specialized_make_closure(f, g) = false | ||||||||||
|
||||||||||
# To go vroooom we need to specialize on the first argument, thus ensuring that | ||||||||||
# a different closure is constructed for each method. | ||||||||||
""" | ||||||||||
@specialize_make_closure(f) | ||||||||||
|
||||||||||
Define `make_closure` and `has_specialized_make_closure` for first first argument being `f` | ||||||||||
and second argument being a type. | ||||||||||
""" | ||||||||||
macro specialize_make_closure(f) | ||||||||||
return quote | ||||||||||
$(DistributionsAD).make_closure(::typeof($(esc(f))), ::Type{D}) where {D} = (x, args...) -> $(esc(f))(D(args...), x) | ||||||||||
$(DistributionsAD).has_specialized_make_closure(::typeof($(esc(f))), ::Type{D}) where {D} = true | ||||||||||
end | ||||||||||
end | ||||||||||
|
||||||||||
""" | ||||||||||
@specialize_make_closure(f, g) | ||||||||||
|
||||||||||
Define `make_closure` and `has_specialized_make_closure` for first first argument being `f` | ||||||||||
and second argument being `g`. | ||||||||||
""" | ||||||||||
macro specialize_make_closure(f, g) | ||||||||||
return quote | ||||||||||
$(DistributionsAD).make_closure(::typeof($(esc(f))), ::typeof($(esc(g)))) = (x, args...) -> $(esc(f))($(esc(g))(args...), x) | ||||||||||
$(DistributionsAD).has_specialized_make_closure(::typeof($(esc(f))), ::typeof{$(esc(g))}) = true | ||||||||||
end | ||||||||||
end | ||||||||||
|
||||||||||
@specialize_make_closure Distributions.pdf | ||||||||||
@specialize_make_closure Distributions.logpdf | ||||||||||
@specialize_make_closure Distributions.loglikelihood | ||||||||||
@specialize_make_closure Distributions.cdf | ||||||||||
@specialize_make_closure Distributions.logcdf | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should be possible to remove all of this code. Maybe type parameters are already sufficient. Or using a callable struct might help. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
const LazyVectorOfUnivariate{ | ||
S<:ValueSupport, | ||
T<:UnivariateDistribution{S}, | ||
Tdists<:BroadcastVector{T}, | ||
} = VectorOfUnivariate{S,T,Tdists} | ||
|
||
_inner_constructor(::Type{<:BroadcastVector{<:Any,Type{D}}}) where {D} = D | ||
|
||
function Distributions._logpdf( | ||
dist::LazyVectorOfUnivariate, | ||
x::AbstractVector{<:Real}, | ||
) | ||
# TODO: Implement chain rule for `LazyArray` constructor to support Zygote. | ||
f = make_closure(logpdf, _inner_constructor(typeof(dist.v))) | ||
# TODO: Make use of `sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)))` once | ||
# we've addressed performance issues in ReverseDiff.jl. | ||
return sum(f.(x, dist.v.args...)) | ||
end | ||
|
||
function Distributions.logpdf( | ||
dist::LazyVectorOfUnivariate, | ||
x::AbstractMatrix{<:Real}, | ||
) | ||
size(x, 1) == length(dist) || | ||
throw(DimensionMismatch("Inconsistent array dimensions.")) | ||
f = make_closure(logpdf, _inner_constructor(typeof(dist.v))) | ||
return vec(sum(f.(x, dist.v.args...), dims = 1)) | ||
end | ||
|
||
const LazyMatrixOfUnivariate{ | ||
S<:ValueSupport, | ||
T<:UnivariateDistribution{S}, | ||
Tdists<:BroadcastArray{T,2}, | ||
} = MatrixOfUnivariate{S,T,Tdists} | ||
|
||
function Distributions._logpdf( | ||
dist::LazyMatrixOfUnivariate, | ||
x::AbstractMatrix{<:Real}, | ||
) | ||
f = make_closure(logpdf, _inner_constructor(typeof(dist.v))) | ||
|
||
return sum(f.(x, dist.v.args)) | ||
end | ||
|
||
lazyarray(f, x...) = BroadcastArray(f, x...) | ||
export lazyarray | ||
Comment on lines
+45
to
+47
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not clear to me why this is needed. It doesn't seem much shorter and it makes it less clear that everything is based on LazyArrays. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh no this was already in DistributionsAD.jl 🤷 Not something I put in here. I was also unaware of this methods existence. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we deprecate it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to! |
||
|
||
# Necessary to make `BroadcastArray` work nicely with Zygote. | ||
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::Type{LazyArrays.BroadcastArray}, f, args...) | ||
return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This even has a bug in it:
dists
isn't defined..