Skip to content

Arm64/Sve: Implement ConditionalSelect API #100718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,15 +434,23 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
break;

case 3:
assert(isRMW);
if (targetReg != op1Reg)
if (isRMW)
{
assert(targetReg != op2Reg);
assert(targetReg != op3Reg);
if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);
assert(targetReg != op3Reg);

GetEmitter()->emitIns_Mov(INS_mov, emitTypeSize(node), targetReg, op1Reg, /* canSkip */ true);
GetEmitter()->emitIns_Mov(INS_mov, emitTypeSize(node), targetReg, op1Reg,
/* canSkip */ true);
}
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op2Reg, op3Reg, opt);
}
else
{
GetEmitter()->emitIns_R_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, op3Reg,
opt, INS_SCALABLE_OPTS_UNPREDICATED);
}
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op2Reg, op3Reg, opt);
break;

default:
Expand Down
3 changes: 2 additions & 1 deletion src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ HARDWARE_INTRINSIC(Sve, LoadVector,
// ***************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************
// Special intrinsics that are generated during importing or lowering

HARDWARE_INTRINSIC(Sve, CreateTrueMaskAll, -1, -1, false, {INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue}, HW_Category_Helper, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, ConditionalSelect, -1, 3, true, {INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_MaskedOperation)
HARDWARE_INTRINSIC(Sve, ConvertMaskToVector, -1, 1, true, {INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov}, HW_Category_Helper, HW_Flag_Scalable|HW_Flag_MaskedOperation)
HARDWARE_INTRINSIC(Sve, ConvertVectorToMask, -1, 2, true, {INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne}, HW_Category_Helper, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask|HW_Flag_LowMaskedOperation)

HARDWARE_INTRINSIC(Sve, CreateTrueMaskAll, -1, -1, false, {INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue}, HW_Category_Helper, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)

#endif // FEATURE_HW_INTRINSIC

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,65 @@ internal Arm64() { }
/// </summary>
public static unsafe Vector<ulong> CreateTrueMaskUInt64([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw new PlatformNotSupportedException(); }

/// ConditionalSelect : Conditionally select elements

/// <summary>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It's preferred to order these "alphabetically" as well, since that's what tooling will do in various places.

This is done based on the type name, not the language keyword:

  • byte (Byte), double (Double), short (Int16), int (Int32), long (Int64), nint (IntPtr), sbyte (SByte), float (Single), ushort (UInt16), uint (UInt32), ulong (UInt64), nuint (UIntPtr)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a74nh - do you mind fixing the tool to generate these alphabetically?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a74nh - do you mind fixing the tool to generate these alphabetically?

Done. The branch with the autogenerated files should now be in order for all the .cs files.

/// svint8_t svsel[_s8](svbool_t pg, svint8_t op1, svint8_t op2)
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// </summary>
public static unsafe Vector<sbyte> ConditionalSelect(Vector<sbyte> mask, Vector<sbyte> left, Vector<sbyte> right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint16_t svsel[_s16](svbool_t pg, svint16_t op1, svint16_t op2)
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// </summary>
public static unsafe Vector<short> ConditionalSelect(Vector<short> mask, Vector<short> left, Vector<short> right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint32_t svsel[_s32](svbool_t pg, svint32_t op1, svint32_t op2)
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// </summary>
public static unsafe Vector<int> ConditionalSelect(Vector<int> mask, Vector<int> left, Vector<int> right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint64_t svsel[_s64](svbool_t pg, svint64_t op1, svint64_t op2)
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// </summary>
public static unsafe Vector<long> ConditionalSelect(Vector<long> mask, Vector<long> left, Vector<long> right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint8_t svsel[_u8](svbool_t pg, svuint8_t op1, svuint8_t op2)
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// </summary>
public static unsafe Vector<byte> ConditionalSelect(Vector<byte> mask, Vector<byte> left, Vector<byte> right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint16_t svsel[_u16](svbool_t pg, svuint16_t op1, svuint16_t op2)
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// </summary>
public static unsafe Vector<ushort> ConditionalSelect(Vector<ushort> mask, Vector<ushort> left, Vector<ushort> right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint32_t svsel[_u32](svbool_t pg, svuint32_t op1, svuint32_t op2)
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// </summary>
public static unsafe Vector<uint> ConditionalSelect(Vector<uint> mask, Vector<uint> left, Vector<uint> right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint64_t svsel[_u64](svbool_t pg, svuint64_t op1, svuint64_t op2)
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// </summary>
public static unsafe Vector<ulong> ConditionalSelect(Vector<ulong> mask, Vector<ulong> left, Vector<ulong> right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svfloat32_t svsel[_f32](svbool_t pg, svfloat32_t op1, svfloat32_t op2)
/// </summary>
public static unsafe Vector<float> ConditionalSelect(Vector<float> mask, Vector<float> left, Vector<float> right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svfloat64_t svsel[_f64](svbool_t pg, svfloat64_t op1, svfloat64_t op2)
/// </summary>
public static unsafe Vector<double> ConditionalSelect(Vector<double> mask, Vector<double> left, Vector<double> right) { throw new PlatformNotSupportedException(); }

/// LoadVector : Unextended load

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,93 @@ internal Arm64() { }
/// </summary>
public static unsafe Vector<ulong> CreateTrueMaskUInt64([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) => CreateTrueMaskUInt64(pattern);

/// ConditionalSelect : Conditionally select elements

/// <summary>
/// svint8_t svsel[_s8](svbool_t pg, svint8_t op1, svint8_t op2)
/// SEL Zresult.B, Pg, Zop1.B, Zop2.B
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// SEL Presult.B, Pg, Pop1.B, Pop2.B
///
/// </summary>
public static unsafe Vector<sbyte> ConditionalSelect(Vector<sbyte> mask, Vector<sbyte> left, Vector<sbyte> right) => ConditionalSelect(mask, left, right);

/// <summary>
/// svint16_t svsel[_s16](svbool_t pg, svint16_t op1, svint16_t op2)
/// SEL Zresult.H, Pg, Zop1.H, Zop2.H
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// SEL Presult.B, Pg, Pop1.B, Pop2.B
///
/// </summary>
public static unsafe Vector<short> ConditionalSelect(Vector<short> mask, Vector<short> left, Vector<short> right) => ConditionalSelect(mask, left, right);

/// <summary>
/// svint32_t svsel[_s32](svbool_t pg, svint32_t op1, svint32_t op2)
/// SEL Zresult.S, Pg, Zop1.S, Zop2.S
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// SEL Presult.B, Pg, Pop1.B, Pop2.B
///
/// </summary>
public static unsafe Vector<int> ConditionalSelect(Vector<int> mask, Vector<int> left, Vector<int> right) => ConditionalSelect(mask, left, right);

/// <summary>
/// svint64_t svsel[_s64](svbool_t pg, svint64_t op1, svint64_t op2)
/// SEL Zresult.D, Pg, Zop1.D, Zop2.D
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// SEL Presult.B, Pg, Pop1.B, Pop2.B
///
/// </summary>
public static unsafe Vector<long> ConditionalSelect(Vector<long> mask, Vector<long> left, Vector<long> right) => ConditionalSelect(mask, left, right);

/// <summary>
/// svuint8_t svsel[_u8](svbool_t pg, svuint8_t op1, svuint8_t op2)
/// SEL Zresult.B, Pg, Zop1.B, Zop2.B
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// SEL Presult.B, Pg, Pop1.B, Pop2.B
///
/// </summary>
public static unsafe Vector<byte> ConditionalSelect(Vector<byte> mask, Vector<byte> left, Vector<byte> right) => ConditionalSelect(mask, left, right);

/// <summary>
/// svuint16_t svsel[_u16](svbool_t pg, svuint16_t op1, svuint16_t op2)
/// SEL Zresult.H, Pg, Zop1.H, Zop2.H
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// SEL Presult.B, Pg, Pop1.B, Pop2.B
///
/// </summary>
public static unsafe Vector<ushort> ConditionalSelect(Vector<ushort> mask, Vector<ushort> left, Vector<ushort> right) => ConditionalSelect(mask, left, right);

/// <summary>
/// svuint32_t svsel[_u32](svbool_t pg, svuint32_t op1, svuint32_t op2)
/// SEL Zresult.S, Pg, Zop1.S, Zop2.S
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// SEL Presult.B, Pg, Pop1.B, Pop2.B
///
/// </summary>
public static unsafe Vector<uint> ConditionalSelect(Vector<uint> mask, Vector<uint> left, Vector<uint> right) => ConditionalSelect(mask, left, right);

/// <summary>
/// svuint64_t svsel[_u64](svbool_t pg, svuint64_t op1, svuint64_t op2)
/// SEL Zresult.D, Pg, Zop1.D, Zop2.D
/// svbool_t svsel[_b](svbool_t pg, svbool_t op1, svbool_t op2)
/// SEL Presult.B, Pg, Pop1.B, Pop2.B
///
/// </summary>
public static unsafe Vector<ulong> ConditionalSelect(Vector<ulong> mask, Vector<ulong> left, Vector<ulong> right) => ConditionalSelect(mask, left, right);

/// <summary>
/// svfloat32_t svsel[_f32](svbool_t pg, svfloat32_t op1, svfloat32_t op2)
/// SEL Zresult.S, Pg, Zop1.S, Zop2.S
///
/// </summary>
public static unsafe Vector<float> ConditionalSelect(Vector<float> mask, Vector<float> left, Vector<float> right) => ConditionalSelect(mask, left, right);

/// <summary>
/// svfloat64_t svsel[_f64](svbool_t pg, svfloat64_t op1, svfloat64_t op2)
/// SEL Zresult.D, Pg, Zop1.D, Zop2.D
///
/// </summary>
public static unsafe Vector<double> ConditionalSelect(Vector<double> mask, Vector<double> left, Vector<double> right) => ConditionalSelect(mask, left, right);

/// LoadVector : Unextended load

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4149,7 +4149,16 @@ internal Arm64() { }
public static System.Numerics.Vector<ushort> CreateTrueMaskUInt16([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
public static System.Numerics.Vector<uint> CreateTrueMaskUInt32([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
public static System.Numerics.Vector<ulong> CreateTrueMaskUInt64([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }

public static System.Numerics.Vector<sbyte> ConditionalSelect(System.Numerics.Vector<sbyte> mask, System.Numerics.Vector<sbyte> left, System.Numerics.Vector<sbyte> right) { throw null; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is ever generated by the tooling, it's going to change all this to be done alphabetically, hence the comment above.

public static System.Numerics.Vector<short> ConditionalSelect(System.Numerics.Vector<short> mask, System.Numerics.Vector<short> left, System.Numerics.Vector<short> right) { throw null; }
public static System.Numerics.Vector<int> ConditionalSelect(System.Numerics.Vector<int> mask, System.Numerics.Vector<int> left, System.Numerics.Vector<int> right) { throw null; }
public static System.Numerics.Vector<long> ConditionalSelect(System.Numerics.Vector<long> mask, System.Numerics.Vector<long> left, System.Numerics.Vector<long> right) { throw null; }
public static System.Numerics.Vector<byte> ConditionalSelect(System.Numerics.Vector<byte> mask, System.Numerics.Vector<byte> left, System.Numerics.Vector<byte> right) { throw null; }
public static System.Numerics.Vector<ushort> ConditionalSelect(System.Numerics.Vector<ushort> mask, System.Numerics.Vector<ushort> left, System.Numerics.Vector<ushort> right) { throw null; }
public static System.Numerics.Vector<uint> ConditionalSelect(System.Numerics.Vector<uint> mask, System.Numerics.Vector<uint> left, System.Numerics.Vector<uint> right) { throw null; }
public static System.Numerics.Vector<ulong> ConditionalSelect(System.Numerics.Vector<ulong> mask, System.Numerics.Vector<ulong> left, System.Numerics.Vector<ulong> right) { throw null; }
public static System.Numerics.Vector<float> ConditionalSelect(System.Numerics.Vector<float> mask, System.Numerics.Vector<float> left, System.Numerics.Vector<float> right) { throw null; }
public static System.Numerics.Vector<double> ConditionalSelect(System.Numerics.Vector<double> mask, System.Numerics.Vector<double> left, System.Numerics.Vector<double> right) { throw null; }
public static unsafe System.Numerics.Vector<sbyte> LoadVector(System.Numerics.Vector<sbyte> mask, sbyte* address) { throw null; }
public static unsafe System.Numerics.Vector<short> LoadVector(System.Numerics.Vector<short> mask, short* address) { throw null; }
public static unsafe System.Numerics.Vector<int> LoadVector(System.Numerics.Vector<int> mask, int* address) { throw null; }
Expand Down
Loading