Compare commits

...

5 Commits

Author SHA1 Message Date
Justine Tunney
034fc2aa43
Merge bb668b608e into 0a683e8088 2024-10-31 11:35:30 -03:00
Kevin Gibbons
0a683e8088
server : include scheme when printing URL (#10106)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-musa.Dockerfile platforms:linux/amd64 tag:full-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-musa.Dockerfile platforms:linux/amd64 tag:light-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-musa.Dockerfile platforms:linux/amd64 tag:server-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
2024-10-31 14:02:35 +01:00
Diego Devesa
dea5e86051
ggml : check tensor name lengths in gguf files (#10100) 2024-10-31 11:40:59 +01:00
Sergio López
1329c0a75e
kompute: add mul_mat_q4_k shader (#10097)
This is a more or less direct translation from the Metal implementation
to GLSL.

Signed-off-by: Sergio Lopez <slp@redhat.com>
2024-10-31 11:09:52 +02:00
Justine Tunney
bb668b608e
ggml : make GeLU more accurate on CPU
This change makes GeLU more accurate on amd64 and arm64 by using a tanhf
approximation that's been explicitly vectorized for avx512f, avx2, sse2,
and neon. No performance is traded away on these architectures, compared
to the 16-bit lookup table that was being used previously. The impact of
this change can be demonstrated easily with whisper, where it leads to a
measurable improvement in levenshtein distance of model output.
2024-08-18 05:47:02 -07:00
7 changed files with 561 additions and 57 deletions

View File

@ -3259,7 +3259,7 @@ int main(int argc, char ** argv) {
ctx_server.queue_tasks.terminate();
};
LOG_INF("%s: server is listening on %s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
ctx_server.queue_tasks.start_loop();

View File

@ -800,6 +800,7 @@ if (GGML_KOMPUTE)
kompute-shaders/op_mul_mat_q8_0.comp
kompute-shaders/op_mul_mat_q4_0.comp
kompute-shaders/op_mul_mat_q4_1.comp
kompute-shaders/op_mul_mat_q4_k.comp
kompute-shaders/op_mul_mat_q6_k.comp
kompute-shaders/op_getrows_f32.comp
kompute-shaders/op_getrows_f16.comp
@ -833,6 +834,7 @@ if (GGML_KOMPUTE)
shaderop_mul_mat_q8_0.h
shaderop_mul_mat_q4_0.h
shaderop_mul_mat_q4_1.h
shaderop_mul_mat_q4_k.h
shaderop_mul_mat_q6_k.h
shaderop_getrows_f32.h
shaderop_getrows_f16.h

View File

@ -20,6 +20,7 @@
#include "shaderop_mul_mat_q8_0.h"
#include "shaderop_mul_mat_q4_0.h"
#include "shaderop_mul_mat_q4_1.h"
#include "shaderop_mul_mat_q4_k.h"
#include "shaderop_mul_mat_q6_k.h"
#include "shaderop_mul_mat_mat_f32.h"
#include "shaderop_getrows_f32.h"
@ -1067,6 +1068,40 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
}
static void ggml_vk_mul_mat_q4_k(
kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA,
const std::shared_ptr<kp::Tensor>& inB,
const std::shared_ptr<kp::Tensor>& out,
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
int32_t ne1, int32_t r2, int32_t r3
) {
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
struct PushConstants {
uint32_t inAOff, inBOff, outOff;
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
} pushConsts {
0, 0, 0,
ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
};
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
if (!komputeManager()->hasAlgorithm(__func__)) {
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
} else {
s_algo = komputeManager()->getAlgorithm(__func__);
s_algo->setTensors({inA, inB, out});
s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
s_algo->setPushConstants<PushConstants>({pushConsts});
s_algo->updateDescriptors(s_kompute_context->pool.get());
}
seq.record<kp::OpAlgoDispatch>(s_algo);
}
static void ggml_vk_mul_mat_q6_k(
kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA,
@ -1384,6 +1419,7 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_K:
return true;
default:
;
@ -1635,6 +1671,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
);
break;
case GGML_TYPE_Q4_K:
ggml_vk_mul_mat_q4_k(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
);
break;
case GGML_TYPE_Q6_K:
ggml_vk_mul_mat_q6_k(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,

View File

@ -306,7 +306,6 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
}
#define GGML_DEBUG 0
#define GGML_GELU_FP16
#define GGML_GELU_QUICK_FP16
#define GGML_SOFT_MAX_UNROLL 4
@ -509,9 +508,6 @@ typedef double ggml_float;
// global data
//
// precomputed gelu table for f16 (128 KB)
static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
// precomputed quick gelu table for f16 (128 KB)
static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
@ -1991,6 +1987,19 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
#endif
// for GeLU and SiLU
#ifdef __FMA__
#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
#define MADD256(x, y, z) _mm256_fmadd_ps(x, y, z)
#define NMADD256(x, y, z) _mm256_fnmadd_ps(x, y, z)
#else
#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
#define MADD256(x, y, z) _mm256_add_ps(_mm256_mul_ps(x, y), z)
#define NMADD256(x, y, z) _mm256_sub_ps(z, _mm256_mul_ps(x, y))
#endif
//
// ggml object
//
@ -2567,55 +2576,343 @@ inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float *
inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
static const float GELU_COEF_A = 0.044715f;
////////////////////////////////////////////////////////////////////////////////
// There's always room for GeLU
static const float GELU_COEF_A = .044715f;
static const float GELU_QUICK_COEF = -1.702f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
static const float SQRT_2_OVER_PI = .79788456080286535587989211986876f;
inline static float ggml_gelu_f32(float x) {
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
return .5f*x*(1.f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
const uint16_t * i16 = (const uint16_t *) x;
for (int i = 0; i < n; ++i) {
y[i] = ggml_table_gelu_f16[i16[i]];
#if defined(__ARM_NEON) && defined(__aarch64__)
/* Approximation for single-precision vector tanh (2.58 ULP)
There is no support for signed zero whose sign is removed
There is no support for floating point exception handling
This code is based on the ARM Limited Optimized Routines. */
inline static float32x4_t
ggml_vtanhf(float32x4_t x)
{
const uint32x4_t ix = vreinterpretq_u32_f32(x);
const float32x4_t ax = vabsq_f32(x);
const uint32x4_t iax = vreinterpretq_u32_f32(ax);
const uint32x4_t sign = veorq_u32(ix, iax);
const uint32x4_t is_boring = vcgtq_u32(iax, vdupq_n_u32(0x41102cb3));
const float32x4_t boring = vreinterpretq_f32_u32(vorrq_u32(sign, vdupq_n_u32(0x3f800000)));
const uint32x4_t special = vcgtq_u32(iax, vdupq_n_u32(0x7f800000));
const float32x4_t ex = vmulq_n_f32(x, 2);
const float32x4_t e = { 0x1.715476p+0f, 0x1.62e4p-1f, 0x1.7f7d1cp-20f };
const float32x4_t j = vsubq_f32(vfmaq_laneq_f32(vdupq_n_f32(0x1.8p23f), ex, e, 0), vdupq_n_f32(0x1.8p23f));
const int32x4_t i = vcvtq_s32_f32(j);
const float32x4_t f = vfmsq_laneq_f32(ex, j, e, 1);
const float32x4_t f1 = vfmsq_laneq_f32(f, j, e, 2);
const float32x4_t f2 = vmulq_f32(f1, f1);
const float32x4_t f4 = vmulq_f32(f2, f2);
const float32x4_t p01 = vfmaq_f32(vdupq_n_f32(0x1.fffffep-2), vdupq_n_f32(0x1.5554aep-3), f1);
const float32x4_t p23 = vfmaq_f32(vdupq_n_f32(0x1.555736p-5), vdupq_n_f32(0x1.12287cp-7), f1);
const float32x4_t p03 = vfmaq_f32(p01, p23, f2);
const float32x4_t p = vfmaq_f32(p03, vdupq_n_f32(0x1.6b55a2p-10), f4);
const float32x4_t p2 = vfmaq_f32(f1, f2, p);
const int32x4_t u = vaddq_s32(vshlq_n_s32(i, 23), vdupq_n_s32(0x3f800000));
const float32x4_t t = vreinterpretq_f32_s32(u);
const float32x4_t q = vfmaq_f32(vsubq_f32(t, vdupq_n_f32(1)), p2, t);
const float32x4_t y = vdivq_f32(q, vaddq_f32(q, vdupq_n_f32(2)));
const float32x4_t result = vbslq_f32(is_boring, boring, y);
if (!vpaddd_u64(vreinterpretq_u64_u32(special)))
return result;
return (float32x4_t){ special[0] ? tanhf(x[0]) : result[0],
special[1] ? tanhf(x[1]) : result[1],
special[2] ? tanhf(x[2]) : result[2],
special[3] ? tanhf(x[3]) : result[3] };
}
inline static float32x4_t
ggml_vgeluf(float32x4_t x)
{
const float32x4_t one = vdupq_n_f32(1);
const float32x4_t half = vdupq_n_f32(.5);
const float32x4_t coef_a = vdupq_n_f32(GELU_COEF_A);
const float32x4_t sqrt_2_over_pi = vdupq_n_f32(SQRT_2_OVER_PI);
const float32x4_t x_squared = vmulq_f32(x, x);
const float32x4_t ax2 = vmulq_f32(coef_a, x_squared);
const float32x4_t one_plus_ax2 = vaddq_f32(one, ax2);
const float32x4_t inner = vmulq_f32(vmulq_f32(sqrt_2_over_pi, x), one_plus_ax2);
const float32x4_t tanh_inner = ggml_vtanhf(inner);
const float32x4_t one_plus_tanh = vaddq_f32(one, tanh_inner);
return vmulq_f32(vmulq_f32(half, x), one_plus_tanh);
}
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
/* Approximation for single-precision vector tanh(x) using a
branchless algorithm that offers a maximum error of 4 ULP
108638843x off by one errors
18273656x 2 to 3 ulp errors
124x 4 ulp erors (e.g. 0.203652 [3e508a10])
1x sign flip
There is no support for signed zero whose sign is removed
There is no support for floating point exception handling
This code is based on the ARM Limited Optimized Routines. */
inline static __m512
ggml_vtanhf(__m512 x)
{
const __m512 sign_mask = _mm512_castsi512_ps(_mm512_set1_epi32(0x80000000));
const __m512 one = _mm512_set1_ps(1);
const __m512 two = _mm512_set1_ps(2);
const __m512 ax = _mm512_abs_ps(x);
const __m512 sign = _mm512_and_ps(x, sign_mask);
const __mmask16 is_boring = _mm512_cmp_ps_mask(ax, _mm512_set1_ps(0x1.205966p+3), _CMP_GT_OQ);
const __m512 boring = _mm512_or_ps(sign, one);
const __m512 ex = _mm512_mul_ps(x, two);
const __m512 j = _mm512_fmadd_ps( ex, _mm512_set1_ps(0x1.715476p+0f), _mm512_set1_ps(0x1.8p23f));
const __m512 jj = _mm512_sub_ps(j, _mm512_set1_ps(0x1.8p23f));
const __m512i i = _mm512_cvttps_epi32(jj);
const __m512 f = _mm512_fnmadd_ps(_mm512_set1_ps(0x1.62e4p-1f), jj, ex);
const __m512 f1 = _mm512_fnmadd_ps(_mm512_set1_ps(0x1.7f7d1cp-20f), jj, f);
const __m512 f2 = _mm512_mul_ps(f1, f1);
const __m512 f4 = _mm512_mul_ps(f2, f2);
const __m512 p01 = _mm512_fmadd_ps( f1, _mm512_set1_ps(0x1.5554aep-3), _mm512_set1_ps(0x1.fffffep-2));
const __m512 p23 = _mm512_fmadd_ps( f1, _mm512_set1_ps(0x1.12287cp-7), _mm512_set1_ps(0x1.555736p-5));
const __m512 p03 = _mm512_fmadd_ps(f2, p23, p01);
const __m512 p = _mm512_fmadd_ps(f4, _mm512_set1_ps(0x1.6b55a2p-10), p03);
const __m512 p2 = _mm512_fmadd_ps(f2, p, f1);
const __m512i u = _mm512_add_epi32(_mm512_slli_epi32(i, 23), _mm512_set1_epi32(0x3f800000));
const __m512 t = _mm512_castsi512_ps(u);
const __m512 q = _mm512_fmadd_ps(p2, t, _mm512_sub_ps(t, one));
const __m512 y = _mm512_div_ps(q, _mm512_add_ps(q, two));
return _mm512_mask_blend_ps(is_boring, y, boring);
}
inline static __m512
ggml_vgeluf(__m512 x)
{
const __m512 one = _mm512_set1_ps(1);
const __m512 half = _mm512_set1_ps(.5);
const __m512 coef_a = _mm512_set1_ps(GELU_COEF_A);
const __m512 sqrt_2_over_pi = _mm512_set1_ps(SQRT_2_OVER_PI);
const __m512 x_squared = _mm512_mul_ps(x, x);
const __m512 ax2 = _mm512_mul_ps(coef_a, x_squared);
const __m512 one_plus_ax2 = _mm512_add_ps(one, ax2);
const __m512 inner = _mm512_mul_ps(_mm512_mul_ps(sqrt_2_over_pi, x), one_plus_ax2);
const __m512 tanh_inner = ggml_vtanhf(inner);
const __m512 one_plus_tanh = _mm512_add_ps(one, tanh_inner);
return _mm512_mul_ps(_mm512_mul_ps(half, x), one_plus_tanh);
}
#elif defined(__AVX2__)
/* Approximation for single-precision vector tanh(x) using a
branchless algorithm that offers a maximum error of 4 ULP
With fused multiply add:
108638843x off by one errors
18273656x 2 to 3 ulp errors
124x 4 ulp erors (e.g. 0.203652 [3e508a10])
1x sign flip
Without fused multiply add:
108479590x off by one errors
18209645x 2 to 3 ulp errors
70x 4 ulp errors (e.g. 0.205979 [3e52ec19])
1x sign flip
There is no support for signed zero whose sign is removed
There is no support for floating point exception handling
This code is based on the ARM Limited Optimized Routines. */
inline static __m256
ggml_vtanhf(__m256 x)
{
const __m256 abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF));
const __m256 one = _mm256_set1_ps(1);
const __m256 two = _mm256_set1_ps(2);
const __m256 ax = _mm256_and_ps(x, abs_mask);
const __m256 sign = _mm256_and_ps(x, _mm256_set1_ps(-0.f));
const __m256 is_boring = _mm256_cmp_ps(ax, _mm256_set1_ps(0x1.205966p+3), _CMP_GT_OQ);
const __m256 boring = _mm256_or_ps(sign, one);
const __m256 ex = _mm256_mul_ps(x, two);
const __m256 j = MADD256(ex, _mm256_set1_ps(0x1.715476p+0f), _mm256_set1_ps(0x1.8p23f));
const __m256 jj = _mm256_sub_ps(j, _mm256_set1_ps(0x1.8p23f));
const __m256i i = _mm256_cvttps_epi32(jj);
const __m256 f = NMADD256(_mm256_set1_ps(0x1.62e4p-1f), jj, ex);
const __m256 f1 = NMADD256(_mm256_set1_ps(0x1.7f7d1cp-20f), jj, f);
const __m256 f2 = _mm256_mul_ps(f1, f1);
const __m256 f4 = _mm256_mul_ps(f2, f2);
const __m256 p01 = MADD256(f1, _mm256_set1_ps(0x1.5554aep-3), _mm256_set1_ps(0x1.fffffep-2));
const __m256 p23 = MADD256(f1, _mm256_set1_ps(0x1.12287cp-7), _mm256_set1_ps(0x1.555736p-5));
const __m256 p03 = MADD256(f2, p23, p01);
const __m256 p = MADD256(f4, _mm256_set1_ps(0x1.6b55a2p-10), p03);
const __m256 p2 = MADD256(f2, p, f1);
const __m256i u = _mm256_add_epi32(_mm256_slli_epi32(i, 23), _mm256_set1_epi32(0x3f800000));
const __m256 t = _mm256_castsi256_ps(u);
const __m256 q = MADD256(p2, t, _mm256_sub_ps(t, one));
const __m256 y = _mm256_div_ps(q, _mm256_add_ps(q, two));
return _mm256_or_ps(_mm256_and_ps(is_boring, boring), _mm256_andnot_ps(is_boring, y));
}
inline static __m256
ggml_vgeluf(__m256 x)
{
const __m256 one = _mm256_set1_ps(1);
const __m256 half = _mm256_set1_ps(.5);
const __m256 coef_a = _mm256_set1_ps(GELU_COEF_A);
const __m256 sqrt_2_over_pi = _mm256_set1_ps(SQRT_2_OVER_PI);
const __m256 x_squared = _mm256_mul_ps(x, x);
const __m256 ax2 = _mm256_mul_ps(coef_a, x_squared);
const __m256 one_plus_ax2 = _mm256_add_ps(one, ax2);
const __m256 inner = _mm256_mul_ps(_mm256_mul_ps(sqrt_2_over_pi, x), one_plus_ax2);
const __m256 tanh_inner = ggml_vtanhf(inner);
const __m256 one_plus_tanh = _mm256_add_ps(one, tanh_inner);
return _mm256_mul_ps(_mm256_mul_ps(half, x), one_plus_tanh);
}
#elif defined(__SSE2__)
/* Approximation for single-precision vector tanh(x) using a
branchless algorithm that offers a maximum error of 4 ULP
Without fused multiply add:
108479590x off by one errors
18209645x 2 to 3 ulp errors
70x 4 ulp errors (e.g. 0.205979 [3e52ec19])
1x sign flip
With fused multiply add:
108638843x off by one errors
18273656x 2 to 3 ulp errors
124x 4 ulp erors (e.g. 0.203652 [3e508a10])
1x sign flip
There is no support for signed zero whose sign is removed
There is no support for floating point exception handling
This code is based on the ARM Limited Optimized Routines. */
inline static __m128
ggml_vtanhf(__m128 x)
{
const __m128 abs_mask = _mm_castsi128_ps(_mm_set1_epi32(0x7FFFFFFF));
const __m128 one = _mm_set1_ps(1);
const __m128 two = _mm_set1_ps(2);
const __m128 ax = _mm_and_ps(x, abs_mask);
const __m128 sign = _mm_and_ps(x, _mm_set1_ps(-0.f));
const __m128 is_boring = _mm_cmpgt_ps(ax, _mm_set1_ps(0x1.205966p+3));
const __m128 boring = _mm_or_ps(sign, one);
const __m128 ex = _mm_mul_ps(x, two);
const __m128 j = MADD128(ex, _mm_set1_ps(0x1.715476p+0f), _mm_set1_ps(0x1.8p23f));
const __m128 jj = _mm_sub_ps(j, _mm_set1_ps(0x1.8p23f));
const __m128i i = _mm_cvttps_epi32(jj);
const __m128 f = NMADD128(_mm_set1_ps(0x1.62e4p-1f), jj, ex);
const __m128 f1 = NMADD128(_mm_set1_ps(0x1.7f7d1cp-20f), jj, f);
const __m128 f2 = _mm_mul_ps(f1, f1);
const __m128 f4 = _mm_mul_ps(f2, f2);
const __m128 p01 = MADD128(f1, _mm_set1_ps(0x1.5554aep-3), _mm_set1_ps(0x1.fffffep-2));
const __m128 p23 = MADD128(f1, _mm_set1_ps(0x1.12287cp-7), _mm_set1_ps(0x1.555736p-5));
const __m128 p03 = MADD128(f2, p23, p01);
const __m128 p = MADD128(f4, _mm_set1_ps(0x1.6b55a2p-10), p03);
const __m128 p2 = MADD128(f2, p, f1);
const __m128i u = _mm_add_epi32(_mm_slli_epi32(i, 23), _mm_set1_epi32(0x3f800000));
const __m128 t = _mm_castsi128_ps(u);
const __m128 q = MADD128(p2, t, _mm_sub_ps(t, one));
const __m128 y = _mm_div_ps(q, _mm_add_ps(q, two));
return _mm_or_ps(_mm_and_ps(is_boring, boring), _mm_andnot_ps(is_boring, y));
}
inline static __m128
ggml_vgeluf(__m128 x)
{
const __m128 one = _mm_set1_ps(1);
const __m128 half = _mm_set1_ps(.5);
const __m128 coef_a = _mm_set1_ps(GELU_COEF_A);
const __m128 sqrt_2_over_pi = _mm_set1_ps(SQRT_2_OVER_PI);
const __m128 x_squared = _mm_mul_ps(x, x);
const __m128 ax2 = _mm_mul_ps(coef_a, x_squared);
const __m128 one_plus_ax2 = _mm_add_ps(one, ax2);
const __m128 inner = _mm_mul_ps(_mm_mul_ps(sqrt_2_over_pi, x), one_plus_ax2);
const __m128 tanh_inner = ggml_vtanhf(inner);
const __m128 one_plus_tanh = _mm_add_ps(one, tanh_inner);
return _mm_mul_ps(_mm_mul_ps(half, x), one_plus_tanh);
}
#endif
static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
int i = 0;
#if defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
vst1q_f32(y + i, ggml_vgeluf(vld1q_f32(x + i)));
}
}
#ifdef GGML_GELU_FP16
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
uint16_t t;
for (int i = 0; i < n; ++i) {
if (x[i] <= -10.0f) {
y[i] = 0.0f;
} else if (x[i] >= 10.0f) {
y[i] = x[i];
} else {
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
memcpy(&t, &fp16, sizeof(uint16_t));
y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
if (i < n) {
float temp_x[4] = {0};
float temp_y[4] = {0};
int rem = n - i;
for (int j = 0; j < rem; j++) {
temp_x[j] = x[i + j];
}
float32x4_t x_vec = vld1q_f32(temp_x);
float32x4_t y_vec = ggml_vgeluf(x_vec);
vst1q_f32(temp_y, y_vec);
for (int j = 0; j < rem; j++) {
y[i + j] = temp_y[j];
}
}
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
for (; i + 15 < n; i += 16) {
_mm512_storeu_ps(y + i, ggml_vgeluf(_mm512_loadu_ps(x + i)));
}
if (i < n) {
__mmask16 mask = _cvtu32_mask16((1U << (n - i)) - 1);
__m512 x_vec = _mm512_maskz_loadu_ps(mask, x + i);
__m512 y_vec = ggml_vgeluf(x_vec);
_mm512_mask_storeu_ps(y + i, mask, y_vec);
}
return;
#elif defined(__AVX2__)
for (; i + 7 < n; i += 8) {
_mm256_storeu_ps(y + i, ggml_vgeluf(_mm256_loadu_ps(x + i)));
}
if (i < n) {
__m256i mask = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(n - i), mask);
__m256 x_vec = _mm256_maskload_ps(x + i, mask);
__m256 y_vec = ggml_vgeluf(x_vec);
_mm256_maskstore_ps(y + i, mask, y_vec);
}
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
_mm_storeu_ps(y + i, ggml_vgeluf(_mm_loadu_ps(x + i)));
}
if (i < n) {
float temp_x[4] = {0};
float temp_y[4] = {0};
int rem = n - i;
for (int j = 0; j < rem; j++) {
temp_x[j] = x[i + j];
}
__m128 x_vec = _mm_loadu_ps(temp_x);
__m128 y_vec = ggml_vgeluf(x_vec);
_mm_storeu_ps(temp_y, y_vec);
for (int j = 0; j < rem; j++) {
y[i + j] = temp_y[j];
}
}
}
#else
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
for (int i = 0; i < n; ++i) {
for (; i < n; ++i) {
y[i] = ggml_gelu_f32(x[i]);
}
}
#endif
}
inline static float ggml_gelu_quick_f32(float x) {
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
}
//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
// const uint16_t * i16 = (const uint16_t *) x;
// for (int i = 0; i < n; ++i) {
// y[i] = ggml_table_gelu_quick_f16[i16[i]];
// }
//}
#ifdef GGML_GELU_QUICK_FP16
inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
uint16_t t;
@ -2782,14 +3079,6 @@ inline static __m256 ggml_v_silu(__m256 x) {
#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
#if defined(__FMA__)
#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
#else
#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
#endif
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
@ -3831,7 +4120,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
ggml_fp16_t fp16;
} u = {i};
float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
}
@ -22102,18 +22390,46 @@ static size_t gguf_type_size(enum gguf_type type) {
return GGUF_TYPE_SIZE[type];
}
static void gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
GGML_ASSERT(info->n_dims <= GGML_MAX_DIMS);
GGML_ASSERT(0 <= info->type && info->type < GGML_TYPE_COUNT);
static bool gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
if (info->n_dims > GGML_MAX_DIMS) {
fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims);
return false;
}
if (info->type < 0 || info->type >= GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type);
return false;
}
if (strlen(info->name.data) >= GGML_MAX_NAME) {
fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data);
return false;
}
for (uint32_t i = 0; i < info->n_dims; ++i) {
GGML_ASSERT(info->ne[i] > 0);
if (info->ne[i] <= 0) {
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]);
return false;
}
}
// prevent overflow for total number of elements
GGML_ASSERT(INT64_MAX/info->ne[1] > info->ne[0]);
GGML_ASSERT(INT64_MAX/info->ne[2] > info->ne[0]*info->ne[1]);
GGML_ASSERT(INT64_MAX/info->ne[3] > info->ne[0]*info->ne[1]*info->ne[2]);
if (INT64_MAX/info->ne[1] <= info->ne[0]) {
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]);
return false;
}
if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) {
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]);
return false;
}
if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) {
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]);
return false;
}
return true;
}
static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
@ -22414,8 +22730,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
// TODO: return an error instead of crashing with GGML_ASSERT
gguf_tensor_info_sanitize(info);
ok = ok && gguf_tensor_info_sanitize(info);
// make sure there is no duplicated tensor names
for (uint64_t j = 0; j < i && ok; ++j) {

View File

@ -15,6 +15,7 @@
#define TWOPI_F 6.283185307179586f
#define QK_K 256
#define K_SCALE_SIZE 12
#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
@ -64,6 +65,14 @@ mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
return reg;
}
#define sizeof_block_q4_k 144
struct block_q4_k {
float16_t d;
float16_t dmin;
uint8_t scales[K_SCALE_SIZE];
uint8_t qs[QK_K/2];
};
#define sizeof_block_q6_k 210
struct block_q6_k {
uint8_t ql[QK_K/2]; // quants, lower 4 bits

View File

@ -0,0 +1,133 @@
#version 450
#include "common.comp"
#define N_DST 4
#define SIZE_OF_BLOCK sizeof_block_q4_k
layout(local_size_x = 4) in;
layout(local_size_y = 8) in;
layout(local_size_z = 1) in;
layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; };
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne10;
int ne0;
int ne1;
int ne01;
int ne02;
int ne12;
int r2;
int r3;
} pcs;
void main() {
const uint16_t kmask1 = uint16_t(0x3f3f);
const uint16_t kmask2 = uint16_t(0x0f0f);
const uint16_t kmask3 = uint16_t(0xc0c0);
const uint ix = gl_SubgroupInvocationID/8; // 0...3
const uint it = gl_SubgroupInvocationID%8; // 0...7
const uint iq = it/4; // 0 or 1
const uint ir = it%4; // 0...3
const uint nb = pcs.ne00/QK_K;
const uint r0 = gl_WorkGroupID.x;
const uint r1 = gl_WorkGroupID.y;
const uint im = gl_WorkGroupID.z;
const uint first_row = r0 * N_DST;
const uint ib_row = first_row * nb;
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
const uint xblk = ib_row + offset0 + pcs.inAOff;
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;
float yl[16];
float yh[16];
float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f};
float all_sum = 0.f;
uint y4 = y + ix * QK_K + 64 * iq + 8 * ir;
for (uint ib = ix; ib < nb; ib += 4) {
const uint blk_idx = ib + xblk;
float sumy[4] = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; ++i) {
yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0];
yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8];
yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0];
yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8];
}
for (int row = 0; row < N_DST; row++) {
uint row_idx = row * nb;
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4);
uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6);
uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8);
uint16_t sc16[4];
sc16[0] = sc_0 & kmask1;
sc16[1] = sc_2 & kmask1;
sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2);
sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2);
float acc1[4] = {0.f, 0.f, 0.f, 0.f};
float acc2[4] = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i);
uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i);
acc1[0] += yl[i+0] * (q1 & 0x000F);
acc1[1] += yl[i+1] * (q1 & 0x0F00);
acc1[2] += yl[i+8] * (q1 & 0x00F0);
acc1[3] += yl[i+9] * (q1 & 0xF000);
acc2[0] += yh[i+0] * (q2 & 0x000F);
acc2[1] += yh[i+1] * (q2 & 0x0F00);
acc2[2] += yh[i+8] * (q2 & 0x00F0);
acc2[3] += yh[i+9] * (q2 & 0xF000);
}
uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF);
uint8_t sc8_1 = uint8_t(sc16[0] >> 8 );
uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF);
uint8_t sc8_3 = uint8_t(sc16[1] >> 8 );
uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF);
uint8_t sc8_5 = uint8_t(sc16[2] >> 8 );
uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF);
uint8_t sc8_7 = uint8_t(sc16[3] >> 8 );
float dall = float(inA[blk_idx + row_idx].d);
float dmin = float(inA[blk_idx + row_idx].dmin);
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 +
(acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f +
(acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 +
(acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) -
dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7);
}
y4 += 4 * QK_K;
}
for (int row = 0; row < N_DST; ++row) {
all_sum = subgroupAdd(sumf[row]);
if (subgroupElect()) {
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum;
}
}
}

View File

@ -4273,8 +4273,11 @@ struct llama_model_loader {
llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
const int tensor_idx = gguf_find_tensor(gguf_ctx, name);
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
if (tensor_idx < 0) {
throw std::runtime_error(format("tensor '%s' not found in the model", name));
}
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) {
throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name));
}
@ -7426,7 +7429,7 @@ static bool llm_load_tensors(
if (flags & llama_model_loader::TENSOR_NOT_REQUIRED) {
return nullptr;
}
throw std::runtime_error(format("missing tensor %s", tn.str().c_str()));
throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str()));
}
// some models use the token embedding tensor as the output, but since these are used in different layers and with different ops