Lower the workgroup count for some shaders by providing a loop that processes

four floats at a time.
This commit is contained in:
Adam Treat 2023-10-26 11:48:36 -04:00 committed by cebtenzzre
parent 752f7ebd61
commit 8d9efbf97a
6 changed files with 37 additions and 21 deletions

View File

@ -1358,7 +1358,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
// src1 is a row // src1 is a row
ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00); ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00);
} else { } else {
ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)); ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
} }
} break; } break;
case GGML_OP_MUL: case GGML_OP_MUL:
@ -1367,7 +1367,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
// src1 is a row // src1 is a row
ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00); ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00);
} else { } else {
ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)); ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
} }
} break; } break;
case GGML_OP_SCALE: case GGML_OP_SCALE:
@ -1379,15 +1379,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
switch (ggml_get_unary_op(gf->nodes[i])) { switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
{ {
ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)); ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
} break; } break;
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
{ {
ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)); ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
} break; } break;
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:
{ {
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)); ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
} break; } break;
default: default:
{ {

View File

@ -23,7 +23,10 @@ layout(push_constant) uniform PushConstants {
} pcs; } pcs;
void main() { void main() {
const uint i = gl_WorkGroupID.x; const uint baseIndex = gl_WorkGroupID.x * 4;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i) + pcs.inBOff]; for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[i + pcs.inBOff];
}
} }

View File

@ -20,8 +20,11 @@ layout(push_constant) uniform PushConstants {
} pcs; } pcs;
void main() { void main() {
const uint i = gl_WorkGroupID.x; const uint baseIndex = gl_WorkGroupID.x * 4;
const float x = in_[i + pcs.inOff];
out_[i + pcs.outOff] = 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x))); for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
const float y = in_[i + pcs.inOff];
out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y)));
}
} }

View File

@ -23,7 +23,10 @@ layout(push_constant) uniform PushConstants {
} pcs; } pcs;
void main() { void main() {
const uint i = gl_WorkGroupID.x; const uint baseIndex = gl_WorkGroupID.x * 4;
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i) + pcs.inBOff]; out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i) + pcs.inBOff];
} }
}

View File

@ -20,7 +20,10 @@ layout(push_constant) uniform PushConstants {
} pcs; } pcs;
void main() { void main() {
const uint i = gl_WorkGroupID.x; const uint baseIndex = gl_WorkGroupID.x * 4;
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]); out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
} }
}

View File

@ -19,8 +19,12 @@ layout(push_constant) uniform PushConstants {
uint outOff; uint outOff;
} pcs; } pcs;
void main() { void main() {
const uint i = gl_WorkGroupID.x;
const float x = in_[i + pcs.inOff];
out_[i + pcs.outOff] = x / (1.0 + exp(-x)); const uint baseIndex = gl_WorkGroupID.x * 4;
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
const float y = in_[i + pcs.inOff];
out_[i + pcs.outOff] = y / (1.0 + exp(-y));
}
} }