Lower the workgroup count for some shaders by providing a loop that processes

four floats at a time.
This commit is contained in:
Adam Treat 2023-10-26 11:48:36 -04:00 committed by cebtenzzre
parent 752f7ebd61
commit 8d9efbf97a
6 changed files with 37 additions and 21 deletions

View File

@ -1358,7 +1358,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
// src1 is a row
ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00);
} else {
ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst));
ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
}
} break;
case GGML_OP_MUL:
@ -1367,7 +1367,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
// src1 is a row
ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00);
} else {
ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst));
ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
}
} break;
case GGML_OP_SCALE:
@ -1379,15 +1379,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_SILU:
{
ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst));
ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
} break;
case GGML_UNARY_OP_RELU:
{
ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst));
ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
} break;
case GGML_UNARY_OP_GELU:
{
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst));
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
} break;
default:
{

View File

@ -23,7 +23,10 @@ layout(push_constant) uniform PushConstants {
} pcs;
void main() {
const uint i = gl_WorkGroupID.x;
const uint baseIndex = gl_WorkGroupID.x * 4;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i) + pcs.inBOff];
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[i + pcs.inBOff];
}
}

View File

@ -20,8 +20,11 @@ layout(push_constant) uniform PushConstants {
} pcs;
void main() {
const uint i = gl_WorkGroupID.x;
const float x = in_[i + pcs.inOff];
const uint baseIndex = gl_WorkGroupID.x * 4;
out_[i + pcs.outOff] = 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x)));
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
const float y = in_[i + pcs.inOff];
out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y)));
}
}

View File

@ -23,7 +23,10 @@ layout(push_constant) uniform PushConstants {
} pcs;
void main() {
const uint i = gl_WorkGroupID.x;
const uint baseIndex = gl_WorkGroupID.x * 4;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i) + pcs.inBOff];
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i) + pcs.inBOff];
}
}

View File

@ -20,7 +20,10 @@ layout(push_constant) uniform PushConstants {
} pcs;
void main() {
const uint i = gl_WorkGroupID.x;
const uint baseIndex = gl_WorkGroupID.x * 4;
out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
}
}

View File

@ -19,8 +19,12 @@ layout(push_constant) uniform PushConstants {
uint outOff;
} pcs;
void main() {
const uint i = gl_WorkGroupID.x;
const float x = in_[i + pcs.inOff];
out_[i + pcs.outOff] = x / (1.0 + exp(-x));
const uint baseIndex = gl_WorkGroupID.x * 4;
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
const float y = in_[i + pcs.inOff];
out_[i + pcs.outOff] = y / (1.0 + exp(-y));
}
}