llama : fix vram_scratch var

This commit is contained in:
Georgi Gerganov 2023-06-06 22:54:39 +03:00
parent 2a4e41a086
commit 2d7bf110ed
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1076,6 +1076,7 @@ static void llama_model_load_internal(
// prepare memory for the weights // prepare memory for the weights
size_t vram_weights = 0; size_t vram_weights = 0;
size_t vram_scratch = 0;
{ {
const uint32_t n_embd = hparams.n_embd; const uint32_t n_embd = hparams.n_embd;
const uint32_t n_layer = hparams.n_layer; const uint32_t n_layer = hparams.n_layer;
@ -1152,8 +1153,9 @@ static void llama_model_load_internal(
fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
(void) vram_scratch;
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
const size_t vram_scratch = n_batch * MB; vram_scratch = n_batch * MB;
ggml_cuda_set_scratch_size(vram_scratch); ggml_cuda_set_scratch_size(vram_scratch);
if (n_gpu_layers > 0) { if (n_gpu_layers > 0) {
fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n", fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n",