mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
RWKV v6: RWKV_WKV op CUDA implementation (#9454)
* ggml: CUDA unary op EXP Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * ggml: rwkv_wkv op CUDA impl Signed-off-by: Molly Sophia <mollysophia379@gmail.com> --------- Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
d09770cae7
commit
2a63caaa69
@ -34,6 +34,7 @@
|
|||||||
#include "ggml-cuda/tsembd.cuh"
|
#include "ggml-cuda/tsembd.cuh"
|
||||||
#include "ggml-cuda/unary.cuh"
|
#include "ggml-cuda/unary.cuh"
|
||||||
#include "ggml-cuda/upscale.cuh"
|
#include "ggml-cuda/upscale.cuh"
|
||||||
|
#include "ggml-cuda/rwkv-wkv.cuh"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
@ -2243,6 +2244,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||||||
case GGML_UNARY_OP_HARDSWISH:
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
ggml_cuda_op_hardswish(ctx, dst);
|
ggml_cuda_op_hardswish(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_UNARY_OP_EXP:
|
||||||
|
ggml_cuda_op_exp(ctx, dst);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -2345,6 +2349,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
ggml_cuda_cross_entropy_loss(ctx, dst);
|
ggml_cuda_cross_entropy_loss(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_RWKV_WKV:
|
||||||
|
ggml_cuda_op_rwkv_wkv(ctx, dst);
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||||
break;
|
break;
|
||||||
@ -2806,6 +2812,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||||||
case GGML_UNARY_OP_HARDSWISH:
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
case GGML_UNARY_OP_TANH:
|
case GGML_UNARY_OP_TANH:
|
||||||
|
case GGML_UNARY_OP_EXP:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
@ -2967,6 +2974,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
case GGML_OP_RWKV_WKV:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
89
ggml/src/ggml-cuda/rwkv-wkv.cu
Normal file
89
ggml/src/ggml-cuda/rwkv-wkv.cu
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
#include "common.cuh"
|
||||||
|
#include "rwkv-wkv.cuh"
|
||||||
|
|
||||||
|
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int bid = blockIdx.x;
|
||||||
|
|
||||||
|
const int head_size = CUDA_WKV_BLOCK_SIZE;
|
||||||
|
const int batch_i = bid / H;
|
||||||
|
const int head_i = bid % H;
|
||||||
|
const int state_size = C * head_size;
|
||||||
|
const int n_seq_tokens = T / B;
|
||||||
|
|
||||||
|
float state[head_size];
|
||||||
|
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < head_size; i++) {
|
||||||
|
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
_tf[tid] = tf[head_i * head_size + tid];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||||
|
__syncthreads();
|
||||||
|
_k[tid] = k[t];
|
||||||
|
_r[tid] = r[t];
|
||||||
|
_td[tid] = td[t];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float _v = v[t];
|
||||||
|
float y = 0;
|
||||||
|
for (int j = 0; j < head_size; j += 4) {
|
||||||
|
const float4& k = (float4&)(_k[j]);
|
||||||
|
const float4& r = (float4&)(_r[j]);
|
||||||
|
const float4& tf = (float4&)(_tf[j]);
|
||||||
|
const float4& td = (float4&)(_td[j]);
|
||||||
|
float4& s = (float4&)(state[j]);
|
||||||
|
float4 kv;
|
||||||
|
|
||||||
|
kv.x = k.x * _v;
|
||||||
|
kv.y = k.y * _v;
|
||||||
|
kv.z = k.z * _v;
|
||||||
|
kv.w = k.w * _v;
|
||||||
|
|
||||||
|
y += r.x * (tf.x * kv.x + s.x);
|
||||||
|
y += r.y * (tf.y * kv.y + s.y);
|
||||||
|
y += r.z * (tf.z * kv.z + s.z);
|
||||||
|
y += r.w * (tf.w * kv.w + s.w);
|
||||||
|
|
||||||
|
s.x = s.x * td.x + kv.x;
|
||||||
|
s.y = s.y * td.y + kv.y;
|
||||||
|
s.z = s.z * td.z + kv.z;
|
||||||
|
s.w = s.w * td.w + kv.w;
|
||||||
|
}
|
||||||
|
dst[t] = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < head_size; i++) {
|
||||||
|
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const float * k_d = (const float *)dst->src[0]->data;
|
||||||
|
const float * v_d = (const float *)dst->src[1]->data;
|
||||||
|
const float * r_d = (const float *)dst->src[2]->data;
|
||||||
|
const float * tf_d = (const float *)dst->src[3]->data;
|
||||||
|
const float * td_d = (const float *)dst->src[4]->data;
|
||||||
|
const float * s_d = (const float *)dst->src[5]->data;
|
||||||
|
|
||||||
|
const int64_t B = dst->src[5]->ne[1];
|
||||||
|
const int64_t T = dst->src[0]->ne[3];
|
||||||
|
const int64_t C = dst->ne[0];
|
||||||
|
const int64_t H = dst->src[0]->ne[2];
|
||||||
|
|
||||||
|
float * dst_d = (float *)dst->data;
|
||||||
|
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(C % H == 0);
|
||||||
|
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
|
||||||
|
|
||||||
|
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||||
|
}
|
5
ggml/src/ggml-cuda/rwkv-wkv.cuh
Normal file
5
ggml/src/ggml-cuda/rwkv-wkv.cuh
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
#define CUDA_WKV_BLOCK_SIZE 64
|
||||||
|
|
||||||
|
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
@ -95,6 +95,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k)
|
|||||||
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __global__ void exp_f32(const float * x, float * dst, const int k) {
|
||||||
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dst[i] = expf(x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
|
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
@ -189,6 +198,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt
|
|||||||
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||||
|
const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
|
||||||
|
exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||||
|
}
|
||||||
|
|
||||||
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
||||||
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
||||||
@ -354,6 +368,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|||||||
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const float * src0_d = (const float *)src0->data;
|
||||||
|
float * dst_d = (float *)dst->data;
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const float * src0_d = (const float *)src0->data;
|
const float * src0_d = (const float *)src0->data;
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#define CUDA_RELU_BLOCK_SIZE 256
|
#define CUDA_RELU_BLOCK_SIZE 256
|
||||||
#define CUDA_SIGMOID_BLOCK_SIZE 256
|
#define CUDA_SIGMOID_BLOCK_SIZE 256
|
||||||
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
|
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
|
||||||
|
#define CUDA_EXP_BLOCK_SIZE 256
|
||||||
#define CUDA_HARDSWISH_BLOCK_SIZE 256
|
#define CUDA_HARDSWISH_BLOCK_SIZE 256
|
||||||
#define CUDA_SQR_BLOCK_SIZE 256
|
#define CUDA_SQR_BLOCK_SIZE 256
|
||||||
#define CUDA_SQRT_BLOCK_SIZE 256
|
#define CUDA_SQRT_BLOCK_SIZE 256
|
||||||
@ -32,6 +33,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|||||||
|
|
||||||
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
@ -1543,6 +1543,36 @@ struct test_ssm_scan : public test_case {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// GGML_OP_RWKV_WKV
|
||||||
|
struct test_rwkv_wkv : public test_case {
|
||||||
|
const ggml_type type;
|
||||||
|
|
||||||
|
const int64_t head_count;
|
||||||
|
const int64_t head_size;
|
||||||
|
const int64_t n_seq_tokens;
|
||||||
|
const int64_t n_seqs;
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
|
||||||
|
}
|
||||||
|
|
||||||
|
test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
|
||||||
|
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
|
||||||
|
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
const int64_t n_tokens = n_seq_tokens * n_seqs;
|
||||||
|
ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
|
||||||
|
ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
|
||||||
|
ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
|
||||||
|
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
|
||||||
|
ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
|
||||||
|
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
|
||||||
|
ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// GGML_OP_MUL_MAT
|
// GGML_OP_MUL_MAT
|
||||||
struct test_mul_mat : public test_case {
|
struct test_mul_mat : public test_case {
|
||||||
const ggml_type type_a;
|
const ggml_type type_a;
|
||||||
@ -3337,6 +3367,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
|
|
||||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
|
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
|
||||||
|
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
|
||||||
|
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||||
|
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
for (ggml_type type_a : base_types) {
|
for (ggml_type type_a : base_types) {
|
||||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
|
Loading…
Reference in New Issue
Block a user