Skip to content

Commit 81cdf4c

Browse files
committed
[Cute] Don't need i64_to_f32x2 anymore
1 parent 8c348fd commit 81cdf4c

File tree

1 file changed

+45
-51
lines changed

1 file changed

+45
-51
lines changed

flash_attn/cute/utils.py

Lines changed: 45 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -485,59 +485,53 @@ def cvt_f16(src: cute.Tensor, dst: cute.Tensor):
485485

486486

487487
@dsl_user_op
488-
def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
489-
vec_i64x1 = vector.from_elements(T.vector(1, T.i64()), (c.ir_value(),), loc=loc, ip=ip)
490-
vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1)
491-
res0 = Float32(
492-
vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
488+
def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
489+
out_f32x2 = llvm.inline_asm(
490+
T.vector(2, T.f32()),
491+
[Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()],
492+
"{\n\t"
493+
".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t"
494+
".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t"
495+
".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t"
496+
"max.ftz.f32 f1, $1, 0fC2FE0000;\n\t"
497+
"max.ftz.f32 f2, $2, 0fC2FE0000;\n\t"
498+
"mov.b64 l1, {f1, f2};\n\t"
499+
"mov.f32 f3, 0f4B400000;\n\t"
500+
"mov.b64 l2, {f3, f3};\n\t"
501+
"add.rm.ftz.f32x2 l7, l1, l2;\n\t"
502+
"sub.rn.ftz.f32x2 l8, l7, l2;\n\t"
503+
"sub.rn.ftz.f32x2 l9, l1, l8;\n\t"
504+
"mov.f32 f7, 0f3D9DF09D;\n\t"
505+
"mov.b64 l6, {f7, f7};\n\t"
506+
"mov.f32 f6, 0f3E6906A4;\n\t"
507+
"mov.b64 l5, {f6, f6};\n\t"
508+
"mov.f32 f5, 0f3F31F519;\n\t"
509+
"mov.b64 l4, {f5, f5};\n\t"
510+
"mov.f32 f4, 0f3F800000;\n\t"
511+
"mov.b64 l3, {f4, f4};\n\t"
512+
"fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t"
513+
"fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t"
514+
"fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t"
515+
"mov.b64 {r1, r2}, l7;\n\t"
516+
"mov.b64 {r3, r4}, l10;\n\t"
517+
"shl.b32 r5, r1, 23;\n\t"
518+
"add.s32 r7, r5, r3;\n\t"
519+
"shl.b32 r6, r2, 23;\n\t"
520+
"add.s32 r8, r6, r4;\n\t"
521+
"mov.b64 $0, {r7, r8};\n\t"
522+
"}\n",
523+
"=l,f,f",
524+
has_side_effects=False,
525+
is_align_stack=False,
526+
asm_dialect=llvm.AsmDialect.AD_ATT,
493527
)
494-
res1 = Float32(
495-
vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
528+
out0 = Float32(
529+
vector.extract(out_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
496530
)
497-
return res0, res1
531+
out1 = Float32(
532+
vector.extract(out_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
533+
)
534+
return out0, out1
498535

499536

500-
@cute.jit
501-
def e2e_asm2(x: Float32, y: Float32) -> Tuple[Float32, Float32]:
502-
out_i64 = cutlass.Int64(
503-
llvm.inline_asm(
504-
T.i64(),
505-
[Float32(x).ir_value(), Float32(y).ir_value()],
506-
"{\n\t"
507-
".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t"
508-
".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t"
509-
".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t"
510-
"max.ftz.f32 f1, $1, 0fC2FE0000;\n\t"
511-
"max.ftz.f32 f2, $2, 0fC2FE0000;\n\t"
512-
"mov.b64 l1, {f1, f2};\n\t"
513-
"mov.f32 f3, 0f4B400000;\n\t"
514-
"mov.b64 l2, {f3, f3};\n\t"
515-
"add.rm.ftz.f32x2 l7, l1, l2;\n\t"
516-
"sub.rn.ftz.f32x2 l8, l7, l2;\n\t"
517-
"sub.rn.ftz.f32x2 l9, l1, l8;\n\t"
518-
"mov.f32 f7, 0f3D9DF09D;\n\t"
519-
"mov.b64 l6, {f7, f7};\n\t"
520-
"mov.f32 f6, 0f3E6906A4;\n\t"
521-
"mov.b64 l5, {f6, f6};\n\t"
522-
"mov.f32 f5, 0f3F31F519;\n\t"
523-
"mov.b64 l4, {f5, f5};\n\t"
524-
"mov.f32 f4, 0f3F800000;\n\t"
525-
"mov.b64 l3, {f4, f4};\n\t"
526-
"fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t"
527-
"fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t"
528-
"fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t"
529-
"mov.b64 {r1, r2}, l7;\n\t"
530-
"mov.b64 {r3, r4}, l10;\n\t"
531-
"shl.b32 r5, r1, 23;\n\t"
532-
"add.s32 r7, r5, r3;\n\t"
533-
"shl.b32 r6, r2, 23;\n\t"
534-
"add.s32 r8, r6, r4;\n\t"
535-
"mov.b64 $0, {r7, r8};\n\t"
536-
"}\n",
537-
"=l,f,f",
538-
has_side_effects=False,
539-
is_align_stack=False,
540-
asm_dialect=llvm.AsmDialect.AD_ATT,
541-
)
542537
)
543-
return i64_to_f32x2(out_i64)

0 commit comments

Comments
 (0)