Skip to content

Commit c6389b7

Browse files
committed
Fix vector handling in 3F2 and pFq
1 parent a92796d commit c6389b7

File tree

4 files changed

+20
-11
lines changed

4 files changed

+20
-11
lines changed

stan/math/prim/fun/hypergeometric_3F2.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ inline auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
118118
// Boost's pFq throws convergence errors in some cases, fallback to naive
119119
// infinite-sum approach (tests pass for these)
120120
if (z == 1.0 && (sum(b) - sum(a)) < 0.0) {
121-
return internal::hypergeometric_3F2_infsum(a, b, z);
121+
return internal::hypergeometric_3F2_infsum(to_vector(a), to_vector(b), z);
122122
}
123123
return hypergeometric_pFq(to_vector(a), to_vector(b), z);
124124
}

stan/math/prim/fun/hypergeometric_pFq.hpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err/check_not_nan.hpp>
66
#include <stan/math/prim/err/check_finite.hpp>
7+
#include <stan/math/prim/fun/to_row_vector.hpp>
78
#include <boost/math/special_functions/hypergeometric_pFq.hpp>
89

910
namespace stan {
@@ -14,10 +15,6 @@ namespace math {
1415
* input arguments:
1516
* \f$_pF_q(a_1,...,a_p;b_1,...,b_q;z)\f$
1617
*
17-
* This function is not intended to be exposed to end users, only
18-
* used for p & q values that are stable with the grad_pFq
19-
* implementation.
20-
*
2118
* See 'grad_pFq.hpp' for the derivatives wrt each parameter
2219
*
2320
* @param[in] a Vector of 'a' arguments to function
@@ -26,7 +23,7 @@ namespace math {
2623
* @return Generalized hypergeometric function
2724
*/
2825
template <typename Ta, typename Tb, typename Tz,
29-
require_all_eigen_st<std::is_arithmetic, Ta, Tb>* = nullptr,
26+
require_all_vector_st<std::is_arithmetic, Ta, Tb>* = nullptr,
3027
require_arithmetic_t<Tz>* = nullptr>
3128
return_type_t<Ta, Tb, Tz> hypergeometric_pFq(const Ta& a, const Tb& b,
3229
const Tz& z) {
@@ -47,8 +44,9 @@ return_type_t<Ta, Tb, Tz> hypergeometric_pFq(const Ta& a, const Tb& b,
4744
std::stringstream msg;
4845
msg << "hypergeometric function pFq does not meet convergence "
4946
<< "conditions with given arguments. "
50-
<< "a: " << a_ref << ", b: " << b_ref << ", "
51-
<< ", z: " << z;
47+
<< "a: " << to_row_vector(a_ref) << ", "
48+
<< "b: " << to_row_vector(b_ref) << ", "
49+
<< "z: " << z;
5250
throw std::domain_error(msg.str());
5351
}
5452

test/unit/math/prim/fun/hypergeometric_3F2_test.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,16 @@
33

44
// converge
55
TEST(MathPrimScalFun, F32_converges_by_z) {
6-
EXPECT_NEAR(2.5,
7-
stan::math::hypergeometric_3F2({1.0, 1.0, 1.0}, {1.0, 1.0}, 0.6),
8-
1e-8);
6+
using stan::math::hypergeometric_3F2;
7+
using stan::math::to_vector;
8+
using stan::math::to_row_vector;
9+
std::vector<double> a = {1.0, 1.0, 1.0};
10+
std::vector<double> b = {1.0, 1.0};
11+
double z = 0.6;
12+
13+
EXPECT_NEAR(2.5, hypergeometric_3F2(a, b, z), 1e-8);
14+
EXPECT_NEAR(2.5, hypergeometric_3F2(to_vector(a), to_vector(b), z), 1e-8);
15+
EXPECT_NEAR(2.5, hypergeometric_3F2(to_row_vector(a), to_row_vector(b), z), 1e-8);
916
}
1017
// terminate by zero numerator, no sign-flip
1118
TEST(MathPrimScalFun, F32_polynomial) {

test/unit/math/prim/fun/hypergeometric_pFq_test.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
TEST(MathFunctions, hypergeometric_pFq_values) {
55
using Eigen::VectorXd;
66
using stan::math::hypergeometric_pFq;
7+
using stan::math::to_array_1d;
8+
using stan::math::to_row_vector;
79

810
VectorXd a(2);
911
VectorXd b(2);
@@ -12,6 +14,8 @@ TEST(MathFunctions, hypergeometric_pFq_values) {
1214
double z = 2;
1315

1416
EXPECT_FLOAT_EQ(3.8420514314107791, hypergeometric_pFq(a, b, z));
17+
EXPECT_FLOAT_EQ(3.8420514314107791, hypergeometric_pFq(to_row_vector(a), to_row_vector(b), z));
18+
EXPECT_FLOAT_EQ(3.8420514314107791, hypergeometric_pFq(to_array_1d(a), to_array_1d(b), z));
1519

1620
a << 6, 4;
1721
b << 3, 1;

0 commit comments

Comments
 (0)