From 739648f3e68c90fb823e06e6922e73ff026a26f5 Mon Sep 17 00:00:00 2001 From: Heiner Date: Thu, 23 May 2024 20:24:47 +0200 Subject: [PATCH] Implement Q8_0 quantization fully in PyTorch. This is equivalent to gguf.quantize_q8_0 but doesn't round-trip to Numpy. --- convert_grok.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/convert_grok.py b/convert_grok.py index 43615d6ba..e5d9aa597 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -123,12 +123,21 @@ def get_weights(fn): assert len(arrays) in (1, 2) +def torch_roundf(t: torch.Tensor) -> torch.Tensor: + """Round halfway cases away from zero like roundf(3). Cf. gguf/quants.py.""" + a = abs(t) + floored = torch.floor(a) + b = floored + torch.floor(2 * (a - floored)) + return torch.sign(t) * b + + def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor: - # equivalent to ggml_quantize_q8_0 in ggml.c (modulo rounding away from zero) + # Equivalent to gguf.quantize_q8_0 but PyTorch instead of Numpy. assert tensor.shape[1] % QK8_0 == 0 tensor = tensor.reshape(-1, QK8_0) scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1) - tensor = (tensor / scale).round().clamp(min=-128, max=127).char() + iscale = torch.where(scale != 0.0, 1.0 / scale, 0.0) + tensor = torch_roundf(tensor * iscale).clamp(min=-128, max=127).char() # add scale into each block tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1) return tensor @@ -175,9 +184,7 @@ def maybe_quantize_tensor(tensor, ggml_type): elif ggml_type == gguf.GGMLQuantizationType.F16: return tensor.half() elif ggml_type == gguf.GGMLQuantizationType.Q8_0: - if tensor.device.type == "meta": - return quantize_q8_0(tensor) # Cannot convert into numpy array. - return torch.from_numpy(gguf.quantize_q8_0(tensor.numpy())) + return quantize_q8_0(tensor) elif ggml_type == gguf.GGMLQuantizationType.Q4_0: return quantize_q4_0(tensor) elif ggml_type == gguf.GGMLQuantizationType.Q4_1: