vulkan : assert various kernel requirements

This commit is contained in:
Jared Van Bortel 2023-11-14 11:59:58 -05:00
parent f194e1b6a6
commit a934b2cb8a

View File

@ -1416,27 +1416,34 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
case GGML_OP_SCALE: case GGML_OP_SCALE:
{ {
const float scale = *(const float *) src1->data; const float scale = *(const float *) src1->data;
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/8, scale); int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 8 == 0);
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, n/8, scale);
} break; } break;
case GGML_OP_UNARY: case GGML_OP_UNARY:
switch (ggml_get_unary_op(gf->nodes[i])) { {
case GGML_UNARY_OP_SILU: int64_t n = ggml_nelements(dst);
{ GGML_ASSERT(n % 4 == 0);
ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4); switch (ggml_get_unary_op(gf->nodes[i])) {
} break; case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU: {
{ ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4); } break;
} break; case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_GELU: {
{ ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/8); } break;
} break; case GGML_UNARY_OP_GELU:
default: {
{ GGML_ASSERT(n % 8 == 0);
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, n/8);
GGML_ASSERT(false); } break;
} default:
{
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_ASSERT(false);
}
}
} break; } break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
{ {
@ -1455,6 +1462,8 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
} break; } break;
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
{ {
GGML_ASSERT(ne00 % 4 == 0);
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps); ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);