Skip to content

Commit 61c1fd0

Browse files
authored
Replace thrust::tuple implementation with cuda::std::tuple (#262)
1 parent e447ecc commit 61c1fd0

File tree

22 files changed

+264
-2098
lines changed

22 files changed

+264
-2098
lines changed

libcudacxx/include/cuda/std/detail/libcxx/include/tuple

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -669,26 +669,40 @@ template <class... _Tp> class _LIBCUDACXX_TEMPLATE_VIS tuple {
669669
struct _PackExpandsToThisTuple<_Arg>
670670
: is_same<__remove_cvref_t<_Arg>, tuple> {};
671671

672-
template <size_t _Jp, class... _Up>
673-
friend _LIBCUDACXX_CONSTEXPR_AFTER_CXX11
674-
_LIBCUDACXX_INLINE_VISIBILITY __tuple_element_t<_Jp, tuple<_Up...>> &
675-
get(tuple<_Up...> &) noexcept;
676-
template <size_t _Jp, class... _Up>
677-
friend _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 _LIBCUDACXX_INLINE_VISIBILITY const
678-
__tuple_element_t<_Jp, tuple<_Up...>> &
679-
get(const tuple<_Up...> &) noexcept;
680-
template <size_t _Jp, class... _Up>
681-
friend _LIBCUDACXX_CONSTEXPR_AFTER_CXX11
682-
_LIBCUDACXX_INLINE_VISIBILITY __tuple_element_t<_Jp, tuple<_Up...>> &&
683-
get(tuple<_Up...> &&) noexcept;
684-
template <size_t _Jp, class... _Up>
685-
friend _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 _LIBCUDACXX_INLINE_VISIBILITY const
686-
__tuple_element_t<_Jp, tuple<_Up...>> &&
687-
get(const tuple<_Up...> &&) noexcept;
688-
689672
public:
690-
template <
691-
class _Constraints = __tuple_constraints<_Tp...>,
673+
template <size_t _Ip>
674+
_LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 __tuple_element_t<_Ip, tuple>&
675+
__get_impl() & noexcept
676+
{
677+
typedef _LIBCUDACXX_NODEBUG_TYPE __tuple_element_t<_Ip, tuple> type;
678+
return static_cast<__tuple_leaf<_Ip, type>&>(__base_).get();
679+
}
680+
681+
template <size_t _Ip>
682+
_LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const __tuple_element_t<_Ip, tuple>&
683+
__get_impl() const& noexcept
684+
{
685+
typedef _LIBCUDACXX_NODEBUG_TYPE __tuple_element_t<_Ip, tuple> type;
686+
return static_cast<const __tuple_leaf<_Ip, type>&>(__base_).get();
687+
}
688+
689+
template <size_t _Ip>
690+
_LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 __tuple_element_t<_Ip, tuple>&&
691+
__get_impl() && noexcept
692+
{
693+
typedef _LIBCUDACXX_NODEBUG_TYPE __tuple_element_t<_Ip, tuple> type;
694+
return static_cast<type&&>(static_cast<__tuple_leaf<_Ip, type>&&>(__base_).get());
695+
}
696+
697+
template <size_t _Ip>
698+
_LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const __tuple_element_t<_Ip, tuple>&&
699+
__get_impl() const&& noexcept
700+
{
701+
typedef _LIBCUDACXX_NODEBUG_TYPE __tuple_element_t<_Ip, tuple> type;
702+
return static_cast<const type&&>(static_cast<const __tuple_leaf<_Ip, type>&&>(__base_).get());
703+
}
704+
705+
template < class _Constraints = __tuple_constraints<_Tp...>,
692706
__enable_if_t<_Constraints::__implicit_default_constructible, int> = 0>
693707
_LIBCUDACXX_INLINE_VISIBILITY constexpr tuple() noexcept(
694708
_Constraints::__nothrow_default_constructible) {}
@@ -962,37 +976,31 @@ inline _LIBCUDACXX_INLINE_VISIBILITY
962976

963977
// get
964978
template <size_t _Ip, class... _Tp>
965-
inline _LIBCUDACXX_INLINE_VISIBILITY
966-
_LIBCUDACXX_CONSTEXPR_AFTER_CXX11 __tuple_element_t<_Ip, tuple<_Tp...>> &
967-
get(tuple<_Tp...> &__t) noexcept {
968-
typedef _LIBCUDACXX_NODEBUG_TYPE __tuple_element_t<_Ip, tuple<_Tp...>> type;
969-
return static_cast<__tuple_leaf<_Ip, type> &>(__t.__base_).get();
979+
inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 __tuple_element_t<_Ip, tuple<_Tp...>>&
980+
get(tuple<_Tp...>& __t) noexcept
981+
{
982+
return __t.template __get_impl<_Ip>();
970983
}
971984

972985
template <size_t _Ip, class... _Tp>
973-
inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const
974-
__tuple_element_t<_Ip, tuple<_Tp...>> &
975-
get(const tuple<_Tp...> &__t) noexcept {
976-
typedef _LIBCUDACXX_NODEBUG_TYPE __tuple_element_t<_Ip, tuple<_Tp...>> type;
977-
return static_cast<const __tuple_leaf<_Ip, type> &>(__t.__base_).get();
986+
inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const __tuple_element_t<_Ip, tuple<_Tp...>>&
987+
get(const tuple<_Tp...>& __t) noexcept
988+
{
989+
return __t.template __get_impl<_Ip>();
978990
}
979991

980992
template <size_t _Ip, class... _Tp>
981-
inline _LIBCUDACXX_INLINE_VISIBILITY
982-
_LIBCUDACXX_CONSTEXPR_AFTER_CXX11 __tuple_element_t<_Ip, tuple<_Tp...>> &&
983-
get(tuple<_Tp...> &&__t) noexcept {
984-
typedef _LIBCUDACXX_NODEBUG_TYPE __tuple_element_t<_Ip, tuple<_Tp...>> type;
985-
return static_cast<type &&>(
986-
static_cast<__tuple_leaf<_Ip, type> &&>(__t.__base_).get());
993+
inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 __tuple_element_t<_Ip, tuple<_Tp...>>&&
994+
get(tuple<_Tp...>&& __t) noexcept
995+
{
996+
return _CUDA_VSTD::move(__t).template __get_impl<_Ip>();
987997
}
988998

989999
template <size_t _Ip, class... _Tp>
990-
inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const
991-
__tuple_element_t<_Ip, tuple<_Tp...>> &&
992-
get(const tuple<_Tp...> &&__t) noexcept {
993-
typedef _LIBCUDACXX_NODEBUG_TYPE __tuple_element_t<_Ip, tuple<_Tp...>> type;
994-
return static_cast<const type &&>(
995-
static_cast<const __tuple_leaf<_Ip, type> &&>(__t.__base_).get());
1000+
inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const __tuple_element_t<_Ip, tuple<_Tp...>>&&
1001+
get(const tuple<_Tp...>&& __t) noexcept
1002+
{
1003+
return _CUDA_VSTD::move(__t).template __get_impl<_Ip>();
9961004
}
9971005

9981006
#if _LIBCUDACXX_STD_VER > 11

thrust/testing/functional.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ THRUST_DISABLE_MSVC_POSSIBLE_LOSS_OF_DATA_WARNING_BEGIN
99

1010
// There is a unfortunate miscompilation of the gcc-12 vectorizer leading to OOB writes
1111
// Adding this attribute suffices that this miscompilation does not appear anymore
12-
#if (THRUST_HOST_COMPILER == THRUST_HOST_COMPILER_GCC) && __GNUC__ >= 12 && THRUST_CPP_DIALECT >= 2020
12+
#if (THRUST_HOST_COMPILER == THRUST_HOST_COMPILER_GCC) && __GNUC__ >= 12
1313
#define THRUST_DISABLE_BROKEN_GCC_VECTORIZER __attribute__((optimize("no-tree-vectorize")))
1414
#else
1515
#define THRUST_DISABLE_BROKEN_GCC_VECTORIZER

thrust/testing/pair.cu

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct TestPairManipulation
4444
// test copy from pair
4545
p4.first = T(2);
4646
p4.second = T(3);
47-
47+
4848
P p5;
4949
p5 = p4;
5050
ASSERT_EQUAL(p4.first, p5.first);
@@ -217,7 +217,7 @@ using PairConstVolatileTypes =
217217
unittest::type_list<thrust::pair<int, float>, thrust::pair<int, float> const,
218218
thrust::pair<int, float> const volatile>;
219219

220-
template <typename Pair>
220+
template <typename Pair>
221221
struct TestPairTupleSize
222222
{
223223
void operator()()
@@ -289,3 +289,16 @@ void TestPairSwap(void)
289289
}
290290
DECLARE_UNITTEST(TestPairSwap);
291291

292+
#if THRUST_CPP_DIALECT >= 2017
293+
void TestPairStructuredBindings(void)
294+
{
295+
const int a = 42;
296+
const int b = 1337;
297+
thrust::pair<int,int> p(a,b);
298+
299+
auto [a2, b2] = p;
300+
ASSERT_EQUAL(a, a2);
301+
ASSERT_EQUAL(b, b2);
302+
}
303+
DECLARE_UNITTEST(TestPairStructuredBindings);
304+
#endif

thrust/testing/transform.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
// There is a unfortunate miscompilation of the gcc-12 vectorizer leading to OOB writes
1111
// Adding this attribute suffices that this miscompilation does not appear anymore
12-
#if (THRUST_HOST_COMPILER == THRUST_HOST_COMPILER_GCC) && __GNUC__ >= 12 && THRUST_CPP_DIALECT >= 2020
12+
#if (THRUST_HOST_COMPILER == THRUST_HOST_COMPILER_GCC) && __GNUC__ >= 12
1313
#define THRUST_DISABLE_BROKEN_GCC_VECTORIZER __attribute__((optimize("no-tree-vectorize")))
1414
#else
1515
#define THRUST_DISABLE_BROKEN_GCC_VECTORIZER

thrust/testing/tuple.cu

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,4 +491,30 @@ void TestTupleSwap(void)
491491
}
492492
DECLARE_UNITTEST(TestTupleSwap);
493493

494-
494+
#if THRUST_CPP_DIALECT >= 2017
495+
void TestTupleStructuredBindings(void)
496+
{
497+
const int a = 0;
498+
const int b = 42;
499+
const int c = 1337;
500+
thrust::tuple<int,int,int> t(a,b,c);
501+
502+
auto [a2, b2, c2] = t;
503+
ASSERT_EQUAL(a, a2);
504+
ASSERT_EQUAL(b, b2);
505+
ASSERT_EQUAL(c, c2);
506+
}
507+
DECLARE_UNITTEST(TestTupleStructuredBindings);
508+
#endif
509+
510+
// Ensure that we are backwards compatible with the old thrust::tuple implementation
511+
static_assert(thrust::tuple_size<thrust::tuple<thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>::value == 0, "");
512+
static_assert(thrust::tuple_size<thrust::tuple<int, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>::value == 1, "");
513+
static_assert(thrust::tuple_size<thrust::tuple<int, int, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>::value == 2, "");
514+
static_assert(thrust::tuple_size<thrust::tuple<int, int, int, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>::value == 3, "");
515+
static_assert(thrust::tuple_size<thrust::tuple<int, int, int, int, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>::value == 4, "");
516+
static_assert(thrust::tuple_size<thrust::tuple<int, int, int, int, int, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>::value == 5, "");
517+
static_assert(thrust::tuple_size<thrust::tuple<int, int, int, int, int, int, thrust::null_type, thrust::null_type, thrust::null_type>>::value == 6, "");
518+
static_assert(thrust::tuple_size<thrust::tuple<int, int, int, int, int, int, int, thrust::null_type, thrust::null_type>>::value == 7, "");
519+
static_assert(thrust::tuple_size<thrust::tuple<int, int, int, int, int, int, int, int, thrust::null_type>>::value == 8, "");
520+
static_assert(thrust::tuple_size<thrust::tuple<int, int, int, int, int, int, int, int, int>>::value == 9, "");

thrust/testing/tuple_sort.cu

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@ struct GetFunctor
2020
{
2121
template<typename Tuple>
2222
__host__ __device__
23-
typename thrust::access_traits<
24-
typename thrust::tuple_element<N, Tuple>::type
25-
>::const_type
26-
operator()(const Tuple &t)
23+
typename thrust::tuple_element<N, Tuple>::type operator()(const Tuple &t)
2724
{
2825
return thrust::get<N>(t);
2926
}
@@ -64,7 +61,7 @@ struct TestTupleStableSort
6461

6562
// select values
6663
transform(h_tuples.begin(), h_tuples.end(), h_values.begin(), GetFunctor<1>());
67-
64+
6865
device_vector<T> d_values(h_values.size());
6966
transform(d_tuples.begin(), d_tuples.end(), d_values.begin(), GetFunctor<1>());
7067

thrust/testing/tuple_transform.cu

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@ struct GetFunctor
1919
{
2020
template<typename Tuple>
2121
__host__ __device__
22-
typename thrust::access_traits<
23-
typename thrust::tuple_element<N, Tuple>::type
24-
>::const_type
25-
operator()(const Tuple &t)
22+
typename thrust::tuple_element<N, Tuple>::type operator()(const Tuple &t)
2623
{
2724
return thrust::get<N>(t);
2825
}

thrust/testing/zip_iterator.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,9 @@ void TestZipIteratorCopy(void)
285285
sequence(input0.begin(), input0.end(), T{0});
286286
sequence(input1.begin(), input1.end(), T{13});
287287

288-
copy( make_zip_iterator(make_tuple(input0.begin(), input1.begin())),
289-
make_zip_iterator(make_tuple(input0.end(), input1.end())),
290-
make_zip_iterator(make_tuple(output0.begin(), output1.begin())));
288+
thrust::copy( make_zip_iterator(make_tuple(input0.begin(), input1.begin())),
289+
make_zip_iterator(make_tuple(input0.end(), input1.end())),
290+
make_zip_iterator(make_tuple(output0.begin(), output1.begin())));
291291

292292
ASSERT_EQUAL(input0, output0);
293293
ASSERT_EQUAL(input1, output1);

thrust/thrust/detail/functional/actor.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ template<typename Eval>
6565
__host__ __device__
6666
actor(const Eval &base);
6767

68-
__host__ __device__
69-
typename apply_actor<eval_type, thrust::null_type >::type
70-
operator()(void) const;
71-
7268
template <typename... Ts>
7369
__host__ __device__
7470
typename apply_actor<eval_type, thrust::tuple<eval_ref<Ts>...>>::type
@@ -122,7 +118,7 @@ template<typename Eval>
122118
{
123119
typedef typename thrust::detail::functional::apply_actor<
124120
thrust::detail::functional::actor<Eval>,
125-
thrust::null_type
121+
thrust::tuple<>
126122
>::type type;
127123
}; // end result_of
128124

thrust/thrust/detail/functional/actor.inl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,6 @@ template<typename Eval>
5454
: eval_type(base)
5555
{}
5656

57-
template<typename Eval>
58-
__host__ __device__
59-
typename apply_actor<
60-
typename actor<Eval>::eval_type,
61-
typename thrust::null_type
62-
>::type
63-
actor<Eval>
64-
::operator()(void) const
65-
{
66-
return eval_type::eval(thrust::null_type());
67-
} // end basic_environment::operator()
68-
6957
// actor::operator() needs to construct a tuple of references to its
7058
// arguments. To make this work with thrust::reference<T>, we need to
7159
// detect thrust proxy references and store them as T rather than T&.

0 commit comments

Comments
 (0)