From dbee0a86c1ea144b3f9546e964cee9d7151498e9 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 9 Jun 2024 18:20:32 +0800 Subject: [PATCH] move i2 to quantize --- convert-hf-to-gguf.py | 139 +++--------------------------------------- ggml-quants.c | 53 ++++++++++------ ggml.c | 6 +- 3 files changed, 48 insertions(+), 150 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 735630b9c..d98967e25 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1418,142 +1418,17 @@ class BitnetModel(Model): dtype = weight.dtype weight = weight.float() s = 1 / weight.abs().mean().clamp(min=1e-5) - # from gguf.lazy import LazyNumpyTensor - # np_s = LazyNumpyTensor.to_eager(s.numpy()) - - # print(np_s) result = (weight * s).round().clamp(-1, 1) / s return result.type(dtype) - def transform_to_i2(self, x): - from gguf.lazy import LazyNumpyTensor - x = LazyNumpyTensor.to_eager(x) - x_num = np.prod(x.shape) - x = np.reshape(x, x_num) - scale = 1 - for i in range(x_num): - if x[i] != 0: - scale = x[i] - break - x = np.where(x * scale > 0, 1, np.where(x * scale < 0, -1, x)) - x = x.astype(np.uint8) - x = np.reshape(x, [x.shape[0] // 4, 4]) - keep_bit = {0:192, 1:48, 2:12, 3:3} - ans = np.zeros([x_num // 4], dtype=np.uint8) - for i in range(4): - x_bit_col = x[:, i] - x_bit_shift = np.left_shift(x_bit_col, 6 - i * 2) - x_bit_shift = np.bitwise_and(x_bit_shift, keep_bit[i]) - ans = np.bitwise_or(ans, x_bit_shift) - scale = np.tile(scale, 8) - return ans, scale + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # transform weight into 1/0/-1 (in fp32) + if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", + "down_proj.weight", "up_proj.weight", "gate_proj.weight", + "o_proj.weight")): + data_torch = data_torch + (self.weight_quant(data_torch) - data_torch).detach() - # def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # # quant weight to i2 (in fp16) - # if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", - # "down_proj.weight", "up_proj.weight", "gate_proj.weight", - # "o_proj.weight")): - # print(name) - # data_torch = data_torch + (self.weight_quant(data_torch) - data_torch).detach() - - # return [(self.map_tensor_name(name), data_torch)] - - def write_tensors(self): - max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") - - for name, data_torch in self.get_tensors(): - # we don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): - continue - - old_dtype = data_torch.dtype - - # convert any unsupported data types to float32 - if data_torch.dtype not in (torch.float16, torch.float32): - data_torch = data_torch.to(torch.float32) - - # use the first number-like part of the tensor name as the block id - bid = None - for part in name.split("."): - if part.isdecimal(): - bid = int(part) - break - - for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): - data: np.ndarray = data # type hint - data_shape = data.shape - n_dims = len(data.shape) - data_dtype = data.dtype - data_qtype: gguf.GGMLQuantizationType | None = None - - # when both are True, f32 should win - extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims) - extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims) - - # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors - # Conditions should closely match those in llama_model_quantize_internal in llama.cpp - extra_f32 = any(cond for cond in ( - extra_f32, - n_dims == 1, - new_name.endswith("_norm.weight"), - )) - - # Some tensor types are always in float32 - extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in ( - gguf.MODEL_TENSOR.FFN_GATE_INP, - gguf.MODEL_TENSOR.POS_EMBD, - gguf.MODEL_TENSOR.TOKEN_TYPES, - )) - - # if f16 desired, convert any float32 2-dim weight tensors to float16 - extra_f16 = any(cond for cond in ( - extra_f16, - (name.endswith(".weight") and n_dims >= 2), - )) - - suit_i2 = True - if name.endswith('embed_tokens.weight') or name.endswith('norm.weight'): - suit_i2 = False - - i2_scale = None - if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32: - if self.ftype == gguf.LlamaFileType.MOSTLY_I2 and suit_i2: - data, i2_scale = self.transform_to_i2(data) - assert data.dtype == np.uint8 - assert i2_scale.dtype == np.float32 - data_qtype = gguf.GGMLQuantizationType.I2 - - elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: - data = gguf.quantize_bf16(data) - assert data.dtype == np.int16 - data_qtype = gguf.GGMLQuantizationType.BF16 - - elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data): - data = gguf.quantize_q8_0(data) - assert data.dtype == np.uint8 - data_qtype = gguf.GGMLQuantizationType.Q8_0 - - else: # default to float16 for quantized tensors - if data_dtype != np.float16: - data = data.astype(np.float16) - data_qtype = gguf.GGMLQuantizationType.F16 - - if data_qtype is None: # by default, convert to float32 - if data_dtype != np.float32: - data = data.astype(np.float32) - data_qtype = gguf.GGMLQuantizationType.F32 - - shape = data_shape - # shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape - # reverse shape to make it similar to the internal ggml dimension order - shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" - - # n_dims is implicit in the shape - logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") - - self.gguf_writer.add_tensor(new_name, data, raw_shape=shape, raw_dtype=data_qtype) - if i2_scale is not None: - self.gguf_writer.add_tensor(new_name + "_scale", i2_scale, raw_dtype=gguf.GGMLQuantizationType.F32) + return [(self.map_tensor_name(name), data_torch)] @Model.register("GrokForCausalLM") class GrokModel(Model): diff --git a/ggml-quants.c b/ggml-quants.c index 96d3c88f6..a4a72c847 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3308,29 +3308,47 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { // 2 bits per weight - size_t row_size = ggml_row_size(GGML_TYPE_I2, n_per_row) / 4; - char * qrow = (char *)dst; - printf("n_row:%d\n", nrow); - printf("n_per_row:%d\n", n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_I2, n_per_row); + int n = nrow * n_per_row; - float accu = 0.0; - float min = 0.00001; - for (int i = 0; i < n; ++i) { - accu += fabs(src[i]); + + // f32 -> q8 + double i2_scale = 0; + for (int i=0; i 1e-6) { + i2_scale = src[i]; + } } - accu = accu > min ? accu : min; - float scale = n / accu; - printf("\nscale:%f\n", scale); + uint8_t* q8 = (uint8_t*)dst; + for (int i=0; i 0 ? 1 : 3; + } - // for (int64_t row = 0; row < nrow; ++row) { - // quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights); - // src += n_per_row; - // qrow += row_size; - // } + // q8 -> 0, 1, 3 + // | | | + // 0, 1,-1 + + uint8_t* i2_weight = (uint8_t*)dst; + for (int i=0; i