refactor: Further refine functionality, improve user interaction, and streamline vocabulary handling

- Renamed command-line arguments for clarity and consistency.
- Improved path resolution and import adjustments for robustness.
- Thoughtfully handled 'awq-path' and conditional logic for the weighted model.
- Enhanced model and vocabulary loading with the 'VocabFactory' class for structured and adaptable loading.
- Strengthened error handling and user feedback for a more user-friendly experience.
- Structured output file handling with clear conditions and defaults.
- Streamlined and organized the 'main' function for better logic flow.
- Passed 'sys.argv[1:]' to 'main' for adaptability and testability.

These changes solidify the script's functionality, making it more robust, user-friendly, and adaptable. The use of the 'VocabFactory' class is a notable enhancement in efficient vocabulary handling, reflecting a thoughtful and iterative approach to script development.
This commit is contained in:
teleprint-me 2024-01-07 21:54:42 -05:00
parent 226cea270e
commit 0614c338f8
No known key found for this signature in database
GPG Key ID: B0D11345E65C4D48

View File

@ -1555,8 +1555,9 @@ def main(argv: Optional[list[str]] = None) -> None:
args = parser.parse_args(argv)
if args.awq_path:
sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
sys.path.insert(1, str(Path(__file__).resolve().parent / "awq-py"))
from awq.apply_awq import add_scale_weights
tmp_model_path = args.model / "weighted_model"
if tmp_model_path.is_dir():
print(f"{tmp_model_path} exists as a weighted model.")
@ -1575,74 +1576,83 @@ def main(argv: Optional[list[str]] = None) -> None:
if not args.vocab_only:
model_plus = load_some_model(args.model)
else:
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
model_plus = ModelPlus(
model={}, paths=[args.model / "dummy"], format="none", vocab=None
)
if args.dump:
do_dump_model(model_plus)
return
endianess = gguf.GGUFEndian.LITTLE
if args.bigendian:
if args.big_endian:
endianess = gguf.GGUFEndian.BIG
params = Params.load(model_plus)
if params.n_ctx == -1:
if args.ctx is None:
raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n"
"Please specify one with --ctx:\n"
" - LLaMA v1: --ctx 2048\n"
" - LLaMA v2: --ctx 4096\n")
raise Exception(
"The model doesn't have a context size, and you didn't specify one with --ctx\n"
"Please specify one with --ctx:\n"
" - LLaMA v1: --ctx 2048\n"
" - LLaMA v2: --ctx 4096\n"
)
params.n_ctx = args.ctx
if args.outtype:
if args.out_type:
params.ftype = {
"f32": GGMLFileType.AllF32,
"f16": GGMLFileType.MostlyF16,
"q8_0": GGMLFileType.MostlyQ8_0,
}[args.outtype]
}[args.out_type]
print(f"params = {params}")
vocab: Vocab
model_parent_path = model_plus.paths[0].parent
vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
vocab_factory = VocabFactory(vocab_path)
vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type, model_parent_path)
if args.vocab_only:
if not args.outfile:
raise ValueError("need --outfile if using --vocab-only")
# FIXME: Try to respect vocab_dir somehow?
vocab = VocabLoader(params, args.vocab_dir or args.model)
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
load_merges = True,
n_vocab = vocab.vocab_size)
outfile = args.outfile
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
endianess = endianess, pad_vocab = args.padvocab)
print(f"Wrote {outfile}")
if not args.out_file:
raise ValueError("need --out-file if using --vocab-only")
out_file = args.out_file
OutputFile.write_vocab_only(
out_file,
params,
vocab,
special_vocab,
endianess=endianess,
pad_vocab=args.pad_vocab,
)
print(f"Wrote {out_file}")
return
if model_plus.vocab is not None and args.vocab_dir is None:
vocab = model_plus.vocab
else:
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
vocab = VocabLoader(params, vocab_dir)
# FIXME: Try to respect vocab_dir somehow?
print(f"Vocab info: {vocab}")
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
load_merges = True,
n_vocab = vocab.vocab_size)
print(f"Special vocab info: {special_vocab}")
model = model_plus.model
model = convert_model_names(model, params)
ftype = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, ftype)
outfile = args.outfile or default_outfile(model_plus.paths, ftype)
model = model_plus.model
model = convert_model_names(model, params)
ftype = pick_output_type(model, args.out_type)
model = convert_to_output_type(model, ftype)
out_file = args.out_file or default_output_file(model_plus.paths, ftype)
params.ftype = ftype
print(f"Writing {outfile}, format {ftype}")
print(f"Writing {out_file}, format {ftype}")
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
concurrency = args.concurrency, endianess = endianess, pad_vocab = args.padvocab)
print(f"Wrote {outfile}")
OutputFile.write_all(
out_file,
ftype,
params,
model,
vocab,
special_vocab,
concurrency=args.concurrency,
endianess=endianess,
pad_vocab=args.pad_vocab,
)
print(f"Wrote {out_file}")
if __name__ == '__main__':
main()
if __name__ == "__main__":
main(sys.argv[1:]) # Exclude the first element (script name) from sys.argv