convert-hf : support bfloat16 conversion (#7158)

* convert-hf : support bfloat16 conversion

* gguf-py : flake8 fixes

* convert-hf : add missing space after comma

* convert-hf : get bit-exact same output as ./quantize

The quantization version was missing.

* convert-hf : don't round bf16 NANs

* convert-hf : save some memory with np.int16 intermediate bf16 weights

* convert-hf : more closely match llama.cpp with which weights to keep in f32

* convert-hf : add --outtype auto-f16

A reason for this to exist is for model quantizers who want an initial
GGUF with the most fidelity to the original model while still using
a 16-bit float type instead of 32-bit floats.

* convert-hf : remove a semicolon because flake8 doesn't like it

It's a reflex from when programming in C/C++, I guess.

* convert-hf : support outtype templating in outfile name

* convert-hf : rename --outtype auto-f16 to --outtype auto
This commit is contained in:
compilade 2024-05-11 11:06:26 -04:00 committed by GitHub
parent fae9d234b6
commit 5a419926b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 404 additions and 182 deletions

View File

@ -12,7 +12,7 @@ import sys
from enum import IntEnum
from pathlib import Path
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
import numpy as np
import torch
@ -48,7 +48,6 @@ class Model:
dir_model: Path
ftype: int
fname_out: Path
is_big_endian: bool
endianess: gguf.GGUFEndian
use_temp_file: bool
@ -56,20 +55,20 @@ class Model:
part_names: list[str]
is_safetensors: bool
hparams: dict[str, Any]
gguf_writer: gguf.GGUFWriter
block_count: int
tensor_map: gguf.TensorNameMap
tensor_names: set[str] | None
fname_out: Path
gguf_writer: gguf.GGUFWriter
# subclasses should define this!
model_arch: gguf.MODEL_ARCH
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
if self.__class__ == Model:
raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
self.dir_model = dir_model
self.ftype = ftype
self.fname_out = fname_out
self.is_big_endian = is_big_endian
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file
@ -79,10 +78,23 @@ class Model:
if not self.is_safetensors:
self.part_names = Model.get_model_part_names(self.dir_model, ".bin")
self.hparams = Model.load_hparams(self.dir_model)
self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self.tensor_names = None
if self.ftype == gguf.LlamaFileType.GUESSED:
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
_, first_tensor = next(self.get_tensors())
if first_tensor.dtype == torch.float16:
logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})")
self.ftype = gguf.LlamaFileType.MOSTLY_F16
else:
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
ftype_up: str = self.ftype.name.partition("_")[2].upper()
ftype_lw: str = ftype_up.lower()
# allow templating the file name with the output ftype, useful with the "auto" ftype
self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
@classmethod
def __init_subclass__(cls):
@ -142,14 +154,27 @@ class Model:
raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
name: str = gguf.TENSOR_NAMES[key]
if key not in gguf.MODEL_TENSORS[self.model_arch]:
raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
name: str = gguf.TENSOR_NAMES[key]
if "{bid}" in name:
assert bid is not None
name = name.format(bid=bid)
return name + suffix
def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> bool:
if key not in gguf.MODEL_TENSORS[self.model_arch]:
return False
key_name: str = gguf.TENSOR_NAMES[key]
if "{bid}" in key_name:
if bid is None:
return False
key_name = key_name.format(bid=bid)
else:
if bid is not None:
return False
return name == (key_name + suffix)
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
if new_name is None:
@ -215,6 +240,23 @@ class Model:
return False
def write_tensors(self):
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
def np_fp32_to_bf16(n: np.ndarray):
# force nan to quiet
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
# flush subnormals to zero
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
# round to nearest even
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
return n.astype(np.int16)
# Doing this row-wise is much, much faster than element-wise, hence the signature
v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
if self.lazy:
# TODO: find a way to implicitly wrap np.vectorize functions
# NOTE: the type is changed to reflect otypes passed to np.vectorize above
v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
for name, data_torch in self.get_tensors():
@ -239,35 +281,60 @@ class Model:
data: np.ndarray = data # type hint
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
data_qtype: gguf.GGMLQuantizationType | None = None
# when both are True, f32 should win
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
# Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
extra_f32 = extra_f32 or n_dims == 1 or new_name.endswith("_norm.weight")
# Conditions should closely match those in llama_model_quantize_internal in llama.cpp
extra_f32 = any(cond for cond in (
extra_f32,
n_dims == 1,
new_name.endswith("_norm.weight"),
))
# Some tensor types are always in float32
extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in (
gguf.MODEL_TENSOR.FFN_GATE_INP,
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
))
# if f16 desired, convert any float32 2-dim weight tensors to float16
extra_f16 = extra_f16 or (name.endswith(".weight") and n_dims >= 2)
extra_f16 = any(cond for cond in (
extra_f16,
(name.endswith(".weight") and n_dims >= 2),
))
# when both extra_f32 and extra_f16 are False, convert to float32 by default
if self.ftype == 1 and data_dtype == np.float16 and (extra_f32 or not extra_f16):
data = data.astype(np.float32)
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
if data_dtype != np.float16:
data = data.astype(np.float16)
data_qtype = gguf.GGMLQuantizationType.F16
if self.ftype == 1 and data_dtype == np.float32 and extra_f16 and not extra_f32:
data = data.astype(np.float16)
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
if data_dtype != np.float32:
data = data.astype(np.float32)
data = v_fp32_to_bf16(data.view(np.int32))
assert data.dtype == np.int16
data_qtype = gguf.GGMLQuantizationType.BF16
else: # by default, convert to float32
if data_dtype != np.float32:
data = data.astype(np.float32)
data_qtype = gguf.GGMLQuantizationType.F32
assert data_qtype is not None
# reverse shape to make it similar to the internal ggml dimension order
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
# n_dims is implicit in the shape
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data.dtype}, shape = {shape_str}")
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
self.gguf_writer.add_tensor(new_name, data)
self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
def write(self):
self.write_tensors()
@ -2044,12 +2111,6 @@ class BertModel(Model):
return [(self.map_tensor_name(name), data_torch)]
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
del new_name, bid, n_dims # unused
# not used with get_rows, must be F32
return name == "embeddings.token_type_embeddings.weight"
@Model.register("NomicBertModel")
class NomicBertModel(BertModel):
@ -2339,92 +2400,40 @@ class JinaBertV2Model(BertModel):
# tree of lazy tensors
class LazyTorchTensor:
_meta: Tensor
_data: Tensor | None
_args: tuple
_func: Callable[[tuple], Tensor] | None
def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
self._meta = meta
self._data = data
self._args = args
self._func = func
@staticmethod
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
# TODO: dict and set
if isinstance(o, (list, tuple)):
L = []
for item in o:
L.append(LazyTorchTensor._recurse_apply(item, fn))
if isinstance(o, tuple):
L = tuple(L)
return L
elif isinstance(o, LazyTorchTensor):
return fn(o)
else:
return o
def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]:
def wrapped_fn(*args, **kwargs):
if kwargs is None:
kwargs = {}
args = ((self,) if use_self else ()) + args
meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)
return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
return wrapped_fn
def __getattr__(self, __name: str) -> Any:
meta_attr = getattr(self._meta, __name)
if callable(meta_attr):
return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
elif isinstance(meta_attr, torch.Tensor):
# for things like self.T
return self._wrap_fn(lambda s: getattr(s, __name))(self)
else:
return meta_attr
class LazyTorchTensor(gguf.LazyBase):
_tensor_type = torch.Tensor
# to keep the type-checker happy
dtype: torch.dtype
shape: torch.Size
# only used when converting a torch.Tensor to a np.ndarray
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
}
def numpy(self) -> gguf.LazyTensor:
def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape)
return gguf.LazyNumpyTensor(
meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
lazy=self._lazy,
args=(self,),
func=(lambda s: s[0].numpy())
)
@overload
@staticmethod
def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...
@overload
@staticmethod
def to_eager(t: tuple) -> tuple: ...
@staticmethod
def to_eager(t: Any) -> Any:
def simple_to_eager(_t: LazyTorchTensor) -> Tensor:
# wake up the lazy tensor
if _t._data is None and _t._func is not None:
# recurse into its arguments
_t._args = LazyTorchTensor.to_eager(_t._args)
_t._data = _t._func(_t._args)
if _t._data is not None:
return _t._data
else:
raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}")
# recurse into lists and/or tuples, keeping their structure
return LazyTorchTensor._recurse_apply(t, simple_to_eager)
@staticmethod
def from_eager(t: Tensor) -> Tensor:
if (t.__class__ == LazyTorchTensor):
@classmethod
def eager_to_meta(cls, t: Tensor) -> Tensor:
if t.is_meta:
return t
return LazyTorchTensor(meta=t.detach().to("meta"), data=t) # type: ignore
return t.detach().to("meta")
@classmethod
def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
m = m.detach()
if not m.is_meta:
m = m.to("meta")
m.dtype = dtype
return m
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
@ -2435,28 +2444,8 @@ class LazyTorchTensor:
if func is torch.Tensor.numpy:
return args[0].numpy()
if func is torch.equal:
eager_args = LazyTorchTensor.to_eager(args)
return func(*eager_args, **kwargs)
return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs)
# special methods bypass __getattr__, so they need to be added manually
# ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
# NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used
# as self._meta is currently used), because then the following
# operations would by default not be wrapped, and so not propagated
# when the tensor is made eager.
# It's better to get non-silent errors for not-yet-supported operators.
# TODO: add more when needed to avoid clutter, or find a more concise way
def __neg__(self, *args): # mamba
return self._wrap_fn(torch.Tensor.__neg__)(self, *args)
def __add__(self, *args): # gemma
return self._wrap_fn(torch.Tensor.__add__)(self, *args)
def __getitem__(self, *args): # bloom falcon refact internlm2
return self._wrap_fn(torch.Tensor.__getitem__)(self, *args)
return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)
def parse_args() -> argparse.Namespace:
@ -2472,11 +2461,11 @@ def parse_args() -> argparse.Namespace:
)
parser.add_argument(
"--outfile", type=Path,
help="path to write to; default: based on input",
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16"], default="f16",
help="output format - use f32 for float32, f16 for float16",
"--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
"--bigendian", action="store_true",
@ -2530,16 +2519,18 @@ def main() -> None:
logger.error(f'Error: {args.model} is not a directory')
sys.exit(1)
ftype_map = {
"f32": gguf.GGMLQuantizationType.F32,
"f16": gguf.GGMLQuantizationType.F16,
ftype_map: dict[str, gguf.LlamaFileType] = {
"f32": gguf.LlamaFileType.ALL_F32,
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
"auto": gguf.LlamaFileType.GUESSED,
}
if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
fname_out = dir_model / 'ggml-model-{ftype}.gguf'
logger.info(f"Loading model: {dir_model.name}")
@ -2555,14 +2546,16 @@ def main() -> None:
logger.info("Set model tokenizer")
model_instance.set_vocab()
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
if args.vocab_only:
logger.info(f"Exporting model vocab to '{fname_out}'")
logger.info(f"Exporting model vocab to '{model_instance.fname_out}'")
model_instance.write_vocab()
else:
logger.info(f"Exporting model to '{fname_out}'")
logger.info(f"Exporting model to '{model_instance.fname_out}'")
model_instance.write()
logger.info(f"Model successfully exported to '{fname_out}'")
logger.info(f"Model successfully exported to '{model_instance.fname_out}'")
if __name__ == '__main__':

View File

@ -1,4 +1,5 @@
from .constants import *
from .lazy import *
from .gguf_reader import *
from .gguf_writer import *
from .tensor_mapping import *

View File

@ -10,6 +10,7 @@ from typing import Any
GGUF_MAGIC = 0x46554747 # "GGUF"
GGUF_VERSION = 3
GGUF_DEFAULT_ALIGNMENT = 32
GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h
#
# metadata keys
@ -838,6 +839,49 @@ class GGMLQuantizationType(IntEnum):
BF16 = 30
# TODO: add GGMLFileType from ggml_ftype in ggml.h
# from llama_ftype in llama.h
# ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE.
class LlamaFileType(IntEnum):
ALL_F32 = 0
MOSTLY_F16 = 1 # except 1d tensors
MOSTLY_Q4_0 = 2 # except 1d tensors
MOSTLY_Q4_1 = 3 # except 1d tensors
MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
# MOSTLY_Q4_2 = 5 # support has been removed
# MOSTLY_Q4_3 = 6 # support has been removed
MOSTLY_Q8_0 = 7 # except 1d tensors
MOSTLY_Q5_0 = 8 # except 1d tensors
MOSTLY_Q5_1 = 9 # except 1d tensors
MOSTLY_Q2_K = 10 # except 1d tensors
MOSTLY_Q3_K_S = 11 # except 1d tensors
MOSTLY_Q3_K_M = 12 # except 1d tensors
MOSTLY_Q3_K_L = 13 # except 1d tensors
MOSTLY_Q4_K_S = 14 # except 1d tensors
MOSTLY_Q4_K_M = 15 # except 1d tensors
MOSTLY_Q5_K_S = 16 # except 1d tensors
MOSTLY_Q5_K_M = 17 # except 1d tensors
MOSTLY_Q6_K = 18 # except 1d tensors
MOSTLY_IQ2_XXS = 19 # except 1d tensors
MOSTLY_IQ2_XS = 20 # except 1d tensors
MOSTLY_Q2_K_S = 21 # except 1d tensors
MOSTLY_IQ3_XS = 22 # except 1d tensors
MOSTLY_IQ3_XXS = 23 # except 1d tensors
MOSTLY_IQ1_S = 24 # except 1d tensors
MOSTLY_IQ4_NL = 25 # except 1d tensors
MOSTLY_IQ3_S = 26 # except 1d tensors
MOSTLY_IQ3_M = 27 # except 1d tensors
MOSTLY_IQ2_S = 28 # except 1d tensors
MOSTLY_IQ2_M = 29 # except 1d tensors
MOSTLY_IQ4_XS = 30 # except 1d tensors
MOSTLY_IQ1_M = 31 # except 1d tensors
MOSTLY_BF16 = 32 # except 1d tensors
GUESSED = 1024 # not specified in the model file
class GGUFEndian(IntEnum):
LITTLE = 0
BIG = 1

View File

@ -7,7 +7,7 @@ import struct
import tempfile
from enum import Enum, auto
from io import BufferedWriter
from typing import IO, Any, Callable, Sequence, Mapping
from typing import IO, Any, Sequence, Mapping
from string import ascii_letters, digits
import numpy as np
@ -28,47 +28,6 @@ from .constants import (
logger = logging.getLogger(__name__)
class LazyTensor:
data: Callable[[], np.ndarray[Any, Any]]
# to avoid too deep recursion
functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
dtype: np.dtype[Any]
shape: tuple[int, ...]
def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
self.data = data
self.functions = []
self.dtype = np.dtype(dtype)
self.shape = shape
def astype(self, dtype: type, **kwargs) -> LazyTensor:
self.functions.append(lambda n: n.astype(dtype, **kwargs))
self.dtype = np.dtype(dtype)
return self
@property
def nbytes(self) -> int:
size = 1
for n in self.shape:
size *= n
return size * self.dtype.itemsize
def tofile(self, *args, **kwargs) -> None:
data = self.data()
for f in self.functions:
data = f(data)
assert data.shape == self.shape
assert data.dtype == self.dtype
assert data.nbytes == self.nbytes
self.functions = []
self.data = lambda: data
data.tofile(*args, **kwargs)
def byteswap(self, *args, **kwargs) -> LazyTensor:
self.functions.append(lambda n: n.byteswap(*args, **kwargs))
return self
class WriterState(Enum):
EMPTY = auto()
HEADER = auto()
@ -79,7 +38,7 @@ class WriterState(Enum):
class GGUFWriter:
fout: BufferedWriter
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: list[np.ndarray[Any, Any] | LazyTensor]
tensors: list[np.ndarray[Any, Any]]
_simple_value_packing = {
GGUFValueType.UINT8: "B",
GGUFValueType.INT8: "b",
@ -278,7 +237,7 @@ class GGUFWriter:
self.ti_data_count += 1
def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None,
) -> None:
if self.endianess == GGUFEndian.BIG:
@ -303,7 +262,7 @@ class GGUFWriter:
if pad != 0:
fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
if self.state is not WriterState.TI_DATA:
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
@ -391,7 +350,7 @@ class GGUFWriter:
def add_name(self, name: str) -> None:
self.add_string(Keys.General.NAME, name)
def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None:
def add_quantization_version(self, quantization_version: int) -> None:
self.add_uint32(
Keys.General.QUANTIZATION_VERSION, quantization_version)

225
gguf-py/gguf/lazy.py Normal file
View File

@ -0,0 +1,225 @@
from __future__ import annotations
from abc import ABC, ABCMeta, abstractmethod
import logging
from typing import Any, Callable
from collections import deque
import numpy as np
from numpy.typing import DTypeLike
logger = logging.getLogger(__name__)
class LazyMeta(ABCMeta):
def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
def __getattr__(self, __name: str) -> Any:
meta_attr = getattr(self._meta, __name)
if callable(meta_attr):
return type(self)._wrap_fn(
(lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)),
use_self=self,
)
elif isinstance(meta_attr, self._tensor_type):
# e.g. self.T with torch.Tensor should still be wrapped
return type(self)._wrap_fn(lambda s: getattr(s, __name))(self)
else:
# no need to wrap non-tensor properties,
# and they likely don't depend on the actual contents of the tensor
return meta_attr
namespace["__getattr__"] = __getattr__
# need to make a builder for the wrapped wrapper to copy the name,
# or else it fails with very cryptic error messages,
# because somehow the same string would end up in every closures
def mk_wrap(op_name: str, *, meta_noop: bool = False):
# need to wrap the wrapper to get self
def wrapped_special_op(self, *args, **kwargs):
return type(self)._wrap_fn(
getattr(type(self)._tensor_type, op_name),
meta_noop=meta_noop,
)(self, *args, **kwargs)
return wrapped_special_op
# special methods bypass __getattr__, so they need to be added manually
# ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
# NOTE: doing this from a metaclass is very convenient
# TODO: make this even more comprehensive
for binary_op in (
"lt", "le", "eq", "ne", "ge", "gt", "not"
"abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul",
"neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor",
"iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
"radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
):
attr_name = f"__{binary_op}__"
# the result of these operators usually has the same shape and dtype as the input,
# so evaluation on the meta tensor can be skipped.
namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)
for special_op in (
"getitem", "setitem", "len",
):
attr_name = f"__{special_op}__"
namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
return super().__new__(cls, name, bases, namespace, **kwargs)
# Tree of lazy tensors
class LazyBase(ABC, metaclass=LazyMeta):
_tensor_type: type
_meta: Any
_data: Any | None
_lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
_args: tuple
_func: Callable[[tuple], Any] | None
def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None):
super().__init__()
self._meta = meta
self._data = data
self._lazy = lazy if lazy is not None else deque()
self._args = args
self._func = func
assert self._func is not None or self._data is not None
if self._data is None:
self._lazy.append(self)
def __init_subclass__(cls) -> None:
if "_tensor_type" not in cls.__dict__:
raise TypeError(f"property '_tensor_type' must be defined for {cls!r}")
return super().__init_subclass__()
@staticmethod
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
# TODO: dict and set
if isinstance(o, (list, tuple)):
L = []
for item in o:
L.append(LazyBase._recurse_apply(item, fn))
if isinstance(o, tuple):
L = tuple(L)
return L
elif isinstance(o, LazyBase):
return fn(o)
else:
return o
@classmethod
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike = False) -> Callable[[Any], Any]:
def wrapped_fn(*args, **kwargs):
if kwargs is None:
kwargs = {}
args = ((use_self,) if use_self is not None else ()) + args
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
if isinstance(meta_noop, bool) and not meta_noop:
try:
res = fn(*meta_args, **kwargs)
except NotImplementedError:
# running some operations on PyTorch's Meta tensors can cause this exception
res = None
else:
# some operators don't need to actually run on the meta tensors
assert len(args) > 0
res = args[0]
assert isinstance(res, cls)
res = res._meta
# allow operations to override the dtype
if meta_noop is not True:
res = cls.meta_with_dtype(res, meta_noop)
if isinstance(res, cls._tensor_type):
def collect_replace(t: LazyBase):
if collect_replace.shared_lazy is None:
collect_replace.shared_lazy = t._lazy
else:
collect_replace.shared_lazy.extend(t._lazy)
t._lazy = collect_replace.shared_lazy
# emulating a static variable
collect_replace.shared_lazy = None
LazyBase._recurse_apply(args, collect_replace)
shared_lazy = collect_replace.shared_lazy
return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
else:
del res # not needed
# non-tensor return likely relies on the contents of the args
# (e.g. the result of torch.equal)
eager_args = cls.to_eager(args)
return fn(*eager_args, **kwargs)
return wrapped_fn
@classmethod
def to_eager(cls, t: Any) -> Any:
def simple_to_eager(_t: LazyBase) -> Any:
def already_eager_to_eager(_t: LazyBase) -> Any:
assert _t._data is not None
return _t._data
while _t._data is None:
lt = _t._lazy.popleft()
if lt._data is not None:
raise ValueError(f"{lt} did not belong in the lazy queue")
assert lt._func is not None
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
lt._data = lt._func(lt._args)
# sanity check
assert lt._data.dtype == lt._meta.dtype
assert lt._data.shape == lt._meta.shape
return _t._data
# recurse into lists and/or tuples, keeping their structure
return cls._recurse_apply(t, simple_to_eager)
@classmethod
def eager_to_meta(cls, t: Any) -> Any:
return cls.meta_with_dtype(t, t.dtype)
# must be overridden, meta tensor init is backend-specific
@classmethod
@abstractmethod
def meta_with_dtype(cls, m: Any, dtype: Any) -> Any: pass
@classmethod
def from_eager(cls, t: Any) -> Any:
if type(t) is cls:
# already eager
return t
elif isinstance(t, cls._tensor_type):
return cls(meta=cls.eager_to_meta(t), data=t)
else:
return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}")
class LazyNumpyTensor(LazyBase):
_tensor_type = np.ndarray
@classmethod
def meta_with_dtype(cls, m: np.ndarray[Any, Any], dtype: DTypeLike) -> np.ndarray[Any, Any]:
# The initial idea was to use np.nan as the fill value,
# but non-float types like np.int16 can't use that.
# So zero it is.
cheat = np.zeros(1, dtype)
return np.lib.stride_tricks.as_strided(cheat, m.shape, (0 for _ in m.shape))
def astype(self, dtype, *args, **kwargs):
meta = type(self).meta_with_dtype(self._meta, dtype)
full_args = (self, dtype,) + args
# very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
def tofile(self, *args, **kwargs):
eager = LazyNumpyTensor.to_eager(self)
return eager.tofile(*args, **kwargs)
# TODO: __array_function__