metal : assert various kernel requirements

This commit is contained in:
Georgi Gerganov 2023-10-08 11:04:20 +03:00
parent 42833bc7a8
commit 0f8df395ce
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 25 additions and 13 deletions

View File

@ -774,8 +774,8 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_CONCAT:
{
const int64_t nb = ne00;
int64_t nb = ne00;
[encoder setComputePipelineState:ctx->pipeline_concat];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@ -807,6 +807,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
const int nth = MIN(1024, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ADD:
@ -904,9 +905,10 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
const int64_t n = ggml_nelements(dst)/4;
const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(gf->nodes[i])) {
@ -916,9 +918,10 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst)/4;
const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_RELU:
{
@ -936,9 +939,10 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst)/4;
const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
{
@ -1220,6 +1224,8 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_RMS_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

View File

@ -347,6 +347,7 @@ kernel void kernel_rms_norm(
uint ntg[[threads_per_threadgroup]]) {
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
device const float * x_scalar = (device const float *) x;
float4 sumf = 0;
float all_sum = 0;
@ -361,6 +362,7 @@ kernel void kernel_rms_norm(
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
@ -368,7 +370,9 @@ kernel void kernel_rms_norm(
}
}
if (tpitg == 0) {
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
sum[0] += x_scalar[i];
}
sum[0] /= ne00;
}
@ -383,7 +387,9 @@ kernel void kernel_rms_norm(
y[i00] = x[i00] * scale;
}
if (tpitg == 0) {
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
y_scalar[i00] = x_scalar[i00] * scale;
}
}
}