move i2 to quantize

This commit is contained in:
root 2024-06-09 18:20:32 +08:00
parent ca09085593
commit dbee0a86c1
3 changed files with 48 additions and 150 deletions

View File

@ -1418,142 +1418,17 @@ class BitnetModel(Model):
dtype = weight.dtype
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
# from gguf.lazy import LazyNumpyTensor
# np_s = LazyNumpyTensor.to_eager(s.numpy())
# print(np_s)
result = (weight * s).round().clamp(-1, 1) / s
return result.type(dtype)
def transform_to_i2(self, x):
from gguf.lazy import LazyNumpyTensor
x = LazyNumpyTensor.to_eager(x)
x_num = np.prod(x.shape)
x = np.reshape(x, x_num)
scale = 1
for i in range(x_num):
if x[i] != 0:
scale = x[i]
break
x = np.where(x * scale > 0, 1, np.where(x * scale < 0, -1, x))
x = x.astype(np.uint8)
x = np.reshape(x, [x.shape[0] // 4, 4])
keep_bit = {0:192, 1:48, 2:12, 3:3}
ans = np.zeros([x_num // 4], dtype=np.uint8)
for i in range(4):
x_bit_col = x[:, i]
x_bit_shift = np.left_shift(x_bit_col, 6 - i * 2)
x_bit_shift = np.bitwise_and(x_bit_shift, keep_bit[i])
ans = np.bitwise_or(ans, x_bit_shift)
scale = np.tile(scale, 8)
return ans, scale
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# transform weight into 1/0/-1 (in fp32)
if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight",
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
"o_proj.weight")):
data_torch = data_torch + (self.weight_quant(data_torch) - data_torch).detach()
# def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# # quant weight to i2 (in fp16)
# if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight",
# "down_proj.weight", "up_proj.weight", "gate_proj.weight",
# "o_proj.weight")):
# print(name)
# data_torch = data_torch + (self.weight_quant(data_torch) - data_torch).detach()
# return [(self.map_tensor_name(name), data_torch)]
def write_tensors(self):
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
for name, data_torch in self.get_tensors():
# we don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
# use the first number-like part of the tensor name as the block id
bid = None
for part in name.split("."):
if part.isdecimal():
bid = int(part)
break
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
data: np.ndarray = data # type hint
data_shape = data.shape
n_dims = len(data.shape)
data_dtype = data.dtype
data_qtype: gguf.GGMLQuantizationType | None = None
# when both are True, f32 should win
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
# Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
# Conditions should closely match those in llama_model_quantize_internal in llama.cpp
extra_f32 = any(cond for cond in (
extra_f32,
n_dims == 1,
new_name.endswith("_norm.weight"),
))
# Some tensor types are always in float32
extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in (
gguf.MODEL_TENSOR.FFN_GATE_INP,
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
))
# if f16 desired, convert any float32 2-dim weight tensors to float16
extra_f16 = any(cond for cond in (
extra_f16,
(name.endswith(".weight") and n_dims >= 2),
))
suit_i2 = True
if name.endswith('embed_tokens.weight') or name.endswith('norm.weight'):
suit_i2 = False
i2_scale = None
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
if self.ftype == gguf.LlamaFileType.MOSTLY_I2 and suit_i2:
data, i2_scale = self.transform_to_i2(data)
assert data.dtype == np.uint8
assert i2_scale.dtype == np.float32
data_qtype = gguf.GGMLQuantizationType.I2
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
data = gguf.quantize_bf16(data)
assert data.dtype == np.int16
data_qtype = gguf.GGMLQuantizationType.BF16
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
data = gguf.quantize_q8_0(data)
assert data.dtype == np.uint8
data_qtype = gguf.GGMLQuantizationType.Q8_0
else: # default to float16 for quantized tensors
if data_dtype != np.float16:
data = data.astype(np.float16)
data_qtype = gguf.GGMLQuantizationType.F16
if data_qtype is None: # by default, convert to float32
if data_dtype != np.float32:
data = data.astype(np.float32)
data_qtype = gguf.GGMLQuantizationType.F32
shape = data_shape
# shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
# reverse shape to make it similar to the internal ggml dimension order
shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}"
# n_dims is implicit in the shape
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
self.gguf_writer.add_tensor(new_name, data, raw_shape=shape, raw_dtype=data_qtype)
if i2_scale is not None:
self.gguf_writer.add_tensor(new_name + "_scale", i2_scale, raw_dtype=gguf.GGMLQuantizationType.F32)
return [(self.map_tensor_name(name), data_torch)]
@Model.register("GrokForCausalLM")
class GrokModel(Model):

View File

@ -3308,29 +3308,47 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
// 2 bits per weight
size_t row_size = ggml_row_size(GGML_TYPE_I2, n_per_row) / 4;
char * qrow = (char *)dst;
printf("n_row:%d\n", nrow);
printf("n_per_row:%d\n", n_per_row);
size_t row_size = ggml_row_size(GGML_TYPE_I2, n_per_row);
int n = nrow * n_per_row;
float accu = 0.0;
float min = 0.00001;
for (int i = 0; i < n; ++i) {
accu += fabs(src[i]);
// f32 -> q8
double i2_scale = 0;
for (int i=0; i<n; i++) {
if (fabs(src[i]) > 1e-6) {
i2_scale = src[i];
}
}
accu = accu > min ? accu : min;
float scale = n / accu;
printf("\nscale:%f\n", scale);
uint8_t* q8 = (uint8_t*)dst;
for (int i=0; i<n; i++) {
if (fabs(src[i]) < 1e-6) {
q8[i] = 0;
continue;
}
q8[i] = src[i] * i2_scale > 0 ? 1 : 3;
}
// for (int64_t row = 0; row < nrow; ++row) {
// quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights);
// src += n_per_row;
// qrow += row_size;
// }
// q8 -> 0, 1, 3
// | | |
// 0, 1,-1
uint8_t* i2_weight = (uint8_t*)dst;
for (int i=0; i<n; i++) {
int group_idx = i / 4;
int group_pos = i % 4;
uint8_t temp = (q8[i] << (6 - 2 * group_pos));
q8[i] = 0;
i2_weight[group_idx] |= temp;
}
float* scale_ptr = (float*)((char*)i2_weight + n / 4);
for (int i=0; i<8; i++) {
scale_ptr[i] = i2_scale;
}
// 32B for scale
return nrow * row_size + 32;
return nrow * row_size / 4 + 32;
}
// ====================== "True" 2-bit (de)-quantization
@ -14413,6 +14431,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_I2:
// nothing to validate
break;
default:

6
ggml.c
View File

@ -21750,7 +21750,11 @@ size_t ggml_quantize_chunk(
assert(false);
}
GGML_ASSERT(result == nrows * row_size);
if (type == GGML_TYPE_I2) {
result = nrows * row_size / 4 + 32;
} else {
GGML_ASSERT(result == nrows * row_size);
}
return result;
}