diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h index eea06cfb99ba2..31804b4c13d08 100644 --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1032,13 +1032,17 @@ class concat_iterator static constexpr bool ReturnsByValue = !(std::is_reference_v())> && ...); + static constexpr bool ReturnsConvertiblePointer = + std::is_pointer_v && + (std::is_convertible_v()), ValueT> && ...); using reference_type = - typename std::conditional_t; + typename std::conditional_t; - using handle_type = - typename std::conditional_t, - ValueT *>; + using handle_type = typename std::conditional_t< + ReturnsConvertiblePointer, ValueT, + std::conditional_t, ValueT *>>; /// We store both the current and end iterators for each concatenated /// sequence in a tuple of pairs. @@ -1088,7 +1092,7 @@ class concat_iterator if (Begin == End) return {}; - if constexpr (ReturnsByValue) + if constexpr (ReturnsByValue || ReturnsConvertiblePointer) return *Begin; else return &*Begin; @@ -1105,8 +1109,12 @@ class concat_iterator // Loop over them, and return the first result we find. for (auto &GetHelperFn : GetHelperFns) - if (auto P = (this->*GetHelperFn)()) - return *P; + if (auto P = (this->*GetHelperFn)()) { + if constexpr (ReturnsConvertiblePointer) + return P; + else + return *P; + } llvm_unreachable("Attempted to get a pointer from an end concat iterator!"); } diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp index 286cfa745fd14..7e9b6f19a3d32 100644 --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -398,6 +398,8 @@ struct some_struct { std::string swap_val; }; +struct derives_from_some_struct : some_struct {}; + std::vector::const_iterator begin(const some_struct &s) { return s.data.begin(); } @@ -532,6 +534,18 @@ TEST(STLExtrasTest, ConcatRangeADL) { EXPECT_THAT(concat(S0, S1), ElementsAre(1, 2, 3, 4)); } +TEST(STLExtrasTest, ConcatRangePtrToDerivedClass) { + some_namespace::some_struct S0{}; + some_namespace::derives_from_some_struct S1{}; + SmallVector V0{&S0}; + SmallVector V1{&S1, &S1}; + + // Use concat over ranges of pointers to different (but related) types. + EXPECT_THAT(concat(V0, V1), + ElementsAre(&S0, static_cast(&S1), + static_cast(&S1))); +} + TEST(STLExtrasTest, MakeFirstSecondRangeADL) { // Make sure that we use the `begin`/`end` functions from `some_namespace`, // using ADL.