Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions aiter/jit/utils/chip_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,33 @@ def get_cu_num():
return cu_num


def _get_pci_chip_id(device_id=0):
import ctypes

libhip = ctypes.CDLL("libamdhip64.so")
chip_id = ctypes.c_int(0)
hipDeviceAttributePciChipId = 10019
err = libhip.hipDeviceGetAttribute(
ctypes.byref(chip_id),
hipDeviceAttributePciChipId,
device_id,
)
if err != 0:
raise RuntimeError(f"hipDeviceGetAttribute(PciChipId) failed with error {err}")
return chip_id.value


MI308_CHIP_IDS = {0x74A2, 0x74A8, 0x74B6, 0x74BC}


def get_device_name():
gfx = get_gfx()

if gfx == "gfx942":
cu = get_cu_num()
if cu == 304:
return "MI300"
elif cu == 80 or cu == 64:
chip_id = _get_pci_chip_id()
if chip_id in MI308_CHIP_IDS:
return "MI308"
return "MI300"
elif gfx == "gfx950":
return "MI350"
else:
Expand Down
11 changes: 5 additions & 6 deletions csrc/cpp_itfs/mha_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,14 @@ std::string get_kernel_co_name(const std::string& cfg_co_name, const std::string
std::string co_name = cfg_co_name;
if(arch_id == "gfx942")
{
auto pos = cfg_co_name.rfind('/');
uint32_t cu_num = get_num_cu_func();
if(cu_num == 304)
auto pos = cfg_co_name.rfind('/');
if(is_mi308_device())
{
co_name = cfg_co_name.substr(0, pos + 1) + "MI300/" + cfg_co_name.substr(pos + 1);
co_name = cfg_co_name.substr(0, pos + 1) + "MI308/" + cfg_co_name.substr(pos + 1);
}
else if(cu_num == 80 || cu_num == 64)
else
{
co_name = cfg_co_name.substr(0, pos + 1) + "MI308/" + cfg_co_name.substr(pos + 1);
co_name = cfg_co_name.substr(0, pos + 1) + "MI300/" + cfg_co_name.substr(pos + 1);
}
}
return co_name;
Expand Down
21 changes: 21 additions & 0 deletions csrc/include/aiter_hip_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,24 @@ static uint32_t get_num_cu_func()
static const uint32_t num_cu = get_num_cu_local();
return num_cu;
}

static int get_pci_chip_id()
{
static const int chip_id = []() {
hipDevice_t dev;
int id = 0;
HIP_CALL(hipGetDevice(&dev));
HIP_CALL(hipDeviceGetAttribute(&id, hipDeviceAttributePciChipId, dev));
AITER_LOG_INFO("pciChipId: 0x" << std::hex << id << std::dec
<< ", CU count: " << get_num_cu_func());
return id;
}();
return chip_id;
}

static bool is_mi308_device()
{
int chip_id = get_pci_chip_id();
return chip_id == 0x74a2 || chip_id == 0x74a8 ||
chip_id == 0x74b6 || chip_id == 0x74bc;
}
Loading