Skip to content

Commit 72ba948

Browse files
committed
Fix and test autodiff containers
1 parent 15f3264 commit 72ba948

File tree

3 files changed

+27
-27
lines changed

3 files changed

+27
-27
lines changed

stan/math/fwd/fun/hypergeometric_pFq.hpp

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <stan/math/prim/fun/dot_product.hpp>
88
#include <stan/math/prim/fun/grad_pFq.hpp>
99
#include <stan/math/prim/fun/hypergeometric_pFq.hpp>
10+
#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
11+
#include <stan/math/prim/fun/to_ref.hpp>
1012

1113
namespace stan {
1214
namespace math {
@@ -30,33 +32,27 @@ template <typename Ta, typename Tb, typename Tz,
3032
bool grad_z = !is_constant<Tz>::value,
3133
require_all_vector_t<Ta, Tb>* = nullptr,
3234
require_fvar_t<FvarT>* = nullptr>
33-
inline FvarT hypergeometric_pFq(const Ta& a, const Tb& b, const Tz& z) {
34-
using PartialsT = partials_type_t<FvarT>;
35-
using ARefT = ref_type_t<Ta>;
36-
using BRefT = ref_type_t<Tb>;
37-
38-
ARefT a_ref = a;
39-
BRefT b_ref = b;
35+
inline FvarT hypergeometric_pFq(Ta&& a, Tb&& b, Tz&& z) {
36+
auto&& a_ref = to_ref(as_column_vector_or_scalar(a));
37+
auto&& b_ref = to_ref(as_column_vector_or_scalar(b));
4038
auto&& a_val = value_of(a_ref);
4139
auto&& b_val = value_of(b_ref);
4240
auto&& z_val = value_of(z);
43-
PartialsT pfq_val = hypergeometric_pFq(a_val, b_val, z_val);
41+
42+
partials_type_t<FvarT> pfq_val = hypergeometric_pFq(a_val, b_val, z_val);
4443
auto grad_tuple
4544
= grad_pFq<grad_a, grad_b, grad_z>(pfq_val, a_val, b_val, z_val);
4645

4746
FvarT rtn = FvarT(pfq_val, 0.0);
4847

49-
if (grad_a) {
50-
rtn.d_ += dot_product(forward_as<promote_scalar_t<FvarT, ARefT>>(a_ref).d(),
51-
std::get<0>(grad_tuple));
48+
if constexpr (grad_a) {
49+
rtn.d_ += dot_product(a_ref.d(), std::get<0>(grad_tuple));
5250
}
53-
if (grad_b) {
54-
rtn.d_ += dot_product(forward_as<promote_scalar_t<FvarT, BRefT>>(b_ref).d(),
55-
std::get<1>(grad_tuple));
51+
if constexpr (grad_b) {
52+
rtn.d_ += dot_product(b_ref.d(), std::get<1>(grad_tuple));
5653
}
57-
if (grad_z) {
58-
rtn.d_ += forward_as<promote_scalar_t<FvarT, Tz>>(z).d_
59-
* std::get<2>(grad_tuple);
54+
if constexpr (grad_z) {
55+
rtn.d_ += z.d_ * std::get<2>(grad_tuple);
6056
}
6157

6258
return rtn;

stan/math/rev/fun/hypergeometric_pFq.hpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/rev/core.hpp>
55
#include <stan/math/rev/meta.hpp>
6+
#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
67
#include <stan/math/prim/fun/grad_pFq.hpp>
78
#include <stan/math/prim/fun/hypergeometric_pFq.hpp>
89

@@ -25,27 +26,24 @@ template <typename Ta, typename Tb, typename Tz,
2526
bool grad_a = !is_constant<Ta>::value,
2627
bool grad_b = !is_constant<Tb>::value,
2728
bool grad_z = !is_constant<Tz>::value,
28-
require_all_matrix_t<Ta, Tb>* = nullptr,
29+
require_all_vector_t<Ta, Tb>* = nullptr,
2930
require_return_type_t<is_var, Ta, Tb, Tz>* = nullptr>
30-
inline var hypergeometric_pFq(const Ta& a, const Tb& b, const Tz& z) {
31-
arena_t<Ta> arena_a = a;
32-
arena_t<Tb> arena_b = b;
31+
inline var hypergeometric_pFq(Ta&& a, Tb&& b, Tz&& z) {
32+
auto&& arena_a = to_arena(as_column_vector_or_scalar(std::forward<Ta>(a)));
33+
auto&& arena_b = to_arena(as_column_vector_or_scalar(std::forward<Tb>(b)));
3334
auto pfq_val = hypergeometric_pFq(arena_a.val(), arena_b.val(), value_of(z));
3435
return make_callback_var(
3536
pfq_val, [arena_a, arena_b, z, pfq_val](auto& vi) mutable {
3637
auto grad_tuple = grad_pFq<grad_a, grad_b, grad_z>(
3738
pfq_val, arena_a.val(), arena_b.val(), value_of(z));
3839
if constexpr (grad_a) {
39-
forward_as<promote_scalar_t<var, Ta>>(arena_a).adj()
40-
+= vi.adj() * std::get<0>(grad_tuple);
40+
arena_a.adj() += vi.adj() * std::get<0>(grad_tuple);
4141
}
4242
if constexpr (grad_b) {
43-
forward_as<promote_scalar_t<var, Tb>>(arena_b).adj()
44-
+= vi.adj() * std::get<1>(grad_tuple);
43+
arena_b.adj() += vi.adj() * std::get<1>(grad_tuple);
4544
}
4645
if constexpr (grad_z) {
47-
forward_as<promote_scalar_t<var, Tz>>(z).adj()
48-
+= vi.adj() * std::get<2>(grad_tuple);
46+
z.adj() += vi.adj() * std::get<2>(grad_tuple);
4947
}
5048
});
5149
}

test/unit/math/mix/fun/hypergeometric_pFq_test.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#include <limits>
33

44
TEST(mathMixScalFun, hyper_2f2) {
5+
using stan::math::to_array_1d;
6+
using stan::math::to_row_vector;
7+
58
auto f = [](const auto& a, const auto& b, const auto& z) {
69
using stan::math::hypergeometric_pFq;
710
return hypergeometric_pFq(a, b, z);
@@ -14,6 +17,9 @@ TEST(mathMixScalFun, hyper_2f2) {
1417
double z = 4;
1518

1619
stan::test::expect_ad(f, in1, in2, z);
20+
stan::test::expect_ad(f, to_array_1d(in1), to_row_vector(in2), z);
21+
stan::test::expect_ad(f, to_row_vector(in1), to_array_1d(in2), z);
22+
stan::test::expect_ad(f, to_array_1d(in1), to_array_1d(in2), z);
1723
}
1824

1925
TEST(mathMixScalFun, hyper_2f3) {

0 commit comments

Comments
 (0)