mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-13 14:29:52 +00:00
cuda : supports running on CPU for GGML_USE_CUBLAS=ON build (#3946)
* protyping the idea that supports running on CPU for a GGML_USE_CUBLAS=on build * doc: add comments to ggml_cublas_loaded() * fix defined(...)
This commit is contained in:
parent
381efbf480
commit
46876d2a2c
17
ggml-cuda.cu
17
ggml-cuda.cu
@ -5790,6 +5790,11 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
|||||||
CUDA_CHECK(cudaFree(ptr));
|
CUDA_CHECK(cudaFree(ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool g_cublas_loaded = false;
|
||||||
|
|
||||||
|
bool ggml_cublas_loaded(void) {
|
||||||
|
return g_cublas_loaded;
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_init_cublas() {
|
void ggml_init_cublas() {
|
||||||
static bool initialized = false;
|
static bool initialized = false;
|
||||||
@ -5803,7 +5808,12 @@ void ggml_init_cublas() {
|
|||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
|
if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) {
|
||||||
|
initialized = true;
|
||||||
|
g_cublas_loaded = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
|
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
|
||||||
int64_t total_vram = 0;
|
int64_t total_vram = 0;
|
||||||
#if defined(GGML_CUDA_FORCE_MMQ)
|
#if defined(GGML_CUDA_FORCE_MMQ)
|
||||||
@ -5851,6 +5861,7 @@ void ggml_init_cublas() {
|
|||||||
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
|
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
|
||||||
|
|
||||||
initialized = true;
|
initialized = true;
|
||||||
|
g_cublas_loaded = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -7158,6 +7169,8 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||||
|
if (!g_cublas_loaded) return false;
|
||||||
|
|
||||||
const int64_t ne10 = src1->ne[0];
|
const int64_t ne10 = src1->ne[0];
|
||||||
|
|
||||||
const int64_t ne0 = dst->ne[0];
|
const int64_t ne0 = dst->ne[0];
|
||||||
@ -7843,6 +7856,8 @@ void ggml_cuda_free_scratch() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||||
|
if (!g_cublas_loaded) return false;
|
||||||
|
|
||||||
ggml_cuda_func_t func;
|
ggml_cuda_func_t func;
|
||||||
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|
||||||
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|
||||||
|
@ -17,7 +17,12 @@ extern "C" {
|
|||||||
|
|
||||||
#define GGML_CUDA_MAX_DEVICES 16
|
#define GGML_CUDA_MAX_DEVICES 16
|
||||||
|
|
||||||
|
// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
|
||||||
GGML_API void ggml_init_cublas(void);
|
GGML_API void ggml_init_cublas(void);
|
||||||
|
|
||||||
|
// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
|
||||||
|
GGML_API bool ggml_cublas_loaded(void);
|
||||||
|
|
||||||
GGML_API void * ggml_cuda_host_malloc(size_t size);
|
GGML_API void * ggml_cuda_host_malloc(size_t size);
|
||||||
GGML_API void ggml_cuda_host_free(void * ptr);
|
GGML_API void ggml_cuda_host_free(void * ptr);
|
||||||
|
|
||||||
|
181
llama.cpp
181
llama.cpp
@ -596,19 +596,37 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph *
|
|||||||
// llama helpers
|
// llama helpers
|
||||||
//
|
//
|
||||||
|
|
||||||
|
inline void * llama_host_malloc(size_t n) {
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
# define llama_host_malloc(n) ggml_cuda_host_malloc(n)
|
if (ggml_cublas_loaded()) {
|
||||||
# define llama_host_free(data) ggml_cuda_host_free(data)
|
return ggml_cuda_host_malloc(n);
|
||||||
|
} else {
|
||||||
|
return malloc(n);
|
||||||
|
}
|
||||||
#elif GGML_USE_METAL
|
#elif GGML_USE_METAL
|
||||||
# define llama_host_malloc(n) ggml_metal_host_malloc(n)
|
return ggml_metal_host_malloc(n);
|
||||||
# define llama_host_free(data) ggml_metal_host_free(data)
|
|
||||||
#elif GGML_USE_CPU_HBM
|
#elif GGML_USE_CPU_HBM
|
||||||
# define llama_host_malloc(n) hbw_malloc(n)
|
return hbw_malloc(n);
|
||||||
# define llama_host_free(data) if (data != NULL) hbw_free(data)
|
|
||||||
#else
|
#else
|
||||||
# define llama_host_malloc(n) malloc(n)
|
return malloc(n);
|
||||||
# define llama_host_free(data) free(data)
|
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void llama_host_free(void * ptr) {
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
if (ggml_cublas_loaded()) {
|
||||||
|
return ggml_cuda_host_free(ptr);
|
||||||
|
} else {
|
||||||
|
return free(ptr);
|
||||||
|
}
|
||||||
|
#elif GGML_USE_METAL
|
||||||
|
return ggml_metal_host_free(ptr);
|
||||||
|
#elif GGML_USE_CPU_HBM
|
||||||
|
return hbw_free(ptr);
|
||||||
|
#else
|
||||||
|
return free(ptr);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
static std::string llama_format_win_err(DWORD err) {
|
static std::string llama_format_win_err(DWORD err) {
|
||||||
@ -1200,9 +1218,11 @@ struct llama_kv_cache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
ggml_cuda_free_data(k);
|
if (ggml_cublas_loaded()) {
|
||||||
ggml_cuda_free_data(v);
|
ggml_cuda_free_data(k);
|
||||||
#endif // GGML_USE_CUBLAS
|
ggml_cuda_free_data(v);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1302,11 +1322,15 @@ struct llama_model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
for (size_t i = 0; i < tensors_by_name.size(); ++i) {
|
if (ggml_cublas_loaded()) {
|
||||||
ggml_cuda_free_data(tensors_by_name[i].second);
|
for (size_t i = 0; i < tensors_by_name.size(); ++i) {
|
||||||
|
ggml_cuda_free_data(tensors_by_name[i].second);
|
||||||
|
}
|
||||||
|
ggml_cuda_free_scratch();
|
||||||
}
|
}
|
||||||
ggml_cuda_free_scratch();
|
#endif
|
||||||
#elif defined(GGML_USE_CLBLAST)
|
|
||||||
|
#if defined(GGML_USE_CLBLAST)
|
||||||
for (size_t i = 0; i < tensors_by_name.size(); ++i) {
|
for (size_t i = 0; i < tensors_by_name.size(); ++i) {
|
||||||
ggml_cl_free_data(tensors_by_name[i].second);
|
ggml_cl_free_data(tensors_by_name[i].second);
|
||||||
}
|
}
|
||||||
@ -1418,23 +1442,26 @@ static bool llama_kv_cache_init(
|
|||||||
ggml_set_name(cache.v, "cache_v");
|
ggml_set_name(cache.v, "cache_v");
|
||||||
|
|
||||||
(void) n_gpu_layers;
|
(void) n_gpu_layers;
|
||||||
#ifdef GGML_USE_CUBLAS
|
|
||||||
size_t vram_kv_cache = 0;
|
|
||||||
|
|
||||||
if (n_gpu_layers > (int)n_layer + 1) {
|
#ifdef GGML_USE_CUBLAS
|
||||||
ggml_cuda_assign_buffers_no_scratch(cache.v);
|
if (ggml_cublas_loaded()) {
|
||||||
LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__);
|
size_t vram_kv_cache = 0;
|
||||||
vram_kv_cache += ggml_nbytes(cache.v);
|
|
||||||
|
if (n_gpu_layers > (int)n_layer + 1) {
|
||||||
|
ggml_cuda_assign_buffers_no_scratch(cache.v);
|
||||||
|
LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__);
|
||||||
|
vram_kv_cache += ggml_nbytes(cache.v);
|
||||||
|
}
|
||||||
|
if (n_gpu_layers > (int)n_layer + 2) {
|
||||||
|
ggml_cuda_assign_buffers_no_scratch(cache.k);
|
||||||
|
LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__);
|
||||||
|
vram_kv_cache += ggml_nbytes(cache.k);
|
||||||
|
}
|
||||||
|
if (vram_kv_cache > 0) {
|
||||||
|
LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (n_gpu_layers > (int)n_layer + 2) {
|
#endif
|
||||||
ggml_cuda_assign_buffers_no_scratch(cache.k);
|
|
||||||
LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__);
|
|
||||||
vram_kv_cache += ggml_nbytes(cache.k);
|
|
||||||
}
|
|
||||||
if (vram_kv_cache > 0) {
|
|
||||||
LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
|
|
||||||
}
|
|
||||||
#endif // GGML_USE_CUBLAS
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -2521,18 +2548,22 @@ static void llm_load_tensors(
|
|||||||
}
|
}
|
||||||
|
|
||||||
(void) main_gpu;
|
(void) main_gpu;
|
||||||
|
|
||||||
|
enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU;
|
||||||
|
enum ggml_backend_type llama_backend_offload_split = GGML_BACKEND_CPU;
|
||||||
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__);
|
if (ggml_cublas_loaded()) {
|
||||||
ggml_cuda_set_main_device(main_gpu);
|
LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__);
|
||||||
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
|
ggml_cuda_set_main_device(main_gpu);
|
||||||
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT
|
|
||||||
|
llama_backend_offload = GGML_BACKEND_GPU;
|
||||||
|
llama_backend_offload_split = GGML_BACKEND_GPU_SPLIT;
|
||||||
|
}
|
||||||
#elif defined(GGML_USE_CLBLAST)
|
#elif defined(GGML_USE_CLBLAST)
|
||||||
LLAMA_LOG_INFO("%s: using OpenCL for GPU acceleration\n", __func__);
|
LLAMA_LOG_INFO("%s: using OpenCL for GPU acceleration\n", __func__);
|
||||||
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
|
llama_backend_offload = GGML_BACKEND_GPU;
|
||||||
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU
|
llama_backend_offload_split = GGML_BACKEND_GPU;
|
||||||
#else
|
|
||||||
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CPU
|
|
||||||
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_CPU
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// prepare memory for the weights
|
// prepare memory for the weights
|
||||||
@ -2559,12 +2590,12 @@ static void llm_load_tensors(
|
|||||||
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
||||||
// on Windows however this is detrimental unless everything is on the GPU
|
// on Windows however this is detrimental unless everything is on the GPU
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
backend_norm = LLAMA_BACKEND_OFFLOAD;
|
backend_norm = llama_backend_offload;
|
||||||
#else
|
#else
|
||||||
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
|
|
||||||
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
|
backend_output = llama_backend_offload_split;
|
||||||
} else {
|
} else {
|
||||||
backend_norm = GGML_BACKEND_CPU;
|
backend_norm = GGML_BACKEND_CPU;
|
||||||
backend_output = GGML_BACKEND_CPU;
|
backend_output = GGML_BACKEND_CPU;
|
||||||
@ -2588,8 +2619,8 @@ static void llm_load_tensors(
|
|||||||
model.layers.resize(n_layer);
|
model.layers.resize(n_layer);
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||||
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
|
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
|
||||||
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
|
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
|
||||||
|
|
||||||
auto & layer = model.layers[i];
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
@ -2625,12 +2656,12 @@ static void llm_load_tensors(
|
|||||||
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
||||||
// on Windows however this is detrimental unless everything is on the GPU
|
// on Windows however this is detrimental unless everything is on the GPU
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
backend_norm = LLAMA_BACKEND_OFFLOAD;
|
backend_norm = llama_backend_offload;
|
||||||
#else
|
#else
|
||||||
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
|
|
||||||
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
|
backend_output = llama_backend_offload_split;
|
||||||
} else {
|
} else {
|
||||||
backend_norm = GGML_BACKEND_CPU;
|
backend_norm = GGML_BACKEND_CPU;
|
||||||
backend_output = GGML_BACKEND_CPU;
|
backend_output = GGML_BACKEND_CPU;
|
||||||
@ -2654,8 +2685,8 @@ static void llm_load_tensors(
|
|||||||
model.layers.resize(n_layer);
|
model.layers.resize(n_layer);
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||||
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
|
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
|
||||||
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
|
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
|
||||||
|
|
||||||
auto & layer = model.layers[i];
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
@ -2695,12 +2726,12 @@ static void llm_load_tensors(
|
|||||||
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
||||||
// on Windows however this is detrimental unless everything is on the GPU
|
// on Windows however this is detrimental unless everything is on the GPU
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
backend_norm = LLAMA_BACKEND_OFFLOAD;
|
backend_norm = llama_backend_offload;
|
||||||
#else
|
#else
|
||||||
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
|
|
||||||
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
|
backend_output = llama_backend_offload_split;
|
||||||
} else {
|
} else {
|
||||||
backend_norm = GGML_BACKEND_CPU;
|
backend_norm = GGML_BACKEND_CPU;
|
||||||
backend_output = GGML_BACKEND_CPU;
|
backend_output = GGML_BACKEND_CPU;
|
||||||
@ -2726,8 +2757,8 @@ static void llm_load_tensors(
|
|||||||
model.layers.resize(n_layer);
|
model.layers.resize(n_layer);
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||||
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
|
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
|
||||||
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
|
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
|
||||||
|
|
||||||
auto & layer = model.layers[i];
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
@ -2772,12 +2803,12 @@ static void llm_load_tensors(
|
|||||||
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
||||||
// on Windows however this is detrimental unless everything is on the GPU
|
// on Windows however this is detrimental unless everything is on the GPU
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
backend_norm = LLAMA_BACKEND_OFFLOAD;
|
backend_norm = llama_backend_offload;
|
||||||
#else
|
#else
|
||||||
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
|
|
||||||
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
|
backend_output = llama_backend_offload_split;
|
||||||
} else {
|
} else {
|
||||||
backend_norm = GGML_BACKEND_CPU;
|
backend_norm = GGML_BACKEND_CPU;
|
||||||
backend_output = GGML_BACKEND_CPU;
|
backend_output = GGML_BACKEND_CPU;
|
||||||
@ -2803,8 +2834,8 @@ static void llm_load_tensors(
|
|||||||
model.layers.resize(n_layer);
|
model.layers.resize(n_layer);
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||||
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
|
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
|
||||||
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
|
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
|
||||||
|
|
||||||
auto & layer = model.layers[i];
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
@ -2849,12 +2880,12 @@ static void llm_load_tensors(
|
|||||||
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
||||||
// on Windows however this is detrimental unless everything is on the GPU
|
// on Windows however this is detrimental unless everything is on the GPU
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
backend_norm = LLAMA_BACKEND_OFFLOAD;
|
backend_norm = llama_backend_offload;
|
||||||
#else
|
#else
|
||||||
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
|
|
||||||
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
|
backend_output = llama_backend_offload_split;
|
||||||
} else {
|
} else {
|
||||||
backend_norm = GGML_BACKEND_CPU;
|
backend_norm = GGML_BACKEND_CPU;
|
||||||
backend_output = GGML_BACKEND_CPU;
|
backend_output = GGML_BACKEND_CPU;
|
||||||
@ -2877,8 +2908,8 @@ static void llm_load_tensors(
|
|||||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||||
model.layers.resize(n_layer);
|
model.layers.resize(n_layer);
|
||||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||||
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload;
|
||||||
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT;
|
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split;
|
||||||
auto & layer = model.layers[i];
|
auto & layer = model.layers[i];
|
||||||
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
||||||
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
|
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
|
||||||
@ -2915,12 +2946,12 @@ static void llm_load_tensors(
|
|||||||
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
||||||
// on Windows however this is detrimental unless everything is on the GPU
|
// on Windows however this is detrimental unless everything is on the GPU
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
backend_norm = LLAMA_BACKEND_OFFLOAD;
|
backend_norm = llama_backend_offload;
|
||||||
#else
|
#else
|
||||||
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
|
|
||||||
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
|
backend_output = llama_backend_offload_split;
|
||||||
} else {
|
} else {
|
||||||
backend_norm = GGML_BACKEND_CPU;
|
backend_norm = GGML_BACKEND_CPU;
|
||||||
backend_output = GGML_BACKEND_CPU;
|
backend_output = GGML_BACKEND_CPU;
|
||||||
@ -2946,8 +2977,8 @@ static void llm_load_tensors(
|
|||||||
model.layers.resize(n_layer);
|
model.layers.resize(n_layer);
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||||
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
|
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
|
||||||
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
|
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
|
||||||
|
|
||||||
auto & layer = model.layers[i];
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
@ -2993,12 +3024,12 @@ static void llm_load_tensors(
|
|||||||
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
||||||
// on Windows however this is detrimental unless everything is on the GPU
|
// on Windows however this is detrimental unless everything is on the GPU
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
backend_norm = LLAMA_BACKEND_OFFLOAD;
|
backend_norm = llama_backend_offload;
|
||||||
#else
|
#else
|
||||||
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
|
|
||||||
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
|
backend_output = llama_backend_offload_split;
|
||||||
} else {
|
} else {
|
||||||
backend_norm = GGML_BACKEND_CPU;
|
backend_norm = GGML_BACKEND_CPU;
|
||||||
backend_output = GGML_BACKEND_CPU;
|
backend_output = GGML_BACKEND_CPU;
|
||||||
@ -3022,8 +3053,8 @@ static void llm_load_tensors(
|
|||||||
model.layers.resize(n_layer);
|
model.layers.resize(n_layer);
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||||
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
|
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
|
||||||
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
|
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
|
||||||
|
|
||||||
auto & layer = model.layers[i];
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user