Skip to content

Faster arraydist with LazyArrays.jl #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DistributionsAD"
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.43"
version = "0.6.44"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
47 changes: 8 additions & 39 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ include("zygote.jl")
end
end

@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
using .Zygote: Zygote
# HACK: Make Zygote (correctly) recognize that it should use `ForwardDiff` for broadcasting.
# See `is_diff_safe` for more information.
Zygote._dual_purefun(::Type{C}) where {C<:Closure} = is_diff_safe(C)
end
Comment on lines +73 to +78

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make into a weak dep?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha yeah that should be done with basically everything in this repo. Or rather, most things should be moved to Distributions and added as a weak dep there to fix the type piracy issues.

Copy link

@ToucheSir ToucheSir Feb 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally is_diff_safe would be a function ForwardDiff or one of its dependencies owns. That would avoid the need for a package extension or Requires block overriding an internal function like _dual_purefun.


@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
using DiffRules
using SpecialFunctions
Expand All @@ -80,45 +87,7 @@ include("zygote.jl")
end

@require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin
using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray

const LazyVectorOfUnivariate{
S<:ValueSupport,
T<:UnivariateDistribution{S},
Tdists<:BroadcastVector{T},
} = VectorOfUnivariate{S,T,Tdists}

function Distributions._logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractVector{<:Real},
)
return sum(copy(logpdf.(dist.v, x)))
end

function Distributions.logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
size(x, 1) == length(dist) ||
throw(DimensionMismatch("Inconsistent array dimensions."))
return vec(sum(copy(logpdf.(dists, x)), dims = 1))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This even has a bug in it: dists isn't defined..

end

const LazyMatrixOfUnivariate{
S<:ValueSupport,
T<:UnivariateDistribution{S},
Tdists<:BroadcastArray{T,2},
} = MatrixOfUnivariate{S,T,Tdists}

function Distributions._logpdf(
dist::LazyMatrixOfUnivariate,
x::AbstractMatrix{<:Real},
)
return sum(copy(logpdf.(dist.dists, x)))
end

lazyarray(f, x...) = LazyArray(Base.broadcasted(f, x...))
export lazyarray
include("lazyarrays.jl")
end
end

Expand Down
90 changes: 90 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,93 @@ parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)

@non_differentiable adapt_randn(::Any...)


"""
Closure{F,G}

A callable of the form `(x, args...) -> F(G(args...), x)`.

# Examples

This is particularly useful when one wants to avoid broadcasting over constructors
which can sometimes cause issues with type-inference, in particular when combined
with reverse-mode AD frameworks.

```juliarepl
julia> using DistributionsAD, Distributions, ReverseDiff, BenchmarkTools

julia> const data = randn(1000);

julia> x = randn(length(data));

julia> f(x) = sum(logpdf.(Normal.(x), data))
f (generic function with 2 methods)

julia> @btime ReverseDiff.gradient(\$f, \$x);
848.759 μs (14605 allocations: 521.84 KiB)

julia> # Much faster with ReverseDiff.jl.
g(x) = sum(DistributionsAD.Closure(logpdf, Normal).(data, x))
g (generic function with 1 method)

julia> @btime ReverseDiff.gradient(\$g, \$x);
17.460 μs (17 allocations: 71.52 KiB)
```

See https://github.com/TuringLang/Turing.jl/issues/1934 more further discussion.
"""
struct Closure{F,G} end

Closure(::F, ::G) where {F,G} = Closure{F,G}()
Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::G) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{F,G}()
Comment on lines +87 to +92
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not really what I had in mind. More something like Base.Fix1:

Suggested change
struct Closure{F,G} end
Closure(::F, ::G) where {F,G} = Closure{F,G}()
Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::G) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{F,G}()
struct Closure{F,G}
f::F
g::G
end
Closure(f::F, g::G) where {F,G} = Closure{F,G}(f, g)
Closure(f::F, ::Type{G}) where {F,G} = Closure{F,Type{G}}(f, G)
Closure(::Type{F}, g::G) where {F,G} = Closure{Type{F},G}(G, g)
Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{Type{F},Type{G}}(F, G)

Generally just storing the type of e.g. f is not sufficient: if f is e.g. a callable struct F does not provide enough information.

However, with fields the struct the performance with ReverseDiff is bad since then we hit https://github.com/JuliaDiff/ReverseDiff.jl/blob/d522508aa6fea16e9716607cdd27d63453bb61e6/src/derivatives/broadcast.jl#L27. This can be fixed by defining

ReverseDiff.mayhavetracked(c::Closure) = ReverseDiff.mayhavetracked(c.f) || ReverseDiff.mayhavetracked(c.g)

I wonder if we can just improve the heuristics in ReverseDiff use a similar check for structs/types with multiple fields.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I originally had fields but yeah this resulted in bad computation paths. Might be something that should be changed in the AD instead, true.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth pointing out that

  1. Closure is not something that's meant to be used heavily for arbitrary callables (e.g. the checks in the adjoint explicitly exclude the scenario where we have fields).
  2. Closure should not be used by the end-user.


"""
is_diff_safe(f)

Return `true` if it's safe to ignore gradients wrt. `f` when computing `f`.

Useful for checking it's okay to take faster paths in pullbacks for certain AD backends.

# Examples

```jldoctest
julia> using Distributions

julia> using DistributionsAD: is_diff_safe, Closure

julia> is_diff_safe(typeof(logpdf))
true

julia> is_diff_safe(typeof(x -> 2x))
true

julia> # But it fails if we make a closure over a variable, which we might want to compute
# the gradient with respect to.
makef(x) = y -> x + y
makef (generic function with 1 method)

julia> is_diff_safe(typeof(makef([1.0])))
false

julia> # Also works on `Closure`s from `DistributionsAD`.
is_diff_safe(typeof(Closure(logpdf, Normal)))
true

julia> is_diff_safe(typeof(Closure(logpdf, makef([1.0]))))
false
"""
@inline is_diff_safe(_) = false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah basically, though I was looking at Zygote's _dual_purefun because ReverseDiff already hits the broadcast_forward, even without the custom adjoint I defined (ReverseDiff also doesn't support calling back into AD).

@inline is_diff_safe(::Type) = true
@inline is_diff_safe(::Type{F}) where {F<:Function} = Base.issingletontype(F)
@inline is_diff_safe(::Type{Closure{F,G}}) where {F,G} = is_diff_safe(F) && is_diff_safe(G)

@generated function (closure::Closure{F,G})(x, args...) where {F,G}
f = Base.issingletontype(F) ? F.instance : F
g = Base.issingletontype(G) ? G.instance : G
Comment on lines +135 to +136
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a 100% certain on this. Need to think when I've had some sleep.

return :($f($g(args...), x))
end
Comment on lines +134 to +138
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of @generated function if there is a different way to get good performance, they have too many limitations IMO.

If we want to keep Closure without fields (not sure, maybe it would be better to change the heuristics in ReverseDiff), then the following seems to work at least in the example above:

julia> struct Closure{F,G} end

julia> Closure(::F, ::G) where {F,G} = Closure{F,G}()

julia> Closure(::F, ::Type{G}) where {F,G} = Closure{F,Type{G}}()

julia> Closure(::Type{F}, ::G) where {F,G} = Closure{Type{F},G}()

julia> Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{Type{F},Type{G}}()

julia> (::Closure{F,G})(x, args...) where {F,G} = F.instance(G.instance(args...), x)

julia> (::Closure{F,Type{G}})(x, args...) where {F,G} = F.instance(G(args...), x)

julia> (::Closure{Type{F},G})(x, args...) where {F,G} = F(G.instance(args...), x)

julia> (::Closure{Type{F},Type{G}})(x, args...) where {F,G} = F(G(args...), x)

But somehow this version and the one in the PR seem all a bit hacky...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But somehow this version and the one in the PR seem all a bit hacky...

As mentioned before, I 100% agree with you. But this performance issue is literally the cause of several Slack and Discourse threads of people going "why is Turing so slow for this simple model?", and so IMO we should just get this fixed despite its hackiness and then we make it less hacky as we go + maybe improve ReverseDiff and Zygote.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh and regarding the @generated, I'm happy to do away with it. I'll try your suggestion 👍



99 changes: 99 additions & 0 deletions src/lazyarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray

const LazyVectorOfUnivariate{
S<:ValueSupport,
T<:UnivariateDistribution{S},
Tdists<:BroadcastVector{T},
} = VectorOfUnivariate{S,T,Tdists}

_inner_constructor(::Type{<:BroadcastVector{<:Any,Type{D}}}) where {D} = D

function Distributions._logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractVector{<:Real},
)
# TODO: Make use of `sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)))` once
# we've addressed performance issues in ReverseDiff.jl.
constructor = _inner_constructor(typeof(dist.v))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure this will be problematic in some cases and break. It's not guaranteed that _inner_constructor returns a proper constructor. Something safer would be https://github.com/JuliaObjects/ConstructionBase.jl I assume.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simple example:

julia> struct A{X,Y}
           x::X
           y::Y
           A(x::X, y::Y) where {X,Y} = new{X,Y}(x, y)
       end

julia> _constructor(::Type{D}) where {D} = D
_constructor (generic function with 1 method)

julia> x, y = 1, 2.0
(1, 2.0)

julia> a = A(x, y)
A{Int64, Float64}(1, 2.0)

julia> _constructor(typeof(a))(x, y)
ERROR: MethodError: no method matching A{Int64, Float64}(::Int64, ::Float64)
Stacktrace:
 [1] top-level scope
   @ REPL[31]:1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I was actually using ConstructionBase locally for this before:) But I removed it because I figured this will only be used for a very simple subset of constructors, so uncertain if it's worth it. But I'll add it back again:)

return sum(Closure(logpdf, constructor).(x, dist.v.args...))
end

function Distributions.logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
size(x, 1) == length(dist) ||
throw(DimensionMismatch("Inconsistent array dimensions."))
constructor = _inner_constructor(typeof(dist.v))
return vec(sum(Closure(logpdf, constructor).(x, dist.v.args...), dims = 1))
end

const LazyMatrixOfUnivariate{
S<:ValueSupport,
T<:UnivariateDistribution{S},
Tdists<:BroadcastArray{T,2},
} = MatrixOfUnivariate{S,T,Tdists}

function Distributions._logpdf(
dist::LazyMatrixOfUnivariate,
x::AbstractMatrix{<:Real},
)

constructor = _inner_constructor(typeof(dist.v))
return sum(Closure(logpdf, constructor).(x, dist.v.args))
end

lazyarray(f, x...) = BroadcastArray(f, x...)
export lazyarray
Comment on lines +45 to +47
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me why this is needed. It doesn't seem much shorter and it makes it less clear that everything is based on LazyArrays.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no this was already in DistributionsAD.jl 🤷 Not something I put in here. I was also unaware of this methods existence.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we deprecate it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to!


# HACK: All of the below probably shouldn't be here.
function ChainRulesCore.rrule(::Type{BroadcastArray}, f, args...)
function BroadcastArray_pullback(Δ::ChainRulesCore.Tangent)
return (ChainRulesCore.NoTangent(), Δ.f, Δ.args...)
end
return BroadcastArray(f, args...), BroadcastArray_pullback
end

ChainRulesCore.ProjectTo(ba::BroadcastArray) = ProjectTo{typeof(ba)}((f=ba.f,))
function (p::ChainRulesCore.ProjectTo{BA})(args...) where {BA<:BroadcastArray}
return ChainRulesCore.Tangent{BA}(f=p.f, args=args)
end
Comment on lines +50 to +60
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised this is needed. Feels like that's the default for Tangents anyway?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so we can alos just close over the function f and construct the Tanget directly in the adjoint (in fact, this is what I did originally), but I thought maybe ProjectTo was the more "proper" way to do it. I can go back to directly constructing the Tangent though:)


function ChainRulesCore.rrule(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW the annoying thing about these kinds of general rules is that it might break (and it happened to me multiple times) code that would have worked without rule and if one would just let the AD system perform its default differentiation. One can fix these issues though by using @opt_out (but that also has some problems...).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also just define this using ZygoteRules.@adjoint if that helps?

Unfortunately there's no way around this because we have to stop Zygote from trying to differentiate through the broadcasted constructor.

config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
::typeof(logpdf),
dist::LazyVectorOfUnivariate,
x::AbstractVector{<:Real}
)
# Extract the constructor used in the `BroadcastArray`.
constructor = DistributionsAD._inner_constructor(typeof(dist.v))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, it's not guaranteed that this actually is a constructor that can be called with the arguments.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW is the qualification really needed?


# If it's not safe to ignore the `constructor` in the pullback, then we fall back
# to the default implementation.
is_diff_safe(constructor) || return ChainRulesCore.rrule_via_ad(config, (d,x) -> sum(logpdf.(d.v, x)), dist, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just use

Suggested change
is_diff_safe(constructor) || return ChainRulesCore.rrule_via_ad(config, (d,x) -> sum(logpdf.(d.v, x)), dist, x)
is_diff_safe(constructor) || return ChainRulesCore.rrule_via_ad(config, (d,x) -> logpdf(d, x), dist, x)

to avoid making any assumptions about how logpdf(dist, x) is implemented?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively maybe just return nothing?

Suggested change
is_diff_safe(constructor) || return ChainRulesCore.rrule_via_ad(config, (d,x) -> sum(logpdf.(d.v, x)), dist, x)
is_diff_safe(constructor) || return nothing

(https://juliadiff.org/ChainRulesCore.jl/stable/ad_author/opt_out.html)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to avoid making any assumptions about how logpdf(dist, x) is implemented?

But won't this hit the same rrule once you get to logpdf?

Alternatively maybe just return nothing?

I don't completely understand this. So opting out means it will fall back to the underlying AD? I thought it meant "nothing to differentiate here"


# Otherwise, we use `Closure`.
cl = DistributionsAD.Closure(logpdf, constructor)

# Construct pullbacks manually to avoid the constructor of `BroadcastArray`.
y, dy = ChainRulesCore.rrule_via_ad(config, broadcast, cl, x, dist.v.args...)
z, dz = ChainRulesCore.rrule_via_ad(config, sum, y)

project_broadcastarray = ChainRulesCore.ProjectTo(dist.v)
function logpdf_adjoint(Δ...)
# 1st argument is `sum` -> nothing.
(_, sum_Δ...) = dz(Δ...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally you might have to deal with unthunk I assume.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But should the other pullbacks also deal with this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we only need to deal with unthunk in the pullback for the BroadcastArray constructor, no?

# 1st argument is `broadcast` -> nothing.
# 2nd argument is `cl` -> `nothing`.
# 3rd argument is `x` -> something.
# Rest is `dist` arguments -> something
(_, _, x_Δ, args_Δ...) = dy(sum_Δ...)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I'm worried about: what if f is a closure containing variables to differentiate wrt. to? 😬

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so that should be addressed with the is_diff_safe above. We will now only hit this faster path if we're certain we don't need to take derivatives wrt. anything "in" the consturctor itself.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Construct the structural tangents.
ba_tangent = project_broadcastarray(args_Δ...)
dist_tangent = ChainRulesCore.Tangent{typeof(dist)}(v=ba_tangent)

return (ChainRulesCore.NoTangent(), dist_tangent, x_Δ)
end

return z, logpdf_adjoint
end