7
7
#include < stan/math/prim/fun/dot_product.hpp>
8
8
#include < stan/math/prim/fun/grad_pFq.hpp>
9
9
#include < stan/math/prim/fun/hypergeometric_pFq.hpp>
10
+ #include < stan/math/prim/fun/as_column_vector_or_scalar.hpp>
11
+ #include < stan/math/prim/fun/to_ref.hpp>
10
12
11
13
namespace stan {
12
14
namespace math {
@@ -30,33 +32,27 @@ template <typename Ta, typename Tb, typename Tz,
30
32
bool grad_z = !is_constant<Tz>::value,
31
33
require_all_vector_t <Ta, Tb>* = nullptr ,
32
34
require_fvar_t <FvarT>* = nullptr >
33
- inline FvarT hypergeometric_pFq (const Ta& a, const Tb& b, const Tz& z) {
34
- using PartialsT = partials_type_t <FvarT>;
35
- using ARefT = ref_type_t <Ta>;
36
- using BRefT = ref_type_t <Tb>;
37
-
38
- ARefT a_ref = a;
39
- BRefT b_ref = b;
35
+ inline FvarT hypergeometric_pFq (Ta&& a, Tb&& b, Tz&& z) {
36
+ auto && a_ref = to_ref (as_column_vector_or_scalar (a));
37
+ auto && b_ref = to_ref (as_column_vector_or_scalar (b));
40
38
auto && a_val = value_of (a_ref);
41
39
auto && b_val = value_of (b_ref);
42
40
auto && z_val = value_of (z);
43
- PartialsT pfq_val = hypergeometric_pFq (a_val, b_val, z_val);
41
+
42
+ partials_type_t <FvarT> pfq_val = hypergeometric_pFq (a_val, b_val, z_val);
44
43
auto grad_tuple
45
44
= grad_pFq<grad_a, grad_b, grad_z>(pfq_val, a_val, b_val, z_val);
46
45
47
46
FvarT rtn = FvarT (pfq_val, 0.0 );
48
47
49
- if (grad_a) {
50
- rtn.d_ += dot_product (forward_as<promote_scalar_t <FvarT, ARefT>>(a_ref).d (),
51
- std::get<0 >(grad_tuple));
48
+ if constexpr (grad_a) {
49
+ rtn.d_ += dot_product (a_ref.d (), std::get<0 >(grad_tuple));
52
50
}
53
- if (grad_b) {
54
- rtn.d_ += dot_product (forward_as<promote_scalar_t <FvarT, BRefT>>(b_ref).d (),
55
- std::get<1 >(grad_tuple));
51
+ if constexpr (grad_b) {
52
+ rtn.d_ += dot_product (b_ref.d (), std::get<1 >(grad_tuple));
56
53
}
57
- if (grad_z) {
58
- rtn.d_ += forward_as<promote_scalar_t <FvarT, Tz>>(z).d_
59
- * std::get<2 >(grad_tuple);
54
+ if constexpr (grad_z) {
55
+ rtn.d_ += z.d_ * std::get<2 >(grad_tuple);
60
56
}
61
57
62
58
return rtn;
0 commit comments