mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 21:39:52 +00:00
Add ReLU and SQR CUDA ops to (partially) fix Persimmon offloading (#4041)
* Add ReLU and SQR CUDA ops to fix Persimmon offloading * Persimmon loader: More helpful error on CUDA/ROCM when offloading too many layers
This commit is contained in:
parent
21fd874c8d
commit
bb50a792ec
72
ggml-cuda.cu
72
ggml-cuda.cu
@ -433,6 +433,8 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
||||
#define CUDA_MUL_BLOCK_SIZE 256
|
||||
#define CUDA_GELU_BLOCK_SIZE 256
|
||||
#define CUDA_SILU_BLOCK_SIZE 256
|
||||
#define CUDA_RELU_BLOCK_SIZE 256
|
||||
#define CUDA_SQR_BLOCK_SIZE 256
|
||||
#define CUDA_CPY_BLOCK_SIZE 32
|
||||
#define CUDA_SCALE_BLOCK_SIZE 256
|
||||
#define CUDA_CLAMP_BLOCK_SIZE 256
|
||||
@ -553,6 +555,24 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
||||
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
||||
}
|
||||
|
||||
static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = fmaxf(x[i], 0);
|
||||
}
|
||||
|
||||
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = x[i] * x[i];
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
@ -4759,6 +4779,16 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
||||
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
||||
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
|
||||
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
@ -6128,6 +6158,34 @@ inline void ggml_cuda_op_silu(
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_relu(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_sqr(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
sqr_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_norm(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
@ -7160,6 +7218,14 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
|
||||
}
|
||||
|
||||
static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
|
||||
}
|
||||
|
||||
static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
|
||||
}
|
||||
|
||||
static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
|
||||
}
|
||||
@ -7891,6 +7957,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
case GGML_UNARY_OP_SILU:
|
||||
func = ggml_cuda_silu;
|
||||
break;
|
||||
case GGML_UNARY_OP_RELU:
|
||||
func = ggml_cuda_relu;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
} break;
|
||||
@ -7909,6 +7978,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
case GGML_OP_SCALE:
|
||||
func = ggml_cuda_scale;
|
||||
break;
|
||||
case GGML_OP_SQR:
|
||||
func = ggml_cuda_sqr;
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
|
@ -2877,6 +2877,13 @@ static void llm_load_tensors(
|
||||
ggml_backend_type backend_output;
|
||||
|
||||
if (n_gpu_layers > int(n_layer)) {
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (n_gpu_layers > int(n_layer + 1)) {
|
||||
LLAMA_LOG_ERROR("%s: CUDA backend missing Persimmon CUDA ops, can offload at most %ld layers. See: https://github.com/ggerganov/llama.cpp/issues/4038\n",
|
||||
__func__, n_layer + 1);
|
||||
throw std::runtime_error("Persimmon CUDA offload failed");
|
||||
}
|
||||
#endif
|
||||
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
||||
// on Windows however this is detrimental unless everything is on the GPU
|
||||
#ifndef _WIN32
|
||||
|
Loading…
Reference in New Issue
Block a user