Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions src/nodes/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ end
# `PredefinedNodeFunctionalForm` are generally the nodes that are defined with the `@node` macro
# The `UndefinedNodeFunctionalForm` nodes can be created as well, but only if the `fform` is a `Function` (see `predefined/delta.jl`)
function factornode(::PredefinedNodeFunctionalForm, fform::F, interfaces::I, factorization) where {F, I}
processed_interfaces = __prepare_interfaces_generic(fform, interfaces)
processed_interfaces = prepare_interfaces_generic(fform, interfaces)
localclusters = FactorNodeLocalClusters(processed_interfaces, collect_factorisation(fform, factorization))
return FactorNode(fform, processed_interfaces, localclusters)
end
Expand All @@ -197,14 +197,44 @@ interfaceindex(factornode::FactorNode, iname::Symbol) =
interfaceindices(factornode::FactorNode, iname::Symbol) = (interfaceindex(factornode, iname),)
interfaceindices(factornode::FactorNode, inames::NTuple{N, Symbol}) where {N} = map(iname -> interfaceindex(factornode, iname), inames)

# Takes a named tuple of abstract variables and converts to a tuple of NodeInterfaces with the same order
function __prepare_interfaces_generic(fform::F, interfaces::AbstractVector) where {F}
function prepare_interfaces_generic(fform::F, interfaces::AbstractVector) where {F}
prepare_interfaces_check_nonempty(fform, interfaces)
prepare_interfaces_check_adjacent_duplicates(fform, interfaces)
prepare_interfaces_check_numarguments(fform, interfaces)
return map(enumerate(interfaces)) do (index, (name, variable))
return NodeInterface(alias_interface(fform, index, name), variable)
end
end

## activate!
function prepare_interfaces_check_nonempty(fform, interfaces)
length(interfaces) > 0 || error(lazy"At least one argument is required for a factor node. Got none for `$(fform)`")
end

function prepare_interfaces_check_adjacent_duplicates(fform, interfaces)
# Here we create an iterator that checks ONLY adjacent interfaces
# The reason here is that we don't want to check all possible combinations of all input interfaces
# because that would require allocating an intermediate storage for `Set`, which would harm the
# performance of nodes creation. The `zip(interfaces, Iterators.drop(interfaces, 1))` creates a generic
# iterator of adjacent interface pairs
foreach(zip(interfaces, Iterators.drop(interfaces, 1))) do (left, right)
lname, _ = left
rname, _ = right
if isequal(lname, rname)
error(
lazy"`$fform` has duplicate entry for interface `$lname`. Did you pass an array (e.g. `x`) instead of an array element (e.g. `x[i]`)? Check your variable indices."
)
end
end
end

function prepare_interfaces_check_numarguments(fform::F, interfaces) where {F}
prepare_interfaces_check_num_inputarguments(fform, inputinterfaces(fform), interfaces)
end

function prepare_interfaces_check_num_inputarguments(fform, inputinterfaces::Val{Input}, interfaces) where {Input}
(length(interfaces) - 1) === length(Input) ||
error(lazy"Expected $(length(Input)) input arguments for `$(fform)`, got $(length(interfaces) - 1): $(join(map(first, Iterators.drop(interfaces, 1)), \", \"))")
end

struct FactorNodeActivationOptions{M, D, P, A, S}
metadata::M
Expand Down
64 changes: 64 additions & 0 deletions test/nodes/nodes_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,67 @@ end
@test occursin(r"DummyNodeForDocumentationStochastic.*Stochastic.*out, x, y \(or yy\)", documentation)
@test occursin(r"DummyNodeForDocumentationDeterministic.*Deterministic.*out, x \(or xx, xxx\), y", documentation)
end

@testitem "Predefined nodes should check the arguments supplied" begin
struct StochasticNodeWithThreeArguments end
struct DeterministicNodeWithFourArguments end

@node StochasticNodeWithThreeArguments Stochastic [out, x, y, z]
@node DeterministicNodeWithFourArguments Deterministic [out, x, y, z, w]

out = randomvar()
x = randomvar()
y = randomvar()
z = randomvar()
w = randomvar()

@test factornode(StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:y, y), (:z, z)], ((1, 2, 3),)) isa ReactiveMP.FactorNode
@test factornode(DeterministicNodeWithFourArguments, [(:out, out), (:x, x), (:y, y), (:z, z), (:w, w)], ((1, 2, 3, 4),)) isa ReactiveMP.FactorNode

@test_throws r"At least one argument is required for a factor node. Got none for `.*StochasticNodeWithThreeArguments`" factornode(StochasticNodeWithThreeArguments, [], ())
@test_throws r"At least one argument is required for a factor node. Got none for `.*DeterministicNodeWithFourArguments`" factornode(DeterministicNodeWithFourArguments, [], ())
@test_throws r"Expected 3 input arguments for `.*StochasticNodeWithThreeArguments`, got 1: x" factornode(StochasticNodeWithThreeArguments, [(:out, out), (:x, x)], ((1,),))
@test_throws r"Expected 3 input arguments for `.*StochasticNodeWithThreeArguments`, got 2: x, y" factornode(
StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:y, y)], ((1, 2),)
)
@test_throws r"Expected 3 input arguments for `.*StochasticNodeWithThreeArguments`, got 4: x, y, z, w" factornode(
StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:y, y), (:z, z), (:w, w)], ((1, 2, 3, 4),)
)
@test_throws r"Expected 4 input arguments for `.*DeterministicNodeWithFourArguments`, got 1: x" factornode(DeterministicNodeWithFourArguments, [(:out, out), (:x, x)], ((1,),))
@test_throws r"Expected 4 input arguments for `.*DeterministicNodeWithFourArguments`, got 2: x, y" factornode(
DeterministicNodeWithFourArguments, [(:out, out), (:x, x), (:y, y)], ((1, 2),)
)
@test_throws r"Expected 4 input arguments for `.*DeterministicNodeWithFourArguments`, got 3: x, y, z" factornode(
DeterministicNodeWithFourArguments, [(:out, out), (:x, x), (:y, y), (:z, z)], ((1, 2, 3),)
)

@test_throws r"`.*StochasticNodeWithThreeArguments` has duplicate entry for interface `w`. Did you pass an array \(e.g. `x`\) instead of an array element \(e\.g\. `x\[i\]`\)\? Check your variable indices\." factornode(
StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:w, w), (:w, w)], ((1, 2, 3, 4),)
)
@test_throws r"`.*StochasticNodeWithThreeArguments` has duplicate entry for interface `w`. Did you pass an array \(e.g. `x`\) instead of an array element \(e\.g\. `x\[i\]`\)\? Check your variable indices\." factornode(
StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:y, y), (:z, z), (:w, w), (:w, w)], ((1, 2, 3, 4, 5, 6),)
)
end

@testitem "Generic node construction checks should not allocate" begin
import ReactiveMP: prepare_interfaces_check_adjacent_duplicates, prepare_interfaces_check_nonempty, prepare_interfaces_check_numarguments

struct NodeForCheckDuplicatesTest end
@node NodeForCheckDuplicatesTest Stochastic [out, x, y, z]

out = randomvar()
x = randomvar()
y = randomvar()
z = randomvar()

interfaces = [(:out, out), (:x, x), (:y, y), (:z, z)]
# compile first
function foo(interfaces)
prepare_interfaces_check_nonempty(NodeForCheckDuplicatesTest, interfaces)
prepare_interfaces_check_adjacent_duplicates(NodeForCheckDuplicatesTest, interfaces)
prepare_interfaces_check_numarguments(NodeForCheckDuplicatesTest, interfaces)
end
foo(interfaces)
@test (@allocated(foo(interfaces)) == 0)
@test (@allocations(foo(interfaces)) == 0)
end