Skip to content

Faster arraydist with LazyArrays.jl #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 1 addition & 39 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,45 +80,7 @@ include("zygote.jl")
end

@require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin
using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray

const LazyVectorOfUnivariate{
S<:ValueSupport,
T<:UnivariateDistribution{S},
Tdists<:BroadcastVector{T},
} = VectorOfUnivariate{S,T,Tdists}

function Distributions._logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractVector{<:Real},
)
return sum(copy(logpdf.(dist.v, x)))
end

function Distributions.logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
size(x, 1) == length(dist) ||
throw(DimensionMismatch("Inconsistent array dimensions."))
return vec(sum(copy(logpdf.(dists, x)), dims = 1))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This even has a bug in it: dists isn't defined..

end

const LazyMatrixOfUnivariate{
S<:ValueSupport,
T<:UnivariateDistribution{S},
Tdists<:BroadcastArray{T,2},
} = MatrixOfUnivariate{S,T,Tdists}

function Distributions._logpdf(
dist::LazyMatrixOfUnivariate,
x::AbstractMatrix{<:Real},
)
return sum(copy(logpdf.(dist.dists, x)))
end

lazyarray(f, x...) = LazyArray(Base.broadcasted(f, x...))
export lazyarray
include("lazyarrays.jl")
end
end

Expand Down
109 changes: 109 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,112 @@ parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)

@non_differentiable adapt_randn(::Any...)

"""
make_closure(f, g)

Return a closure of the form `(x, args...) -> f(g(args...), x)`.

# Examples

This is particularly useful when one wants to avoid broadcasting over constructors
which can sometimes cause issues with type-inference, in particular when combined
with reverse-mode AD frameworks.

```juliarepl
julia> using DistributionsAD, Distributions, ReverseDiff, BenchmarkTools

julia> const data = randn(1000);

julia> x = randn(length(data));

julia> f(x) = sum(logpdf.(Normal.(x), data))
f (generic function with 2 methods)

julia> @btime ReverseDiff.gradient(\$f, \$x);
848.759 μs (14605 allocations: 521.84 KiB)

julia> # Much faster with ReverseDiff.jl.
g(x) = let g_inner = DistributionsAD.make_closure(logpdf, Normal)
sum(g_inner.(data, x))
end
g (generic function with 1 method)

julia> @btime ReverseDiff.gradient(\$g, \$x);
17.460 μs (17 allocations: 71.52 KiB)
```

See https://github.com/TuringLang/Turing.jl/issues/1934 more further discussion.

# Notes
To really go "vrooom!\" one needs to specialize on the arguments, e.g. if one
has a function `myfunc` then we need to define

```julia
make_closure(::typeof(myfunc), ::Type{D}) where {D} = myfunc(D(args...), x)
```

This can also be done using `DistributionsAD.@specialize_make_closure`:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just missing type parameters in the definition below?


```julia
julia> mylogpdf(d, x) = logpdf(d, x)
mylogpdf (generic function with 1 method)

julia> h(x) = let inner = DistributionsAD.make_closure(mylogpdf, Normal)
sum(inner.(data, x))
end
h (generic function with 1 method)

julia> @btime ReverseDiff.gradient(\$h, \$x);
1.220 ms (37011 allocations: 1.42 MiB)

julia> DistributionsAD.@specialize_make_closure mylogpdf

julia> @btime ReverseDiff.gradient(\$h, \$x);
17.038 μs (17 allocations: 71.52 KiB)
```
"""
make_closure(f, g) = (x, args...) -> f(g(args...), x)
make_closure(f, ::Type{D}) where {D} = (x, args...) -> f(D(args...), x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there are any possible performance/compiler benefits by not closing over the variables but to use (more-Julian) callable structs that capture f and g. In any case, I think you want

Suggested change
make_closure(f, g) = (x, args...) -> f(g(args...), x)
make_closure(f, ::Type{D}) where {D} = (x, args...) -> f(D(args...), x)
make_closure(f::F, g::G) where {F,G} = (x, args...) -> f(g(args...), x)
make_closure(f::F, ::Type{D}) where {F,D} = (x, args...) -> f(D(args...), x)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I unfortuantely tried that but AFAIK it ends up with this issue of closing over a UnionAll type again, which is exactly what we're trying to avoid (because of the issues it's causing with some AD backends) 😕

I might have not done it correctly though.

But you suggestion I have tried, and it unfortunately doesn't have an affect. If you just look at the returned closures, they're all the same one 😕

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like

struct Closure{F,G} end

Closure(::F, ::G) where {F,G} = Closure{F,G}()
Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::G) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{F,G}()

for f in [pdf, logpdf, cdf, logcdf]
    @eval (::$(Closure){typeof($f),G})(x, args...) where {G} = $f(G(args...), x)
end

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}()

and others avoid the UnionAll issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have a look at what I've done now 👍



"""
has_specialized_make_closure(f, g)

Return `true` if there exists a specialized `make_closure(f, g)` implementation.
"""
has_specialized_make_closure(f, g) = false

# To go vroooom we need to specialize on the first argument, thus ensuring that
# a different closure is constructed for each method.
"""
@specialize_make_closure(f)

Define `make_closure` and `has_specialized_make_closure` for first first argument being `f`
and second argument being a type.
"""
macro specialize_make_closure(f)
return quote
$(DistributionsAD).make_closure(::typeof($(esc(f))), ::Type{D}) where {D} = (x, args...) -> $(esc(f))(D(args...), x)
$(DistributionsAD).has_specialized_make_closure(::typeof($(esc(f))), ::Type{D}) where {D} = true
end
end

"""
@specialize_make_closure(f, g)

Define `make_closure` and `has_specialized_make_closure` for first first argument being `f`
and second argument being `g`.
"""
macro specialize_make_closure(f, g)
return quote
$(DistributionsAD).make_closure(::typeof($(esc(f))), ::typeof($(esc(g)))) = (x, args...) -> $(esc(f))($(esc(g))(args...), x)
$(DistributionsAD).has_specialized_make_closure(::typeof($(esc(f))), ::typeof{$(esc(g))}) = true
end
end

@specialize_make_closure Distributions.pdf
@specialize_make_closure Distributions.logpdf
@specialize_make_closure Distributions.loglikelihood
@specialize_make_closure Distributions.cdf
@specialize_make_closure Distributions.logcdf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be possible to remove all of this code. Maybe type parameters are already sufficient. Or using a callable struct might help.

53 changes: 53 additions & 0 deletions src/lazyarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray

const LazyVectorOfUnivariate{
S<:ValueSupport,
T<:UnivariateDistribution{S},
Tdists<:BroadcastVector{T},
} = VectorOfUnivariate{S,T,Tdists}

_inner_constructor(::Type{<:BroadcastVector{<:Any,Type{D}}}) where {D} = D

function Distributions._logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractVector{<:Real},
)
# TODO: Implement chain rule for `LazyArray` constructor to support Zygote.
f = make_closure(logpdf, _inner_constructor(typeof(dist.v)))
# TODO: Make use of `sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)))` once
# we've addressed performance issues in ReverseDiff.jl.
return sum(f.(x, dist.v.args...))
end

function Distributions.logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
size(x, 1) == length(dist) ||
throw(DimensionMismatch("Inconsistent array dimensions."))
f = make_closure(logpdf, _inner_constructor(typeof(dist.v)))
return vec(sum(f.(x, dist.v.args...), dims = 1))
end

const LazyMatrixOfUnivariate{
S<:ValueSupport,
T<:UnivariateDistribution{S},
Tdists<:BroadcastArray{T,2},
} = MatrixOfUnivariate{S,T,Tdists}

function Distributions._logpdf(
dist::LazyMatrixOfUnivariate,
x::AbstractMatrix{<:Real},
)
f = make_closure(logpdf, _inner_constructor(typeof(dist.v)))

return sum(f.(x, dist.v.args))
end

lazyarray(f, x...) = BroadcastArray(f, x...)
export lazyarray
Comment on lines +45 to +47
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me why this is needed. It doesn't seem much shorter and it makes it less clear that everything is based on LazyArrays.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no this was already in DistributionsAD.jl 🤷 Not something I put in here. I was also unaware of this methods existence.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we deprecate it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to!


# Necessary to make `BroadcastArray` work nicely with Zygote.
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::Type{LazyArrays.BroadcastArray}, f, args...)
return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...)
end