mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 02:14:35 +00:00
kompute : implement op_getrows_f32 (#6403)
op_getrows_f32 is required since https://github.com/ggerganov/llama.cpp/pull/6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again.
This commit is contained in:
parent
3413ae2193
commit
9e405b6e2e
@ -777,6 +777,7 @@ if (LLAMA_KOMPUTE)
|
||||
kompute-shaders/op_mul_mat_q4_0.comp
|
||||
kompute-shaders/op_mul_mat_q4_1.comp
|
||||
kompute-shaders/op_mul_mat_q6_k.comp
|
||||
kompute-shaders/op_getrows_f32.comp
|
||||
kompute-shaders/op_getrows_f16.comp
|
||||
kompute-shaders/op_getrows_q4_0.comp
|
||||
kompute-shaders/op_getrows_q4_1.comp
|
||||
@ -809,6 +810,7 @@ if (LLAMA_KOMPUTE)
|
||||
shaderop_mul_mat_q4_0.h
|
||||
shaderop_mul_mat_q4_1.h
|
||||
shaderop_mul_mat_q6_k.h
|
||||
shaderop_getrows_f32.h
|
||||
shaderop_getrows_f16.h
|
||||
shaderop_getrows_q4_0.h
|
||||
shaderop_getrows_q4_1.h
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "shaderop_mul_mat_q4_1.h"
|
||||
#include "shaderop_mul_mat_q6_k.h"
|
||||
#include "shaderop_mul_mat_mat_f32.h"
|
||||
#include "shaderop_getrows_f32.h"
|
||||
#include "shaderop_getrows_f16.h"
|
||||
#include "shaderop_getrows_q4_0.h"
|
||||
#include "shaderop_getrows_q4_1.h"
|
||||
@ -1146,6 +1147,14 @@ static void ggml_vk_get_rows(
|
||||
seq.record<kp::OpAlgoDispatch>(s_algo);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
static void ggml_vk_get_rows_f32(Args&&... args) {
|
||||
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
|
||||
kp::shader_data::op_getrows_f32_comp_spv_len);
|
||||
|
||||
ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
static void ggml_vk_get_rows_f16(Args&&... args) {
|
||||
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
|
||||
@ -1371,6 +1380,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
|
||||
return op->ne[3] == 1;
|
||||
case GGML_OP_GET_ROWS:
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
@ -1661,7 +1671,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
if (src0t == GGML_TYPE_F16) {
|
||||
if (src0t == GGML_TYPE_F32) {
|
||||
ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
|
||||
} else if (src0t == GGML_TYPE_F16) {
|
||||
ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
|
||||
} else if (src0t == GGML_TYPE_Q4_0) {
|
||||
ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
|
||||
|
31
kompute-shaders/op_getrows_f32.comp
Normal file
31
kompute-shaders/op_getrows_f32.comp
Normal file
@ -0,0 +1,31 @@
|
||||
#version 450
|
||||
|
||||
#include "common.comp"
|
||||
|
||||
layout(local_size_x = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer tensorInA { float inA[]; };
|
||||
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
|
||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint inAOff;
|
||||
uint inBOff;
|
||||
uint outOff;
|
||||
int ne00;
|
||||
int nb01;
|
||||
int nb1;
|
||||
} pcs;
|
||||
|
||||
void dequantize_row_f32(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
out_[y + j] = inA[x + j];
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x;
|
||||
const int r = inB[i + pcs.inBOff];
|
||||
|
||||
dequantize_row_f32(r*pcs.nb01/4 + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
|
||||
}
|
Loading…
Reference in New Issue
Block a user