Skip to content

Commit 59c0014

Browse files
authored
Merge pull request #3165 from stan-dev/hyper1F0-naming
Minor fixes for hypergeometric 1F0 and 2F1
2 parents cf3d0dd + 72ba948 commit 59c0014

16 files changed

+88
-69
lines changed

stan/math/fwd/fun/hypergeometric_1F0.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ namespace math {
3131
template <typename Ta, typename Tz, typename FvarT = return_type_t<Ta, Tz>,
3232
require_all_stan_scalar_t<Ta, Tz>* = nullptr,
3333
require_any_fvar_t<Ta, Tz>* = nullptr>
34-
FvarT hypergeometric_1f0(const Ta& a, const Tz& z) {
34+
FvarT hypergeometric_1F0(const Ta& a, const Tz& z) {
3535
partials_type_t<Ta> a_val = value_of(a);
3636
partials_type_t<Tz> z_val = value_of(z);
37-
FvarT rtn = FvarT(hypergeometric_1f0(a_val, z_val), 0.0);
37+
FvarT rtn = FvarT(hypergeometric_1F0(a_val, z_val), 0.0);
3838
if (!is_constant_all<Ta>::value) {
3939
rtn.d_ += forward_as<FvarT>(a).d() * -rtn.val() * log1m(z_val);
4040
}

stan/math/fwd/fun/hypergeometric_2F1.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ namespace math {
3030
template <typename Ta1, typename Ta2, typename Tb, typename Tz,
3131
require_all_stan_scalar_t<Ta1, Ta2, Tb, Tz>* = nullptr,
3232
require_any_fvar_t<Ta1, Ta2, Tb, Tz>* = nullptr>
33-
inline return_type_t<Ta1, Ta1, Tb, Tz> hypergeometric_2F1(const Ta1& a1,
33+
inline return_type_t<Ta1, Ta2, Tb, Tz> hypergeometric_2F1(const Ta1& a1,
3434
const Ta2& a2,
3535
const Tb& b,
3636
const Tz& z) {
37-
using fvar_t = return_type_t<Ta1, Ta1, Tb, Tz>;
37+
using fvar_t = return_type_t<Ta1, Ta2, Tb, Tz>;
3838

3939
auto a1_val = value_of(a1);
4040
auto a2_val = value_of(a2);

stan/math/fwd/fun/hypergeometric_pFq.hpp

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <stan/math/prim/fun/dot_product.hpp>
88
#include <stan/math/prim/fun/grad_pFq.hpp>
99
#include <stan/math/prim/fun/hypergeometric_pFq.hpp>
10+
#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
11+
#include <stan/math/prim/fun/to_ref.hpp>
1012

1113
namespace stan {
1214
namespace math {
@@ -30,33 +32,27 @@ template <typename Ta, typename Tb, typename Tz,
3032
bool grad_z = !is_constant<Tz>::value,
3133
require_all_vector_t<Ta, Tb>* = nullptr,
3234
require_fvar_t<FvarT>* = nullptr>
33-
inline FvarT hypergeometric_pFq(const Ta& a, const Tb& b, const Tz& z) {
34-
using PartialsT = partials_type_t<FvarT>;
35-
using ARefT = ref_type_t<Ta>;
36-
using BRefT = ref_type_t<Tb>;
37-
38-
ARefT a_ref = a;
39-
BRefT b_ref = b;
35+
inline FvarT hypergeometric_pFq(Ta&& a, Tb&& b, Tz&& z) {
36+
auto&& a_ref = to_ref(as_column_vector_or_scalar(a));
37+
auto&& b_ref = to_ref(as_column_vector_or_scalar(b));
4038
auto&& a_val = value_of(a_ref);
4139
auto&& b_val = value_of(b_ref);
4240
auto&& z_val = value_of(z);
43-
PartialsT pfq_val = hypergeometric_pFq(a_val, b_val, z_val);
41+
42+
partials_type_t<FvarT> pfq_val = hypergeometric_pFq(a_val, b_val, z_val);
4443
auto grad_tuple
4544
= grad_pFq<grad_a, grad_b, grad_z>(pfq_val, a_val, b_val, z_val);
4645

4746
FvarT rtn = FvarT(pfq_val, 0.0);
4847

49-
if (grad_a) {
50-
rtn.d_ += dot_product(forward_as<promote_scalar_t<FvarT, ARefT>>(a_ref).d(),
51-
std::get<0>(grad_tuple));
48+
if constexpr (grad_a) {
49+
rtn.d_ += dot_product(a_ref.d(), std::get<0>(grad_tuple));
5250
}
53-
if (grad_b) {
54-
rtn.d_ += dot_product(forward_as<promote_scalar_t<FvarT, BRefT>>(b_ref).d(),
55-
std::get<1>(grad_tuple));
51+
if constexpr (grad_b) {
52+
rtn.d_ += dot_product(b_ref.d(), std::get<1>(grad_tuple));
5653
}
57-
if (grad_z) {
58-
rtn.d_ += forward_as<promote_scalar_t<FvarT, Tz>>(z).d_
59-
* std::get<2>(grad_tuple);
54+
if constexpr (grad_z) {
55+
rtn.d_ += z.d_ * std::get<2>(grad_tuple);
6056
}
6157

6258
return rtn;

stan/math/prim/fun/hypergeometric_1F0.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ namespace math {
2828
* @return Hypergeometric 1F0 function
2929
*/
3030
template <typename Ta, typename Tz, require_all_arithmetic_t<Ta, Tz>* = nullptr>
31-
return_type_t<Ta, Tz> hypergeometric_1f0(const Ta& a, const Tz& z) {
32-
constexpr const char* function = "hypergeometric_1f0";
33-
check_less("hypergeometric_1f0", "abs(z)", std::fabs(z), 1.0);
31+
return_type_t<Ta, Tz> hypergeometric_1F0(const Ta& a, const Tz& z) {
32+
check_less("hypergeometric_1F0", "abs(z)", std::fabs(z), 1.0);
3433

3534
return boost::math::hypergeometric_1F0(a, z, boost_policy_t<>());
3635
}

stan/math/prim/fun/hypergeometric_2F1.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ namespace internal {
4343
* @return Gauss hypergeometric function
4444
*/
4545
template <typename Ta1, typename Ta2, typename Tb, typename Tz,
46-
typename RtnT = boost::optional<return_type_t<Ta1, Ta1, Tb, Tz>>,
46+
typename RtnT = boost::optional<return_type_t<Ta1, Ta2, Tb, Tz>>,
4747
require_all_arithmetic_t<Ta1, Ta2, Tb, Tz>* = nullptr>
4848
inline RtnT hyper_2F1_special_cases(const Ta1& a1, const Ta2& a2, const Tb& b,
4949
const Tz& z) {
@@ -148,10 +148,10 @@ inline RtnT hyper_2F1_special_cases(const Ta1& a1, const Ta2& a2, const Tb& b,
148148
* @return Gauss hypergeometric function
149149
*/
150150
template <typename Ta1, typename Ta2, typename Tb, typename Tz,
151-
typename ScalarT = return_type_t<Ta1, Ta1, Tb, Tz>,
151+
typename ScalarT = return_type_t<Ta1, Ta2, Tb, Tz>,
152152
typename OptT = boost::optional<ScalarT>,
153153
require_all_arithmetic_t<Ta1, Ta2, Tb, Tz>* = nullptr>
154-
inline return_type_t<Ta1, Ta1, Tb, Tz> hypergeometric_2F1(const Ta1& a1,
154+
inline return_type_t<Ta1, Ta2, Tb, Tz> hypergeometric_2F1(const Ta1& a1,
155155
const Ta2& a2,
156156
const Tb& b,
157157
const Tz& z) {

stan/math/prim/fun/hypergeometric_3F2.hpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,20 @@ template <typename Ta, typename Tb, typename Tz,
114114
require_all_vector_t<Ta, Tb>* = nullptr,
115115
require_stan_scalar_t<Tz>* = nullptr>
116116
inline auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
117-
check_3F2_converges("hypergeometric_3F2", a[0], a[1], a[2], b[0], b[1], z);
117+
check_size_match("hypergeometric_3F2", "a", a.size(), "3", 3);
118+
check_size_match("hypergeometric_3F2", "b", b.size(), "2", 2);
119+
120+
auto a_ref = to_vector(a);
121+
auto b_ref = to_vector(b);
122+
123+
check_3F2_converges("hypergeometric_3F2", a_ref[0], a_ref[1], a_ref[2],
124+
b_ref[0], b_ref[1], z);
118125
// Boost's pFq throws convergence errors in some cases, fallback to naive
119126
// infinite-sum approach (tests pass for these)
120-
if (z == 1.0 && (sum(b) - sum(a)) < 0.0) {
121-
return internal::hypergeometric_3F2_infsum(a, b, z);
127+
if (z == 1.0 && (sum(b_ref) - sum(a_ref)) < 0.0) {
128+
return internal::hypergeometric_3F2_infsum(a_ref, b_ref, z);
122129
}
123-
return hypergeometric_pFq(to_vector(a), to_vector(b), z);
130+
return hypergeometric_pFq(a_ref, b_ref, z);
124131
}
125132

126133
/**

stan/math/prim/fun/hypergeometric_pFq.hpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err/check_not_nan.hpp>
66
#include <stan/math/prim/err/check_finite.hpp>
7+
#include <stan/math/prim/fun/to_row_vector.hpp>
78
#include <boost/math/special_functions/hypergeometric_pFq.hpp>
89

910
namespace stan {
@@ -14,10 +15,6 @@ namespace math {
1415
* input arguments:
1516
* \f$_pF_q(a_1,...,a_p;b_1,...,b_q;z)\f$
1617
*
17-
* This function is not intended to be exposed to end users, only
18-
* used for p & q values that are stable with the grad_pFq
19-
* implementation.
20-
*
2118
* See 'grad_pFq.hpp' for the derivatives wrt each parameter
2219
*
2320
* @param[in] a Vector of 'a' arguments to function
@@ -26,7 +23,7 @@ namespace math {
2623
* @return Generalized hypergeometric function
2724
*/
2825
template <typename Ta, typename Tb, typename Tz,
29-
require_all_eigen_st<std::is_arithmetic, Ta, Tb>* = nullptr,
26+
require_all_vector_st<std::is_arithmetic, Ta, Tb>* = nullptr,
3027
require_arithmetic_t<Tz>* = nullptr>
3128
return_type_t<Ta, Tb, Tz> hypergeometric_pFq(const Ta& a, const Tb& b,
3229
const Tz& z) {
@@ -47,8 +44,9 @@ return_type_t<Ta, Tb, Tz> hypergeometric_pFq(const Ta& a, const Tb& b,
4744
std::stringstream msg;
4845
msg << "hypergeometric function pFq does not meet convergence "
4946
<< "conditions with given arguments. "
50-
<< "a: " << a_ref << ", b: " << b_ref << ", "
51-
<< ", z: " << z;
47+
<< "a: " << to_row_vector(a_ref) << ", "
48+
<< "b: " << to_row_vector(b_ref) << ", "
49+
<< "z: " << z;
5250
throw std::domain_error(msg.str());
5351
}
5452

stan/math/rev/fun/hypergeometric_1F0.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ namespace math {
3131
template <typename Ta, typename Tz,
3232
require_all_stan_scalar_t<Ta, Tz>* = nullptr,
3333
require_any_var_t<Ta, Tz>* = nullptr>
34-
var hypergeometric_1f0(const Ta& a, const Tz& z) {
34+
var hypergeometric_1F0(const Ta& a, const Tz& z) {
3535
double a_val = value_of(a);
3636
double z_val = value_of(z);
37-
double rtn = hypergeometric_1f0(a_val, z_val);
37+
double rtn = hypergeometric_1F0(a_val, z_val);
3838
return make_callback_var(rtn, [rtn, a, z, a_val, z_val](auto& vi) mutable {
3939
if (!is_constant_all<Ta>::value) {
4040
forward_as<var>(a).adj() += vi.adj() * -rtn * log1m(z_val);

stan/math/rev/fun/hypergeometric_2F1.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace math {
2929
template <typename Ta1, typename Ta2, typename Tb, typename Tz,
3030
require_all_stan_scalar_t<Ta1, Ta2, Tb, Tz>* = nullptr,
3131
require_any_var_t<Ta1, Ta2, Tb, Tz>* = nullptr>
32-
inline return_type_t<Ta1, Ta1, Tb, Tz> hypergeometric_2F1(const Ta1& a1,
32+
inline return_type_t<Ta1, Ta2, Tb, Tz> hypergeometric_2F1(const Ta1& a1,
3333
const Ta2& a2,
3434
const Tb& b,
3535
const Tz& z) {

stan/math/rev/fun/hypergeometric_pFq.hpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/rev/core.hpp>
55
#include <stan/math/rev/meta.hpp>
6+
#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
67
#include <stan/math/prim/fun/grad_pFq.hpp>
78
#include <stan/math/prim/fun/hypergeometric_pFq.hpp>
89

@@ -25,27 +26,24 @@ template <typename Ta, typename Tb, typename Tz,
2526
bool grad_a = !is_constant<Ta>::value,
2627
bool grad_b = !is_constant<Tb>::value,
2728
bool grad_z = !is_constant<Tz>::value,
28-
require_all_matrix_t<Ta, Tb>* = nullptr,
29+
require_all_vector_t<Ta, Tb>* = nullptr,
2930
require_return_type_t<is_var, Ta, Tb, Tz>* = nullptr>
30-
inline var hypergeometric_pFq(const Ta& a, const Tb& b, const Tz& z) {
31-
arena_t<Ta> arena_a = a;
32-
arena_t<Tb> arena_b = b;
33-
auto pfq_val = hypergeometric_pFq(a.val(), b.val(), value_of(z));
31+
inline var hypergeometric_pFq(Ta&& a, Tb&& b, Tz&& z) {
32+
auto&& arena_a = to_arena(as_column_vector_or_scalar(std::forward<Ta>(a)));
33+
auto&& arena_b = to_arena(as_column_vector_or_scalar(std::forward<Tb>(b)));
34+
auto pfq_val = hypergeometric_pFq(arena_a.val(), arena_b.val(), value_of(z));
3435
return make_callback_var(
3536
pfq_val, [arena_a, arena_b, z, pfq_val](auto& vi) mutable {
3637
auto grad_tuple = grad_pFq<grad_a, grad_b, grad_z>(
3738
pfq_val, arena_a.val(), arena_b.val(), value_of(z));
3839
if constexpr (grad_a) {
39-
forward_as<promote_scalar_t<var, Ta>>(arena_a).adj()
40-
+= vi.adj() * std::get<0>(grad_tuple);
40+
arena_a.adj() += vi.adj() * std::get<0>(grad_tuple);
4141
}
4242
if constexpr (grad_b) {
43-
forward_as<promote_scalar_t<var, Tb>>(arena_b).adj()
44-
+= vi.adj() * std::get<1>(grad_tuple);
43+
arena_b.adj() += vi.adj() * std::get<1>(grad_tuple);
4544
}
4645
if constexpr (grad_z) {
47-
forward_as<promote_scalar_t<var, Tz>>(z).adj()
48-
+= vi.adj() * std::get<2>(grad_tuple);
46+
z.adj() += vi.adj() * std::get<2>(grad_tuple);
4947
}
5048
});
5149
}

0 commit comments

Comments
 (0)