Skip to content

Commit fc33745

Browse files
authored
[analysis] Let Flat lattice take multiple types (#8052)
Previously, one could approximate a Flat lattice whose elements could have multiple types by creating a Flat lattice of a variant type. However, this would produce elements that were variants of variants, wasting space on an extra discriminant. To make this use case more efficient and ergonomic, support taking multiple type parameters in Flat. The multiple type parameters all become part of the element variant type. To handle the case where types are repeated, also add element accessors templatized on the type index.
1 parent 89688ba commit fc33745

File tree

2 files changed

+73
-31
lines changed

2 files changed

+73
-31
lines changed

src/analysis/lattices/flat.h

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
#ifndef wasm_analysis_lattices_flat_h
1818
#define wasm_analysis_lattices_flat_h
1919

20+
#include <tuple>
21+
#include <type_traits>
2022
#include <variant>
2123

2224
#if __cplusplus >= 202002L
2325
#include <concepts>
2426
#endif
2527

26-
#include "../lattice.h"
28+
#include "analysis/lattice.h"
2729
#include "support/utilities.h"
2830

2931
namespace wasm::analysis {
@@ -33,27 +35,48 @@ namespace wasm::analysis {
3335
template<typename T>
3436
concept Flattenable = std::copyable<T> && std::equality_comparable<T>;
3537

36-
// Given a type T, Flat<T> is the lattice where none of the values of T are
37-
// comparable except with themselves, but they are all greater than a common
38-
// bottom element not in T and less than a common top element also not in T.
39-
template<Flattenable T>
38+
// Given types Ts..., Flat<T...> is the lattice where none of the values of any
39+
// T are comparable except with themselves, but they are all greater than a
40+
// common bottom element and less than a common top element.
41+
template<Flattenable T, Flattenable... Ts>
4042
#else
41-
template<typename T>
43+
template<typename T, typename... Ts>
4244
#endif
4345
struct Flat {
4446
private:
45-
struct Bot {};
46-
struct Top {};
47+
struct Bot : std::monostate {};
48+
struct Top : std::monostate {};
49+
50+
template<std::size_t I>
51+
using TI = std::tuple_element_t<I, std::tuple<T, Ts...>>;
4752

4853
public:
49-
struct Element : std::variant<Bot, T, Top> {
54+
struct Element : std::variant<T, Ts..., Bot, Top> {
5055
bool isBottom() const noexcept { return std::get_if<Bot>(this); }
5156
bool isTop() const noexcept { return std::get_if<Top>(this); }
52-
const T* getVal() const noexcept { return std::get_if<T>(this); }
53-
T* getVal() noexcept { return std::get_if<T>(this); }
57+
template<typename U = T> const U* getVal() const noexcept {
58+
return std::get_if<U>(this);
59+
}
60+
template<typename U = T> U* getVal() noexcept {
61+
return std::get_if<U>(this);
62+
}
63+
template<std::size_t I> const TI<I>* getVal() const noexcept {
64+
return std::get_if<I>(this);
65+
}
66+
template<std::size_t I> TI<I>* getVal() noexcept {
67+
return std::get_if<I>(this);
68+
}
5469
bool operator==(const Element& other) const noexcept {
55-
return ((isBottom() && other.isBottom()) || (isTop() && other.isTop()) ||
56-
(getVal() && other.getVal() && *getVal() == *other.getVal()));
70+
return this->index() == other.index() &&
71+
std::visit(
72+
[](const auto& a, const auto& b) {
73+
if constexpr (std::is_same_v<decltype(a), decltype(b)>) {
74+
return a == b;
75+
}
76+
return false;
77+
},
78+
*this,
79+
other);
5780
}
5881
bool operator!=(const Element& other) const noexcept {
5982
return !(*this == other);
@@ -62,18 +85,21 @@ struct Flat {
6285

6386
Element getBottom() const noexcept { return Element{Bot{}}; }
6487
Element getTop() const noexcept { return Element{Top{}}; }
65-
Element get(T&& val) const noexcept { return Element{std::move(val)}; }
88+
template<typename U> Element get(U&& val) const noexcept {
89+
return Element{std::move(val)};
90+
}
6691

6792
LatticeComparison compare(const Element& a, const Element& b) const noexcept {
68-
if (a.index() < b.index()) {
69-
return LESS;
70-
} else if (a.index() > b.index()) {
71-
return GREATER;
72-
} else if (auto pA = a.getVal(); pA && *pA != *b.getVal()) {
73-
return NO_RELATION;
74-
} else {
93+
if (a == b) {
7594
return EQUAL;
7695
}
96+
if (a.isTop() || b.isBottom()) {
97+
return GREATER;
98+
}
99+
if (a.isBottom() || b.isTop()) {
100+
return LESS;
101+
}
102+
return NO_RELATION;
77103
}
78104

79105
bool join(Element& joinee, const Element& joiner) const noexcept {

test/gtest/lattices.cpp

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,22 @@ TEST(FlatLattice, Join) {
348348
flat, flat.getBottom(), flat.get(0), flat.get(1), flat.getTop());
349349
}
350350

351+
TEST(FlatLattice, MultipleTypes) {
352+
analysis::Flat<int, std::string> flat;
353+
testDiamondJoin(
354+
flat, flat.getBottom(), flat.get(0), flat.get("foo"), flat.getTop());
355+
356+
auto stringElem = flat.get("foo");
357+
358+
EXPECT_EQ(stringElem.getVal<0>(), nullptr);
359+
ASSERT_NE(stringElem.getVal<1>(), nullptr);
360+
EXPECT_EQ(*stringElem.getVal<1>(), std::string("foo"));
361+
362+
EXPECT_EQ(stringElem.getVal<int>(), nullptr);
363+
ASSERT_NE(stringElem.getVal<std::string>(), nullptr);
364+
EXPECT_EQ(*stringElem.getVal<std::string>(), std::string("foo"));
365+
}
366+
351367
TEST(LiftLattice, GetBottom) {
352368
analysis::Lift lift{analysis::Bool{}};
353369
EXPECT_TRUE(lift.getBottom().isBottom());
@@ -711,19 +727,19 @@ TEST(StackLattice, Compare) {
711727
auto& flat = stack.lattice;
712728
testDiamondCompare(stack,
713729
{},
714-
{flat.get(0)},
715-
{flat.get(0), flat.get(1)},
716-
{flat.get(0), flat.getTop()});
730+
{flat.get(0u)},
731+
{flat.get(0u), flat.get(1u)},
732+
{flat.get(0u), flat.getTop()});
717733
}
718734

719735
TEST(StackLattice, Join) {
720736
analysis::Stack stack{analysis::Flat<uint32_t>{}};
721737
auto& flat = stack.lattice;
722738
testDiamondJoin(stack,
723739
{},
724-
{flat.get(0)},
725-
{flat.get(0), flat.get(1)},
726-
{flat.get(0), flat.getTop()});
740+
{flat.get(0u)},
741+
{flat.get(0u), flat.get(1u)},
742+
{flat.get(0u), flat.getTop()});
727743
}
728744

729745
using OddEvenInt = analysis::Flat<uint32_t>;
@@ -815,10 +831,10 @@ TEST(AbstractionLattice, Join) {
815831
#define JOIN(a, b, c) expectJoin(__FILE__, __LINE__, a, b, c)
816832

817833
auto bot = abstraction.getBottom();
818-
auto one = OddEvenAbstraction::Element(OddEvenInt{}.get(1));
819-
auto two = OddEvenAbstraction::Element(OddEvenInt{}.get(2));
820-
auto three = OddEvenAbstraction::Element(OddEvenInt{}.get(3));
821-
auto four = OddEvenAbstraction::Element(OddEvenInt{}.get(4));
834+
auto one = OddEvenAbstraction::Element(OddEvenInt{}.get(1u));
835+
auto two = OddEvenAbstraction::Element(OddEvenInt{}.get(2u));
836+
auto three = OddEvenAbstraction::Element(OddEvenInt{}.get(3u));
837+
auto four = OddEvenAbstraction::Element(OddEvenInt{}.get(4u));
822838
auto even = OddEvenAbstraction::Element(OddEvenBool{}.get(true));
823839
auto odd = OddEvenAbstraction::Element(OddEvenBool{}.get(false));
824840
auto top = OddEvenAbstraction::Element(OddEvenBool{}.getTop());

0 commit comments

Comments
 (0)