Implement Q8_0 quantization fully in PyTorch.

This is equivalent to gguf.quantize_q8_0 but doesn't round-trip to
Numpy.
This commit is contained in:
Heiner 2024-05-23 20:24:47 +02:00
parent abc958b07e
commit 739648f3e6

View File

@ -123,12 +123,21 @@ def get_weights(fn):
assert len(arrays) in (1, 2)
def torch_roundf(t: torch.Tensor) -> torch.Tensor:
"""Round halfway cases away from zero like roundf(3). Cf. gguf/quants.py."""
a = abs(t)
floored = torch.floor(a)
b = floored + torch.floor(2 * (a - floored))
return torch.sign(t) * b
def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q8_0 in ggml.c (modulo rounding away from zero)
# Equivalent to gguf.quantize_q8_0 but PyTorch instead of Numpy.
assert tensor.shape[1] % QK8_0 == 0
tensor = tensor.reshape(-1, QK8_0)
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
tensor = (tensor / scale).round().clamp(min=-128, max=127).char()
iscale = torch.where(scale != 0.0, 1.0 / scale, 0.0)
tensor = torch_roundf(tensor * iscale).clamp(min=-128, max=127).char()
# add scale into each block
tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1)
return tensor
@ -175,9 +184,7 @@ def maybe_quantize_tensor(tensor, ggml_type):
elif ggml_type == gguf.GGMLQuantizationType.F16:
return tensor.half()
elif ggml_type == gguf.GGMLQuantizationType.Q8_0:
if tensor.device.type == "meta":
return quantize_q8_0(tensor) # Cannot convert into numpy array.
return torch.from_numpy(gguf.quantize_q8_0(tensor.numpy()))
return quantize_q8_0(tensor)
elif ggml_type == gguf.GGMLQuantizationType.Q4_0:
return quantize_q4_0(tensor)
elif ggml_type == gguf.GGMLQuantizationType.Q4_1: