From 5e59660173d050628647f9a8c6f96830f57ce052 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Fri, 7 Jun 2024 14:42:52 +0800 Subject: [PATCH] finish f16 hf bitnet e2e --- convert-hf-to-gguf.py | 124 +++++++++++++++++++++- ggml-common.h | 67 ++++++++++++ ggml-quants.c | 22 ++++ ggml-quants.h | 1 + ggml.c | 202 +++++++++++++++++++++++++++++++++++- ggml.h | 1 + gguf-py/gguf/constants.py | 3 + gguf-py/gguf/gguf_writer.py | 13 ++- llama.cpp | 17 +-- llama.h | 1 + 10 files changed, 440 insertions(+), 11 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index d39bb3bd1..42d67aca4 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1421,7 +1421,31 @@ class BitnetModel(Model): result = (weight * s).round().clamp(-1, 1) / s return result.type(dtype) + def transform_to_i2(self, x): + from gguf.lazy import LazyNumpyTensor + x = LazyNumpyTensor.to_eager(x) + x_num = np.prod(x.shape) + x = np.reshape(x, x_num) + scale = 1 + for i in range(x_num): + if x[i] != 0: + scale = x[i] + break + x = np.divide(x, scale) + x = x.astype(np.uint8) + x = np.reshape(x, [x.shape[0] // 4, 4]) + keep_bit = {0:192, 1:48, 2:12, 3:3} + ans = np.zeros([x_num // 4], dtype=np.uint8) + for i in range(4): + x_bit_col = x[:, i] + x_bit_shift = np.left_shift(x_bit_col, 6 - i * 2) + x_bit_shift = np.bitwise_and(x_bit_shift, keep_bit[i]) + ans = np.bitwise_or(ans, x_bit_shift) + scale = np.tile(scale, 8) + return ans, scale + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # quant weight to i2 (in fp16) if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", "down_proj.weight", "up_proj.weight", "gate_proj.weight", "o_proj.weight")): @@ -1429,6 +1453,103 @@ class BitnetModel(Model): return [(self.map_tensor_name(name), data_torch)] + def write_tensors(self): + max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") + + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + # use the first number-like part of the tensor name as the block id + bid = None + for part in name.split("."): + if part.isdecimal(): + bid = int(part) + break + + for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): + data: np.ndarray = data # type hint + data_shape = data.shape + n_dims = len(data.shape) + data_dtype = data.dtype + data_qtype: gguf.GGMLQuantizationType | None = None + + # when both are True, f32 should win + extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims) + extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims) + + # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors + # Conditions should closely match those in llama_model_quantize_internal in llama.cpp + extra_f32 = any(cond for cond in ( + extra_f32, + n_dims == 1, + new_name.endswith("_norm.weight"), + )) + + # Some tensor types are always in float32 + extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in ( + gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.POS_EMBD, + gguf.MODEL_TENSOR.TOKEN_TYPES, + )) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + extra_f16 = any(cond for cond in ( + extra_f16, + (name.endswith(".weight") and n_dims >= 2), + )) + + suit_i2 = True + if name.endswith('embed_tokens.weight') or name.endswith('norm.weight'): + suit_i2 = False + + i2_scale = None + if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32: + if self.ftype == gguf.LlamaFileType.MOSTLY_I2 and suit_i2: + data, i2_scale = self.transform_to_i2(data) + assert data.dtype == np.uint8 + assert i2_scale.dtype == np.float32 + data_qtype = gguf.GGMLQuantizationType.I2 + + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + data = gguf.quantize_bf16(data) + assert data.dtype == np.int16 + data_qtype = gguf.GGMLQuantizationType.BF16 + + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data): + data = gguf.quantize_q8_0(data) + assert data.dtype == np.uint8 + data_qtype = gguf.GGMLQuantizationType.Q8_0 + + else: # default to float16 for quantized tensors + if data_dtype != np.float16: + data = data.astype(np.float16) + data_qtype = gguf.GGMLQuantizationType.F16 + + if data_qtype is None: # by default, convert to float32 + if data_dtype != np.float32: + data = data.astype(np.float32) + data_qtype = gguf.GGMLQuantizationType.F32 + + shape = data_shape + # shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape + # reverse shape to make it similar to the internal ggml dimension order + shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" + + # n_dims is implicit in the shape + logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + self.gguf_writer.add_tensor(new_name, data, raw_shape=shape, raw_dtype=data_qtype) + if i2_scale is not None: + self.gguf_writer.add_tensor(new_name + "_scale", i2_scale, raw_dtype=gguf.GGMLQuantizationType.F32) + @Model.register("GrokForCausalLM") class GrokModel(Model): model_arch = gguf.MODEL_ARCH.GROK @@ -2804,7 +2925,7 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "i2", "auto"], default="f16", help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( @@ -2864,6 +2985,7 @@ def main() -> None: "f16": gguf.LlamaFileType.MOSTLY_F16, "bf16": gguf.LlamaFileType.MOSTLY_BF16, "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + "i2" : gguf.LlamaFileType.MOSTLY_I2, "auto": gguf.LlamaFileType.GUESSED, } diff --git a/ggml-common.h b/ggml-common.h index 77e6bfba4..a3d4f7a56 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -1016,6 +1016,73 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, GGML_TABLE_END() +GGML_TABLE_BEGIN(uint32_t, i2_q8, 256) +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00010000, 0x01010000, 0x00010000, 0xff010000, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, +0x00000100, 0x01000100, 0x00000100, 0xff000100, +0x00010100, 0x01010100, 0x00010100, 0xff010100, +0x00000100, 0x01000100, 0x00000100, 0xff000100, +0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00010000, 0x01010000, 0x00010000, 0xff010000, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, +0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, +0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, +0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, +0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, +0x00000001, 0x01000001, 0x00000001, 0xff000001, +0x00010001, 0x01010001, 0x00010001, 0xff010001, +0x00000001, 0x01000001, 0x00000001, 0xff000001, +0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, +0x00000101, 0x01000101, 0x00000101, 0xff000101, +0x00010101, 0x01010101, 0x00010101, 0xff010101, +0x00000101, 0x01000101, 0x00000101, 0xff000101, +0x00ff0101, 0x01ff0101, 0x00ff0101, 0xffff0101, +0x00000001, 0x01000001, 0x00000001, 0xff000001, +0x00010001, 0x01010001, 0x00010001, 0xff010001, +0x00000001, 0x01000001, 0x00000001, 0xff000001, +0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, +0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, +0x0001ff01, 0x0101ff01, 0x0001ff01, 0xff01ff01, +0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, +0x00ffff01, 0x01ffff01, 0x00ffff01, 0xffffff01, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00010000, 0x01010000, 0x00010000, 0xff010000, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, +0x00000100, 0x01000100, 0x00000100, 0xff000100, +0x00010100, 0x01010100, 0x00010100, 0xff010100, +0x00000100, 0x01000100, 0x00000100, 0xff000100, +0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00010000, 0x01010000, 0x00010000, 0xff010000, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, +0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, +0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, +0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, +0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, +0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, +0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, +0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, +0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, +0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, +0x000101ff, 0x010101ff, 0x000101ff, 0xff0101ff, +0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, +0x00ff01ff, 0x01ff01ff, 0x00ff01ff, 0xffff01ff, +0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, +0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, +0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, +0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, +0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, +0x0001ffff, 0x0101ffff, 0x0001ffff, 0xff01ffff, +0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, +0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff, +GGML_TABLE_END() + #define NGRID_IQ1S 2048 #define IQ1S_DELTA 0.125f #define IQ1M_DELTA 0.125f diff --git a/ggml-quants.c b/ggml-quants.c index 9f864e5c4..a13211307 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3726,6 +3726,28 @@ static inline __m128i get_scale_shuffle(int i) { } #endif +//====================================== I2 =============================================== + +void ggml_vec_dot_i2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + const uint8_t * restrict x = vx; + const int8_t * restrict y = vy; + + int sumi = 0; + + for (int i = 0; i < n / 4; i++) { + int8_t* weight = (const int8_t *)(i2_q8 + x[i]); + sumi += (int)y[i*4+0] * weight[0]; + sumi += (int)y[i*4+1] * weight[1]; + sumi += (int)y[i*4+2] * weight[2]; + sumi += (int)y[i*4+3] * weight[3]; + } + *s = (float)(sumi); + +} + void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml-quants.h b/ggml-quants.h index 4d436a8f0..1c8e3839d 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -99,6 +99,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_i2_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml.c b/ggml.c index 4c3e6f723..335cb1cdd 100644 --- a/ggml.c +++ b/ggml.c @@ -569,6 +569,15 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc); static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { + [GGML_TYPE_I2] = { + .type_name = "i2", + .blck_size = 1, + .type_size = sizeof(int8_t), + .is_quantized = false, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_i2_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_I8] = { .type_name = "i8", .blck_size = 1, @@ -1805,6 +1814,7 @@ inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } +inline static void ggml_vec_mul_f32_bitnet (const int n, float * y, const float x) { for (int i = 0; i < n; ++i) y[i] = y[i] * x; } static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { assert(nrc == 1); @@ -2636,6 +2646,16 @@ inline static void ggml_vec_scaleroundclamp_f32(const int n, float * s, const fl s[i] /= scale; } } +inline static void ggml_vec_scaleroundclamp_f32_v2(const int n, float * s, int8_t* inp, float scale, float min, float max) { + + for (int i = 0; i < n; ++i) { + s[i] = round(s[i] * scale); + if (s[i] > max) s[i] = max; + if (s[i] < min) s[i] = min; + inp[i] = (int8_t)(s[i]); + } + +} // // data types @@ -3081,6 +3101,10 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) { for (int i = 0; i < GGML_MAX_DIMS; ++i) { nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; } + if(tensor->type == 31){ + nbytes = nbytes / 4 + 32; + } + } else { nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; @@ -12411,7 +12435,10 @@ static void ggml_compute_forward_mul_mat_one_chunk( } const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = ggml_row_size(vec_dot_type, ne10); + size_t row_size = ggml_row_size(vec_dot_type, ne10); + if (src0->type == 31) { + row_size = ne10; + } assert(ne12 % ne02 == 0); assert(ne13 % ne03 == 0); @@ -12425,6 +12452,9 @@ static void ggml_compute_forward_mul_mat_one_chunk( // attempt to reduce false-sharing (does not seem to make a difference) // 16 * 2, accounting for mmla kernels float tmp[32]; + uint8_t *i_weight = (uint8_t*) (src0->data); + float * scale = (float * )((i_weight) + (ne00 * ne01 / 4)); + float* act_scales = (float*) ((char *) wdata + ((ne11*nb11) / 4)); for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { @@ -12458,9 +12488,15 @@ static void ggml_compute_forward_mul_mat_one_chunk( //} for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { + if (src0->type == 31) { + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + tmp[ir0 - iir0] = tmp[ir0 - iir0] * (*scale) * (act_scales[i11]); + }else { vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); } + } + for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); } @@ -12469,6 +12505,164 @@ static void ggml_compute_forward_mul_mat_one_chunk( } } + +static void ggml_compute_forward_bitnet_mul_mat( + const struct ggml_compute_params * params, + struct ggml_tensor * dst, + struct ggml_compute_state * state) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + const bool src1_cont = ggml_is_contiguous(src1); + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; + UNUSED(r2); + UNUSED(r3); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + if (params->type == GGML_TASK_TYPE_INIT) { + if (ith != 0) { + return; + } + atomic_store(&state->shared->current_chunk, nth); + char * wdata = params->wdata; + float* act_scales = (float*) ((char *) wdata + ((ne11*nb11) / 4)); + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + for (int64_t i11 = 0; i11 < ne11; i11++) { + float rowmax = 0.00001; + ggml_vec_absmaxclamp_f32(ne10, &rowmax, (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13), 0.00001); + float s = 127 / rowmax; + act_scales[i11] = 1/s; + ggml_vec_scaleroundclamp_f32_v2(ne10, + (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13), + (int8_t*) ((char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4)), + s, -128, 127); + } + } + } + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + // atomic_store(&state->shared->current_chunk, nth); + // // char * wdata = params->wdata; + // const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, ne10); + // // printf("vec_dot_type:%d\n", vec_dot_type); + // // printf("row_size:%ld\n", row_size); + // assert(params->wsize >= ne11*ne12*ne13*row_size); + // GGML_ASSERT(src1->type == GGML_TYPE_F32); + + // for (int64_t i13 = 0; i13 < ne13; ++i13) { + // for (int64_t i12 = 0; i12 < ne12; ++i12) { + // for (int64_t i11 = 0; i11 < ne11; ++i11) { + // quantize_row_q8_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + // wdata += row_size; + // } + // } + // } + + + return; + } + + if (params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) + const int64_t nr0 = ne0; + + // This is the size of the rest of the dimensions of the result + const int64_t nr1 = ne1 * ne2 * ne3; + + // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols + int64_t num_rows_per_vec_dot = 1; + // TODO: currently the mmla kernels support only even numbered rows/cols. + // this check can be removed once they are extended to support odd numbered rows/cols too + if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { + num_rows_per_vec_dot = 1; + } + + // Now select a reasonable chunk size. + int chunk_size = 16; + + // We need to step up the size if it's small + if (nr0 == 1 || nr1 == 1) { + chunk_size = 64; + } + + // distribute the work across the inner or outer loop based on which one is larger + // The number of chunks in the 0/1 dim. + // CEIL(nr0/chunk_size) + int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; + int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; + + // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread. + // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915 + // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that. + if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) { + // distribute the thread work across the inner or outer loop based on which one is larger + nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + } + + // The number of elements in each chunk + const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; + + //if (ith == 0) + // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1); + + // The first chunk comes from our thread_id, the rest will get auto-assigned. + int current_chunk = ith; + + while (current_chunk < nchunk0 * nchunk1) { + const int64_t ith0 = current_chunk % nchunk0; + const int64_t ith1 = current_chunk / nchunk0; + + const int64_t ir0_start = dr0 * ith0; + const int64_t ir0_end = MIN(ir0_start + dr0, nr0); + + const int64_t ir1_start = dr1 * ith1; + const int64_t ir1_end = MIN(ir1_start + dr1, nr1); + + ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); + + if (nth >= nchunk0 * nchunk1) { + break; + } + + current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1); + } + +} + static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, struct ggml_tensor * dst, @@ -12482,6 +12676,11 @@ static void ggml_compute_forward_mul_mat( GGML_TENSOR_BINARY_OP_LOCALS + if (src0->type == 31) { + ggml_compute_forward_bitnet_mul_mat(params, dst, state); + return; + } + const int ith = params->ith; const int nth = params->nth; @@ -14349,6 +14548,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_I32: case GGML_TYPE_I64: case GGML_TYPE_F64: + case GGML_TYPE_I2: case GGML_TYPE_COUNT: { GGML_ASSERT(false); diff --git a/ggml.h b/ggml.h index 98ef96132..5d540aa30 100644 --- a/ggml.h +++ b/ggml.h @@ -377,6 +377,7 @@ extern "C" { GGML_TYPE_F64 = 28, GGML_TYPE_IQ1_M = 29, GGML_TYPE_BF16 = 30, + GGML_TYPE_I2 = 31, GGML_TYPE_COUNT, }; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 429f38189..5e94edb22 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -925,6 +925,7 @@ class GGMLQuantizationType(IntEnum): F64 = 28 IQ1_M = 29 BF16 = 30 + I2 = 31 # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -966,6 +967,7 @@ class LlamaFileType(IntEnum): MOSTLY_IQ4_XS = 30 # except 1d tensors MOSTLY_IQ1_M = 31 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors + MOSTLY_I2 = 33 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1032,6 +1034,7 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4), GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16), GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64), + GGMLQuantizationType.I2: (1, 1), GGMLQuantizationType.I8: (1, 1), GGMLQuantizationType.I16: (1, 2), GGMLQuantizationType.I32: (1, 4), diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index b93747aff..2d19cd44c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -225,8 +225,10 @@ class GGUFWriter: dtype = GGMLQuantizationType.I32 elif tensor_dtype == np.int64: dtype = GGMLQuantizationType.I64 + elif tensor_dtype == np.uint8: + dtype = GGMLQuantizationType.I2 else: - raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now") + raise ValueError("Only F16, F32, F64, I8, I16, I32, I64, I2 tensors are supported for now") else: dtype = raw_dtype if tensor_dtype == np.uint8: @@ -237,7 +239,10 @@ class GGUFWriter: self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) self.ti_data += self._pack("I", dtype) self.ti_data += self._pack("Q", self.offset_tensor) - self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) + if dtype == GGMLQuantizationType.I2: + self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) + self.data_alignment + else: + self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) self.ti_data_count += 1 def add_tensor( @@ -252,7 +257,9 @@ class GGUFWriter: self.temp_file = fp shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape - self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) + + if (raw_dtype != GGMLQuantizationType.F32 or not name.endswith("scale")): + self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) if self.temp_file is None: self.tensors.append(tensor) diff --git a/llama.cpp b/llama.cpp index 38dbf31e0..53f047333 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3192,8 +3192,9 @@ struct llama_model_loader { llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) { const int tensor_idx = gguf_find_tensor(gguf_ctx, name); + printf("name:%s\n", name); offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx); - + printf("offs:%ld\n", offs + ggml_nbytes(tensor)); if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) { throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name)); } @@ -7140,7 +7141,7 @@ static struct ggml_tensor * llm_build_kqv( cb(cur, "attn_sub_norm", il); // B2 for wo - cur = llm_build_qbitlinear(ctx, cur); + // cur = llm_build_qbitlinear(ctx, cur); } ggml_build_forward_expand(graph, cur); @@ -11563,7 +11564,7 @@ struct llm_build_context { { // compute Q and K and RoPE them // B1.Q - cur = llm_build_qbitlinear(ctx0, cur); + // cur = llm_build_qbitlinear(ctx0, cur); struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { @@ -11635,7 +11636,7 @@ struct llm_build_context { // cb(cur, "ffn_out", il); - cur = llm_build_qbitlinear(ctx0, cur); + // cur = llm_build_qbitlinear(ctx0, cur); struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); @@ -11658,7 +11659,7 @@ struct llm_build_context { cb(cur, "ffn_sub_norm", il); // B4 for w2 - cur = llm_build_qbitlinear(ctx0, cur); + // cur = llm_build_qbitlinear(ctx0, cur); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); cb(cur, "ffn_down", il); @@ -15684,6 +15685,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_I2 : default_type = GGML_TYPE_I2; break; default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } @@ -15921,7 +15923,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) { new_type = params->output_tensor_type; } - + if (tensor->type == 31) { + // no need quantize for i2 + new_type = tensor->type; + } // If we've decided to quantize to the same type the tensor is already // in then there's nothing to do. quantize = tensor->type != new_type; diff --git a/llama.h b/llama.h index a78ccdaf5..1a225fa61 100644 --- a/llama.h +++ b/llama.h @@ -156,6 +156,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors + LLAMA_FTYPE_MOSTLY_I2 = 33, LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file };