mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
Store layers in VRAM
This commit is contained in:
parent
d052a0ed4c
commit
3ed4588e22
@ -271,6 +271,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
params.use_color = true;
|
params.use_color = true;
|
||||||
} else if (arg == "--mlock") {
|
} else if (arg == "--mlock") {
|
||||||
params.use_mlock = true;
|
params.use_mlock = true;
|
||||||
|
} else if (arg == "--gpu_layers") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.gpu_layers = std::stoi(argv[i]);
|
||||||
} else if (arg == "--no-mmap") {
|
} else if (arg == "--no-mmap") {
|
||||||
params.use_mmap = false;
|
params.use_mmap = false;
|
||||||
} else if (arg == "--mtest") {
|
} else if (arg == "--mtest") {
|
||||||
@ -406,6 +412,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
if (llama_mmap_supported()) {
|
if (llama_mmap_supported()) {
|
||||||
fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
|
fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
|
||||||
}
|
}
|
||||||
|
fprintf(stderr, " --gpu_layers number of layers to store in VRAM");
|
||||||
fprintf(stderr, " --mtest compute maximum memory usage\n");
|
fprintf(stderr, " --mtest compute maximum memory usage\n");
|
||||||
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
|
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
|
||||||
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
||||||
@ -454,6 +461,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
|
|||||||
lparams.f16_kv = params.memory_f16;
|
lparams.f16_kv = params.memory_f16;
|
||||||
lparams.use_mmap = params.use_mmap;
|
lparams.use_mmap = params.use_mmap;
|
||||||
lparams.use_mlock = params.use_mlock;
|
lparams.use_mlock = params.use_mlock;
|
||||||
|
lparams.gpu_layers = params.gpu_layers;
|
||||||
lparams.logits_all = params.perplexity;
|
lparams.logits_all = params.perplexity;
|
||||||
lparams.embedding = params.embedding;
|
lparams.embedding = params.embedding;
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ struct gpt_params {
|
|||||||
bool perplexity = false; // compute perplexity over the prompt
|
bool perplexity = false; // compute perplexity over the prompt
|
||||||
bool use_mmap = true; // use mmap for faster loads
|
bool use_mmap = true; // use mmap for faster loads
|
||||||
bool use_mlock = false; // use mlock to keep model in memory
|
bool use_mlock = false; // use mlock to keep model in memory
|
||||||
|
int gpu_layers = 0; // number of layers to store in VRAM
|
||||||
bool mem_test = false; // compute maximum memory usage
|
bool mem_test = false; // compute maximum memory usage
|
||||||
bool verbose_prompt = false; // print prompt tokens before generation
|
bool verbose_prompt = false; // print prompt tokens before generation
|
||||||
};
|
};
|
||||||
|
41
ggml-cuda.cu
41
ggml-cuda.cu
@ -349,7 +349,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// buffer pool for cuda
|
// buffer pool for cuda
|
||||||
#define MAX_CUDA_BUFFERS 16
|
#define MAX_CUDA_BUFFERS 256
|
||||||
|
|
||||||
struct scoped_spin_lock {
|
struct scoped_spin_lock {
|
||||||
std::atomic_flag& lock;
|
std::atomic_flag& lock;
|
||||||
@ -678,9 +678,15 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
|||||||
float * c_D = d_D + i * d_ne;
|
float * c_D = d_D + i * d_ne;
|
||||||
char * c_Q = d_Q + i * q_sz;
|
char * c_Q = d_Q + i * q_sz;
|
||||||
|
|
||||||
if (ne11 == 1) {
|
// copy src0 to device if necessary
|
||||||
// copy src0 to device
|
if (src0->backend == GGML_BACKEND_CPU) {
|
||||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
|
||||||
|
} else if (src0->backend == GGML_BACKEND_CUDA) {
|
||||||
|
c_Q = ((char *) src0->data) + i * q_sz;
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
if (ne11 == 1) {
|
||||||
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
||||||
|
|
||||||
// copy src1 to device
|
// copy src1 to device
|
||||||
@ -696,8 +702,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
|||||||
} else {
|
} else {
|
||||||
float * c_X = d_X + i * x_ne;
|
float * c_X = d_X + i * x_ne;
|
||||||
|
|
||||||
// copy src0 and convert to fp32 on device
|
// convert src0 to fp32 on device
|
||||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
|
|
||||||
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
|
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
||||||
@ -742,8 +747,8 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
|
|||||||
// TODO: find the optimal values for these
|
// TODO: find the optimal values for these
|
||||||
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
||||||
src1->type == GGML_TYPE_F32 &&
|
src1->type == GGML_TYPE_F32 &&
|
||||||
dst->type == GGML_TYPE_F32) {
|
dst->type == GGML_TYPE_F32 &&
|
||||||
|
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -795,3 +800,25 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
|
||||||
|
const int64_t ne0 = tensor->ne[0];
|
||||||
|
const int64_t ne1 = tensor->ne[1];
|
||||||
|
const int64_t ne2 = tensor->ne[2];
|
||||||
|
const int64_t ne3 = tensor->ne[3];
|
||||||
|
|
||||||
|
const ggml_type type = tensor->type;
|
||||||
|
const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
|
||||||
|
|
||||||
|
size_t q_size;
|
||||||
|
char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
|
||||||
|
|
||||||
|
cudaStream_t cudaStream2 = g_cudaStreams2[0];
|
||||||
|
|
||||||
|
// copy tensor to device
|
||||||
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
|
tensor->data = d_Q;
|
||||||
|
tensor->backend = GGML_BACKEND_CUDA;
|
||||||
|
}
|
||||||
|
@ -14,6 +14,8 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
|
|||||||
void * ggml_cuda_host_malloc(size_t size);
|
void * ggml_cuda_host_malloc(size_t size);
|
||||||
void ggml_cuda_host_free(void * ptr);
|
void ggml_cuda_host_free(void * ptr);
|
||||||
|
|
||||||
|
void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
1
ggml.c
1
ggml.c
@ -4711,6 +4711,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
|
|||||||
|
|
||||||
*result = (struct ggml_tensor) {
|
*result = (struct ggml_tensor) {
|
||||||
/*.type =*/ type,
|
/*.type =*/ type,
|
||||||
|
/*.backend =*/ GGML_BACKEND_CPU,
|
||||||
/*.n_dims =*/ n_dims,
|
/*.n_dims =*/ n_dims,
|
||||||
/*.ne =*/ { 1, 1, 1, 1 },
|
/*.ne =*/ { 1, 1, 1, 1 },
|
||||||
/*.nb =*/ { 0, 0, 0, 0 },
|
/*.nb =*/ { 0, 0, 0, 0 },
|
||||||
|
8
ggml.h
8
ggml.h
@ -243,6 +243,11 @@ extern "C" {
|
|||||||
GGML_TYPE_COUNT,
|
GGML_TYPE_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum ggml_backend {
|
||||||
|
GGML_BACKEND_CPU = 0,
|
||||||
|
GGML_BACKEND_CUDA = 1,
|
||||||
|
};
|
||||||
|
|
||||||
// model file types
|
// model file types
|
||||||
enum ggml_ftype {
|
enum ggml_ftype {
|
||||||
GGML_FTYPE_UNKNOWN = -1,
|
GGML_FTYPE_UNKNOWN = -1,
|
||||||
@ -323,6 +328,7 @@ extern "C" {
|
|||||||
// n-dimensional tensor
|
// n-dimensional tensor
|
||||||
struct ggml_tensor {
|
struct ggml_tensor {
|
||||||
enum ggml_type type;
|
enum ggml_type type;
|
||||||
|
enum ggml_backend backend;
|
||||||
|
|
||||||
int n_dims;
|
int n_dims;
|
||||||
int64_t ne[GGML_MAX_DIMS]; // number of elements
|
int64_t ne[GGML_MAX_DIMS]; // number of elements
|
||||||
@ -353,7 +359,7 @@ extern "C" {
|
|||||||
|
|
||||||
char name[32];
|
char name[32];
|
||||||
|
|
||||||
char padding[8]; // TODO: remove and add padding to name?
|
char padding[9]; // TODO: remove and add padding to name?
|
||||||
};
|
};
|
||||||
|
|
||||||
// computation graph
|
// computation graph
|
||||||
|
22
llama.cpp
22
llama.cpp
@ -9,6 +9,9 @@
|
|||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
#include "ggml-cuda.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
@ -815,6 +818,7 @@ struct llama_context_params llama_context_default_params() {
|
|||||||
/*.vocab_only =*/ false,
|
/*.vocab_only =*/ false,
|
||||||
/*.use_mmap =*/ true,
|
/*.use_mmap =*/ true,
|
||||||
/*.use_mlock =*/ false,
|
/*.use_mlock =*/ false,
|
||||||
|
/*.gpu_layers =*/ 0,
|
||||||
/*.embedding =*/ false,
|
/*.embedding =*/ false,
|
||||||
/*.progress_callback =*/ nullptr,
|
/*.progress_callback =*/ nullptr,
|
||||||
/*.progress_callback_user_data =*/ nullptr,
|
/*.progress_callback_user_data =*/ nullptr,
|
||||||
@ -877,6 +881,7 @@ static void llama_model_load_internal(
|
|||||||
ggml_type memory_type,
|
ggml_type memory_type,
|
||||||
bool use_mmap,
|
bool use_mmap,
|
||||||
bool use_mlock,
|
bool use_mlock,
|
||||||
|
int gpu_layers,
|
||||||
bool vocab_only,
|
bool vocab_only,
|
||||||
llama_progress_callback progress_callback,
|
llama_progress_callback progress_callback,
|
||||||
void * progress_callback_user_data) {
|
void * progress_callback_user_data) {
|
||||||
@ -1011,6 +1016,18 @@ static void llama_model_load_internal(
|
|||||||
ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
|
ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
|
||||||
|
|
||||||
model.mapping = std::move(ml->mapping);
|
model.mapping = std::move(ml->mapping);
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
for (int i = 0; i < std::min(gpu_layers, int(hparams.n_layer)); ++i) {
|
||||||
|
auto & layer = model.layers[i];
|
||||||
|
ggml_cuda_transform_tensor(layer.wq);
|
||||||
|
ggml_cuda_transform_tensor(layer.wk);
|
||||||
|
ggml_cuda_transform_tensor(layer.wv);
|
||||||
|
ggml_cuda_transform_tensor(layer.wo);
|
||||||
|
ggml_cuda_transform_tensor(layer.w1);
|
||||||
|
ggml_cuda_transform_tensor(layer.w2);
|
||||||
|
ggml_cuda_transform_tensor(layer.w3);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// loading time will be recalculate after the first eval, so
|
// loading time will be recalculate after the first eval, so
|
||||||
// we take page faults deferred by mmap() into consideration
|
// we take page faults deferred by mmap() into consideration
|
||||||
@ -1024,11 +1041,12 @@ static bool llama_model_load(
|
|||||||
ggml_type memory_type,
|
ggml_type memory_type,
|
||||||
bool use_mmap,
|
bool use_mmap,
|
||||||
bool use_mlock,
|
bool use_mlock,
|
||||||
|
int gpu_layers,
|
||||||
bool vocab_only,
|
bool vocab_only,
|
||||||
llama_progress_callback progress_callback,
|
llama_progress_callback progress_callback,
|
||||||
void *progress_callback_user_data) {
|
void *progress_callback_user_data) {
|
||||||
try {
|
try {
|
||||||
llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock,
|
llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock, gpu_layers,
|
||||||
vocab_only, progress_callback, progress_callback_user_data);
|
vocab_only, progress_callback, progress_callback_user_data);
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::string & err) {
|
} catch (const std::string & err) {
|
||||||
@ -2088,7 +2106,7 @@ struct llama_context * llama_init_from_file(
|
|||||||
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||||
|
|
||||||
if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type,
|
if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type,
|
||||||
params.use_mmap, params.use_mlock, params.vocab_only,
|
params.use_mmap, params.use_mlock, params.gpu_layers, params.vocab_only,
|
||||||
params.progress_callback, params.progress_callback_user_data)) {
|
params.progress_callback, params.progress_callback_user_data)) {
|
||||||
fprintf(stderr, "%s: failed to load model\n", __func__);
|
fprintf(stderr, "%s: failed to load model\n", __func__);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
|
1
llama.h
1
llama.h
@ -63,6 +63,7 @@ extern "C" {
|
|||||||
bool vocab_only; // only load the vocabulary, no weights
|
bool vocab_only; // only load the vocabulary, no weights
|
||||||
bool use_mmap; // use mmap if possible
|
bool use_mmap; // use mmap if possible
|
||||||
bool use_mlock; // force system to keep model in RAM
|
bool use_mlock; // force system to keep model in RAM
|
||||||
|
int gpu_layers; // number of layers to store in VRAM
|
||||||
bool embedding; // embedding mode only
|
bool embedding; // embedding mode only
|
||||||
|
|
||||||
// called with a progress value between 0 and 1, pass NULL to disable
|
// called with a progress value between 0 and 1, pass NULL to disable
|
||||||
|
Loading…
Reference in New Issue
Block a user