From 32289aa447344fa8a5a8d9f6289af41fb15fd910 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Mon, 2 Oct 2023 21:00:48 -0400 Subject: [PATCH] Fixes for norm. --- kompute/op_norm.comp | 2 +- kompute/op_rmsnorm.comp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kompute/op_norm.comp b/kompute/op_norm.comp index 4b2db25e3..5aafeaac5 100644 --- a/kompute/op_norm.comp +++ b/kompute/op_norm.comp @@ -56,7 +56,7 @@ void main() { const float mean = sum[0]; // recenter - const uint y = (gl_WorkGroupID.x*pcs.ne00/4) + pcs.outOff; // Based from out_ + const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_ for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) { out_[y+i00] = in_[x+i00] - mean; } diff --git a/kompute/op_rmsnorm.comp b/kompute/op_rmsnorm.comp index dd2c5cdde..8d6c0fa6a 100644 --- a/kompute/op_rmsnorm.comp +++ b/kompute/op_rmsnorm.comp @@ -10,7 +10,7 @@ #include "common.comp" -#define nth 256 +#define nth 512 layout(local_size_x = nth) in; @@ -56,7 +56,7 @@ void main() { const float scale = 1.0f/sqrt(sum[0] + pcs.eps); - const uint y = (gl_WorkGroupID.x*pcs.ne00/4) + pcs.outOff; // Based from out_ + const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_ for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) { out_[y+i00] = in_[x+i00] * scale; }