-
Notifications
You must be signed in to change notification settings - Fork 34
Closed
Description
I made a custom node and want to check the forwarded message out of that by using it in a model. This is the implementation:
struct MyNode end
@node MyNode Stochastic [out, in1, in2]
# rule specification
@rule MyNode(:out, Marginalisation) (m_in1::UnivariateNormalDistributionsFamily, m_in2::UnivariateNormalDistributionsFamily) = begin
min1, vin1 = mean_var(m_in1)
min2, vin2 = mean_var(m_in2)
return NormalMeanVariance(min1 + min2, vin1 + vin2)
end
@rule MyNode(:in1, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in2::UnivariateNormalDistributionsFamily) = begin
min2, vin2 = mean_var(m_in2)
mout, vout = mean_var(m_out)
return NormalMeanVariance(mout - min2, vout + vin2)
end
@rule MyNode(:in2, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in1::UnivariateNormalDistributionsFamily) = begin
min1, vin1 = mean_var(m_in1)
mout, vout = mean_var(m_out)
return NormalMeanVariance(mout - min1, vout + vin1)
end
@rule MyNode(:in1, Marginalisation) (q_out::Any, m_in2::UnivariateNormalDistributionsFamily) = begin
min2, vin2 = mean_var(m_in2)
return NormalMeanVariance(mean(q_out) - min2, vin2)
end
@rule MyNode(:in2, Marginalisation) (q_out::Any, m_in1::UnivariateNormalDistributionsFamily) = begin
min1, vin1 = mean_var(m_in1)
return NormalMeanVariance(mean(q_out) - min1, vin1)
end
@model function My_model(y)
A ~ NormalMeanVariance(2.0,1.0)
B ~ NormalMeanVariance(1.0,1.0)
y ~ MyNode(A,B)
end
result = infer(
model = My_model(),
predictvars = (y = KeepLast(), ),
)
But when I predict the y using:
result.predictions[:y]
it gives "missing". As an alternative, I also tried
result = infer(
model = My_model(),
data = (y = missing , ),
)
The result is the same.
Metadata
Metadata
Assignees
Labels
No labels