Skip to content

Missing marginal rules multiplication node #505

@bartvanerp

Description

@bartvanerp

For the following model:

@model function online_mmodel(x_prev_mean, x_prev_prec, ν_obs, W_obs, α_proc, β_proc, data_vec, y)
    x_prev ~ Normal(mean=x_prev_mean, precision=x_prev_prec)
    gamma ~ Gamma(shape=gamma_shape, rate=gamma_beta)
    x ~ Normal(mean=x_prev, precision=gamma)
    z := x * data_vec
    Λ ~ Wishart(ν, W)  
    y ~ MvNormalMeanPrecision(z, Λ) 
end
constraints = @constraints begin
    q(x_prev, x,  z, Λ, gamma) = q(x_prev, x, z)q(Λ)q(gamma)
end

with x being univariate and data_vec and y both being vector observations, the free energy computation fails. It seems to miss the following rule:

@marginalrule typeof(*)(:A_in) (m_out::MvNormalMeanPrecision, m_A::NormalMeanPrecision, m_in::PointMass{<:AbstractVector}, meta::Any) = begin 
    return ...
end

I think we should support it as Julia also allows scalar-vector products in the form of 3 * [1, 2] = [1,2] * 3 =[3, 6]

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions