Lower the workgroup count for some shaders by providing a loop that processes

four floats at a time.
This commit is contained in:
Adam Treat 2023-10-26 11:48:36 -04:00 committed by cebtenzzre
parent 752f7ebd61
commit 8d9efbf97a
6 changed files with 37 additions and 21 deletions

View File

@ -1358,7 +1358,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
// src1 is a row // src1 is a row
ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00); ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00);
} else { } else {
ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)); ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
} }
} break; } break;
case GGML_OP_MUL: case GGML_OP_MUL:
@ -1367,7 +1367,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
// src1 is a row // src1 is a row
ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00); ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00);
} else { } else {
ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)); ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
} }
} break; } break;
case GGML_OP_SCALE: case GGML_OP_SCALE:
@ -1379,15 +1379,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
switch (ggml_get_unary_op(gf->nodes[i])) { switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
{ {
ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)); ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
} break; } break;
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
{ {
ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)); ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
} break; } break;
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:
{ {
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)); ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
} break; } break;
default: default:
{ {
@ -1427,9 +1427,9 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
ggml_is_transposed(src1)) { ggml_is_transposed(src1)) {
fprintf(stderr, "%s: %s: matmul on tranposed tensor not supported: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t); fprintf(stderr, "%s: %s: matmul on tranposed tensor not supported: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
goto not_implemented; goto not_implemented;
} }
switch (src0t) { switch (src0t) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
ggml_vk_mul_mat_mat_f32(seq, ggml_vk_mul_mat_mat_f32(seq,
id_src0, id_src1, id_dst, id_src0, id_src1, id_dst,
@ -1459,7 +1459,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
goto not_implemented; goto not_implemented;
} }
} }
} break; } break;
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
{ {

View File

@ -23,7 +23,10 @@ layout(push_constant) uniform PushConstants {
} pcs; } pcs;
void main() { void main() {
const uint i = gl_WorkGroupID.x; const uint baseIndex = gl_WorkGroupID.x * 4;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i) + pcs.inBOff]; for (uint x = 0; x < 4; x++) {
} const uint i = baseIndex + x;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[i + pcs.inBOff];
}
}

View File

@ -20,8 +20,11 @@ layout(push_constant) uniform PushConstants {
} pcs; } pcs;
void main() { void main() {
const uint i = gl_WorkGroupID.x; const uint baseIndex = gl_WorkGroupID.x * 4;
const float x = in_[i + pcs.inOff];
out_[i + pcs.outOff] = 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x))); for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
const float y = in_[i + pcs.inOff];
out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y)));
}
} }

View File

@ -23,7 +23,10 @@ layout(push_constant) uniform PushConstants {
} pcs; } pcs;
void main() { void main() {
const uint i = gl_WorkGroupID.x; const uint baseIndex = gl_WorkGroupID.x * 4;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i) + pcs.inBOff]; for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i) + pcs.inBOff];
}
} }

View File

@ -20,7 +20,10 @@ layout(push_constant) uniform PushConstants {
} pcs; } pcs;
void main() { void main() {
const uint i = gl_WorkGroupID.x; const uint baseIndex = gl_WorkGroupID.x * 4;
out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]); for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
}
} }

View File

@ -19,8 +19,12 @@ layout(push_constant) uniform PushConstants {
uint outOff; uint outOff;
} pcs; } pcs;
void main() { void main() {
const uint i = gl_WorkGroupID.x;
const float x = in_[i + pcs.inOff];
out_[i + pcs.outOff] = x / (1.0 + exp(-x)); const uint baseIndex = gl_WorkGroupID.x * 4;
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
const float y = in_[i + pcs.inOff];
out_[i + pcs.outOff] = y / (1.0 + exp(-y));
}
} }