2828# The default implementation via `@generated` function fails to infer this
2929exponential_family_typetag (:: Categorical ) = Categorical
3030
31- isproper (:: NaturalParametersSpace , :: Type{Categorical} , η, conditioner) = isinteger (conditioner) && (conditioner === length (η)) && (length (η) >= 2 ) && (η[end ] ≈ 0 )
31+ isproper (:: NaturalParametersSpace , :: Type{Categorical} , η, conditioner) =
32+ isinteger (conditioner) && (conditioner === length (η)) && (length (η) >= 2 ) && (η[end ] ≈ 0 )
3233isproper (:: MeanParametersSpace , :: Type{Categorical} , θ, conditioner) =
3334 isinteger (conditioner) && (conditioner === length (θ)) && (length (θ) >= 2 ) && all (> (0 ), θ) && isapprox (sum (θ), 1 )
3435
@@ -51,14 +52,14 @@ function (::MeanToNatural{Categorical})(tuple_of_θ::Tuple{Any}, _)
5152 return (LoopVectorization. vmap (pᵢ -> log (pᵢ / pₖ), p),)
5253end
5354
54- function (:: NaturalToMean{Categorical} )(tuple_of_η:: Tuple{V} , _) where { V <: Vector }
55+ function (:: NaturalToMean{Categorical} )(tuple_of_η:: Tuple{V} , _) where {V <: Vector }
5556 (η,) = tuple_of_η
5657 return (softmax (η),)
5758end
5859
5960# We use `Categorical` from `Distributions.jl` for the `MeanParametersSpace`
6061# and their implementation supports only `Vector`s
61- function (:: NaturalToMean{Categorical} )(tuple_of_η:: Tuple{V} , _) where { V <: AbstractVector }
62+ function (:: NaturalToMean{Categorical} )(tuple_of_η:: Tuple{V} , _) where {V <: AbstractVector }
6263 (η,) = tuple_of_η
6364 return (softmax (convert (Vector, η)),)
6465end
@@ -87,7 +88,7 @@ getlogpartition(::NaturalParametersSpace, ::Type{Categorical}, conditioner) =
8788 return logsumexp (η)
8889 end
8990
90- getgradlogpartition (:: NaturalParametersSpace , :: Type{Categorical} , conditioner) =
91+ getgradlogpartition (:: NaturalParametersSpace , :: Type{Categorical} , conditioner) =
9192 (η) -> begin
9293 if (conditioner != = length (η))
9394 throw (
@@ -97,7 +98,7 @@ getgradlogpartition(::NaturalParametersSpace, ::Type{Categorical}, conditioner)
9798 )
9899 end
99100 sumη = vmapreduce (exp, + , η)
100- return vmap (d-> exp (d)/ sumη , η)
101+ return vmap (d -> exp (d) / sumη, η)
101102 end
102103
103104getfisherinformation (:: NaturalParametersSpace , :: Type{Categorical} , conditioner) =
0 commit comments