diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 88e4fb732..6c7b11ec6 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -258,7 +258,6 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) { } #define GGML_DEBUG 0 -#define GGML_GELU_FP16 #define GGML_GELU_QUICK_FP16 #define GGML_SOFT_MAX_UNROLL 4 @@ -393,9 +392,6 @@ typedef double ggml_float; // global data // -// precomputed gelu table for f16 (128 KB) -static ggml_fp16_t ggml_table_gelu_f16[1 << 16]; - // precomputed quick gelu table for f16 (128 KB) static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16]; @@ -1842,6 +1838,19 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) #endif +// for GeLU and SiLU +#ifdef __FMA__ +#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z) +#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z) +#define MADD256(x, y, z) _mm256_fmadd_ps(x, y, z) +#define NMADD256(x, y, z) _mm256_fnmadd_ps(x, y, z) +#else +#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z) +#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y)) +#define MADD256(x, y, z) _mm256_add_ps(_mm256_mul_ps(x, y), z) +#define NMADD256(x, y, z) _mm256_sub_ps(z, _mm256_mul_ps(x, y)) +#endif + // // ggml context // @@ -2323,55 +2332,343 @@ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } -static const float GELU_COEF_A = 0.044715f; +//////////////////////////////////////////////////////////////////////////////// +// There's always room for GeLU + +static const float GELU_COEF_A = .044715f; static const float GELU_QUICK_COEF = -1.702f; -static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; +static const float SQRT_2_OVER_PI = .79788456080286535587989211986876f; inline static float ggml_gelu_f32(float x) { - return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + return .5f*x*(1.f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } -inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { - const uint16_t * i16 = (const uint16_t *) x; - for (int i = 0; i < n; ++i) { - y[i] = ggml_table_gelu_f16[i16[i]]; +#if defined(__ARM_NEON) && defined(__aarch64__) + +/* Approximation for single-precision vector tanh (2.58 ULP) + There is no support for signed zero whose sign is removed + There is no support for floating point exception handling + This code is based on the ARM Limited Optimized Routines. */ +inline static float32x4_t +ggml_vtanhf(float32x4_t x) +{ + const uint32x4_t ix = vreinterpretq_u32_f32(x); + const float32x4_t ax = vabsq_f32(x); + const uint32x4_t iax = vreinterpretq_u32_f32(ax); + const uint32x4_t sign = veorq_u32(ix, iax); + const uint32x4_t is_boring = vcgtq_u32(iax, vdupq_n_u32(0x41102cb3)); + const float32x4_t boring = vreinterpretq_f32_u32(vorrq_u32(sign, vdupq_n_u32(0x3f800000))); + const uint32x4_t special = vcgtq_u32(iax, vdupq_n_u32(0x7f800000)); + const float32x4_t ex = vmulq_n_f32(x, 2); + const float32x4_t e = { 0x1.715476p+0f, 0x1.62e4p-1f, 0x1.7f7d1cp-20f }; + const float32x4_t j = vsubq_f32(vfmaq_laneq_f32(vdupq_n_f32(0x1.8p23f), ex, e, 0), vdupq_n_f32(0x1.8p23f)); + const int32x4_t i = vcvtq_s32_f32(j); + const float32x4_t f = vfmsq_laneq_f32(ex, j, e, 1); + const float32x4_t f1 = vfmsq_laneq_f32(f, j, e, 2); + const float32x4_t f2 = vmulq_f32(f1, f1); + const float32x4_t f4 = vmulq_f32(f2, f2); + const float32x4_t p01 = vfmaq_f32(vdupq_n_f32(0x1.fffffep-2), vdupq_n_f32(0x1.5554aep-3), f1); + const float32x4_t p23 = vfmaq_f32(vdupq_n_f32(0x1.555736p-5), vdupq_n_f32(0x1.12287cp-7), f1); + const float32x4_t p03 = vfmaq_f32(p01, p23, f2); + const float32x4_t p = vfmaq_f32(p03, vdupq_n_f32(0x1.6b55a2p-10), f4); + const float32x4_t p2 = vfmaq_f32(f1, f2, p); + const int32x4_t u = vaddq_s32(vshlq_n_s32(i, 23), vdupq_n_s32(0x3f800000)); + const float32x4_t t = vreinterpretq_f32_s32(u); + const float32x4_t q = vfmaq_f32(vsubq_f32(t, vdupq_n_f32(1)), p2, t); + const float32x4_t y = vdivq_f32(q, vaddq_f32(q, vdupq_n_f32(2))); + const float32x4_t result = vbslq_f32(is_boring, boring, y); + if (!vpaddd_u64(vreinterpretq_u64_u32(special))) + return result; + return (float32x4_t){ special[0] ? tanhf(x[0]) : result[0], + special[1] ? tanhf(x[1]) : result[1], + special[2] ? tanhf(x[2]) : result[2], + special[3] ? tanhf(x[3]) : result[3] }; +} + +inline static float32x4_t +ggml_vgeluf(float32x4_t x) +{ + const float32x4_t one = vdupq_n_f32(1); + const float32x4_t half = vdupq_n_f32(.5); + const float32x4_t coef_a = vdupq_n_f32(GELU_COEF_A); + const float32x4_t sqrt_2_over_pi = vdupq_n_f32(SQRT_2_OVER_PI); + const float32x4_t x_squared = vmulq_f32(x, x); + const float32x4_t ax2 = vmulq_f32(coef_a, x_squared); + const float32x4_t one_plus_ax2 = vaddq_f32(one, ax2); + const float32x4_t inner = vmulq_f32(vmulq_f32(sqrt_2_over_pi, x), one_plus_ax2); + const float32x4_t tanh_inner = ggml_vtanhf(inner); + const float32x4_t one_plus_tanh = vaddq_f32(one, tanh_inner); + return vmulq_f32(vmulq_f32(half, x), one_plus_tanh); +} + +#elif defined(__AVX512F__) && defined(__AVX512DQ__) + +/* Approximation for single-precision vector tanh(x) using a + branchless algorithm that offers a maximum error of 4 ULP + + 108638843x off by one errors + 18273656x 2 to 3 ulp errors + 124x 4 ulp erors (e.g. 0.203652 [3e508a10]) + 1x sign flip + + There is no support for signed zero whose sign is removed + There is no support for floating point exception handling + This code is based on the ARM Limited Optimized Routines. */ +inline static __m512 +ggml_vtanhf(__m512 x) +{ + const __m512 sign_mask = _mm512_castsi512_ps(_mm512_set1_epi32(0x80000000)); + const __m512 one = _mm512_set1_ps(1); + const __m512 two = _mm512_set1_ps(2); + const __m512 ax = _mm512_abs_ps(x); + const __m512 sign = _mm512_and_ps(x, sign_mask); + const __mmask16 is_boring = _mm512_cmp_ps_mask(ax, _mm512_set1_ps(0x1.205966p+3), _CMP_GT_OQ); + const __m512 boring = _mm512_or_ps(sign, one); + const __m512 ex = _mm512_mul_ps(x, two); + const __m512 j = _mm512_fmadd_ps( ex, _mm512_set1_ps(0x1.715476p+0f), _mm512_set1_ps(0x1.8p23f)); + const __m512 jj = _mm512_sub_ps(j, _mm512_set1_ps(0x1.8p23f)); + const __m512i i = _mm512_cvttps_epi32(jj); + const __m512 f = _mm512_fnmadd_ps(_mm512_set1_ps(0x1.62e4p-1f), jj, ex); + const __m512 f1 = _mm512_fnmadd_ps(_mm512_set1_ps(0x1.7f7d1cp-20f), jj, f); + const __m512 f2 = _mm512_mul_ps(f1, f1); + const __m512 f4 = _mm512_mul_ps(f2, f2); + const __m512 p01 = _mm512_fmadd_ps( f1, _mm512_set1_ps(0x1.5554aep-3), _mm512_set1_ps(0x1.fffffep-2)); + const __m512 p23 = _mm512_fmadd_ps( f1, _mm512_set1_ps(0x1.12287cp-7), _mm512_set1_ps(0x1.555736p-5)); + const __m512 p03 = _mm512_fmadd_ps(f2, p23, p01); + const __m512 p = _mm512_fmadd_ps(f4, _mm512_set1_ps(0x1.6b55a2p-10), p03); + const __m512 p2 = _mm512_fmadd_ps(f2, p, f1); + const __m512i u = _mm512_add_epi32(_mm512_slli_epi32(i, 23), _mm512_set1_epi32(0x3f800000)); + const __m512 t = _mm512_castsi512_ps(u); + const __m512 q = _mm512_fmadd_ps(p2, t, _mm512_sub_ps(t, one)); + const __m512 y = _mm512_div_ps(q, _mm512_add_ps(q, two)); + return _mm512_mask_blend_ps(is_boring, y, boring); +} + +inline static __m512 +ggml_vgeluf(__m512 x) +{ + const __m512 one = _mm512_set1_ps(1); + const __m512 half = _mm512_set1_ps(.5); + const __m512 coef_a = _mm512_set1_ps(GELU_COEF_A); + const __m512 sqrt_2_over_pi = _mm512_set1_ps(SQRT_2_OVER_PI); + const __m512 x_squared = _mm512_mul_ps(x, x); + const __m512 ax2 = _mm512_mul_ps(coef_a, x_squared); + const __m512 one_plus_ax2 = _mm512_add_ps(one, ax2); + const __m512 inner = _mm512_mul_ps(_mm512_mul_ps(sqrt_2_over_pi, x), one_plus_ax2); + const __m512 tanh_inner = ggml_vtanhf(inner); + const __m512 one_plus_tanh = _mm512_add_ps(one, tanh_inner); + return _mm512_mul_ps(_mm512_mul_ps(half, x), one_plus_tanh); +} + +#elif defined(__AVX2__) + +/* Approximation for single-precision vector tanh(x) using a + branchless algorithm that offers a maximum error of 4 ULP + + With fused multiply add: + + 108638843x off by one errors + 18273656x 2 to 3 ulp errors + 124x 4 ulp erors (e.g. 0.203652 [3e508a10]) + 1x sign flip + + Without fused multiply add: + + 108479590x off by one errors + 18209645x 2 to 3 ulp errors + 70x 4 ulp errors (e.g. 0.205979 [3e52ec19]) + 1x sign flip + + There is no support for signed zero whose sign is removed + There is no support for floating point exception handling + This code is based on the ARM Limited Optimized Routines. */ +inline static __m256 +ggml_vtanhf(__m256 x) +{ + const __m256 abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)); + const __m256 one = _mm256_set1_ps(1); + const __m256 two = _mm256_set1_ps(2); + const __m256 ax = _mm256_and_ps(x, abs_mask); + const __m256 sign = _mm256_and_ps(x, _mm256_set1_ps(-0.f)); + const __m256 is_boring = _mm256_cmp_ps(ax, _mm256_set1_ps(0x1.205966p+3), _CMP_GT_OQ); + const __m256 boring = _mm256_or_ps(sign, one); + const __m256 ex = _mm256_mul_ps(x, two); + const __m256 j = MADD256(ex, _mm256_set1_ps(0x1.715476p+0f), _mm256_set1_ps(0x1.8p23f)); + const __m256 jj = _mm256_sub_ps(j, _mm256_set1_ps(0x1.8p23f)); + const __m256i i = _mm256_cvttps_epi32(jj); + const __m256 f = NMADD256(_mm256_set1_ps(0x1.62e4p-1f), jj, ex); + const __m256 f1 = NMADD256(_mm256_set1_ps(0x1.7f7d1cp-20f), jj, f); + const __m256 f2 = _mm256_mul_ps(f1, f1); + const __m256 f4 = _mm256_mul_ps(f2, f2); + const __m256 p01 = MADD256(f1, _mm256_set1_ps(0x1.5554aep-3), _mm256_set1_ps(0x1.fffffep-2)); + const __m256 p23 = MADD256(f1, _mm256_set1_ps(0x1.12287cp-7), _mm256_set1_ps(0x1.555736p-5)); + const __m256 p03 = MADD256(f2, p23, p01); + const __m256 p = MADD256(f4, _mm256_set1_ps(0x1.6b55a2p-10), p03); + const __m256 p2 = MADD256(f2, p, f1); + const __m256i u = _mm256_add_epi32(_mm256_slli_epi32(i, 23), _mm256_set1_epi32(0x3f800000)); + const __m256 t = _mm256_castsi256_ps(u); + const __m256 q = MADD256(p2, t, _mm256_sub_ps(t, one)); + const __m256 y = _mm256_div_ps(q, _mm256_add_ps(q, two)); + return _mm256_or_ps(_mm256_and_ps(is_boring, boring), _mm256_andnot_ps(is_boring, y)); +} + +inline static __m256 +ggml_vgeluf(__m256 x) +{ + const __m256 one = _mm256_set1_ps(1); + const __m256 half = _mm256_set1_ps(.5); + const __m256 coef_a = _mm256_set1_ps(GELU_COEF_A); + const __m256 sqrt_2_over_pi = _mm256_set1_ps(SQRT_2_OVER_PI); + const __m256 x_squared = _mm256_mul_ps(x, x); + const __m256 ax2 = _mm256_mul_ps(coef_a, x_squared); + const __m256 one_plus_ax2 = _mm256_add_ps(one, ax2); + const __m256 inner = _mm256_mul_ps(_mm256_mul_ps(sqrt_2_over_pi, x), one_plus_ax2); + const __m256 tanh_inner = ggml_vtanhf(inner); + const __m256 one_plus_tanh = _mm256_add_ps(one, tanh_inner); + return _mm256_mul_ps(_mm256_mul_ps(half, x), one_plus_tanh); +} + +#elif defined(__SSE2__) + +/* Approximation for single-precision vector tanh(x) using a + branchless algorithm that offers a maximum error of 4 ULP + + Without fused multiply add: + + 108479590x off by one errors + 18209645x 2 to 3 ulp errors + 70x 4 ulp errors (e.g. 0.205979 [3e52ec19]) + 1x sign flip + + With fused multiply add: + + 108638843x off by one errors + 18273656x 2 to 3 ulp errors + 124x 4 ulp erors (e.g. 0.203652 [3e508a10]) + 1x sign flip + + There is no support for signed zero whose sign is removed + There is no support for floating point exception handling + This code is based on the ARM Limited Optimized Routines. */ +inline static __m128 +ggml_vtanhf(__m128 x) +{ + const __m128 abs_mask = _mm_castsi128_ps(_mm_set1_epi32(0x7FFFFFFF)); + const __m128 one = _mm_set1_ps(1); + const __m128 two = _mm_set1_ps(2); + const __m128 ax = _mm_and_ps(x, abs_mask); + const __m128 sign = _mm_and_ps(x, _mm_set1_ps(-0.f)); + const __m128 is_boring = _mm_cmpgt_ps(ax, _mm_set1_ps(0x1.205966p+3)); + const __m128 boring = _mm_or_ps(sign, one); + const __m128 ex = _mm_mul_ps(x, two); + const __m128 j = MADD128(ex, _mm_set1_ps(0x1.715476p+0f), _mm_set1_ps(0x1.8p23f)); + const __m128 jj = _mm_sub_ps(j, _mm_set1_ps(0x1.8p23f)); + const __m128i i = _mm_cvttps_epi32(jj); + const __m128 f = NMADD128(_mm_set1_ps(0x1.62e4p-1f), jj, ex); + const __m128 f1 = NMADD128(_mm_set1_ps(0x1.7f7d1cp-20f), jj, f); + const __m128 f2 = _mm_mul_ps(f1, f1); + const __m128 f4 = _mm_mul_ps(f2, f2); + const __m128 p01 = MADD128(f1, _mm_set1_ps(0x1.5554aep-3), _mm_set1_ps(0x1.fffffep-2)); + const __m128 p23 = MADD128(f1, _mm_set1_ps(0x1.12287cp-7), _mm_set1_ps(0x1.555736p-5)); + const __m128 p03 = MADD128(f2, p23, p01); + const __m128 p = MADD128(f4, _mm_set1_ps(0x1.6b55a2p-10), p03); + const __m128 p2 = MADD128(f2, p, f1); + const __m128i u = _mm_add_epi32(_mm_slli_epi32(i, 23), _mm_set1_epi32(0x3f800000)); + const __m128 t = _mm_castsi128_ps(u); + const __m128 q = MADD128(p2, t, _mm_sub_ps(t, one)); + const __m128 y = _mm_div_ps(q, _mm_add_ps(q, two)); + return _mm_or_ps(_mm_and_ps(is_boring, boring), _mm_andnot_ps(is_boring, y)); +} + +inline static __m128 +ggml_vgeluf(__m128 x) +{ + const __m128 one = _mm_set1_ps(1); + const __m128 half = _mm_set1_ps(.5); + const __m128 coef_a = _mm_set1_ps(GELU_COEF_A); + const __m128 sqrt_2_over_pi = _mm_set1_ps(SQRT_2_OVER_PI); + const __m128 x_squared = _mm_mul_ps(x, x); + const __m128 ax2 = _mm_mul_ps(coef_a, x_squared); + const __m128 one_plus_ax2 = _mm_add_ps(one, ax2); + const __m128 inner = _mm_mul_ps(_mm_mul_ps(sqrt_2_over_pi, x), one_plus_ax2); + const __m128 tanh_inner = ggml_vtanhf(inner); + const __m128 one_plus_tanh = _mm_add_ps(one, tanh_inner); + return _mm_mul_ps(_mm_mul_ps(half, x), one_plus_tanh); +} + +#endif + +static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + int i = 0; +#if defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, ggml_vgeluf(vld1q_f32(x + i))); } -} - -#ifdef GGML_GELU_FP16 -inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { - uint16_t t; - for (int i = 0; i < n; ++i) { - if (x[i] <= -10.0f) { - y[i] = 0.0f; - } else if (x[i] >= 10.0f) { - y[i] = x[i]; - } else { - ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]); + if (i < n) { + float temp_x[4] = {0}; + float temp_y[4] = {0}; + int rem = n - i; + for (int j = 0; j < rem; j++) { + temp_x[j] = x[i + j]; + } + float32x4_t x_vec = vld1q_f32(temp_x); + float32x4_t y_vec = ggml_vgeluf(x_vec); + vst1q_f32(temp_y, y_vec); + for (int j = 0; j < rem; j++) { + y[i + j] = temp_y[j]; + } + } +#elif defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, ggml_vgeluf(_mm512_loadu_ps(x + i))); + } + if (i < n) { + __mmask16 mask = _cvtu32_mask16((1U << (n - i)) - 1); + __m512 x_vec = _mm512_maskz_loadu_ps(mask, x + i); + __m512 y_vec = ggml_vgeluf(x_vec); + _mm512_mask_storeu_ps(y + i, mask, y_vec); + } + return; +#elif defined(__AVX2__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, ggml_vgeluf(_mm256_loadu_ps(x + i))); + } + if (i < n) { + __m256i mask = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(n - i), mask); + __m256 x_vec = _mm256_maskload_ps(x + i, mask); + __m256 y_vec = ggml_vgeluf(x_vec); + _mm256_maskstore_ps(y + i, mask, y_vec); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(y + i, ggml_vgeluf(_mm_loadu_ps(x + i))); + } + if (i < n) { + float temp_x[4] = {0}; + float temp_y[4] = {0}; + int rem = n - i; + for (int j = 0; j < rem; j++) { + temp_x[j] = x[i + j]; + } + __m128 x_vec = _mm_loadu_ps(temp_x); + __m128 y_vec = ggml_vgeluf(x_vec); + _mm_storeu_ps(temp_y, y_vec); + for (int j = 0; j < rem; j++) { + y[i + j] = temp_y[j]; } } -} #else -inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { + for (; i < n; ++i) { y[i] = ggml_gelu_f32(x[i]); } -} #endif +} inline static float ggml_gelu_quick_f32(float x) { return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); } -//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { -// const uint16_t * i16 = (const uint16_t *) x; -// for (int i = 0; i < n; ++i) { -// y[i] = ggml_table_gelu_quick_f16[i16[i]]; -// } -//} - #ifdef GGML_GELU_QUICK_FP16 inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { uint16_t t; @@ -2538,14 +2835,6 @@ inline static __m256 ggml_v_silu(__m256 x) { #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON -#if defined(__FMA__) -#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z) -#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z) -#else -#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z) -#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y)) -#endif - // adapted from arm limited optimized routine // the maximum error is 1.45358 plus 0.5 ulps // numbers above 88.38 will flush to infinity @@ -3486,7 +3775,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { ggml_fp16_t fp16; } u = {i}; float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16); - ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f)); }