diff --git a/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp index c250340c38fc7..806d49f342d29 100644 --- a/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp @@ -11,8 +11,11 @@ //===----------------------------------------------------------------------===// #include +#include +#include #include #include +#include #ifdef _WIN32 #define SYCL_RUNTIME_EXPORT __declspec(dllexport) @@ -49,39 +52,53 @@ auto catchAll(F &&func) { } // namespace -static sycl::device getDefaultDevice() { - static sycl::device syclDevice; - static bool isDeviceInitialised = false; - if (!isDeviceInitialised) { - auto platformList = sycl::platform::get_platforms(); - for (const auto &platform : platformList) { - auto platformName = platform.get_info(); - bool isLevelZero = platformName.find("Level-Zero") != std::string::npos; - if (!isLevelZero) - continue; - - syclDevice = platform.get_devices()[0]; - isDeviceInitialised = true; - return syclDevice; +thread_local static int32_t defaultDevice = 0; +thread_local static bool isGpuPoolInitialized = false; +thread_local static bool isDefaultContextInitialized = false; +thread_local static std::vector *pGpuPool = nullptr; +thread_local static sycl::context *pDefaultContext = nullptr; + +static void initGpuPool() { + if (isGpuPoolInitialized) + return; + auto platformList = sycl::platform::get_platforms(); + for (const auto &platform : platformList) { + if (platform.get_backend() == sycl::backend::ext_oneapi_level_zero) { + auto gpuDevices = platform.get_devices(sycl::info::device_type::gpu); + if (gpuDevices.empty()) { + throw std::runtime_error("SyclRuntime: No GPU devices found!"); + } + pGpuPool = new std::vector{gpuDevices}; + isGpuPoolInitialized = true; + return; } - throw std::runtime_error("getDefaultDevice failed"); - } else - return syclDevice; + } + throw std::runtime_error("SyclRuntime: No GPU devices found!"); } -static sycl::context getDefaultContext() { - static sycl::context syclContext{getDefaultDevice()}; - return syclContext; +static sycl::device *getDefaultDevicePtr() { + initGpuPool(); + return &((*pGpuPool)[defaultDevice]); +} + +static sycl::context *getDefaultContextPtr() { + if (isDefaultContextInitialized) { + return pDefaultContext; + } + initGpuPool(); + pDefaultContext = new sycl::context(*pGpuPool); + isDefaultContextInitialized = true; + return pDefaultContext; } static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) { void *memPtr = nullptr; if (isShared) { - memPtr = sycl::aligned_alloc_shared(64, size, getDefaultDevice(), - getDefaultContext()); + memPtr = sycl::aligned_alloc_shared(64, size, *getDefaultDevicePtr(), + *getDefaultContextPtr()); } else { - memPtr = sycl::aligned_alloc_device(64, size, getDefaultDevice(), - getDefaultContext()); + memPtr = sycl::aligned_alloc_device(64, size, *getDefaultDevicePtr(), + *getDefaultContextPtr()); } if (memPtr == nullptr) { throw std::runtime_error("mem allocation failed!"); @@ -90,7 +107,13 @@ static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) { } static void deallocDeviceMemory(sycl::queue *queue, void *ptr) { - sycl::free(ptr, *queue); + if (queue == nullptr) { + queue = new sycl::queue(*getDefaultContextPtr(), *getDefaultDevicePtr()); + sycl::free(ptr, *queue); + delete queue; + } else { + sycl::free(ptr, *queue); + } } static ze_module_handle_t loadModule(const void *data, size_t dataSize) { @@ -104,9 +127,9 @@ static ze_module_handle_t loadModule(const void *data, size_t dataSize) { nullptr, nullptr}; auto zeDevice = sycl::get_native( - getDefaultDevice()); + *getDefaultDevicePtr()); auto zeContext = sycl::get_native( - getDefaultContext()); + *getDefaultContextPtr()); L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr)); return zeModule; } @@ -115,17 +138,33 @@ static sycl::kernel *getKernel(ze_module_handle_t zeModule, const char *name) { assert(zeModule); assert(name); ze_kernel_handle_t zeKernel; - ze_kernel_desc_t desc = {}; - desc.pKernelName = name; + ze_kernel_desc_t desc = {ZE_STRUCTURE_TYPE_KERNEL_DESC, nullptr, + 0, // flags + name}; + + ze_result_t result = zeKernelCreate(zeModule, &desc, &zeKernel); + + // Check if there are unresolved imports + if (result == ZE_RESULT_ERROR_INVALID_MODULE_UNLINKED) { + fprintf(stdout, "Unresolved imports!!!\n"); + fflush(stdout); + abort(); + } + + // Check to see if the kernel name was found in the supplied module + if (result == ZE_RESULT_ERROR_INVALID_KERNEL_NAME) { + fprintf(stdout, "Invalid kernel name: %s !!!\n", name); + fflush(stdout); + abort(); + } - L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel)); sycl::kernel_bundle kernelBundle = sycl::make_kernel_bundle( - {zeModule}, getDefaultContext()); + {zeModule}, *getDefaultContextPtr()); auto kernel = sycl::make_kernel( - {kernelBundle, zeKernel}, getDefaultContext()); + {kernelBundle, zeKernel}, *getDefaultContextPtr()); return new sycl::kernel(kernel); } @@ -152,7 +191,7 @@ extern "C" SYCL_RUNTIME_EXPORT sycl::queue *mgpuStreamCreate() { return catchAll([&]() { sycl::queue *queue = - new sycl::queue(getDefaultContext(), getDefaultDevice()); + new sycl::queue(*getDefaultContextPtr(), *getDefaultDevicePtr()); return queue; }); } @@ -207,3 +246,11 @@ mgpuModuleUnload(ze_module_handle_t module) { catchAll([&]() { L0_SAFE_CALL(zeModuleDestroy(module)); }); } + +extern "C" SYCL_RUNTIME_EXPORT void mgpuSetDefaultDevice(int32_t device) { + initGpuPool(); + if (device >= pGpuPool->size()) { + throw std::runtime_error("SyclRuntime: Invalid device index!"); + } + defaultDevice = device; +}