mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
finish f16 hf bitnet e2e
This commit is contained in:
parent
1f2e0ee012
commit
5e59660173
@ -1421,7 +1421,31 @@ class BitnetModel(Model):
|
||||
result = (weight * s).round().clamp(-1, 1) / s
|
||||
return result.type(dtype)
|
||||
|
||||
def transform_to_i2(self, x):
|
||||
from gguf.lazy import LazyNumpyTensor
|
||||
x = LazyNumpyTensor.to_eager(x)
|
||||
x_num = np.prod(x.shape)
|
||||
x = np.reshape(x, x_num)
|
||||
scale = 1
|
||||
for i in range(x_num):
|
||||
if x[i] != 0:
|
||||
scale = x[i]
|
||||
break
|
||||
x = np.divide(x, scale)
|
||||
x = x.astype(np.uint8)
|
||||
x = np.reshape(x, [x.shape[0] // 4, 4])
|
||||
keep_bit = {0:192, 1:48, 2:12, 3:3}
|
||||
ans = np.zeros([x_num // 4], dtype=np.uint8)
|
||||
for i in range(4):
|
||||
x_bit_col = x[:, i]
|
||||
x_bit_shift = np.left_shift(x_bit_col, 6 - i * 2)
|
||||
x_bit_shift = np.bitwise_and(x_bit_shift, keep_bit[i])
|
||||
ans = np.bitwise_or(ans, x_bit_shift)
|
||||
scale = np.tile(scale, 8)
|
||||
return ans, scale
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# quant weight to i2 (in fp16)
|
||||
if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight",
|
||||
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
|
||||
"o_proj.weight")):
|
||||
@ -1429,6 +1453,103 @@ class BitnetModel(Model):
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def write_tensors(self):
|
||||
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
|
||||
|
||||
for name, data_torch in self.get_tensors():
|
||||
# we don't need these
|
||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||
continue
|
||||
|
||||
old_dtype = data_torch.dtype
|
||||
|
||||
# convert any unsupported data types to float32
|
||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||
data_torch = data_torch.to(torch.float32)
|
||||
|
||||
# use the first number-like part of the tensor name as the block id
|
||||
bid = None
|
||||
for part in name.split("."):
|
||||
if part.isdecimal():
|
||||
bid = int(part)
|
||||
break
|
||||
|
||||
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
|
||||
data: np.ndarray = data # type hint
|
||||
data_shape = data.shape
|
||||
n_dims = len(data.shape)
|
||||
data_dtype = data.dtype
|
||||
data_qtype: gguf.GGMLQuantizationType | None = None
|
||||
|
||||
# when both are True, f32 should win
|
||||
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
|
||||
extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
|
||||
|
||||
# Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
|
||||
# Conditions should closely match those in llama_model_quantize_internal in llama.cpp
|
||||
extra_f32 = any(cond for cond in (
|
||||
extra_f32,
|
||||
n_dims == 1,
|
||||
new_name.endswith("_norm.weight"),
|
||||
))
|
||||
|
||||
# Some tensor types are always in float32
|
||||
extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in (
|
||||
gguf.MODEL_TENSOR.FFN_GATE_INP,
|
||||
gguf.MODEL_TENSOR.POS_EMBD,
|
||||
gguf.MODEL_TENSOR.TOKEN_TYPES,
|
||||
))
|
||||
|
||||
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||
extra_f16 = any(cond for cond in (
|
||||
extra_f16,
|
||||
(name.endswith(".weight") and n_dims >= 2),
|
||||
))
|
||||
|
||||
suit_i2 = True
|
||||
if name.endswith('embed_tokens.weight') or name.endswith('norm.weight'):
|
||||
suit_i2 = False
|
||||
|
||||
i2_scale = None
|
||||
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
|
||||
if self.ftype == gguf.LlamaFileType.MOSTLY_I2 and suit_i2:
|
||||
data, i2_scale = self.transform_to_i2(data)
|
||||
assert data.dtype == np.uint8
|
||||
assert i2_scale.dtype == np.float32
|
||||
data_qtype = gguf.GGMLQuantizationType.I2
|
||||
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
|
||||
data = gguf.quantize_bf16(data)
|
||||
assert data.dtype == np.int16
|
||||
data_qtype = gguf.GGMLQuantizationType.BF16
|
||||
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
|
||||
data = gguf.quantize_q8_0(data)
|
||||
assert data.dtype == np.uint8
|
||||
data_qtype = gguf.GGMLQuantizationType.Q8_0
|
||||
|
||||
else: # default to float16 for quantized tensors
|
||||
if data_dtype != np.float16:
|
||||
data = data.astype(np.float16)
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
|
||||
if data_qtype is None: # by default, convert to float32
|
||||
if data_dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
|
||||
shape = data_shape
|
||||
# shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
|
||||
# reverse shape to make it similar to the internal ggml dimension order
|
||||
shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}"
|
||||
|
||||
# n_dims is implicit in the shape
|
||||
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
|
||||
|
||||
self.gguf_writer.add_tensor(new_name, data, raw_shape=shape, raw_dtype=data_qtype)
|
||||
if i2_scale is not None:
|
||||
self.gguf_writer.add_tensor(new_name + "_scale", i2_scale, raw_dtype=gguf.GGMLQuantizationType.F32)
|
||||
|
||||
@Model.register("GrokForCausalLM")
|
||||
class GrokModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.GROK
|
||||
@ -2804,7 +2925,7 @@ def parse_args() -> argparse.Namespace:
|
||||
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "i2", "auto"], default="f16",
|
||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -2864,6 +2985,7 @@ def main() -> None:
|
||||
"f16": gguf.LlamaFileType.MOSTLY_F16,
|
||||
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
|
||||
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
|
||||
"i2" : gguf.LlamaFileType.MOSTLY_I2,
|
||||
"auto": gguf.LlamaFileType.GUESSED,
|
||||
}
|
||||
|
||||
|
@ -1016,6 +1016,73 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
|
||||
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
||||
GGML_TABLE_END()
|
||||
|
||||
GGML_TABLE_BEGIN(uint32_t, i2_q8, 256)
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||
0x00010100, 0x01010100, 0x00010100, 0xff010100,
|
||||
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||
0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||
0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00,
|
||||
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||
0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00,
|
||||
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||
0x00010001, 0x01010001, 0x00010001, 0xff010001,
|
||||
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||
0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001,
|
||||
0x00000101, 0x01000101, 0x00000101, 0xff000101,
|
||||
0x00010101, 0x01010101, 0x00010101, 0xff010101,
|
||||
0x00000101, 0x01000101, 0x00000101, 0xff000101,
|
||||
0x00ff0101, 0x01ff0101, 0x00ff0101, 0xffff0101,
|
||||
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||
0x00010001, 0x01010001, 0x00010001, 0xff010001,
|
||||
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||
0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001,
|
||||
0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01,
|
||||
0x0001ff01, 0x0101ff01, 0x0001ff01, 0xff01ff01,
|
||||
0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01,
|
||||
0x00ffff01, 0x01ffff01, 0x00ffff01, 0xffffff01,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||
0x00010100, 0x01010100, 0x00010100, 0xff010100,
|
||||
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||
0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||
0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00,
|
||||
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||
0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00,
|
||||
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||
0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff,
|
||||
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||
0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff,
|
||||
0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff,
|
||||
0x000101ff, 0x010101ff, 0x000101ff, 0xff0101ff,
|
||||
0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff,
|
||||
0x00ff01ff, 0x01ff01ff, 0x00ff01ff, 0xffff01ff,
|
||||
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||
0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff,
|
||||
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||
0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff,
|
||||
0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff,
|
||||
0x0001ffff, 0x0101ffff, 0x0001ffff, 0xff01ffff,
|
||||
0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff,
|
||||
0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff,
|
||||
GGML_TABLE_END()
|
||||
|
||||
#define NGRID_IQ1S 2048
|
||||
#define IQ1S_DELTA 0.125f
|
||||
#define IQ1M_DELTA 0.125f
|
||||
|
@ -3726,6 +3726,28 @@ static inline __m128i get_scale_shuffle(int i) {
|
||||
}
|
||||
#endif
|
||||
|
||||
//====================================== I2 ===============================================
|
||||
|
||||
void ggml_vec_dot_i2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
const uint8_t * restrict x = vx;
|
||||
const int8_t * restrict y = vy;
|
||||
|
||||
int sumi = 0;
|
||||
|
||||
for (int i = 0; i < n / 4; i++) {
|
||||
int8_t* weight = (const int8_t *)(i2_q8 + x[i]);
|
||||
sumi += (int)y[i*4+0] * weight[0];
|
||||
sumi += (int)y[i*4+1] * weight[1];
|
||||
sumi += (int)y[i*4+2] * weight[2];
|
||||
sumi += (int)y[i*4+3] * weight[3];
|
||||
}
|
||||
*s = (float)(sumi);
|
||||
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
@ -99,6 +99,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_i2_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
||||
size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
|
202
ggml.c
202
ggml.c
@ -569,6 +569,15 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t *
|
||||
static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
|
||||
|
||||
static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
[GGML_TYPE_I2] = {
|
||||
.type_name = "i2",
|
||||
.blck_size = 1,
|
||||
.type_size = sizeof(int8_t),
|
||||
.is_quantized = false,
|
||||
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_i2_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_I8] = {
|
||||
.type_name = "i8",
|
||||
.blck_size = 1,
|
||||
@ -1805,6 +1814,7 @@ inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)
|
||||
inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
|
||||
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
|
||||
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
|
||||
inline static void ggml_vec_mul_f32_bitnet (const int n, float * y, const float x) { for (int i = 0; i < n; ++i) y[i] = y[i] * x; }
|
||||
|
||||
static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
@ -2636,6 +2646,16 @@ inline static void ggml_vec_scaleroundclamp_f32(const int n, float * s, const fl
|
||||
s[i] /= scale;
|
||||
}
|
||||
}
|
||||
inline static void ggml_vec_scaleroundclamp_f32_v2(const int n, float * s, int8_t* inp, float scale, float min, float max) {
|
||||
|
||||
for (int i = 0; i < n; ++i) {
|
||||
s[i] = round(s[i] * scale);
|
||||
if (s[i] > max) s[i] = max;
|
||||
if (s[i] < min) s[i] = min;
|
||||
inp[i] = (int8_t)(s[i]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
// data types
|
||||
@ -3081,6 +3101,10 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
||||
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
|
||||
}
|
||||
if(tensor->type == 31){
|
||||
nbytes = nbytes / 4 + 32;
|
||||
}
|
||||
|
||||
}
|
||||
else {
|
||||
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
|
||||
@ -12411,7 +12435,10 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||
}
|
||||
|
||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
if (src0->type == 31) {
|
||||
row_size = ne10;
|
||||
}
|
||||
|
||||
assert(ne12 % ne02 == 0);
|
||||
assert(ne13 % ne03 == 0);
|
||||
@ -12425,6 +12452,9 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||
// attempt to reduce false-sharing (does not seem to make a difference)
|
||||
// 16 * 2, accounting for mmla kernels
|
||||
float tmp[32];
|
||||
uint8_t *i_weight = (uint8_t*) (src0->data);
|
||||
float * scale = (float * )((i_weight) + (ne00 * ne01 / 4));
|
||||
float* act_scales = (float*) ((char *) wdata + ((ne11*nb11) / 4));
|
||||
|
||||
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
||||
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
||||
@ -12458,9 +12488,15 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||
//}
|
||||
|
||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
||||
if (src0->type == 31) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
||||
tmp[ir0 - iir0] = tmp[ir0 - iir0] * (*scale) * (act_scales[i11]);
|
||||
}else {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
||||
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
|
||||
}
|
||||
@ -12469,6 +12505,164 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void ggml_compute_forward_bitnet_mul_mat(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst,
|
||||
struct ggml_compute_state * state) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
int64_t t0 = ggml_perf_time_us();
|
||||
UNUSED(t0);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
const bool src1_cont = ggml_is_contiguous(src1);
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
GGML_ASSERT(ne2 == ne12);
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t r2 = ne12 / ne02;
|
||||
const int64_t r3 = ne13 / ne03;
|
||||
UNUSED(r2);
|
||||
UNUSED(r3);
|
||||
|
||||
// nb01 >= nb00 - src0 is not transposed
|
||||
// compute by src0 rows
|
||||
if (params->type == GGML_TASK_TYPE_INIT) {
|
||||
if (ith != 0) {
|
||||
return;
|
||||
}
|
||||
atomic_store(&state->shared->current_chunk, nth);
|
||||
char * wdata = params->wdata;
|
||||
float* act_scales = (float*) ((char *) wdata + ((ne11*nb11) / 4));
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
||||
float rowmax = 0.00001;
|
||||
ggml_vec_absmaxclamp_f32(ne10, &rowmax, (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13), 0.00001);
|
||||
float s = 127 / rowmax;
|
||||
act_scales[i11] = 1/s;
|
||||
ggml_vec_scaleroundclamp_f32_v2(ne10,
|
||||
(float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13),
|
||||
(int8_t*) ((char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4)),
|
||||
s, -128, 127);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||
// atomic_store(&state->shared->current_chunk, nth);
|
||||
// // char * wdata = params->wdata;
|
||||
// const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, ne10);
|
||||
// // printf("vec_dot_type:%d\n", vec_dot_type);
|
||||
// // printf("row_size:%ld\n", row_size);
|
||||
// assert(params->wsize >= ne11*ne12*ne13*row_size);
|
||||
// GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
// for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
// for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
// for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
// quantize_row_q8_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
||||
// wdata += row_size;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
// This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
|
||||
const int64_t nr0 = ne0;
|
||||
|
||||
// This is the size of the rest of the dimensions of the result
|
||||
const int64_t nr1 = ne1 * ne2 * ne3;
|
||||
|
||||
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
||||
int64_t num_rows_per_vec_dot = 1;
|
||||
// TODO: currently the mmla kernels support only even numbered rows/cols.
|
||||
// this check can be removed once they are extended to support odd numbered rows/cols too
|
||||
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
|
||||
num_rows_per_vec_dot = 1;
|
||||
}
|
||||
|
||||
// Now select a reasonable chunk size.
|
||||
int chunk_size = 16;
|
||||
|
||||
// We need to step up the size if it's small
|
||||
if (nr0 == 1 || nr1 == 1) {
|
||||
chunk_size = 64;
|
||||
}
|
||||
|
||||
// distribute the work across the inner or outer loop based on which one is larger
|
||||
// The number of chunks in the 0/1 dim.
|
||||
// CEIL(nr0/chunk_size)
|
||||
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
|
||||
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
|
||||
|
||||
// If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
|
||||
// Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
|
||||
// In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
|
||||
if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
|
||||
// distribute the thread work across the inner or outer loop based on which one is larger
|
||||
nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
|
||||
nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
|
||||
}
|
||||
|
||||
// The number of elements in each chunk
|
||||
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
||||
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
|
||||
|
||||
//if (ith == 0)
|
||||
// printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
|
||||
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
||||
while (current_chunk < nchunk0 * nchunk1) {
|
||||
const int64_t ith0 = current_chunk % nchunk0;
|
||||
const int64_t ith1 = current_chunk / nchunk0;
|
||||
|
||||
const int64_t ir0_start = dr0 * ith0;
|
||||
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
|
||||
|
||||
const int64_t ir1_start = dr1 * ith1;
|
||||
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
||||
|
||||
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
|
||||
|
||||
if (nth >= nchunk0 * nchunk1) {
|
||||
break;
|
||||
}
|
||||
|
||||
current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_mul_mat(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst,
|
||||
@ -12482,6 +12676,11 @@ static void ggml_compute_forward_mul_mat(
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
if (src0->type == 31) {
|
||||
ggml_compute_forward_bitnet_mul_mat(params, dst, state);
|
||||
return;
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@ -14349,6 +14548,7 @@ static void ggml_compute_forward_clamp(
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_I64:
|
||||
case GGML_TYPE_F64:
|
||||
case GGML_TYPE_I2:
|
||||
case GGML_TYPE_COUNT:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
1
ggml.h
1
ggml.h
@ -377,6 +377,7 @@ extern "C" {
|
||||
GGML_TYPE_F64 = 28,
|
||||
GGML_TYPE_IQ1_M = 29,
|
||||
GGML_TYPE_BF16 = 30,
|
||||
GGML_TYPE_I2 = 31,
|
||||
GGML_TYPE_COUNT,
|
||||
};
|
||||
|
||||
|
@ -925,6 +925,7 @@ class GGMLQuantizationType(IntEnum):
|
||||
F64 = 28
|
||||
IQ1_M = 29
|
||||
BF16 = 30
|
||||
I2 = 31
|
||||
|
||||
|
||||
# TODO: add GGMLFileType from ggml_ftype in ggml.h
|
||||
@ -966,6 +967,7 @@ class LlamaFileType(IntEnum):
|
||||
MOSTLY_IQ4_XS = 30 # except 1d tensors
|
||||
MOSTLY_IQ1_M = 31 # except 1d tensors
|
||||
MOSTLY_BF16 = 32 # except 1d tensors
|
||||
MOSTLY_I2 = 33 # except 1d tensors
|
||||
|
||||
GUESSED = 1024 # not specified in the model file
|
||||
|
||||
@ -1032,6 +1034,7 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
||||
GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4),
|
||||
GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16),
|
||||
GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64),
|
||||
GGMLQuantizationType.I2: (1, 1),
|
||||
GGMLQuantizationType.I8: (1, 1),
|
||||
GGMLQuantizationType.I16: (1, 2),
|
||||
GGMLQuantizationType.I32: (1, 4),
|
||||
|
@ -225,8 +225,10 @@ class GGUFWriter:
|
||||
dtype = GGMLQuantizationType.I32
|
||||
elif tensor_dtype == np.int64:
|
||||
dtype = GGMLQuantizationType.I64
|
||||
elif tensor_dtype == np.uint8:
|
||||
dtype = GGMLQuantizationType.I2
|
||||
else:
|
||||
raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now")
|
||||
raise ValueError("Only F16, F32, F64, I8, I16, I32, I64, I2 tensors are supported for now")
|
||||
else:
|
||||
dtype = raw_dtype
|
||||
if tensor_dtype == np.uint8:
|
||||
@ -237,7 +239,10 @@ class GGUFWriter:
|
||||
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
|
||||
self.ti_data += self._pack("I", dtype)
|
||||
self.ti_data += self._pack("Q", self.offset_tensor)
|
||||
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
|
||||
if dtype == GGMLQuantizationType.I2:
|
||||
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) + self.data_alignment
|
||||
else:
|
||||
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
|
||||
self.ti_data_count += 1
|
||||
|
||||
def add_tensor(
|
||||
@ -252,7 +257,9 @@ class GGUFWriter:
|
||||
self.temp_file = fp
|
||||
|
||||
shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
|
||||
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
|
||||
|
||||
if (raw_dtype != GGMLQuantizationType.F32 or not name.endswith("scale")):
|
||||
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
|
||||
|
||||
if self.temp_file is None:
|
||||
self.tensors.append(tensor)
|
||||
|
17
llama.cpp
17
llama.cpp
@ -3192,8 +3192,9 @@ struct llama_model_loader {
|
||||
|
||||
llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
|
||||
const int tensor_idx = gguf_find_tensor(gguf_ctx, name);
|
||||
printf("name:%s\n", name);
|
||||
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
|
||||
|
||||
printf("offs:%ld\n", offs + ggml_nbytes(tensor));
|
||||
if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) {
|
||||
throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name));
|
||||
}
|
||||
@ -7140,7 +7141,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
cb(cur, "attn_sub_norm", il);
|
||||
|
||||
// B2 for wo
|
||||
cur = llm_build_qbitlinear(ctx, cur);
|
||||
// cur = llm_build_qbitlinear(ctx, cur);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(graph, cur);
|
||||
@ -11563,7 +11564,7 @@ struct llm_build_context {
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
// B1.Q
|
||||
cur = llm_build_qbitlinear(ctx0, cur);
|
||||
// cur = llm_build_qbitlinear(ctx0, cur);
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
@ -11635,7 +11636,7 @@ struct llm_build_context {
|
||||
// cb(cur, "ffn_out", il);
|
||||
|
||||
|
||||
cur = llm_build_qbitlinear(ctx0, cur);
|
||||
// cur = llm_build_qbitlinear(ctx0, cur);
|
||||
|
||||
struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
|
||||
|
||||
@ -11658,7 +11659,7 @@ struct llm_build_context {
|
||||
cb(cur, "ffn_sub_norm", il);
|
||||
|
||||
// B4 for w2
|
||||
cur = llm_build_qbitlinear(ctx0, cur);
|
||||
// cur = llm_build_qbitlinear(ctx0, cur);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur);
|
||||
cb(cur, "ffn_down", il);
|
||||
@ -15684,6 +15685,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||
case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break;
|
||||
case LLAMA_FTYPE_MOSTLY_I2 : default_type = GGML_TYPE_I2; break;
|
||||
|
||||
default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
|
||||
}
|
||||
@ -15921,7 +15923,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||
if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
|
||||
new_type = params->output_tensor_type;
|
||||
}
|
||||
|
||||
if (tensor->type == 31) {
|
||||
// no need quantize for i2
|
||||
new_type = tensor->type;
|
||||
}
|
||||
// If we've decided to quantize to the same type the tensor is already
|
||||
// in then there's nothing to do.
|
||||
quantize = tensor->type != new_type;
|
||||
|
1
llama.h
1
llama.h
@ -156,6 +156,7 @@ extern "C" {
|
||||
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_I2 = 33,
|
||||
|
||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user