mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 19:50:17 +00:00
move i2 to quantize
This commit is contained in:
parent
ca09085593
commit
dbee0a86c1
@ -1418,142 +1418,17 @@ class BitnetModel(Model):
|
||||
dtype = weight.dtype
|
||||
weight = weight.float()
|
||||
s = 1 / weight.abs().mean().clamp(min=1e-5)
|
||||
# from gguf.lazy import LazyNumpyTensor
|
||||
# np_s = LazyNumpyTensor.to_eager(s.numpy())
|
||||
|
||||
# print(np_s)
|
||||
result = (weight * s).round().clamp(-1, 1) / s
|
||||
return result.type(dtype)
|
||||
|
||||
def transform_to_i2(self, x):
|
||||
from gguf.lazy import LazyNumpyTensor
|
||||
x = LazyNumpyTensor.to_eager(x)
|
||||
x_num = np.prod(x.shape)
|
||||
x = np.reshape(x, x_num)
|
||||
scale = 1
|
||||
for i in range(x_num):
|
||||
if x[i] != 0:
|
||||
scale = x[i]
|
||||
break
|
||||
x = np.where(x * scale > 0, 1, np.where(x * scale < 0, -1, x))
|
||||
x = x.astype(np.uint8)
|
||||
x = np.reshape(x, [x.shape[0] // 4, 4])
|
||||
keep_bit = {0:192, 1:48, 2:12, 3:3}
|
||||
ans = np.zeros([x_num // 4], dtype=np.uint8)
|
||||
for i in range(4):
|
||||
x_bit_col = x[:, i]
|
||||
x_bit_shift = np.left_shift(x_bit_col, 6 - i * 2)
|
||||
x_bit_shift = np.bitwise_and(x_bit_shift, keep_bit[i])
|
||||
ans = np.bitwise_or(ans, x_bit_shift)
|
||||
scale = np.tile(scale, 8)
|
||||
return ans, scale
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# transform weight into 1/0/-1 (in fp32)
|
||||
if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight",
|
||||
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
|
||||
"o_proj.weight")):
|
||||
data_torch = data_torch + (self.weight_quant(data_torch) - data_torch).detach()
|
||||
|
||||
# def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# # quant weight to i2 (in fp16)
|
||||
# if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight",
|
||||
# "down_proj.weight", "up_proj.weight", "gate_proj.weight",
|
||||
# "o_proj.weight")):
|
||||
# print(name)
|
||||
# data_torch = data_torch + (self.weight_quant(data_torch) - data_torch).detach()
|
||||
|
||||
# return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def write_tensors(self):
|
||||
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
|
||||
|
||||
for name, data_torch in self.get_tensors():
|
||||
# we don't need these
|
||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||
continue
|
||||
|
||||
old_dtype = data_torch.dtype
|
||||
|
||||
# convert any unsupported data types to float32
|
||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||
data_torch = data_torch.to(torch.float32)
|
||||
|
||||
# use the first number-like part of the tensor name as the block id
|
||||
bid = None
|
||||
for part in name.split("."):
|
||||
if part.isdecimal():
|
||||
bid = int(part)
|
||||
break
|
||||
|
||||
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
|
||||
data: np.ndarray = data # type hint
|
||||
data_shape = data.shape
|
||||
n_dims = len(data.shape)
|
||||
data_dtype = data.dtype
|
||||
data_qtype: gguf.GGMLQuantizationType | None = None
|
||||
|
||||
# when both are True, f32 should win
|
||||
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
|
||||
extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
|
||||
|
||||
# Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
|
||||
# Conditions should closely match those in llama_model_quantize_internal in llama.cpp
|
||||
extra_f32 = any(cond for cond in (
|
||||
extra_f32,
|
||||
n_dims == 1,
|
||||
new_name.endswith("_norm.weight"),
|
||||
))
|
||||
|
||||
# Some tensor types are always in float32
|
||||
extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in (
|
||||
gguf.MODEL_TENSOR.FFN_GATE_INP,
|
||||
gguf.MODEL_TENSOR.POS_EMBD,
|
||||
gguf.MODEL_TENSOR.TOKEN_TYPES,
|
||||
))
|
||||
|
||||
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||
extra_f16 = any(cond for cond in (
|
||||
extra_f16,
|
||||
(name.endswith(".weight") and n_dims >= 2),
|
||||
))
|
||||
|
||||
suit_i2 = True
|
||||
if name.endswith('embed_tokens.weight') or name.endswith('norm.weight'):
|
||||
suit_i2 = False
|
||||
|
||||
i2_scale = None
|
||||
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
|
||||
if self.ftype == gguf.LlamaFileType.MOSTLY_I2 and suit_i2:
|
||||
data, i2_scale = self.transform_to_i2(data)
|
||||
assert data.dtype == np.uint8
|
||||
assert i2_scale.dtype == np.float32
|
||||
data_qtype = gguf.GGMLQuantizationType.I2
|
||||
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
|
||||
data = gguf.quantize_bf16(data)
|
||||
assert data.dtype == np.int16
|
||||
data_qtype = gguf.GGMLQuantizationType.BF16
|
||||
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
|
||||
data = gguf.quantize_q8_0(data)
|
||||
assert data.dtype == np.uint8
|
||||
data_qtype = gguf.GGMLQuantizationType.Q8_0
|
||||
|
||||
else: # default to float16 for quantized tensors
|
||||
if data_dtype != np.float16:
|
||||
data = data.astype(np.float16)
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
|
||||
if data_qtype is None: # by default, convert to float32
|
||||
if data_dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
|
||||
shape = data_shape
|
||||
# shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
|
||||
# reverse shape to make it similar to the internal ggml dimension order
|
||||
shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}"
|
||||
|
||||
# n_dims is implicit in the shape
|
||||
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
|
||||
|
||||
self.gguf_writer.add_tensor(new_name, data, raw_shape=shape, raw_dtype=data_qtype)
|
||||
if i2_scale is not None:
|
||||
self.gguf_writer.add_tensor(new_name + "_scale", i2_scale, raw_dtype=gguf.GGMLQuantizationType.F32)
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
@Model.register("GrokForCausalLM")
|
||||
class GrokModel(Model):
|
||||
|
@ -3308,29 +3308,47 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
|
||||
|
||||
size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
// 2 bits per weight
|
||||
size_t row_size = ggml_row_size(GGML_TYPE_I2, n_per_row) / 4;
|
||||
char * qrow = (char *)dst;
|
||||
printf("n_row:%d\n", nrow);
|
||||
printf("n_per_row:%d\n", n_per_row);
|
||||
size_t row_size = ggml_row_size(GGML_TYPE_I2, n_per_row);
|
||||
|
||||
int n = nrow * n_per_row;
|
||||
float accu = 0.0;
|
||||
float min = 0.00001;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
accu += fabs(src[i]);
|
||||
|
||||
// f32 -> q8
|
||||
double i2_scale = 0;
|
||||
for (int i=0; i<n; i++) {
|
||||
if (fabs(src[i]) > 1e-6) {
|
||||
i2_scale = src[i];
|
||||
}
|
||||
}
|
||||
accu = accu > min ? accu : min;
|
||||
float scale = n / accu;
|
||||
|
||||
printf("\nscale:%f\n", scale);
|
||||
uint8_t* q8 = (uint8_t*)dst;
|
||||
for (int i=0; i<n; i++) {
|
||||
if (fabs(src[i]) < 1e-6) {
|
||||
q8[i] = 0;
|
||||
continue;
|
||||
}
|
||||
q8[i] = src[i] * i2_scale > 0 ? 1 : 3;
|
||||
}
|
||||
|
||||
// for (int64_t row = 0; row < nrow; ++row) {
|
||||
// quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights);
|
||||
// src += n_per_row;
|
||||
// qrow += row_size;
|
||||
// }
|
||||
// q8 -> 0, 1, 3
|
||||
// | | |
|
||||
// 0, 1,-1
|
||||
|
||||
uint8_t* i2_weight = (uint8_t*)dst;
|
||||
for (int i=0; i<n; i++) {
|
||||
int group_idx = i / 4;
|
||||
int group_pos = i % 4;
|
||||
uint8_t temp = (q8[i] << (6 - 2 * group_pos));
|
||||
q8[i] = 0;
|
||||
i2_weight[group_idx] |= temp;
|
||||
}
|
||||
|
||||
float* scale_ptr = (float*)((char*)i2_weight + n / 4);
|
||||
for (int i=0; i<8; i++) {
|
||||
scale_ptr[i] = i2_scale;
|
||||
}
|
||||
|
||||
// 32B for scale
|
||||
return nrow * row_size + 32;
|
||||
return nrow * row_size / 4 + 32;
|
||||
}
|
||||
|
||||
// ====================== "True" 2-bit (de)-quantization
|
||||
@ -14413,6 +14431,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_I64:
|
||||
case GGML_TYPE_I2:
|
||||
// nothing to validate
|
||||
break;
|
||||
default:
|
||||
|
Loading…
Reference in New Issue
Block a user