metal : GGML_OP_RMS_NORM

This commit is contained in:
Georgi Gerganov 2024-11-10 15:31:43 +02:00
parent 967727a8ed
commit aee082fdeb
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 52 additions and 42 deletions

View File

@ -642,6 +642,13 @@ typedef struct {
int32_t ne1; int32_t ne1;
uint64_t nb1; uint64_t nb1;
} ggml_metal_kargs_mul_mv_id; } ggml_metal_kargs_mul_mv_id;
typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
float eps;
} ggml_metal_kargs_rms_norm;
#endif #endif
#endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL

View File

@ -2618,20 +2618,28 @@ static void ggml_metal_encode_node(
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
int nth = 32; // SIMD width int nth = 32; // SIMD width
while (nth < ne00/4 && nth < 1024) { while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
nth *= 2; nth *= 2;
} }
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; nth = MIN(nth, ne00/4);
ggml_metal_kargs_rms_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb01 =*/ nb01,
/*.eps =*/ eps,
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0); const int64_t nrows = ggml_nrows(src0);

View File

@ -1288,50 +1288,45 @@ kernel void kernel_norm(
} }
kernel void kernel_rms_norm( kernel void kernel_rms_norm(
device const void * src0, constant ggml_metal_kargs_rms_norm & args,
device float * dst, device const char * src0,
constant int64_t & ne00, device char * dst,
constant uint64_t & nb01, threadgroup float * shmem_f32 [[threadgroup(0)]],
constant float & eps, uint tgpig[[threadgroup_position_in_grid]],
threadgroup float * buf [[threadgroup(0)]], ushort tpitg[[thread_position_in_threadgroup]],
uint tgpig[[threadgroup_position_in_grid]], ushort sgitg[[simdgroup_index_in_threadgroup]],
uint tpitg[[thread_position_in_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]], ushort ntg[[threads_per_threadgroup]]) {
uint tiisg[[thread_index_in_simdgroup]], if (sgitg == 0) {
uint ntg[[threads_per_threadgroup]]) { shmem_f32[tiisg] = 0.0f;
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); }
float4 sumf = 0; device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
float all_sum = 0;
float sumf = 0.0f;
// parallel sum // parallel sum
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
sumf += x[i00] * x[i00]; sumf += dot(x[i00], x[i00]);
} }
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; sumf = simd_sum(sumf);
all_sum = simd_sum(all_sum);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) { if (tiisg == 0) {
buf[sgitg] = all_sum; shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
all_sum = buf[tiisg];
all_sum = simd_sum(all_sum);
} }
const float mean = all_sum/ne00; threadgroup_barrier(mem_flags::mem_threadgroup);
const float scale = 1.0f/sqrt(mean + eps);
device float4 * y = (device float4 *) (dst + tgpig*ne00); sumf = shmem_f32[tiisg];
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { sumf = simd_sum(sumf);
const float mean = sumf/args.ne00;
const float scale = 1.0f/sqrt(mean + args.eps);
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
y[i00] = x[i00] * scale; y[i00] = x[i00] * scale;
} }
} }