diff --git a/CMakeLists.txt b/CMakeLists.txt index 39dd95eb0..76a03d95f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -476,6 +476,7 @@ if (LLAMA_KOMPUTE) # Compile our shaders compile_shader(SOURCES kompute/op_scale.comp + kompute/op_scale_8.comp kompute/op_add.comp kompute/op_addrow.comp kompute/op_mul.comp @@ -508,6 +509,7 @@ if (LLAMA_KOMPUTE) # Create a custom target for our generated shaders add_custom_target(generated_shaders DEPENDS shaderop_scale.h + shaderop_scale_8.h shaderop_add.h shaderop_addrow.h shaderop_mul.h diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index d4d6d1b87..8c048c77d 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -11,6 +11,7 @@ // These are generated at build time by cmake custom command #include "shaderop_scale.h" +#include "shaderop_scale_8.h" #include "shaderop_add.h" #include "shaderop_addrow.h" #include "shaderop_mul.h" @@ -724,8 +725,12 @@ void ggml_vk_scale(kp::Sequence& seq, const std::shared_ptr& out, uint32_t inOff, uint32_t outOff, uint32_t size, float scale) { - const static auto spirv = getSpirvShader(kp::shader_data::op_scale_comp_spv, - kp::shader_data::op_scale_comp_spv_len); + const static auto spirv_1 = getSpirvShader( + kp::shader_data::op_scale_comp_spv, kp::shader_data::op_scale_comp_spv_len + ); + const static auto spirv_8 = getSpirvShader( + kp::shader_data::op_scale_8_comp_spv, kp::shader_data::op_scale_8_comp_spv_len + ); struct PushConstants { uint32_t inOff, outOff; @@ -735,11 +740,19 @@ void ggml_vk_scale(kp::Sequence& seq, scale }; + const auto * spirv = &spirv_1; + std::string name(__func__); + if (size % 8 == 0) { + size /= 8; + name += "_8"; + spirv = &spirv_8; + } + std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts}); - else { - s_algo = komputeManager()->getAlgorithm(__func__); + if (!komputeManager()->hasAlgorithm(name)) { + s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {in, out}, *spirv, {size}, {}, {pushConsts}); + } else { + s_algo = komputeManager()->getAlgorithm(name); s_algo->setTensors({in, out}); s_algo->setWorkgroup({size}); s_algo->setPushConstants({pushConsts}); @@ -1416,9 +1429,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph case GGML_OP_SCALE: { const float scale = *(const float *) src1->data; - int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 8 == 0); - ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, n/8, scale); + ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale); } break; case GGML_OP_UNARY: { diff --git a/kompute/op_scale.comp b/kompute/op_scale.comp index 2ec524435..be6806091 100644 --- a/kompute/op_scale.comp +++ b/kompute/op_scale.comp @@ -22,10 +22,6 @@ layout(push_constant) uniform PushConstants { } pcs; void main() { - const uint baseIndex = gl_WorkGroupID.x * 8; - - for (uint x = 0; x < 8; x++) { - const uint i = baseIndex + x; - out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale; - } -} \ No newline at end of file + const uint i = gl_WorkGroupID.x; + out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale; +} diff --git a/kompute/op_scale_8.comp b/kompute/op_scale_8.comp new file mode 100644 index 000000000..29fa9b35a --- /dev/null +++ b/kompute/op_scale_8.comp @@ -0,0 +1,31 @@ +/** + * Copyright (c) 2023 Nomic, Inc. All rights reserved. + * + * This software is licensed under the terms of the Software for Open Models License (SOM), + * version 1.0, as detailed in the LICENSE_SOM.txt file. A copy of this license should accompany + * this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc. + */ + +#version 450 + +#include "common.comp" + +layout(local_size_x = 1) in; + +layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; +layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; }; + +layout(push_constant) uniform PushConstants { + uint inOff; + uint outOff; + float scale; +} pcs; + +void main() { + const uint baseIndex = gl_WorkGroupID.x * 8; + + for (uint x = 0; x < 8; x++) { + const uint i = baseIndex + x; + out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale; + } +}