Skip to content

Commit 7641af6

Browse files
committed
HSA queue support (ggml-org#12)
* Creating HSA queue * Zero-init all members * Adding signal support
1 parent 5d90690 commit 7641af6

File tree

2 files changed

+65
-28
lines changed

2 files changed

+65
-28
lines changed

src/ggml-hsa/common.hpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ struct ggml_hsa_device_info {
5959
* @brief Information about a single HSA device.
6060
*/
6161
struct hsa_device_info {
62-
hsa_agent_t agent{}; ///< HSA agent associated with the device.
63-
hsa_device_type_t type{}; ///< Agent type.
64-
std::string name; ///< Agent name.
65-
hsa_memory_pool_info data_memory; ///< Pool for data.
66-
hsa_memory_pool_info kernarg_memory; ///< Pool for kernel arguments.
62+
hsa_agent_t agent{}; ///< HSA agent associated with the device.
63+
hsa_device_type_t type{}; ///< Agent type.
64+
std::string name; ///< Agent name.
65+
hsa_memory_pool_info data_memory{}; ///< Pool for data.
66+
hsa_memory_pool_info kernarg_memory{}; ///< Pool for kernel arguments.
6767
};
6868

6969
std::array<hsa_device_info, GGML_HSA_MAX_DEVICES> devices = {};
@@ -84,8 +84,18 @@ const ggml_hsa_device_info & ggml_hsa_info();
8484
* @brief Context for HSA backend operations.
8585
*/
8686
struct ggml_backend_hsa_context {
87-
std::int32_t device; ///< Device ID.
88-
std::string name; ///< Device name.
87+
std::int32_t device{}; ///< Device ID.
88+
std::string name; ///< Device name.
89+
hsa_queue_t* queue{}; ///< HSA queue associated with the context.
90+
hsa_signal_t dispatch_signal{}; ///< Signal to wait for dispatches.
8991

90-
explicit ggml_backend_hsa_context(std::int32_t device);
92+
ggml_backend_hsa_context(std::int32_t device, const ggml_hsa_device_info::hsa_device_info& device_info);
93+
94+
ggml_backend_hsa_context(const ggml_backend_hsa_context &) = delete;
95+
ggml_backend_hsa_context(ggml_backend_hsa_context &&) = delete;
96+
97+
~ggml_backend_hsa_context();
98+
99+
ggml_backend_hsa_context& operator=(const ggml_backend_hsa_context &) = delete;
100+
ggml_backend_hsa_context& operator=(ggml_backend_hsa_context &&) = delete;
91101
};

src/ggml-hsa/ggml-hsa.cpp

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <cstdint>
88
#include <cstring>
99
#include <mutex>
10+
#include <stdexcept>
1011
#include <string>
1112
#include <vector>
1213

@@ -61,6 +62,13 @@ static std::string ggml_hsa_agent_name(hsa_agent_t agent) {
6162
return GGML_HSA_NAME + std::string{agent_name};
6263
}
6364

65+
// Returns the minimum queue size
66+
static std::uint32_t ggml_hsa_get_agent_min_queue_size(hsa_agent_t agent) {
67+
std::uint32_t min_queue_size = 0;
68+
HSA_CHECK(hsa_agent_get_info(agent, HSA_AGENT_INFO_QUEUE_MIN_SIZE, &min_queue_size));
69+
return min_queue_size;
70+
}
71+
6472
/**
6573
* @brief Populates the information in @p info from @p pool.
6674
*/
@@ -194,8 +202,26 @@ const ggml_hsa_device_info & ggml_hsa_info() {
194202
return info;
195203
}
196204

197-
ggml_backend_hsa_context::ggml_backend_hsa_context(std::int32_t device) :
205+
ggml_backend_hsa_context::ggml_backend_hsa_context(std::int32_t device, const ggml_hsa_device_info::hsa_device_info & device_info) :
198206
device(device), name(ggml_hsa_format_name(device)) {
207+
// create queue
208+
const std::uint32_t min_queue_size = ggml_hsa_get_agent_min_queue_size(device_info.agent);
209+
if (auto status = hsa_queue_create(device_info.agent, min_queue_size, HSA_QUEUE_TYPE_SINGLE, nullptr, nullptr, 0, 0, &queue);
210+
status != HSA_STATUS_SUCCESS) {
211+
GGML_LOG_ERROR("%s: hsa_queue_create failed: %s", __func__, ggml_hsa_get_status_string(status));
212+
throw std::runtime_error("hsa_queue_create failed");
213+
}
214+
215+
// create signal to wait for packets
216+
if (auto status = hsa_signal_create(0, 0, nullptr, &dispatch_signal); status != HSA_STATUS_SUCCESS) {
217+
GGML_LOG_ERROR("%s: hsa_signal_create failed: %s", __func__, ggml_hsa_get_status_string(status));
218+
throw std::runtime_error("hsa_signal_create failed");
219+
}
220+
}
221+
222+
ggml_backend_hsa_context::~ggml_backend_hsa_context() {
223+
HSA_CHECK(hsa_signal_destroy(dispatch_signal));
224+
HSA_CHECK(hsa_queue_destroy(queue));
199225
}
200226

201227
// HSA buffer
@@ -538,13 +564,13 @@ void ggml_hsa_mul_mat_impl(ggml_backend_hsa_context &, const ggml_tensor * src0,
538564
m_out = m_in1.transpose() * m_in2;
539565
}
540566

541-
static void ggml_hsa_mul_mat(ggml_backend_hsa_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
567+
static enum ggml_status ggml_hsa_mul_mat(ggml_backend_hsa_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
542568
assert(src0->type == src1->type);
543-
assert(dst->type == dst->type);
569+
assert(src0->type == dst->type);
544570

545-
if (ggml_is_transposed(dst->src[0]) ||
546-
ggml_is_transposed(dst->src[1])) {
547-
GGML_ABORT("%s: %s: matmul on tranposed tensor not supported", __func__, ggml_op_name(dst->op));
571+
if (ggml_is_transposed(src0) || ggml_is_transposed(src1)) {
572+
GGML_LOG_ERROR("%s: %s: matmul on tranposed tensor not supported", __func__, ggml_op_name(dst->op));
573+
return GGML_STATUS_FAILED;
548574
}
549575

550576
switch (src0->type) {
@@ -558,9 +584,12 @@ static void ggml_hsa_mul_mat(ggml_backend_hsa_context & ctx, const ggml_tensor *
558584
ggml_hsa_mul_mat_impl<double>(ctx, src0, src1, dst);
559585
break;
560586
default:
561-
GGML_ABORT("%s: Unsupported type %s", __func__, ggml_type_name(src0->type));
587+
GGML_LOG_ERROR("%s: Unsupported type %s", __func__, ggml_type_name(src0->type));
588+
return GGML_STATUS_FAILED;
562589
}
563590
}
591+
592+
return GGML_STATUS_SUCCESS;
564593
}
565594

566595
////////////////////////////////////////////////////////////////////////////////
@@ -646,12 +675,15 @@ static bool ggml_backend_hsa_cpy_tensor_async(ggml_backend_t backend_src, ggml_b
646675
}
647676

648677
static void ggml_backend_hsa_synchronize(ggml_backend_t backend) {
649-
GGML_LOG_WARN("%s: needs synchronize kernel\n", __func__);
678+
auto * ctx = static_cast<ggml_backend_hsa_context *>(backend->context);
679+
if (auto val = hsa_signal_wait_scacquire(ctx->dispatch_signal, HSA_SIGNAL_CONDITION_EQ, 0, UINT64_MAX, HSA_WAIT_STATE_BLOCKED);
680+
val != 0) {
681+
GGML_ABORT("%s: error: unexpected signal value (%d)\n", __func__, val);
682+
}
650683
}
651684

652685
static enum ggml_status ggml_backend_hsa_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
653686
auto * ctx = static_cast<ggml_backend_hsa_context *>(backend->context);
654-
655687
ggml_status status = GGML_STATUS_SUCCESS;
656688
auto backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
657689

@@ -669,8 +701,7 @@ static enum ggml_status ggml_backend_hsa_graph_compute(ggml_backend_t backend, g
669701
case GGML_OP_VIEW:
670702
break;
671703
case GGML_OP_MUL_MAT: {
672-
// ggml_hsa_mul_mat(*ctx, node->src[0], node->src[1], node);
673-
status = ggml_backend_graph_compute(backend_cpu, cgraph);
704+
status = ggml_hsa_mul_mat(*ctx, node->src[0], node->src[1], node);
674705
break;
675706
}
676707
default: {
@@ -862,12 +893,8 @@ static bool ggml_backend_hsa_device_supports_op(ggml_backend_dev_t dev, const gg
862893
case GGML_OP_TRANSPOSE:
863894
case GGML_OP_VIEW:
864895
return true;
865-
case GGML_OP_MUL_MAT: {
866-
auto backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
867-
auto result = ggml_backend_supports_op(backend_cpu, op);
868-
ggml_backend_free(backend_cpu);
869-
return result;
870-
}
896+
case GGML_OP_MUL_MAT:
897+
return true;
871898
default: {
872899
// GGML_LOG_ERROR("%s: error: unknown operator %s", __func__, ggml_op_name(op->op));
873900
// return false;
@@ -1036,9 +1063,9 @@ ggml_backend_t ggml_backend_hsa_init(int device) {
10361063

10371064
ggml_backend_hsa_context * ctx = nullptr;
10381065
try {
1039-
ctx = new ggml_backend_hsa_context{device};
1040-
} catch (const std::bad_alloc&) {
1041-
GGML_LOG_ERROR("%s: failed to allocate context\n", __func__);
1066+
ctx = new ggml_backend_hsa_context{device, info.devices[device]};
1067+
} catch (const std::exception&) {
1068+
GGML_LOG_ERROR("%s: failed to create context\n", __func__);
10421069
return nullptr;
10431070
}
10441071

0 commit comments

Comments
 (0)