mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
convert-hf : support bfloat16 conversion (#7158)
* convert-hf : support bfloat16 conversion * gguf-py : flake8 fixes * convert-hf : add missing space after comma * convert-hf : get bit-exact same output as ./quantize The quantization version was missing. * convert-hf : don't round bf16 NANs * convert-hf : save some memory with np.int16 intermediate bf16 weights * convert-hf : more closely match llama.cpp with which weights to keep in f32 * convert-hf : add --outtype auto-f16 A reason for this to exist is for model quantizers who want an initial GGUF with the most fidelity to the original model while still using a 16-bit float type instead of 32-bit floats. * convert-hf : remove a semicolon because flake8 doesn't like it It's a reflex from when programming in C/C++, I guess. * convert-hf : support outtype templating in outfile name * convert-hf : rename --outtype auto-f16 to --outtype auto
This commit is contained in:
parent
fae9d234b6
commit
5a419926b0
@ -12,7 +12,7 @@ import sys
|
|||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload
|
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -48,7 +48,6 @@ class Model:
|
|||||||
|
|
||||||
dir_model: Path
|
dir_model: Path
|
||||||
ftype: int
|
ftype: int
|
||||||
fname_out: Path
|
|
||||||
is_big_endian: bool
|
is_big_endian: bool
|
||||||
endianess: gguf.GGUFEndian
|
endianess: gguf.GGUFEndian
|
||||||
use_temp_file: bool
|
use_temp_file: bool
|
||||||
@ -56,20 +55,20 @@ class Model:
|
|||||||
part_names: list[str]
|
part_names: list[str]
|
||||||
is_safetensors: bool
|
is_safetensors: bool
|
||||||
hparams: dict[str, Any]
|
hparams: dict[str, Any]
|
||||||
gguf_writer: gguf.GGUFWriter
|
|
||||||
block_count: int
|
block_count: int
|
||||||
tensor_map: gguf.TensorNameMap
|
tensor_map: gguf.TensorNameMap
|
||||||
tensor_names: set[str] | None
|
tensor_names: set[str] | None
|
||||||
|
fname_out: Path
|
||||||
|
gguf_writer: gguf.GGUFWriter
|
||||||
|
|
||||||
# subclasses should define this!
|
# subclasses should define this!
|
||||||
model_arch: gguf.MODEL_ARCH
|
model_arch: gguf.MODEL_ARCH
|
||||||
|
|
||||||
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
|
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
|
||||||
if self.__class__ == Model:
|
if type(self) is Model:
|
||||||
raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
|
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||||
self.dir_model = dir_model
|
self.dir_model = dir_model
|
||||||
self.ftype = ftype
|
self.ftype = ftype
|
||||||
self.fname_out = fname_out
|
|
||||||
self.is_big_endian = is_big_endian
|
self.is_big_endian = is_big_endian
|
||||||
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
||||||
self.use_temp_file = use_temp_file
|
self.use_temp_file = use_temp_file
|
||||||
@ -79,10 +78,23 @@ class Model:
|
|||||||
if not self.is_safetensors:
|
if not self.is_safetensors:
|
||||||
self.part_names = Model.get_model_part_names(self.dir_model, ".bin")
|
self.part_names = Model.get_model_part_names(self.dir_model, ".bin")
|
||||||
self.hparams = Model.load_hparams(self.dir_model)
|
self.hparams = Model.load_hparams(self.dir_model)
|
||||||
self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
|
|
||||||
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
|
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
|
||||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||||
self.tensor_names = None
|
self.tensor_names = None
|
||||||
|
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||||
|
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
|
||||||
|
_, first_tensor = next(self.get_tensors())
|
||||||
|
if first_tensor.dtype == torch.float16:
|
||||||
|
logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})")
|
||||||
|
self.ftype = gguf.LlamaFileType.MOSTLY_F16
|
||||||
|
else:
|
||||||
|
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
|
||||||
|
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
|
||||||
|
ftype_up: str = self.ftype.name.partition("_")[2].upper()
|
||||||
|
ftype_lw: str = ftype_up.lower()
|
||||||
|
# allow templating the file name with the output ftype, useful with the "auto" ftype
|
||||||
|
self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
|
||||||
|
self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __init_subclass__(cls):
|
def __init_subclass__(cls):
|
||||||
@ -142,14 +154,27 @@ class Model:
|
|||||||
raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
|
raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
|
||||||
|
|
||||||
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
|
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
|
||||||
name: str = gguf.TENSOR_NAMES[key]
|
|
||||||
if key not in gguf.MODEL_TENSORS[self.model_arch]:
|
if key not in gguf.MODEL_TENSORS[self.model_arch]:
|
||||||
raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
|
raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
|
||||||
|
name: str = gguf.TENSOR_NAMES[key]
|
||||||
if "{bid}" in name:
|
if "{bid}" in name:
|
||||||
assert bid is not None
|
assert bid is not None
|
||||||
name = name.format(bid=bid)
|
name = name.format(bid=bid)
|
||||||
return name + suffix
|
return name + suffix
|
||||||
|
|
||||||
|
def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> bool:
|
||||||
|
if key not in gguf.MODEL_TENSORS[self.model_arch]:
|
||||||
|
return False
|
||||||
|
key_name: str = gguf.TENSOR_NAMES[key]
|
||||||
|
if "{bid}" in key_name:
|
||||||
|
if bid is None:
|
||||||
|
return False
|
||||||
|
key_name = key_name.format(bid=bid)
|
||||||
|
else:
|
||||||
|
if bid is not None:
|
||||||
|
return False
|
||||||
|
return name == (key_name + suffix)
|
||||||
|
|
||||||
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
|
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
|
||||||
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
|
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
@ -215,6 +240,23 @@ class Model:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
|
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
|
||||||
|
def np_fp32_to_bf16(n: np.ndarray):
|
||||||
|
# force nan to quiet
|
||||||
|
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
|
||||||
|
# flush subnormals to zero
|
||||||
|
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
|
||||||
|
# round to nearest even
|
||||||
|
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
|
||||||
|
return n.astype(np.int16)
|
||||||
|
|
||||||
|
# Doing this row-wise is much, much faster than element-wise, hence the signature
|
||||||
|
v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
|
||||||
|
if self.lazy:
|
||||||
|
# TODO: find a way to implicitly wrap np.vectorize functions
|
||||||
|
# NOTE: the type is changed to reflect otypes passed to np.vectorize above
|
||||||
|
v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)
|
||||||
|
|
||||||
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
|
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
|
||||||
|
|
||||||
for name, data_torch in self.get_tensors():
|
for name, data_torch in self.get_tensors():
|
||||||
@ -239,35 +281,60 @@ class Model:
|
|||||||
data: np.ndarray = data # type hint
|
data: np.ndarray = data # type hint
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
|
data_qtype: gguf.GGMLQuantizationType | None = None
|
||||||
# if f32 desired, convert any float16 to float32
|
|
||||||
if self.ftype == 0 and data_dtype == np.float16:
|
|
||||||
data = data.astype(np.float32)
|
|
||||||
|
|
||||||
# when both are True, f32 should win
|
# when both are True, f32 should win
|
||||||
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
|
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
|
||||||
extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
|
extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
|
||||||
|
|
||||||
# Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
|
# Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
|
||||||
extra_f32 = extra_f32 or n_dims == 1 or new_name.endswith("_norm.weight")
|
# Conditions should closely match those in llama_model_quantize_internal in llama.cpp
|
||||||
|
extra_f32 = any(cond for cond in (
|
||||||
|
extra_f32,
|
||||||
|
n_dims == 1,
|
||||||
|
new_name.endswith("_norm.weight"),
|
||||||
|
))
|
||||||
|
|
||||||
|
# Some tensor types are always in float32
|
||||||
|
extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in (
|
||||||
|
gguf.MODEL_TENSOR.FFN_GATE_INP,
|
||||||
|
gguf.MODEL_TENSOR.POS_EMBD,
|
||||||
|
gguf.MODEL_TENSOR.TOKEN_TYPES,
|
||||||
|
))
|
||||||
|
|
||||||
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||||
extra_f16 = extra_f16 or (name.endswith(".weight") and n_dims >= 2)
|
extra_f16 = any(cond for cond in (
|
||||||
|
extra_f16,
|
||||||
|
(name.endswith(".weight") and n_dims >= 2),
|
||||||
|
))
|
||||||
|
|
||||||
# when both extra_f32 and extra_f16 are False, convert to float32 by default
|
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
|
||||||
if self.ftype == 1 and data_dtype == np.float16 and (extra_f32 or not extra_f16):
|
if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
|
||||||
data = data.astype(np.float32)
|
if data_dtype != np.float16:
|
||||||
|
|
||||||
if self.ftype == 1 and data_dtype == np.float32 and extra_f16 and not extra_f32:
|
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
data_qtype = gguf.GGMLQuantizationType.F16
|
||||||
|
|
||||||
|
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
|
||||||
|
if data_dtype != np.float32:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
data = v_fp32_to_bf16(data.view(np.int32))
|
||||||
|
assert data.dtype == np.int16
|
||||||
|
data_qtype = gguf.GGMLQuantizationType.BF16
|
||||||
|
|
||||||
|
else: # by default, convert to float32
|
||||||
|
if data_dtype != np.float32:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
data_qtype = gguf.GGMLQuantizationType.F32
|
||||||
|
|
||||||
|
assert data_qtype is not None
|
||||||
|
|
||||||
# reverse shape to make it similar to the internal ggml dimension order
|
# reverse shape to make it similar to the internal ggml dimension order
|
||||||
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
|
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
|
||||||
|
|
||||||
# n_dims is implicit in the shape
|
# n_dims is implicit in the shape
|
||||||
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data.dtype}, shape = {shape_str}")
|
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
|
||||||
|
|
||||||
def write(self):
|
def write(self):
|
||||||
self.write_tensors()
|
self.write_tensors()
|
||||||
@ -2044,12 +2111,6 @@ class BertModel(Model):
|
|||||||
|
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
|
||||||
del new_name, bid, n_dims # unused
|
|
||||||
|
|
||||||
# not used with get_rows, must be F32
|
|
||||||
return name == "embeddings.token_type_embeddings.weight"
|
|
||||||
|
|
||||||
|
|
||||||
@Model.register("NomicBertModel")
|
@Model.register("NomicBertModel")
|
||||||
class NomicBertModel(BertModel):
|
class NomicBertModel(BertModel):
|
||||||
@ -2339,92 +2400,40 @@ class JinaBertV2Model(BertModel):
|
|||||||
|
|
||||||
|
|
||||||
# tree of lazy tensors
|
# tree of lazy tensors
|
||||||
class LazyTorchTensor:
|
class LazyTorchTensor(gguf.LazyBase):
|
||||||
_meta: Tensor
|
_tensor_type = torch.Tensor
|
||||||
_data: Tensor | None
|
# to keep the type-checker happy
|
||||||
_args: tuple
|
dtype: torch.dtype
|
||||||
_func: Callable[[tuple], Tensor] | None
|
shape: torch.Size
|
||||||
|
|
||||||
def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
|
|
||||||
self._meta = meta
|
|
||||||
self._data = data
|
|
||||||
self._args = args
|
|
||||||
self._func = func
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
|
|
||||||
# TODO: dict and set
|
|
||||||
if isinstance(o, (list, tuple)):
|
|
||||||
L = []
|
|
||||||
for item in o:
|
|
||||||
L.append(LazyTorchTensor._recurse_apply(item, fn))
|
|
||||||
if isinstance(o, tuple):
|
|
||||||
L = tuple(L)
|
|
||||||
return L
|
|
||||||
elif isinstance(o, LazyTorchTensor):
|
|
||||||
return fn(o)
|
|
||||||
else:
|
|
||||||
return o
|
|
||||||
|
|
||||||
def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]:
|
|
||||||
def wrapped_fn(*args, **kwargs):
|
|
||||||
if kwargs is None:
|
|
||||||
kwargs = {}
|
|
||||||
args = ((self,) if use_self else ()) + args
|
|
||||||
|
|
||||||
meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)
|
|
||||||
|
|
||||||
return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
|
|
||||||
return wrapped_fn
|
|
||||||
|
|
||||||
def __getattr__(self, __name: str) -> Any:
|
|
||||||
meta_attr = getattr(self._meta, __name)
|
|
||||||
if callable(meta_attr):
|
|
||||||
return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
|
|
||||||
elif isinstance(meta_attr, torch.Tensor):
|
|
||||||
# for things like self.T
|
|
||||||
return self._wrap_fn(lambda s: getattr(s, __name))(self)
|
|
||||||
else:
|
|
||||||
return meta_attr
|
|
||||||
|
|
||||||
|
# only used when converting a torch.Tensor to a np.ndarray
|
||||||
_dtype_map: dict[torch.dtype, type] = {
|
_dtype_map: dict[torch.dtype, type] = {
|
||||||
torch.float16: np.float16,
|
torch.float16: np.float16,
|
||||||
torch.float32: np.float32,
|
torch.float32: np.float32,
|
||||||
}
|
}
|
||||||
|
|
||||||
def numpy(self) -> gguf.LazyTensor:
|
def numpy(self) -> gguf.LazyNumpyTensor:
|
||||||
dtype = self._dtype_map[self.dtype]
|
dtype = self._dtype_map[self.dtype]
|
||||||
return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape)
|
return gguf.LazyNumpyTensor(
|
||||||
|
meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
|
||||||
|
lazy=self._lazy,
|
||||||
|
args=(self,),
|
||||||
|
func=(lambda s: s[0].numpy())
|
||||||
|
)
|
||||||
|
|
||||||
@overload
|
@classmethod
|
||||||
@staticmethod
|
def eager_to_meta(cls, t: Tensor) -> Tensor:
|
||||||
def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...
|
if t.is_meta:
|
||||||
|
|
||||||
@overload
|
|
||||||
@staticmethod
|
|
||||||
def to_eager(t: tuple) -> tuple: ...
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def to_eager(t: Any) -> Any:
|
|
||||||
def simple_to_eager(_t: LazyTorchTensor) -> Tensor:
|
|
||||||
# wake up the lazy tensor
|
|
||||||
if _t._data is None and _t._func is not None:
|
|
||||||
# recurse into its arguments
|
|
||||||
_t._args = LazyTorchTensor.to_eager(_t._args)
|
|
||||||
_t._data = _t._func(_t._args)
|
|
||||||
if _t._data is not None:
|
|
||||||
return _t._data
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}")
|
|
||||||
|
|
||||||
# recurse into lists and/or tuples, keeping their structure
|
|
||||||
return LazyTorchTensor._recurse_apply(t, simple_to_eager)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_eager(t: Tensor) -> Tensor:
|
|
||||||
if (t.__class__ == LazyTorchTensor):
|
|
||||||
return t
|
return t
|
||||||
return LazyTorchTensor(meta=t.detach().to("meta"), data=t) # type: ignore
|
return t.detach().to("meta")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
|
||||||
|
m = m.detach()
|
||||||
|
if not m.is_meta:
|
||||||
|
m = m.to("meta")
|
||||||
|
m.dtype = dtype
|
||||||
|
return m
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||||
@ -2435,28 +2444,8 @@ class LazyTorchTensor:
|
|||||||
|
|
||||||
if func is torch.Tensor.numpy:
|
if func is torch.Tensor.numpy:
|
||||||
return args[0].numpy()
|
return args[0].numpy()
|
||||||
if func is torch.equal:
|
|
||||||
eager_args = LazyTorchTensor.to_eager(args)
|
|
||||||
return func(*eager_args, **kwargs)
|
|
||||||
|
|
||||||
return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs)
|
return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)
|
||||||
|
|
||||||
# special methods bypass __getattr__, so they need to be added manually
|
|
||||||
# ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
|
|
||||||
# NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used
|
|
||||||
# as self._meta is currently used), because then the following
|
|
||||||
# operations would by default not be wrapped, and so not propagated
|
|
||||||
# when the tensor is made eager.
|
|
||||||
# It's better to get non-silent errors for not-yet-supported operators.
|
|
||||||
# TODO: add more when needed to avoid clutter, or find a more concise way
|
|
||||||
def __neg__(self, *args): # mamba
|
|
||||||
return self._wrap_fn(torch.Tensor.__neg__)(self, *args)
|
|
||||||
|
|
||||||
def __add__(self, *args): # gemma
|
|
||||||
return self._wrap_fn(torch.Tensor.__add__)(self, *args)
|
|
||||||
|
|
||||||
def __getitem__(self, *args): # bloom falcon refact internlm2
|
|
||||||
return self._wrap_fn(torch.Tensor.__getitem__)(self, *args)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
@ -2472,11 +2461,11 @@ def parse_args() -> argparse.Namespace:
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--outfile", type=Path,
|
"--outfile", type=Path,
|
||||||
help="path to write to; default: based on input",
|
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--outtype", type=str, choices=["f32", "f16"], default="f16",
|
"--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16",
|
||||||
help="output format - use f32 for float32, f16 for float16",
|
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bigendian", action="store_true",
|
"--bigendian", action="store_true",
|
||||||
@ -2530,16 +2519,18 @@ def main() -> None:
|
|||||||
logger.error(f'Error: {args.model} is not a directory')
|
logger.error(f'Error: {args.model} is not a directory')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
ftype_map = {
|
ftype_map: dict[str, gguf.LlamaFileType] = {
|
||||||
"f32": gguf.GGMLQuantizationType.F32,
|
"f32": gguf.LlamaFileType.ALL_F32,
|
||||||
"f16": gguf.GGMLQuantizationType.F16,
|
"f16": gguf.LlamaFileType.MOSTLY_F16,
|
||||||
|
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
|
||||||
|
"auto": gguf.LlamaFileType.GUESSED,
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.outfile is not None:
|
if args.outfile is not None:
|
||||||
fname_out = args.outfile
|
fname_out = args.outfile
|
||||||
else:
|
else:
|
||||||
# output in the same directory as the model by default
|
# output in the same directory as the model by default
|
||||||
fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
|
fname_out = dir_model / 'ggml-model-{ftype}.gguf'
|
||||||
|
|
||||||
logger.info(f"Loading model: {dir_model.name}")
|
logger.info(f"Loading model: {dir_model.name}")
|
||||||
|
|
||||||
@ -2555,14 +2546,16 @@ def main() -> None:
|
|||||||
logger.info("Set model tokenizer")
|
logger.info("Set model tokenizer")
|
||||||
model_instance.set_vocab()
|
model_instance.set_vocab()
|
||||||
|
|
||||||
|
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
|
||||||
|
|
||||||
if args.vocab_only:
|
if args.vocab_only:
|
||||||
logger.info(f"Exporting model vocab to '{fname_out}'")
|
logger.info(f"Exporting model vocab to '{model_instance.fname_out}'")
|
||||||
model_instance.write_vocab()
|
model_instance.write_vocab()
|
||||||
else:
|
else:
|
||||||
logger.info(f"Exporting model to '{fname_out}'")
|
logger.info(f"Exporting model to '{model_instance.fname_out}'")
|
||||||
model_instance.write()
|
model_instance.write()
|
||||||
|
|
||||||
logger.info(f"Model successfully exported to '{fname_out}'")
|
logger.info(f"Model successfully exported to '{model_instance.fname_out}'")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from .constants import *
|
from .constants import *
|
||||||
|
from .lazy import *
|
||||||
from .gguf_reader import *
|
from .gguf_reader import *
|
||||||
from .gguf_writer import *
|
from .gguf_writer import *
|
||||||
from .tensor_mapping import *
|
from .tensor_mapping import *
|
||||||
|
@ -10,6 +10,7 @@ from typing import Any
|
|||||||
GGUF_MAGIC = 0x46554747 # "GGUF"
|
GGUF_MAGIC = 0x46554747 # "GGUF"
|
||||||
GGUF_VERSION = 3
|
GGUF_VERSION = 3
|
||||||
GGUF_DEFAULT_ALIGNMENT = 32
|
GGUF_DEFAULT_ALIGNMENT = 32
|
||||||
|
GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h
|
||||||
|
|
||||||
#
|
#
|
||||||
# metadata keys
|
# metadata keys
|
||||||
@ -838,6 +839,49 @@ class GGMLQuantizationType(IntEnum):
|
|||||||
BF16 = 30
|
BF16 = 30
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: add GGMLFileType from ggml_ftype in ggml.h
|
||||||
|
|
||||||
|
|
||||||
|
# from llama_ftype in llama.h
|
||||||
|
# ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE.
|
||||||
|
class LlamaFileType(IntEnum):
|
||||||
|
ALL_F32 = 0
|
||||||
|
MOSTLY_F16 = 1 # except 1d tensors
|
||||||
|
MOSTLY_Q4_0 = 2 # except 1d tensors
|
||||||
|
MOSTLY_Q4_1 = 3 # except 1d tensors
|
||||||
|
MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
|
||||||
|
# MOSTLY_Q4_2 = 5 # support has been removed
|
||||||
|
# MOSTLY_Q4_3 = 6 # support has been removed
|
||||||
|
MOSTLY_Q8_0 = 7 # except 1d tensors
|
||||||
|
MOSTLY_Q5_0 = 8 # except 1d tensors
|
||||||
|
MOSTLY_Q5_1 = 9 # except 1d tensors
|
||||||
|
MOSTLY_Q2_K = 10 # except 1d tensors
|
||||||
|
MOSTLY_Q3_K_S = 11 # except 1d tensors
|
||||||
|
MOSTLY_Q3_K_M = 12 # except 1d tensors
|
||||||
|
MOSTLY_Q3_K_L = 13 # except 1d tensors
|
||||||
|
MOSTLY_Q4_K_S = 14 # except 1d tensors
|
||||||
|
MOSTLY_Q4_K_M = 15 # except 1d tensors
|
||||||
|
MOSTLY_Q5_K_S = 16 # except 1d tensors
|
||||||
|
MOSTLY_Q5_K_M = 17 # except 1d tensors
|
||||||
|
MOSTLY_Q6_K = 18 # except 1d tensors
|
||||||
|
MOSTLY_IQ2_XXS = 19 # except 1d tensors
|
||||||
|
MOSTLY_IQ2_XS = 20 # except 1d tensors
|
||||||
|
MOSTLY_Q2_K_S = 21 # except 1d tensors
|
||||||
|
MOSTLY_IQ3_XS = 22 # except 1d tensors
|
||||||
|
MOSTLY_IQ3_XXS = 23 # except 1d tensors
|
||||||
|
MOSTLY_IQ1_S = 24 # except 1d tensors
|
||||||
|
MOSTLY_IQ4_NL = 25 # except 1d tensors
|
||||||
|
MOSTLY_IQ3_S = 26 # except 1d tensors
|
||||||
|
MOSTLY_IQ3_M = 27 # except 1d tensors
|
||||||
|
MOSTLY_IQ2_S = 28 # except 1d tensors
|
||||||
|
MOSTLY_IQ2_M = 29 # except 1d tensors
|
||||||
|
MOSTLY_IQ4_XS = 30 # except 1d tensors
|
||||||
|
MOSTLY_IQ1_M = 31 # except 1d tensors
|
||||||
|
MOSTLY_BF16 = 32 # except 1d tensors
|
||||||
|
|
||||||
|
GUESSED = 1024 # not specified in the model file
|
||||||
|
|
||||||
|
|
||||||
class GGUFEndian(IntEnum):
|
class GGUFEndian(IntEnum):
|
||||||
LITTLE = 0
|
LITTLE = 0
|
||||||
BIG = 1
|
BIG = 1
|
||||||
|
@ -7,7 +7,7 @@ import struct
|
|||||||
import tempfile
|
import tempfile
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from io import BufferedWriter
|
from io import BufferedWriter
|
||||||
from typing import IO, Any, Callable, Sequence, Mapping
|
from typing import IO, Any, Sequence, Mapping
|
||||||
from string import ascii_letters, digits
|
from string import ascii_letters, digits
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -28,47 +28,6 @@ from .constants import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LazyTensor:
|
|
||||||
data: Callable[[], np.ndarray[Any, Any]]
|
|
||||||
# to avoid too deep recursion
|
|
||||||
functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
|
|
||||||
dtype: np.dtype[Any]
|
|
||||||
shape: tuple[int, ...]
|
|
||||||
|
|
||||||
def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
|
|
||||||
self.data = data
|
|
||||||
self.functions = []
|
|
||||||
self.dtype = np.dtype(dtype)
|
|
||||||
self.shape = shape
|
|
||||||
|
|
||||||
def astype(self, dtype: type, **kwargs) -> LazyTensor:
|
|
||||||
self.functions.append(lambda n: n.astype(dtype, **kwargs))
|
|
||||||
self.dtype = np.dtype(dtype)
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def nbytes(self) -> int:
|
|
||||||
size = 1
|
|
||||||
for n in self.shape:
|
|
||||||
size *= n
|
|
||||||
return size * self.dtype.itemsize
|
|
||||||
|
|
||||||
def tofile(self, *args, **kwargs) -> None:
|
|
||||||
data = self.data()
|
|
||||||
for f in self.functions:
|
|
||||||
data = f(data)
|
|
||||||
assert data.shape == self.shape
|
|
||||||
assert data.dtype == self.dtype
|
|
||||||
assert data.nbytes == self.nbytes
|
|
||||||
self.functions = []
|
|
||||||
self.data = lambda: data
|
|
||||||
data.tofile(*args, **kwargs)
|
|
||||||
|
|
||||||
def byteswap(self, *args, **kwargs) -> LazyTensor:
|
|
||||||
self.functions.append(lambda n: n.byteswap(*args, **kwargs))
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class WriterState(Enum):
|
class WriterState(Enum):
|
||||||
EMPTY = auto()
|
EMPTY = auto()
|
||||||
HEADER = auto()
|
HEADER = auto()
|
||||||
@ -79,7 +38,7 @@ class WriterState(Enum):
|
|||||||
class GGUFWriter:
|
class GGUFWriter:
|
||||||
fout: BufferedWriter
|
fout: BufferedWriter
|
||||||
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
||||||
tensors: list[np.ndarray[Any, Any] | LazyTensor]
|
tensors: list[np.ndarray[Any, Any]]
|
||||||
_simple_value_packing = {
|
_simple_value_packing = {
|
||||||
GGUFValueType.UINT8: "B",
|
GGUFValueType.UINT8: "B",
|
||||||
GGUFValueType.INT8: "b",
|
GGUFValueType.INT8: "b",
|
||||||
@ -278,7 +237,7 @@ class GGUFWriter:
|
|||||||
self.ti_data_count += 1
|
self.ti_data_count += 1
|
||||||
|
|
||||||
def add_tensor(
|
def add_tensor(
|
||||||
self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
|
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
||||||
raw_dtype: GGMLQuantizationType | None = None,
|
raw_dtype: GGMLQuantizationType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.endianess == GGUFEndian.BIG:
|
if self.endianess == GGUFEndian.BIG:
|
||||||
@ -303,7 +262,7 @@ class GGUFWriter:
|
|||||||
if pad != 0:
|
if pad != 0:
|
||||||
fp.write(bytes([0] * pad))
|
fp.write(bytes([0] * pad))
|
||||||
|
|
||||||
def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
|
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
|
||||||
if self.state is not WriterState.TI_DATA:
|
if self.state is not WriterState.TI_DATA:
|
||||||
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
|
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
|
||||||
|
|
||||||
@ -391,7 +350,7 @@ class GGUFWriter:
|
|||||||
def add_name(self, name: str) -> None:
|
def add_name(self, name: str) -> None:
|
||||||
self.add_string(Keys.General.NAME, name)
|
self.add_string(Keys.General.NAME, name)
|
||||||
|
|
||||||
def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None:
|
def add_quantization_version(self, quantization_version: int) -> None:
|
||||||
self.add_uint32(
|
self.add_uint32(
|
||||||
Keys.General.QUANTIZATION_VERSION, quantization_version)
|
Keys.General.QUANTIZATION_VERSION, quantization_version)
|
||||||
|
|
||||||
|
225
gguf-py/gguf/lazy.py
Normal file
225
gguf-py/gguf/lazy.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from abc import ABC, ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import DTypeLike
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LazyMeta(ABCMeta):
|
||||||
|
|
||||||
|
def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
|
||||||
|
def __getattr__(self, __name: str) -> Any:
|
||||||
|
meta_attr = getattr(self._meta, __name)
|
||||||
|
if callable(meta_attr):
|
||||||
|
return type(self)._wrap_fn(
|
||||||
|
(lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)),
|
||||||
|
use_self=self,
|
||||||
|
)
|
||||||
|
elif isinstance(meta_attr, self._tensor_type):
|
||||||
|
# e.g. self.T with torch.Tensor should still be wrapped
|
||||||
|
return type(self)._wrap_fn(lambda s: getattr(s, __name))(self)
|
||||||
|
else:
|
||||||
|
# no need to wrap non-tensor properties,
|
||||||
|
# and they likely don't depend on the actual contents of the tensor
|
||||||
|
return meta_attr
|
||||||
|
|
||||||
|
namespace["__getattr__"] = __getattr__
|
||||||
|
|
||||||
|
# need to make a builder for the wrapped wrapper to copy the name,
|
||||||
|
# or else it fails with very cryptic error messages,
|
||||||
|
# because somehow the same string would end up in every closures
|
||||||
|
def mk_wrap(op_name: str, *, meta_noop: bool = False):
|
||||||
|
# need to wrap the wrapper to get self
|
||||||
|
def wrapped_special_op(self, *args, **kwargs):
|
||||||
|
return type(self)._wrap_fn(
|
||||||
|
getattr(type(self)._tensor_type, op_name),
|
||||||
|
meta_noop=meta_noop,
|
||||||
|
)(self, *args, **kwargs)
|
||||||
|
return wrapped_special_op
|
||||||
|
|
||||||
|
# special methods bypass __getattr__, so they need to be added manually
|
||||||
|
# ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
|
||||||
|
# NOTE: doing this from a metaclass is very convenient
|
||||||
|
# TODO: make this even more comprehensive
|
||||||
|
for binary_op in (
|
||||||
|
"lt", "le", "eq", "ne", "ge", "gt", "not"
|
||||||
|
"abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul",
|
||||||
|
"neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor",
|
||||||
|
"iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
|
||||||
|
"radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
|
||||||
|
):
|
||||||
|
attr_name = f"__{binary_op}__"
|
||||||
|
# the result of these operators usually has the same shape and dtype as the input,
|
||||||
|
# so evaluation on the meta tensor can be skipped.
|
||||||
|
namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)
|
||||||
|
|
||||||
|
for special_op in (
|
||||||
|
"getitem", "setitem", "len",
|
||||||
|
):
|
||||||
|
attr_name = f"__{special_op}__"
|
||||||
|
namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
|
||||||
|
|
||||||
|
return super().__new__(cls, name, bases, namespace, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Tree of lazy tensors
|
||||||
|
class LazyBase(ABC, metaclass=LazyMeta):
|
||||||
|
_tensor_type: type
|
||||||
|
_meta: Any
|
||||||
|
_data: Any | None
|
||||||
|
_lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
|
||||||
|
_args: tuple
|
||||||
|
_func: Callable[[tuple], Any] | None
|
||||||
|
|
||||||
|
def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None):
|
||||||
|
super().__init__()
|
||||||
|
self._meta = meta
|
||||||
|
self._data = data
|
||||||
|
self._lazy = lazy if lazy is not None else deque()
|
||||||
|
self._args = args
|
||||||
|
self._func = func
|
||||||
|
assert self._func is not None or self._data is not None
|
||||||
|
if self._data is None:
|
||||||
|
self._lazy.append(self)
|
||||||
|
|
||||||
|
def __init_subclass__(cls) -> None:
|
||||||
|
if "_tensor_type" not in cls.__dict__:
|
||||||
|
raise TypeError(f"property '_tensor_type' must be defined for {cls!r}")
|
||||||
|
return super().__init_subclass__()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
|
||||||
|
# TODO: dict and set
|
||||||
|
if isinstance(o, (list, tuple)):
|
||||||
|
L = []
|
||||||
|
for item in o:
|
||||||
|
L.append(LazyBase._recurse_apply(item, fn))
|
||||||
|
if isinstance(o, tuple):
|
||||||
|
L = tuple(L)
|
||||||
|
return L
|
||||||
|
elif isinstance(o, LazyBase):
|
||||||
|
return fn(o)
|
||||||
|
else:
|
||||||
|
return o
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike = False) -> Callable[[Any], Any]:
|
||||||
|
def wrapped_fn(*args, **kwargs):
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
args = ((use_self,) if use_self is not None else ()) + args
|
||||||
|
|
||||||
|
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
|
||||||
|
|
||||||
|
if isinstance(meta_noop, bool) and not meta_noop:
|
||||||
|
try:
|
||||||
|
res = fn(*meta_args, **kwargs)
|
||||||
|
except NotImplementedError:
|
||||||
|
# running some operations on PyTorch's Meta tensors can cause this exception
|
||||||
|
res = None
|
||||||
|
else:
|
||||||
|
# some operators don't need to actually run on the meta tensors
|
||||||
|
assert len(args) > 0
|
||||||
|
res = args[0]
|
||||||
|
assert isinstance(res, cls)
|
||||||
|
res = res._meta
|
||||||
|
# allow operations to override the dtype
|
||||||
|
if meta_noop is not True:
|
||||||
|
res = cls.meta_with_dtype(res, meta_noop)
|
||||||
|
|
||||||
|
if isinstance(res, cls._tensor_type):
|
||||||
|
def collect_replace(t: LazyBase):
|
||||||
|
if collect_replace.shared_lazy is None:
|
||||||
|
collect_replace.shared_lazy = t._lazy
|
||||||
|
else:
|
||||||
|
collect_replace.shared_lazy.extend(t._lazy)
|
||||||
|
t._lazy = collect_replace.shared_lazy
|
||||||
|
|
||||||
|
# emulating a static variable
|
||||||
|
collect_replace.shared_lazy = None
|
||||||
|
|
||||||
|
LazyBase._recurse_apply(args, collect_replace)
|
||||||
|
|
||||||
|
shared_lazy = collect_replace.shared_lazy
|
||||||
|
|
||||||
|
return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
|
||||||
|
else:
|
||||||
|
del res # not needed
|
||||||
|
# non-tensor return likely relies on the contents of the args
|
||||||
|
# (e.g. the result of torch.equal)
|
||||||
|
eager_args = cls.to_eager(args)
|
||||||
|
return fn(*eager_args, **kwargs)
|
||||||
|
return wrapped_fn
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def to_eager(cls, t: Any) -> Any:
|
||||||
|
def simple_to_eager(_t: LazyBase) -> Any:
|
||||||
|
def already_eager_to_eager(_t: LazyBase) -> Any:
|
||||||
|
assert _t._data is not None
|
||||||
|
return _t._data
|
||||||
|
|
||||||
|
while _t._data is None:
|
||||||
|
lt = _t._lazy.popleft()
|
||||||
|
if lt._data is not None:
|
||||||
|
raise ValueError(f"{lt} did not belong in the lazy queue")
|
||||||
|
assert lt._func is not None
|
||||||
|
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
|
||||||
|
lt._data = lt._func(lt._args)
|
||||||
|
# sanity check
|
||||||
|
assert lt._data.dtype == lt._meta.dtype
|
||||||
|
assert lt._data.shape == lt._meta.shape
|
||||||
|
|
||||||
|
return _t._data
|
||||||
|
|
||||||
|
# recurse into lists and/or tuples, keeping their structure
|
||||||
|
return cls._recurse_apply(t, simple_to_eager)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def eager_to_meta(cls, t: Any) -> Any:
|
||||||
|
return cls.meta_with_dtype(t, t.dtype)
|
||||||
|
|
||||||
|
# must be overridden, meta tensor init is backend-specific
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def meta_with_dtype(cls, m: Any, dtype: Any) -> Any: pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_eager(cls, t: Any) -> Any:
|
||||||
|
if type(t) is cls:
|
||||||
|
# already eager
|
||||||
|
return t
|
||||||
|
elif isinstance(t, cls._tensor_type):
|
||||||
|
return cls(meta=cls.eager_to_meta(t), data=t)
|
||||||
|
else:
|
||||||
|
return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}")
|
||||||
|
|
||||||
|
|
||||||
|
class LazyNumpyTensor(LazyBase):
|
||||||
|
_tensor_type = np.ndarray
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def meta_with_dtype(cls, m: np.ndarray[Any, Any], dtype: DTypeLike) -> np.ndarray[Any, Any]:
|
||||||
|
# The initial idea was to use np.nan as the fill value,
|
||||||
|
# but non-float types like np.int16 can't use that.
|
||||||
|
# So zero it is.
|
||||||
|
cheat = np.zeros(1, dtype)
|
||||||
|
return np.lib.stride_tricks.as_strided(cheat, m.shape, (0 for _ in m.shape))
|
||||||
|
|
||||||
|
def astype(self, dtype, *args, **kwargs):
|
||||||
|
meta = type(self).meta_with_dtype(self._meta, dtype)
|
||||||
|
full_args = (self, dtype,) + args
|
||||||
|
# very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
|
||||||
|
return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
|
||||||
|
|
||||||
|
def tofile(self, *args, **kwargs):
|
||||||
|
eager = LazyNumpyTensor.to_eager(self)
|
||||||
|
return eager.tofile(*args, **kwargs)
|
||||||
|
|
||||||
|
# TODO: __array_function__
|
Loading…
Reference in New Issue
Block a user