mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-13 14:29:52 +00:00
Compare commits
5 Commits
698e28cb12
...
034fc2aa43
Author | SHA1 | Date | |
---|---|---|---|
|
034fc2aa43 | ||
|
0a683e8088 | ||
|
dea5e86051 | ||
|
1329c0a75e | ||
|
bb668b608e |
@ -3259,7 +3259,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_tasks.terminate();
|
||||
};
|
||||
|
||||
LOG_INF("%s: server is listening on %s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
||||
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
||||
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
|
||||
|
@ -800,6 +800,7 @@ if (GGML_KOMPUTE)
|
||||
kompute-shaders/op_mul_mat_q8_0.comp
|
||||
kompute-shaders/op_mul_mat_q4_0.comp
|
||||
kompute-shaders/op_mul_mat_q4_1.comp
|
||||
kompute-shaders/op_mul_mat_q4_k.comp
|
||||
kompute-shaders/op_mul_mat_q6_k.comp
|
||||
kompute-shaders/op_getrows_f32.comp
|
||||
kompute-shaders/op_getrows_f16.comp
|
||||
@ -833,6 +834,7 @@ if (GGML_KOMPUTE)
|
||||
shaderop_mul_mat_q8_0.h
|
||||
shaderop_mul_mat_q4_0.h
|
||||
shaderop_mul_mat_q4_1.h
|
||||
shaderop_mul_mat_q4_k.h
|
||||
shaderop_mul_mat_q6_k.h
|
||||
shaderop_getrows_f32.h
|
||||
shaderop_getrows_f16.h
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "shaderop_mul_mat_q8_0.h"
|
||||
#include "shaderop_mul_mat_q4_0.h"
|
||||
#include "shaderop_mul_mat_q4_1.h"
|
||||
#include "shaderop_mul_mat_q4_k.h"
|
||||
#include "shaderop_mul_mat_q6_k.h"
|
||||
#include "shaderop_mul_mat_mat_f32.h"
|
||||
#include "shaderop_getrows_f32.h"
|
||||
@ -1067,6 +1068,40 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
|
||||
ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_q4_k(
|
||||
kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& inA,
|
||||
const std::shared_ptr<kp::Tensor>& inB,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
||||
int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
|
||||
int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
|
||||
int32_t ne1, int32_t r2, int32_t r3
|
||||
) {
|
||||
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
|
||||
kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
|
||||
|
||||
struct PushConstants {
|
||||
uint32_t inAOff, inBOff, outOff;
|
||||
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
|
||||
} pushConsts {
|
||||
0, 0, 0,
|
||||
ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
|
||||
};
|
||||
|
||||
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
||||
if (!komputeManager()->hasAlgorithm(__func__)) {
|
||||
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
|
||||
} else {
|
||||
s_algo = komputeManager()->getAlgorithm(__func__);
|
||||
s_algo->setTensors({inA, inB, out});
|
||||
s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
|
||||
s_algo->setPushConstants<PushConstants>({pushConsts});
|
||||
s_algo->updateDescriptors(s_kompute_context->pool.get());
|
||||
}
|
||||
seq.record<kp::OpAlgoDispatch>(s_algo);
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_q6_k(
|
||||
kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& inA,
|
||||
@ -1384,6 +1419,7 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_K:
|
||||
return true;
|
||||
default:
|
||||
;
|
||||
@ -1635,6 +1671,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
|
||||
);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
ggml_vk_mul_mat_q4_k(
|
||||
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
||||
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
|
||||
);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
ggml_vk_mul_mat_q6_k(
|
||||
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
||||
|
423
ggml/src/ggml.c
423
ggml/src/ggml.c
@ -306,7 +306,6 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
|
||||
}
|
||||
|
||||
#define GGML_DEBUG 0
|
||||
#define GGML_GELU_FP16
|
||||
#define GGML_GELU_QUICK_FP16
|
||||
|
||||
#define GGML_SOFT_MAX_UNROLL 4
|
||||
@ -509,9 +508,6 @@ typedef double ggml_float;
|
||||
// global data
|
||||
//
|
||||
|
||||
// precomputed gelu table for f16 (128 KB)
|
||||
static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
|
||||
|
||||
// precomputed quick gelu table for f16 (128 KB)
|
||||
static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
|
||||
|
||||
@ -1991,6 +1987,19 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
||||
#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
|
||||
#endif
|
||||
|
||||
// for GeLU and SiLU
|
||||
#ifdef __FMA__
|
||||
#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
|
||||
#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
|
||||
#define MADD256(x, y, z) _mm256_fmadd_ps(x, y, z)
|
||||
#define NMADD256(x, y, z) _mm256_fnmadd_ps(x, y, z)
|
||||
#else
|
||||
#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
|
||||
#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
|
||||
#define MADD256(x, y, z) _mm256_add_ps(_mm256_mul_ps(x, y), z)
|
||||
#define NMADD256(x, y, z) _mm256_sub_ps(z, _mm256_mul_ps(x, y))
|
||||
#endif
|
||||
|
||||
//
|
||||
// ggml object
|
||||
//
|
||||
@ -2567,55 +2576,343 @@ inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float *
|
||||
inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
|
||||
inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
|
||||
|
||||
static const float GELU_COEF_A = 0.044715f;
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// There's always room for GeLU
|
||||
|
||||
static const float GELU_COEF_A = .044715f;
|
||||
static const float GELU_QUICK_COEF = -1.702f;
|
||||
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
static const float SQRT_2_OVER_PI = .79788456080286535587989211986876f;
|
||||
|
||||
inline static float ggml_gelu_f32(float x) {
|
||||
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
return .5f*x*(1.f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
||||
const uint16_t * i16 = (const uint16_t *) x;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = ggml_table_gelu_f16[i16[i]];
|
||||
#if defined(__ARM_NEON) && defined(__aarch64__)
|
||||
|
||||
/* Approximation for single-precision vector tanh (2.58 ULP)
|
||||
There is no support for signed zero whose sign is removed
|
||||
There is no support for floating point exception handling
|
||||
This code is based on the ARM Limited Optimized Routines. */
|
||||
inline static float32x4_t
|
||||
ggml_vtanhf(float32x4_t x)
|
||||
{
|
||||
const uint32x4_t ix = vreinterpretq_u32_f32(x);
|
||||
const float32x4_t ax = vabsq_f32(x);
|
||||
const uint32x4_t iax = vreinterpretq_u32_f32(ax);
|
||||
const uint32x4_t sign = veorq_u32(ix, iax);
|
||||
const uint32x4_t is_boring = vcgtq_u32(iax, vdupq_n_u32(0x41102cb3));
|
||||
const float32x4_t boring = vreinterpretq_f32_u32(vorrq_u32(sign, vdupq_n_u32(0x3f800000)));
|
||||
const uint32x4_t special = vcgtq_u32(iax, vdupq_n_u32(0x7f800000));
|
||||
const float32x4_t ex = vmulq_n_f32(x, 2);
|
||||
const float32x4_t e = { 0x1.715476p+0f, 0x1.62e4p-1f, 0x1.7f7d1cp-20f };
|
||||
const float32x4_t j = vsubq_f32(vfmaq_laneq_f32(vdupq_n_f32(0x1.8p23f), ex, e, 0), vdupq_n_f32(0x1.8p23f));
|
||||
const int32x4_t i = vcvtq_s32_f32(j);
|
||||
const float32x4_t f = vfmsq_laneq_f32(ex, j, e, 1);
|
||||
const float32x4_t f1 = vfmsq_laneq_f32(f, j, e, 2);
|
||||
const float32x4_t f2 = vmulq_f32(f1, f1);
|
||||
const float32x4_t f4 = vmulq_f32(f2, f2);
|
||||
const float32x4_t p01 = vfmaq_f32(vdupq_n_f32(0x1.fffffep-2), vdupq_n_f32(0x1.5554aep-3), f1);
|
||||
const float32x4_t p23 = vfmaq_f32(vdupq_n_f32(0x1.555736p-5), vdupq_n_f32(0x1.12287cp-7), f1);
|
||||
const float32x4_t p03 = vfmaq_f32(p01, p23, f2);
|
||||
const float32x4_t p = vfmaq_f32(p03, vdupq_n_f32(0x1.6b55a2p-10), f4);
|
||||
const float32x4_t p2 = vfmaq_f32(f1, f2, p);
|
||||
const int32x4_t u = vaddq_s32(vshlq_n_s32(i, 23), vdupq_n_s32(0x3f800000));
|
||||
const float32x4_t t = vreinterpretq_f32_s32(u);
|
||||
const float32x4_t q = vfmaq_f32(vsubq_f32(t, vdupq_n_f32(1)), p2, t);
|
||||
const float32x4_t y = vdivq_f32(q, vaddq_f32(q, vdupq_n_f32(2)));
|
||||
const float32x4_t result = vbslq_f32(is_boring, boring, y);
|
||||
if (!vpaddd_u64(vreinterpretq_u64_u32(special)))
|
||||
return result;
|
||||
return (float32x4_t){ special[0] ? tanhf(x[0]) : result[0],
|
||||
special[1] ? tanhf(x[1]) : result[1],
|
||||
special[2] ? tanhf(x[2]) : result[2],
|
||||
special[3] ? tanhf(x[3]) : result[3] };
|
||||
}
|
||||
|
||||
inline static float32x4_t
|
||||
ggml_vgeluf(float32x4_t x)
|
||||
{
|
||||
const float32x4_t one = vdupq_n_f32(1);
|
||||
const float32x4_t half = vdupq_n_f32(.5);
|
||||
const float32x4_t coef_a = vdupq_n_f32(GELU_COEF_A);
|
||||
const float32x4_t sqrt_2_over_pi = vdupq_n_f32(SQRT_2_OVER_PI);
|
||||
const float32x4_t x_squared = vmulq_f32(x, x);
|
||||
const float32x4_t ax2 = vmulq_f32(coef_a, x_squared);
|
||||
const float32x4_t one_plus_ax2 = vaddq_f32(one, ax2);
|
||||
const float32x4_t inner = vmulq_f32(vmulq_f32(sqrt_2_over_pi, x), one_plus_ax2);
|
||||
const float32x4_t tanh_inner = ggml_vtanhf(inner);
|
||||
const float32x4_t one_plus_tanh = vaddq_f32(one, tanh_inner);
|
||||
return vmulq_f32(vmulq_f32(half, x), one_plus_tanh);
|
||||
}
|
||||
|
||||
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
|
||||
/* Approximation for single-precision vector tanh(x) using a
|
||||
branchless algorithm that offers a maximum error of 4 ULP
|
||||
|
||||
108638843x off by one errors
|
||||
18273656x 2 to 3 ulp errors
|
||||
124x 4 ulp erors (e.g. 0.203652 [3e508a10])
|
||||
1x sign flip
|
||||
|
||||
There is no support for signed zero whose sign is removed
|
||||
There is no support for floating point exception handling
|
||||
This code is based on the ARM Limited Optimized Routines. */
|
||||
inline static __m512
|
||||
ggml_vtanhf(__m512 x)
|
||||
{
|
||||
const __m512 sign_mask = _mm512_castsi512_ps(_mm512_set1_epi32(0x80000000));
|
||||
const __m512 one = _mm512_set1_ps(1);
|
||||
const __m512 two = _mm512_set1_ps(2);
|
||||
const __m512 ax = _mm512_abs_ps(x);
|
||||
const __m512 sign = _mm512_and_ps(x, sign_mask);
|
||||
const __mmask16 is_boring = _mm512_cmp_ps_mask(ax, _mm512_set1_ps(0x1.205966p+3), _CMP_GT_OQ);
|
||||
const __m512 boring = _mm512_or_ps(sign, one);
|
||||
const __m512 ex = _mm512_mul_ps(x, two);
|
||||
const __m512 j = _mm512_fmadd_ps( ex, _mm512_set1_ps(0x1.715476p+0f), _mm512_set1_ps(0x1.8p23f));
|
||||
const __m512 jj = _mm512_sub_ps(j, _mm512_set1_ps(0x1.8p23f));
|
||||
const __m512i i = _mm512_cvttps_epi32(jj);
|
||||
const __m512 f = _mm512_fnmadd_ps(_mm512_set1_ps(0x1.62e4p-1f), jj, ex);
|
||||
const __m512 f1 = _mm512_fnmadd_ps(_mm512_set1_ps(0x1.7f7d1cp-20f), jj, f);
|
||||
const __m512 f2 = _mm512_mul_ps(f1, f1);
|
||||
const __m512 f4 = _mm512_mul_ps(f2, f2);
|
||||
const __m512 p01 = _mm512_fmadd_ps( f1, _mm512_set1_ps(0x1.5554aep-3), _mm512_set1_ps(0x1.fffffep-2));
|
||||
const __m512 p23 = _mm512_fmadd_ps( f1, _mm512_set1_ps(0x1.12287cp-7), _mm512_set1_ps(0x1.555736p-5));
|
||||
const __m512 p03 = _mm512_fmadd_ps(f2, p23, p01);
|
||||
const __m512 p = _mm512_fmadd_ps(f4, _mm512_set1_ps(0x1.6b55a2p-10), p03);
|
||||
const __m512 p2 = _mm512_fmadd_ps(f2, p, f1);
|
||||
const __m512i u = _mm512_add_epi32(_mm512_slli_epi32(i, 23), _mm512_set1_epi32(0x3f800000));
|
||||
const __m512 t = _mm512_castsi512_ps(u);
|
||||
const __m512 q = _mm512_fmadd_ps(p2, t, _mm512_sub_ps(t, one));
|
||||
const __m512 y = _mm512_div_ps(q, _mm512_add_ps(q, two));
|
||||
return _mm512_mask_blend_ps(is_boring, y, boring);
|
||||
}
|
||||
|
||||
inline static __m512
|
||||
ggml_vgeluf(__m512 x)
|
||||
{
|
||||
const __m512 one = _mm512_set1_ps(1);
|
||||
const __m512 half = _mm512_set1_ps(.5);
|
||||
const __m512 coef_a = _mm512_set1_ps(GELU_COEF_A);
|
||||
const __m512 sqrt_2_over_pi = _mm512_set1_ps(SQRT_2_OVER_PI);
|
||||
const __m512 x_squared = _mm512_mul_ps(x, x);
|
||||
const __m512 ax2 = _mm512_mul_ps(coef_a, x_squared);
|
||||
const __m512 one_plus_ax2 = _mm512_add_ps(one, ax2);
|
||||
const __m512 inner = _mm512_mul_ps(_mm512_mul_ps(sqrt_2_over_pi, x), one_plus_ax2);
|
||||
const __m512 tanh_inner = ggml_vtanhf(inner);
|
||||
const __m512 one_plus_tanh = _mm512_add_ps(one, tanh_inner);
|
||||
return _mm512_mul_ps(_mm512_mul_ps(half, x), one_plus_tanh);
|
||||
}
|
||||
|
||||
#elif defined(__AVX2__)
|
||||
|
||||
/* Approximation for single-precision vector tanh(x) using a
|
||||
branchless algorithm that offers a maximum error of 4 ULP
|
||||
|
||||
With fused multiply add:
|
||||
|
||||
108638843x off by one errors
|
||||
18273656x 2 to 3 ulp errors
|
||||
124x 4 ulp erors (e.g. 0.203652 [3e508a10])
|
||||
1x sign flip
|
||||
|
||||
Without fused multiply add:
|
||||
|
||||
108479590x off by one errors
|
||||
18209645x 2 to 3 ulp errors
|
||||
70x 4 ulp errors (e.g. 0.205979 [3e52ec19])
|
||||
1x sign flip
|
||||
|
||||
There is no support for signed zero whose sign is removed
|
||||
There is no support for floating point exception handling
|
||||
This code is based on the ARM Limited Optimized Routines. */
|
||||
inline static __m256
|
||||
ggml_vtanhf(__m256 x)
|
||||
{
|
||||
const __m256 abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF));
|
||||
const __m256 one = _mm256_set1_ps(1);
|
||||
const __m256 two = _mm256_set1_ps(2);
|
||||
const __m256 ax = _mm256_and_ps(x, abs_mask);
|
||||
const __m256 sign = _mm256_and_ps(x, _mm256_set1_ps(-0.f));
|
||||
const __m256 is_boring = _mm256_cmp_ps(ax, _mm256_set1_ps(0x1.205966p+3), _CMP_GT_OQ);
|
||||
const __m256 boring = _mm256_or_ps(sign, one);
|
||||
const __m256 ex = _mm256_mul_ps(x, two);
|
||||
const __m256 j = MADD256(ex, _mm256_set1_ps(0x1.715476p+0f), _mm256_set1_ps(0x1.8p23f));
|
||||
const __m256 jj = _mm256_sub_ps(j, _mm256_set1_ps(0x1.8p23f));
|
||||
const __m256i i = _mm256_cvttps_epi32(jj);
|
||||
const __m256 f = NMADD256(_mm256_set1_ps(0x1.62e4p-1f), jj, ex);
|
||||
const __m256 f1 = NMADD256(_mm256_set1_ps(0x1.7f7d1cp-20f), jj, f);
|
||||
const __m256 f2 = _mm256_mul_ps(f1, f1);
|
||||
const __m256 f4 = _mm256_mul_ps(f2, f2);
|
||||
const __m256 p01 = MADD256(f1, _mm256_set1_ps(0x1.5554aep-3), _mm256_set1_ps(0x1.fffffep-2));
|
||||
const __m256 p23 = MADD256(f1, _mm256_set1_ps(0x1.12287cp-7), _mm256_set1_ps(0x1.555736p-5));
|
||||
const __m256 p03 = MADD256(f2, p23, p01);
|
||||
const __m256 p = MADD256(f4, _mm256_set1_ps(0x1.6b55a2p-10), p03);
|
||||
const __m256 p2 = MADD256(f2, p, f1);
|
||||
const __m256i u = _mm256_add_epi32(_mm256_slli_epi32(i, 23), _mm256_set1_epi32(0x3f800000));
|
||||
const __m256 t = _mm256_castsi256_ps(u);
|
||||
const __m256 q = MADD256(p2, t, _mm256_sub_ps(t, one));
|
||||
const __m256 y = _mm256_div_ps(q, _mm256_add_ps(q, two));
|
||||
return _mm256_or_ps(_mm256_and_ps(is_boring, boring), _mm256_andnot_ps(is_boring, y));
|
||||
}
|
||||
|
||||
inline static __m256
|
||||
ggml_vgeluf(__m256 x)
|
||||
{
|
||||
const __m256 one = _mm256_set1_ps(1);
|
||||
const __m256 half = _mm256_set1_ps(.5);
|
||||
const __m256 coef_a = _mm256_set1_ps(GELU_COEF_A);
|
||||
const __m256 sqrt_2_over_pi = _mm256_set1_ps(SQRT_2_OVER_PI);
|
||||
const __m256 x_squared = _mm256_mul_ps(x, x);
|
||||
const __m256 ax2 = _mm256_mul_ps(coef_a, x_squared);
|
||||
const __m256 one_plus_ax2 = _mm256_add_ps(one, ax2);
|
||||
const __m256 inner = _mm256_mul_ps(_mm256_mul_ps(sqrt_2_over_pi, x), one_plus_ax2);
|
||||
const __m256 tanh_inner = ggml_vtanhf(inner);
|
||||
const __m256 one_plus_tanh = _mm256_add_ps(one, tanh_inner);
|
||||
return _mm256_mul_ps(_mm256_mul_ps(half, x), one_plus_tanh);
|
||||
}
|
||||
|
||||
#elif defined(__SSE2__)
|
||||
|
||||
/* Approximation for single-precision vector tanh(x) using a
|
||||
branchless algorithm that offers a maximum error of 4 ULP
|
||||
|
||||
Without fused multiply add:
|
||||
|
||||
108479590x off by one errors
|
||||
18209645x 2 to 3 ulp errors
|
||||
70x 4 ulp errors (e.g. 0.205979 [3e52ec19])
|
||||
1x sign flip
|
||||
|
||||
With fused multiply add:
|
||||
|
||||
108638843x off by one errors
|
||||
18273656x 2 to 3 ulp errors
|
||||
124x 4 ulp erors (e.g. 0.203652 [3e508a10])
|
||||
1x sign flip
|
||||
|
||||
There is no support for signed zero whose sign is removed
|
||||
There is no support for floating point exception handling
|
||||
This code is based on the ARM Limited Optimized Routines. */
|
||||
inline static __m128
|
||||
ggml_vtanhf(__m128 x)
|
||||
{
|
||||
const __m128 abs_mask = _mm_castsi128_ps(_mm_set1_epi32(0x7FFFFFFF));
|
||||
const __m128 one = _mm_set1_ps(1);
|
||||
const __m128 two = _mm_set1_ps(2);
|
||||
const __m128 ax = _mm_and_ps(x, abs_mask);
|
||||
const __m128 sign = _mm_and_ps(x, _mm_set1_ps(-0.f));
|
||||
const __m128 is_boring = _mm_cmpgt_ps(ax, _mm_set1_ps(0x1.205966p+3));
|
||||
const __m128 boring = _mm_or_ps(sign, one);
|
||||
const __m128 ex = _mm_mul_ps(x, two);
|
||||
const __m128 j = MADD128(ex, _mm_set1_ps(0x1.715476p+0f), _mm_set1_ps(0x1.8p23f));
|
||||
const __m128 jj = _mm_sub_ps(j, _mm_set1_ps(0x1.8p23f));
|
||||
const __m128i i = _mm_cvttps_epi32(jj);
|
||||
const __m128 f = NMADD128(_mm_set1_ps(0x1.62e4p-1f), jj, ex);
|
||||
const __m128 f1 = NMADD128(_mm_set1_ps(0x1.7f7d1cp-20f), jj, f);
|
||||
const __m128 f2 = _mm_mul_ps(f1, f1);
|
||||
const __m128 f4 = _mm_mul_ps(f2, f2);
|
||||
const __m128 p01 = MADD128(f1, _mm_set1_ps(0x1.5554aep-3), _mm_set1_ps(0x1.fffffep-2));
|
||||
const __m128 p23 = MADD128(f1, _mm_set1_ps(0x1.12287cp-7), _mm_set1_ps(0x1.555736p-5));
|
||||
const __m128 p03 = MADD128(f2, p23, p01);
|
||||
const __m128 p = MADD128(f4, _mm_set1_ps(0x1.6b55a2p-10), p03);
|
||||
const __m128 p2 = MADD128(f2, p, f1);
|
||||
const __m128i u = _mm_add_epi32(_mm_slli_epi32(i, 23), _mm_set1_epi32(0x3f800000));
|
||||
const __m128 t = _mm_castsi128_ps(u);
|
||||
const __m128 q = MADD128(p2, t, _mm_sub_ps(t, one));
|
||||
const __m128 y = _mm_div_ps(q, _mm_add_ps(q, two));
|
||||
return _mm_or_ps(_mm_and_ps(is_boring, boring), _mm_andnot_ps(is_boring, y));
|
||||
}
|
||||
|
||||
inline static __m128
|
||||
ggml_vgeluf(__m128 x)
|
||||
{
|
||||
const __m128 one = _mm_set1_ps(1);
|
||||
const __m128 half = _mm_set1_ps(.5);
|
||||
const __m128 coef_a = _mm_set1_ps(GELU_COEF_A);
|
||||
const __m128 sqrt_2_over_pi = _mm_set1_ps(SQRT_2_OVER_PI);
|
||||
const __m128 x_squared = _mm_mul_ps(x, x);
|
||||
const __m128 ax2 = _mm_mul_ps(coef_a, x_squared);
|
||||
const __m128 one_plus_ax2 = _mm_add_ps(one, ax2);
|
||||
const __m128 inner = _mm_mul_ps(_mm_mul_ps(sqrt_2_over_pi, x), one_plus_ax2);
|
||||
const __m128 tanh_inner = ggml_vtanhf(inner);
|
||||
const __m128 one_plus_tanh = _mm_add_ps(one, tanh_inner);
|
||||
return _mm_mul_ps(_mm_mul_ps(half, x), one_plus_tanh);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
|
||||
int i = 0;
|
||||
#if defined(__ARM_NEON) && defined(__aarch64__)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
vst1q_f32(y + i, ggml_vgeluf(vld1q_f32(x + i)));
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_GELU_FP16
|
||||
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
|
||||
uint16_t t;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (x[i] <= -10.0f) {
|
||||
y[i] = 0.0f;
|
||||
} else if (x[i] >= 10.0f) {
|
||||
y[i] = x[i];
|
||||
} else {
|
||||
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
|
||||
memcpy(&t, &fp16, sizeof(uint16_t));
|
||||
y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
|
||||
if (i < n) {
|
||||
float temp_x[4] = {0};
|
||||
float temp_y[4] = {0};
|
||||
int rem = n - i;
|
||||
for (int j = 0; j < rem; j++) {
|
||||
temp_x[j] = x[i + j];
|
||||
}
|
||||
float32x4_t x_vec = vld1q_f32(temp_x);
|
||||
float32x4_t y_vec = ggml_vgeluf(x_vec);
|
||||
vst1q_f32(temp_y, y_vec);
|
||||
for (int j = 0; j < rem; j++) {
|
||||
y[i + j] = temp_y[j];
|
||||
}
|
||||
}
|
||||
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
for (; i + 15 < n; i += 16) {
|
||||
_mm512_storeu_ps(y + i, ggml_vgeluf(_mm512_loadu_ps(x + i)));
|
||||
}
|
||||
if (i < n) {
|
||||
__mmask16 mask = _cvtu32_mask16((1U << (n - i)) - 1);
|
||||
__m512 x_vec = _mm512_maskz_loadu_ps(mask, x + i);
|
||||
__m512 y_vec = ggml_vgeluf(x_vec);
|
||||
_mm512_mask_storeu_ps(y + i, mask, y_vec);
|
||||
}
|
||||
return;
|
||||
#elif defined(__AVX2__)
|
||||
for (; i + 7 < n; i += 8) {
|
||||
_mm256_storeu_ps(y + i, ggml_vgeluf(_mm256_loadu_ps(x + i)));
|
||||
}
|
||||
if (i < n) {
|
||||
__m256i mask = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
|
||||
mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(n - i), mask);
|
||||
__m256 x_vec = _mm256_maskload_ps(x + i, mask);
|
||||
__m256 y_vec = ggml_vgeluf(x_vec);
|
||||
_mm256_maskstore_ps(y + i, mask, y_vec);
|
||||
}
|
||||
#elif defined(__SSE2__)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
_mm_storeu_ps(y + i, ggml_vgeluf(_mm_loadu_ps(x + i)));
|
||||
}
|
||||
if (i < n) {
|
||||
float temp_x[4] = {0};
|
||||
float temp_y[4] = {0};
|
||||
int rem = n - i;
|
||||
for (int j = 0; j < rem; j++) {
|
||||
temp_x[j] = x[i + j];
|
||||
}
|
||||
__m128 x_vec = _mm_loadu_ps(temp_x);
|
||||
__m128 y_vec = ggml_vgeluf(x_vec);
|
||||
_mm_storeu_ps(temp_y, y_vec);
|
||||
for (int j = 0; j < rem; j++) {
|
||||
y[i + j] = temp_y[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
for (; i < n; ++i) {
|
||||
y[i] = ggml_gelu_f32(x[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static float ggml_gelu_quick_f32(float x) {
|
||||
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
|
||||
}
|
||||
|
||||
//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
||||
// const uint16_t * i16 = (const uint16_t *) x;
|
||||
// for (int i = 0; i < n; ++i) {
|
||||
// y[i] = ggml_table_gelu_quick_f16[i16[i]];
|
||||
// }
|
||||
//}
|
||||
|
||||
#ifdef GGML_GELU_QUICK_FP16
|
||||
inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
|
||||
uint16_t t;
|
||||
@ -2782,14 +3079,6 @@ inline static __m256 ggml_v_silu(__m256 x) {
|
||||
|
||||
#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
|
||||
|
||||
#if defined(__FMA__)
|
||||
#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
|
||||
#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
|
||||
#else
|
||||
#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
|
||||
#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
|
||||
#endif
|
||||
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
// numbers above 88.38 will flush to infinity
|
||||
@ -3831,7 +4120,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
||||
ggml_fp16_t fp16;
|
||||
} u = {i};
|
||||
float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
|
||||
ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
|
||||
ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
|
||||
}
|
||||
|
||||
@ -22102,18 +22390,46 @@ static size_t gguf_type_size(enum gguf_type type) {
|
||||
return GGUF_TYPE_SIZE[type];
|
||||
}
|
||||
|
||||
static void gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
|
||||
GGML_ASSERT(info->n_dims <= GGML_MAX_DIMS);
|
||||
GGML_ASSERT(0 <= info->type && info->type < GGML_TYPE_COUNT);
|
||||
static bool gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
|
||||
if (info->n_dims > GGML_MAX_DIMS) {
|
||||
fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (info->type < 0 || info->type >= GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (strlen(info->name.data) >= GGML_MAX_NAME) {
|
||||
fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data);
|
||||
return false;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < info->n_dims; ++i) {
|
||||
GGML_ASSERT(info->ne[i] > 0);
|
||||
if (info->ne[i] <= 0) {
|
||||
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// prevent overflow for total number of elements
|
||||
GGML_ASSERT(INT64_MAX/info->ne[1] > info->ne[0]);
|
||||
GGML_ASSERT(INT64_MAX/info->ne[2] > info->ne[0]*info->ne[1]);
|
||||
GGML_ASSERT(INT64_MAX/info->ne[3] > info->ne[0]*info->ne[1]*info->ne[2]);
|
||||
if (INT64_MAX/info->ne[1] <= info->ne[0]) {
|
||||
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) {
|
||||
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) {
|
||||
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
|
||||
@ -22414,8 +22730,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||
ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
|
||||
ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
|
||||
|
||||
// TODO: return an error instead of crashing with GGML_ASSERT
|
||||
gguf_tensor_info_sanitize(info);
|
||||
ok = ok && gguf_tensor_info_sanitize(info);
|
||||
|
||||
// make sure there is no duplicated tensor names
|
||||
for (uint64_t j = 0; j < i && ok; ++j) {
|
||||
|
@ -15,6 +15,7 @@
|
||||
#define TWOPI_F 6.283185307179586f
|
||||
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE 12
|
||||
|
||||
#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
|
||||
#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
|
||||
@ -64,6 +65,14 @@ mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
|
||||
return reg;
|
||||
}
|
||||
|
||||
#define sizeof_block_q4_k 144
|
||||
struct block_q4_k {
|
||||
float16_t d;
|
||||
float16_t dmin;
|
||||
uint8_t scales[K_SCALE_SIZE];
|
||||
uint8_t qs[QK_K/2];
|
||||
};
|
||||
|
||||
#define sizeof_block_q6_k 210
|
||||
struct block_q6_k {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
|
133
ggml/src/kompute-shaders/op_mul_mat_q4_k.comp
Normal file
133
ggml/src/kompute-shaders/op_mul_mat_q4_k.comp
Normal file
@ -0,0 +1,133 @@
|
||||
#version 450
|
||||
|
||||
#include "common.comp"
|
||||
|
||||
#define N_DST 4
|
||||
#define SIZE_OF_BLOCK sizeof_block_q4_k
|
||||
|
||||
layout(local_size_x = 4) in;
|
||||
layout(local_size_y = 8) in;
|
||||
layout(local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; };
|
||||
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
|
||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint inAOff;
|
||||
uint inBOff;
|
||||
uint outOff;
|
||||
int ne00;
|
||||
int ne10;
|
||||
int ne0;
|
||||
int ne1;
|
||||
int ne01;
|
||||
int ne02;
|
||||
int ne12;
|
||||
int r2;
|
||||
int r3;
|
||||
} pcs;
|
||||
|
||||
void main() {
|
||||
const uint16_t kmask1 = uint16_t(0x3f3f);
|
||||
const uint16_t kmask2 = uint16_t(0x0f0f);
|
||||
const uint16_t kmask3 = uint16_t(0xc0c0);
|
||||
|
||||
const uint ix = gl_SubgroupInvocationID/8; // 0...3
|
||||
const uint it = gl_SubgroupInvocationID%8; // 0...7
|
||||
const uint iq = it/4; // 0 or 1
|
||||
const uint ir = it%4; // 0...3
|
||||
|
||||
const uint nb = pcs.ne00/QK_K;
|
||||
|
||||
const uint r0 = gl_WorkGroupID.x;
|
||||
const uint r1 = gl_WorkGroupID.y;
|
||||
const uint im = gl_WorkGroupID.z;
|
||||
|
||||
const uint first_row = r0 * N_DST;
|
||||
const uint ib_row = first_row * nb;
|
||||
|
||||
const uint i12 = im%pcs.ne12;
|
||||
const uint i13 = im/pcs.ne12;
|
||||
|
||||
const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
|
||||
|
||||
const uint xblk = ib_row + offset0 + pcs.inAOff;
|
||||
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;
|
||||
|
||||
float yl[16];
|
||||
float yh[16];
|
||||
float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f};
|
||||
float all_sum = 0.f;
|
||||
|
||||
uint y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
||||
|
||||
for (uint ib = ix; ib < nb; ib += 4) {
|
||||
const uint blk_idx = ib + xblk;
|
||||
|
||||
float sumy[4] = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0];
|
||||
yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8];
|
||||
yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0];
|
||||
yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8];
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
uint row_idx = row * nb;
|
||||
|
||||
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
|
||||
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
|
||||
uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4);
|
||||
uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6);
|
||||
uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8);
|
||||
|
||||
uint16_t sc16[4];
|
||||
sc16[0] = sc_0 & kmask1;
|
||||
sc16[1] = sc_2 & kmask1;
|
||||
sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2);
|
||||
sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2);
|
||||
|
||||
float acc1[4] = {0.f, 0.f, 0.f, 0.f};
|
||||
float acc2[4] = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i);
|
||||
uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i);
|
||||
acc1[0] += yl[i+0] * (q1 & 0x000F);
|
||||
acc1[1] += yl[i+1] * (q1 & 0x0F00);
|
||||
acc1[2] += yl[i+8] * (q1 & 0x00F0);
|
||||
acc1[3] += yl[i+9] * (q1 & 0xF000);
|
||||
acc2[0] += yh[i+0] * (q2 & 0x000F);
|
||||
acc2[1] += yh[i+1] * (q2 & 0x0F00);
|
||||
acc2[2] += yh[i+8] * (q2 & 0x00F0);
|
||||
acc2[3] += yh[i+9] * (q2 & 0xF000);
|
||||
}
|
||||
|
||||
uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF);
|
||||
uint8_t sc8_1 = uint8_t(sc16[0] >> 8 );
|
||||
uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF);
|
||||
uint8_t sc8_3 = uint8_t(sc16[1] >> 8 );
|
||||
uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF);
|
||||
uint8_t sc8_5 = uint8_t(sc16[2] >> 8 );
|
||||
uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF);
|
||||
uint8_t sc8_7 = uint8_t(sc16[3] >> 8 );
|
||||
|
||||
float dall = float(inA[blk_idx + row_idx].d);
|
||||
float dmin = float(inA[blk_idx + row_idx].dmin);
|
||||
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 +
|
||||
(acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f +
|
||||
(acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 +
|
||||
(acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) -
|
||||
dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7);
|
||||
}
|
||||
|
||||
y4 += 4 * QK_K;
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = subgroupAdd(sumf[row]);
|
||||
if (subgroupElect()) {
|
||||
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
@ -4273,8 +4273,11 @@ struct llama_model_loader {
|
||||
|
||||
llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
|
||||
const int tensor_idx = gguf_find_tensor(gguf_ctx, name);
|
||||
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
|
||||
if (tensor_idx < 0) {
|
||||
throw std::runtime_error(format("tensor '%s' not found in the model", name));
|
||||
}
|
||||
|
||||
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
|
||||
if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) {
|
||||
throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name));
|
||||
}
|
||||
@ -7426,7 +7429,7 @@ static bool llm_load_tensors(
|
||||
if (flags & llama_model_loader::TENSOR_NOT_REQUIRED) {
|
||||
return nullptr;
|
||||
}
|
||||
throw std::runtime_error(format("missing tensor %s", tn.str().c_str()));
|
||||
throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str()));
|
||||
}
|
||||
|
||||
// some models use the token embedding tensor as the output, but since these are used in different layers and with different ops
|
||||
|
Loading…
Reference in New Issue
Block a user