vulkan : rope n_past is now KQ_pos, f16 rope kernel

This commit is contained in:
Jared Van Bortel 2023-11-23 17:18:42 -05:00
parent 71565eb0c3
commit 84f7fc4553
5 changed files with 169 additions and 47 deletions

View File

@ -490,7 +490,8 @@ if (LLAMA_KOMPUTE)
kompute/op_getrows_q4_0.comp kompute/op_getrows_q4_0.comp
kompute/op_getrows_q4_1.comp kompute/op_getrows_q4_1.comp
kompute/op_getrows_q6_k.comp kompute/op_getrows_q6_k.comp
kompute/op_rope.comp kompute/op_rope_f16.comp
kompute/op_rope_f32.comp
kompute/op_cpy_f16_f16.comp kompute/op_cpy_f16_f16.comp
kompute/op_cpy_f16_f32.comp kompute/op_cpy_f16_f32.comp
kompute/op_cpy_f32_f16.comp kompute/op_cpy_f32_f16.comp
@ -521,7 +522,8 @@ if (LLAMA_KOMPUTE)
shaderop_getrows_q4_0.h shaderop_getrows_q4_0.h
shaderop_getrows_q4_1.h shaderop_getrows_q4_1.h
shaderop_getrows_q6_k.h shaderop_getrows_q6_k.h
shaderop_rope.h shaderop_rope_f16.h
shaderop_rope_f32.h
shaderop_cpy_f16_f16.h shaderop_cpy_f16_f16.h
shaderop_cpy_f16_f32.h shaderop_cpy_f16_f32.h
shaderop_cpy_f32_f16.h shaderop_cpy_f32_f16.h

View File

@ -32,7 +32,8 @@
#include "shaderop_getrows_q4_0.h" #include "shaderop_getrows_q4_0.h"
#include "shaderop_getrows_q4_1.h" #include "shaderop_getrows_q4_1.h"
#include "shaderop_getrows_q6_k.h" #include "shaderop_getrows_q6_k.h"
#include "shaderop_rope.h" #include "shaderop_rope_f16.h"
#include "shaderop_rope_f32.h"
#include "shaderop_cpy_f16_f16.h" #include "shaderop_cpy_f16_f16.h"
#include "shaderop_cpy_f16_f32.h" #include "shaderop_cpy_f16_f32.h"
#include "shaderop_cpy_f32_f16.h" #include "shaderop_cpy_f32_f16.h"
@ -1175,51 +1176,66 @@ void ggml_vk_get_rows_q6_k(Args&&... args) {
ggml_vk_get_rows(spirv, 1/*We access blocks unaligned*/, QK_NL, std::forward<Args>(args)...); ggml_vk_get_rows(spirv, 1/*We access blocks unaligned*/, QK_NL, std::forward<Args>(args)...);
} }
void ggml_vk_rope(kp::Sequence& seq, void ggml_vk_rope(
const std::shared_ptr<kp::Tensor>& in, kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& out, const std::shared_ptr<kp::Tensor>& inA,
uint32_t inOff, uint32_t outOff, const std::shared_ptr<kp::Tensor>& inB,
uint32_t n_past, int32_t n_dims, int32_t mode, const std::shared_ptr<kp::Tensor>& out,
float freq_base, float freq_scale, uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
int32_t ne01, int32_t ne02, int32_t ne03, ggml_type src0t, int32_t n_dims, int32_t mode,
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03, float freq_base, float freq_scale,
int32_t ne0, int32_t ne01, int32_t ne02, int32_t ne03,
uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3) { uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
const static auto spirv = getSpirvShader(kp::shader_data::op_rope_comp_spv, int32_t ne0,
kp::shader_data::op_rope_comp_spv_len); uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
) {
GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
GGML_ASSERT(nb03%sizeof(float) == 0); static const auto spirv_f16 = getSpirvShader(
GGML_ASSERT(nb02%sizeof(float) == 0); kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
GGML_ASSERT(nb01%sizeof(float) == 0); );
GGML_ASSERT(nb00%sizeof(float) == 0); static const auto spirv_f32 = getSpirvShader(
GGML_ASSERT(nb3%sizeof(float) == 0); kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
GGML_ASSERT(nb2%sizeof(float) == 0); );
GGML_ASSERT(nb1%sizeof(float) == 0);
GGML_ASSERT(nb0%sizeof(float) == 0); int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
GGML_ASSERT(nb03 % type_size == 0);
GGML_ASSERT(nb02 % type_size == 0);
GGML_ASSERT(nb01 % type_size == 0);
GGML_ASSERT(nb00 % type_size == 0);
GGML_ASSERT(nb3 % type_size == 0);
GGML_ASSERT(nb2 % type_size == 0);
GGML_ASSERT(nb1 % type_size == 0);
GGML_ASSERT(nb0 % type_size == 0);
struct PushConstants { struct PushConstants {
uint32_t inOff, outOff; uint32_t inAOff, inBOff, outOff;
uint32_t n_past;
int32_t n_dims, mode; int32_t n_dims, mode;
float freq_base, freq_scale; float freq_base, freq_scale;
uint32_t nb00, nb01, nb02, nb03; uint32_t nb00, nb01, nb02, nb03;
int32_t ne0; int32_t ne0;
uint32_t nb0, nb1, nb2, nb3; uint32_t nb0, nb1, nb2, nb3;
} pushConsts { } pushConsts {
safe_divide(inOff, 4), safe_divide(outOff, 4), safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
n_past, n_dims, mode, n_dims, mode,
freq_base, freq_scale, freq_base, freq_scale,
nb00, nb01, nb02, nb03, nb00, nb01, nb02, nb03,
ne0, ne0,
nb0, nb1, nb2, nb3 nb0, nb1, nb2, nb3
}; };
auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
std::shared_ptr<kp::Algorithm> s_algo = nullptr; std::shared_ptr<kp::Algorithm> s_algo = nullptr;
if (!komputeManager()->hasAlgorithm(__func__)) if (!komputeManager()->hasAlgorithm(name)) {
s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}); s_algo = komputeManager()->algorithm<float, PushConstants>(
else { name, s_kompute_context->pool.get(), {inA, inB, out},
s_algo = komputeManager()->getAlgorithm(__func__); src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
s_algo->setTensors({in, out}); {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
);
} else {
s_algo = komputeManager()->getAlgorithm(name);
s_algo->setTensors({inA, inB, out});
s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)}); s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
s_algo->setPushConstants<PushConstants>({pushConsts}); s_algo->setPushConstants<PushConstants>({pushConsts});
s_algo->updateDescriptors(s_kompute_context->pool.get()); s_algo->updateDescriptors(s_kompute_context->pool.get());
@ -1506,14 +1522,16 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
} break; } break;
case GGML_OP_ROPE: case GGML_OP_ROPE:
{ {
const int n_past = ((int32_t *) dst->op_params)[0]; GGML_ASSERT(ne10 == ne02);
GGML_ASSERT(src0t == dstt);
// const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
float freq_base; float freq_base;
float freq_scale; float freq_scale;
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
ggml_vk_rope(seq, id_src0, id_dst, off_src0, off_dst, n_past, n_dims, mode, freq_base, freq_scale, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3); ggml_vk_rope(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, freq_base, freq_scale, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3);
} break; } break;
case GGML_OP_DUP: case GGML_OP_DUP:
case GGML_OP_CPY: case GGML_OP_CPY:

89
kompute/op_rope_f16.comp Normal file
View File

@ -0,0 +1,89 @@
/**
* Copyright (c) 2023 Nomic, Inc. All rights reserved.
*
* This software is licensed under the terms of the Software for Open Models License (SOM),
* version 1.0, as detailed in the LICENSE_SOM.txt file. A copy of this license should accompany
* this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc.
*/
#version 450
#include "common.comp"
// TODO: use a local size of 32 or more (Metal uses 1024)
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int n_dims;
int mode;
float freq_base;
float freq_scale;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
const bool is_neox = (pcs.mode & 2) != 0;
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
const int p = inB[pcs.inBOff + i2];
float theta = pcs.freq_scale * float(p);
if (!is_neox) {
for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);
theta *= theta_scale;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
const float x0 = float(inA[src]);
const float x1 = float(inA[src+1]);
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
}
} else {
const float inv_ndims = -1.f/pcs.n_dims;
for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) {
for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);
theta *= theta_scale;
const uint i0 = ib*pcs.n_dims + ic/2;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
const float x0 = float(inA[src]);
const float x1 = float(inA[src+pcs.n_dims/2]);
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
}
}
}
}

View File

@ -12,13 +12,14 @@
layout(local_size_x = 1) in; layout(local_size_x = 1) in;
layout (binding = 0) readonly buffer tensorIn { float in_[]; }; layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout (binding = 1) writeonly buffer tensorOut { float out_[]; }; layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
layout (push_constant) uniform parameter { layout (push_constant) uniform parameter {
uint inOff; uint inAOff;
uint inBOff;
uint outOff; uint outOff;
uint n_past;
int n_dims; int n_dims;
int mode; int mode;
float freq_base; float freq_base;
@ -42,7 +43,7 @@ void main() {
const bool is_neox = (pcs.mode & 2) != 0; const bool is_neox = (pcs.mode & 2) != 0;
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
const uint p = ((pcs.mode & 1) == 0 ? pcs.n_past + i2 : i2); const int p = inB[pcs.inBOff + i2];
float theta = pcs.freq_scale * float(p); float theta = pcs.freq_scale * float(p);
@ -53,11 +54,11 @@ void main() {
theta *= theta_scale; theta *= theta_scale;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inOff; // Based from in const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
const float x0 = in_[src]; const float x0 = inA[src];
const float x1 = in_[src+1]; const float x1 = inA[src+1];
out_[dst_data] = x0*cos_theta - x1*sin_theta; out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+1] = x0*sin_theta + x1*cos_theta; out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
@ -73,11 +74,11 @@ void main() {
const uint i0 = ib*pcs.n_dims + ic/2; const uint i0 = ib*pcs.n_dims + ic/2;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inOff; // Based from in const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
const float x0 = in_[src]; const float x0 = inA[src];
const float x1 = in_[src+pcs.n_dims/2]; const float x1 = inA[src+pcs.n_dims/2];
out_[dst_data] = x0*cos_theta - x1*sin_theta; out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta; out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;

View File

@ -2772,8 +2772,9 @@ static struct ggml_cgraph * llm_build_llama(
} }
// shift the entire K-cache if needed // shift the entire K-cache if needed
struct ggml_tensor * K_shift = nullptr;
if (do_rope_shift) { if (do_rope_shift) {
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
offload_func_kq(K_shift); offload_func_kq(K_shift);
ggml_set_name(K_shift, "K_shift"); ggml_set_name(K_shift, "K_shift");
ggml_allocr_alloc(lctx.alloc, K_shift); ggml_allocr_alloc(lctx.alloc, K_shift);
@ -3024,6 +3025,11 @@ static struct ggml_cgraph * llm_build_llama(
ggml_vk_h2d_all(lctx.ctx_kompute); ggml_vk_h2d_all(lctx.ctx_kompute);
} else { } else {
ggml_vk_h2d_tensor(lctx.ctx_kompute, toDeviceTensor); ggml_vk_h2d_tensor(lctx.ctx_kompute, toDeviceTensor);
ggml_vk_h2d_tensor(lctx.ctx_kompute, KQ_pos);
ggml_vk_h2d_tensor(lctx.ctx_kompute, KQ_mask);
if (K_shift) {
ggml_vk_h2d_tensor(lctx.ctx_kompute, K_shift);
}
} }
} }
#endif #endif
@ -3589,8 +3595,9 @@ static struct ggml_cgraph * llm_build_falcon(
} }
// shift the entire K-cache if needed // shift the entire K-cache if needed
struct ggml_tensor * K_shift = nullptr;
if (do_rope_shift) { if (do_rope_shift) {
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
offload_func_kq(K_shift); offload_func_kq(K_shift);
ggml_set_name(K_shift, "K_shift"); ggml_set_name(K_shift, "K_shift");
ggml_allocr_alloc(lctx.alloc, K_shift); ggml_allocr_alloc(lctx.alloc, K_shift);
@ -3820,6 +3827,11 @@ static struct ggml_cgraph * llm_build_falcon(
ggml_vk_h2d_all(lctx.ctx_kompute); ggml_vk_h2d_all(lctx.ctx_kompute);
} else { } else {
ggml_vk_h2d_tensor(lctx.ctx_kompute, toDeviceTensor); ggml_vk_h2d_tensor(lctx.ctx_kompute, toDeviceTensor);
ggml_vk_h2d_tensor(lctx.ctx_kompute, KQ_pos);
ggml_vk_h2d_tensor(lctx.ctx_kompute, KQ_mask);
if (K_shift) {
ggml_vk_h2d_tensor(lctx.ctx_kompute, K_shift);
}
} }
} }
#endif #endif