Skip to content

Commit 31ac0e0

Browse files
committed
gpu: jit: conv: avoid explicit cast when fixing 32-bit overflow if possible
1 parent e3cb07d commit 31ac0e0

File tree

3 files changed

+133
-91
lines changed

3 files changed

+133
-91
lines changed

src/gpu/jit/ir/ir.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ expr_t nary_op_back_transform(const expr_t &e);
410410
expr_t nary_op_canonicalize(const expr_t &_e);
411411
expr_t make_nary_op(op_kind_t op_kind, const std::vector<expr_t> &args);
412412
std::vector<expr_t> cvt_expr_to_nary_op_args(const expr_t &e);
413+
expr_t reorder_nary_add_args(const expr_t &e, bool x64_first);
413414

414415
// Substitutes all occurrences of `from` to `to` in `root`.
415416
object_t substitute(const object_t &root, const object_t &from,

src/gpu/jit/pass/overflow.cpp

Lines changed: 107 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,100 @@ class overflow_bound_finder_t : public bound_finder_base_t {
7070
object_map_t<expr_t, std::pair<int64_t, int64_t>> var_bounds_;
7171
};
7272

73+
struct overflow_context_t {
74+
overflow_bound_finder_t bound_finder;
75+
object_map_t<expr_t, std::vector<expr_t>> vec_vars;
76+
object_set_t<expr_t> vars_with_load;
77+
78+
bool contains_load(const expr_t &e) const {
79+
if (!find_objects<load_t>(e).empty()) return true;
80+
for (auto &v : find_objects<var_t>(e)) {
81+
if (vars_with_load.count(v) != 0) return true;
82+
}
83+
return false;
84+
}
85+
};
86+
87+
class expr_overflow_fixer_t : public ir_mutator_t {
88+
public:
89+
expr_overflow_fixer_t(const overflow_context_t &ctx) : ctx_(ctx) {}
90+
91+
object_t _mutate(const binary_op_t &obj) override {
92+
return mutate_expr(obj);
93+
}
94+
95+
object_t _mutate(const unary_op_t &obj) override {
96+
return mutate_expr(obj);
97+
}
98+
99+
private:
100+
template <typename T>
101+
object_t mutate_expr(const T &obj) {
102+
expr_t new_obj = ir_mutator_t::_mutate(obj);
103+
if (!new_obj.type().is_x32()) return std::move(new_obj);
104+
if (ctx_.contains_load(new_obj)) return std::move(new_obj);
105+
106+
bool found_overflow = false;
107+
int elems = new_obj.type().elems();
108+
for (int i = 0; i < elems; i++) {
109+
expr_scalarizer_t scalarizer(elems, i, ctx_.vec_vars);
110+
expr_t value = scalarizer.mutate(new_obj);
111+
int64_t lo = ctx_.bound_finder.find_low_bound(value);
112+
int64_t hi = ctx_.bound_finder.find_high_bound(value);
113+
bool ok = bound_finder_base_t::is_good_bound(lo)
114+
&& bound_finder_base_t::is_good_bound(hi);
115+
if (ok) {
116+
int64_t type_lo = value.type().is_s32()
117+
? (int64_t)std::numeric_limits<int32_t>::min()
118+
: (int64_t)std::numeric_limits<uint32_t>::min();
119+
int64_t type_hi = value.type().is_s32()
120+
? (int64_t)std::numeric_limits<int32_t>::max()
121+
: (int64_t)std::numeric_limits<uint32_t>::max();
122+
123+
bool is_overflow = (lo < type_lo || hi > type_hi);
124+
if (is_overflow) {
125+
found_overflow = true;
126+
ir_warning() << "Found overflow: " << value
127+
<< " low bound: " << lo
128+
<< " high bound: " << hi << std::endl;
129+
break;
130+
}
131+
}
132+
}
133+
if (found_overflow) return fix_overflow(new_obj);
134+
return std::move(new_obj);
135+
}
136+
137+
static expr_t fix_overflow(const expr_t &e) {
138+
auto *binary = e.as_ptr<binary_op_t>();
139+
if (binary) {
140+
return binary_op_t::make(binary->op_kind,
141+
cast(binary->a, type_t::u64(e.type().elems())), binary->b);
142+
}
143+
144+
ir_error_not_expected() << "Can't fix overflow: " << e;
145+
return e;
146+
}
147+
148+
const overflow_context_t &ctx_;
149+
};
150+
151+
expr_t fix_expr_overflow(const expr_t &e, const overflow_context_t &ctx) {
152+
auto e_fixed = expr_overflow_fixer_t(ctx).mutate(e);
153+
if (e_fixed.is_same(e)) return e;
154+
155+
// Overflow detected, try to rearrange summands and avoid explicit casting.
156+
auto nary = reorder_nary_add_args(
157+
nary_op_canonicalize(e), /*x64_first=*/true);
158+
auto e_reordered = nary_op_back_transform(nary);
159+
auto e_reordered_fixed = expr_overflow_fixer_t(ctx).mutate(e_reordered);
160+
if (e_reordered_fixed.is_same(e_reordered)) {
161+
// No overflow detected after rearranging, return it.
162+
return e_reordered;
163+
}
164+
return e_fixed;
165+
}
166+
73167
class overflow_fixer_t : public ir_mutator_t {
74168
public:
75169
overflow_fixer_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) {
@@ -90,7 +184,7 @@ class overflow_fixer_t : public ir_mutator_t {
90184
<< to_string(rel.op_kind());
91185
}
92186
}
93-
bound_finder_.set_var_bounds(kv.first, {lo, hi});
187+
ctx_.bound_finder.set_var_bounds(kv.first, {lo, hi});
94188
}
95189
}
96190

@@ -99,13 +193,13 @@ class overflow_fixer_t : public ir_mutator_t {
99193
}
100194

101195
object_t _mutate(const binary_op_t &obj) override {
102-
return mutate_expr(obj);
196+
return fix_expr_overflow(obj, ctx_);
103197
}
104198

105199
object_t _mutate(const for_t &obj) override {
106200
auto lo = to_cpp<int64_t>(obj.init);
107201
auto hi = to_cpp<int64_t>(obj.bound) - 1;
108-
bound_finder_.set_var_bounds(obj.var, {lo, hi});
202+
ctx_.bound_finder.set_var_bounds(obj.var, {lo, hi});
109203
return ir_mutator_t::_mutate(obj);
110204
}
111205

@@ -114,25 +208,25 @@ class overflow_fixer_t : public ir_mutator_t {
114208
if (!obj.var.type().is_int()) ok = false;
115209
if (ok && obj.value.is_empty()) ok = false;
116210
if (ok && obj.value.type().is_bool()) ok = false;
117-
if (ok && bound_finder_.has_var(obj.var)) ok = false;
211+
if (ok && ctx_.bound_finder.has_var(obj.var)) ok = false;
118212

119213
if (ok) {
120-
if (contains_load(obj.value)) {
121-
vars_with_load_.insert(obj.var);
214+
if (ctx_.contains_load(obj.value)) {
215+
ctx_.vars_with_load.insert(obj.var);
122216
ok = false;
123217
}
124218
}
125219

126220
if (ok) {
127221
int elems = obj.var.type().elems();
128-
vec_vars_[obj.var].reserve(elems);
222+
ctx_.vec_vars[obj.var].reserve(elems);
129223
for (int i = 0; i < elems; i++) {
130224
auto var_i = make_vec_var(obj.var, elems, i);
131-
expr_scalarizer_t scalarizer(elems, i, vec_vars_);
225+
expr_scalarizer_t scalarizer(elems, i, ctx_.vec_vars);
132226
auto value_i = scalarizer.mutate(obj.value);
133-
auto lo_hi = bound_finder_.find_bounds(value_i);
134-
bound_finder_.set_var_bounds(var_i, lo_hi);
135-
vec_vars_[obj.var].push_back(var_i);
227+
auto lo_hi = ctx_.bound_finder.find_bounds(value_i);
228+
ctx_.bound_finder.set_var_bounds(var_i, lo_hi);
229+
ctx_.vec_vars[obj.var].push_back(var_i);
136230
}
137231
}
138232
expr_t var = obj.var;
@@ -150,77 +244,19 @@ class overflow_fixer_t : public ir_mutator_t {
150244
}
151245

152246
object_t _mutate(const unary_op_t &obj) override {
153-
return mutate_expr(obj);
247+
return fix_expr_overflow(obj, ctx_);
154248
}
155249

156250
private:
157-
template <typename T>
158-
object_t mutate_expr(const T &obj) {
159-
expr_t new_obj = ir_mutator_t::_mutate(obj);
160-
if (!new_obj.type().is_x32()) return std::move(new_obj);
161-
if (contains_load(new_obj)) return std::move(new_obj);
162-
163-
bool found_overflow = false;
164-
int elems = new_obj.type().elems();
165-
for (int i = 0; i < elems; i++) {
166-
expr_scalarizer_t scalarizer(elems, i, vec_vars_);
167-
expr_t value = scalarizer.mutate(new_obj);
168-
int64_t lo = bound_finder_.find_low_bound(value);
169-
int64_t hi = bound_finder_.find_high_bound(value);
170-
bool ok = bound_finder_base_t::is_good_bound(lo)
171-
&& bound_finder_base_t::is_good_bound(hi);
172-
if (ok) {
173-
int64_t type_lo = value.type().is_s32()
174-
? (int64_t)std::numeric_limits<int32_t>::min()
175-
: (int64_t)std::numeric_limits<uint32_t>::min();
176-
int64_t type_hi = value.type().is_s32()
177-
? (int64_t)std::numeric_limits<int32_t>::max()
178-
: (int64_t)std::numeric_limits<uint32_t>::max();
179-
180-
bool is_overflow = (lo < type_lo || hi > type_hi);
181-
if (is_overflow) {
182-
found_overflow = true;
183-
ir_warning() << "Found overflow: " << value
184-
<< " low bound: " << lo
185-
<< " high bound: " << hi << std::endl;
186-
break;
187-
}
188-
}
189-
}
190-
if (found_overflow) return fix_overflow(new_obj);
191-
return std::move(new_obj);
192-
}
193-
194-
bool contains_load(const expr_t &e) const {
195-
if (!find_objects<load_t>(e).empty()) return true;
196-
for (auto &v : find_objects<var_t>(e)) {
197-
if (vars_with_load_.count(v) != 0) return true;
198-
}
199-
return false;
200-
}
201-
202251
static expr_t make_vec_var(const expr_t &_var, int elems, int idx) {
203252
if (elems == 1) return _var;
204253
auto &var = _var.as<var_t>();
205254
auto vec_name = var.name + "_" + std::to_string(idx) + "_";
206255
return var_t::make(var.type.scalar(), vec_name);
207256
}
208257

209-
static expr_t fix_overflow(const expr_t &e) {
210-
auto *binary = e.as_ptr<binary_op_t>();
211-
if (binary) {
212-
return binary_op_t::make(binary->op_kind,
213-
cast(binary->a, type_t::u64(e.type().elems())), binary->b);
214-
}
215-
216-
ir_error_not_expected() << "Can't fix overflow: " << e;
217-
return e;
218-
}
219-
220258
ir_context_t &ir_ctx_;
221-
overflow_bound_finder_t bound_finder_;
222-
object_map_t<expr_t, std::vector<expr_t>> vec_vars_;
223-
object_set_t<expr_t> vars_with_load_;
259+
overflow_context_t ctx_;
224260
};
225261

226262
stmt_t fix_int32_overflow(const stmt_t &s, ir_context_t &ir_ctx) {

src/gpu/jit/pass/simplify.cpp

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,6 +1497,30 @@ class common_factor_simplifier_t : public nary_op_mutator_t {
14971497
}
14981498
};
14991499

1500+
expr_t reorder_nary_add_args(const expr_t &e, bool x64_first) {
1501+
auto *nary_op = e.as_ptr<nary_op_t>();
1502+
if (nary_op->op_kind != op_kind_t::_add || nary_op->args.size() <= 2)
1503+
return e;
1504+
1505+
std::vector<expr_t> other_args;
1506+
std::vector<expr_t> x64_args;
1507+
for (auto &a : nary_op->args) {
1508+
if (a.type().is_x64()) {
1509+
x64_args.push_back(a);
1510+
} else {
1511+
other_args.push_back(a);
1512+
}
1513+
}
1514+
1515+
if (other_args.empty() || x64_args.empty()) return e;
1516+
1517+
std::vector<expr_t> new_args = std::move(other_args);
1518+
new_args.insert(x64_first ? new_args.begin() : new_args.end(),
1519+
x64_args.begin(), x64_args.end());
1520+
1521+
return nary_op_t::make(nary_op->op_kind, new_args);
1522+
}
1523+
15001524
// Rewrites addition with mixed 64-bit/32-bit expressions to reduce 64-bit
15011525
// arithmetic. Example:
15021526
// Before: ((x.s64 + y.s32) + z.s32) [two 64-bit add]
@@ -1505,26 +1529,7 @@ class _64_bit_add_optimizer_t : public nary_op_mutator_t {
15051529
public:
15061530
object_t _mutate(const nary_op_t &obj) override {
15071531
auto new_obj = nary_op_mutator_t::_mutate(obj);
1508-
auto *nary_op = new_obj.as_ptr<nary_op_t>();
1509-
if (nary_op->op_kind != op_kind_t::_add || nary_op->args.size() <= 2)
1510-
return new_obj;
1511-
1512-
std::vector<expr_t> other_args;
1513-
std::vector<expr_t> x64_args;
1514-
for (auto &a : nary_op->args) {
1515-
if (a.type().is_x64()) {
1516-
x64_args.push_back(a);
1517-
} else {
1518-
other_args.push_back(a);
1519-
}
1520-
}
1521-
1522-
if (other_args.empty() || x64_args.empty()) return new_obj;
1523-
1524-
std::vector<expr_t> new_args = std::move(other_args);
1525-
new_args.insert(new_args.end(), x64_args.begin(), x64_args.end());
1526-
1527-
return nary_op_t::make(nary_op->op_kind, new_args);
1532+
return reorder_nary_add_args(new_obj, /*x64_first=*/false);
15281533
}
15291534
};
15301535

0 commit comments

Comments
 (0)