mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
convert_hf : faster lazy safetensors (#8482)
* convert_hf : faster lazy safetensors This makes '--dry-run' much, much faster. * convert_hf : fix memory leak in lazy MoE conversion The '_lazy' queue was sometimes self-referential, which caused reference cycles of objects old enough to avoid garbage collection until potential memory exhaustion.
This commit is contained in:
parent
97bdd26eee
commit
7acfd4e8d5
@ -148,7 +148,14 @@ class Model:
|
|||||||
tensor_names_from_parts.update(model_part.keys())
|
tensor_names_from_parts.update(model_part.keys())
|
||||||
|
|
||||||
for name in model_part.keys():
|
for name in model_part.keys():
|
||||||
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
|
if self.is_safetensors:
|
||||||
|
if self.lazy:
|
||||||
|
data = model_part.get_slice(name)
|
||||||
|
data = LazyTorchTensor.from_safetensors_slice(data)
|
||||||
|
else:
|
||||||
|
data = model_part.get_tensor(name)
|
||||||
|
else:
|
||||||
|
data = model_part[name]
|
||||||
if self.lazy:
|
if self.lazy:
|
||||||
data = LazyTorchTensor.from_eager(data)
|
data = LazyTorchTensor.from_eager(data)
|
||||||
yield name, data
|
yield name, data
|
||||||
@ -3424,19 +3431,46 @@ class LazyTorchTensor(gguf.LazyBase):
|
|||||||
torch.float32: np.float32,
|
torch.float32: np.float32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# used for safetensors slices
|
||||||
|
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
|
||||||
|
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
|
||||||
|
_dtype_str_map: dict[str, torch.dtype] = {
|
||||||
|
"F64": torch.float64,
|
||||||
|
"F32": torch.float32,
|
||||||
|
"BF16": torch.bfloat16,
|
||||||
|
"F16": torch.float16,
|
||||||
|
# "U64": torch.uint64,
|
||||||
|
"I64": torch.int64,
|
||||||
|
# "U32": torch.uint32,
|
||||||
|
"I32": torch.int32,
|
||||||
|
# "U16": torch.uint16,
|
||||||
|
"I16": torch.int16,
|
||||||
|
"U8": torch.uint8,
|
||||||
|
"I8": torch.int8,
|
||||||
|
"BOOL": torch.bool,
|
||||||
|
"F8_E4M3": torch.float8_e4m3fn,
|
||||||
|
"F8_E5M2": torch.float8_e5m2,
|
||||||
|
}
|
||||||
|
|
||||||
def numpy(self) -> gguf.LazyNumpyTensor:
|
def numpy(self) -> gguf.LazyNumpyTensor:
|
||||||
dtype = self._dtype_map[self.dtype]
|
dtype = self._dtype_map[self.dtype]
|
||||||
return gguf.LazyNumpyTensor(
|
return gguf.LazyNumpyTensor(
|
||||||
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
|
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
|
||||||
lazy=self._lazy,
|
|
||||||
args=(self,),
|
args=(self,),
|
||||||
func=(lambda s: s[0].numpy())
|
func=(lambda s: s.numpy())
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor:
|
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
|
||||||
return torch.empty(size=shape, dtype=dtype, device="meta")
|
return torch.empty(size=shape, dtype=dtype, device="meta")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
|
||||||
|
dtype = cls._dtype_str_map[st_slice.get_dtype()]
|
||||||
|
shape: tuple[int, ...] = tuple(st_slice.get_shape())
|
||||||
|
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
|
||||||
|
return cast(torch.Tensor, lazy)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||||
del types # unused
|
del types # unused
|
||||||
@ -3447,7 +3481,7 @@ class LazyTorchTensor(gguf.LazyBase):
|
|||||||
if func is torch.Tensor.numpy:
|
if func is torch.Tensor.numpy:
|
||||||
return args[0].numpy()
|
return args[0].numpy()
|
||||||
|
|
||||||
return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)
|
return cls._wrap_fn(func)(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
|
@ -3,7 +3,6 @@ from abc import ABC, ABCMeta, abstractmethod
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import DTypeLike
|
from numpy.typing import DTypeLike
|
||||||
@ -74,20 +73,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
|||||||
_tensor_type: type
|
_tensor_type: type
|
||||||
_meta: Any
|
_meta: Any
|
||||||
_data: Any | None
|
_data: Any | None
|
||||||
_lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
|
|
||||||
_args: tuple
|
_args: tuple
|
||||||
_func: Callable[[tuple], Any] | None
|
_kwargs: dict[str, Any]
|
||||||
|
_func: Callable[[Any], Any] | None
|
||||||
|
|
||||||
def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None):
|
def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._meta = meta
|
self._meta = meta
|
||||||
self._data = data
|
self._data = data
|
||||||
self._lazy = lazy if lazy is not None else deque()
|
|
||||||
self._args = args
|
self._args = args
|
||||||
|
self._kwargs = kwargs if kwargs is not None else {}
|
||||||
self._func = func
|
self._func = func
|
||||||
assert self._func is not None or self._data is not None
|
assert self._func is not None or self._data is not None
|
||||||
if self._data is None:
|
|
||||||
self._lazy.append(self)
|
|
||||||
|
|
||||||
def __init_subclass__(cls) -> None:
|
def __init_subclass__(cls) -> None:
|
||||||
if "_tensor_type" not in cls.__dict__:
|
if "_tensor_type" not in cls.__dict__:
|
||||||
@ -117,6 +114,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
|||||||
args = ((use_self,) if use_self is not None else ()) + args
|
args = ((use_self,) if use_self is not None else ()) + args
|
||||||
|
|
||||||
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
|
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
|
||||||
|
# TODO: maybe handle tensors in kwargs too
|
||||||
|
|
||||||
if isinstance(meta_noop, bool) and not meta_noop:
|
if isinstance(meta_noop, bool) and not meta_noop:
|
||||||
try:
|
try:
|
||||||
@ -140,23 +138,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
|||||||
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
|
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
|
||||||
|
|
||||||
if isinstance(res, cls._tensor_type):
|
if isinstance(res, cls._tensor_type):
|
||||||
class CollectSharedLazy:
|
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
|
||||||
# emulating a static variable
|
|
||||||
shared_lazy: None | deque[LazyBase] = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def collect_replace(t: LazyBase):
|
|
||||||
if CollectSharedLazy.shared_lazy is None:
|
|
||||||
CollectSharedLazy.shared_lazy = t._lazy
|
|
||||||
else:
|
|
||||||
CollectSharedLazy.shared_lazy.extend(t._lazy)
|
|
||||||
t._lazy = CollectSharedLazy.shared_lazy
|
|
||||||
|
|
||||||
LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace)
|
|
||||||
|
|
||||||
shared_lazy = CollectSharedLazy.shared_lazy
|
|
||||||
|
|
||||||
return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
|
|
||||||
else:
|
else:
|
||||||
del res # not needed
|
del res # not needed
|
||||||
# non-tensor return likely relies on the contents of the args
|
# non-tensor return likely relies on the contents of the args
|
||||||
@ -168,26 +150,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def to_eager(cls, t: Any) -> Any:
|
def to_eager(cls, t: Any) -> Any:
|
||||||
def simple_to_eager(_t: LazyBase) -> Any:
|
def simple_to_eager(_t: LazyBase) -> Any:
|
||||||
def already_eager_to_eager(_t: LazyBase) -> Any:
|
if _t._data is not None:
|
||||||
assert _t._data is not None
|
|
||||||
return _t._data
|
return _t._data
|
||||||
|
|
||||||
while _t._data is None:
|
# NOTE: there's a recursion limit in Python (usually 1000)
|
||||||
lt = _t._lazy.popleft()
|
|
||||||
if lt._data is not None:
|
assert _t._func is not None
|
||||||
# Lazy tensor did not belong in the lazy queue.
|
_t._args = cls._recurse_apply(_t._args, simple_to_eager)
|
||||||
# Weirdly only happens with Bloom models...
|
_t._data = _t._func(*_t._args, **_t._kwargs)
|
||||||
# likely because tensors aren't unique in the queue.
|
|
||||||
# The final output is still the same as in eager mode,
|
|
||||||
# so it's safe to ignore this.
|
|
||||||
continue
|
|
||||||
assert lt._func is not None
|
|
||||||
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
|
|
||||||
lt._data = lt._func(lt._args)
|
|
||||||
# sanity check
|
# sanity check
|
||||||
assert lt._data is not None
|
assert _t._data is not None
|
||||||
assert lt._data.dtype == lt._meta.dtype
|
assert _t._data.dtype == _t._meta.dtype
|
||||||
assert lt._data.shape == lt._meta.shape
|
assert _t._data.shape == _t._meta.shape
|
||||||
|
|
||||||
return _t._data
|
return _t._data
|
||||||
|
|
||||||
@ -206,7 +180,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_eager(cls, t: Any) -> Any:
|
def from_eager(cls, t: Any) -> Any:
|
||||||
if type(t) is cls:
|
if type(t) is cls:
|
||||||
# already eager
|
# already lazy
|
||||||
return t
|
return t
|
||||||
elif isinstance(t, cls._tensor_type):
|
elif isinstance(t, cls._tensor_type):
|
||||||
return cls(meta=cls.eager_to_meta(t), data=t)
|
return cls(meta=cls.eager_to_meta(t), data=t)
|
||||||
@ -228,8 +202,7 @@ class LazyNumpyTensor(LazyBase):
|
|||||||
def astype(self, dtype, *args, **kwargs):
|
def astype(self, dtype, *args, **kwargs):
|
||||||
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
|
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
|
||||||
full_args = (self, dtype,) + args
|
full_args = (self, dtype,) + args
|
||||||
# very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
|
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
|
||||||
return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
|
|
||||||
|
|
||||||
def tofile(self, *args, **kwargs):
|
def tofile(self, *args, **kwargs):
|
||||||
eager = LazyNumpyTensor.to_eager(self)
|
eager = LazyNumpyTensor.to_eager(self)
|
||||||
|
@ -602,13 +602,11 @@ class TensorNameMap:
|
|||||||
for tensor, keys in self.block_mappings_cfg.items():
|
for tensor, keys in self.block_mappings_cfg.items():
|
||||||
if tensor not in MODEL_TENSORS[arch]:
|
if tensor not in MODEL_TENSORS[arch]:
|
||||||
continue
|
continue
|
||||||
# TODO: make this configurable
|
|
||||||
n_experts = 160
|
tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
|
||||||
for xid in range(n_experts):
|
|
||||||
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
|
|
||||||
self.mapping[tensor_name] = (tensor, tensor_name)
|
self.mapping[tensor_name] = (tensor, tensor_name)
|
||||||
for key in keys:
|
for key in keys:
|
||||||
key = key.format(bid = bid, xid = xid)
|
key = key.format(bid = bid)
|
||||||
self.mapping[key] = (tensor, tensor_name)
|
self.mapping[key] = (tensor, tensor_name)
|
||||||
|
|
||||||
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
|
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
|
||||||
|
Loading…
Reference in New Issue
Block a user