From 3a14e00366399040a139c67dd5951177a8cb5695 Mon Sep 17 00:00:00 2001 From: compilade Date: Thu, 8 Aug 2024 13:33:09 -0400 Subject: [PATCH] gguf-py : simplify support for quant types (#8838) * gguf-py : use classes for quants * convert_hf : simplify internal quantization type selection * gguf-py : fix flake8 lint * gguf-py : fix BF16 numpy view type * gguf-py : remove LlamaFileTypeMap Too specific to 'llama.cpp', and would be a maintenance burden to keep up to date. * gguf-py : add generic quantize and dequantize functions The quant classes no longer need to be known, only the target or the source type, for 'quantize' and 'dequantize', respectively. --- convert_hf_to_gguf.py | 107 ++++++++--------- gguf-py/gguf/constants.py | 11 +- gguf-py/gguf/lazy.py | 2 + gguf-py/gguf/quants.py | 238 ++++++++++++++++++++++++++------------ 4 files changed, 226 insertions(+), 132 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 38b92bc81..7136db440 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -251,12 +251,7 @@ class Model: return [(self.map_tensor_name(name), data_torch)] - def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: - del name, new_name, bid, n_dims # unused - - return False - - def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid, n_dims # unused return False @@ -285,55 +280,47 @@ class Model: for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): data: np.ndarray # type hint n_dims = len(data.shape) - data_dtype = data.dtype - data_qtype: gguf.GGMLQuantizationType | None = None - - # when both are True, f32 should win - extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims) - extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims) + data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims) # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors - # Conditions should closely match those in llama_model_quantize_internal in llama.cpp - extra_f32 = any(cond for cond in ( - extra_f32, - n_dims == 1, - new_name.endswith("_norm.weight"), - )) - - # Some tensor types are always in float32 - extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in ( - gguf.MODEL_TENSOR.FFN_GATE_INP, - gguf.MODEL_TENSOR.POS_EMBD, - gguf.MODEL_TENSOR.TOKEN_TYPES, - )) - - # if f16 desired, convert any float32 2-dim weight tensors to float16 - extra_f16 = any(cond for cond in ( - extra_f16, - (name.endswith(".weight") and n_dims >= 2), - )) - - if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32: - if self.ftype == gguf.LlamaFileType.MOSTLY_BF16: - data = gguf.quantize_bf16(data) - assert data.dtype == np.uint16 - data_qtype = gguf.GGMLQuantizationType.BF16 - - elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data): - data = gguf.quantize_q8_0(data) - assert data.dtype == np.uint8 - data_qtype = gguf.GGMLQuantizationType.Q8_0 - - else: # default to float16 for quantized tensors - if data_dtype != np.float16: - data = data.astype(np.float16) - data_qtype = gguf.GGMLQuantizationType.F16 - - if data_qtype is None: # by default, convert to float32 - if data_dtype != np.float32: - data = data.astype(np.float32) + if n_dims <= 1 or new_name.endswith("_norm.weight"): data_qtype = gguf.GGMLQuantizationType.F32 + # Conditions should closely match those in llama_model_quantize_internal in llama.cpp + # Some tensor types are always in float32 + if data_qtype is False and ( + any( + self.match_model_tensor_name(new_name, key, bid) + for key in ( + gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.POS_EMBD, + gguf.MODEL_TENSOR.TOKEN_TYPES, + ) + ) + or not name.endswith(".weight") + ): + data_qtype = gguf.GGMLQuantizationType.F32 + + # No override (data_qtype is False), or wants to be quantized (data_qtype is True) + if isinstance(data_qtype, bool): + if self.ftype == gguf.LlamaFileType.ALL_F32: + data_qtype = gguf.GGMLQuantizationType.F32 + elif self.ftype == gguf.LlamaFileType.MOSTLY_F16: + data_qtype = gguf.GGMLQuantizationType.F16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + data_qtype = gguf.GGMLQuantizationType.BF16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: + data_qtype = gguf.GGMLQuantizationType.Q8_0 + else: + raise ValueError(f"Unknown file type: {self.ftype.name}") + + try: + data = gguf.quants.quantize(data, data_qtype) + except gguf.QuantError as e: + logger.warning("%s, %s", e, "falling back to F16") + data_qtype = gguf.GGMLQuantizationType.F16 + data = gguf.quants.quantize(data, data_qtype) + shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape # reverse shape to make it similar to the internal ggml dimension order @@ -1765,7 +1752,7 @@ class DbrxModel(Model): return [(new_name, data_torch)] - def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid # unused return n_dims > 1 @@ -2786,18 +2773,22 @@ class MambaModel(Model): return [(new_name, data_torch)] - def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: - del n_dims # unused - - return bid is not None and new_name in ( - self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [ + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: + if bid is not None and new_name in ( + self.format_tensor_name( + n, bid, ".weight" if name.endswith(".weight") else "" + ) + for n in [ gguf.MODEL_TENSOR.SSM_CONV1D, gguf.MODEL_TENSOR.SSM_X, gguf.MODEL_TENSOR.SSM_DT, gguf.MODEL_TENSOR.SSM_A, gguf.MODEL_TENSOR.SSM_D, ] - ) + ): + return gguf.GGMLQuantizationType.F32 + + return super().tensor_force_quant(name, new_name, bid, n_dims) @Model.register("CohereForCausalLM") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 59ffd92ea..89efe0c80 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1146,6 +1146,9 @@ class GGMLQuantizationType(IntEnum): F64 = 28 IQ1_M = 29 BF16 = 30 + Q4_0_4_4 = 31 + Q4_0_4_8 = 32 + Q4_0_8_8 = 33 # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -1158,7 +1161,7 @@ class LlamaFileType(IntEnum): MOSTLY_F16 = 1 # except 1d tensors MOSTLY_Q4_0 = 2 # except 1d tensors MOSTLY_Q4_1 = 3 # except 1d tensors - MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16 + # MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16 # MOSTLY_Q4_2 = 5 # support has been removed # MOSTLY_Q4_3 = 6 # support has been removed MOSTLY_Q8_0 = 7 # except 1d tensors @@ -1187,6 +1190,9 @@ class LlamaFileType(IntEnum): MOSTLY_IQ4_XS = 30 # except 1d tensors MOSTLY_IQ1_M = 31 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors + MOSTLY_Q4_0_4_4 = 33 # except 1d tensors + MOSTLY_Q4_0_4_8 = 34 # except 1d tensors + MOSTLY_Q4_0_8_8 = 35 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1260,6 +1266,9 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { GGMLQuantizationType.F64: (1, 8), GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), GGMLQuantizationType.BF16: (1, 2), + GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16), + GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16), + GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16), } diff --git a/gguf-py/gguf/lazy.py b/gguf-py/gguf/lazy.py index ac98d9a92..8d4fece2d 100644 --- a/gguf-py/gguf/lazy.py +++ b/gguf-py/gguf/lazy.py @@ -191,6 +191,8 @@ class LazyBase(ABC, metaclass=LazyMeta): class LazyNumpyTensor(LazyBase): _tensor_type = np.ndarray + shape: tuple[int, ...] # Makes the type checker happy in quants.py + @classmethod def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]: # The initial idea was to use np.nan as the fill value, diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index f4361d751..a443dd27e 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -1,5 +1,6 @@ from __future__ import annotations -from typing import Callable, Sequence +from abc import ABC, abstractmethod +from typing import Any, Callable, Sequence from numpy.typing import DTypeLike @@ -9,32 +10,22 @@ from .lazy import LazyNumpyTensor import numpy as np -def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType): +def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]: block_size, type_size = GGML_QUANT_SIZES[quant_type] if shape[-1] % block_size != 0: raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})") return (*shape[:-1], shape[-1] // block_size * type_size) -def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType): +def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]: block_size, type_size = GGML_QUANT_SIZES[quant_type] if shape[-1] % type_size != 0: raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})") return (*shape[:-1], shape[-1] // type_size * block_size) -# same as ggml_compute_fp32_to_bf16 in ggml-impl.h -def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray: - n = n.astype(np.float32, copy=False).view(np.uint32) - # force nan to quiet - n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n) - # round to nearest even - n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16 - return n.astype(np.uint16) - - # This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time -def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray: +def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray: rows = arr.reshape((-1, arr.shape[-1])) osize = 1 for dim in oshape: @@ -46,27 +37,6 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np. return out.reshape(oshape) -def __quantize_bf16_array(n: np.ndarray) -> np.ndarray: - return __apply_over_grouped_rows(__compute_fp32_to_bf16, arr=n, otype=np.uint16, oshape=n.shape) - - -__quantize_bf16_lazy = LazyNumpyTensor._wrap_fn(__quantize_bf16_array, meta_noop=np.uint16) - - -def quantize_bf16(n: np.ndarray): - if type(n) is LazyNumpyTensor: - return __quantize_bf16_lazy(n) - else: - return __quantize_bf16_array(n) - - -__q8_block_size, __q8_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q8_0] - - -def can_quantize_to_q8_0(n: np.ndarray) -> bool: - return n.shape[-1] % __q8_block_size == 0 - - # round away from zero # ref: https://stackoverflow.com/a/59143326/22827863 def np_roundf(n: np.ndarray) -> np.ndarray: @@ -76,46 +46,168 @@ def np_roundf(n: np.ndarray) -> np.ndarray: return np.sign(n) * b -def __quantize_q8_0_shape_change(s: tuple[int, ...]) -> tuple[int, ...]: - return (*s[:-1], s[-1] // __q8_block_size * __q8_type_size) +class QuantError(Exception): ... -# Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c -def __quantize_q8_0_rows(n: np.ndarray) -> np.ndarray: - shape = n.shape - assert shape[-1] % __q8_block_size == 0 - - n_blocks = n.size // __q8_block_size - - blocks = n.reshape((n_blocks, __q8_block_size)).astype(np.float32, copy=False) - - d = abs(blocks).max(axis=1, keepdims=True) / 127 - with np.errstate(divide="ignore"): - id = np.where(d == 0, 0, 1 / d) - qs = np_roundf(blocks * id) - - # (n_blocks, 2) - d = d.astype(np.float16).view(np.uint8) - # (n_blocks, block_size) - qs = qs.astype(np.int8).view(np.uint8) - - assert d.shape[1] + qs.shape[1] == __q8_type_size - - return np.concatenate([d, qs], axis=1).reshape(__quantize_q8_0_shape_change(shape)) +_type_traits: dict[GGMLQuantizationType, type[__Quant]] = {} -def __quantize_q8_0_array(n: np.ndarray) -> np.ndarray: - return __apply_over_grouped_rows(__quantize_q8_0_rows, arr=n, otype=np.uint8, oshape=__quantize_q8_0_shape_change(n.shape)) - - -__quantize_q8_0_lazy = LazyNumpyTensor._wrap_fn( - __quantize_q8_0_array, - meta_noop=(np.uint8, __quantize_q8_0_shape_change), -) - - -def quantize_q8_0(data: np.ndarray): - if type(data) is LazyNumpyTensor: - return __quantize_q8_0_lazy(data) +def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: + if qtype == GGMLQuantizationType.F32: + return data.astype(np.float32, copy=False) + elif qtype == GGMLQuantizationType.F16: + return data.astype(np.float16, copy=False) + elif (q := _type_traits.get(qtype)) is not None: + return q.quantize(data) else: - return __quantize_q8_0_array(data) + raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented") + + +def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: + if qtype == GGMLQuantizationType.F32 or qtype == GGMLQuantizationType.F16: + return data.astype(np.float32, copy=False) + elif (q := _type_traits.get(qtype)) is not None: + return q.dequantize(data) + else: + raise NotImplementedError(f"Dequantization for {qtype.name} is not yet implemented") + + +class __Quant(ABC): + qtype: GGMLQuantizationType + block_size: int + type_size: int + + def __init__(self): + return TypeError("Quant conversion classes can't have instances") + + def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None: + cls.qtype = qtype + cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype] + cls.__quantize_lazy = LazyNumpyTensor._wrap_fn( + cls.__quantize_array, + meta_noop=(np.uint8, cls.__shape_to_bytes) + ) + cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn( + cls.__dequantize_array, + meta_noop=(np.float32, cls.__shape_from_bytes) + ) + assert qtype not in _type_traits + _type_traits[qtype] = cls + + @classmethod + @abstractmethod + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + raise NotImplementedError + + @classmethod + @abstractmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + raise NotImplementedError + + @classmethod + def quantize_rows(cls, rows: np.ndarray) -> np.ndarray: + rows = rows.astype(np.float32, copy=False) + shape = rows.shape + n_blocks = rows.size // cls.block_size + blocks = rows.reshape((n_blocks, cls.block_size)) + blocks = cls.quantize_blocks(blocks) + assert blocks.dtype == np.uint8 + assert blocks.shape[-1] == cls.type_size + return blocks.reshape(cls.__shape_to_bytes(shape)) + + @classmethod + def dequantize_rows(cls, rows: np.ndarray) -> np.ndarray: + rows = rows.view(np.uint8) + shape = rows.shape + n_blocks = rows.size // cls.type_size + blocks = rows.reshape((n_blocks, cls.type_size)) + blocks = cls.dequantize_blocks(blocks) + assert blocks.dtype == np.float32 + assert blocks.shape[-1] == cls.block_size + return blocks.reshape(cls.__shape_from_bytes(shape)) + + @classmethod + def __shape_to_bytes(cls, shape: Sequence[int]): + return quant_shape_to_byte_shape(shape, cls.qtype) + + @classmethod + def __shape_from_bytes(cls, shape: Sequence[int]): + return quant_shape_from_byte_shape(shape, cls.qtype) + + @classmethod + def __quantize_array(cls, array: np.ndarray) -> np.ndarray: + return _apply_over_grouped_rows(cls.quantize_rows, arr=array, otype=np.uint8, oshape=cls.__shape_to_bytes(array.shape)) + + @classmethod + def __dequantize_array(cls, array: np.ndarray) -> np.ndarray: + return _apply_over_grouped_rows(cls.dequantize_rows, arr=array, otype=np.float32, oshape=cls.__shape_from_bytes(array.shape)) + + @classmethod + def __quantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any: + pass + + @classmethod + def __dequantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any: + pass + + @classmethod + def can_quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> bool: + return tensor.shape[-1] % cls.block_size == 0 + + @classmethod + def quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray: + if not cls.can_quantize(tensor): + raise QuantError(f"Can't quantize tensor with shape {tensor.shape} to {cls.qtype.name}") + if isinstance(tensor, LazyNumpyTensor): + return cls.__quantize_lazy(tensor) + else: + return cls.__quantize_array(tensor) + + @classmethod + def dequantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray: + if isinstance(tensor, LazyNumpyTensor): + return cls.__dequantize_lazy(tensor) + else: + return cls.__dequantize_array(tensor) + + +class BF16(__Quant, qtype=GGMLQuantizationType.BF16): + @classmethod + # same as ggml_compute_fp32_to_bf16 in ggml-impl.h + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n = blocks.view(np.uint32) + # force nan to quiet + n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n) + # round to nearest even + n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16 + return n.astype(np.uint16).view(np.uint8) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + return (blocks.view(np.int16).astype(np.int32) << 16).view(np.float32) + + +class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0): + @classmethod + # Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + + d = abs(blocks).max(axis=1, keepdims=True) / 127 + with np.errstate(divide="ignore"): + id = np.where(d == 0, 0, 1 / d) + qs = np_roundf(blocks * id) + + # (n_blocks, 2) + d = d.astype(np.float16).view(np.uint8) + # (n_blocks, block_size) + qs = qs.astype(np.int8).view(np.uint8) + + return np.concatenate([d, qs], axis=1) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + d, x = np.split(blocks, [2], axis=1) + d = d.view(np.float16).astype(np.float32) + x = x.view(np.int8).astype(np.float32) + + return (x * d)