Skip to content

Commit 4024775

Browse files
committed
gpu: jit: conv: work around MSVC bug
1 parent 23576f9 commit 4024775

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

src/gpu/jit/conv/config.hpp

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2021-2022 Intel Corporation
2+
* Copyright 2021-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -1309,14 +1309,29 @@ class bmnk_dim_helper_t {
13091309
static const char *bwd_w_n_dims[] = {"oc", nullptr};
13101310
static const char *bwd_w_k_dims[] = {"mb", "od", "oh", "ow", nullptr};
13111311

1312-
const char **b_dims = prb_.pick_by_dir<const char **>(
1313-
fwd_b_dims, bwd_d_b_dims, bwd_w_b_dims);
1314-
const char **m_dims = prb_.pick_by_dir<const char **>(
1315-
fwd_m_dims, bwd_d_m_dims, bwd_w_m_dims);
1316-
const char **n_dims = prb_.pick_by_dir<const char **>(
1317-
fwd_n_dims, bwd_d_n_dims, bwd_w_n_dims);
1318-
const char **k_dims = prb_.pick_by_dir<const char **>(
1319-
fwd_k_dims, bwd_d_k_dims, bwd_w_k_dims);
1312+
// XXX: Do not use pick_by_dir() to work around MSVC compiler bug.
1313+
const char **b_dims = nullptr;
1314+
const char **m_dims = nullptr;
1315+
const char **n_dims = nullptr;
1316+
const char **k_dims = nullptr;
1317+
if (prb_.is_fwd) {
1318+
b_dims = fwd_b_dims;
1319+
m_dims = fwd_m_dims;
1320+
n_dims = fwd_n_dims;
1321+
k_dims = fwd_k_dims;
1322+
} else if (prb_.is_bwd_d) {
1323+
b_dims = bwd_d_b_dims;
1324+
m_dims = bwd_d_m_dims;
1325+
n_dims = bwd_d_n_dims;
1326+
k_dims = bwd_d_k_dims;
1327+
} else if (prb_.is_bwd_w) {
1328+
b_dims = bwd_w_b_dims;
1329+
m_dims = bwd_w_m_dims;
1330+
n_dims = bwd_w_n_dims;
1331+
k_dims = bwd_w_k_dims;
1332+
} else {
1333+
ir_error_not_expected();
1334+
}
13201335

13211336
if (contains(b_dims, dim_name)) return 'b';
13221337
if (contains(m_dims, dim_name)) return 'm';

0 commit comments

Comments
 (0)