Skip to content

Commit 12d5743

Browse files
committed
gpu: jit: do not rewrite 64-bit exprs after overflow fix pass
1 parent 31ac0e0 commit 12d5743

File tree

5 files changed

+60
-14
lines changed

5 files changed

+60
-14
lines changed

src/gpu/jit/conv/ir_builder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,7 @@ void conv_ir_builder_t::build() {
755755
cfg_.reserved_regs());
756756
stmt_ = split_shuffle(stmt_, ir_ctx);
757757
stmt_ = fixup_if_conditions(stmt_, ir_ctx);
758+
stmt_ = optimize_int64_exprs(stmt_, ir_ctx);
758759
stmt_ = fix_int32_overflow(stmt_, ir_ctx);
759760
stmt_ = eliminate_common_subexprs(
760761
stmt_, ir_ctx, cfg_.reserved_regs(), cfg_.slm().gmem_bufs());

src/gpu/jit/pass/pass.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "gpu/jit/ir/message.hpp"
2020
#include "gpu/jit/ir/reorder.hpp"
21+
#include "gpu/jit/pass/simplify.hpp"
2122
#include "gpu/jit/utils/trace.hpp"
2223

2324
namespace dnnl {
@@ -176,6 +177,35 @@ stmt_t fixup_if_conditions(const stmt_t &s, ir_context_t &ir_ctx) {
176177
return ret;
177178
}
178179

180+
class int64_expr_optimizer_t : public ir_mutator_t {
181+
public:
182+
#define HANDLE_IR_OBJECT(type) \
183+
object_t _mutate(const type &obj) override { return mutate_expr(obj); }
184+
185+
HANDLE_EXPR_IR_OBJECTS()
186+
187+
#undef HANDLE_IR_OBJECT
188+
189+
private:
190+
template <typename T>
191+
object_t mutate_expr(const T &obj) {
192+
auto new_obj = ir_mutator_t::_mutate(obj);
193+
if (auto *binary = new_obj.template as_ptr<binary_op_t>()) {
194+
if (binary->op_kind == op_kind_t::_add) {
195+
new_obj = simplify_64_bit_add(new_obj);
196+
}
197+
}
198+
return new_obj;
199+
}
200+
};
201+
202+
stmt_t optimize_int64_exprs(const stmt_t &s, ir_context_t &ir_ctx) {
203+
trace_start();
204+
auto ret = int64_expr_optimizer_t().mutate(s);
205+
trace_pass("optimize_int64_exprs", ret, ir_ctx);
206+
return ret;
207+
}
208+
179209
} // namespace jit
180210
} // namespace gpu
181211
} // namespace impl

src/gpu/jit/pass/pass.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ stmt_t split_wide_stores(const stmt_t &s, ir_context_t &ir_ctx);
5757
// if (bcast8(cond)) { ... }
5858
stmt_t fixup_if_conditions(const stmt_t &s, ir_context_t &ir_ctx);
5959

60+
// Rewrites mixed 64-bit/32-bit expressions to reduce 64-bit arithmetic.
61+
// Potential overflow is ignored and must be checked/fixed by further passes.
62+
stmt_t optimize_int64_exprs(const stmt_t &s, ir_context_t &ir_ctx);
63+
6064
} // namespace jit
6165
} // namespace gpu
6266
} // namespace impl

src/gpu/jit/pass/simplify.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,18 +1521,6 @@ expr_t reorder_nary_add_args(const expr_t &e, bool x64_first) {
15211521
return nary_op_t::make(nary_op->op_kind, new_args);
15221522
}
15231523

1524-
// Rewrites addition with mixed 64-bit/32-bit expressions to reduce 64-bit
1525-
// arithmetic. Example:
1526-
// Before: ((x.s64 + y.s32) + z.s32) [two 64-bit add]
1527-
// After: ((y.s32 + z.s32) + x.s64) [one 32-bit add and one 64-bit add]
1528-
class _64_bit_add_optimizer_t : public nary_op_mutator_t {
1529-
public:
1530-
object_t _mutate(const nary_op_t &obj) override {
1531-
auto new_obj = nary_op_mutator_t::_mutate(obj);
1532-
return reorder_nary_add_args(new_obj, /*x64_first=*/false);
1533-
}
1534-
};
1535-
15361524
// Simplifies using the N-ary form.
15371525
expr_t simplify_with_nary(const expr_t &_e, const constraint_set_t &cset) {
15381526
auto e = _e;
@@ -1545,13 +1533,30 @@ expr_t simplify_with_nary(const expr_t &_e, const constraint_set_t &cset) {
15451533
e = int_div_mod_expander_t(cset).mutate(e);
15461534
e = common_factor_simplifier_t().mutate(e);
15471535
e = int_div_mod_range_simplifier_t(cset).mutate(e);
1548-
e = _64_bit_add_optimizer_t().mutate(e);
15491536

15501537
e = nary_op_back_transform(e);
15511538

15521539
return e;
15531540
}
15541541

1542+
class _64_bit_add_optimizer_t : public nary_op_mutator_t {
1543+
public:
1544+
object_t _mutate(const nary_op_t &obj) override {
1545+
auto new_obj = nary_op_mutator_t::_mutate(obj);
1546+
return reorder_nary_add_args(new_obj, /*x64_first=*/false);
1547+
}
1548+
};
1549+
1550+
expr_t simplify_64_bit_add(const expr_t &_e) {
1551+
auto e = _e;
1552+
1553+
e = nary_op_canonicalize(e);
1554+
e = _64_bit_add_optimizer_t().mutate(e);
1555+
e = nary_op_back_transform(e);
1556+
1557+
return e;
1558+
}
1559+
15551560
class stmt_simplifier_t : public ir_mutator_t {
15561561
public:
15571562
stmt_simplifier_t(const constraint_set_t &cset) : cset_(cset) {}

src/gpu/jit/pass/simplify.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2022 Intel Corporation
2+
* Copyright 2022-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -43,6 +43,12 @@ expr_t simplify_rewrite_with_ternary(const expr_t &e, bool recursive = true);
4343
// Example: (c0 + x) op c1 -> x op (c1 - c0)
4444
expr_t simplify_cmp_move_const_to_rhs(const expr_t &e);
4545

46+
// Rewrites addition with mixed 64-bit/32-bit expressions to reduce 64-bit
47+
// arithmetic. Example:
48+
// Before: ((x.s64 + y.s32) + z.s32) [two 64-bit add]
49+
// After: ((y.s32 + z.s32) + x.s64) [one 32-bit add and one 64-bit add]
50+
expr_t simplify_64_bit_add(const expr_t &e);
51+
4652
// Reduces left and right hand sides of an expression.
4753
// Example: A * x < A * B -> x < B (if A > 0).
4854
expr_t simplify_cmp_reduce_lhs_rhs(const expr_t &e);

0 commit comments

Comments
 (0)