Skip to content

Commit 4c82561

Browse files
authored
Merge branch 'main' into atom_tests
2 parents d1a9bd9 + 3b2346f commit 4c82561

File tree

2 files changed

+39
-37
lines changed

2 files changed

+39
-37
lines changed

aiter/ops/rmsnorm.py

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,15 @@ def rmsnorm2d_fwd_with_add(
8282
epsilon: float,
8383
use_model_sensitive_rmsnorm: int = 0,
8484
) -> None:
85-
if use_model_sensitive_rmsnorm > 0 or input.shape[-1] > 8192:
86-
rmsnorm2d_fwd_with_add_ck(
87-
out,
88-
input,
89-
residual_in,
90-
residual_out,
91-
weight,
92-
epsilon,
93-
use_model_sensitive_rmsnorm,
94-
)
95-
else:
96-
add_rmsnorm(out, input, residual_in, residual_out, weight, epsilon)
85+
rmsnorm2d_fwd_with_add_ck(
86+
out,
87+
input,
88+
residual_in,
89+
residual_out,
90+
weight,
91+
epsilon,
92+
use_model_sensitive_rmsnorm,
93+
)
9794

9895

9996
@compile_ops("module_rmsnorm")
@@ -155,31 +152,18 @@ def rmsnorm2d_fwd_with_add_dynamicquant(
155152
group_size: int = 0,
156153
shuffle_scale: bool = False,
157154
) -> None:
158-
if use_model_sensitive_rmsnorm > 0 or input.shape[-1] > 8192:
159-
assert group_size == 0, "group_size is not supported for ck rmsnorm"
160-
assert not shuffle_scale, "shuffle_scale is not supported for ck rmsnorm"
161-
rmsnorm2d_fwd_with_add_dynamicquant_ck(
162-
out,
163-
input,
164-
residual_in,
165-
residual_out,
166-
yscale,
167-
weight,
168-
epsilon,
169-
use_model_sensitive_rmsnorm,
170-
)
171-
else:
172-
add_rmsnorm_quant(
173-
out,
174-
input,
175-
residual_in,
176-
residual_out,
177-
yscale,
178-
weight,
179-
epsilon,
180-
group_size,
181-
shuffle_scale,
182-
)
155+
assert group_size == 0, "group_size is not supported for ck rmsnorm"
156+
assert not shuffle_scale, "shuffle_scale is not supported for ck rmsnorm"
157+
rmsnorm2d_fwd_with_add_dynamicquant_ck(
158+
out,
159+
input,
160+
residual_in,
161+
residual_out,
162+
yscale,
163+
weight,
164+
epsilon,
165+
use_model_sensitive_rmsnorm,
166+
)
183167

184168

185169
@compile_ops(

csrc/include/opus/opus.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,9 @@ REGISTER_DTYPE(bf8 , unsigned _BitInt(8))
837837
REGISTER_DTYPE(i32 , int32_t)
838838
REGISTER_DTYPE(u32 , uint32_t)
839839
REGISTER_DTYPE(i16 , int16_t)
840+
#if __clang_major__ >= 20
840841
REGISTER_DTYPE(u16 , uint16_t)
842+
#endif
841843
REGISTER_DTYPE(i8 , int8_t)
842844
REGISTER_DTYPE(u8 , uint8_t)
843845

@@ -1163,6 +1165,12 @@ OPUS_D constexpr auto buffer_default_config() {
11631165
OPUS_D __amdgpu_buffer_rsrc_t make_buffer_rsrc(const void* ptr, uint32_t size = 0xffffffff, uint32_t config = buffer_default_config()) {
11641166
return __builtin_amdgcn_make_buffer_rsrc(const_cast<void*>(static_cast<const void*>(ptr)), 0, size, config); // void *p, short stride, int num, int flags
11651167
}
1168+
#if __clang_major__ < 20
1169+
#pragma clang diagnostic push
1170+
#pragma clang diagnostic ignored "-Wundefined-inline"
1171+
OPUS_D void llvm_amdgcn_raw_buffer_load_lds(i32x4_t r, __attribute__((address_space(3))) uint32_t* p, index_t size, index_t vos, index_t sos, index_t ios, index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
1172+
#pragma clang diagnostic pop
1173+
#endif
11661174
template<typename T_>
11671175
struct gmem {
11681176
using T = remove_cvref_t<T_>;
@@ -1193,6 +1201,16 @@ struct gmem {
11931201
else if constexpr (sizeof(type) == 12) { __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, dst, 12, v_os, s_os, 0, aux); }
11941202
else if constexpr (sizeof(type) == 16) { __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, dst, 16, v_os, s_os, 0, aux); }
11951203
#endif
1204+
#else
1205+
i32x4_t cached_rsrc_;
1206+
__builtin_memcpy(&cached_rsrc_, &cached_rsrc, sizeof(i32x4_t)); // builtin memcpy, __builtin_bit_cast() can not use here due to __amdgpu_buffer_rsrc_t is non copyable
1207+
if constexpr (sizeof(type) == 1) {llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, reinterpret_cast<__attribute__((address_space(3))) u32_t*>(reinterpret_cast<unsigned long int>(dst)), 1, v_os, s_os, 0, aux); }
1208+
else if constexpr (sizeof(type) == 2) {llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, reinterpret_cast<__attribute__((address_space(3))) u32_t*>(reinterpret_cast<unsigned long int>(dst)), 2, v_os, s_os, 0, aux); }
1209+
else if constexpr (sizeof(type) == 4) {llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, reinterpret_cast<__attribute__((address_space(3))) u32_t*>(reinterpret_cast<unsigned long int>(dst)), 4, v_os, s_os, 0, aux); }
1210+
#if defined(__gfx950__)
1211+
else if constexpr (sizeof(type) == 12) {llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, reinterpret_cast<__attribute__((address_space(3))) u32_t*>(reinterpret_cast<unsigned long int>(dst)), 12, v_os, s_os, 0, aux); }
1212+
else if constexpr (sizeof(type) == 16) {llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, reinterpret_cast<__attribute__((address_space(3))) u32_t*>(reinterpret_cast<unsigned long int>(dst)), 16, v_os, s_os, 0, aux); }
1213+
#endif
11961214
#endif
11971215
}
11981216

0 commit comments

Comments
 (0)