@@ -485,59 +485,53 @@ def cvt_f16(src: cute.Tensor, dst: cute.Tensor):
485485
486486
487487@dsl_user_op
488- def i64_to_f32x2 (c : cutlass .Int64 , * , loc = None , ip = None ) -> Tuple [Float32 , Float32 ]:
489- vec_i64x1 = vector .from_elements (T .vector (1 , T .i64 ()), (c .ir_value (),), loc = loc , ip = ip )
490- vec_f32x2 = vector .bitcast (T .vector (2 , T .f32 ()), vec_i64x1 )
491- res0 = Float32 (
492- vector .extract (vec_f32x2 , dynamic_position = [], static_position = [0 ], loc = loc , ip = ip )
488+ def e2e_asm2 (x : Float32 , y : Float32 , * , loc = None , ip = None ) -> Tuple [Float32 , Float32 ]:
489+ out_f32x2 = llvm .inline_asm (
490+ T .vector (2 , T .f32 ()),
491+ [Float32 (x ).ir_value (loc = loc , ip = ip ), Float32 (y , loc = loc , ip = ip ).ir_value ()],
492+ "{\n \t "
493+ ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n \t "
494+ ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n \t "
495+ ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n \t "
496+ "max.ftz.f32 f1, $1, 0fC2FE0000;\n \t "
497+ "max.ftz.f32 f2, $2, 0fC2FE0000;\n \t "
498+ "mov.b64 l1, {f1, f2};\n \t "
499+ "mov.f32 f3, 0f4B400000;\n \t "
500+ "mov.b64 l2, {f3, f3};\n \t "
501+ "add.rm.ftz.f32x2 l7, l1, l2;\n \t "
502+ "sub.rn.ftz.f32x2 l8, l7, l2;\n \t "
503+ "sub.rn.ftz.f32x2 l9, l1, l8;\n \t "
504+ "mov.f32 f7, 0f3D9DF09D;\n \t "
505+ "mov.b64 l6, {f7, f7};\n \t "
506+ "mov.f32 f6, 0f3E6906A4;\n \t "
507+ "mov.b64 l5, {f6, f6};\n \t "
508+ "mov.f32 f5, 0f3F31F519;\n \t "
509+ "mov.b64 l4, {f5, f5};\n \t "
510+ "mov.f32 f4, 0f3F800000;\n \t "
511+ "mov.b64 l3, {f4, f4};\n \t "
512+ "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n \t "
513+ "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n \t "
514+ "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n \t "
515+ "mov.b64 {r1, r2}, l7;\n \t "
516+ "mov.b64 {r3, r4}, l10;\n \t "
517+ "shl.b32 r5, r1, 23;\n \t "
518+ "add.s32 r7, r5, r3;\n \t "
519+ "shl.b32 r6, r2, 23;\n \t "
520+ "add.s32 r8, r6, r4;\n \t "
521+ "mov.b64 $0, {r7, r8};\n \t "
522+ "}\n " ,
523+ "=l,f,f" ,
524+ has_side_effects = False ,
525+ is_align_stack = False ,
526+ asm_dialect = llvm .AsmDialect .AD_ATT ,
493527 )
494- res1 = Float32 (
495- vector .extract (vec_f32x2 , dynamic_position = [], static_position = [1 ], loc = loc , ip = ip )
528+ out0 = Float32 (
529+ vector .extract (out_f32x2 , dynamic_position = [], static_position = [0 ], loc = loc , ip = ip )
496530 )
497- return res0 , res1
531+ out1 = Float32 (
532+ vector .extract (out_f32x2 , dynamic_position = [], static_position = [1 ], loc = loc , ip = ip )
533+ )
534+ return out0 , out1
498535
499536
500- @cute .jit
501- def e2e_asm2 (x : Float32 , y : Float32 ) -> Tuple [Float32 , Float32 ]:
502- out_i64 = cutlass .Int64 (
503- llvm .inline_asm (
504- T .i64 (),
505- [Float32 (x ).ir_value (), Float32 (y ).ir_value ()],
506- "{\n \t "
507- ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n \t "
508- ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n \t "
509- ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n \t "
510- "max.ftz.f32 f1, $1, 0fC2FE0000;\n \t "
511- "max.ftz.f32 f2, $2, 0fC2FE0000;\n \t "
512- "mov.b64 l1, {f1, f2};\n \t "
513- "mov.f32 f3, 0f4B400000;\n \t "
514- "mov.b64 l2, {f3, f3};\n \t "
515- "add.rm.ftz.f32x2 l7, l1, l2;\n \t "
516- "sub.rn.ftz.f32x2 l8, l7, l2;\n \t "
517- "sub.rn.ftz.f32x2 l9, l1, l8;\n \t "
518- "mov.f32 f7, 0f3D9DF09D;\n \t "
519- "mov.b64 l6, {f7, f7};\n \t "
520- "mov.f32 f6, 0f3E6906A4;\n \t "
521- "mov.b64 l5, {f6, f6};\n \t "
522- "mov.f32 f5, 0f3F31F519;\n \t "
523- "mov.b64 l4, {f5, f5};\n \t "
524- "mov.f32 f4, 0f3F800000;\n \t "
525- "mov.b64 l3, {f4, f4};\n \t "
526- "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n \t "
527- "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n \t "
528- "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n \t "
529- "mov.b64 {r1, r2}, l7;\n \t "
530- "mov.b64 {r3, r4}, l10;\n \t "
531- "shl.b32 r5, r1, 23;\n \t "
532- "add.s32 r7, r5, r3;\n \t "
533- "shl.b32 r6, r2, 23;\n \t "
534- "add.s32 r8, r6, r4;\n \t "
535- "mov.b64 $0, {r7, r8};\n \t "
536- "}\n " ,
537- "=l,f,f" ,
538- has_side_effects = False ,
539- is_align_stack = False ,
540- asm_dialect = llvm .AsmDialect .AD_ATT ,
541- )
542537 )
543- return i64_to_f32x2 (out_i64 )
0 commit comments