flake : fix

This commit is contained in:
Georgi Gerganov 2024-03-04 21:50:50 +02:00
parent a1c6d96ed8
commit e0843afe1b
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -36,8 +36,10 @@ class SentencePieceTokenTypes(IntEnum):
UNUSED = 5
BYTE = 6
AnyModel = TypeVar("AnyModel", bound="type[Model]")
class Model(ABC):
_model_classes: dict[str, type[Model]] = {}
@ -187,6 +189,7 @@ class Model(ABC):
@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
assert names
def func(modelcls: type[Model]):
for name in names:
cls._model_classes[name] = modelcls