7
7
#include < cstdint>
8
8
#include < cstring>
9
9
#include < mutex>
10
+ #include < stdexcept>
10
11
#include < string>
11
12
#include < vector>
12
13
@@ -61,6 +62,13 @@ static std::string ggml_hsa_agent_name(hsa_agent_t agent) {
61
62
return GGML_HSA_NAME + std::string{agent_name};
62
63
}
63
64
65
+ // Returns the minimum queue size
66
+ static std::uint32_t ggml_hsa_get_agent_min_queue_size (hsa_agent_t agent) {
67
+ std::uint32_t min_queue_size = 0 ;
68
+ HSA_CHECK (hsa_agent_get_info (agent, HSA_AGENT_INFO_QUEUE_MIN_SIZE, &min_queue_size));
69
+ return min_queue_size;
70
+ }
71
+
64
72
/* *
65
73
* @brief Populates the information in @p info from @p pool.
66
74
*/
@@ -194,8 +202,26 @@ const ggml_hsa_device_info & ggml_hsa_info() {
194
202
return info;
195
203
}
196
204
197
- ggml_backend_hsa_context::ggml_backend_hsa_context (std::int32_t device) :
205
+ ggml_backend_hsa_context::ggml_backend_hsa_context (std::int32_t device, const ggml_hsa_device_info::hsa_device_info & device_info ) :
198
206
device(device), name(ggml_hsa_format_name(device)) {
207
+ // create queue
208
+ const std::uint32_t min_queue_size = ggml_hsa_get_agent_min_queue_size (device_info.agent );
209
+ if (auto status = hsa_queue_create (device_info.agent , min_queue_size, HSA_QUEUE_TYPE_SINGLE, nullptr , nullptr , 0 , 0 , &queue);
210
+ status != HSA_STATUS_SUCCESS) {
211
+ GGML_LOG_ERROR (" %s: hsa_queue_create failed: %s" , __func__, ggml_hsa_get_status_string (status));
212
+ throw std::runtime_error (" hsa_queue_create failed" );
213
+ }
214
+
215
+ // create signal to wait for packets
216
+ if (auto status = hsa_signal_create (0 , 0 , nullptr , &dispatch_signal); status != HSA_STATUS_SUCCESS) {
217
+ GGML_LOG_ERROR (" %s: hsa_signal_create failed: %s" , __func__, ggml_hsa_get_status_string (status));
218
+ throw std::runtime_error (" hsa_signal_create failed" );
219
+ }
220
+ }
221
+
222
+ ggml_backend_hsa_context::~ggml_backend_hsa_context () {
223
+ HSA_CHECK (hsa_signal_destroy (dispatch_signal));
224
+ HSA_CHECK (hsa_queue_destroy (queue));
199
225
}
200
226
201
227
// HSA buffer
@@ -538,13 +564,13 @@ void ggml_hsa_mul_mat_impl(ggml_backend_hsa_context &, const ggml_tensor * src0,
538
564
m_out = m_in1.transpose () * m_in2;
539
565
}
540
566
541
- static void ggml_hsa_mul_mat (ggml_backend_hsa_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
567
+ static enum ggml_status ggml_hsa_mul_mat (ggml_backend_hsa_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
542
568
assert (src0->type == src1->type );
543
- assert (dst ->type == dst->type );
569
+ assert (src0 ->type == dst->type );
544
570
545
- if (ggml_is_transposed (dst-> src [ 0 ] ) ||
546
- ggml_is_transposed ( dst->src [ 1 ])) {
547
- GGML_ABORT ( " %s: %s: matmul on tranposed tensor not supported " , __func__, ggml_op_name (dst-> op )) ;
571
+ if (ggml_is_transposed (src0 ) || ggml_is_transposed (src1)) {
572
+ GGML_LOG_ERROR ( " %s: %s: matmul on tranposed tensor not supported " , __func__, ggml_op_name ( dst->op ));
573
+ return GGML_STATUS_FAILED ;
548
574
}
549
575
550
576
switch (src0->type ) {
@@ -558,9 +584,12 @@ static void ggml_hsa_mul_mat(ggml_backend_hsa_context & ctx, const ggml_tensor *
558
584
ggml_hsa_mul_mat_impl<double >(ctx, src0, src1, dst);
559
585
break ;
560
586
default :
561
- GGML_ABORT (" %s: Unsupported type %s" , __func__, ggml_type_name (src0->type ));
587
+ GGML_LOG_ERROR (" %s: Unsupported type %s" , __func__, ggml_type_name (src0->type ));
588
+ return GGML_STATUS_FAILED;
562
589
}
563
590
}
591
+
592
+ return GGML_STATUS_SUCCESS;
564
593
}
565
594
566
595
// //////////////////////////////////////////////////////////////////////////////
@@ -646,12 +675,15 @@ static bool ggml_backend_hsa_cpy_tensor_async(ggml_backend_t backend_src, ggml_b
646
675
}
647
676
648
677
static void ggml_backend_hsa_synchronize (ggml_backend_t backend) {
649
- GGML_LOG_WARN (" %s: needs synchronize kernel\n " , __func__);
678
+ auto * ctx = static_cast <ggml_backend_hsa_context *>(backend->context );
679
+ if (auto val = hsa_signal_wait_scacquire (ctx->dispatch_signal , HSA_SIGNAL_CONDITION_EQ, 0 , UINT64_MAX, HSA_WAIT_STATE_BLOCKED);
680
+ val != 0 ) {
681
+ GGML_ABORT (" %s: error: unexpected signal value (%d)\n " , __func__, val);
682
+ }
650
683
}
651
684
652
685
static enum ggml_status ggml_backend_hsa_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
653
686
auto * ctx = static_cast <ggml_backend_hsa_context *>(backend->context );
654
-
655
687
ggml_status status = GGML_STATUS_SUCCESS;
656
688
auto backend_cpu = ggml_backend_init_by_type (GGML_BACKEND_DEVICE_TYPE_CPU, nullptr );
657
689
@@ -669,8 +701,7 @@ static enum ggml_status ggml_backend_hsa_graph_compute(ggml_backend_t backend, g
669
701
case GGML_OP_VIEW:
670
702
break ;
671
703
case GGML_OP_MUL_MAT: {
672
- // ggml_hsa_mul_mat(*ctx, node->src[0], node->src[1], node);
673
- status = ggml_backend_graph_compute (backend_cpu, cgraph);
704
+ status = ggml_hsa_mul_mat (*ctx, node->src [0 ], node->src [1 ], node);
674
705
break ;
675
706
}
676
707
default : {
@@ -862,12 +893,8 @@ static bool ggml_backend_hsa_device_supports_op(ggml_backend_dev_t dev, const gg
862
893
case GGML_OP_TRANSPOSE:
863
894
case GGML_OP_VIEW:
864
895
return true ;
865
- case GGML_OP_MUL_MAT: {
866
- auto backend_cpu = ggml_backend_init_by_type (GGML_BACKEND_DEVICE_TYPE_CPU, NULL );
867
- auto result = ggml_backend_supports_op (backend_cpu, op);
868
- ggml_backend_free (backend_cpu);
869
- return result;
870
- }
896
+ case GGML_OP_MUL_MAT:
897
+ return true ;
871
898
default : {
872
899
// GGML_LOG_ERROR("%s: error: unknown operator %s", __func__, ggml_op_name(op->op));
873
900
// return false;
@@ -1036,9 +1063,9 @@ ggml_backend_t ggml_backend_hsa_init(int device) {
1036
1063
1037
1064
ggml_backend_hsa_context * ctx = nullptr ;
1038
1065
try {
1039
- ctx = new ggml_backend_hsa_context{device};
1040
- } catch (const std::bad_alloc &) {
1041
- GGML_LOG_ERROR (" %s: failed to allocate context\n " , __func__);
1066
+ ctx = new ggml_backend_hsa_context{device, info. devices [device] };
1067
+ } catch (const std::exception &) {
1068
+ GGML_LOG_ERROR (" %s: failed to create context\n " , __func__);
1042
1069
return nullptr ;
1043
1070
}
1044
1071
0 commit comments