-
Notifications
You must be signed in to change notification settings - Fork 30
Faster arraydist
with LazyArrays.jl
#231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0aad616
cd9c845
3cfa86e
896bce3
fcf9bf4
6d0855b
74d4e38
ae52b81
fcdd588
59423de
bcfdecf
90d3bbc
9a1b201
80b3d51
a792e8b
c9324a3
b3b2786
a43d6e5
8604902
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,13 @@ include("zygote.jl") | |
end | ||
end | ||
|
||
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin | ||
using .Zygote: Zygote | ||
# HACK: Make Zygote (correctly) recognize that it should use `ForwardDiff` for broadcasting. | ||
# See `is_diff_safe` for more information. | ||
Zygote._dual_purefun(::Type{C}) where {C<:Closure} = is_diff_safe(C) | ||
end | ||
|
||
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin | ||
using DiffRules | ||
using SpecialFunctions | ||
|
@@ -80,45 +87,7 @@ include("zygote.jl") | |
end | ||
|
||
@require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin | ||
using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray | ||
|
||
const LazyVectorOfUnivariate{ | ||
S<:ValueSupport, | ||
T<:UnivariateDistribution{S}, | ||
Tdists<:BroadcastVector{T}, | ||
} = VectorOfUnivariate{S,T,Tdists} | ||
|
||
function Distributions._logpdf( | ||
dist::LazyVectorOfUnivariate, | ||
x::AbstractVector{<:Real}, | ||
) | ||
return sum(copy(logpdf.(dist.v, x))) | ||
end | ||
|
||
function Distributions.logpdf( | ||
dist::LazyVectorOfUnivariate, | ||
x::AbstractMatrix{<:Real}, | ||
) | ||
size(x, 1) == length(dist) || | ||
throw(DimensionMismatch("Inconsistent array dimensions.")) | ||
return vec(sum(copy(logpdf.(dists, x)), dims = 1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This even has a bug in it: |
||
end | ||
|
||
const LazyMatrixOfUnivariate{ | ||
S<:ValueSupport, | ||
T<:UnivariateDistribution{S}, | ||
Tdists<:BroadcastArray{T,2}, | ||
} = MatrixOfUnivariate{S,T,Tdists} | ||
|
||
function Distributions._logpdf( | ||
dist::LazyMatrixOfUnivariate, | ||
x::AbstractMatrix{<:Real}, | ||
) | ||
return sum(copy(logpdf.(dist.dists, x))) | ||
end | ||
|
||
lazyarray(f, x...) = LazyArray(Base.broadcasted(f, x...)) | ||
export lazyarray | ||
include("lazyarrays.jl") | ||
end | ||
end | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -48,3 +48,93 @@ parameterless_type(x) = parameterless_type(typeof(x)) | |||||||||||||||||||||||||||||||
parameterless_type(x::Type) = __parameterless_type(x) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
@non_differentiable adapt_randn(::Any...) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||
Closure{F,G} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
A callable of the form `(x, args...) -> F(G(args...), x)`. | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
# Examples | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
This is particularly useful when one wants to avoid broadcasting over constructors | ||||||||||||||||||||||||||||||||
which can sometimes cause issues with type-inference, in particular when combined | ||||||||||||||||||||||||||||||||
with reverse-mode AD frameworks. | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
```juliarepl | ||||||||||||||||||||||||||||||||
julia> using DistributionsAD, Distributions, ReverseDiff, BenchmarkTools | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
julia> const data = randn(1000); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
julia> x = randn(length(data)); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
julia> f(x) = sum(logpdf.(Normal.(x), data)) | ||||||||||||||||||||||||||||||||
f (generic function with 2 methods) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
julia> @btime ReverseDiff.gradient(\$f, \$x); | ||||||||||||||||||||||||||||||||
848.759 μs (14605 allocations: 521.84 KiB) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
julia> # Much faster with ReverseDiff.jl. | ||||||||||||||||||||||||||||||||
g(x) = sum(DistributionsAD.Closure(logpdf, Normal).(data, x)) | ||||||||||||||||||||||||||||||||
g (generic function with 1 method) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
julia> @btime ReverseDiff.gradient(\$g, \$x); | ||||||||||||||||||||||||||||||||
17.460 μs (17 allocations: 71.52 KiB) | ||||||||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
See https://github.com/TuringLang/Turing.jl/issues/1934 more further discussion. | ||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||
struct Closure{F,G} end | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Closure(::F, ::G) where {F,G} = Closure{F,G}() | ||||||||||||||||||||||||||||||||
Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}() | ||||||||||||||||||||||||||||||||
Closure(::Type{F}, ::G) where {F,G} = Closure{F,G}() | ||||||||||||||||||||||||||||||||
Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{F,G}() | ||||||||||||||||||||||||||||||||
Comment on lines
+87
to
+92
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not really what I had in mind. More something like
Suggested change
Generally just storing the type of e.g. However, with fields the struct the performance with ReverseDiff is bad since then we hit https://github.com/JuliaDiff/ReverseDiff.jl/blob/d522508aa6fea16e9716607cdd27d63453bb61e6/src/derivatives/broadcast.jl#L27. This can be fixed by defining ReverseDiff.mayhavetracked(c::Closure) = ReverseDiff.mayhavetracked(c.f) || ReverseDiff.mayhavetracked(c.g) I wonder if we can just improve the heuristics in ReverseDiff use a similar check for structs/types with multiple fields. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I originally had fields but yeah this resulted in bad computation paths. Might be something that should be changed in the AD instead, true. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Worth pointing out that
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||
is_diff_safe(f) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Return `true` if it's safe to ignore gradients wrt. `f` when computing `f`. | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Useful for checking it's okay to take faster paths in pullbacks for certain AD backends. | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
# Examples | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
```jldoctest | ||||||||||||||||||||||||||||||||
julia> using Distributions | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
julia> using DistributionsAD: is_diff_safe, Closure | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
julia> is_diff_safe(typeof(logpdf)) | ||||||||||||||||||||||||||||||||
true | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
julia> is_diff_safe(typeof(x -> 2x)) | ||||||||||||||||||||||||||||||||
true | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
julia> # But it fails if we make a closure over a variable, which we might want to compute | ||||||||||||||||||||||||||||||||
# the gradient with respect to. | ||||||||||||||||||||||||||||||||
makef(x) = y -> x + y | ||||||||||||||||||||||||||||||||
makef (generic function with 1 method) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
julia> is_diff_safe(typeof(makef([1.0]))) | ||||||||||||||||||||||||||||||||
false | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
julia> # Also works on `Closure`s from `DistributionsAD`. | ||||||||||||||||||||||||||||||||
is_diff_safe(typeof(Closure(logpdf, Normal))) | ||||||||||||||||||||||||||||||||
true | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
julia> is_diff_safe(typeof(Closure(logpdf, makef([1.0])))) | ||||||||||||||||||||||||||||||||
false | ||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||
@inline is_diff_safe(_) = false | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems this serves the same purpose as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah basically, though I was looking at Zygote's |
||||||||||||||||||||||||||||||||
@inline is_diff_safe(::Type) = true | ||||||||||||||||||||||||||||||||
@inline is_diff_safe(::Type{F}) where {F<:Function} = Base.issingletontype(F) | ||||||||||||||||||||||||||||||||
@inline is_diff_safe(::Type{Closure{F,G}}) where {F,G} = is_diff_safe(F) && is_diff_safe(G) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
@generated function (closure::Closure{F,G})(x, args...) where {F,G} | ||||||||||||||||||||||||||||||||
f = Base.issingletontype(F) ? F.instance : F | ||||||||||||||||||||||||||||||||
g = Base.issingletontype(G) ? G.instance : G | ||||||||||||||||||||||||||||||||
Comment on lines
+135
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not a 100% certain on this. Need to think when I've had some sleep. |
||||||||||||||||||||||||||||||||
return :($f($g(args...), x)) | ||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
Comment on lines
+134
to
+138
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not a fan of If we want to keep julia> struct Closure{F,G} end
julia> Closure(::F, ::G) where {F,G} = Closure{F,G}()
julia> Closure(::F, ::Type{G}) where {F,G} = Closure{F,Type{G}}()
julia> Closure(::Type{F}, ::G) where {F,G} = Closure{Type{F},G}()
julia> Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{Type{F},Type{G}}()
julia> (::Closure{F,G})(x, args...) where {F,G} = F.instance(G.instance(args...), x)
julia> (::Closure{F,Type{G}})(x, args...) where {F,G} = F.instance(G(args...), x)
julia> (::Closure{Type{F},G})(x, args...) where {F,G} = F(G.instance(args...), x)
julia> (::Closure{Type{F},Type{G}})(x, args...) where {F,G} = F(G(args...), x) But somehow this version and the one in the PR seem all a bit hacky... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
As mentioned before, I 100% agree with you. But this performance issue is literally the cause of several Slack and Discourse threads of people going "why is Turing so slow for this simple model?", and so IMO we should just get this fixed despite its hackiness and then we make it less hacky as we go + maybe improve ReverseDiff and Zygote. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh and regarding the |
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,99 @@ | ||||||||||
using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray | ||||||||||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
||||||||||
const LazyVectorOfUnivariate{ | ||||||||||
S<:ValueSupport, | ||||||||||
T<:UnivariateDistribution{S}, | ||||||||||
Tdists<:BroadcastVector{T}, | ||||||||||
} = VectorOfUnivariate{S,T,Tdists} | ||||||||||
|
||||||||||
_inner_constructor(::Type{<:BroadcastVector{<:Any,Type{D}}}) where {D} = D | ||||||||||
|
||||||||||
function Distributions._logpdf( | ||||||||||
dist::LazyVectorOfUnivariate, | ||||||||||
x::AbstractVector{<:Real}, | ||||||||||
) | ||||||||||
# TODO: Make use of `sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)))` once | ||||||||||
# we've addressed performance issues in ReverseDiff.jl. | ||||||||||
constructor = _inner_constructor(typeof(dist.v)) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm sure this will be problematic in some cases and break. It's not guaranteed that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A simple example: julia> struct A{X,Y}
x::X
y::Y
A(x::X, y::Y) where {X,Y} = new{X,Y}(x, y)
end
julia> _constructor(::Type{D}) where {D} = D
_constructor (generic function with 1 method)
julia> x, y = 1, 2.0
(1, 2.0)
julia> a = A(x, y)
A{Int64, Float64}(1, 2.0)
julia> _constructor(typeof(a))(x, y)
ERROR: MethodError: no method matching A{Int64, Float64}(::Int64, ::Float64)
Stacktrace:
[1] top-level scope
@ REPL[31]:1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I was actually using ConstructionBase locally for this before:) But I removed it because I figured this will only be used for a very simple subset of constructors, so uncertain if it's worth it. But I'll add it back again:) |
||||||||||
return sum(Closure(logpdf, constructor).(x, dist.v.args...)) | ||||||||||
end | ||||||||||
|
||||||||||
function Distributions.logpdf( | ||||||||||
dist::LazyVectorOfUnivariate, | ||||||||||
x::AbstractMatrix{<:Real}, | ||||||||||
) | ||||||||||
size(x, 1) == length(dist) || | ||||||||||
throw(DimensionMismatch("Inconsistent array dimensions.")) | ||||||||||
constructor = _inner_constructor(typeof(dist.v)) | ||||||||||
return vec(sum(Closure(logpdf, constructor).(x, dist.v.args...), dims = 1)) | ||||||||||
end | ||||||||||
|
||||||||||
const LazyMatrixOfUnivariate{ | ||||||||||
S<:ValueSupport, | ||||||||||
T<:UnivariateDistribution{S}, | ||||||||||
Tdists<:BroadcastArray{T,2}, | ||||||||||
} = MatrixOfUnivariate{S,T,Tdists} | ||||||||||
|
||||||||||
function Distributions._logpdf( | ||||||||||
dist::LazyMatrixOfUnivariate, | ||||||||||
x::AbstractMatrix{<:Real}, | ||||||||||
) | ||||||||||
|
||||||||||
constructor = _inner_constructor(typeof(dist.v)) | ||||||||||
return sum(Closure(logpdf, constructor).(x, dist.v.args)) | ||||||||||
end | ||||||||||
|
||||||||||
lazyarray(f, x...) = BroadcastArray(f, x...) | ||||||||||
export lazyarray | ||||||||||
Comment on lines
+45
to
+47
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not clear to me why this is needed. It doesn't seem much shorter and it makes it less clear that everything is based on LazyArrays. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh no this was already in DistributionsAD.jl 🤷 Not something I put in here. I was also unaware of this methods existence. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we deprecate it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to! |
||||||||||
|
||||||||||
# HACK: All of the below probably shouldn't be here. | ||||||||||
function ChainRulesCore.rrule(::Type{BroadcastArray}, f, args...) | ||||||||||
function BroadcastArray_pullback(Δ::ChainRulesCore.Tangent) | ||||||||||
return (ChainRulesCore.NoTangent(), Δ.f, Δ.args...) | ||||||||||
end | ||||||||||
return BroadcastArray(f, args...), BroadcastArray_pullback | ||||||||||
end | ||||||||||
|
||||||||||
ChainRulesCore.ProjectTo(ba::BroadcastArray) = ProjectTo{typeof(ba)}((f=ba.f,)) | ||||||||||
function (p::ChainRulesCore.ProjectTo{BA})(args...) where {BA<:BroadcastArray} | ||||||||||
return ChainRulesCore.Tangent{BA}(f=p.f, args=args) | ||||||||||
end | ||||||||||
Comment on lines
+50
to
+60
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm surprised this is needed. Feels like that's the default for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, so we can alos just close over the function |
||||||||||
|
||||||||||
function ChainRulesCore.rrule( | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW the annoying thing about these kinds of general rules is that it might break (and it happened to me multiple times) code that would have worked without rule and if one would just let the AD system perform its default differentiation. One can fix these issues though by using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could also just define this using Unfortunately there's no way around this because we have to stop Zygote from trying to differentiate through the broadcasted constructor. |
||||||||||
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, | ||||||||||
::typeof(logpdf), | ||||||||||
dist::LazyVectorOfUnivariate, | ||||||||||
x::AbstractVector{<:Real} | ||||||||||
) | ||||||||||
# Extract the constructor used in the `BroadcastArray`. | ||||||||||
constructor = DistributionsAD._inner_constructor(typeof(dist.v)) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned above, it's not guaranteed that this actually is a constructor that can be called with the arguments. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW is the qualification really needed? |
||||||||||
|
||||||||||
# If it's not safe to ignore the `constructor` in the pullback, then we fall back | ||||||||||
# to the default implementation. | ||||||||||
is_diff_safe(constructor) || return ChainRulesCore.rrule_via_ad(config, (d,x) -> sum(logpdf.(d.v, x)), dist, x) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we just use
Suggested change
to avoid making any assumptions about how There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively maybe just return
Suggested change
(https://juliadiff.org/ChainRulesCore.jl/stable/ad_author/opt_out.html) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
But won't this hit the same
I don't completely understand this. So opting out means it will fall back to the underlying AD? I thought it meant "nothing to differentiate here" |
||||||||||
|
||||||||||
# Otherwise, we use `Closure`. | ||||||||||
cl = DistributionsAD.Closure(logpdf, constructor) | ||||||||||
|
||||||||||
# Construct pullbacks manually to avoid the constructor of `BroadcastArray`. | ||||||||||
y, dy = ChainRulesCore.rrule_via_ad(config, broadcast, cl, x, dist.v.args...) | ||||||||||
z, dz = ChainRulesCore.rrule_via_ad(config, sum, y) | ||||||||||
|
||||||||||
project_broadcastarray = ChainRulesCore.ProjectTo(dist.v) | ||||||||||
function logpdf_adjoint(Δ...) | ||||||||||
# 1st argument is `sum` -> nothing. | ||||||||||
(_, sum_Δ...) = dz(Δ...) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Generally you might have to deal with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But should the other pullbacks also deal with this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like we only need to deal with |
||||||||||
# 1st argument is `broadcast` -> nothing. | ||||||||||
# 2nd argument is `cl` -> `nothing`. | ||||||||||
# 3rd argument is `x` -> something. | ||||||||||
# Rest is `dist` arguments -> something | ||||||||||
(_, _, x_Δ, args_Δ...) = dy(sum_Δ...) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One thing I'm worried about: what if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, so that should be addressed with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The following example in ChainRules is related: https://github.com/JuliaDiff/ChainRules.jl/blob/9adf759bc63432dc518ccf499d6938fc5a217113/src/rulesets/Base/mapreduce.jl#L76 |
||||||||||
# Construct the structural tangents. | ||||||||||
ba_tangent = project_broadcastarray(args_Δ...) | ||||||||||
dist_tangent = ChainRulesCore.Tangent{typeof(dist)}(v=ba_tangent) | ||||||||||
|
||||||||||
return (ChainRulesCore.NoTangent(), dist_tangent, x_Δ) | ||||||||||
end | ||||||||||
|
||||||||||
return z, logpdf_adjoint | ||||||||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make into a weak dep?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha yeah that should be done with basically everything in this repo. Or rather, most things should be moved to Distributions and added as a weak dep there to fix the type piracy issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally
is_diff_safe
would be a function ForwardDiff or one of its dependencies owns. That would avoid the need for a package extension or Requires block overriding an internal function like_dual_purefun
.