convert_hf : faster lazy safetensors (#8482)

* convert_hf : faster lazy safetensors

This makes '--dry-run' much, much faster.

* convert_hf : fix memory leak in lazy MoE conversion

The '_lazy' queue was sometimes self-referential,
which caused reference cycles of objects old enough
to avoid garbage collection until potential memory exhaustion.
This commit is contained in:
compilade 2024-07-15 23:13:10 -04:00 committed by GitHub
parent 97bdd26eee
commit 7acfd4e8d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 60 deletions

View File

@ -148,9 +148,16 @@ class Model:
tensor_names_from_parts.update(model_part.keys()) tensor_names_from_parts.update(model_part.keys())
for name in model_part.keys(): for name in model_part.keys():
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name] if self.is_safetensors:
if self.lazy: if self.lazy:
data = LazyTorchTensor.from_eager(data) data = model_part.get_slice(name)
data = LazyTorchTensor.from_safetensors_slice(data)
else:
data = model_part.get_tensor(name)
else:
data = model_part[name]
if self.lazy:
data = LazyTorchTensor.from_eager(data)
yield name, data yield name, data
# only verify tensor name presence; it doesn't matter if they are not in the right files # only verify tensor name presence; it doesn't matter if they are not in the right files
@ -3424,19 +3431,46 @@ class LazyTorchTensor(gguf.LazyBase):
torch.float32: np.float32, torch.float32: np.float32,
} }
# used for safetensors slices
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
_dtype_str_map: dict[str, torch.dtype] = {
"F64": torch.float64,
"F32": torch.float32,
"BF16": torch.bfloat16,
"F16": torch.float16,
# "U64": torch.uint64,
"I64": torch.int64,
# "U32": torch.uint32,
"I32": torch.int32,
# "U16": torch.uint16,
"I16": torch.int16,
"U8": torch.uint8,
"I8": torch.int8,
"BOOL": torch.bool,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}
def numpy(self) -> gguf.LazyNumpyTensor: def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype] dtype = self._dtype_map[self.dtype]
return gguf.LazyNumpyTensor( return gguf.LazyNumpyTensor(
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
lazy=self._lazy,
args=(self,), args=(self,),
func=(lambda s: s[0].numpy()) func=(lambda s: s.numpy())
) )
@classmethod @classmethod
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor: def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
return torch.empty(size=shape, dtype=dtype, device="meta") return torch.empty(size=shape, dtype=dtype, device="meta")
@classmethod
def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
dtype = cls._dtype_str_map[st_slice.get_dtype()]
shape: tuple[int, ...] = tuple(st_slice.get_shape())
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
return cast(torch.Tensor, lazy)
@classmethod @classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
del types # unused del types # unused
@ -3447,7 +3481,7 @@ class LazyTorchTensor(gguf.LazyBase):
if func is torch.Tensor.numpy: if func is torch.Tensor.numpy:
return args[0].numpy() return args[0].numpy()
return LazyTorchTensor._wrap_fn(func)(*args, **kwargs) return cls._wrap_fn(func)(*args, **kwargs)
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:

View File

@ -3,7 +3,6 @@ from abc import ABC, ABCMeta, abstractmethod
import logging import logging
from typing import Any, Callable from typing import Any, Callable
from collections import deque
import numpy as np import numpy as np
from numpy.typing import DTypeLike from numpy.typing import DTypeLike
@ -74,20 +73,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
_tensor_type: type _tensor_type: type
_meta: Any _meta: Any
_data: Any | None _data: Any | None
_lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
_args: tuple _args: tuple
_func: Callable[[tuple], Any] | None _kwargs: dict[str, Any]
_func: Callable[[Any], Any] | None
def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None): def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
super().__init__() super().__init__()
self._meta = meta self._meta = meta
self._data = data self._data = data
self._lazy = lazy if lazy is not None else deque()
self._args = args self._args = args
self._kwargs = kwargs if kwargs is not None else {}
self._func = func self._func = func
assert self._func is not None or self._data is not None assert self._func is not None or self._data is not None
if self._data is None:
self._lazy.append(self)
def __init_subclass__(cls) -> None: def __init_subclass__(cls) -> None:
if "_tensor_type" not in cls.__dict__: if "_tensor_type" not in cls.__dict__:
@ -117,6 +114,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
args = ((use_self,) if use_self is not None else ()) + args args = ((use_self,) if use_self is not None else ()) + args
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta) meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
# TODO: maybe handle tensors in kwargs too
if isinstance(meta_noop, bool) and not meta_noop: if isinstance(meta_noop, bool) and not meta_noop:
try: try:
@ -140,23 +138,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape) res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
if isinstance(res, cls._tensor_type): if isinstance(res, cls._tensor_type):
class CollectSharedLazy: return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
# emulating a static variable
shared_lazy: None | deque[LazyBase] = None
@staticmethod
def collect_replace(t: LazyBase):
if CollectSharedLazy.shared_lazy is None:
CollectSharedLazy.shared_lazy = t._lazy
else:
CollectSharedLazy.shared_lazy.extend(t._lazy)
t._lazy = CollectSharedLazy.shared_lazy
LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace)
shared_lazy = CollectSharedLazy.shared_lazy
return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
else: else:
del res # not needed del res # not needed
# non-tensor return likely relies on the contents of the args # non-tensor return likely relies on the contents of the args
@ -168,26 +150,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
@classmethod @classmethod
def to_eager(cls, t: Any) -> Any: def to_eager(cls, t: Any) -> Any:
def simple_to_eager(_t: LazyBase) -> Any: def simple_to_eager(_t: LazyBase) -> Any:
def already_eager_to_eager(_t: LazyBase) -> Any: if _t._data is not None:
assert _t._data is not None
return _t._data return _t._data
while _t._data is None: # NOTE: there's a recursion limit in Python (usually 1000)
lt = _t._lazy.popleft()
if lt._data is not None: assert _t._func is not None
# Lazy tensor did not belong in the lazy queue. _t._args = cls._recurse_apply(_t._args, simple_to_eager)
# Weirdly only happens with Bloom models... _t._data = _t._func(*_t._args, **_t._kwargs)
# likely because tensors aren't unique in the queue. # sanity check
# The final output is still the same as in eager mode, assert _t._data is not None
# so it's safe to ignore this. assert _t._data.dtype == _t._meta.dtype
continue assert _t._data.shape == _t._meta.shape
assert lt._func is not None
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
lt._data = lt._func(lt._args)
# sanity check
assert lt._data is not None
assert lt._data.dtype == lt._meta.dtype
assert lt._data.shape == lt._meta.shape
return _t._data return _t._data
@ -206,7 +180,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
@classmethod @classmethod
def from_eager(cls, t: Any) -> Any: def from_eager(cls, t: Any) -> Any:
if type(t) is cls: if type(t) is cls:
# already eager # already lazy
return t return t
elif isinstance(t, cls._tensor_type): elif isinstance(t, cls._tensor_type):
return cls(meta=cls.eager_to_meta(t), data=t) return cls(meta=cls.eager_to_meta(t), data=t)
@ -228,8 +202,7 @@ class LazyNumpyTensor(LazyBase):
def astype(self, dtype, *args, **kwargs): def astype(self, dtype, *args, **kwargs):
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape) meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
full_args = (self, dtype,) + args full_args = (self, dtype,) + args
# very important to pass the shared _lazy deque, or else there's an infinite loop somewhere. return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
def tofile(self, *args, **kwargs): def tofile(self, *args, **kwargs):
eager = LazyNumpyTensor.to_eager(self) eager = LazyNumpyTensor.to_eager(self)

View File

@ -602,14 +602,12 @@ class TensorNameMap:
for tensor, keys in self.block_mappings_cfg.items(): for tensor, keys in self.block_mappings_cfg.items():
if tensor not in MODEL_TENSORS[arch]: if tensor not in MODEL_TENSORS[arch]:
continue continue
# TODO: make this configurable
n_experts = 160 tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
for xid in range(n_experts): self.mapping[tensor_name] = (tensor, tensor_name)
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid) for key in keys:
self.mapping[tensor_name] = (tensor, tensor_name) key = key.format(bid = bid)
for key in keys: self.mapping[key] = (tensor, tensor_name)
key = key.format(bid = bid, xid = xid)
self.mapping[key] = (tensor, tensor_name)
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None: def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key) result = self.mapping.get(key)