Compare commits

...

5 Commits

Author SHA1 Message Date
Radoslav Gerganov
34c686c3f7
Merge adf3bce13b into a5b57b08ce 2024-09-22 12:48:16 +03:00
Johannes Gäßler
a5b57b08ce
CUDA: enable Gemma FA for HIP/Pascal (#9581)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
2024-09-22 09:34:52 +02:00
Shankar
ecd5d6b65b
llama: remove redundant loop when constructing ubatch (#9574)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
2024-09-22 04:30:34 +02:00
Molly Sophia
2a63caaa69
RWKV v6: RWKV_WKV op CUDA implementation (#9454)
* ggml: CUDA unary op EXP

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* ggml: rwkv_wkv op CUDA impl

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
2024-09-22 04:29:12 +02:00
Radoslav Gerganov
adf3bce13b vulkan : do not use tensor->extra
This patch allows using the Vulkan backend with the RPC backend as
tensor->extra is no longer used.

Ref: #8536
2024-09-10 16:24:27 +03:00
9 changed files with 311 additions and 194 deletions

View File

@ -34,6 +34,7 @@
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/rwkv-wkv.cuh"
#include <algorithm>
#include <array>
@ -2243,6 +2244,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_HARDSWISH:
ggml_cuda_op_hardswish(ctx, dst);
break;
case GGML_UNARY_OP_EXP:
ggml_cuda_op_exp(ctx, dst);
break;
default:
return false;
}
@ -2345,6 +2349,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
case GGML_OP_RWKV_WKV:
ggml_cuda_op_rwkv_wkv(ctx, dst);
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
ggml_cuda_cross_entropy_loss_back(ctx, dst);
break;
@ -2806,6 +2812,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@ -2967,20 +2974,21 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV:
return true;
case GGML_OP_FLASH_ATTN_EXT:
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
#else
if (op->src[0]->ne[0] == 128) {
return true;
}
case GGML_OP_FLASH_ATTN_EXT: {
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
return true;
}
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
if (op->src[0]->ne[0] == 128) {
return true;
}
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
return true;
}
const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
}
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW:

View File

@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
if (!fast_fp16_available(cc)) {
if (Q->ne[1] <= 8) {
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);

View File

@ -0,0 +1,89 @@
#include "common.cuh"
#include "rwkv-wkv.cuh"
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int head_size = CUDA_WKV_BLOCK_SIZE;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float state[head_size];
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
__syncthreads();
_tf[tid] = tf[head_i * head_size + tid];
__syncthreads();
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
__syncthreads();
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
__syncthreads();
const float _v = v[t];
float y = 0;
for (int j = 0; j < head_size; j += 4) {
const float4& k = (float4&)(_k[j]);
const float4& r = (float4&)(_r[j]);
const float4& tf = (float4&)(_tf[j]);
const float4& td = (float4&)(_td[j]);
float4& s = (float4&)(state[j]);
float4 kv;
kv.x = k.x * _v;
kv.y = k.y * _v;
kv.z = k.z * _v;
kv.w = k.w * _v;
y += r.x * (tf.x * kv.x + s.x);
y += r.y * (tf.y * kv.y + s.y);
y += r.z * (tf.z * kv.z + s.z);
y += r.w * (tf.w * kv.w + s.w);
s.x = s.x * td.x + kv.x;
s.y = s.y * td.y + kv.y;
s.z = s.z * td.z + kv.z;
s.w = s.w * td.w + kv.w;
}
dst[t] = y;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
}
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * k_d = (const float *)dst->src[0]->data;
const float * v_d = (const float *)dst->src[1]->data;
const float * r_d = (const float *)dst->src[2]->data;
const float * tf_d = (const float *)dst->src[3]->data;
const float * td_d = (const float *)dst->src[4]->data;
const float * s_d = (const float *)dst->src[5]->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[3];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[2];
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
}

View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_WKV_BLOCK_SIZE 64
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -95,6 +95,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k)
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
}
static __global__ void exp_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = expf(x[i]);
}
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@ -189,6 +198,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
@ -354,6 +368,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;

View File

@ -8,6 +8,7 @@
#define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_SIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
#define CUDA_EXP_BLOCK_SIZE 256
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_SQRT_BLOCK_SIZE 256
@ -32,6 +33,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -431,16 +431,6 @@ struct vk_context_struct {
typedef std::shared_ptr<vk_context_struct> vk_context;
typedef std::weak_ptr<vk_context_struct> vk_context_ref;
struct ggml_tensor_extra_gpu {
vk_buffer_ref buffer_gpu;
uint64_t offset;
void reset() {
buffer_gpu.reset();
offset = 0;
}
};
struct ggml_vk_garbage_collector {
std::vector<vk_semaphore> tl_semaphores;
std::vector<vk_semaphore> semaphores;
@ -551,6 +541,31 @@ struct ggml_backend_vk_context {
std::vector<vk_context_ref> tensor_ctxs;
};
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
static uint64_t vk_tensor_offset(const ggml_tensor * tensor) {
if (tensor->view_src) {
return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base;
}
return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
}
struct ggml_backend_vk_buffer_context {
vk_device_ref device;
vk_buffer dev_buffer;
std::string name;
ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) :
device(device),
dev_buffer(dev_buffer),
name(name) {
}
~ggml_backend_vk_buffer_context() {
ggml_vk_destroy_buffer(dev_buffer);
}
};
#ifdef GGML_VULKAN_MEMORY_DEBUG
void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {
std::lock_guard<std::mutex> guard(log_mutex);
@ -3038,9 +3053,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const uint64_t r2 = ne12 / ne02;
const uint64_t r3 = ne13 / ne03;
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
vk_buffer d_Qx;
size_t qx_buf_offset = 0;
@ -3142,8 +3157,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
return;
}
vk_buffer d_D = extra->buffer_gpu.lock();
const uint64_t d_buf_offset = extra->offset + dst->view_offs;
vk_buffer d_D = dst_buf_ctx->dev_buffer;
const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
GGML_ASSERT(d_D != nullptr);
GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03);
vk_buffer d_X;
@ -3151,13 +3166,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
vk_buffer d_Y;
uint64_t y_buf_offset = 0;
if (!src0_uma) {
d_Qx = extra_src0->buffer_gpu.lock();
qx_buf_offset = extra_src0->offset + src0->view_offs;
d_Qx = src0_buf_ctx->dev_buffer;
qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
GGML_ASSERT(d_Qx != nullptr);
}
if (!src1_uma) {
d_Qy = extra_src1->buffer_gpu.lock();
qy_buf_offset = extra_src1->offset + src1->view_offs;
d_Qy = src1_buf_ctx->dev_buffer;
qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
GGML_ASSERT(d_Qy != nullptr);
}
if (qx_needs_dequant) {
@ -3238,9 +3253,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t r2 = ne12 / ne02;
const uint64_t r3 = ne13 / ne03;
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
vk_buffer d_Qx;
size_t qx_buf_offset = 0;
@ -3319,21 +3334,21 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
return;
}
vk_buffer d_D = extra->buffer_gpu.lock();
const uint64_t d_buf_offset = extra->offset + dst->view_offs;
vk_buffer d_D = dst_buf_ctx->dev_buffer;
const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
GGML_ASSERT(d_D != nullptr);
vk_buffer d_X;
uint64_t x_buf_offset = 0;
vk_buffer d_Y;
uint64_t y_buf_offset = 0;
if(!src0_uma) {
d_Qx = extra_src0->buffer_gpu.lock();
qx_buf_offset = extra_src0->offset + src0->view_offs;
d_Qx = src0_buf_ctx->dev_buffer;
qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
GGML_ASSERT(d_Qx != nullptr);
}
if(!src1_uma) {
d_Qy = extra_src1->buffer_gpu.lock();
qy_buf_offset = extra_src1->offset + src1->view_offs;
d_Qy = src1_buf_ctx->dev_buffer;
qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
GGML_ASSERT(d_Qy != nullptr);
}
if (qx_needs_dequant) {
@ -3416,9 +3431,9 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
GGML_ASSERT(ne11 == 1);
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
vk_buffer d_Qy;
size_t qy_buf_offset = 0;
@ -3444,15 +3459,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
return;
}
vk_buffer d_D = extra->buffer_gpu.lock();
const uint64_t d_buf_offset = extra->offset + dst->view_offs;
vk_buffer d_D = dst_buf_ctx->dev_buffer;
const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
GGML_ASSERT(d_D != nullptr);
vk_buffer d_Qx = extra_src0->buffer_gpu.lock();
const uint64_t qx_buf_offset = extra_src0->offset + src0->view_offs;
vk_buffer d_Qx = src0_buf_ctx->dev_buffer;
const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
GGML_ASSERT(d_Qx != nullptr);
if (!src1_uma) {
d_Qy = extra_src1->buffer_gpu.lock();
qy_buf_offset = extra_src1->offset + src1->view_offs;
d_Qy = src1_buf_ctx->dev_buffer;
qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
GGML_ASSERT(d_Qx != nullptr);
}
@ -3494,9 +3509,9 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
GGML_ASSERT(ne11 == 1);
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
vk_buffer d_Qy = nullptr;
size_t qy_buf_offset = 0;
@ -3523,15 +3538,15 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
return;
}
vk_buffer d_D = extra->buffer_gpu.lock();
const uint64_t d_buf_offset = extra->offset + dst->view_offs;
vk_buffer d_D = dst_buf_ctx->dev_buffer;
const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
GGML_ASSERT(d_D != nullptr);
vk_buffer d_Qx = extra_src0->buffer_gpu.lock();
const uint64_t qx_buf_offset = extra_src0->offset + src0->view_offs;
vk_buffer d_Qx = src0_buf_ctx->dev_buffer;
const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
GGML_ASSERT(d_Qx != nullptr);
if (!src1_uma) {
d_Qy = extra_src1->buffer_gpu.lock();
qy_buf_offset = extra_src1->offset + src1->view_offs;
d_Qy = src1_buf_ctx->dev_buffer;
qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
GGML_ASSERT(d_Qx != nullptr);
}
@ -3593,10 +3608,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t n_as = ne02;
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
ggml_tensor_extra_gpu * extra_ids = (ggml_tensor_extra_gpu *) ids->extra;
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
vk_buffer d_Qx;
size_t qx_buf_offset = 0;
@ -3693,26 +3708,26 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
return;
}
vk_buffer d_D = extra->buffer_gpu.lock();
const uint64_t d_buf_offset = extra->offset + dst->view_offs;
vk_buffer d_D = dst_buf_ctx->dev_buffer;
const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
GGML_ASSERT(d_D != nullptr);
vk_buffer d_X;
uint64_t x_buf_offset = 0;
vk_buffer d_Y;
uint64_t y_buf_offset = 0;
if (!src0_uma) {
d_Qx = extra_src0->buffer_gpu.lock();
qx_buf_offset = extra_src0->offset + src0->view_offs;
d_Qx = src0_buf_ctx->dev_buffer;
qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
GGML_ASSERT(d_Qx != nullptr);
}
if (!src1_uma) {
d_Qy = extra_src1->buffer_gpu.lock();
qy_buf_offset = extra_src1->offset + src1->view_offs;
d_Qy = src1_buf_ctx->dev_buffer;
qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
GGML_ASSERT(d_Qy != nullptr);
}
if (!ids_uma) {
d_ids = extra_ids->buffer_gpu.lock();
ids_buf_offset = extra_ids->offset + ids->view_offs;
d_ids = ids_buf_ctx->dev_buffer;
ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
GGML_ASSERT(d_ids != nullptr);
}
if (qx_needs_dequant) {
@ -3798,10 +3813,10 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const uint64_t ne22 = dst->ne[2];
const uint64_t ne23 = dst->ne[3];
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
ggml_tensor_extra_gpu * extra_ids = (ggml_tensor_extra_gpu *) ids->extra;
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
vk_buffer d_Qx;
size_t qx_buf_offset = 0;
@ -3886,26 +3901,26 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
return;
}
vk_buffer d_D = extra->buffer_gpu.lock();
const uint64_t d_buf_offset = extra->offset + dst->view_offs;
vk_buffer d_D = dst_buf_ctx->dev_buffer;
const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
GGML_ASSERT(d_D != nullptr);
vk_buffer d_X;
uint64_t x_buf_offset = 0;
vk_buffer d_Y;
uint64_t y_buf_offset = 0;
if(!src0_uma) {
d_Qx = extra_src0->buffer_gpu.lock();
qx_buf_offset = extra_src0->offset + src0->view_offs;
d_Qx = src0_buf_ctx->dev_buffer;
qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
GGML_ASSERT(d_Qx != nullptr);
}
if(!src1_uma) {
d_Qy = extra_src1->buffer_gpu.lock();
qy_buf_offset = extra_src1->offset + src1->view_offs;
d_Qy = src1_buf_ctx->dev_buffer;
qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
GGML_ASSERT(d_Qy != nullptr);
}
if(!ids_uma) {
d_ids = extra_ids->buffer_gpu.lock();
ids_buf_offset = extra_ids->offset + ids->view_offs;
d_ids = ids_buf_ctx->dev_buffer;
ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
GGML_ASSERT(d_ids != nullptr);
}
if (qx_needs_dequant) {
@ -4212,7 +4227,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")");
GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
GGML_ASSERT(dst->extra != nullptr);
GGML_ASSERT(dst->buffer != nullptr);
const uint64_t ne00 = src0->ne[0];
const uint64_t ne01 = src0->ne[1];
const uint64_t ne02 = src0->ne[2];
@ -4258,10 +4273,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
ggml_tensor_extra_gpu * extra_src1 = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
ggml_tensor_extra_gpu * extra_src2 = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr;
ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr;
vk_buffer d_X = nullptr;
size_t x_buf_offset = 0;
@ -4292,7 +4307,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0;
uint64_t d_sz = ggml_type_size(dst->type) * ned;
vk_buffer d_D = extra->buffer_gpu.lock();
vk_buffer d_D = dst_buf_ctx->dev_buffer;
// Workaround for tiny tensor inputs on ROPE
if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) {
@ -4300,21 +4315,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
GGML_ASSERT(d_D != nullptr);
uint64_t d_buf_offset = ((extra->offset + dst->view_offs) / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
GGML_ASSERT(d_buf_offset == extra->offset || op == GGML_OP_CPY); // NOLINT
uint64_t d_buf_offset = ((vk_tensor_offset(dst) + dst->view_offs) / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
GGML_ASSERT(d_buf_offset == vk_tensor_offset(dst) || op == GGML_OP_CPY); // NOLINT
if(!src0_uma) {
d_X = extra_src0->buffer_gpu.lock();
x_buf_offset = extra_src0->offset + src0->view_offs;
d_X = src0_buf_ctx->dev_buffer;
x_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
GGML_ASSERT(d_X != nullptr);
}
if (use_src1 && !src1_uma) {
d_Y = extra_src1->buffer_gpu.lock();
y_buf_offset = extra_src1->offset + src1->view_offs;
d_Y = src1_buf_ctx->dev_buffer;
y_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
GGML_ASSERT(d_Y != nullptr);
}
if (use_src2 && !src2_uma) {
d_Z = extra_src2->buffer_gpu.lock();
z_buf_offset = extra_src2->offset + src2->view_offs;
d_Z = src2_buf_ctx->dev_buffer;
z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
GGML_ASSERT(d_Z != nullptr);
}
@ -4493,11 +4508,10 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
}
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
@ -4686,10 +4700,9 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
(uint32_t)ggml_nelements(src0),
@ -5487,14 +5500,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
}
#endif
static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor) {
VK_LOG_DEBUG("ggml_vk_create_extra(" << tensor << " (" << tensor->name << ", " << ggml_op_name(tensor->op) << "))");
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
extra->reset();
tensor->extra = extra;
return extra;
}
static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
#if defined(GGML_VULKAN_RUN_TESTS)
ctx->staging = ggml_vk_create_buffer_check(ctx->device, 100ul * 1024ul * 1024ul,
@ -5666,9 +5671,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* t
// Returns true if node has enqueued work into the queue, false otherwise
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra;
if (ggml_is_empty(node) || extra == nullptr) {
if (ggml_is_empty(node) || !node->buffer) {
return false;
}
@ -5920,7 +5923,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
}
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){
ggml_tensor_extra_gpu * extra = nullptr;
ggml_backend_buffer * buf = nullptr;
switch (tensor->op) {
case GGML_OP_ADD:
@ -5956,7 +5959,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT:
extra = (ggml_tensor_extra_gpu *) tensor->extra;
buf = tensor->buffer;
break;
case GGML_OP_UNARY:
@ -5966,7 +5969,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH:
extra = (ggml_tensor_extra_gpu *) tensor->extra;
buf = tensor->buffer;
break;
default:
return false;
@ -5974,14 +5977,14 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
break;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
extra = (ggml_tensor_extra_gpu *) tensor->extra;
buf = tensor->buffer;
break;
default:
return false;
}
if (extra == nullptr) {
if (buf == nullptr) {
return false;
}
@ -6122,42 +6125,6 @@ GGML_CALL static void ggml_vk_get_device_description(int device, char * descript
// device backend
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
struct ggml_backend_vk_buffer_context {
vk_device_ref device;
vk_buffer dev_buffer;
ggml_tensor_extra_gpu * temp_tensor_extras = nullptr;
size_t temp_tensor_extra_index = 0;
std::string name;
ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) :
device(device),
dev_buffer(dev_buffer),
name(name) {
}
~ggml_backend_vk_buffer_context() {
ggml_vk_destroy_buffer(dev_buffer);
if (temp_tensor_extras != nullptr) {
delete[] temp_tensor_extras;
}
}
ggml_tensor_extra_gpu * ggml_vk_alloc_temp_tensor_extra() {
if (temp_tensor_extras == nullptr) {
temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_VK_MAX_NODES];
}
size_t alloc_index = temp_tensor_extra_index;
temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_VK_MAX_NODES;
ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
extra->reset();
return extra;
}
};
GGML_CALL static const char * ggml_backend_vk_buffer_get_name(ggml_backend_buffer_t buffer) {
ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
return ctx->name.c_str();
@ -6182,51 +6149,37 @@ GGML_CALL static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t bu
GGML_CALL static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")");
ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
if (tensor->view_src != nullptr) {
GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
GGML_ASSERT(tensor->view_src->extra != nullptr);
tensor->extra = tensor->view_src->extra;
} else {
ggml_tensor_extra_gpu * extra = ctx->ggml_vk_alloc_temp_tensor_extra();
extra->buffer_gpu = ctx->dev_buffer;
extra->offset = (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
tensor->extra = extra;
}
}
GGML_CALL static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
vk_buffer buf = buf_ctx->dev_buffer;
vk_buffer buf = extra->buffer_gpu.lock();
ggml_vk_buffer_write(buf, extra->offset + tensor->view_offs + offset, data, size);
GGML_UNUSED(buffer);
ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
}
GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
vk_buffer buf = extra->buffer_gpu.lock();
vk_buffer buf = buf_ctx->dev_buffer;
ggml_vk_buffer_read(buf, extra->offset + tensor->view_offs + offset, data, size);
GGML_UNUSED(buffer);
ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
}
GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
if (ggml_backend_buffer_is_vk(src->buffer)) {
ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra;
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
vk_buffer src_buf = src_extra->buffer_gpu.lock();
vk_buffer dst_buf = dst_extra->buffer_gpu.lock();
vk_buffer src_buf = src_buf_ctx->dev_buffer;
vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
ggml_vk_buffer_copy(dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src));
ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
return true;
}
@ -6404,7 +6357,7 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
vk_context transfer_ctx;
@ -6417,9 +6370,9 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g
transfer_ctx = ctx->transfer_ctx.lock();
}
vk_buffer buf = extra->buffer_gpu.lock();
vk_buffer buf = buf_ctx->dev_buffer;
ggml_vk_buffer_write_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size);
ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
}
GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@ -6427,7 +6380,7 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
vk_context transfer_ctx;
@ -6440,17 +6393,17 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c
transfer_ctx = ctx->transfer_ctx.lock();
}
vk_buffer buf = extra->buffer_gpu.lock();
vk_buffer buf = buf_ctx->dev_buffer;
ggml_vk_buffer_read_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size);
ggml_vk_buffer_read_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
}
GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra;
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
vk_context transfer_ctx;
@ -6463,10 +6416,10 @@ GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, c
transfer_ctx = ctx->transfer_ctx.lock();
}
vk_buffer src_buf = src_extra->buffer_gpu.lock();
vk_buffer dst_buf = dst_extra->buffer_gpu.lock();
vk_buffer src_buf = src_buf_ctx->dev_buffer;
vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src));
ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
return true;
}

View File

@ -3056,18 +3056,14 @@ struct llama_sbatch {
} else {
// simple split
if (batch->n_seq_id) {
for (size_t i = 0; i < length; ++i) {
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
}
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
} else {
for (size_t i = 0; i < length; ++i) {
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
}
}
if (batch->seq_id) {
for (size_t i = 0; i < length; ++i) {
ubatch.seq_id = batch->seq_id + seq.offset;
}
ubatch.seq_id = batch->seq_id + seq.offset;
} else {
for (size_t i = 0; i < length; ++i) {
ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;

View File

@ -1543,6 +1543,36 @@ struct test_ssm_scan : public test_case {
}
};
// GGML_OP_RWKV_WKV
struct test_rwkv_wkv : public test_case {
const ggml_type type;
const int64_t head_count;
const int64_t head_size;
const int64_t n_seq_tokens;
const int64_t n_seqs;
std::string vars() override {
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
}
test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t n_tokens = n_seq_tokens * n_seqs;
ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
return out;
}
};
// GGML_OP_MUL_MAT
struct test_mul_mat : public test_case {
const ggml_type type_a;
@ -3337,6 +3367,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
#if 1
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
@ -3564,7 +3599,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
if (hs != 128 && logit_softcap != 0.0f) continue;
for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) {
for (int nb : { 1, 2, 4, 8, }) {
for (int nb : { 1, 3, 32, 35, }) {
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
}