Skip to content

Commit 865d042

Browse files
committed
use cudaMemcpy3DPeerAsync
1 parent 1659cd1 commit 865d042

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

ggml-cuda.cu

+17-11
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@
6868
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
6969
#endif
7070
#define cudaMemcpy hipMemcpy
71-
#define cudaMemcpy2DAsync hipMemcpy2DAsync
7271
#define cudaMemcpyAsync hipMemcpyAsync
7372
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
73+
#define cudaMemcpy2DAsync hipMemcpy2DAsync
7474
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
7575
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
7676
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
@@ -8258,17 +8258,23 @@ static void ggml_cuda_op_mul_mat(
82588258
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
82598259
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
82608260
dhf_dst_i += src1_col_0*ne0 + row_low[id];
8261-
8261+
#if !defined(GGML_USE_HIPBLAS)
82628262
if (kind == cudaMemcpyDeviceToDevice && id != g_main_device) {
8263-
// there is no cudaMemcpy2DPeerAsync so we need to copy each row separately
8264-
for (int64_t i = 0; i < src1_ncols; ++i) {
8265-
CUDA_CHECK(cudaMemcpyPeerAsync(dhf_dst_i + i*ne0, g_main_device,
8266-
dst_dd_i + i*row_diff, id,
8267-
row_diff*sizeof(float), stream));
8268-
}
8269-
} else {
8270-
CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_dd_i, row_diff*sizeof(float),
8271-
row_diff*sizeof(float), src1_ncols, kind, stream));
8263+
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
8264+
cudaMemcpy3DPeerParms p = {};
8265+
p.dstDevice = g_main_device;
8266+
p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), ne0, src1_ncols);
8267+
p.srcDevice = id;
8268+
p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols);
8269+
p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1);
8270+
CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream));
8271+
} else
8272+
#endif
8273+
{
8274+
CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float),
8275+
dst_dd_i, row_diff*sizeof(float),
8276+
row_diff*sizeof(float), src1_ncols,
8277+
kind, stream));
82728278
}
82738279
} else {
82748280
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);

0 commit comments

Comments
 (0)