ggml : make GeLU more accurate on CPU

This change makes GeLU more accurate on amd64 and arm64 by using a tanhf
approximation that's been explicitly vectorized for avx512f, avx2, sse2,
and neon. No performance is traded away on these architectures, compared
to the 16-bit lookup table that was being used previously. The impact of
this change can be demonstrated easily with whisper, where it leads to a
measurable improvement in levenshtein distance of model output.
This commit is contained in:
Justine Tunney 2024-08-05 08:50:50 -07:00
parent 8b3befc0e2
commit bb668b608e
No known key found for this signature in database
GPG Key ID: 52965314629936D4

View File

@ -258,7 +258,6 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
}
#define GGML_DEBUG 0
#define GGML_GELU_FP16
#define GGML_GELU_QUICK_FP16
#define GGML_SOFT_MAX_UNROLL 4
@ -393,9 +392,6 @@ typedef double ggml_float;
// global data
//
// precomputed gelu table for f16 (128 KB)
static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
// precomputed quick gelu table for f16 (128 KB)
static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
@ -1842,6 +1838,19 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
#endif
// for GeLU and SiLU
#ifdef __FMA__
#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
#define MADD256(x, y, z) _mm256_fmadd_ps(x, y, z)
#define NMADD256(x, y, z) _mm256_fnmadd_ps(x, y, z)
#else
#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
#define MADD256(x, y, z) _mm256_add_ps(_mm256_mul_ps(x, y), z)
#define NMADD256(x, y, z) _mm256_sub_ps(z, _mm256_mul_ps(x, y))
#endif
//
// ggml context
//
@ -2323,55 +2332,343 @@ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x
inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
static const float GELU_COEF_A = 0.044715f;
////////////////////////////////////////////////////////////////////////////////
// There's always room for GeLU
static const float GELU_COEF_A = .044715f;
static const float GELU_QUICK_COEF = -1.702f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
static const float SQRT_2_OVER_PI = .79788456080286535587989211986876f;
inline static float ggml_gelu_f32(float x) {
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
return .5f*x*(1.f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
const uint16_t * i16 = (const uint16_t *) x;
for (int i = 0; i < n; ++i) {
y[i] = ggml_table_gelu_f16[i16[i]];
#if defined(__ARM_NEON) && defined(__aarch64__)
/* Approximation for single-precision vector tanh (2.58 ULP)
There is no support for signed zero whose sign is removed
There is no support for floating point exception handling
This code is based on the ARM Limited Optimized Routines. */
inline static float32x4_t
ggml_vtanhf(float32x4_t x)
{
const uint32x4_t ix = vreinterpretq_u32_f32(x);
const float32x4_t ax = vabsq_f32(x);
const uint32x4_t iax = vreinterpretq_u32_f32(ax);
const uint32x4_t sign = veorq_u32(ix, iax);
const uint32x4_t is_boring = vcgtq_u32(iax, vdupq_n_u32(0x41102cb3));
const float32x4_t boring = vreinterpretq_f32_u32(vorrq_u32(sign, vdupq_n_u32(0x3f800000)));
const uint32x4_t special = vcgtq_u32(iax, vdupq_n_u32(0x7f800000));
const float32x4_t ex = vmulq_n_f32(x, 2);
const float32x4_t e = { 0x1.715476p+0f, 0x1.62e4p-1f, 0x1.7f7d1cp-20f };
const float32x4_t j = vsubq_f32(vfmaq_laneq_f32(vdupq_n_f32(0x1.8p23f), ex, e, 0), vdupq_n_f32(0x1.8p23f));
const int32x4_t i = vcvtq_s32_f32(j);
const float32x4_t f = vfmsq_laneq_f32(ex, j, e, 1);
const float32x4_t f1 = vfmsq_laneq_f32(f, j, e, 2);
const float32x4_t f2 = vmulq_f32(f1, f1);
const float32x4_t f4 = vmulq_f32(f2, f2);
const float32x4_t p01 = vfmaq_f32(vdupq_n_f32(0x1.fffffep-2), vdupq_n_f32(0x1.5554aep-3), f1);
const float32x4_t p23 = vfmaq_f32(vdupq_n_f32(0x1.555736p-5), vdupq_n_f32(0x1.12287cp-7), f1);
const float32x4_t p03 = vfmaq_f32(p01, p23, f2);
const float32x4_t p = vfmaq_f32(p03, vdupq_n_f32(0x1.6b55a2p-10), f4);
const float32x4_t p2 = vfmaq_f32(f1, f2, p);
const int32x4_t u = vaddq_s32(vshlq_n_s32(i, 23), vdupq_n_s32(0x3f800000));
const float32x4_t t = vreinterpretq_f32_s32(u);
const float32x4_t q = vfmaq_f32(vsubq_f32(t, vdupq_n_f32(1)), p2, t);
const float32x4_t y = vdivq_f32(q, vaddq_f32(q, vdupq_n_f32(2)));
const float32x4_t result = vbslq_f32(is_boring, boring, y);
if (!vpaddd_u64(vreinterpretq_u64_u32(special)))
return result;
return (float32x4_t){ special[0] ? tanhf(x[0]) : result[0],
special[1] ? tanhf(x[1]) : result[1],
special[2] ? tanhf(x[2]) : result[2],
special[3] ? tanhf(x[3]) : result[3] };
}
inline static float32x4_t
ggml_vgeluf(float32x4_t x)
{
const float32x4_t one = vdupq_n_f32(1);
const float32x4_t half = vdupq_n_f32(.5);
const float32x4_t coef_a = vdupq_n_f32(GELU_COEF_A);
const float32x4_t sqrt_2_over_pi = vdupq_n_f32(SQRT_2_OVER_PI);
const float32x4_t x_squared = vmulq_f32(x, x);
const float32x4_t ax2 = vmulq_f32(coef_a, x_squared);
const float32x4_t one_plus_ax2 = vaddq_f32(one, ax2);
const float32x4_t inner = vmulq_f32(vmulq_f32(sqrt_2_over_pi, x), one_plus_ax2);
const float32x4_t tanh_inner = ggml_vtanhf(inner);
const float32x4_t one_plus_tanh = vaddq_f32(one, tanh_inner);
return vmulq_f32(vmulq_f32(half, x), one_plus_tanh);
}
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
/* Approximation for single-precision vector tanh(x) using a
branchless algorithm that offers a maximum error of 4 ULP
108638843x off by one errors
18273656x 2 to 3 ulp errors
124x 4 ulp erors (e.g. 0.203652 [3e508a10])
1x sign flip
There is no support for signed zero whose sign is removed
There is no support for floating point exception handling
This code is based on the ARM Limited Optimized Routines. */
inline static __m512
ggml_vtanhf(__m512 x)
{
const __m512 sign_mask = _mm512_castsi512_ps(_mm512_set1_epi32(0x80000000));
const __m512 one = _mm512_set1_ps(1);
const __m512 two = _mm512_set1_ps(2);
const __m512 ax = _mm512_abs_ps(x);
const __m512 sign = _mm512_and_ps(x, sign_mask);
const __mmask16 is_boring = _mm512_cmp_ps_mask(ax, _mm512_set1_ps(0x1.205966p+3), _CMP_GT_OQ);
const __m512 boring = _mm512_or_ps(sign, one);
const __m512 ex = _mm512_mul_ps(x, two);
const __m512 j = _mm512_fmadd_ps( ex, _mm512_set1_ps(0x1.715476p+0f), _mm512_set1_ps(0x1.8p23f));
const __m512 jj = _mm512_sub_ps(j, _mm512_set1_ps(0x1.8p23f));
const __m512i i = _mm512_cvttps_epi32(jj);
const __m512 f = _mm512_fnmadd_ps(_mm512_set1_ps(0x1.62e4p-1f), jj, ex);
const __m512 f1 = _mm512_fnmadd_ps(_mm512_set1_ps(0x1.7f7d1cp-20f), jj, f);
const __m512 f2 = _mm512_mul_ps(f1, f1);
const __m512 f4 = _mm512_mul_ps(f2, f2);
const __m512 p01 = _mm512_fmadd_ps( f1, _mm512_set1_ps(0x1.5554aep-3), _mm512_set1_ps(0x1.fffffep-2));
const __m512 p23 = _mm512_fmadd_ps( f1, _mm512_set1_ps(0x1.12287cp-7), _mm512_set1_ps(0x1.555736p-5));
const __m512 p03 = _mm512_fmadd_ps(f2, p23, p01);
const __m512 p = _mm512_fmadd_ps(f4, _mm512_set1_ps(0x1.6b55a2p-10), p03);
const __m512 p2 = _mm512_fmadd_ps(f2, p, f1);
const __m512i u = _mm512_add_epi32(_mm512_slli_epi32(i, 23), _mm512_set1_epi32(0x3f800000));
const __m512 t = _mm512_castsi512_ps(u);
const __m512 q = _mm512_fmadd_ps(p2, t, _mm512_sub_ps(t, one));
const __m512 y = _mm512_div_ps(q, _mm512_add_ps(q, two));
return _mm512_mask_blend_ps(is_boring, y, boring);
}
inline static __m512
ggml_vgeluf(__m512 x)
{
const __m512 one = _mm512_set1_ps(1);
const __m512 half = _mm512_set1_ps(.5);
const __m512 coef_a = _mm512_set1_ps(GELU_COEF_A);
const __m512 sqrt_2_over_pi = _mm512_set1_ps(SQRT_2_OVER_PI);
const __m512 x_squared = _mm512_mul_ps(x, x);
const __m512 ax2 = _mm512_mul_ps(coef_a, x_squared);
const __m512 one_plus_ax2 = _mm512_add_ps(one, ax2);
const __m512 inner = _mm512_mul_ps(_mm512_mul_ps(sqrt_2_over_pi, x), one_plus_ax2);
const __m512 tanh_inner = ggml_vtanhf(inner);
const __m512 one_plus_tanh = _mm512_add_ps(one, tanh_inner);
return _mm512_mul_ps(_mm512_mul_ps(half, x), one_plus_tanh);
}
#elif defined(__AVX2__)
/* Approximation for single-precision vector tanh(x) using a
branchless algorithm that offers a maximum error of 4 ULP
With fused multiply add:
108638843x off by one errors
18273656x 2 to 3 ulp errors
124x 4 ulp erors (e.g. 0.203652 [3e508a10])
1x sign flip
Without fused multiply add:
108479590x off by one errors
18209645x 2 to 3 ulp errors
70x 4 ulp errors (e.g. 0.205979 [3e52ec19])
1x sign flip
There is no support for signed zero whose sign is removed
There is no support for floating point exception handling
This code is based on the ARM Limited Optimized Routines. */
inline static __m256
ggml_vtanhf(__m256 x)
{
const __m256 abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF));
const __m256 one = _mm256_set1_ps(1);
const __m256 two = _mm256_set1_ps(2);
const __m256 ax = _mm256_and_ps(x, abs_mask);
const __m256 sign = _mm256_and_ps(x, _mm256_set1_ps(-0.f));
const __m256 is_boring = _mm256_cmp_ps(ax, _mm256_set1_ps(0x1.205966p+3), _CMP_GT_OQ);
const __m256 boring = _mm256_or_ps(sign, one);
const __m256 ex = _mm256_mul_ps(x, two);
const __m256 j = MADD256(ex, _mm256_set1_ps(0x1.715476p+0f), _mm256_set1_ps(0x1.8p23f));
const __m256 jj = _mm256_sub_ps(j, _mm256_set1_ps(0x1.8p23f));
const __m256i i = _mm256_cvttps_epi32(jj);
const __m256 f = NMADD256(_mm256_set1_ps(0x1.62e4p-1f), jj, ex);
const __m256 f1 = NMADD256(_mm256_set1_ps(0x1.7f7d1cp-20f), jj, f);
const __m256 f2 = _mm256_mul_ps(f1, f1);
const __m256 f4 = _mm256_mul_ps(f2, f2);
const __m256 p01 = MADD256(f1, _mm256_set1_ps(0x1.5554aep-3), _mm256_set1_ps(0x1.fffffep-2));
const __m256 p23 = MADD256(f1, _mm256_set1_ps(0x1.12287cp-7), _mm256_set1_ps(0x1.555736p-5));
const __m256 p03 = MADD256(f2, p23, p01);
const __m256 p = MADD256(f4, _mm256_set1_ps(0x1.6b55a2p-10), p03);
const __m256 p2 = MADD256(f2, p, f1);
const __m256i u = _mm256_add_epi32(_mm256_slli_epi32(i, 23), _mm256_set1_epi32(0x3f800000));
const __m256 t = _mm256_castsi256_ps(u);
const __m256 q = MADD256(p2, t, _mm256_sub_ps(t, one));
const __m256 y = _mm256_div_ps(q, _mm256_add_ps(q, two));
return _mm256_or_ps(_mm256_and_ps(is_boring, boring), _mm256_andnot_ps(is_boring, y));
}
inline static __m256
ggml_vgeluf(__m256 x)
{
const __m256 one = _mm256_set1_ps(1);
const __m256 half = _mm256_set1_ps(.5);
const __m256 coef_a = _mm256_set1_ps(GELU_COEF_A);
const __m256 sqrt_2_over_pi = _mm256_set1_ps(SQRT_2_OVER_PI);
const __m256 x_squared = _mm256_mul_ps(x, x);
const __m256 ax2 = _mm256_mul_ps(coef_a, x_squared);
const __m256 one_plus_ax2 = _mm256_add_ps(one, ax2);
const __m256 inner = _mm256_mul_ps(_mm256_mul_ps(sqrt_2_over_pi, x), one_plus_ax2);
const __m256 tanh_inner = ggml_vtanhf(inner);
const __m256 one_plus_tanh = _mm256_add_ps(one, tanh_inner);
return _mm256_mul_ps(_mm256_mul_ps(half, x), one_plus_tanh);
}
#elif defined(__SSE2__)
/* Approximation for single-precision vector tanh(x) using a
branchless algorithm that offers a maximum error of 4 ULP
Without fused multiply add:
108479590x off by one errors
18209645x 2 to 3 ulp errors
70x 4 ulp errors (e.g. 0.205979 [3e52ec19])
1x sign flip
With fused multiply add:
108638843x off by one errors
18273656x 2 to 3 ulp errors
124x 4 ulp erors (e.g. 0.203652 [3e508a10])
1x sign flip
There is no support for signed zero whose sign is removed
There is no support for floating point exception handling
This code is based on the ARM Limited Optimized Routines. */
inline static __m128
ggml_vtanhf(__m128 x)
{
const __m128 abs_mask = _mm_castsi128_ps(_mm_set1_epi32(0x7FFFFFFF));
const __m128 one = _mm_set1_ps(1);
const __m128 two = _mm_set1_ps(2);
const __m128 ax = _mm_and_ps(x, abs_mask);
const __m128 sign = _mm_and_ps(x, _mm_set1_ps(-0.f));
const __m128 is_boring = _mm_cmpgt_ps(ax, _mm_set1_ps(0x1.205966p+3));
const __m128 boring = _mm_or_ps(sign, one);
const __m128 ex = _mm_mul_ps(x, two);
const __m128 j = MADD128(ex, _mm_set1_ps(0x1.715476p+0f), _mm_set1_ps(0x1.8p23f));
const __m128 jj = _mm_sub_ps(j, _mm_set1_ps(0x1.8p23f));
const __m128i i = _mm_cvttps_epi32(jj);
const __m128 f = NMADD128(_mm_set1_ps(0x1.62e4p-1f), jj, ex);
const __m128 f1 = NMADD128(_mm_set1_ps(0x1.7f7d1cp-20f), jj, f);
const __m128 f2 = _mm_mul_ps(f1, f1);
const __m128 f4 = _mm_mul_ps(f2, f2);
const __m128 p01 = MADD128(f1, _mm_set1_ps(0x1.5554aep-3), _mm_set1_ps(0x1.fffffep-2));
const __m128 p23 = MADD128(f1, _mm_set1_ps(0x1.12287cp-7), _mm_set1_ps(0x1.555736p-5));
const __m128 p03 = MADD128(f2, p23, p01);
const __m128 p = MADD128(f4, _mm_set1_ps(0x1.6b55a2p-10), p03);
const __m128 p2 = MADD128(f2, p, f1);
const __m128i u = _mm_add_epi32(_mm_slli_epi32(i, 23), _mm_set1_epi32(0x3f800000));
const __m128 t = _mm_castsi128_ps(u);
const __m128 q = MADD128(p2, t, _mm_sub_ps(t, one));
const __m128 y = _mm_div_ps(q, _mm_add_ps(q, two));
return _mm_or_ps(_mm_and_ps(is_boring, boring), _mm_andnot_ps(is_boring, y));
}
inline static __m128
ggml_vgeluf(__m128 x)
{
const __m128 one = _mm_set1_ps(1);
const __m128 half = _mm_set1_ps(.5);
const __m128 coef_a = _mm_set1_ps(GELU_COEF_A);
const __m128 sqrt_2_over_pi = _mm_set1_ps(SQRT_2_OVER_PI);
const __m128 x_squared = _mm_mul_ps(x, x);
const __m128 ax2 = _mm_mul_ps(coef_a, x_squared);
const __m128 one_plus_ax2 = _mm_add_ps(one, ax2);
const __m128 inner = _mm_mul_ps(_mm_mul_ps(sqrt_2_over_pi, x), one_plus_ax2);
const __m128 tanh_inner = ggml_vtanhf(inner);
const __m128 one_plus_tanh = _mm_add_ps(one, tanh_inner);
return _mm_mul_ps(_mm_mul_ps(half, x), one_plus_tanh);
}
#endif
static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
int i = 0;
#if defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
vst1q_f32(y + i, ggml_vgeluf(vld1q_f32(x + i)));
}
}
#ifdef GGML_GELU_FP16
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
uint16_t t;
for (int i = 0; i < n; ++i) {
if (x[i] <= -10.0f) {
y[i] = 0.0f;
} else if (x[i] >= 10.0f) {
y[i] = x[i];
} else {
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
memcpy(&t, &fp16, sizeof(uint16_t));
y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
if (i < n) {
float temp_x[4] = {0};
float temp_y[4] = {0};
int rem = n - i;
for (int j = 0; j < rem; j++) {
temp_x[j] = x[i + j];
}
float32x4_t x_vec = vld1q_f32(temp_x);
float32x4_t y_vec = ggml_vgeluf(x_vec);
vst1q_f32(temp_y, y_vec);
for (int j = 0; j < rem; j++) {
y[i + j] = temp_y[j];
}
}
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
for (; i + 15 < n; i += 16) {
_mm512_storeu_ps(y + i, ggml_vgeluf(_mm512_loadu_ps(x + i)));
}
if (i < n) {
__mmask16 mask = _cvtu32_mask16((1U << (n - i)) - 1);
__m512 x_vec = _mm512_maskz_loadu_ps(mask, x + i);
__m512 y_vec = ggml_vgeluf(x_vec);
_mm512_mask_storeu_ps(y + i, mask, y_vec);
}
return;
#elif defined(__AVX2__)
for (; i + 7 < n; i += 8) {
_mm256_storeu_ps(y + i, ggml_vgeluf(_mm256_loadu_ps(x + i)));
}
if (i < n) {
__m256i mask = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(n - i), mask);
__m256 x_vec = _mm256_maskload_ps(x + i, mask);
__m256 y_vec = ggml_vgeluf(x_vec);
_mm256_maskstore_ps(y + i, mask, y_vec);
}
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
_mm_storeu_ps(y + i, ggml_vgeluf(_mm_loadu_ps(x + i)));
}
if (i < n) {
float temp_x[4] = {0};
float temp_y[4] = {0};
int rem = n - i;
for (int j = 0; j < rem; j++) {
temp_x[j] = x[i + j];
}
__m128 x_vec = _mm_loadu_ps(temp_x);
__m128 y_vec = ggml_vgeluf(x_vec);
_mm_storeu_ps(temp_y, y_vec);
for (int j = 0; j < rem; j++) {
y[i + j] = temp_y[j];
}
}
}
#else
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
for (int i = 0; i < n; ++i) {
for (; i < n; ++i) {
y[i] = ggml_gelu_f32(x[i]);
}
}
#endif
}
inline static float ggml_gelu_quick_f32(float x) {
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
}
//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
// const uint16_t * i16 = (const uint16_t *) x;
// for (int i = 0; i < n; ++i) {
// y[i] = ggml_table_gelu_quick_f16[i16[i]];
// }
//}
#ifdef GGML_GELU_QUICK_FP16
inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
uint16_t t;
@ -2538,14 +2835,6 @@ inline static __m256 ggml_v_silu(__m256 x) {
#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
#if defined(__FMA__)
#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
#else
#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
#endif
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
@ -3486,7 +3775,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
ggml_fp16_t fp16;
} u = {i};
float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
}