3
3
4
4
#include < stan/math/rev/core.hpp>
5
5
#include < stan/math/rev/meta.hpp>
6
+ #include < stan/math/prim/fun/as_column_vector_or_scalar.hpp>
6
7
#include < stan/math/prim/fun/grad_pFq.hpp>
7
8
#include < stan/math/prim/fun/hypergeometric_pFq.hpp>
8
9
@@ -25,27 +26,24 @@ template <typename Ta, typename Tb, typename Tz,
25
26
bool grad_a = !is_constant<Ta>::value,
26
27
bool grad_b = !is_constant<Tb>::value,
27
28
bool grad_z = !is_constant<Tz>::value,
28
- require_all_matrix_t <Ta, Tb>* = nullptr ,
29
+ require_all_vector_t <Ta, Tb>* = nullptr ,
29
30
require_return_type_t <is_var, Ta, Tb, Tz>* = nullptr >
30
- inline var hypergeometric_pFq (const Ta& a, const Tb& b, const Tz & z) {
31
- arena_t <Ta> arena_a = a ;
32
- arena_t <Tb> arena_b = b ;
33
- auto pfq_val = hypergeometric_pFq (a .val (), b .val (), value_of (z));
31
+ inline var hypergeometric_pFq (Ta&& a, Tb&& b, Tz& & z) {
32
+ auto && arena_a = to_arena ( as_column_vector_or_scalar (std::forward<Ta>(a))) ;
33
+ auto && arena_b = to_arena ( as_column_vector_or_scalar (std::forward<Tb>(b))) ;
34
+ auto pfq_val = hypergeometric_pFq (arena_a .val (), arena_b .val (), value_of (z));
34
35
return make_callback_var (
35
36
pfq_val, [arena_a, arena_b, z, pfq_val](auto & vi) mutable {
36
37
auto grad_tuple = grad_pFq<grad_a, grad_b, grad_z>(
37
38
pfq_val, arena_a.val (), arena_b.val (), value_of (z));
38
39
if constexpr (grad_a) {
39
- forward_as<promote_scalar_t <var, Ta>>(arena_a).adj ()
40
- += vi.adj () * std::get<0 >(grad_tuple);
40
+ arena_a.adj () += vi.adj () * std::get<0 >(grad_tuple);
41
41
}
42
42
if constexpr (grad_b) {
43
- forward_as<promote_scalar_t <var, Tb>>(arena_b).adj ()
44
- += vi.adj () * std::get<1 >(grad_tuple);
43
+ arena_b.adj () += vi.adj () * std::get<1 >(grad_tuple);
45
44
}
46
45
if constexpr (grad_z) {
47
- forward_as<promote_scalar_t <var, Tz>>(z).adj ()
48
- += vi.adj () * std::get<2 >(grad_tuple);
46
+ z.adj () += vi.adj () * std::get<2 >(grad_tuple);
49
47
}
50
48
});
51
49
}
0 commit comments