Skip to content

Commit c46ad16

Browse files
Ryker0627assistant-librarian[bot]
authored andcommitted
[rocm-libraries] ROCm/rocm-libraries#4489 (commit 56d5c20)
[hipTensor] Add IDs of newly added contraction instances with unary ops to Actor-Critic model (#4489) ## Motivation Add IDs of newly added contraction instances with unary ops to Actor-Critic model and add new contraction instances with unary ops for f16-input-f16-compute and bf16-input-bf16-compute contractions ## Technical Details 1. Since Estevan has added f16-input-f16-compute and bf16-input-bf16-compute contraction instances (supported by gfx11 and gfx12 architectures) in CK, add new contraction instances with unary ops for these two combinations in hipTensor as well. 2. Add IDs of newly added contraction instances with unary ops to Actor-Critic model. These IDs are best-performance instances selected from benchmark tests results on MI300X GPU and Navi4x GPU. ## Test Plan run regular tests. ## Test Result all tests are passed. ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent e41a07d commit c46ad16

File tree

24 files changed

+4604
-530
lines changed

24 files changed

+4604
-530
lines changed

library/src/contraction/contraction_selection.cpp

Lines changed: 2390 additions & 0 deletions
Large diffs are not rendered by default.

library/src/contraction/contraction_selection.hpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,56 @@ namespace hiptensor
108108
hiptensorComputeDescriptor_t computeType,
109109
const uint64_t workspaceSize);
110110

111+
template <typename A,
112+
typename B,
113+
typename C,
114+
typename D,
115+
ContractionOpId_t ContractionOp,
116+
typename ComputeType>
117+
struct ActorCriticSelectionUnaryOps
118+
{
119+
static hiptensorStatus_t
120+
selectWinner(ContractionSolution** winner,
121+
std::unordered_map<size_t, ContractionSolution*> const& candidates,
122+
hiptensorDataType_t typeA,
123+
std::vector<std::size_t> const& a_ms_ks_lengths,
124+
std::vector<std::size_t> const& a_ms_ks_strides,
125+
std::vector<int32_t> const& a_ms_ks_modes,
126+
hiptensorDataType_t typeB,
127+
std::vector<std::size_t> const& b_ns_ks_lengths,
128+
std::vector<std::size_t> const& b_ns_ks_strides,
129+
std::vector<int32_t> const& b_ns_ks_modes,
130+
hiptensorDataType_t typeD,
131+
std::vector<std::size_t> const& d_ms_ns_lengths,
132+
std::vector<std::size_t> const& d_ms_ns_strides,
133+
std::vector<int32_t> const& d_ms_ns_modes,
134+
hiptensorDataType_t typeE,
135+
std::vector<std::size_t> const& e_ms_ns_lengths,
136+
std::vector<std::size_t> const& e_ms_ns_strides,
137+
std::vector<int32_t> const& e_ms_ns_modes,
138+
const uint64_t workspaceSize);
139+
};
140+
141+
hiptensorStatus_t
142+
actorCriticModelUnaryOps(ContractionSolution** winner,
143+
std::unordered_map<size_t, ContractionSolution*> const& candidates,
144+
hiptensorDataType_t typeA,
145+
std::vector<std::size_t> const& a_ms_ks_lengths,
146+
std::vector<std::size_t> const& a_ms_ks_strides,
147+
std::vector<int32_t> const& a_ms_ks_modes,
148+
hiptensorDataType_t typeB,
149+
std::vector<std::size_t> const& b_ns_ks_lengths,
150+
std::vector<std::size_t> const& b_ns_ks_strides,
151+
std::vector<int32_t> const& b_ns_ks_modes,
152+
hiptensorDataType_t typeD,
153+
std::vector<std::size_t> const& d_ms_ns_lengths,
154+
std::vector<std::size_t> const& d_ms_ns_strides,
155+
std::vector<int32_t> const& d_ms_ns_modes,
156+
hiptensorDataType_t typeE,
157+
std::vector<std::size_t> const& e_ms_ns_lengths,
158+
std::vector<std::size_t> const& e_ms_ns_strides,
159+
std::vector<int32_t> const& e_ms_ns_modes,
160+
hiptensorComputeDescriptor_t computeType,
161+
const uint64_t workspaceSize);
162+
111163
} // namespace hiptensor

library/src/contraction/contraction_solution_impl.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,9 @@ namespace hiptensor
169169
= applyCKColMajorStridesOptimizationForContraction(normal_e_ms_ns_lengths);
170170

171171
// Initialize the argument pointer
172-
if constexpr(std::
173-
is_same_v<typename Traits::AOp,
174-
ck::
175-
tensor_operation::
176-
element_wise::PassThrough> && std::is_same_v<typename Traits::BOp, ck::tensor_operation::element_wise::PassThrough> && (std::is_same_v<typename Traits::CDEOp, ck::tensor_operation::element_wise::Bilinear> || std::is_same_v<typename Traits::CDEOp, ck::tensor_operation::element_wise::BilinearComplex>))
172+
if constexpr(std::is_same_v<typename Traits::AOp, ck::tensor_operation::element_wise::PassThrough>
173+
&& std::is_same_v<typename Traits::BOp, ck::tensor_operation::element_wise::PassThrough>
174+
&&(std::is_same_v<typename Traits::CDEOp, ck::tensor_operation::element_wise::Bilinear> || std::is_same_v<typename Traits::CDEOp, ck::tensor_operation::element_wise::BilinearComplex>))
177175
{
178176
Base::mInvokerArgPtr = std::move(deviceOp->MakeArgumentPointer(
179177
A,

library/src/contraction/contraction_solution_instances.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,18 @@ namespace hiptensor
6060
ck::tensor_operation::element_wise::Bilinear,
6161
ck::bhalf_t>());
6262

63+
registerSolutions(enumerateContractionSolutions<6,
64+
6,
65+
6,
66+
ck::bhalf_t,
67+
ck::bhalf_t,
68+
ck::Tuple<ck::bhalf_t>,
69+
ck::bhalf_t,
70+
CkHiptensorUnaryOp,
71+
CkHiptensorUnaryOp,
72+
CkBilinearUnary,
73+
ck::bhalf_t>());
74+
6375
registerSolutions(
6476
enumerateContractionSolutions<6,
6577
6,
@@ -99,6 +111,18 @@ namespace hiptensor
99111
ck::tensor_operation::element_wise::Bilinear,
100112
ck::half_t>());
101113

114+
registerSolutions(enumerateContractionSolutions<6,
115+
6,
116+
6,
117+
ck::half_t,
118+
ck::half_t,
119+
ck::Tuple<ck::half_t>,
120+
ck::half_t,
121+
CkHiptensorUnaryOp,
122+
CkHiptensorUnaryOp,
123+
CkBilinearUnary,
124+
ck::half_t>());
125+
102126
registerSolutions(
103127
enumerateContractionSolutions<6,
104128
6,
@@ -293,6 +317,18 @@ namespace hiptensor
293317
ck::tensor_operation::element_wise::Scale,
294318
ck::bhalf_t>());
295319

320+
registerSolutions(enumerateContractionSolutions<6,
321+
6,
322+
6,
323+
ck::bhalf_t,
324+
ck::bhalf_t,
325+
ck::Tuple<>,
326+
ck::bhalf_t,
327+
CkHiptensorUnaryOp,
328+
CkHiptensorUnaryOp,
329+
ck::tensor_operation::element_wise::Scale,
330+
ck::bhalf_t>());
331+
296332
registerSolutions(
297333
enumerateContractionSolutions<6,
298334
6,
@@ -332,6 +368,18 @@ namespace hiptensor
332368
ck::tensor_operation::element_wise::Scale,
333369
ck::half_t>());
334370

371+
registerSolutions(enumerateContractionSolutions<6,
372+
6,
373+
6,
374+
ck::half_t,
375+
ck::half_t,
376+
ck::Tuple<>,
377+
ck::half_t,
378+
CkHiptensorUnaryOp,
379+
CkHiptensorUnaryOp,
380+
ck::tensor_operation::element_wise::Scale,
381+
ck::half_t>());
382+
335383
registerSolutions(
336384
enumerateContractionSolutions<6,
337385
6,

library/src/contraction/device/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@
141141
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_bilinear_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp
142142
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_bilinear_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp
143143
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_bilinear_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp
144+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_bilinear_unary_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance.cpp
145+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_bilinear_unary_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance.cpp
146+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_bilinear_unary_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance.cpp
147+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_bilinear_unary_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance.cpp
148+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_bilinear_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance.cpp
149+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_bilinear_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance.cpp
150+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_bilinear_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance.cpp
151+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_bilinear_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance.cpp
144152
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_scale_unary_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp
145153
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_scale_unary_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp
146154
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_scale_unary_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp
@@ -169,6 +177,14 @@
169177
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_scale_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp
170178
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_scale_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp
171179
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_scale_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp
180+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_scale_unary_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_kkn_instance.cpp
181+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_scale_unary_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_knn_instance.cpp
182+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_scale_unary_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mkn_instance.cpp
183+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_scale_unary_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mnn_instance.cpp
184+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_scale_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance.cpp
185+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_scale_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_knn_instance.cpp
186+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_scale_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance.cpp
187+
${CMAKE_CURRENT_SOURCE_DIR}/device_contraction_scale_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance.cpp
172188
)
173189

174190
add_hiptensor_component(hiptensor_contraction_instances ${CK_CONTRACTION_INSTANCE_SOURCES})
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*******************************************************************************
2+
*
3+
* MIT License
4+
*
5+
* Copyright (C) 2023-2026 Advanced Micro Devices, Inc. All rights reserved.
6+
*
7+
* Permission is hereby granted, free of charge, to any person obtaining a copy
8+
* of this software and associated documentation files (the "Software"), to deal
9+
* in the Software without restriction, including without limitation the rights
10+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
* copies of the Software, and to permit persons to whom the Software is
12+
* furnished to do so, subject to the following conditions:
13+
*
14+
* The above copyright notice and this permission notice shall be included in
15+
* all copies or substantial portions of the Software.
16+
*
17+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23+
* THE SOFTWARE.
24+
*
25+
*******************************************************************************/
26+
27+
#include <ck/ck.hpp>
28+
#include <ck/library/tensor_operation_instance/add_device_operation_instance.hpp>
29+
#include <ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp>
30+
#include <ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp>
31+
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
32+
33+
#include "hiptensor_ck_types.hpp"
34+
35+
namespace ck {
36+
namespace tensor_operation {
37+
namespace device {
38+
namespace instance {
39+
40+
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
41+
// k/k/n/n are the fast changing dimension for A/B/D/E
42+
using device_contraction_bilinear_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance =
43+
device_contraction_kk_instance<BF16,
44+
BF16,
45+
F32,
46+
BF16,
47+
BF16_Tuple,
48+
BF16,
49+
BF16,
50+
hiptensor::CkHiptensorUnaryOp,
51+
hiptensor::CkHiptensorUnaryOp,
52+
hiptensor::CkBilinearUnary,
53+
6>;
54+
55+
void add_device_contraction_bilinear_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance(
56+
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
57+
6,
58+
6,
59+
BF16,
60+
BF16,
61+
BF16_Tuple,
62+
BF16,
63+
hiptensor::CkHiptensorUnaryOp,
64+
hiptensor::CkHiptensorUnaryOp,
65+
hiptensor::CkBilinearUnary,
66+
BF16>>>& instances)
67+
{
68+
add_device_operation_instances(
69+
instances,
70+
device_contraction_bilinear_unary_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance{});
71+
}
72+
73+
} // namespace instance
74+
} // namespace device
75+
} // namespace tensor_operation
76+
} // namespace ck

0 commit comments

Comments
 (0)