mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 02:14:35 +00:00
ggml : introduce bfloat16 support (#6412)
* Introduce bfloat16 support Many models on Hugging Face (e.g. Mistral, TinyLLaMA) use bfloat16 as their canonical floating point format. ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌──┴───┐┌─┴───┐ 0b0000000000000000 brain16 This encoding has the same number of exponent bits as float32. That makes conversion relatively straightforward, even in the absence of hardware support. For example, converting brain16 to binary32 means simply shifting 16 bits to the left. ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌──┴───┐┌─┴───────────────────┐ 0b00000000000000000000000000000000 IEEE binary32 The issue is that converting bf16 to fp16 can result in information loss. Only 13% of bf16 numbers can be precisely represented in fp16 which in practice ends up being 99.71% of Mistral 7b v0.2's weights however there is currently no way other than fp32 to get the others ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌─┴─┐┌─┴──────┐ 0b0000000000000000 IEEE binary16 This change fixes that, by adding a bf16 data type to GGML. Support for CPU inference has been implemented along with optimizations for the AVX2, AVX512, and AVX512BF16 ISAs. Perplexity on Mistral 7b 0.2 improves somewhere around -0.0024 to -0.0046 compared to using fp16 * Remove GGML code that's not needed * Minimize the GGML API surface area for BF16 * Remove bf16 luts * Make the GGML header look nicer * Fix documentation * Apply ggerganov's fixes for test-backend-ops * Add BF16 code for new ggml_validate_row_data() function
This commit is contained in:
parent
c0e6fbf8c3
commit
3855416027
@ -575,7 +575,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
|
||||
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
|
||||
|
||||
auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
|
||||
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16) {
|
||||
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16) {
|
||||
return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
|
||||
} else if (a->type == GGML_TYPE_F32) {
|
||||
return ggml_add(ctx, a, b);
|
||||
|
@ -46,7 +46,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
|
||||
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", },
|
||||
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", },
|
||||
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", },
|
||||
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "13.00G @ 7B", },
|
||||
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, -0.0020 ppl @ Mistral-7B", },
|
||||
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", },
|
||||
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },
|
||||
// Note: Ensure COPY comes after F32 to avoid ftype 0 from matching.
|
||||
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
|
||||
|
77
ggml-impl.h
77
ggml-impl.h
@ -17,6 +17,83 @@
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
/**
|
||||
* Converts brain16 to float32.
|
||||
*
|
||||
* The bfloat16 floating point format has the following structure:
|
||||
*
|
||||
* ┌sign
|
||||
* │
|
||||
* │ ┌exponent
|
||||
* │ │
|
||||
* │ │ ┌mantissa
|
||||
* │ │ │
|
||||
* │┌──┴───┐┌─┴───┐
|
||||
* 0b0000000000000000 brain16
|
||||
*
|
||||
* Since bf16 has the same number of exponent bits as a 32bit float,
|
||||
* encoding and decoding numbers becomes relatively straightforward.
|
||||
*
|
||||
* ┌sign
|
||||
* │
|
||||
* │ ┌exponent
|
||||
* │ │
|
||||
* │ │ ┌mantissa
|
||||
* │ │ │
|
||||
* │┌──┴───┐┌─┴───────────────────┐
|
||||
* 0b00000000000000000000000000000000 IEEE binary32
|
||||
*
|
||||
* For comparison, the standard fp16 format has fewer exponent bits.
|
||||
*
|
||||
* ┌sign
|
||||
* │
|
||||
* │ ┌exponent
|
||||
* │ │
|
||||
* │ │ ┌mantissa
|
||||
* │ │ │
|
||||
* │┌─┴─┐┌─┴──────┐
|
||||
* 0b0000000000000000 IEEE binary16
|
||||
*
|
||||
* @see IEEE 754-2008
|
||||
*/
|
||||
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t i;
|
||||
} u;
|
||||
u.i = (uint32_t)h.bits << 16;
|
||||
return u.f;
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts float32 to brain16.
|
||||
*
|
||||
* This function is binary identical to AMD Zen4 VCVTNEPS2BF16.
|
||||
* Subnormals shall be flushed to zero, and NANs will be quiet.
|
||||
* This code should vectorize nicely if using modern compilers.
|
||||
*/
|
||||
static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
|
||||
ggml_bf16_t h;
|
||||
union {
|
||||
float f;
|
||||
uint32_t i;
|
||||
} u;
|
||||
u.f = s;
|
||||
if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
|
||||
h.bits = (u.i >> 16) | 64; /* force to quiet */
|
||||
return h;
|
||||
}
|
||||
if (!(u.i & 0x7f800000)) { /* subnormal */
|
||||
h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */
|
||||
return h;
|
||||
}
|
||||
h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
|
||||
return h;
|
||||
}
|
||||
|
||||
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
|
||||
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
@ -803,7 +803,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
return op->ne[3] == 1;
|
||||
return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
|
@ -12450,6 +12450,24 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
||||
const size_t nb = nbytes/ggml_type_size(type);
|
||||
|
||||
switch (type) {
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
int nans = 0;
|
||||
int infs = 0;
|
||||
const unsigned short * f = (const unsigned short *) data;
|
||||
for (size_t i = 0; i < nb; ++i) {
|
||||
nans += (f[i] & 0x7fff) > 0x7f80;
|
||||
infs += (f[i] & 0x7fff) == 0x7f80;
|
||||
}
|
||||
if (nans) {
|
||||
fprintf(stderr, "%s: found %d NaNs in row of %zu BF16 values\n", __func__, nans, nb);
|
||||
return false;
|
||||
}
|
||||
if (infs) {
|
||||
fprintf(stderr, "%s: found %d infinities in row of %zu BF16 values\n", __func__, infs, nb);
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
const ggml_fp16_t * f = (const ggml_fp16_t *) data;
|
||||
|
20
ggml.h
20
ggml.h
@ -326,14 +326,20 @@ extern "C" {
|
||||
// get ggml_status name string
|
||||
GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status);
|
||||
|
||||
// ieee 754-2008 half-precision float16
|
||||
// todo: make this not an integral type
|
||||
typedef uint16_t ggml_fp16_t;
|
||||
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t);
|
||||
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float);
|
||||
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t);
|
||||
GGML_API void ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t);
|
||||
|
||||
// convert FP16 <-> FP32
|
||||
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
|
||||
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
|
||||
|
||||
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n);
|
||||
GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n);
|
||||
// google brain half-precision bfloat16
|
||||
typedef struct { uint16_t bits; } ggml_bf16_t;
|
||||
GGML_API ggml_bf16_t ggml_fp32_to_bf16(float);
|
||||
GGML_API float ggml_bf16_to_fp32(ggml_bf16_t); // consider just doing << 16
|
||||
GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t);
|
||||
GGML_API void ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t);
|
||||
|
||||
struct ggml_object;
|
||||
struct ggml_context;
|
||||
@ -370,6 +376,7 @@ extern "C" {
|
||||
GGML_TYPE_I64 = 27,
|
||||
GGML_TYPE_F64 = 28,
|
||||
GGML_TYPE_IQ1_M = 29,
|
||||
GGML_TYPE_BF16 = 30,
|
||||
GGML_TYPE_COUNT,
|
||||
};
|
||||
|
||||
@ -410,6 +417,7 @@ extern "C" {
|
||||
GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
|
||||
};
|
||||
|
||||
// available tensor operations:
|
||||
|
@ -817,6 +817,7 @@ class GGMLQuantizationType(IntEnum):
|
||||
I64 = 27
|
||||
F64 = 28
|
||||
IQ1_M = 29
|
||||
BF16 = 30
|
||||
|
||||
|
||||
class GGUFEndian(IntEnum):
|
||||
@ -888,6 +889,7 @@ GGML_QUANT_SIZES = {
|
||||
GGMLQuantizationType.I64: (1, 8),
|
||||
GGMLQuantizationType.F64: (1, 8),
|
||||
GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
|
||||
GGMLQuantizationType.BF16: (1, 2),
|
||||
}
|
||||
|
||||
|
||||
|
20
llama.cpp
20
llama.cpp
@ -3175,6 +3175,7 @@ struct llama_model_loader {
|
||||
switch (type_max) {
|
||||
case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break;
|
||||
case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break;
|
||||
case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break;
|
||||
case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break;
|
||||
case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break;
|
||||
case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break;
|
||||
@ -3666,6 +3667,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
|
||||
switch (ftype) {
|
||||
case LLAMA_FTYPE_ALL_F32: return "all F32";
|
||||
case LLAMA_FTYPE_MOSTLY_F16: return "F16";
|
||||
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
|
||||
@ -6129,6 +6131,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
||||
|| !(
|
||||
model.ftype == LLAMA_FTYPE_ALL_F32 ||
|
||||
model.ftype == LLAMA_FTYPE_MOSTLY_F16 ||
|
||||
model.ftype == LLAMA_FTYPE_MOSTLY_BF16 ||
|
||||
model.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ||
|
||||
model.ftype == LLAMA_FTYPE_MOSTLY_Q4_1
|
||||
)
|
||||
@ -14158,13 +14161,16 @@ static void llama_tensor_dequantize_internal(
|
||||
if (qtype.to_float == NULL) {
|
||||
throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
|
||||
}
|
||||
} else if (tensor->type != GGML_TYPE_F16) {
|
||||
} else if (tensor->type != GGML_TYPE_F16 &&
|
||||
tensor->type != GGML_TYPE_BF16) {
|
||||
throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
|
||||
}
|
||||
|
||||
if (nthread < 2) {
|
||||
if (tensor->type == GGML_TYPE_F16) {
|
||||
ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
|
||||
} else if (tensor->type == GGML_TYPE_BF16) {
|
||||
ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
|
||||
} else if (ggml_is_quantized(tensor->type)) {
|
||||
qtype.to_float(tensor->data, f32_output, nelements);
|
||||
} else {
|
||||
@ -14173,7 +14179,14 @@ static void llama_tensor_dequantize_internal(
|
||||
return;
|
||||
}
|
||||
|
||||
size_t block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type);
|
||||
size_t block_size;
|
||||
if (tensor->type == GGML_TYPE_F16 ||
|
||||
tensor->type == GGML_TYPE_BF16) {
|
||||
block_size = 1;
|
||||
} else {
|
||||
block_size = (size_t)ggml_blck_size(tensor->type);
|
||||
}
|
||||
|
||||
size_t block_size_bytes = ggml_type_size(tensor->type);
|
||||
|
||||
GGML_ASSERT(nelements % block_size == 0);
|
||||
@ -14192,6 +14205,8 @@ static void llama_tensor_dequantize_internal(
|
||||
auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
|
||||
if (typ == GGML_TYPE_F16) {
|
||||
ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
|
||||
} else if (typ == GGML_TYPE_BF16) {
|
||||
ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
|
||||
} else {
|
||||
qtype.to_float(inbuf, outbuf, nels);
|
||||
}
|
||||
@ -14552,6 +14567,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
|
||||
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
|
||||
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
|
||||
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
|
||||
|
||||
// K-quants
|
||||
|
1
llama.h
1
llama.h
@ -137,6 +137,7 @@ extern "C" {
|
||||
LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
|
||||
|
||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||
};
|
||||
|
@ -50,7 +50,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
|
||||
|
||||
if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
|
||||
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
|
||||
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) {
|
||||
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) {
|
||||
GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
|
||||
std::vector<uint8_t> dataq(ggml_row_size(tensor->type, size));
|
||||
std::vector<float> imatrix(tensor->ne[0], 1.0f); // dummy importance matrix
|
||||
@ -92,6 +92,8 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
|
||||
size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];
|
||||
if (t->type == GGML_TYPE_F16) {
|
||||
tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]));
|
||||
} else if (t->type == GGML_TYPE_BF16) {
|
||||
tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i]));
|
||||
} else if (t->type == GGML_TYPE_F32) {
|
||||
tv.push_back(*(float *) &buf[i]);
|
||||
} else if (t->type == GGML_TYPE_I32) {
|
||||
@ -1898,7 +1900,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
std::default_random_engine rng(0);
|
||||
|
||||
const ggml_type all_types[] = {
|
||||
GGML_TYPE_F32, GGML_TYPE_F16,
|
||||
GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
|
||||
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q8_0,
|
||||
|
Loading…
Reference in New Issue
Block a user