Compare commits

...

6 Commits

Author SHA1 Message Date
Andrei
946cef3314
Merge 951f1d9053 into a5b57b08ce 2024-09-22 14:16:20 +02:00
Johannes Gäßler
a5b57b08ce
CUDA: enable Gemma FA for HIP/Pascal (#9581)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
2024-09-22 09:34:52 +02:00
Shankar
ecd5d6b65b
llama: remove redundant loop when constructing ubatch (#9574)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
2024-09-22 04:30:34 +02:00
Molly Sophia
2a63caaa69
RWKV v6: RWKV_WKV op CUDA implementation (#9454)
* ggml: CUDA unary op EXP

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* ggml: rwkv_wkv op CUDA impl

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
2024-09-22 04:29:12 +02:00
Andrei Betlen
951f1d9053 Merge remote-tracking branch 'origin' into add-support-for-phi3-vision 2024-08-27 18:13:54 -04:00
Andrei Betlen
dc0625ab8f Add support for Phi3-vision-instruct 2024-08-27 18:11:41 -04:00
9 changed files with 323 additions and 18 deletions

View File

@ -136,6 +136,8 @@ static std::string format(const char * fmt, ...) {
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
#define TN_IMAGE_NEWLINE "model.image_newline"
#define TN_SUB_GN "v.sub_gn"
#define TN_GLB_GN "v.glb_gn"
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
#define TN_MINICPMV_QUERY "resampler.query"
@ -534,6 +536,9 @@ struct clip_vision_model {
struct ggml_tensor * mm_model_ln_kv_b;
struct ggml_tensor * mm_model_ln_post_w;
struct ggml_tensor * mm_model_ln_post_b;
struct ggml_tensor * sub_gn;
struct ggml_tensor * glb_gn;
};
struct clip_ctx {
@ -781,6 +786,138 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
// print_tensor_info(embeddings, "embeddings");
// phi-3.5-vision-instruct
if (model.sub_gn && model.glb_gn) {
// Phi3VisionEmbedding.hd_transform()
ggml_tensor * x = embeddings;
int num_images = batch_size;
int h_crop = 1, w_crop = 1;
int C = x->ne[0];
int L = x->ne[1];
int N = x->ne[2];
int H = (int)sqrt((float)L);
GGML_ASSERT(H * H == L);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
// Phi3ImageEmbedding.reshape_hd_patches_2x2merge()
x = ggml_reshape_4d(ctx0, x, N, H, H, C);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 1, 3));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 0, 1, 2));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 2, 3, 1, 0));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_reshape_4d(ctx0, x, 2, H / 2, 2, H / 2 * C * N);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 1, 3));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 1, 3, 2));
x = ggml_reshape_3d(ctx0, x, N * C * (H / 2), (H / 2), 4);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_reshape_4d(ctx0, x, 4, H / 2, H / 2, N * C);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_reshape_4d(ctx0, x, 4, (H / 2) * (H / 2), C, N);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 3, 1, 2));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_reshape_4d(ctx0, x, 4 * C, H / 2, H / 2, N);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_reshape_4d(ctx0, x, (H / 2) * 4 * C, (H / 2), w_crop, num_images * h_crop);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 1, 3));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_reshape_4d(ctx0, x, 4 * C, w_crop * (H / 2), h_crop * (H / 2), num_images);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
ggml_tensor * global_image_features_hd = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
// Phi3ImageEmbedding.add_image_newline()
ggml_tensor * newline_embedding = model.sub_gn;
for (int i = 0; i < H/2-1; i++) {
newline_embedding = ggml_concat(ctx0, newline_embedding, model.sub_gn, 2);
}
ggml_tensor * global_image_features_hd_newline = ggml_concat(ctx0, global_image_features_hd, newline_embedding, 1);
global_image_features_hd_newline = ggml_cont(ctx0, ggml_permute(ctx0, global_image_features_hd_newline, 3, 2, 1, 0));
global_image_features_hd_newline = ggml_reshape_4d(ctx0, global_image_features_hd_newline, 1, 1, (w_crop*(H/2)+1) * h_crop*(H/2), 4*C);
global_image_features_hd_newline = ggml_cont(ctx0, ggml_permute(ctx0, global_image_features_hd_newline, 3, 2, 1, 0));
h_crop = image_size / 336;
w_crop = image_size / 336;
// sub_image_features_hd
x = embeddings;
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
// Phi3ImageEmbedding.reshape_hd_patches_2x2merge()
x = ggml_reshape_4d(ctx0, x, N, H, H, C);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 1, 3));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 0, 1, 2));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 2, 3, 1, 0));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_reshape_4d(ctx0, x, 2, H / 2, 2, H / 2 * C * N);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 1, 3));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 1, 3, 2));
x = ggml_reshape_3d(ctx0, x, N * C * (H / 2), (H / 2), 4);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_reshape_4d(ctx0, x, 4, H / 2, H / 2, N * C);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_reshape_4d(ctx0, x, 4, (H / 2) * (H / 2), C, N);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 3, 1, 2));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_reshape_4d(ctx0, x, 4 * C, H / 2, H / 2, N);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_reshape_4d(ctx0, x, (H / 2) * 4 * C, (H / 2), w_crop, num_images * h_crop);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 1, 3));
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
x = ggml_reshape_4d(ctx0, x, 4 * C, w_crop * (H / 2), h_crop * (H / 2), num_images);
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
ggml_tensor * sub_image_features_hd = ggml_cont(ctx0, ggml_permute(ctx0, x, 3, 2, 1, 0));
// Phi3ImageEmbedding.add_image_newline()
newline_embedding = model.sub_gn;
for (int i = 0; i < (H/2-1); i++) {
newline_embedding = ggml_concat(ctx0, newline_embedding, model.sub_gn, 2);
}
ggml_tensor * sub_image_features_hd_newline = ggml_concat(ctx0, sub_image_features_hd, newline_embedding, 1);
sub_image_features_hd_newline = ggml_cont(ctx0, ggml_permute(ctx0, sub_image_features_hd_newline, 3, 2, 1, 0));
sub_image_features_hd_newline = ggml_reshape_4d(ctx0, sub_image_features_hd_newline, 1, 1, (w_crop*(H/2)+1) * h_crop*(H/2), 4*C);
sub_image_features_hd_newline = ggml_cont(ctx0, ggml_permute(ctx0, sub_image_features_hd_newline, 3, 2, 1, 0));
embeddings = ggml_concat(ctx0, sub_image_features_hd_newline, model.glb_gn, 1);
embeddings = ggml_concat(ctx0, embeddings, global_image_features_hd_newline, 1);
}
// llava projector
if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
@ -1406,6 +1543,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
vision_model.image_newline = get_tensor(new_clip->ctx_data, TN_IMAGE_NEWLINE);
// LOG_INF("%s: image_newline tensor (llava-1.6) found\n", __func__);
} catch (std::runtime_error & /*e*/) { }
try {
vision_model.sub_gn = get_tensor(new_clip->ctx_data, TN_SUB_GN);
vision_model.glb_gn = get_tensor(new_clip->ctx_data, TN_GLB_GN);
} catch (std::runtime_error & /*e*/) { }
} else if (new_clip->proj_type == PROJECTOR_TYPE_LDP) {
// MobileVLM projection
vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "weight"));

View File

@ -34,6 +34,7 @@
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/rwkv-wkv.cuh"
#include <algorithm>
#include <array>
@ -2243,6 +2244,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_HARDSWISH:
ggml_cuda_op_hardswish(ctx, dst);
break;
case GGML_UNARY_OP_EXP:
ggml_cuda_op_exp(ctx, dst);
break;
default:
return false;
}
@ -2345,6 +2349,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
case GGML_OP_RWKV_WKV:
ggml_cuda_op_rwkv_wkv(ctx, dst);
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
ggml_cuda_cross_entropy_loss_back(ctx, dst);
break;
@ -2806,6 +2812,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@ -2967,20 +2974,21 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV:
return true;
case GGML_OP_FLASH_ATTN_EXT:
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
#else
if (op->src[0]->ne[0] == 128) {
return true;
}
case GGML_OP_FLASH_ATTN_EXT: {
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
return true;
}
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
if (op->src[0]->ne[0] == 128) {
return true;
}
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
return true;
}
const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
}
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW:

View File

@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
if (!fast_fp16_available(cc)) {
if (Q->ne[1] <= 8) {
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);

View File

@ -0,0 +1,89 @@
#include "common.cuh"
#include "rwkv-wkv.cuh"
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int head_size = CUDA_WKV_BLOCK_SIZE;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float state[head_size];
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
__syncthreads();
_tf[tid] = tf[head_i * head_size + tid];
__syncthreads();
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
__syncthreads();
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
__syncthreads();
const float _v = v[t];
float y = 0;
for (int j = 0; j < head_size; j += 4) {
const float4& k = (float4&)(_k[j]);
const float4& r = (float4&)(_r[j]);
const float4& tf = (float4&)(_tf[j]);
const float4& td = (float4&)(_td[j]);
float4& s = (float4&)(state[j]);
float4 kv;
kv.x = k.x * _v;
kv.y = k.y * _v;
kv.z = k.z * _v;
kv.w = k.w * _v;
y += r.x * (tf.x * kv.x + s.x);
y += r.y * (tf.y * kv.y + s.y);
y += r.z * (tf.z * kv.z + s.z);
y += r.w * (tf.w * kv.w + s.w);
s.x = s.x * td.x + kv.x;
s.y = s.y * td.y + kv.y;
s.z = s.z * td.z + kv.z;
s.w = s.w * td.w + kv.w;
}
dst[t] = y;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
}
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * k_d = (const float *)dst->src[0]->data;
const float * v_d = (const float *)dst->src[1]->data;
const float * r_d = (const float *)dst->src[2]->data;
const float * tf_d = (const float *)dst->src[3]->data;
const float * td_d = (const float *)dst->src[4]->data;
const float * s_d = (const float *)dst->src[5]->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[3];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[2];
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
}

View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_WKV_BLOCK_SIZE 64
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -95,6 +95,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k)
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
}
static __global__ void exp_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = expf(x[i]);
}
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@ -189,6 +198,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
@ -354,6 +368,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;

View File

@ -8,6 +8,7 @@
#define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_SIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
#define CUDA_EXP_BLOCK_SIZE 256
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_SQRT_BLOCK_SIZE 256
@ -32,6 +33,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -3056,18 +3056,14 @@ struct llama_sbatch {
} else {
// simple split
if (batch->n_seq_id) {
for (size_t i = 0; i < length; ++i) {
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
}
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
} else {
for (size_t i = 0; i < length; ++i) {
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
}
}
if (batch->seq_id) {
for (size_t i = 0; i < length; ++i) {
ubatch.seq_id = batch->seq_id + seq.offset;
}
ubatch.seq_id = batch->seq_id + seq.offset;
} else {
for (size_t i = 0; i < length; ++i) {
ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;

View File

@ -1543,6 +1543,36 @@ struct test_ssm_scan : public test_case {
}
};
// GGML_OP_RWKV_WKV
struct test_rwkv_wkv : public test_case {
const ggml_type type;
const int64_t head_count;
const int64_t head_size;
const int64_t n_seq_tokens;
const int64_t n_seqs;
std::string vars() override {
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
}
test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t n_tokens = n_seq_tokens * n_seqs;
ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
return out;
}
};
// GGML_OP_MUL_MAT
struct test_mul_mat : public test_case {
const ggml_type type_a;
@ -3337,6 +3367,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
#if 1
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
@ -3564,7 +3599,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
if (hs != 128 && logit_softcap != 0.0f) continue;
for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) {
for (int nb : { 1, 2, 4, 8, }) {
for (int nb : { 1, 3, 32, 35, }) {
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
}