@@ -22,18 +22,38 @@ where
2222
2323import Control.Monad.Bayes.Class (MonadDistribution , MonadMeasure )
2424import Control.Monad.Bayes.Population
25- ( PopulationT ,
25+ ( PopulationT (.. ),
26+ flatten ,
2627 pushEvidence ,
28+ single ,
2729 withParticles ,
2830 )
31+ import Control.Monad.Bayes.Population.Applicative qualified as Applicative
2932import Control.Monad.Bayes.Sequential.Coroutine as Coroutine
33+ import Control.Monad.Bayes.Sequential.Coroutine qualified as SequentialT
34+ import Control.Monad.Bayes.Weighted (WeightedT (.. ), weightedT )
35+ import Control.Monad.Coroutine
36+ import Control.Monad.Trans.Free (FreeF (.. ), FreeT (.. ))
3037
3138data SMCConfig m = SMCConfig
3239 { resampler :: forall x . PopulationT m x -> PopulationT m x ,
3340 numSteps :: Int ,
3441 numParticles :: Int
3542 }
3643
44+ sequentialToPopulation :: (Monad m ) => Coroutine. SequentialT (Applicative. PopulationT m ) a -> PopulationT m a
45+ sequentialToPopulation =
46+ PopulationT
47+ . weightedT
48+ . coroutineToFree
49+ . Coroutine. runSequentialT
50+ where
51+ coroutineToFree =
52+ FreeT
53+ . fmap (Free . fmap (\ (cont, p) -> either (coroutineToFree . extract) (pure . (,p)) cont))
54+ . Applicative. runPopulationT
55+ . resume
56+
3757-- | Sequential importance resampling.
3858-- Basically an SMC template that takes a custom resampler.
3959smc ::
@@ -42,12 +62,15 @@ smc ::
4262 Coroutine. SequentialT (PopulationT m ) a ->
4363 PopulationT m a
4464smc SMCConfig {.. } =
45- Coroutine. sequentially resampler numSteps
65+ (single . flatten)
66+ . Coroutine. sequentially resampler numSteps
67+ . SequentialT. hoist (single . flatten)
4668 . Coroutine. hoistFirst (withParticles numParticles)
69+ . SequentialT. hoist (single . flatten)
4770
4871-- | Sequential Monte Carlo with multinomial resampling at each timestep.
4972-- Weights are normalized at each timestep and the total weight is pushed
5073-- as a score into the transformed monad.
5174smcPush ::
5275 (MonadMeasure m ) => SMCConfig m -> Coroutine. SequentialT (PopulationT m ) a -> PopulationT m a
53- smcPush config = smc config {resampler = (pushEvidence . resampler config)}
76+ smcPush config = smc config {resampler = (single . flatten . pushEvidence . resampler config)}
0 commit comments