Skip to content

Commit 2d0b31e

Browse files
committed
gpu: jit: conv: fix performance with OHWI weights
1 parent c616453 commit 2d0b31e

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/gpu/jit/conv/config.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -592,11 +592,18 @@ bool can_use_2d_send(const conv_config_t &cfg, const layout_t &l, bool is_a) {
592592
// 2D messages does not support vnni format with 4 byte elements
593593
if (type_t(prb.b_data_type).size() >= 4) return false;
594594

595+
auto is_plain_wei_ok = [&]() {
596+
if (l.is_empty()) return true;
597+
for (auto *t : {"xba", "xab", "axb"}) {
598+
if (matches_tag_strict(l, t)) return true;
599+
}
600+
return false;
601+
};
602+
595603
auto is_plain_ok = [&]() {
596604
if (is_a || prb.is_bwd_w) return matches_tag_strict(l, "axb");
597-
if (is_b && l.is_empty()) return true;
598-
if (is_b && prb.is_fwd) return matches_tag_strict(l, "xba");
599-
if (is_b && prb.is_bwd_d) return matches_tag_strict(l, "xab");
605+
bool is_wei = (is_b && prb.is_fwd) || (is_b && prb.is_bwd_d);
606+
if (is_wei) return is_plain_wei_ok();
600607
return false;
601608
};
602609

0 commit comments

Comments
 (0)