vulkan : handle ggml_scale for n%8 != 0

ref ggerganov/llama.cpp#3754
This commit is contained in:
Jared Van Bortel 2023-11-14 12:10:52 -05:00
parent 2a41ba7258
commit 6474fc879a
4 changed files with 56 additions and 16 deletions

View File

@ -476,6 +476,7 @@ if (LLAMA_KOMPUTE)
# Compile our shaders
compile_shader(SOURCES
kompute/op_scale.comp
kompute/op_scale_8.comp
kompute/op_add.comp
kompute/op_addrow.comp
kompute/op_mul.comp
@ -508,6 +509,7 @@ if (LLAMA_KOMPUTE)
# Create a custom target for our generated shaders
add_custom_target(generated_shaders DEPENDS
shaderop_scale.h
shaderop_scale_8.h
shaderop_add.h
shaderop_addrow.h
shaderop_mul.h

View File

@ -11,6 +11,7 @@
// These are generated at build time by cmake custom command
#include "shaderop_scale.h"
#include "shaderop_scale_8.h"
#include "shaderop_add.h"
#include "shaderop_addrow.h"
#include "shaderop_mul.h"
@ -724,8 +725,12 @@ void ggml_vk_scale(kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& out,
uint32_t inOff, uint32_t outOff,
uint32_t size, float scale) {
const static auto spirv = getSpirvShader(kp::shader_data::op_scale_comp_spv,
kp::shader_data::op_scale_comp_spv_len);
const static auto spirv_1 = getSpirvShader(
kp::shader_data::op_scale_comp_spv, kp::shader_data::op_scale_comp_spv_len
);
const static auto spirv_8 = getSpirvShader(
kp::shader_data::op_scale_8_comp_spv, kp::shader_data::op_scale_8_comp_spv_len
);
struct PushConstants {
uint32_t inOff, outOff;
@ -735,11 +740,19 @@ void ggml_vk_scale(kp::Sequence& seq,
scale
};
const auto * spirv = &spirv_1;
std::string name(__func__);
if (size % 8 == 0) {
size /= 8;
name += "_8";
spirv = &spirv_8;
}
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
if (!komputeManager()->hasAlgorithm(__func__))
s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts});
else {
s_algo = komputeManager()->getAlgorithm(__func__);
if (!komputeManager()->hasAlgorithm(name)) {
s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, *spirv, {size}, {}, {pushConsts});
} else {
s_algo = komputeManager()->getAlgorithm(name);
s_algo->setTensors({in, out});
s_algo->setWorkgroup({size});
s_algo->setPushConstants<PushConstants>({pushConsts});
@ -1416,9 +1429,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
case GGML_OP_SCALE:
{
const float scale = *(const float *) src1->data;
int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 8 == 0);
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, n/8, scale);
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale);
} break;
case GGML_OP_UNARY:
{

View File

@ -22,10 +22,6 @@ layout(push_constant) uniform PushConstants {
} pcs;
void main() {
const uint baseIndex = gl_WorkGroupID.x * 8;
for (uint x = 0; x < 8; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
}
const uint i = gl_WorkGroupID.x;
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
}

31
kompute/op_scale_8.comp Normal file
View File

@ -0,0 +1,31 @@
/**
* Copyright (c) 2023 Nomic, Inc. All rights reserved.
*
* This software is licensed under the terms of the Software for Open Models License (SOM),
* version 1.0, as detailed in the LICENSE_SOM.txt file. A copy of this license should accompany
* this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc.
*/
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
float scale;
} pcs;
void main() {
const uint baseIndex = gl_WorkGroupID.x * 8;
for (uint x = 0; x < 8; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
}
}