Skip to content

Fix condition and fix in submodels #892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 69 additions & 29 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,25 @@

**Breaking changes**

### AD testing utilities

`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default.
To disable this, pass the `linked=false` keyword argument.
If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure.
This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information.
From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`.
### Submodels: conditioning

### SimpleVarInfo linking / invlinking

Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error.
Variables in a submodel can now be conditioned and fixed in a correct way.
See https://github.com/TuringLang/DynamicPPL.jl/issues/857 for a full illustration, but essentially it means you can now do this:

### VarInfo constructors

`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead.

The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed.
If you were not using this argument (most likely), then there is no change needed.
If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below).

The `UntypedVarInfo` constructor and type is no longer exported.
If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead.

The `TypedVarInfo` constructor and type is no longer exported.
The _type_ has been replaced with `DynamicPPL.NTVarInfo`.
The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`.
```julia
@model function inner()
x ~ Normal()
return y ~ Normal()
end
@model function outer()
return a ~ to_submodel(inner() | (x=1.0,))
end
```

Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail.
Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs.
Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface.
and the `a.x` variable will be correctly conditioned.
(Previously, you would have to condition `inner()` with the variable `a.x`, meaning that you would need to know what prefix to use before you had actually prefixed it.)

### VarName prefixing behaviour
### Submodel prefixing

The way in which VarNames in submodels are prefixed has been changed.
This is best explained through an example.
Expand Down Expand Up @@ -77,9 +64,62 @@ outer() | (@varname(var"a.x") => 1.0,)
outer() | (a.x=1.0,)
```

If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain.
In a similar way, if the variable on the left-hand side of your tilde statement is not just a single identifier, any fields or indices it accesses are now properly respected.
Consider the following setup:

```julia
using DynamicPPL, Distributions
@model inner() = x ~ Normal()
@model function outer()
a = Vector{Float64}(undef, 1)
a[1] ~ to_submodel(inner())
return a
end
```

In this case, the variable sampled is actually the `x` field of the first element of `a`:

```julia
julia> only(keys(VarInfo(outer()))) == @varname(a[1].x)
true
```

Before this version, it used to be a single variable called `var"a[1].x"`.

Note that if you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain.
(This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.)

### AD testing utilities

`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default.
To disable this, pass the `linked=false` keyword argument.
If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure.
This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information.
From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`.

### SimpleVarInfo linking / invlinking

Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error.

### VarInfo constructors

`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead.

The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed.
If you were not using this argument (most likely), then there is no change needed.
If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below).

The `UntypedVarInfo` constructor and type is no longer exported.
If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead.

The `TypedVarInfo` constructor and type is no longer exported.
The _type_ has been replaced with `DynamicPPL.NTVarInfo`.
The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`.

Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail.
Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs.
Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface.

**Other changes**

While these are technically breaking, they are only internal changes and do not affect the public API.
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
4 changes: 3 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ makedocs(;
format=Documenter.HTML(; size_threshold=2^10 * 400),
modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)],
pages=[
"Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"]
"Home" => "index.md",
"API" => "api.md",
"Internals" => ["internals/varinfo.md", "internals/submodel_condition.md"],
],
checkdocs=:exports,
doctest=false,
Expand Down
14 changes: 7 additions & 7 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ decondition

## Fixing and unfixing

We can also _fix_ a collection of variables in a [`Model`](@ref) to certain using [`fix`](@ref).
We can also _fix_ a collection of variables in a [`Model`](@ref) to certain values using [`DynamicPPL.fix`](@ref).

This might seem quite similar to the aforementioned [`condition`](@ref) and its siblings,
This is quite similar to the aforementioned [`condition`](@ref) and its siblings,
but they are indeed different operations:

- `condition`ed variables are considered to be _observations_, and are thus
Expand All @@ -89,19 +89,19 @@ but they are indeed different operations:
- `fix`ed variables are considered to be _constant_, and are thus not included
in any log-probability computations.

The differences are more clearly spelled out in the docstring of [`fix`](@ref) below.
The differences are more clearly spelled out in the docstring of [`DynamicPPL.fix`](@ref) below.

```@docs
fix
DynamicPPL.fix
DynamicPPL.fixed
```

The difference between [`fix`](@ref) and [`condition`](@ref) is described in the docstring of [`fix`](@ref) above.
The difference between [`DynamicPPL.fix`](@ref) and [`DynamicPPL.condition`](@ref) is described in the docstring of [`DynamicPPL.fix`](@ref) above.

Similarly, we can [`unfix`](@ref) variables, i.e. return them to their original meaning:
Similarly, we can revert this with [`DynamicPPL.unfix`](@ref), i.e. return the variables to their original meaning:

```@docs
unfix
DynamicPPL.unfix
```

## Predicting
Expand Down
Loading
Loading