From 81bb5f6f2b602bdeb89cc6cf7ef2ff2bb58aee21 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 3 Feb 2025 15:18:18 +0000 Subject: [PATCH 01/16] Refactor dot_tilde, work in progress --- src/compiler.jl | 109 ++++++++++++++++++++++++++---------------------- 1 file changed, 60 insertions(+), 49 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8743641af..8da9c43ad 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -172,7 +172,7 @@ Check if the right-hand side `x` of a `~` is a `Distribution` or an array of function check_tilde_rhs(@nospecialize(x)) return throw( ArgumentError( - "the right-hand side of a `~` must be a `Distribution` or an array of `Distribution`s", + "the right-hand side of a `~` must be a `Distribution`, an array of `Distribution`s, or a submodel", ), ) end @@ -184,6 +184,26 @@ function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} return Sampleable{typeof(model),AutoPrefix}(model) end +""" + check_dot_tilde_rhs(x) + +Check if the right-hand side `x` of a `.~` is a `Distribution` or an array of +univariate `Distributions`, then return `x`. +""" +function check_dot_tilde_rhs(@nospecialize(x)) + return throw( + ArgumentError( + "the right-hand side of a `.~` must be a `Distribution` or an array of univariate `Distribution`s", + ), + ) +end +check_dot_tilde_rhs(x::Distribution) = x +check_dot_tilde_rhs(x::AbstractArray{<:UnivariateDistribution}) = x +function check_dot_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} + model = check_dot_tilde_rhs(x.model) + return Sampleable{typeof(model),AutoPrefix}(model) +end + """ unwrap_right_vn(right, vn) @@ -356,11 +376,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn) args_dottilde = getargs_dottilde(expr) if args_dottilde !== nothing L, R = args_dottilde - return Base.remove_linenums!( - generate_dot_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), - ), + return generate_mainbody!( + mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn ) end @@ -487,56 +504,50 @@ end Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) - isliteral(left) && return generate_tilde_literal(left, right) - - # Otherwise it is determined by the model or its value, - # if the LHS represents an observation - @gensym vn isassumption value + @gensym dist left_axes num_dist_dims colons left_axes_to_loop lhs_idx lhs_indexed local_dist rhs_idx + tilde_statement = if isliteral(left) + # If the LHS is a literal, we need to first index into it to get another literal + # value, and then ~ on that to get a tilde_observe call. + quote + $lhs_indexed = $left[$lhs_idx...] + $lhs_indexed ~ $local_dist + end + else + # If the LHS is not a literal, we can index into it in the tilde statement. + quote + $left[$lhs_idx...] ~ $local_dist + end + end + # TODO(mhauru) Add informative error messages if dimensions don't match. return quote - $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $right - ) - $isassumption = $(DynamicPPL.isassumption(left, vn)) - if $(DynamicPPL.isfixed(left, vn)) - $left .= $(DynamicPPL.getfixed_nested)(__context__, $vn) - elseif $isassumption - $(generate_dot_tilde_assume(left, right, vn)) + $dist = DynamicPPL.check_dot_tilde_rhs($right) + $left_axes = axes($left) + # The two that we support for the RHS, it being a Distribution or an array of + # univariate Distributions, need to be treated quite differently. For a Distribution + # the LHS needs to be indexed with colons, and the RHS can remain as-is. For an + # array we need to loop over the whole LHS iterable, and pick the right values from + # the RHS in each loop iteration. + if $dist isa Distributions.Distribution + $num_dist_dims = length(Distributions.size($dist)) + $colons = fill(:, $num_dist_dims) + $left_axes_to_loop = $left_axes[(begin + $num_dist_dims):end] + $local_dist = $dist + for idx in Iterators.product($left_axes_to_loop...) + $lhs_idx = ($colons..., idx...) + $tilde_statement + end else - # If `vn` is not in `argnames`, we need to make sure that the variable is defined. - if !$(DynamicPPL.inargnames)($vn, __model__) - $left .= $(DynamicPPL.getconditioned_nested)(__context__, $vn) + $num_dist_dims = length(size($dist)) + for idx in Iterators.product($left_axes...) + $lhs_idx = idx + $rhs_idx = idx[1:($num_dist_dims)] + $local_dist = $dist[$rhs_idx...] + $tilde_statement end - - $value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)( - __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $(maybe_view(left)), - $vn, - __varinfo__, - ) - $value end end end -function generate_dot_tilde_assume(left, right, vn) - # We don't need to use `Setfield.@set` here since - # `.=` is always going to be inplace + needs `left` to - # be something that supports `.=`. - @gensym value - return quote - $value, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)( - __context__, - $(DynamicPPL.unwrap_right_left_vns)( - $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn - )..., - __varinfo__, - ) - $left .= $value - $value - end -end - # Note that we cannot use `MacroTools.isdef` because # of https://github.com/FluxML/MacroTools.jl/issues/154. """ From 15a55c1bef66cfc605bd0989e3311653af2b71c7 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 4 Feb 2025 17:35:36 +0000 Subject: [PATCH 02/16] Restrict dot_tilde to univariate dists on the RHS --- src/compiler.jl | 61 +++++++++++++------------------------------------ 1 file changed, 16 insertions(+), 45 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8da9c43ad..4b9bd226d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -161,7 +161,16 @@ Return `true` if `expr` is a literal, e.g. `1.0` or `[1.0, ]`, and `false` other """ isliteral(e) = false isliteral(::Number) = true -isliteral(e::Expr) = !isempty(e.args) && all(isliteral, e.args) +function isliteral(e::Expr) + # In the special case that the expression is of the form `abc[blahblah]`, we consider it + # to be a literal if `abc` is a literal. This is necessary for cases like + # [1.0, 2.0][1] ~ Normal() + # which are generate when turning `.~` expressions into loops over `~` expressions. + if e.head == :ref + return isliteral(e.args[1]) + end + return !isempty(e.args) && all(isliteral, e.args) +end """ check_tilde_rhs(x) @@ -187,18 +196,14 @@ end """ check_dot_tilde_rhs(x) -Check if the right-hand side `x` of a `.~` is a `Distribution` or an array of -univariate `Distributions`, then return `x`. +Check if the right-hand side `x` of a `.~` is a `UnivariateDistribution`, then return `x`. """ function check_dot_tilde_rhs(@nospecialize(x)) return throw( - ArgumentError( - "the right-hand side of a `.~` must be a `Distribution` or an array of univariate `Distribution`s", - ), + ArgumentError("the right-hand side of a `.~` must be a `UnivariateDistribution`") ) end -check_dot_tilde_rhs(x::Distribution) = x -check_dot_tilde_rhs(x::AbstractArray{<:UnivariateDistribution}) = x +check_dot_tilde_rhs(x::UnivariateDistribution) = x function check_dot_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} model = check_dot_tilde_rhs(x.model) return Sampleable{typeof(model),AutoPrefix}(model) @@ -504,46 +509,12 @@ end Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) - @gensym dist left_axes num_dist_dims colons left_axes_to_loop lhs_idx lhs_indexed local_dist rhs_idx - tilde_statement = if isliteral(left) - # If the LHS is a literal, we need to first index into it to get another literal - # value, and then ~ on that to get a tilde_observe call. - quote - $lhs_indexed = $left[$lhs_idx...] - $lhs_indexed ~ $local_dist - end - else - # If the LHS is not a literal, we can index into it in the tilde statement. - quote - $left[$lhs_idx...] ~ $local_dist - end - end - # TODO(mhauru) Add informative error messages if dimensions don't match. + @gensym dist left_axes return quote $dist = DynamicPPL.check_dot_tilde_rhs($right) $left_axes = axes($left) - # The two that we support for the RHS, it being a Distribution or an array of - # univariate Distributions, need to be treated quite differently. For a Distribution - # the LHS needs to be indexed with colons, and the RHS can remain as-is. For an - # array we need to loop over the whole LHS iterable, and pick the right values from - # the RHS in each loop iteration. - if $dist isa Distributions.Distribution - $num_dist_dims = length(Distributions.size($dist)) - $colons = fill(:, $num_dist_dims) - $left_axes_to_loop = $left_axes[(begin + $num_dist_dims):end] - $local_dist = $dist - for idx in Iterators.product($left_axes_to_loop...) - $lhs_idx = ($colons..., idx...) - $tilde_statement - end - else - $num_dist_dims = length(size($dist)) - for idx in Iterators.product($left_axes...) - $lhs_idx = idx - $rhs_idx = idx[1:($num_dist_dims)] - $local_dist = $dist[$rhs_idx...] - $tilde_statement - end + for idx in Iterators.product($left_axes...) + $left[idx...] ~ $dist end end end From 35c01e376d611acdca9e995fb6ea11e83c4d6092 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 7 Feb 2025 17:16:08 +0000 Subject: [PATCH 03/16] Remove tests with multivariates or arrays as RHS of .~ --- src/test_utils/models.jl | 121 ++++++++++---------------------- test/context_implementations.jl | 53 ++++---------- 2 files changed, 54 insertions(+), 120 deletions(-) diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index c506e1ba3..e29614982 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -186,31 +186,29 @@ function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) return (s=s_unconstrained, m=m), logprior_true(model, s, m) - Δlogp end -@model function demo_dot_assume_dot_observe( - x=[1.5, 2.0], ::Type{TV}=Vector{Float64} -) where {TV} +@model function demo_dot_assume_observe(x=[1.5, 2.0], ::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and `observe` s = TV(undef, length(x)) m = TV(undef, length(x)) s .~ InverseGamma(2, 3) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) x ~ MvNormal(m, Diagonal(s)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe)}, s, m) return loglikelihood(MvNormal(m, Diagonal(s)), model.args.x) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_dot_assume_dot_observe)}, s, m + model::Model{typeof(demo_dot_assume_observe)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_dot_assume_dot_observe)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] +function varnames(model::Model{typeof(demo_dot_assume_observe)}) + return [@varname(s[1]), @varname(s[2]), @varname(m)] end @model function demo_assume_index_observe( @@ -276,7 +274,7 @@ end s = TV(undef, length(x)) s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) for i in eachindex(x) x[i] ~ Normal(m[i], sqrt(s[i])) end @@ -295,7 +293,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_index)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end # Using vector of `length` 1 here so the posterior of `m` is the same @@ -355,7 +353,7 @@ end s = TV(undef, 2) m = TV(undef, 2) s .~ InverseGamma(2, 3) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) @@ -376,7 +374,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end @model function demo_assume_observe_literal() @@ -431,7 +429,7 @@ end s = TV(undef, 2) s .~ InverseGamma(2, 3) m = TV(undef, 2) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) return s, m end @@ -460,7 +458,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end @model function _likelihood_multivariate_observe(s, m, x) @@ -473,7 +471,7 @@ end s = TV(undef, length(x)) s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) # Submodel likelihood # With to_submodel, we have to have a left-hand side variable to @@ -494,76 +492,39 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end -@model function demo_dot_assume_dot_observe_matrix( +@model function demo_dot_assume_observe_matrix_index( x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64} ) where {TV} s = TV(undef, length(x)) s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) - # Dotted observe for `Matrix`. - x .~ MvNormal(m, Diagonal(s)) + x[:, 1] ~ MvNormal(m, Diagonal(s)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) - return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) -end -function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m -) - return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) -end -function varnames(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] -end - -@model function demo_dot_assume_matrix_dot_observe_matrix( - x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} -) where {TV} - n = length(x) - d = length(x) ÷ 2 - s = TV(undef, d, 2) - s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) - s_vec = vec(s) - m ~ MvNormal(zeros(n), Diagonal(s_vec)) - - # Dotted observe for `Matrix`. - x .~ MvNormal(m, Diagonal(s_vec)) - - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) -end -function logprior_true( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m -) - n = length(model.args.x) - s_vec = vec(s) - return loglikelihood(InverseGamma(2, 3), s_vec) + - logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) -end function loglikelihood_true( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m ) - return loglikelihood(MvNormal(m, Diagonal(vec(s))), model.args.x) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) - s = zeros(1, 2) # used for varname concretization only - return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(m)] +function varnames(model::Model{typeof(demo_dot_assume_observe_matrix_index)}) + return [@varname(s[1]), @varname(s[2]), @varname(m)] end -@model function demo_assume_matrix_dot_observe_matrix( +@model function demo_assume_matrix_observe_matrix_index( x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} ) where {TV} n = length(x) @@ -572,33 +533,32 @@ end s_vec = vec(s) m ~ MvNormal(zeros(n), Diagonal(s_vec)) - # Dotted observe for `Matrix`. - x .~ MvNormal(m, Diagonal(s_vec)) + x[:, 1] ~ MvNormal(m, Diagonal(s_vec)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m) +function logprior_true(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m) n = length(model.args.x) s_vec = vec(s) return loglikelihood(InverseGamma(2, 3), s_vec) + logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) end function loglikelihood_true( - model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m ) return loglikelihood(MvNormal(m, Diagonal(vec(s))), model.args.x) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}) +function varnames(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}) return [@varname(s), @varname(m)] end const DemoModels = Union{ - Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_dot_assume_observe)}, Model{typeof(demo_assume_index_observe)}, Model{typeof(demo_assume_multivariate_observe)}, Model{typeof(demo_dot_assume_observe_index)}, @@ -609,9 +569,8 @@ const DemoModels = Union{ Model{typeof(demo_dot_assume_observe_index_literal)}, Model{typeof(demo_assume_submodel_observe_index_literal)}, Model{typeof(demo_dot_assume_observe_submodel)}, - Model{typeof(demo_dot_assume_dot_observe_matrix)}, - Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, - Model{typeof(demo_assume_matrix_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_observe_matrix_index)}, + Model{typeof(demo_assume_matrix_observe_matrix_index)}, } const UnivariateAssumeDemoModels = Union{ @@ -637,7 +596,7 @@ function rand_prior_true(rng::Random.AbstractRNG, model::UnivariateAssumeDemoMod end const MultivariateAssumeDemoModels = Union{ - Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_dot_assume_observe)}, Model{typeof(demo_assume_index_observe)}, Model{typeof(demo_assume_multivariate_observe)}, Model{typeof(demo_dot_assume_observe_index)}, @@ -645,8 +604,7 @@ const MultivariateAssumeDemoModels = Union{ Model{typeof(demo_dot_assume_observe_index_literal)}, Model{typeof(demo_assume_submodel_observe_index_literal)}, Model{typeof(demo_dot_assume_observe_submodel)}, - Model{typeof(demo_dot_assume_dot_observe_matrix)}, - Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_observe_matrix_index)}, } function posterior_mean(model::MultivariateAssumeDemoModels) # Get some containers to fill. @@ -699,7 +657,7 @@ function rand_prior_true(rng::Random.AbstractRNG, model::MultivariateAssumeDemoM end const MatrixvariateAssumeDemoModels = Union{ - Model{typeof(demo_assume_matrix_dot_observe_matrix)} + Model{typeof(demo_assume_matrix_observe_matrix_index)} } function posterior_mean(model::MatrixvariateAssumeDemoModels) # Get some containers to fill. @@ -786,7 +744,7 @@ And for the multivariate one (the latter one): """ const DEMO_MODELS = ( - demo_dot_assume_dot_observe(), + demo_dot_assume_observe(), demo_assume_index_observe(), demo_assume_multivariate_observe(), demo_dot_assume_observe_index(), @@ -797,7 +755,6 @@ const DEMO_MODELS = ( demo_assume_observe_literal(), demo_assume_submodel_observe_index_literal(), demo_dot_assume_observe_submodel(), - demo_dot_assume_dot_observe_matrix(), - demo_dot_assume_matrix_dot_observe_matrix(), - demo_assume_matrix_dot_observe_matrix(), + demo_dot_assume_observe_matrix_index(), + demo_assume_matrix_observe_matrix_index(), ) diff --git a/test/context_implementations.jl b/test/context_implementations.jl index 8a795320d..0ec88c07c 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -4,7 +4,7 @@ @model function test(x) μ ~ MvNormal(zeros(2), 4 * I) z = Vector{Int}(undef, length(x)) - z .~ Categorical.(fill([0.5, 0.5], length(x))) + z ~ product_distribution(Categorical.(fill([0.5, 0.5], length(x)))) for i in 1:length(x) x[i] ~ Normal(μ[z[i]], 0.1) end @@ -13,59 +13,36 @@ test([1, 1, -1])(VarInfo(), SampleFromPrior(), LikelihoodContext()) end - # https://github.com/TuringLang/DynamicPPL.jl/issues/28#issuecomment-829223577 - @testset "dot tilde: arrays of distributions" begin + @testset "dot tilde with varying sizes" begin @testset "assume" begin @model function test(x, size) y = Array{Float64,length(size)}(undef, size...) - y .~ Normal.(x) + y .~ Normal(x) return y, getlogp(__varinfo__) end for ysize in ((2,), (2, 3), (2, 3, 4)) - for x in ( - # scalar - randn(), - # drop trailing dimensions - ntuple(i -> randn(ysize[1:i]), length(ysize))..., - # singleton dimensions - ntuple( - i -> randn(ysize[1:(i - 1)]..., 1, ysize[(i + 1):end]...), - length(ysize), - )..., - ) - model = test(x, ysize) - y, lp = model() - @test lp ≈ sum(logpdf.(Normal.(x), y)) + x = randn() + model = test(x, ysize) + y, lp = model() + @test lp ≈ sum(logpdf.(Normal.(x), y)) - ys = [first(model()) for _ in 1:10_000] - @test norm(mean(ys) .- x, Inf) < 0.1 - @test norm(std(ys) .- 1, Inf) < 0.1 - end + ys = [first(model()) for _ in 1:10_000] + @test norm(mean(ys) .- x, Inf) < 0.1 + @test norm(std(ys) .- 1, Inf) < 0.1 end end @testset "observe" begin @model function test(x, y) - return y .~ Normal.(x) + return y .~ Normal(x) end for ysize in ((2,), (2, 3), (2, 3, 4)) - for x in ( - # scalar - randn(), - # drop trailing dimensions - ntuple(i -> randn(ysize[1:i]), length(ysize))..., - # singleton dimensions - ntuple( - i -> randn(ysize[1:(i - 1)]..., 1, ysize[(i + 1):end]...), - length(ysize), - )..., - ) - y = randn(ysize) - z = logjoint(test(x, y), VarInfo()) - @test z ≈ sum(logpdf.(Normal.(x), y)) - end + x = randn() + y = randn(ysize) + z = logjoint(test(x, y), VarInfo()) + @test z ≈ sum(logpdf.(Normal.(x), y)) end end end From 332131f4e82adc9af47988044a76cc27b152b359 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 7 Feb 2025 17:32:45 +0000 Subject: [PATCH 04/16] emove dot_tilde pipeline --- Project.toml | 4 - docs/src/api.md | 2 - ext/DynamicPPLZygoteRulesExt.jl | 25 --- src/DynamicPPL.jl | 4 - src/context_implementations.jl | 381 -------------------------------- src/debug_utils.jl | 145 ------------ src/extract_priors.jl | 5 - src/pointwise_logdensities.jl | 86 +------ src/simple_varinfo.jl | 51 ----- src/test_utils/contexts.jl | 12 - src/transforming.jl | 61 ----- src/values_as_in_model.jl | 23 -- test/compat/ad.jl | 28 --- 13 files changed, 1 insertion(+), 826 deletions(-) delete mode 100644 ext/DynamicPPLZygoteRulesExt.jl diff --git a/Project.toml b/Project.toml index 38382f98f..bb67cfb07 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,6 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -35,7 +34,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] @@ -44,7 +42,6 @@ DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMooncakeExt = ["Mooncake"] -DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] ADTypes = "1" @@ -74,5 +71,4 @@ OrderedCollections = "1" Random = "1.6" Requires = "1" Test = "1.6" -ZygoteRules = "0.2" julia = "1.10" diff --git a/docs/src/api.md b/docs/src/api.md index 6c58264fe..f463c50ef 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -447,10 +447,8 @@ DynamicPPL.Experimental.is_suitable_varinfo ```@docs tilde_assume -dot_tilde_assume ``` ```@docs tilde_observe -dot_tilde_observe ``` diff --git a/ext/DynamicPPLZygoteRulesExt.jl b/ext/DynamicPPLZygoteRulesExt.jl deleted file mode 100644 index 78831fdc4..000000000 --- a/ext/DynamicPPLZygoteRulesExt.jl +++ /dev/null @@ -1,25 +0,0 @@ -module DynamicPPLZygoteRulesExt - -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL, Distributions - using ZygoteRules: ZygoteRules -else - using ..DynamicPPL: DynamicPPL, Distributions - using ..ZygoteRules: ZygoteRules -end - -# https://github.com/TuringLang/Turing.jl/issues/1595 -ZygoteRules.@adjoint function DynamicPPL.dot_observe( - spl::Union{DynamicPPL.SampleFromPrior,DynamicPPL.SampleFromUniform}, - dists::AbstractArray{<:Distributions.Distribution}, - value::AbstractArray, - vi, -) - function dot_observe_fallback(spl, dists, value, vi) - DynamicPPL.increment_num_produce!(vi) - return sum(map(Distributions.loglikelihood, dists, value)), vi - end - return ZygoteRules.pullback(dot_observe_fallback, __context__, spl, dists, value, vi) -end - -end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 55e1f7e88..0559da3ef 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -101,13 +101,9 @@ export AbstractVarInfo, PrefixContext, ConditionContext, assume, - dot_assume, observe, - dot_observe, tilde_assume, tilde_observe, - dot_tilde_assume, - dot_tilde_observe, # Pseudo distributions NamedDist, NoDist, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 462012676..662f4bf48 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -258,384 +258,3 @@ function observe(right::Distribution, left, vi) increment_num_produce!(vi) return Distributions.loglikelihood(right, left), vi end - -# .~ functions - -# assume -""" - dot_tilde_assume(context::SamplingContext, right, left, vn, vi) - -Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the -model inputs), accumulate the log probability, and return the sampled value for a context -associated with a sampler. - -Falls back to -```julia -dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, vi) -``` -""" -function dot_tilde_assume(context::SamplingContext, right, left, vn, vi) - return dot_tilde_assume( - context.rng, context.context, context.sampler, right, left, vn, vi - ) -end - -# `DefaultContext` -function dot_tilde_assume(context::AbstractContext, args...) - return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), context, args...) -end -function dot_tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), rng, context, args...) -end - -function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, vi) - return dot_assume(right, left, vns, vi) -end -function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, vi) - return dot_assume(rng, sampler, right, vns, left, vi) -end - -function dot_tilde_assume(::IsParent, context::AbstractContext, args...) - return dot_tilde_assume(childcontext(context), args...) -end -function dot_tilde_assume(::IsParent, rng, context::AbstractContext, args...) - return dot_tilde_assume(rng, childcontext(context), args...) -end - -function dot_tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, sampler, right, left, vns, vi -) - return dot_assume(rng, sampler, right, vns, left, vi) -end - -# `LikelihoodContext` -function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - return dot_assume(nodist(right), left, vn, vi) -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi -) - return dot_assume(rng, sampler, nodist(right), vn, left, vi) -end - -# `PrefixContext` -function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) - return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi) -end - -function dot_tilde_assume( - rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi -) - return dot_tilde_assume( - rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi - ) -end - -""" - dot_tilde_assume!!(context, right, left, vn, vi) - -Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the -model inputs), accumulate the log probability, and return the sampled value and updated `vi`. - -Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. -""" -function dot_tilde_assume!!(context, right, left, vn, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`.~` with a model on the right-hand side is not supported; please use `~`" - ), - ) - value, logp, vi = dot_tilde_assume(context, right, left, vn, vi) - return value, acclogp_assume!!(context, vi, logp) -end - -# `dot_assume` -function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::AbstractVarInfo, -) - @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns, dist] - lp = sum(zip(vns, eachcol(r))) do (vn, ri) - return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) - end - return r, lp, vi -end - -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dist::MultivariateDistribution, - vns::AbstractVector{<:VarName}, - var::AbstractMatrix, - vi::AbstractVarInfo, -) - @assert length(dist) == size(var, 1) - r = get_and_set_val!(rng, vi, vns, dist, spl) - lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) - return r, lp, vi -end - -function dot_assume( - dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi -) - r = getindex.((vi,), vns, (dist,)) - lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns))) - return r, lp, vi -end - -function dot_assume( - dists::AbstractArray{<:Distribution}, - var::AbstractArray, - vns::AbstractArray{<:VarName}, - vi, -) - r = getindex.((vi,), vns, dists) - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) - return r, lp, vi -end - -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - vns::AbstractArray{<:VarName}, - var::AbstractArray, - vi::AbstractVarInfo, -) - r = get_and_set_val!(rng, vi, vns, dists, spl) - # Make sure `r` is not a matrix for multivariate distributions - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) - return r, lp, vi -end -function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) - return error( - "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing assume statement" - ) -end - -# HACK: These methods are only used in the `get_and_set_val!` methods below. -# FIXME: Remove these. -function _link_broadcast_new(vi, vn, dist, r) - b = to_linked_internal_transform(vi, vn, dist) - return b(r) -end - -function _maybe_invlink_broadcast(vi, vn, dist) - xvec = getindex_internal(vi, vn) - b = from_maybe_linked_internal_transform(vi, vn, dist) - return b(xvec) -end - -function get_and_set_val!( - rng, - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractVector{<:VarName}, - dist::MultivariateDistribution, - spl::Union{SampleFromPrior,SampleFromUniform}, -) - n = length(vns) - if haskey(vi, vns[1]) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if - # that's okay. - unset_flag!(vi, vns[1], "del", true) - r = init(rng, dist, spl, n) - for i in 1:n - vn = vns[i] - f_link_maybe = to_maybe_linked_internal_transform(vi, vn, dist) - setindex!!(vi, f_link_maybe(r[:, i]), vn) - setorder!(vi, vn, get_num_produce(vi)) - end - else - r = vi[vns, dist] - end - else - r = init(rng, dist, spl, n) - for i in 1:n - vn = vns[i] - if istrans(vi) - ri_linked = _link_broadcast_new(vi, vn, dist, r[:, i]) - push!!(vi, vn, ri_linked, dist, spl) - # `push!!` sets the trans-flag to `false` by default. - settrans!!(vi, true, vn) - else - push!!(vi, vn, r[:, i], dist, spl) - end - end - end - return r -end - -function get_and_set_val!( - rng, - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractArray{<:VarName}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - spl::Union{SampleFromPrior,SampleFromUniform}, -) - if haskey(vi, vns[1]) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if - # that's okay. - unset_flag!(vi, vns[1], "del", true) - f = (vn, dist) -> init(rng, dist, spl) - r = f.(vns, dists) - for i in eachindex(vns) - vn = vns[i] - dist = dists isa AbstractArray ? dists[i] : dists - f_link_maybe = to_maybe_linked_internal_transform(vi, vn, dist) - setindex!!(vi, f_link_maybe(r[i]), vn) - setorder!(vi, vn, get_num_produce(vi)) - end - else - rs = _maybe_invlink_broadcast.((vi,), vns, dists) - r = reshape(rs, size(vns)) - end - else - f = (vn, dist) -> init(rng, dist, spl) - r = f.(vns, dists) - # TODO: This will inefficient since it will allocate an entire vector. - # We could either: - # 1. Figure out the broadcast size and use a `foreach`. - # 2. Define an anonymous function which returns `nothing`, which - # we then broadcast. This will allocate a vector of `nothing` though. - if istrans(vi) - push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,)) - # NOTE: Need to add the correction. - # FIXME: This is not great. - acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r))) - # `push!!` sets the trans-flag to `false` by default. - settrans!!.((vi,), true, vns) - else - push!!.((vi,), vns, r, dists, (spl,)) - end - end - return r -end - -function set_val!( - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractVector{<:VarName}, - dist::MultivariateDistribution, - val::AbstractMatrix, -) - @assert size(val, 2) == length(vns) - foreach(enumerate(vns)) do (i, vn) - setindex!!(vi, val[:, i], vn) - end - return val -end -function set_val!( - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractArray{<:VarName}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - val::AbstractArray, -) - @assert size(val) == size(vns) - foreach(CartesianIndices(val)) do ind - setindex!!(vi, tovec(val[ind]), vns[ind]) - end - return val -end - -# observe -""" - dot_tilde_observe(context::SamplingContext, right, left, vi) - -Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log -probability, and return the observed value for a context associated with a sampler. - -Falls back to `dot_tilde_observe(context.context, context.sampler, right, left, vi)`. -""" -function dot_tilde_observe(context::SamplingContext, right, left, vi) - return dot_tilde_observe(context.context, context.sampler, right, left, vi) -end - -# Leaf contexts -function dot_tilde_observe(context::AbstractContext, args...) - return dot_tilde_observe(NodeTrait(tilde_observe, context), context, args...) -end -dot_tilde_observe(::IsLeaf, ::AbstractContext, args...) = dot_observe(args...) -function dot_tilde_observe(::IsParent, context::AbstractContext, args...) - return dot_tilde_observe(childcontext(context), args...) -end - -dot_tilde_observe(::PriorContext, right, left, vi) = 0, vi -dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0, vi - -# `MiniBatchContext` -function dot_tilde_observe(context::MiniBatchContext, right, left, vi) - logp, vi = dot_tilde_observe(context.context, right, left, vi) - return context.loglike_scalar * logp, vi -end - -# `PrefixContext` -function dot_tilde_observe(context::PrefixContext, right, left, vi) - return dot_tilde_observe(context.context, right, left, vi) -end - -""" - dot_tilde_observe!!(context, right, left, vname, vi) - -Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the observed value and updated `vi`. - -Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the information about variable -name and indices; if needed, these can be accessed through this function, though. -""" -function dot_tilde_observe!!(context, right, left, vn, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - return dot_tilde_observe!!(context, right, left, vi) -end - -""" - dot_tilde_observe!!(context, right, left, vi) - -Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log -probability, and return the observed value and updated `vi`. - -Falls back to `dot_tilde_observe(context, right, left, vi)`. -""" -function dot_tilde_observe!!(context, right, left, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - logp, vi = dot_tilde_observe(context, right, left, vi) - return left, acclogp_observe!!(context, vi, logp) -end - -# Falls back to non-sampler definition. -function dot_observe(::AbstractSampler, dist, value, vi) - return dot_observe(dist, value, vi) -end -function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) - increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value), vi -end -function dot_observe(dists::Distribution, value::AbstractArray, vi) - increment_num_produce!(vi) - return Distributions.loglikelihood(dists, value), vi -end -function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) - increment_num_produce!(vi) - return sum(Distributions.loglikelihood.(dists, value)), vi -end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 43b5054d5..328fe6983 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -113,52 +113,6 @@ function Base.show(io::IO, stmt::ObserveStmt) return print(io, ")") end -Base.@kwdef struct DotAssumeStmt <: Stmt - varname - left - right - value - logp - varinfo = nothing -end - -function Base.show(io::IO, stmt::DotAssumeStmt) - io = add_io_context(io) - print(io, " assume: ") - show_varname(io, stmt.varname) - print(io, " = ") - print(io, stmt.left) - print(io, " .~ ") - show_right(io, stmt.right) - print(io, " ") - print(io, RESULT_SYMBOL) - print(io, " ") - print(io, stmt.value) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") -end - -Base.@kwdef struct DotObserveStmt <: Stmt - left - right - logp - varinfo = nothing -end - -function Base.show(io::IO, stmt::DotObserveStmt) - io = add_io_context(io) - print(io, "observe: ") - print(io, stmt.left) - print(io, " .~ ") - show_right(io, stmt.right) - print(io, " ") - print(io, RESULT_SYMBOL) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") -end - # Some utility methods for extracting information from a trace. """ varnames_in_trace(trace) @@ -168,24 +122,14 @@ Return all the varnames present in the trace. varnames_in_trace(trace::AbstractVector) = mapreduce(varnames_in_stmt, vcat, trace) varnames_in_stmt(stmt::AssumeStmt) = [stmt.varname] -function varnames_in_stmt(stmt::DotAssumeStmt) - return stmt.varname isa VarName ? [stmt.varname] : stmt.varname -end varnames_in_stmt(::ObserveStmt) = [] -varnames_in_stmt(::DotObserveStmt) = [] function distributions_in_trace(trace::AbstractVector) return mapreduce(distributions_in_stmt, vcat, trace) end distributions_in_stmt(stmt::AssumeStmt) = [stmt.right] -function distributions_in_stmt(stmt::DotAssumeStmt) - return stmt.right isa AbstractArray ? vec(stmt.right) : [stmt.right] -end distributions_in_stmt(stmt::ObserveStmt) = [stmt.right] -function distributions_in_stmt(stmt::DotObserveStmt) - return stmt.right isa AbstractArray ? vec(stmt.right) : [stmt.right] -end """ DebugContext <: AbstractContext @@ -382,95 +326,6 @@ function DynamicPPL.tilde_observe(context::DebugContext, sampler, right, left, v return logp, vi end -# dot-assume -function record_pre_dot_tilde_assume!(context::DebugContext, vn, left, right, varinfo) - # Check for `missing`s; these should not end up here. - if _has_missings(left) - error( - "Variable $(vn) has missing has missing value(s)!\n" * - "Usage of `missing` is not supported for dotted syntax, such as " * - "`@. x ~ dist` or `x .~ dist`", - ) - end - - # TODO: Can we do without the memory allocation here? - record_varname!.(broadcast_safe(context), vn, broadcast_safe(right)) - - # Check that `left` does not contain any `` - return nothing -end - -function record_post_dot_tilde_assume!( - context::DebugContext, vns, left, right, value, logp, varinfo -) - stmt = DotAssumeStmt(; - varname=vns, - left=left, - right=right, - value=value, - logp=logp, - varinfo=context.record_varinfo ? deepcopy(varinfo) : nothing, - ) - if context.record_statements - push!(context.statements, stmt) - end - - return nothing -end - -function DynamicPPL.dot_tilde_assume(context::DebugContext, right, left, vn, vi) - record_pre_dot_tilde_assume!(context, vn, left, right, vi) - value, logp, vi = DynamicPPL.dot_tilde_assume( - childcontext(context), right, left, vn, vi - ) - record_post_dot_tilde_assume!(context, vn, left, right, value, logp, vi) - return value, logp, vi -end - -function DynamicPPL.dot_tilde_assume( - rng::Random.AbstractRNG, context::DebugContext, sampler, right, left, vn, vi -) - record_pre_dot_tilde_assume!(context, vn, left, right, vi) - value, logp, vi = DynamicPPL.dot_tilde_assume( - rng, childcontext(context), sampler, right, left, vn, vi - ) - record_post_dot_tilde_assume!(context, vn, left, right, value, logp, vi) - return value, logp, vi -end - -# dot-observe -function record_pre_dot_tilde_observe!(context::DebugContext, left, right, vi) - # Check for `missing`s; these should not end up here. - if _has_missings(left) - # TODO: Once `observe` statements receive `vn`, refer to this in the - # error message. - error( - "Encountered missing value(s) in observe!\n" * - "Usage of `missing` is not supported for dotted syntax, such as " * - "`@. x ~ dist` or `x .~ dist`", - ) - end -end - -function record_post_dot_tilde_observe!(context::DebugContext, left, right, logp, vi) - stmt = DotObserveStmt(; - left=left, - right=right, - logp=logp, - varinfo=context.record_varinfo ? deepcopy(vi) : nothing, - ) - if context.record_statements - push!(context.statements, stmt) - end - return nothing -end -function DynamicPPL.dot_tilde_observe(context::DebugContext, right, left, vi) - record_pre_dot_tilde_observe!(context, left, right, vi) - logp, vi = DynamicPPL.dot_tilde_observe(childcontext(context), right, left, vi) - record_post_dot_tilde_observe!(context, left, right, logp, vi) - return logp, vi -end - _conditioned_varnames(d::AbstractDict) = keys(d) _conditioned_varnames(d) = map(sym -> VarName{sym}(), keys(d)) function conditioned_varnames(context) diff --git a/src/extract_priors.jl b/src/extract_priors.jl index dd5aeeb04..0f312fa2c 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -39,11 +39,6 @@ function DynamicPPL.tilde_assume(context::PriorExtractorContext, right, vn, vi) return DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) end -function DynamicPPL.dot_tilde_assume(context::PriorExtractorContext, right, left, vn, vi) - setprior!(context, vn, right) - return DynamicPPL.dot_tilde_assume(childcontext(context), right, left, vn, vi) -end - """ extract_priors([rng::Random.AbstractRNG, ]model::Model) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 8c18163e3..cb9ea4894 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -100,52 +100,6 @@ function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, v return left, acclogp!!(vi, logp) end -function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) - # Defer literal `observe` to child-context. - return dot_tilde_observe!!(context.context, right, left, vi) -end -function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) - # Completely defer to child context if we are not tracking likelihoods. - if !(_include_likelihood(context)) - return dot_tilde_observe!!(context.context, right, left, vn, vi) - end - - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `dot_tilde_observe!`. - - # We want to treat `.~` as a collection of independent observations, - # hence we need the `logp` for each of them. Broadcasting the univariate - # `tilde_observe` does exactly this. - logps = _pointwise_tilde_observe(context.context, right, left, vi) - - # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. - _, _, vns = unwrap_right_left_vns(right, left, vn) - for (vn, logp) in zip(vns, logps) - # Track loglikelihood value. - push!(context, vn, logp) - end - - return left, acclogp!!(vi, sum(logps)) -end - -# FIXME: This is really not a good approach since it needs to stay in sync with -# the `dot_assume` implementations, but as things are _right now_ this is the best we can do. -function _pointwise_tilde_observe(context, right, left, vi) - # We need to drop the `vi` returned. - return broadcast(right, left) do r, l - return first(tilde_observe(context, r, l, vi)) - end -end - -function _pointwise_tilde_observe( - context, right::MultivariateDistribution, left::AbstractMatrix, vi::AbstractVarInfo -) - # We need to drop the `vi` returned. - return map(eachcol(left)) do l - return first(tilde_observe(context, right, l, vi)) - end -end - # Note on submodels (penelopeysm) # # We don't need to overload tilde_observe!! for Sampleables (yet), because it @@ -174,44 +128,6 @@ function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) return value, acclogp!!(vi, logp) end -function dot_tilde_assume!!(context::PointwiseLogdensityContext, right, left, vns, vi) - !_include_prior(context) && - return (dot_tilde_assume!!(context.context, right, left, vns, vi)) - value, logps = _pointwise_tilde_assume(context, right, left, vns, vi) - # Track loglikelihood values. - for (vn, logp) in zip(vns, logps) - push!(context, vn, logp) - end - return value, acclogp!!(vi, sum(logps)) -end - -function _pointwise_tilde_assume(context, right, left, vns, vi) - # We need to drop the `vi` returned. - values_and_logps = broadcast(right, left, vns) do r, l, vn - # HACK(torfjelde): This drops the `vi` returned, which means the `vi` is not updated - # in case of immutable varinfos. But a) atm we're only using mutable varinfos for this, - # and b) even if the variables aren't stored in the vi correctly, we're not going to use - # this vi for anything downstream anyways, i.e. I don't see a case where this would matter - # for this particular use case. - val, logp, _ = tilde_assume(context, r, vn, vi) - return val, logp - end - return map(first, values_and_logps), map(last, values_and_logps) -end -function _pointwise_tilde_assume( - context, right::MultivariateDistribution, left::AbstractMatrix, vns, vi -) - # We need to drop the `vi` returned. - values_and_logps = map(eachcol(left), vns) do l, vn - val, logp, _ = tilde_assume(context, right, vn, vi) - return val, logp - end - # HACK(torfjelde): Due to the way we handle `.~`, we should use `recombine` to stay consistent. - # But this also means that we need to first flatten the entire `values` component before recombining. - values = recombine(right, mapreduce(vec ∘ first, vcat, values_and_logps), length(vns)) - return values, map(last, values_and_logps) -end - """ pointwise_logdensities(model::Model, chain::Chains, keytype = String) @@ -357,7 +273,7 @@ end """ pointwise_loglikelihoods(model, chain[, keytype, context]) - + Compute the pointwise log-likelihoods of the model given the chain. This is the same as `pointwise_logdensities(model, chain, context)`, but only including the likelihood terms. diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 07296c3f7..324390394 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -487,57 +487,6 @@ function assume( return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi end -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - vns::AbstractArray{<:VarName}, - var::AbstractArray, - vi::SimpleOrThreadSafeSimple, -) - f = (vn, dist) -> init(rng, dist, spl) - value = f.(vns, dists) - - # Transform if we're working in transformed space. - value_raw = if dists isa Distribution - to_maybe_linked_internal.((vi,), vns, (dists,), value) - else - to_maybe_linked_internal.((vi,), vns, dists, value) - end - - # Update `vi` - vi = BangBang.setindex!!(vi, value_raw, vns) - - # Compute logp. - lp = sum(Bijectors.logpdf_with_trans.(dists, value, istrans.((vi,), vns))) - return value, lp, vi -end - -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dist::MultivariateDistribution, - vns::AbstractVector{<:VarName}, - var::AbstractMatrix, - vi::SimpleOrThreadSafeSimple, -) - @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" - - # r = get_and_set_val!(rng, vi, vns, dist, spl) - n = length(vns) - value = init(rng, dist, spl, n) - - # Update `vi`. - for (vn, val) in zip(vns, eachcol(value)) - val_linked = to_maybe_linked_internal(vi, vn, dist, val) - vi = BangBang.setindex!!(vi, val_linked, vn) - end - - # Compute logp. - lp = sum(Bijectors.logpdf_with_trans(dist, value, istrans(vi))) - return value, lp, vi -end - # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 93bb02d3b..5150be64b 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -26,22 +26,10 @@ function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, v value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) return value, logp * context.mod, vi end -function DynamicPPL.dot_tilde_assume( - context::TestLogModifyingChildContext, right, left, vn, vi -) - value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi) - return value, logp * context.mod, vi -end function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) return logp * context.mod, vi end -function DynamicPPL.dot_tilde_observe( - context::TestLogModifyingChildContext, right, left, vi -) - logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi) - return logp * context.mod, vi -end # Dummy context to test nested behaviors. struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext diff --git a/src/transforming.jl b/src/transforming.jl index 1a26d212f..0239725ae 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -30,67 +30,6 @@ function tilde_assume( return r, lp, setindex!!(vi, r_transformed, vn) end -function dot_tilde_assume( - ::DynamicTransformationContext{isinverse}, - dist::Distribution, - var::AbstractArray, - vns::AbstractArray{<:VarName}, - vi, -) where {isinverse} - r = getindex.((vi,), vns, (dist,)) - b = link_transform(dist) - - is_trans_uniques = unique(istrans.((vi,), vns)) - @assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables" - is_trans = first(is_trans_uniques) - if is_trans - @assert isinverse "Trying to link already transformed variables" - else - @assert !isinverse "Trying to invlink non-transformed variables" - end - - # Only transform if `!isinverse` since `vi[vn, right]` - # already performs the inverse transformation if it's transformed. - r_transformed = isinverse ? r : b.(r) - lp = sum(Bijectors.logpdf_with_trans.((dist,), r, (!isinverse,))) - return r, lp, setindex!!(vi, r_transformed, vns) -end - -function dot_tilde_assume( - ::DynamicTransformationContext{isinverse}, - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::AbstractVarInfo, -) where {isinverse} - @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" - r = vi[vns, dist] - - # Compute `logpdf` with logabsdet-jacobian correction. - lp = sum(zip(vns, eachcol(r))) do (vn, ri) - return Bijectors.logpdf_with_trans(dist, ri, !isinverse) - end - - # Transform _all_ values. - is_trans_uniques = unique(istrans.((vi,), vns)) - @assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables" - is_trans = first(is_trans_uniques) - if is_trans - @assert isinverse "Trying to link already transformed variables" - else - @assert !isinverse "Trying to invlink non-transformed variables" - end - - b = link_transform(dist) - for (vn, ri) in zip(vns, eachcol(r)) - # Only transform if `!isinverse` since `vi[vn, right]` - # already performs the inverse transformation if it's transformed. - vi = setindex!!(vi, isinverse ? ri : b(ri), vn) - end - - return r, lp, vi -end - function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 4cef5fa4e..d3bfd697a 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -90,29 +90,6 @@ function tilde_assume( return value, logp, vi end -# `dot_tilde_assume` -function dot_tilde_assume(context::ValuesAsInModelContext, right, left, vn, vi) - value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi) - - # Save the value. - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - broadcast_push!(context, _vns, value) - - return value, logp, vi -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, left, vn, vi -) - value, logp, vi = dot_tilde_assume( - rng, childcontext(context), sampler, right, left, vn, vi - ) - # Save the value. - _right, _left, _vns = unwrap_right_left_vns(right, left, vn) - broadcast_push!(context, _vns, value) - - return value, logp, vi -end - """ values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext]) diff --git a/test/compat/ad.jl b/test/compat/ad.jl index f76ce6f6e..e6b23f379 100644 --- a/test/compat/ad.jl +++ b/test/compat/ad.jl @@ -26,32 +26,4 @@ test_model_ad(wishart_ad(), logp_wishart_ad) end - - # https://github.com/TuringLang/Turing.jl/issues/1595 - @testset "dot_observe" begin - function f_dot_observe(x) - logp, _ = DynamicPPL.dot_observe( - SampleFromPrior(), [Normal(), Normal(-1.0, 2.0)], x, VarInfo() - ) - return logp - end - function f_dot_observe_manual(x) - return logpdf(Normal(), x[1]) + logpdf(Normal(-1.0, 2.0), x[2]) - end - - # Manual computation of the gradient. - x = randn(2) - val = f_dot_observe_manual(x) - grad = ForwardDiff.gradient(f_dot_observe_manual, x) - - @test ForwardDiff.gradient(f_dot_observe, x) ≈ grad - - y, back = Tracker.forward(f_dot_observe, x) - @test Tracker.data(y) ≈ val - @test Tracker.data(back(1)[1]) ≈ grad - - y, back = Zygote.pullback(f_dot_observe, x) - @test y ≈ val - @test back(1)[1] ≈ grad - end end From e90ea0d4a67cf49647edbc48ec3e8ffc00a50a2d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 7 Feb 2025 17:48:16 +0000 Subject: [PATCH 05/16] Fix a .~ bug --- src/compiler.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 4b9bd226d..f34f37eb9 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -164,8 +164,8 @@ isliteral(::Number) = true function isliteral(e::Expr) # In the special case that the expression is of the form `abc[blahblah]`, we consider it # to be a literal if `abc` is a literal. This is necessary for cases like - # [1.0, 2.0][1] ~ Normal() - # which are generate when turning `.~` expressions into loops over `~` expressions. + # [1.0, 2.0][idx...] ~ Normal() + # which are generated when turning `.~` expressions into loops over `~` expressions. if e.head == :ref return isliteral(e.args[1]) end @@ -509,12 +509,12 @@ end Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) - @gensym dist left_axes + @gensym dist left_axes idx return quote $dist = DynamicPPL.check_dot_tilde_rhs($right) $left_axes = axes($left) - for idx in Iterators.product($left_axes...) - $left[idx...] ~ $dist + for $idx in Iterators.product($left_axes...) + $left[$idx...] ~ $dist end end end From 90a6a9d1babd90e57e965709050b5c41a2c9b4ea Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 7 Feb 2025 18:06:41 +0000 Subject: [PATCH 06/16] Update HISTORY.md --- HISTORY.md | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 6b7247c8d..9d96a9336 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,63 @@ **Breaking** +### `.~` right hand side must be a univariate distribution + +Previously we allowed statements like + +```julia +x .~ [Normal(), Gamma()] +``` + +where the right hand side of a `.~` was an array of distributions, and ones like + +```julia +x .~ MvNormal(fill(0.0, 2), I) +``` + +where the right hand was a multivariate distribution. + +These are no longer allowed. The only things allowed on the right hand side of a `.~` statement are univariate distributions, such as + +```julia +x = Array{Float64,3}(undef, 2, 3, 4) +x .~ Normal() +``` + +The reasons for this are internal code simplification and the fact that broadcasting where both sides are multidimensional but of different dimensions is typically confusing to read. + +Cases where the dimension of the multivariate distribution or the array of distribution is the same as the dimension of the left hand side variable can be replaced with `product_distribution`. For example, instead of + +```julia +x .~ [Normal(), Gamma()] +``` + +do + +```julia +x ~ product_distribution([Normal(), Gamma()]) +``` + +This is often more performant as well. Note that using a product distribution will change how a `VarInfo` views the variable: Instead of viewing each `x[i]` as a distinct univariate variable like with `.~`, with `x ~ product_distribution(...)` `x` will be viewed as a single multivariate variable. This was already the case before this release. If, for some reason, you _do_ want each `x[i]` independently in your `VarInfo`, you can always turn the `.~` statement into a loop. + +Cases where the right hand side is of a different dimension than the left hand side, and neither is a scalar, must be replaced with a loop. For example, + +```julia +x = Array{Float64,3}(undef, 2, 3, 4) +x .~ MvNormal(fill(0, 2), I) +``` + +should be replaced with something like + +```julia +x = Array{Float64,3}(2, 3, 4) +for i in 1:3, j in 1:4 + x[:, i, j] ~ MvNormal(fill(0, 2), I) +end +``` + +This release also completely rewrites the internal implementation of `.~`, where from now on all `.~` statements are turned into loops over `~` statements at macro time. However, the only breaking aspect of this change is the above change to what's allowed on the right hand side. + ### Remove indexing by samplers This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular, From 5278e7203331643eef7cf33194a0377bcbdecbd2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 7 Feb 2025 18:45:32 +0000 Subject: [PATCH 07/16] Fix a tiny test bug --- test/pointwise_logdensities.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 5c0b2e090..2c0207469 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -49,7 +49,7 @@ end # We'll just test one, since `pointwise_logdensities(::Model, ::AbstractVarInfo)` is tested extensively, # and this is what is used to implement `pointwise_logdensities(::Model, ::Chains)`. This test suite is just # to ensure that we don't accidentally break the the version on `Chains`. - model = DynamicPPL.TestUtils.demo_dot_assume_dot_observe() + model = DynamicPPL.TestUtils.demo_dot_assume_observe() # FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced # an impl of this for containers. # NOTE(torfjelde): This only returns the varnames of the _random_ variables, i.e. excl. observed. From e302dbae2653d8275f50834c94980e31151f962b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 7 Feb 2025 18:45:50 +0000 Subject: [PATCH 08/16] Re-enable some SimpleVarInfo tests --- test/simple_varinfo.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 137c791c2..e67b5656a 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -139,8 +139,6 @@ @testset "SimpleVarInfo on $(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix() - # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) @@ -155,9 +153,10 @@ svi_nt, svi_dict, svi_vnv, - DynamicPPL.settrans!!(deepcopy(svi_nt), true), - DynamicPPL.settrans!!(deepcopy(svi_dict), true), - DynamicPPL.settrans!!(deepcopy(svi_vnv), true), + # TODO(mhauru) Fix linked SimpleVarInfos to work with our test models. + # DynamicPPL.settrans!!(deepcopy(svi_nt), true), + # DynamicPPL.settrans!!(deepcopy(svi_dict), true), + # DynamicPPL.settrans!!(deepcopy(svi_vnv), true), ) # RandOM seed is set in each `@testset`, so we need to sample # a new realization for `m` here. From 4c690104a5343a88dd81293d3ce174499885e4f5 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 12 Feb 2025 17:04:18 +0000 Subject: [PATCH 09/16] Improve changelog entry --- HISTORY.md | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 9d96a9336..adc00e494 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -18,7 +18,7 @@ where the right hand side of a `.~` was an array of distributions, and ones like x .~ MvNormal(fill(0.0, 2), I) ``` -where the right hand was a multivariate distribution. +where the right hand side was a multivariate distribution. These are no longer allowed. The only things allowed on the right hand side of a `.~` statement are univariate distributions, such as @@ -29,19 +29,30 @@ x .~ Normal() The reasons for this are internal code simplification and the fact that broadcasting where both sides are multidimensional but of different dimensions is typically confusing to read. -Cases where the dimension of the multivariate distribution or the array of distribution is the same as the dimension of the left hand side variable can be replaced with `product_distribution`. For example, instead of +If the right hand side and the left hand side have the same dimension, one can simply use `~`. Arrays of distributions can be replaced with `product_distribution`. So instead of ```julia x .~ [Normal(), Gamma()] +x .~ Normal.(y) +x .~ MvNormal(fill(0.0, 2), I) ``` do ```julia x ~ product_distribution([Normal(), Gamma()]) +x ~ product_distribution(Normal.(y)) +x ~ MvNormal(fill(0.0, 2), I) ``` -This is often more performant as well. Note that using a product distribution will change how a `VarInfo` views the variable: Instead of viewing each `x[i]` as a distinct univariate variable like with `.~`, with `x ~ product_distribution(...)` `x` will be viewed as a single multivariate variable. This was already the case before this release. If, for some reason, you _do_ want each `x[i]` independently in your `VarInfo`, you can always turn the `.~` statement into a loop. +This is often more performant as well. Note that using `~` rather than `.~` does change the internal storage format a bit: With `.~` `x[i]` are stored as separate variables, with `~` as a single multivariate variable `x`. In most cases this does not change anything for the user, but if it does cause issues, e.g. if you are dealing with `VarInfo` objects directly and need to keep the old behavior, you can always expand into a loop, such as + +```julia +dists = Normal.(y) +for i in 1:length(dists) + x[i] ~ dists[i] +end +``` Cases where the right hand side is of a different dimension than the left hand side, and neither is a scalar, must be replaced with a loop. For example, From f794acc977739ce19ef2b545d42a6810cd80551e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 12 Feb 2025 17:04:49 +0000 Subject: [PATCH 10/16] Improve error message --- src/compiler.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index f34f37eb9..338a9e6c0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -203,6 +203,15 @@ function check_dot_tilde_rhs(@nospecialize(x)) ArgumentError("the right-hand side of a `.~` must be a `UnivariateDistribution`") ) end +function check_dot_tilde_rhs(x::AbstactArray{<:Distribution}) + msg = """ + As of v0.35, DynamicPPL does not allow arrays of distributions in `.~`. \ + Please use `product_distribution` instead, or write a loop if necessary. \ + See https://github.com/TuringLang/DynamicPPL.jl/releases/tag/v0.35.0 for more \ + details.\ + """ + return throw(ArgumentError(msg)) +end check_dot_tilde_rhs(x::UnivariateDistribution) = x function check_dot_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} model = check_dot_tilde_rhs(x.model) From f74258ed3bcdff08d29b740a32fbfae7fd391a70 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 12 Feb 2025 17:09:54 +0000 Subject: [PATCH 11/16] Fix trivial typos --- HISTORY.md | 2 +- src/compiler.jl | 2 +- src/model.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index adc00e494..59030e600 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -82,7 +82,7 @@ This release removes the feature of `VarInfo` where it kept track of which varia - `unflatten` no longer accepts a sampler as an argument - `eltype(::VarInfo)` no longer accepts a sampler as an argument - `keys(::VarInfo)` no longer accepts a sampler as an argument - - `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument. + - `VarInfo(::VarInfo, ::Sampler, ::AbstractVector)` no longer accepts the sampler argument. ### Reverse prefixing order diff --git a/src/compiler.jl b/src/compiler.jl index 338a9e6c0..92f3522a7 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -203,7 +203,7 @@ function check_dot_tilde_rhs(@nospecialize(x)) ArgumentError("the right-hand side of a `.~` must be a `UnivariateDistribution`") ) end -function check_dot_tilde_rhs(x::AbstactArray{<:Distribution}) +function check_dot_tilde_rhs(::AbstractArray{<:Distribution}) msg = """ As of v0.35, DynamicPPL does not allow arrays of distributions in `.~`. \ Please use `product_distribution` instead, or write a loop if necessary. \ diff --git a/src/model.jl b/src/model.jl index 3601d77fd..0fb18f463 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,5 +1,5 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstactContext} + struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} From 8f3e7b9b0c2f2e9622ce19d11eedb70572148857 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 12 Feb 2025 17:36:21 +0000 Subject: [PATCH 12/16] Fix pointwise_logdensity test --- test/pointwise_logdensities.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 2c0207469..61c842638 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -48,8 +48,8 @@ end @testset "pointwise_logdensities chain" begin # We'll just test one, since `pointwise_logdensities(::Model, ::AbstractVarInfo)` is tested extensively, # and this is what is used to implement `pointwise_logdensities(::Model, ::Chains)`. This test suite is just - # to ensure that we don't accidentally break the the version on `Chains`. - model = DynamicPPL.TestUtils.demo_dot_assume_observe() + # to ensure that we don't accidentally break the version on `Chains`. + model = DynamicPPL.TestUtils.demo_assume_index_observe() # FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced # an impl of this for containers. # NOTE(torfjelde): This only returns the varnames of the _random_ variables, i.e. excl. observed. From af8b01707d174a1c35521c3c9f19efab1d2b5a3c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 13 Feb 2025 14:26:51 +0000 Subject: [PATCH 13/16] Remove pointless check_dot_tilde_rhs method --- src/compiler.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 92f3522a7..8bde5e784 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -213,10 +213,6 @@ function check_dot_tilde_rhs(::AbstractArray{<:Distribution}) return throw(ArgumentError(msg)) end check_dot_tilde_rhs(x::UnivariateDistribution) = x -function check_dot_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} - model = check_dot_tilde_rhs(x.model) - return Sampleable{typeof(model),AutoPrefix}(model) -end """ unwrap_right_vn(right, vn) From 67622b954e82f052cbaf6e01e190ee383a783ea4 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 13 Feb 2025 14:36:32 +0000 Subject: [PATCH 14/16] Add tests for old .~ syntax --- test/compiler.jl | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/test/compiler.jl b/test/compiler.jl index 051eba618..8d81c530a 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -288,6 +288,33 @@ module Issue537 end x = vdemo()() @test all((isassigned(x, i) for i in eachindex(x))) end + + # A couple of uses of .~ that are no longer valid as of v0.35. + @testset "old .~ syntax" begin + @model function multivariate_dot_tilde() + x = Vector{Float64}(undef, 2) + x .~ MvNormal(zeros(2), I) + return x + end + expected_error = ArgumentError( + "the right-hand side of a `.~` must be a `UnivariateDistribution`" + ) + @test_throws expected_error (multivariate_dot_tilde()(); true) + + @model function vector_dot_tilde() + x = Vector{Float64}(undef, 2) + x .~ [Normal(), Normal()] + return x + end + expected_error = ArgumentError(""" + As of v0.35, DynamicPPL does not allow arrays of distributions in `.~`. \ + Please use `product_distribution` instead, or write a loop if necessary. \ + See https://github.com/TuringLang/DynamicPPL.jl/releases/tag/v0.35.0 for more \ + details.\ + """) + @test_throws expected_error (vector_dot_tilde()(); true) + end + @testset "nested model" begin function makemodel(p) @model function testmodel(x) From 15dc8a27e598eef515aca5fa1bcef7f739c38d97 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 17 Feb 2025 18:21:52 +0000 Subject: [PATCH 15/16] Bump Mooncake patch version to v0.4.90 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index bb67cfb07..21d1744bd 100644 --- a/Project.toml +++ b/Project.toml @@ -66,7 +66,7 @@ LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" -Mooncake = "0.4.59" +Mooncake = "0.4.90" OrderedCollections = "1" Random = "1.6" Requires = "1" From 30c241fddc0bea815ceddd57a813cfb2ace1f510 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 17 Feb 2025 20:13:36 +0000 Subject: [PATCH 16/16] Bump Mooncake to 0.4.95 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c736ac118..7cd47fdbb 100644 --- a/Project.toml +++ b/Project.toml @@ -62,7 +62,7 @@ LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" -Mooncake = "0.4.90" +Mooncake = "0.4.95" OrderedCollections = "1" Random = "1.6" Requires = "1"