mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 09:11:46 +00:00
metal : GGML_OP_NORM
This commit is contained in:
parent
aee082fdeb
commit
7941b6b9ec
@ -643,6 +643,13 @@ typedef struct {
|
|||||||
uint64_t nb1;
|
uint64_t nb1;
|
||||||
} ggml_metal_kargs_mul_mv_id;
|
} ggml_metal_kargs_mul_mv_id;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int32_t ne00;
|
||||||
|
int32_t ne00_4;
|
||||||
|
uint64_t nb01;
|
||||||
|
float eps;
|
||||||
|
} ggml_metal_kargs_norm;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t ne00;
|
int32_t ne00;
|
||||||
int32_t ne00_4;
|
int32_t ne00_4;
|
||||||
|
@ -2681,22 +2681,35 @@ static void ggml_metal_encode_node(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
const int nth = MIN(256, ne00);
|
|
||||||
|
|
||||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
|
||||||
|
|
||||||
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
|
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||||
|
nth *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
nth = MIN(nth, ne00/4);
|
||||||
|
|
||||||
|
ggml_metal_kargs_norm args = {
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.ne00_4 =*/ ne00/4,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.eps =*/ eps,
|
||||||
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
||||||
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||||
[encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
|
|
||||||
|
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
|
@ -1236,53 +1236,68 @@ kernel void kernel_ssm_scan_f32(
|
|||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_norm(
|
kernel void kernel_norm(
|
||||||
device const void * src0,
|
constant ggml_metal_kargs_norm & args,
|
||||||
device float * dst,
|
device const char * src0,
|
||||||
constant int64_t & ne00,
|
device char * dst,
|
||||||
constant uint64_t & nb01,
|
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||||
constant float & eps,
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
threadgroup float * sum [[threadgroup(0)]],
|
ushort tpitg[[thread_position_in_threadgroup]],
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
uint tpitg[[thread_position_in_threadgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
uint ntg[[threads_per_threadgroup]]) {
|
ushort ntg[[threads_per_threadgroup]]) {
|
||||||
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
|
if (sgitg == 0) {
|
||||||
// MEAN
|
shmem_f32[tiisg] = 0.0f;
|
||||||
// parallel sum
|
|
||||||
sum[tpitg] = 0.0f;
|
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
||||||
sum[tpitg] += x[i00];
|
|
||||||
}
|
}
|
||||||
// reduce
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
for (uint i = ntg/2; i > 0; i /= 2) {
|
|
||||||
if (tpitg < i) {
|
|
||||||
sum[tpitg] += sum[tpitg + i];
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
}
|
|
||||||
const float mean = sum[0] / ne00;
|
|
||||||
|
|
||||||
// recenter and VARIANCE
|
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
|
||||||
|
|
||||||
|
float4 sumf4(0.0f);
|
||||||
|
|
||||||
|
float sumf = 0.0f;
|
||||||
|
|
||||||
|
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||||
|
sumf4 += x[i00];
|
||||||
|
}
|
||||||
|
sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3];
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
device float * y = dst + tgpig*ne00;
|
|
||||||
sum[tpitg] = 0.0f;
|
if (tiisg == 0) {
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
shmem_f32[sgitg] = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
sumf = shmem_f32[tiisg];
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
const float mean = sumf/args.ne00;
|
||||||
|
|
||||||
|
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
|
||||||
|
|
||||||
|
sumf = 0.0f;
|
||||||
|
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||||
y[i00] = x[i00] - mean;
|
y[i00] = x[i00] - mean;
|
||||||
sum[tpitg] += y[i00] * y[i00];
|
sumf += dot(y[i00], y[i00]);
|
||||||
}
|
}
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
// reduce
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for (uint i = ntg/2; i > 0; i /= 2) {
|
|
||||||
if (tpitg < i) {
|
|
||||||
sum[tpitg] += sum[tpitg + i];
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
}
|
|
||||||
const float variance = sum[0] / ne00;
|
|
||||||
|
|
||||||
const float scale = 1.0f/sqrt(variance + eps);
|
if (tiisg == 0) {
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
shmem_f32[sgitg] = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
sumf = shmem_f32[tiisg];
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
const float variance = sumf/args.ne00;
|
||||||
|
|
||||||
|
const float scale = 1.0f/sqrt(variance + args.eps);
|
||||||
|
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||||
y[i00] = y[i00] * scale;
|
y[i00] = y[i00] * scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user