mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
gguf-py : simplify support for quant types (#8838)
* gguf-py : use classes for quants * convert_hf : simplify internal quantization type selection * gguf-py : fix flake8 lint * gguf-py : fix BF16 numpy view type * gguf-py : remove LlamaFileTypeMap Too specific to 'llama.cpp', and would be a maintenance burden to keep up to date. * gguf-py : add generic quantize and dequantize functions The quant classes no longer need to be known, only the target or the source type, for 'quantize' and 'dequantize', respectively.
This commit is contained in:
parent
afd27f01fe
commit
3a14e00366
@ -251,12 +251,7 @@ class Model:
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
||||
del name, new_name, bid, n_dims # unused
|
||||
|
||||
return False
|
||||
|
||||
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
||||
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
|
||||
del name, new_name, bid, n_dims # unused
|
||||
|
||||
return False
|
||||
@ -285,55 +280,47 @@ class Model:
|
||||
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
|
||||
data: np.ndarray # type hint
|
||||
n_dims = len(data.shape)
|
||||
data_dtype = data.dtype
|
||||
data_qtype: gguf.GGMLQuantizationType | None = None
|
||||
|
||||
# when both are True, f32 should win
|
||||
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
|
||||
extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
|
||||
data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
# Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
|
||||
# Conditions should closely match those in llama_model_quantize_internal in llama.cpp
|
||||
extra_f32 = any(cond for cond in (
|
||||
extra_f32,
|
||||
n_dims == 1,
|
||||
new_name.endswith("_norm.weight"),
|
||||
))
|
||||
|
||||
# Some tensor types are always in float32
|
||||
extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in (
|
||||
gguf.MODEL_TENSOR.FFN_GATE_INP,
|
||||
gguf.MODEL_TENSOR.POS_EMBD,
|
||||
gguf.MODEL_TENSOR.TOKEN_TYPES,
|
||||
))
|
||||
|
||||
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||
extra_f16 = any(cond for cond in (
|
||||
extra_f16,
|
||||
(name.endswith(".weight") and n_dims >= 2),
|
||||
))
|
||||
|
||||
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
|
||||
if self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
|
||||
data = gguf.quantize_bf16(data)
|
||||
assert data.dtype == np.uint16
|
||||
data_qtype = gguf.GGMLQuantizationType.BF16
|
||||
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
|
||||
data = gguf.quantize_q8_0(data)
|
||||
assert data.dtype == np.uint8
|
||||
data_qtype = gguf.GGMLQuantizationType.Q8_0
|
||||
|
||||
else: # default to float16 for quantized tensors
|
||||
if data_dtype != np.float16:
|
||||
data = data.astype(np.float16)
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
|
||||
if data_qtype is None: # by default, convert to float32
|
||||
if data_dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
if n_dims <= 1 or new_name.endswith("_norm.weight"):
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
|
||||
# Conditions should closely match those in llama_model_quantize_internal in llama.cpp
|
||||
# Some tensor types are always in float32
|
||||
if data_qtype is False and (
|
||||
any(
|
||||
self.match_model_tensor_name(new_name, key, bid)
|
||||
for key in (
|
||||
gguf.MODEL_TENSOR.FFN_GATE_INP,
|
||||
gguf.MODEL_TENSOR.POS_EMBD,
|
||||
gguf.MODEL_TENSOR.TOKEN_TYPES,
|
||||
)
|
||||
)
|
||||
or not name.endswith(".weight")
|
||||
):
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
|
||||
# No override (data_qtype is False), or wants to be quantized (data_qtype is True)
|
||||
if isinstance(data_qtype, bool):
|
||||
if self.ftype == gguf.LlamaFileType.ALL_F32:
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
|
||||
data_qtype = gguf.GGMLQuantizationType.BF16
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
|
||||
data_qtype = gguf.GGMLQuantizationType.Q8_0
|
||||
else:
|
||||
raise ValueError(f"Unknown file type: {self.ftype.name}")
|
||||
|
||||
try:
|
||||
data = gguf.quants.quantize(data, data_qtype)
|
||||
except gguf.QuantError as e:
|
||||
logger.warning("%s, %s", e, "falling back to F16")
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
data = gguf.quants.quantize(data, data_qtype)
|
||||
|
||||
shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
|
||||
|
||||
# reverse shape to make it similar to the internal ggml dimension order
|
||||
@ -1765,7 +1752,7 @@ class DbrxModel(Model):
|
||||
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
||||
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
|
||||
del name, new_name, bid # unused
|
||||
|
||||
return n_dims > 1
|
||||
@ -2786,18 +2773,22 @@ class MambaModel(Model):
|
||||
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
||||
del n_dims # unused
|
||||
|
||||
return bid is not None and new_name in (
|
||||
self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [
|
||||
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
|
||||
if bid is not None and new_name in (
|
||||
self.format_tensor_name(
|
||||
n, bid, ".weight" if name.endswith(".weight") else ""
|
||||
)
|
||||
for n in [
|
||||
gguf.MODEL_TENSOR.SSM_CONV1D,
|
||||
gguf.MODEL_TENSOR.SSM_X,
|
||||
gguf.MODEL_TENSOR.SSM_DT,
|
||||
gguf.MODEL_TENSOR.SSM_A,
|
||||
gguf.MODEL_TENSOR.SSM_D,
|
||||
]
|
||||
)
|
||||
):
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
|
||||
@Model.register("CohereForCausalLM")
|
||||
|
@ -1146,6 +1146,9 @@ class GGMLQuantizationType(IntEnum):
|
||||
F64 = 28
|
||||
IQ1_M = 29
|
||||
BF16 = 30
|
||||
Q4_0_4_4 = 31
|
||||
Q4_0_4_8 = 32
|
||||
Q4_0_8_8 = 33
|
||||
|
||||
|
||||
# TODO: add GGMLFileType from ggml_ftype in ggml.h
|
||||
@ -1158,7 +1161,7 @@ class LlamaFileType(IntEnum):
|
||||
MOSTLY_F16 = 1 # except 1d tensors
|
||||
MOSTLY_Q4_0 = 2 # except 1d tensors
|
||||
MOSTLY_Q4_1 = 3 # except 1d tensors
|
||||
MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
|
||||
# MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
|
||||
# MOSTLY_Q4_2 = 5 # support has been removed
|
||||
# MOSTLY_Q4_3 = 6 # support has been removed
|
||||
MOSTLY_Q8_0 = 7 # except 1d tensors
|
||||
@ -1187,6 +1190,9 @@ class LlamaFileType(IntEnum):
|
||||
MOSTLY_IQ4_XS = 30 # except 1d tensors
|
||||
MOSTLY_IQ1_M = 31 # except 1d tensors
|
||||
MOSTLY_BF16 = 32 # except 1d tensors
|
||||
MOSTLY_Q4_0_4_4 = 33 # except 1d tensors
|
||||
MOSTLY_Q4_0_4_8 = 34 # except 1d tensors
|
||||
MOSTLY_Q4_0_8_8 = 35 # except 1d tensors
|
||||
|
||||
GUESSED = 1024 # not specified in the model file
|
||||
|
||||
@ -1260,6 +1266,9 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
||||
GGMLQuantizationType.F64: (1, 8),
|
||||
GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
|
||||
GGMLQuantizationType.BF16: (1, 2),
|
||||
GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16),
|
||||
GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16),
|
||||
GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16),
|
||||
}
|
||||
|
||||
|
||||
|
@ -191,6 +191,8 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
class LazyNumpyTensor(LazyBase):
|
||||
_tensor_type = np.ndarray
|
||||
|
||||
shape: tuple[int, ...] # Makes the type checker happy in quants.py
|
||||
|
||||
@classmethod
|
||||
def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
|
||||
# The initial idea was to use np.nan as the fill value,
|
||||
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Callable, Sequence
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
@ -9,32 +10,22 @@ from .lazy import LazyNumpyTensor
|
||||
import numpy as np
|
||||
|
||||
|
||||
def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType):
|
||||
def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
|
||||
block_size, type_size = GGML_QUANT_SIZES[quant_type]
|
||||
if shape[-1] % block_size != 0:
|
||||
raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})")
|
||||
return (*shape[:-1], shape[-1] // block_size * type_size)
|
||||
|
||||
|
||||
def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType):
|
||||
def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
|
||||
block_size, type_size = GGML_QUANT_SIZES[quant_type]
|
||||
if shape[-1] % type_size != 0:
|
||||
raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})")
|
||||
return (*shape[:-1], shape[-1] // type_size * block_size)
|
||||
|
||||
|
||||
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
|
||||
def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
|
||||
n = n.astype(np.float32, copy=False).view(np.uint32)
|
||||
# force nan to quiet
|
||||
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
|
||||
# round to nearest even
|
||||
n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
|
||||
return n.astype(np.uint16)
|
||||
|
||||
|
||||
# This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time
|
||||
def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
|
||||
def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
|
||||
rows = arr.reshape((-1, arr.shape[-1]))
|
||||
osize = 1
|
||||
for dim in oshape:
|
||||
@ -46,27 +37,6 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.
|
||||
return out.reshape(oshape)
|
||||
|
||||
|
||||
def __quantize_bf16_array(n: np.ndarray) -> np.ndarray:
|
||||
return __apply_over_grouped_rows(__compute_fp32_to_bf16, arr=n, otype=np.uint16, oshape=n.shape)
|
||||
|
||||
|
||||
__quantize_bf16_lazy = LazyNumpyTensor._wrap_fn(__quantize_bf16_array, meta_noop=np.uint16)
|
||||
|
||||
|
||||
def quantize_bf16(n: np.ndarray):
|
||||
if type(n) is LazyNumpyTensor:
|
||||
return __quantize_bf16_lazy(n)
|
||||
else:
|
||||
return __quantize_bf16_array(n)
|
||||
|
||||
|
||||
__q8_block_size, __q8_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q8_0]
|
||||
|
||||
|
||||
def can_quantize_to_q8_0(n: np.ndarray) -> bool:
|
||||
return n.shape[-1] % __q8_block_size == 0
|
||||
|
||||
|
||||
# round away from zero
|
||||
# ref: https://stackoverflow.com/a/59143326/22827863
|
||||
def np_roundf(n: np.ndarray) -> np.ndarray:
|
||||
@ -76,46 +46,168 @@ def np_roundf(n: np.ndarray) -> np.ndarray:
|
||||
return np.sign(n) * b
|
||||
|
||||
|
||||
def __quantize_q8_0_shape_change(s: tuple[int, ...]) -> tuple[int, ...]:
|
||||
return (*s[:-1], s[-1] // __q8_block_size * __q8_type_size)
|
||||
class QuantError(Exception): ...
|
||||
|
||||
|
||||
# Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c
|
||||
def __quantize_q8_0_rows(n: np.ndarray) -> np.ndarray:
|
||||
shape = n.shape
|
||||
assert shape[-1] % __q8_block_size == 0
|
||||
|
||||
n_blocks = n.size // __q8_block_size
|
||||
|
||||
blocks = n.reshape((n_blocks, __q8_block_size)).astype(np.float32, copy=False)
|
||||
|
||||
d = abs(blocks).max(axis=1, keepdims=True) / 127
|
||||
with np.errstate(divide="ignore"):
|
||||
id = np.where(d == 0, 0, 1 / d)
|
||||
qs = np_roundf(blocks * id)
|
||||
|
||||
# (n_blocks, 2)
|
||||
d = d.astype(np.float16).view(np.uint8)
|
||||
# (n_blocks, block_size)
|
||||
qs = qs.astype(np.int8).view(np.uint8)
|
||||
|
||||
assert d.shape[1] + qs.shape[1] == __q8_type_size
|
||||
|
||||
return np.concatenate([d, qs], axis=1).reshape(__quantize_q8_0_shape_change(shape))
|
||||
_type_traits: dict[GGMLQuantizationType, type[__Quant]] = {}
|
||||
|
||||
|
||||
def __quantize_q8_0_array(n: np.ndarray) -> np.ndarray:
|
||||
return __apply_over_grouped_rows(__quantize_q8_0_rows, arr=n, otype=np.uint8, oshape=__quantize_q8_0_shape_change(n.shape))
|
||||
|
||||
|
||||
__quantize_q8_0_lazy = LazyNumpyTensor._wrap_fn(
|
||||
__quantize_q8_0_array,
|
||||
meta_noop=(np.uint8, __quantize_q8_0_shape_change),
|
||||
)
|
||||
|
||||
|
||||
def quantize_q8_0(data: np.ndarray):
|
||||
if type(data) is LazyNumpyTensor:
|
||||
return __quantize_q8_0_lazy(data)
|
||||
def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
|
||||
if qtype == GGMLQuantizationType.F32:
|
||||
return data.astype(np.float32, copy=False)
|
||||
elif qtype == GGMLQuantizationType.F16:
|
||||
return data.astype(np.float16, copy=False)
|
||||
elif (q := _type_traits.get(qtype)) is not None:
|
||||
return q.quantize(data)
|
||||
else:
|
||||
return __quantize_q8_0_array(data)
|
||||
raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented")
|
||||
|
||||
|
||||
def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
|
||||
if qtype == GGMLQuantizationType.F32 or qtype == GGMLQuantizationType.F16:
|
||||
return data.astype(np.float32, copy=False)
|
||||
elif (q := _type_traits.get(qtype)) is not None:
|
||||
return q.dequantize(data)
|
||||
else:
|
||||
raise NotImplementedError(f"Dequantization for {qtype.name} is not yet implemented")
|
||||
|
||||
|
||||
class __Quant(ABC):
|
||||
qtype: GGMLQuantizationType
|
||||
block_size: int
|
||||
type_size: int
|
||||
|
||||
def __init__(self):
|
||||
return TypeError("Quant conversion classes can't have instances")
|
||||
|
||||
def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None:
|
||||
cls.qtype = qtype
|
||||
cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype]
|
||||
cls.__quantize_lazy = LazyNumpyTensor._wrap_fn(
|
||||
cls.__quantize_array,
|
||||
meta_noop=(np.uint8, cls.__shape_to_bytes)
|
||||
)
|
||||
cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn(
|
||||
cls.__dequantize_array,
|
||||
meta_noop=(np.float32, cls.__shape_from_bytes)
|
||||
)
|
||||
assert qtype not in _type_traits
|
||||
_type_traits[qtype] = cls
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def quantize_rows(cls, rows: np.ndarray) -> np.ndarray:
|
||||
rows = rows.astype(np.float32, copy=False)
|
||||
shape = rows.shape
|
||||
n_blocks = rows.size // cls.block_size
|
||||
blocks = rows.reshape((n_blocks, cls.block_size))
|
||||
blocks = cls.quantize_blocks(blocks)
|
||||
assert blocks.dtype == np.uint8
|
||||
assert blocks.shape[-1] == cls.type_size
|
||||
return blocks.reshape(cls.__shape_to_bytes(shape))
|
||||
|
||||
@classmethod
|
||||
def dequantize_rows(cls, rows: np.ndarray) -> np.ndarray:
|
||||
rows = rows.view(np.uint8)
|
||||
shape = rows.shape
|
||||
n_blocks = rows.size // cls.type_size
|
||||
blocks = rows.reshape((n_blocks, cls.type_size))
|
||||
blocks = cls.dequantize_blocks(blocks)
|
||||
assert blocks.dtype == np.float32
|
||||
assert blocks.shape[-1] == cls.block_size
|
||||
return blocks.reshape(cls.__shape_from_bytes(shape))
|
||||
|
||||
@classmethod
|
||||
def __shape_to_bytes(cls, shape: Sequence[int]):
|
||||
return quant_shape_to_byte_shape(shape, cls.qtype)
|
||||
|
||||
@classmethod
|
||||
def __shape_from_bytes(cls, shape: Sequence[int]):
|
||||
return quant_shape_from_byte_shape(shape, cls.qtype)
|
||||
|
||||
@classmethod
|
||||
def __quantize_array(cls, array: np.ndarray) -> np.ndarray:
|
||||
return _apply_over_grouped_rows(cls.quantize_rows, arr=array, otype=np.uint8, oshape=cls.__shape_to_bytes(array.shape))
|
||||
|
||||
@classmethod
|
||||
def __dequantize_array(cls, array: np.ndarray) -> np.ndarray:
|
||||
return _apply_over_grouped_rows(cls.dequantize_rows, arr=array, otype=np.float32, oshape=cls.__shape_from_bytes(array.shape))
|
||||
|
||||
@classmethod
|
||||
def __quantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def __dequantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def can_quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> bool:
|
||||
return tensor.shape[-1] % cls.block_size == 0
|
||||
|
||||
@classmethod
|
||||
def quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
|
||||
if not cls.can_quantize(tensor):
|
||||
raise QuantError(f"Can't quantize tensor with shape {tensor.shape} to {cls.qtype.name}")
|
||||
if isinstance(tensor, LazyNumpyTensor):
|
||||
return cls.__quantize_lazy(tensor)
|
||||
else:
|
||||
return cls.__quantize_array(tensor)
|
||||
|
||||
@classmethod
|
||||
def dequantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
|
||||
if isinstance(tensor, LazyNumpyTensor):
|
||||
return cls.__dequantize_lazy(tensor)
|
||||
else:
|
||||
return cls.__dequantize_array(tensor)
|
||||
|
||||
|
||||
class BF16(__Quant, qtype=GGMLQuantizationType.BF16):
|
||||
@classmethod
|
||||
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
|
||||
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
|
||||
n = blocks.view(np.uint32)
|
||||
# force nan to quiet
|
||||
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
|
||||
# round to nearest even
|
||||
n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
|
||||
return n.astype(np.uint16).view(np.uint8)
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
|
||||
return (blocks.view(np.int16).astype(np.int32) << 16).view(np.float32)
|
||||
|
||||
|
||||
class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
|
||||
@classmethod
|
||||
# Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c
|
||||
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
|
||||
|
||||
d = abs(blocks).max(axis=1, keepdims=True) / 127
|
||||
with np.errstate(divide="ignore"):
|
||||
id = np.where(d == 0, 0, 1 / d)
|
||||
qs = np_roundf(blocks * id)
|
||||
|
||||
# (n_blocks, 2)
|
||||
d = d.astype(np.float16).view(np.uint8)
|
||||
# (n_blocks, block_size)
|
||||
qs = qs.astype(np.int8).view(np.uint8)
|
||||
|
||||
return np.concatenate([d, qs], axis=1)
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
|
||||
d, x = np.split(blocks, [2], axis=1)
|
||||
d = d.view(np.float16).astype(np.float32)
|
||||
x = x.view(np.int8).astype(np.float32)
|
||||
|
||||
return (x * d)
|
||||
|
Loading…
Reference in New Issue
Block a user