s/_MODEL_CLASSES/_model_classes/

This commit is contained in:
Jared Van Bortel 2024-03-02 12:14:37 -05:00
parent 7f0a1d66b5
commit 0b673ca187

View File

@ -39,7 +39,7 @@ class SentencePieceTokenTypes(IntEnum):
AnyModel = TypeVar("AnyModel", bound="type[Model]") AnyModel = TypeVar("AnyModel", bound="type[Model]")
class Model(ABC): class Model(ABC):
_MODEL_CLASSES: dict[str, type[Model]] = {} _model_classes: dict[str, type[Model]] = {}
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool): def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool):
self.dir_model = dir_model self.dir_model = dir_model
@ -189,14 +189,14 @@ class Model(ABC):
assert names assert names
def func(modelcls: type[Model]): def func(modelcls: type[Model]):
for name in names: for name in names:
cls._MODEL_CLASSES[name] = modelcls cls._model_classes[name] = modelcls
return modelcls return modelcls
return func return func
@classmethod @classmethod
def from_model_architecture(cls, arch): def from_model_architecture(cls, arch):
try: try:
return cls._MODEL_CLASSES[arch] return cls._model_classes[arch]
except KeyError: except KeyError:
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None raise NotImplementedError(f'Architecture {arch!r} not supported!') from None