|
| 1 | +//===- RISCVVectorMaskDAGMutation.cpp - RISCV Vector Mask DAGMutation -----===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// A schedule mutation that adds an artificial dependency between masks producer |
| 10 | +// instructions and masked instructions, so that we can reduce the live range |
| 11 | +// overlaps of mask registers. |
| 12 | +// |
| 13 | +// The reason why we need to do this: |
| 14 | +// 1. When tracking register pressure, we don't track physical registers. |
| 15 | +// 2. We have a RegisterClass for mask reigster (which is `VMV0`), but we don't |
| 16 | +// use it in most RVV pseudos (only used in inline asm constraint and add/sub |
| 17 | +// with carry instructions). Instead, we use physical register V0 directly |
| 18 | +// and insert a `$v0 = COPY ...` before the use. And, there is a fundamental |
| 19 | +// issue in register allocator when handling RegisterClass with only one |
| 20 | +// physical register, so we can't simply replace V0 with VMV0. |
| 21 | +// 3. For mask producers, we are using VR RegisterClass (we can allocate V0-V31 |
| 22 | +// to it). So if V0 is not available, there are still 31 available registers |
| 23 | +// out there. |
| 24 | +// |
| 25 | +// This means that the RegPressureTracker can't track the pressure of mask |
| 26 | +// registers correctly. |
| 27 | +// |
| 28 | +// This schedule mutation is a workaround to fix this issue. |
| 29 | +// |
| 30 | +//===----------------------------------------------------------------------===// |
| 31 | + |
| 32 | +#include "MCTargetDesc/RISCVBaseInfo.h" |
| 33 | +#include "MCTargetDesc/RISCVMCTargetDesc.h" |
| 34 | +#include "RISCVRegisterInfo.h" |
| 35 | +#include "RISCVTargetMachine.h" |
| 36 | +#include "llvm/CodeGen/LiveIntervals.h" |
| 37 | +#include "llvm/CodeGen/MachineInstr.h" |
| 38 | +#include "llvm/CodeGen/ScheduleDAGInstrs.h" |
| 39 | +#include "llvm/CodeGen/ScheduleDAGMutation.h" |
| 40 | +#include "llvm/TargetParser/RISCVTargetParser.h" |
| 41 | + |
| 42 | +#define DEBUG_TYPE "machine-scheduler" |
| 43 | + |
| 44 | +namespace llvm { |
| 45 | + |
| 46 | +static inline bool isVectorMaskProducer(const MachineInstr *MI) { |
| 47 | + switch (RISCV::getRVVMCOpcode(MI->getOpcode())) { |
| 48 | + // Vector Mask Instructions |
| 49 | + case RISCV::VMAND_MM: |
| 50 | + case RISCV::VMNAND_MM: |
| 51 | + case RISCV::VMANDN_MM: |
| 52 | + case RISCV::VMXOR_MM: |
| 53 | + case RISCV::VMOR_MM: |
| 54 | + case RISCV::VMNOR_MM: |
| 55 | + case RISCV::VMORN_MM: |
| 56 | + case RISCV::VMXNOR_MM: |
| 57 | + case RISCV::VMSBF_M: |
| 58 | + case RISCV::VMSIF_M: |
| 59 | + case RISCV::VMSOF_M: |
| 60 | + // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions |
| 61 | + case RISCV::VMADC_VV: |
| 62 | + case RISCV::VMADC_VX: |
| 63 | + case RISCV::VMADC_VI: |
| 64 | + case RISCV::VMADC_VVM: |
| 65 | + case RISCV::VMADC_VXM: |
| 66 | + case RISCV::VMADC_VIM: |
| 67 | + case RISCV::VMSBC_VV: |
| 68 | + case RISCV::VMSBC_VX: |
| 69 | + case RISCV::VMSBC_VVM: |
| 70 | + case RISCV::VMSBC_VXM: |
| 71 | + // Vector Integer Compare Instructions |
| 72 | + case RISCV::VMSEQ_VV: |
| 73 | + case RISCV::VMSEQ_VX: |
| 74 | + case RISCV::VMSEQ_VI: |
| 75 | + case RISCV::VMSNE_VV: |
| 76 | + case RISCV::VMSNE_VX: |
| 77 | + case RISCV::VMSNE_VI: |
| 78 | + case RISCV::VMSLT_VV: |
| 79 | + case RISCV::VMSLT_VX: |
| 80 | + case RISCV::VMSLTU_VV: |
| 81 | + case RISCV::VMSLTU_VX: |
| 82 | + case RISCV::VMSLE_VV: |
| 83 | + case RISCV::VMSLE_VX: |
| 84 | + case RISCV::VMSLE_VI: |
| 85 | + case RISCV::VMSLEU_VV: |
| 86 | + case RISCV::VMSLEU_VX: |
| 87 | + case RISCV::VMSLEU_VI: |
| 88 | + case RISCV::VMSGTU_VX: |
| 89 | + case RISCV::VMSGTU_VI: |
| 90 | + case RISCV::VMSGT_VX: |
| 91 | + case RISCV::VMSGT_VI: |
| 92 | + // Vector Floating-Point Compare Instructions |
| 93 | + case RISCV::VMFEQ_VV: |
| 94 | + case RISCV::VMFEQ_VF: |
| 95 | + case RISCV::VMFNE_VV: |
| 96 | + case RISCV::VMFNE_VF: |
| 97 | + case RISCV::VMFLT_VV: |
| 98 | + case RISCV::VMFLT_VF: |
| 99 | + case RISCV::VMFLE_VV: |
| 100 | + case RISCV::VMFLE_VF: |
| 101 | + case RISCV::VMFGT_VF: |
| 102 | + case RISCV::VMFGE_VF: |
| 103 | + return true; |
| 104 | + } |
| 105 | + return false; |
| 106 | +} |
| 107 | + |
| 108 | +class RISCVVectorMaskDAGMutation : public ScheduleDAGMutation { |
| 109 | +private: |
| 110 | + const TargetRegisterInfo *TRI; |
| 111 | + |
| 112 | +public: |
| 113 | + RISCVVectorMaskDAGMutation(const TargetRegisterInfo *TRI) : TRI(TRI) {} |
| 114 | + |
| 115 | + void apply(ScheduleDAGInstrs *DAG) override { |
| 116 | + SUnit *NearestUseV0SU = nullptr; |
| 117 | + for (SUnit &SU : DAG->SUnits) { |
| 118 | + const MachineInstr *MI = SU.getInstr(); |
| 119 | + if (MI->findRegisterUseOperand(RISCV::V0, TRI)) |
| 120 | + NearestUseV0SU = &SU; |
| 121 | + |
| 122 | + if (NearestUseV0SU && NearestUseV0SU != &SU && isVectorMaskProducer(MI) && |
| 123 | + // For LMUL=8 cases, there will be more possibilities to spill. |
| 124 | + // FIXME: We should use RegPressureTracker to do fine-grained |
| 125 | + // controls. |
| 126 | + RISCVII::getLMul(MI->getDesc().TSFlags) != RISCVII::LMUL_8) |
| 127 | + DAG->addEdge(&SU, SDep(NearestUseV0SU, SDep::Artificial)); |
| 128 | + } |
| 129 | + } |
| 130 | +}; |
| 131 | + |
| 132 | +std::unique_ptr<ScheduleDAGMutation> |
| 133 | +createRISCVVectorMaskDAGMutation(const TargetRegisterInfo *TRI) { |
| 134 | + return std::make_unique<RISCVVectorMaskDAGMutation>(TRI); |
| 135 | +} |
| 136 | + |
| 137 | +} // namespace llvm |
0 commit comments