mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 11:40:17 +00:00
Implement Q8_0 quantization fully in PyTorch.
This is equivalent to gguf.quantize_q8_0 but doesn't round-trip to Numpy.
This commit is contained in:
parent
abc958b07e
commit
739648f3e6
@ -123,12 +123,21 @@ def get_weights(fn):
|
||||
assert len(arrays) in (1, 2)
|
||||
|
||||
|
||||
def torch_roundf(t: torch.Tensor) -> torch.Tensor:
|
||||
"""Round halfway cases away from zero like roundf(3). Cf. gguf/quants.py."""
|
||||
a = abs(t)
|
||||
floored = torch.floor(a)
|
||||
b = floored + torch.floor(2 * (a - floored))
|
||||
return torch.sign(t) * b
|
||||
|
||||
|
||||
def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
|
||||
# equivalent to ggml_quantize_q8_0 in ggml.c (modulo rounding away from zero)
|
||||
# Equivalent to gguf.quantize_q8_0 but PyTorch instead of Numpy.
|
||||
assert tensor.shape[1] % QK8_0 == 0
|
||||
tensor = tensor.reshape(-1, QK8_0)
|
||||
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
|
||||
tensor = (tensor / scale).round().clamp(min=-128, max=127).char()
|
||||
iscale = torch.where(scale != 0.0, 1.0 / scale, 0.0)
|
||||
tensor = torch_roundf(tensor * iscale).clamp(min=-128, max=127).char()
|
||||
# add scale into each block
|
||||
tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1)
|
||||
return tensor
|
||||
@ -175,9 +184,7 @@ def maybe_quantize_tensor(tensor, ggml_type):
|
||||
elif ggml_type == gguf.GGMLQuantizationType.F16:
|
||||
return tensor.half()
|
||||
elif ggml_type == gguf.GGMLQuantizationType.Q8_0:
|
||||
if tensor.device.type == "meta":
|
||||
return quantize_q8_0(tensor) # Cannot convert into numpy array.
|
||||
return torch.from_numpy(gguf.quantize_q8_0(tensor.numpy()))
|
||||
return quantize_q8_0(tensor)
|
||||
elif ggml_type == gguf.GGMLQuantizationType.Q4_0:
|
||||
return quantize_q4_0(tensor)
|
||||
elif ggml_type == gguf.GGMLQuantizationType.Q4_1:
|
||||
|
Loading…
Reference in New Issue
Block a user