Skip to content

Commit e35e377

Browse files
--main-gpu CLI option
1 parent 59aa825 commit e35e377

File tree

9 files changed

+79
-18
lines changed

9 files changed

+79
-18
lines changed

examples/common.cpp

+13-1
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,16 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
295295
#else
296296
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
297297
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
298+
#endif
299+
} else if (arg == "--main-gpu" || arg == "-mg") {
300+
if (++i >= argc) {
301+
invalid_param = true;
302+
break;
303+
}
304+
#ifdef GGML_USE_CUBLAS
305+
params.main_gpu = std::stoi(argv[i]);
306+
#else
307+
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
298308
#endif
299309
} else if (arg == "--tensor-split" || arg == "-ts") {
300310
if (++i >= argc) {
@@ -318,7 +328,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
318328
}
319329
}
320330
#else
321-
fprintf(stderr, "WARNING: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
331+
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
322332
#endif // GGML_USE_CUBLAS
323333
} else if (arg == "--no-mmap") {
324334
params.use_mmap = false;
@@ -465,6 +475,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
465475
fprintf(stderr, " number of layers to store in VRAM\n");
466476
fprintf(stderr, " -ts SPLIT --tensor-split SPLIT\n");
467477
fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
478+
fprintf(stderr, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" );
468479
#endif
469480
fprintf(stderr, " --mtest compute maximum memory usage\n");
470481
fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n");
@@ -512,6 +523,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
512523
lparams.n_ctx = params.n_ctx;
513524
lparams.n_batch = params.n_batch;
514525
lparams.n_gpu_layers = params.n_gpu_layers;
526+
lparams.main_gpu = params.main_gpu;
515527
memcpy(lparams.tensor_split, params.tensor_split, LLAMA_MAX_DEVICES*sizeof(float));
516528
lparams.seed = params.seed;
517529
lparams.f16_kv = params.memory_f16;

examples/common.h

+7-6
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121
int32_t get_num_physical_cores();
2222

2323
struct gpt_params {
24-
int32_t seed = -1; // RNG seed
24+
int32_t seed = -1; // RNG seed
2525
int32_t n_threads = get_num_physical_cores();
26-
int32_t n_predict = -1; // new tokens to predict
27-
int32_t n_ctx = 512; // context size
28-
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
29-
int32_t n_keep = 0; // number of tokens to keep from initial prompt
30-
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
26+
int32_t n_predict = -1; // new tokens to predict
27+
int32_t n_ctx = 512; // context size
28+
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
29+
int32_t n_keep = 0; // number of tokens to keep from initial prompt
30+
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
31+
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
3132
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
3233

3334
// sampling parameters

examples/main/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -286,5 +286,7 @@ These options provide extra functionality and customization when running the LLa
286286
- `--verbose-prompt`: Print the prompt before generating text.
287287
- `--mtest`: Test the model's functionality by running a series of tests to ensure it's working properly.
288288
- `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
289+
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
290+
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
289291
- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
290292
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.

examples/server/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ Test();
287287
- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`).
288288
- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference.
289289
- `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
290+
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
291+
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
290292
- `--embedding`: Enable the embedding mode. **Completion function doesn't work in this mode**.
291293
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`;
292294
- `--port`: Set the port to listen. Default: `8080`.

examples/server/server.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,8 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params &params)
403403
fprintf(stderr, " number of layers to store in VRAM\n");
404404
fprintf(stderr, " -ts SPLIT --tensor-split SPLIT\n");
405405
fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
406+
fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
407+
fprintf(stderr, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" );
406408
#endif
407409
fprintf(stderr, " -m FNAME, --model FNAME\n");
408410
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
@@ -536,6 +538,19 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
536538
#else
537539
fprintf(stderr, "WARNING: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
538540
#endif // GGML_USE_CUBLAS
541+
}
542+
else if (arg == "--main-gpu" || arg == "-mg")
543+
{
544+
if (++i >= argc)
545+
{
546+
invalid_param = true;
547+
break;
548+
}
549+
#ifdef GGML_USE_CUBLAS
550+
params.main_gpu = std::stoi(argv[i]);
551+
#else
552+
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
553+
#endif
539554
}
540555
else
541556
{

ggml-cuda.cu

+30-7
Original file line numberDiff line numberDiff line change
@@ -561,15 +561,15 @@ void ggml_init_cublas() {
561561
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
562562
int64_t total_vram = 0;
563563
fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count);
564-
for (int i = 0; i < g_device_count; ++i) {
564+
for (int id = 0; id < g_device_count; ++id) {
565565
cudaDeviceProp prop;
566-
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
567-
fprintf(stderr, " %d. %s\n", i+1, prop.name);
568-
g_tensor_split[i] = total_vram;
566+
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
567+
fprintf(stderr, " Device %d: %s\n", id, prop.name);
568+
g_tensor_split[id] = total_vram;
569569
total_vram += prop.totalGlobalMem;
570570
}
571-
for (int i = 0; i < g_device_count; ++i) {
572-
g_tensor_split[i] /= total_vram;
571+
for (int id = 0; id < g_device_count; ++id) {
572+
g_tensor_split[id] /= total_vram;
573573
}
574574

575575
for (int id = 0; id < g_device_count; ++id) {
@@ -835,7 +835,10 @@ inline void ggml_cuda_op_mul_mat_cublas(
835835

836836
int id;
837837
CUDA_CHECK(cudaGetDevice(&id));
838-
const int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
838+
839+
// the main device has a larger memory buffer to hold the results from all GPUs
840+
// ldc == nrows of the matrix that cuBLAS writes into
841+
int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
839842

840843
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], cudaStream_main));
841844
CUBLAS_CHECK(
@@ -1041,6 +1044,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
10411044
dst_ddf_i -= (row_low % ne0)*ne1;
10421045
}
10431046

1047+
// the main device memory buffer can be on VRAM scratch, with space for all partial results
1048+
// in that case an offset on dst_ddf_i is needed
1049+
if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) {
1050+
dst_ddf_i += i01_low; // offset is 0 if no tensor split
1051+
}
1052+
10441053
// copy src0, src1 to device if necessary
10451054
if (use_src1) {
10461055
if (src1->backend == GGML_BACKEND_CPU) {
@@ -1322,6 +1331,20 @@ void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
13221331
tensor->extra = extra;
13231332
}
13241333

1334+
void ggml_cuda_set_main_device(int main_device) {
1335+
if (main_device > g_device_count) {
1336+
fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
1337+
main_device, g_device_count, g_main_device);
1338+
return;
1339+
}
1340+
g_main_device = main_device;
1341+
if (g_device_count > 1) {
1342+
cudaDeviceProp prop;
1343+
CUDA_CHECK(cudaGetDeviceProperties(&prop, g_main_device));
1344+
fprintf(stderr, "%s: using device %d (%s) as main device\n", __func__, g_main_device, prop.name);
1345+
}
1346+
}
1347+
13251348
void ggml_cuda_set_scratch_size(size_t scratch_size) {
13261349
g_scratch_size = scratch_size;
13271350
}

ggml-cuda.h

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ void ggml_cuda_host_free(void * ptr);
2727
void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
2828
void ggml_cuda_free_data(struct ggml_tensor * tensor);
2929
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
30+
void ggml_cuda_set_main_device(int main_device);
3031
void ggml_cuda_set_scratch_size(size_t scratch_size);
3132
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
3233

llama.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,7 @@ struct llama_context_params llama_context_default_params() {
857857
/*.n_ctx =*/ 512,
858858
/*.n_batch =*/ 512,
859859
/*.gpu_layers =*/ 0,
860+
/*.main_gpu =*/ 0,
860861
/*.tensor_split =*/ {0},
861862
/*.seed =*/ -1,
862863
/*.f16_kv =*/ true,
@@ -943,6 +944,7 @@ static void llama_model_load_internal(
943944
int n_ctx,
944945
int n_batch,
945946
int n_gpu_layers,
947+
int main_gpu,
946948
const float * tensor_split,
947949
ggml_type memory_type,
948950
bool use_mmap,
@@ -1039,6 +1041,7 @@ static void llama_model_load_internal(
10391041

10401042
#if defined(GGML_USE_CUBLAS)
10411043
fprintf(stderr, "%s: using CUDA for GPU acceleration\n", __func__);
1044+
ggml_cuda_set_main_device(main_gpu);
10421045
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
10431046
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT
10441047
#elif defined(GGML_USE_CLBLAST)
@@ -1225,6 +1228,7 @@ static bool llama_model_load(
12251228
int n_ctx,
12261229
int n_batch,
12271230
int n_gpu_layers,
1231+
int main_gpu,
12281232
float * tensor_split,
12291233
ggml_type memory_type,
12301234
bool use_mmap,
@@ -1233,8 +1237,8 @@ static bool llama_model_load(
12331237
llama_progress_callback progress_callback,
12341238
void *progress_callback_user_data) {
12351239
try {
1236-
llama_model_load_internal(fname, lctx, n_ctx, n_batch, n_gpu_layers, tensor_split, memory_type, use_mmap,
1237-
use_mlock, vocab_only, progress_callback, progress_callback_user_data);
1240+
llama_model_load_internal(fname, lctx, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, memory_type,
1241+
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
12381242
return true;
12391243
} catch (const std::string & err) {
12401244
fprintf(stderr, "error loading model: %s\n", err.c_str());
@@ -2400,8 +2404,8 @@ struct llama_context * llama_init_from_file(
24002404
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
24012405

24022406
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_batch, params.n_gpu_layers,
2403-
params.tensor_split, memory_type, params.use_mmap, params.use_mlock, params.vocab_only,
2404-
params.progress_callback, params.progress_callback_user_data)) {
2407+
params.main_gpu, params.tensor_split, memory_type, params.use_mmap, params.use_mlock,
2408+
params.vocab_only, params.progress_callback, params.progress_callback_user_data)) {
24052409
fprintf(stderr, "%s: failed to load model\n", __func__);
24062410
llama_free(ctx);
24072411
return nullptr;

llama.h

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ extern "C" {
7575
int n_ctx; // text context
7676
int n_batch; // prompt processing batch size
7777
int n_gpu_layers; // number of layers to store in VRAM
78+
int main_gpu; // the GPU that is used for scratch and small tensors
7879
float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
7980
int seed; // RNG seed, -1 for random
8081

0 commit comments

Comments
 (0)