s/_MODEL_CLASSES/_model_classes/

This commit is contained in:
Jared Van Bortel 2024-03-02 12:14:37 -05:00
parent 7f0a1d66b5
commit 0b673ca187

View File

@ -39,7 +39,7 @@ class SentencePieceTokenTypes(IntEnum):
AnyModel = TypeVar("AnyModel", bound="type[Model]")
class Model(ABC):
_MODEL_CLASSES: dict[str, type[Model]] = {}
_model_classes: dict[str, type[Model]] = {}
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool):
self.dir_model = dir_model
@ -189,14 +189,14 @@ class Model(ABC):
assert names
def func(modelcls: type[Model]):
for name in names:
cls._MODEL_CLASSES[name] = modelcls
cls._model_classes[name] = modelcls
return modelcls
return func
@classmethod
def from_model_architecture(cls, arch):
try:
return cls._MODEL_CLASSES[arch]
return cls._model_classes[arch]
except KeyError:
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None