Skip to content
Merged
51 changes: 34 additions & 17 deletions stan/math/prim/fun/inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,14 @@ namespace math {
* @return Inverse logit of argument.
*/
inline double inv_logit(double a) {
using std::exp;
if (a < 0) {
double exp_a = exp(a);
double exp_a = std::exp(a);
if (a < LOG_EPSILON) {
return exp_a;
}
return exp_a / (1 + exp_a);
return exp_a / (1.0 + exp_a);
}
return inv(1 + exp(-a));
return inv(1 + std::exp(-a));
}

/**
Expand All @@ -69,28 +68,46 @@ inline double inv_logit(double a) {
*/
struct inv_logit_fun {
template <typename T>
static inline auto fun(const T& x) {
return inv_logit(x);
static inline auto fun(T&& x) {
return inv_logit(std::forward<T>(x));
}
};

/**
* Vectorized version of inv_logit().
* Vectorized version of inv_logit() for containers containing ad types.
*
* @tparam T type of container
* @param x container
* @tparam T type of std::vector
* @param x std::vector
* @return Inverse logit applied to each value in x.
*/
template <
typename T, require_not_var_matrix_t<T>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
inline auto inv_logit(const T& x) {
return apply_scalar_unary<inv_logit_fun, T>::apply(x);
template <typename Container, require_ad_container_t<Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr,
require_not_rev_matrix_t<Container>* = nullptr>
inline auto inv_logit(Container&& x) {
return apply_scalar_unary<inv_logit_fun, Container>::apply(
std::forward<Container>(x));
}

// TODO(Tadej): Eigen is introducing their implementation logistic() of this
// in 3.4. Use that once we switch to Eigen 3.4

/**
* Vectorized version of inv_logit() for containers with arithmetic scalar
* types.
*
* @tparam T A type of either `std::vector` or a type that directly inherits
* from `Eigen::DenseBase`. The inner scalar type must not have an auto diff
* scalar type.
* @param x Eigen expression
* @return Inverse logit applied to each value in x.
*/
template <typename Container,
require_container_bt<std::is_arithmetic, Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
inline auto inv_logit(Container&& x) {
return apply_vector_unary<Container>::apply(
std::forward<Container>(x),
[](const auto& v) { return v.array().logistic(); });
}
} // namespace math
} // namespace stan

Expand Down
22 changes: 12 additions & 10 deletions stan/math/prim/functor/apply_scalar_unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct apply_scalar_unary<F, T, require_eigen_t<T>> {
* @return Componentwise application of the function specified
* by F to the specified matrix.
*/
static inline auto apply(const T& x) {
static inline auto apply(const std::decay_t<T>& x) {
return x.unaryExpr([](auto&& x) {
return apply_scalar_unary<F, std::decay_t<decltype(x)>>::apply(x);
});
Expand All @@ -69,7 +69,7 @@ struct apply_scalar_unary<F, T, require_eigen_t<T>> {
* expression template of type T.
*/
using return_t = std::decay_t<decltype(
apply_scalar_unary<F, T>::apply(std::declval<T>()))>;
apply_scalar_unary<F, std::decay_t<T>>::apply(std::declval<T>()))>;
};

/**
Expand All @@ -83,7 +83,8 @@ struct apply_scalar_unary<F, T, require_floating_point_t<T>> {
/**
* The return type, double.
*/
using return_t = std::decay_t<decltype(F::fun(std::declval<T>()))>;
using return_t
= std::decay_t<decltype(F::fun(std::declval<std::decay_t<T>>()))>;

/**
* Apply the function specified by F to the specified argument.
Expand Down Expand Up @@ -114,11 +115,12 @@ struct apply_scalar_unary<F, T, require_complex_t<T>> {
* @param x Argument scalar.
* @return Result of applying F to the scalar.
*/
static inline auto apply(const T& x) { return F::fun(x); }
static inline auto apply(const std::decay_t<T>& x) { return F::fun(x); }
/**
* The return type
*/
using return_t = std::decay_t<decltype(F::fun(std::declval<T>()))>;
using return_t
= std::decay_t<decltype(F::fun(std::declval<std::decay_t<T>>()))>;
};

/**
Expand Down Expand Up @@ -157,13 +159,13 @@ struct apply_scalar_unary<F, T, require_integral_t<T>> {
* @tparam T Type of element contained in standard vector.
*/
template <typename F, typename T>
struct apply_scalar_unary<F, std::vector<T>> {
struct apply_scalar_unary<F, T, require_std_vector_t<T>> {
/**
* Return type, which is calculated recursively as a standard
* vector of the return type of the contained type T.
*/
using return_t = typename std::vector<
plain_type_t<typename apply_scalar_unary<F, T>::return_t>>;
using return_t = typename std::vector<plain_type_t<
typename apply_scalar_unary<F, value_type_t<std::decay_t<T>>>::return_t>>;

/**
* Apply the function specified by F elementwise to the
Expand All @@ -174,10 +176,10 @@ struct apply_scalar_unary<F, std::vector<T>> {
* @return Elementwise application of F to the elements of the
* container.
*/
static inline auto apply(const std::vector<T>& x) {
static inline auto apply(const std::decay_t<T>& x) {
return_t fx(x.size());
for (size_t i = 0; i < x.size(); ++i) {
fx[i] = apply_scalar_unary<F, T>::apply(x[i]);
fx[i] = apply_scalar_unary<F, value_type_t<T>>::apply(x[i]);
}
return fx;
}
Expand Down
25 changes: 25 additions & 0 deletions stan/math/rev/fun/inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,31 @@ inline auto inv_logit(const var_value<T>& a) {
});
}

/**
* The inverse logit function for Eigen expressions with var value type.
*
* See inv_logit() for the double-based version.
*
* The derivative of inverse logit is
*
* \f$\frac{d}{dx} \mbox{logit}^{-1}(x) = \mbox{logit}^{-1}(x) (1 -
* \mbox{logit}^{-1}(x))\f$.
*
* @tparam T type of Eigen expression
* @param x Eigen expression
* @return Inverse logit of argument.
*/
template <typename T, require_eigen_vt<is_var, T>* = nullptr>
inline auto inv_logit(T&& x) {
auto x_arena = to_arena(std::forward<T>(x));
arena_t<T> ret = inv_logit(x_arena.val());
reverse_pass_callback([x_arena, ret]() mutable {
x_arena.adj().array()
+= ret.adj().array() * ret.val().array() * (1.0 - ret.val().array());
});
return ret;
}

} // namespace math
} // namespace stan
#endif