Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions src/helpers/macrohelpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,99 @@ macro test_inferred(T, expression)
end)
end

function check_rule_interfaces(macrotype, fform, lambda, ifaces, on_type, m_names, q_names; mod = __MODULE__)
# skip rules like (typeof(+))(:in1_in2) for which interfaces returns nothing
if ifaces === nothing
return nothing
end
names_expected = valof_set(ifaces, mod)
onames = valof_set(on_type, mod)
mnames = valof_set(m_names, mod)
qnames = valof_set(q_names, mod)
names_used = union(onames, mnames, qnames)

names_unknown = setdiff(names_expected, names_used)
if !isempty(names_unknown)
missing_list = join(sort(collect(names_unknown)), ", ")
expected_list = join(sort(collect(names_expected)), ", ")
provided_list = join(sort(collect(names_used)), ", ")

throw(ArgumentError("""
Interface mismatch for $(macrotype) $(fform) $(lambda):
Expected symbols: $expected_list
Provided symbols: $provided_list
Missing symbols: $missing_list
"""))
end

names_extra = setdiff(names_used, names_expected)
if !isempty(names_extra)
extras_list = join(sort(collect(names_extra)), ", ")
expected_list = join(sort(collect(names_expected)), ", ")
provided_list = join(sort(collect(names_used)), ", ")

throw(ArgumentError("""
Interface mismatch for $(macrotype) $(fform) $(lambda):
Expected symbols: $expected_list
Provided symbols: $provided_list
Extra symbols: $extras_list
"""))
end
end

function valof_set(x::Nothing, mod::Module)
return Set{Symbol}()
end

function valof_set(x::Symbol, mod::Module)
s = Set{Symbol}()
if x === :Nothing
return s
end
# Split joint message symbol by underscores
for part in split(string(x), '_')
push!(s, Symbol(part))
end
return s
end

function valof_set(x::Val, mod::Module)
return valof_set(typeof(x).parameters[1], mod)
end

valof_set(x::Type{<:Val}, mod::Module) = valof_set(first(x.parameters), mod)

function valof_set(x::Type{<:Tuple}, mod::Module)
# Handle tuple types like Tuple{Val{:inputs}, Int}
s = Set{Symbol}()
for p in x.parameters
if p <: Integer
continue
end
s = union(s, valof_set(p, mod))
end
return s
end

function valof_set(x::Tuple, mod::Module)
# Handle **tuple values** (instances)
s = Set{Symbol}()
for xi in x
s = union(s, valof_set(xi, mod))
end
return s
end

function valof_set(x::Expr, mod::Module)
@capture(x, (Val{values_}))
return __split_val(values, mod)
end

__split_val(x::QuoteNode, mod) = valof_set(x.value, mod)
__split_val(x::Expr, mod) = valof_set(Tuple(map(z -> z.value, x.args)), mod)
__split_val(x::Nothing, mod) = error("Unexpected expression encountered (Not of form `Val{...}`).")

# Fallback for other types
valof_set(x, mod::Module) = Set{Symbol}()

end
20 changes: 15 additions & 5 deletions src/nodes/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,15 @@ Converts the given `name` to a correct interface name for a given factor node ty
"""
function alias_interface end

"""
nodesymbol_to_nodefform(symbol_in_val)

Returns the factor node type associated with a given symbol wrapped in a `Val` object. Returns `nothing` for unknown symbol.
"""
function nodesymbol_to_nodefform(::Val)
return nothing
end

node_expression_extract_interface(s::Symbol) = (s, [])

function node_expression_extract_interface(e::Expr)
Expand All @@ -313,10 +322,10 @@ function generate_node_expression(node_fform, node_type, node_interfaces)
interfaces = map(node_expression_extract_interface, node_interfaces.args)

# Determine whether we should dispatch on `typeof($fform)` or `Type{$node_fform}`
dispatch_type = if @capture(node_fform, typeof(fform_))
:(typeof($fform))
dispatch_type, node_symbol = if @capture(node_fform, typeof(fform_))
(:(typeof($fform)), fform)
else
:(Type{<:$node_fform})
(:(Type{<:$node_fform}), node_fform)
end

foreach(interfaces) do (name, aliases)
Expand Down Expand Up @@ -389,9 +398,10 @@ function generate_node_expression(node_fform, node_type, node_interfaces)
result = quote
@doc $doc ReactiveMP.is_predefined_node(::$dispatch_type) = ReactiveMP.PredefinedNodeFunctionalForm()

ReactiveMP.sdtype(::$dispatch_type) = (ReactiveMP.$node_type)()
ReactiveMP.interfaces(::$dispatch_type) = Val($(Tuple(map(first, interfaces))))
ReactiveMP.sdtype(::$dispatch_type) = (ReactiveMP.$node_type)()
ReactiveMP.interfaces(::$dispatch_type) = Val($(Tuple(map(first, interfaces))))
ReactiveMP.inputinterfaces(::$dispatch_type) = Val($(Tuple(map(first, skipindex(interfaces, 1)))))
ReactiveMP.nodesymbol_to_nodefform(::Val{$(QuoteNode(node_symbol))}) = $node_symbol

$collect_factorisation_fn
$nodefunctions
Expand Down
24 changes: 24 additions & 0 deletions src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ macro rule(fform, lambda)
@capture(args, (inputs__, meta::metatype_) | (inputs__,)) || error("Error in macro. Lambda body arguments specification is incorrect")

fuppertype = MacroHelpers.upper_type(fformtype)
fbottomtype = MacroHelpers.bottom_type(fformtype)
on_type, on_index, on_index_init = rule_macro_parse_on_tag(on)
whereargs = whereargs === nothing ? [] : whereargs
metatype = metatype === nothing ? :Nothing : metatype
Expand All @@ -405,6 +406,17 @@ macro rule(fform, lambda)
m_names, m_types, m_init_block = rule_macro_parse_fn_args(inputs; specname = :messages, prefix = :m_, proxy = :(ReactiveMP.Message))
q_names, q_types, q_init_block = rule_macro_parse_fn_args(inputs; specname = :marginals, prefix = :q_, proxy = :(ReactiveMP.Marginal))

# Some `@rules` are more complex in terms of functional form specification, e.g. `NormalMixture{N}`
if fbottomtype isa Symbol
# Not all nodes are defined with the `@node` macro, so we need to check if the node is defined with the `@node` macro
# `nodesymbol_to_nodefform` may return `nothing` for such nodes, in this case we skip the interface check
nodefform_from_symbol = ReactiveMP.nodesymbol_to_nodefform(Val(fbottomtype))
if !isnothing(nodefform_from_symbol)
ifaces = ReactiveMP.interfaces(nodefform_from_symbol)
MacroHelpers.check_rule_interfaces("@rule", fform, lambda, ifaces, on_type, m_names, q_names; mod = __module__)
end
end

output = quote
$(
rule_function_expression(fuppertype, on_type, vconstraint, m_names, m_types, q_names, q_types, metatype, whereargs) do
Expand Down Expand Up @@ -588,6 +600,7 @@ macro marginalrule(fform, lambda)
@capture(args, (inputs__, meta::metatype_) | (inputs__,)) || error("Error in macro. Lambda body arguments specification is incorrect")

fuppertype = MacroHelpers.upper_type(fformtype)
fbottomtype = MacroHelpers.bottom_type(fformtype)
on_type, on_index, on_index_init = rule_macro_parse_on_tag(on)
whereargs = whereargs === nothing ? [] : whereargs
metatype = metatype === nothing ? :Any : metatype
Expand All @@ -602,6 +615,17 @@ macro marginalrule(fform, lambda)
m_names, m_types, m_init_block = rule_macro_parse_fn_args(inputs; specname = :messages, prefix = :m_, proxy = :(ReactiveMP.Message))
q_names, q_types, q_init_block = rule_macro_parse_fn_args(inputs; specname = :marginals, prefix = :q_, proxy = :(ReactiveMP.Marginal))

# Some `@rules` are more complex in terms of functional form specification, e.g. `NormalMixture{N}`
if fbottomtype isa Symbol
# Not all nodes are defined with the `@node` macro, so we need to check if the node is defined with the `@node` macro
# `nodesymbol_to_nodefform` may return `nothing` for such nodes, in this case we skip the interface check
nodefform_from_symbol = ReactiveMP.nodesymbol_to_nodefform(Val(fbottomtype))
if !isnothing(nodefform_from_symbol)
ifaces = ReactiveMP.interfaces(nodefform_from_symbol)
MacroHelpers.check_rule_interfaces("@marginalrule", fform, lambda, ifaces, on_type, m_names, q_names; mod = __module__)
end
end

output = quote
$(
marginalrule_function_expression(fuppertype, on_type, m_names, m_types, q_names, q_types, metatype, whereargs) do
Expand Down
16 changes: 14 additions & 2 deletions src/score/score.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ macro average_energy(fformtype, lambda)
@capture(args, (inputs__, meta::metatype_) | (inputs__,)) || error("Error in macro. Lambda body arguments speicifcation is incorrect")

fuppertype = MacroHelpers.upper_type(fformtype)
whereargs = whereargs === nothing ? [] : whereargs
metatype = metatype === nothing ? :Nothing : metatype
fbottomtype = MacroHelpers.bottom_type(fformtype)
whereargs = whereargs === nothing ? [] : whereargs
metatype = metatype === nothing ? :Nothing : metatype

inputs = map(inputs) do input
@capture(input, iname_::itype_) || error("Error in macro. Input $(input) is incorrect")
Expand All @@ -91,6 +92,17 @@ macro average_energy(fformtype, lambda)

q_names, q_types, q_init_block = rule_macro_parse_fn_args(inputs; specname = :marginals, prefix = :q_, proxy = :Marginal)

# Some `@rules` are more complex in terms of functional form specification, e.g. `NormalMixture{N}`
if fbottomtype isa Symbol
# Not all nodes are defined with the `@node` macro, so we need to check if the node is defined with the `@node` macro
# `nodesymbol_to_nodefform` may return `nothing` for such nodes, in this case we skip the interface check
nodefform_from_symbol = ReactiveMP.nodesymbol_to_nodefform(Val(fbottomtype))
if !isnothing(nodefform_from_symbol)
ifaces = ReactiveMP.interfaces(nodefform_from_symbol)
MacroHelpers.check_rule_interfaces("@average_energy", fformtype, lambda, ifaces, nothing, nothing, q_names; mod = __module__)
end
end

result = quote
function ReactiveMP.score(::AverageEnergy, fform::$(fuppertype), marginals_names::$(q_names), marginals::$(q_types), meta::$(metatype)) where {$(whereargs...)}
$(q_init_block...)
Expand Down
82 changes: 82 additions & 0 deletions test/nodes/nodes_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,85 @@ end
UnknownDistribution, interfaces, ((1, 2, 3),)
)
end

@testitem "new node defined with `@node` macro should define Symbol -> Node function mapping" begin
struct DummyNodeToTestSymbolToNodeFunctionMapping end

@node DummyNodeToTestSymbolToNodeFunctionMapping Stochastic [out, x, y, z]

@test ReactiveMP.nodesymbol_to_nodefform(Val(:DummyNodeToTestSymbolToNodeFunctionMapping)) == DummyNodeToTestSymbolToNodeFunctionMapping
end

@testitem "nodesymbol_to_nodefform returns nothing for an unknown node symbol" begin
@test ReactiveMP.nodesymbol_to_nodefform(Val(:UnknownNode)) === nothing
end

@testitem "`@node` macro should error if defined a rule for undefined interface" begin
struct DummyNodeToTestRuleForUndefinedInterface end

@node DummyNodeToTestRuleForUndefinedInterface Stochastic [out, x]

@test_throws "Interface mismatch for @rule DummyNodeToTestRuleForUndefinedInterface(:out, Marginalisation)" eval(
quote
@rule DummyNodeToTestRuleForUndefinedInterface(:out, Marginalisation) (m_y::PointMass,) = 0.0
end
)
@test_throws "Interface mismatch for @rule DummyNodeToTestRuleForUndefinedInterface(:out, Marginalisation)" eval(
quote
@rule DummyNodeToTestRuleForUndefinedInterface(:out, Marginalisation) (q_y::PointMass,) = 0.0
end
)
@test_throws "Interface mismatch for @rule DummyNodeToTestRuleForUndefinedInterface(:x, Marginalisation)" eval(
quote
@rule DummyNodeToTestRuleForUndefinedInterface(:x, Marginalisation) (m_y::PointMass,) = 0.0
end
)
@test_throws "Interface mismatch for @rule DummyNodeToTestRuleForUndefinedInterface(:x, Marginalisation)" eval(
quote
@rule DummyNodeToTestRuleForUndefinedInterface(:x, Marginalisation) (q_y::PointMass,) = 0.0
end
)

@test_throws "Interface mismatch for @marginalrule DummyNodeToTestRuleForUndefinedInterface(:out) (m_y::Any, m_x::Any)" eval(
quote
@marginalrule DummyNodeToTestRuleForUndefinedInterface(:out) (m_y::Any, m_x::Any) = 0.0
end
)

@test_throws "Interface mismatch for @average_energy DummyNodeToTestRuleForUndefinedInterface (q_y::Any, q_x::Any)" eval(
quote
@average_energy DummyNodeToTestRuleForUndefinedInterface (q_y::Any, q_x::Any) = 0.0
end
)

function dummynodetestruleforundefinedinterface end

@node typeof(dummynodetestruleforundefinedinterface) Stochastic [out, x]

@test_throws "Interface mismatch for @rule (typeof(dummynodetestruleforundefinedinterface))(:out, Marginalisation)" eval(
quote
@rule typeof(dummynodetestruleforundefinedinterface)(:out, Marginalisation) (m_y::PointMass,) = 0.0
end
)
@test_throws "Interface mismatch for @rule (typeof(dummynodetestruleforundefinedinterface))(:out, Marginalisation)" eval(
quote
@rule typeof(dummynodetestruleforundefinedinterface)(:out, Marginalisation) (q_y::PointMass,) = 0.0
end
)
@test_throws "Interface mismatch for @rule (typeof(dummynodetestruleforundefinedinterface))(:x, Marginalisation)" eval(
quote
@rule typeof(dummynodetestruleforundefinedinterface)(:x, Marginalisation) (m_y::PointMass,) = 0.0
end
)
@test_throws "Interface mismatch for @rule (typeof(dummynodetestruleforundefinedinterface))(:x, Marginalisation)" eval(
quote
@rule typeof(dummynodetestruleforundefinedinterface)(:x, Marginalisation) (q_y::PointMass,) = 0.0
end
)

@test_throws "Interface mismatch for @average_energy typeof(dummynodetestruleforundefinedinterface) (q_y::Any, q_x::Any)" eval(
quote
@average_energy typeof(dummynodetestruleforundefinedinterface) (q_y::Any, q_x::Any) = 0.0
end
)
end
Loading