mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-13 14:29:52 +00:00
cuBLAS: use host pinned memory and dequantize while copying (#1207)
* cuBLAS: dequantize simultaneously while copying memory * cuBLAS: use host pinned memory * cuBLAS: improve ggml_compute_forward_mul_mat_f16_f32 with pinned memory * cuBLAS: also pin kv cache * fix rebase
This commit is contained in:
parent
b1ee8f59b4
commit
7fc50c051a
5
Makefile
5
Makefile
@ -106,6 +106,7 @@ ifdef LLAMA_OPENBLAS
|
|||||||
endif
|
endif
|
||||||
ifdef LLAMA_CUBLAS
|
ifdef LLAMA_CUBLAS
|
||||||
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
|
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
|
||||||
|
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
|
||||||
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
|
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
|
||||||
OBJS += ggml-cuda.o
|
OBJS += ggml-cuda.o
|
||||||
NVCC = nvcc
|
NVCC = nvcc
|
||||||
@ -164,10 +165,10 @@ $(info )
|
|||||||
# Build library
|
# Build library
|
||||||
#
|
#
|
||||||
|
|
||||||
ggml.o: ggml.c ggml.h
|
ggml.o: ggml.c ggml.h ggml-cuda.h
|
||||||
$(CC) $(CFLAGS) -c $< -o $@
|
$(CC) $(CFLAGS) -c $< -o $@
|
||||||
|
|
||||||
llama.o: llama.cpp ggml.h llama.h llama_util.h
|
llama.o: llama.cpp ggml.h ggml-cuda.h llama.h llama_util.h
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
common.o: examples/common.cpp examples/common.h
|
common.o: examples/common.cpp examples/common.h
|
||||||
|
45
ggml-cuda.cu
45
ggml-cuda.cu
@ -227,6 +227,25 @@ void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t st
|
|||||||
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
|
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
|
||||||
|
switch (type) {
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
return dequantize_row_q4_0_cuda;
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
return dequantize_row_q4_1_cuda;
|
||||||
|
case GGML_TYPE_Q4_2:
|
||||||
|
return dequantize_row_q4_2_cuda;
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
return dequantize_row_q5_0_cuda;
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
return dequantize_row_q5_1_cuda;
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
return dequantize_row_q8_0_cuda;
|
||||||
|
default:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// buffer pool for cuda
|
// buffer pool for cuda
|
||||||
#define MAX_CUDA_BUFFERS 16
|
#define MAX_CUDA_BUFFERS 16
|
||||||
|
|
||||||
@ -286,18 +305,22 @@ void ggml_cuda_pool_free(void * ptr, size_t size) {
|
|||||||
CUDA_CHECK(cudaFree(ptr));
|
CUDA_CHECK(cudaFree(ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
cublasHandle_t g_cublasH = NULL;
|
cublasHandle_t g_cublasH = nullptr;
|
||||||
cudaStream_t g_cudaStream = NULL;
|
cudaStream_t g_cudaStream = nullptr;
|
||||||
|
cudaStream_t g_cudaStream2 = nullptr;
|
||||||
|
cudaEvent_t g_cudaEvent = nullptr;
|
||||||
|
|
||||||
void ggml_init_cublas(void) {
|
void ggml_init_cublas() {
|
||||||
if (g_cublasH == NULL) {
|
if (g_cublasH == nullptr) {
|
||||||
// create cublas handle, bind a stream
|
// create cublas handle, bind a stream
|
||||||
CUBLAS_CHECK(cublasCreate(&g_cublasH));
|
CUBLAS_CHECK(cublasCreate(&g_cublasH));
|
||||||
|
|
||||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
|
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
|
||||||
|
|
||||||
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
|
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
|
||||||
|
|
||||||
|
// create additional stream and event for synchronization
|
||||||
|
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
|
||||||
|
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming));
|
||||||
|
|
||||||
// configure logging to stdout
|
// configure logging to stdout
|
||||||
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
|
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
|
||||||
}
|
}
|
||||||
@ -330,3 +353,13 @@ cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src,
|
|||||||
return cudaSuccess;
|
return cudaSuccess;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void * ggml_cuda_host_malloc(size_t size) {
|
||||||
|
void * ptr;
|
||||||
|
CUDA_CHECK(cudaMallocHost((void **) &ptr, size));
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_host_free(void * ptr) {
|
||||||
|
CUDA_CHECK(cudaFreeHost(ptr));
|
||||||
|
}
|
||||||
|
10
ggml-cuda.h
10
ggml-cuda.h
@ -26,9 +26,14 @@ extern "C" {
|
|||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
extern cublasHandle_t g_cublasH;
|
extern cublasHandle_t g_cublasH;
|
||||||
extern cudaStream_t g_cudaStream;
|
extern cudaStream_t g_cudaStream;
|
||||||
|
extern cudaStream_t g_cudaStream2;
|
||||||
|
extern cudaEvent_t g_cudaEvent;
|
||||||
|
|
||||||
void ggml_init_cublas(void);
|
void ggml_init_cublas(void);
|
||||||
|
void * ggml_cuda_host_malloc(size_t size);
|
||||||
|
void ggml_cuda_host_free(void * ptr);
|
||||||
|
|
||||||
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
|
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
|
||||||
void ggml_cuda_pool_free(void * ptr, size_t size);
|
void ggml_cuda_pool_free(void * ptr, size_t size);
|
||||||
|
|
||||||
@ -41,6 +46,9 @@ void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t st
|
|||||||
|
|
||||||
cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream);
|
cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream);
|
||||||
|
|
||||||
|
typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
||||||
|
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
70
ggml.c
70
ggml.c
@ -8033,7 +8033,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
const float alpha = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const float beta = 0.0f;
|
const float beta = 0.0f;
|
||||||
const int x_ne = ne01 * ne10;
|
const int x_ne = ne01 * ne00;
|
||||||
const int y_ne = ne11 * ne10;
|
const int y_ne = ne11 * ne10;
|
||||||
const int d_ne = ne11 * ne01;
|
const int d_ne = ne11 * ne01;
|
||||||
|
|
||||||
@ -8235,25 +8235,27 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
ggml_fp16_t * const wdata = params->wdata;
|
|
||||||
|
|
||||||
const float alpha = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const float beta = 0.0f;
|
const float beta = 0.0f;
|
||||||
const int x_ne = ne01 * ne10;
|
const int x_ne = ne01 * ne00;
|
||||||
const int y_ne = ne11 * ne10;
|
const int y_ne = ne11 * ne10;
|
||||||
const int d_ne = ne11 * ne01;
|
const int d_ne = ne11 * ne01;
|
||||||
|
|
||||||
size_t x_size, y_size, d_size;
|
size_t x_size, y_size, d_size;
|
||||||
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
||||||
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
||||||
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
||||||
#else
|
#else
|
||||||
float * const wdata = params->wdata;
|
float * const wdata = params->wdata;
|
||||||
#endif
|
#endif
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
// copy src0 while converting src1
|
||||||
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
|
||||||
|
|
||||||
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
|
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
|
||||||
|
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + (ne11 * ne10) * (i03 * ne02 + i02);
|
||||||
{
|
{
|
||||||
size_t id = 0;
|
size_t id = 0;
|
||||||
for (int64_t i01 = 0; i01 < ne11; ++i01) {
|
for (int64_t i01 = 0; i01 < ne11; ++i01) {
|
||||||
@ -8275,11 +8277,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
|
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
|
||||||
|
|
||||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
|
|
||||||
// copy data to device
|
// copy data to device
|
||||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
|
|
||||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
@ -8498,39 +8498,19 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
const float alpha = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const float beta = 0.0f;
|
const float beta = 0.0f;
|
||||||
const int x_ne = ne01 * ne10;
|
const int x_ne = ne01 * ne00;
|
||||||
const int y_ne = ne11 * ne10;
|
const int y_ne = ne11 * ne10;
|
||||||
const int d_ne = ne11 * ne01;
|
const int d_ne = ne11 * ne01;
|
||||||
|
|
||||||
size_t x_size, y_size, d_size, q_size;
|
size_t x_size, y_size, d_size, q_size;
|
||||||
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
float * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
||||||
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
float * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
||||||
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
||||||
float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
|
void * d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
|
||||||
|
|
||||||
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
|
const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda(type);
|
||||||
if (type == GGML_TYPE_Q4_0) {
|
GGML_ASSERT(dequantize_row_q_cuda != NULL);
|
||||||
dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
|
#else
|
||||||
}
|
|
||||||
else if (type == GGML_TYPE_Q4_1) {
|
|
||||||
dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
|
|
||||||
}
|
|
||||||
else if (type == GGML_TYPE_Q4_2) {
|
|
||||||
dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
|
|
||||||
}
|
|
||||||
else if (type == GGML_TYPE_Q5_0) {
|
|
||||||
dequantize_row_q_cuda = dequantize_row_q5_0_cuda;
|
|
||||||
}
|
|
||||||
else if (type == GGML_TYPE_Q5_1) {
|
|
||||||
dequantize_row_q_cuda = dequantize_row_q5_1_cuda;
|
|
||||||
}
|
|
||||||
else if (type == GGML_TYPE_Q8_0) {
|
|
||||||
dequantize_row_q_cuda = dequantize_row_q8_0_cuda;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
}
|
|
||||||
#elif !defined(GGML_USE_CLBLAST)
|
|
||||||
float * const wdata = params->wdata;
|
float * const wdata = params->wdata;
|
||||||
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
||||||
#endif
|
#endif
|
||||||
@ -8543,10 +8523,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
// copy and dequantize on device
|
// copy and dequantize on device
|
||||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream));
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream2));
|
||||||
|
|
||||||
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
|
dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2));
|
||||||
#elif defined(GGML_USE_CLBLAST)
|
#elif defined(GGML_USE_CLBLAST)
|
||||||
const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
|
const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
|
||||||
#else
|
#else
|
||||||
@ -8560,11 +8541,13 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||||||
const float * x = wdata;
|
const float * x = wdata;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
// copy data to device
|
// copy data to device
|
||||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
|
||||||
|
|
||||||
|
// wait for dequantization
|
||||||
|
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0));
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
@ -11588,7 +11571,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||||
node->n_tasks = 1; // TODO: this actually is doing nothing
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
||||||
// the threads are still spinning
|
// the threads are still spinning
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*MAX(ggml_nelements(node->src1), ggml_nelements(node->src0));
|
||||||
//printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
|
//printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
|
||||||
//printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
|
//printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
|
||||||
//printf("cur = %zu\n", cur);
|
//printf("cur = %zu\n", cur);
|
||||||
@ -11600,6 +11583,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||||||
#endif
|
#endif
|
||||||
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
||||||
cur = 0;
|
cur = 0;
|
||||||
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
||||||
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||||
|
node->n_tasks = 1;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
} else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
|
} else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||||
|
@ -136,7 +136,7 @@ struct llama_kv_cache {
|
|||||||
|
|
||||||
struct ggml_context * ctx = NULL;
|
struct ggml_context * ctx = NULL;
|
||||||
|
|
||||||
llama_buffer buf;
|
llama_ctx_buffer buf;
|
||||||
|
|
||||||
int n; // number of tokens currently in the cache
|
int n; // number of tokens currently in the cache
|
||||||
|
|
||||||
@ -167,7 +167,7 @@ struct llama_model {
|
|||||||
struct llama_kv_cache kv_self;
|
struct llama_kv_cache kv_self;
|
||||||
|
|
||||||
// the model memory buffer
|
// the model memory buffer
|
||||||
llama_buffer buf;
|
llama_ctx_buffer buf;
|
||||||
|
|
||||||
// model memory mapped file
|
// model memory mapped file
|
||||||
std::unique_ptr<llama_mmap> mapping;
|
std::unique_ptr<llama_mmap> mapping;
|
||||||
@ -228,8 +228,8 @@ struct llama_context {
|
|||||||
|
|
||||||
// memory buffers used to evaluate the model
|
// memory buffers used to evaluate the model
|
||||||
// TODO: move in llama_state
|
// TODO: move in llama_state
|
||||||
llama_buffer buf_compute;
|
llama_ctx_buffer buf_compute;
|
||||||
llama_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
|
llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
|
||||||
|
|
||||||
int buf_last = 0;
|
int buf_last = 0;
|
||||||
size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
|
size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
|
||||||
|
26
llama_util.h
26
llama_util.h
@ -405,4 +405,30 @@ struct llama_buffer {
|
|||||||
delete[] addr;
|
delete[] addr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
#include "ggml-cuda.h"
|
||||||
|
struct llama_ctx_buffer {
|
||||||
|
uint8_t * addr = NULL;
|
||||||
|
size_t size = 0;
|
||||||
|
|
||||||
|
void resize(size_t size) {
|
||||||
|
if (addr) {
|
||||||
|
ggml_cuda_host_free(addr);
|
||||||
|
}
|
||||||
|
addr = (uint8_t *) ggml_cuda_host_malloc(size);
|
||||||
|
this->size = size;
|
||||||
|
}
|
||||||
|
|
||||||
|
~llama_ctx_buffer() {
|
||||||
|
if (addr) {
|
||||||
|
ggml_cuda_host_free(addr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
typedef llama_buffer llama_ctx_buffer;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
Loading…
Reference in New Issue
Block a user