This commit is contained in:
FSSRepo 2023-10-12 12:55:08 -04:00
commit b716eeb72a
37 changed files with 13747 additions and 2584 deletions

2
.gitignore vendored
View File

@ -45,6 +45,7 @@ models-mnt
/infill /infill
/libllama.so /libllama.so
/llama-bench /llama-bench
/llava
/main /main
/metal /metal
/perplexity /perplexity
@ -56,6 +57,7 @@ models-mnt
/server /server
/simple /simple
/batched /batched
/batched-bench
/export-lora /export-lora
/finetune /finetune
/speculative /speculative

View File

@ -422,8 +422,7 @@ endif()
if (LLAMA_ALL_WARNINGS) if (LLAMA_ALL_WARNINGS)
if (NOT MSVC) if (NOT MSVC)
set(warning_flags -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function) set(warning_flags -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function)
set(c_flags -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int set(c_flags -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int -Werror=implicit-function-declaration)
-Werror=implicit-function-declaration)
set(cxx_flags -Wmissing-declarations -Wmissing-noreturn) set(cxx_flags -Wmissing-declarations -Wmissing-noreturn)
set(host_cxx_flags "") set(host_cxx_flags "")
@ -455,7 +454,8 @@ if (LLAMA_ALL_WARNINGS)
set(c_flags ${c_flags} ${warning_flags}) set(c_flags ${c_flags} ${warning_flags})
set(cxx_flags ${cxx_flags} ${warning_flags}) set(cxx_flags ${cxx_flags} ${warning_flags})
add_compile_options("$<$<COMPILE_LANGUAGE:C>:${c_flags}>" add_compile_options("$<$<COMPILE_LANGUAGE:C>:${c_flags}>"
"$<$<COMPILE_LANGUAGE:CXX>:${cxx_flags} ${host_cxx_flags}>") "$<$<COMPILE_LANGUAGE:CXX>:${cxx_flags}>"
"$<$<COMPILE_LANGUAGE:CXX>:${host_cxx_flags}>")
endif() endif()

102
Makefile
View File

@ -1,8 +1,14 @@
# Define the default target now so that it is always the first target # Define the default target now so that it is always the first target
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o BUILD_TARGETS = \
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server embd-input-test gguf llama-bench llava baby-llama beam-search \
speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o
# Binaries only useful for tests # Binaries only useful for tests
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe TEST_TARGETS = \
tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt \
tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \
tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe
# Code coverage output files # Code coverage output files
COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report
@ -172,6 +178,24 @@ else
MK_CPPFLAGS += -DNDEBUG MK_CPPFLAGS += -DNDEBUG
endif endif
ifdef LLAMA_SANITIZE_THREAD
MK_CFLAGS += -fsanitize=thread -g
MK_CXXFLAGS += -fsanitize=thread -g
MK_LDFLAGS += -fsanitize=thread -g
endif
ifdef LLAMA_SANITIZE_ADDRESS
MK_CFLAGS += -fsanitize=address -fno-omit-frame-pointer -g
MK_CXXFLAGS += -fsanitize=address -fno-omit-frame-pointer -g
MK_LDFLAGS += -fsanitize=address -fno-omit-frame-pointer -g
endif
ifdef LLAMA_SANITIZE_UNDEFINED
MK_CFLAGS += -fsanitize=undefined -g
MK_CXXFLAGS += -fsanitize=undefined -g
MK_LDFLAGS += -fsanitize=undefined -g
endif
ifdef LLAMA_SERVER_VERBOSE ifdef LLAMA_SERVER_VERBOSE
MK_CPPFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE) MK_CPPFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE)
endif endif
@ -520,7 +544,13 @@ OBJS += ggml-alloc.o ggml-backend.o
llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
common.o: common/common.cpp common/common.h build-info.h common/log.h COMMON_H_DEPS = common/common.h common/sampling.h build-info.h common/log.h
COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o
common.o: common/common.cpp $(COMMON_H_DEPS)
$(CXX) $(CXXFLAGS) -c $< -o $@
sampling.o: common/sampling.cpp $(COMMON_H_DEPS)
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
console.o: common/console.cpp common/console.h console.o: common/console.cpp common/console.h
@ -542,19 +572,22 @@ clean:
# Examples # Examples
# #
main: examples/main/main.cpp build-info.h ggml.o llama.o common.o console.o grammar-parser.o $(OBJS) main: examples/main/main.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
@echo @echo
@echo '==== Run ./main -h for help. ====' @echo '==== Run ./main -h for help. ===='
@echo @echo
infill: examples/infill/infill.cpp build-info.h ggml.o llama.o common.o console.o grammar-parser.o $(OBJS) infill: examples/infill/infill.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS) simple: examples/simple/simple.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
batched: examples/batched/batched.cpp build-info.h ggml.o llama.o common.o $(OBJS) batched: examples/batched/batched.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
batched-bench: examples/batched-bench/batched-bench.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS) quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS)
@ -563,53 +596,56 @@ quantize: examples/quantize/quantize.cpp build-info.h ggml.
quantize-stats: examples/quantize-stats/quantize-stats.cpp build-info.h ggml.o llama.o $(OBJS) quantize-stats: examples/quantize-stats/quantize-stats.cpp build-info.h ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
perplexity: examples/perplexity/perplexity.cpp build-info.h ggml.o llama.o common.o $(OBJS) perplexity: examples/perplexity/perplexity.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
embedding: examples/embedding/embedding.cpp build-info.h ggml.o llama.o common.o $(OBJS) embedding: examples/embedding/embedding.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS) save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2) $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2)
$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) --shared $(CXXFLAGS) $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(CXX) --shared $(CXXFLAGS) $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS)
embd-input-test: $(LIB_PRE)embdinput$(DSO_EXT) examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o common.o $(OBJS) embd-input-test: $(LIB_PRE)embdinput$(DSO_EXT) examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %$(DSO_EXT),$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput $(CXX) $(CXXFLAGS) $(filter-out %$(DSO_EXT),$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput
gguf: examples/gguf/gguf.cpp ggml.o llama.o $(OBJS) gguf: examples/gguf/gguf.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp ggml.o llama.o common.o train.o $(OBJS) train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp ggml.o llama.o $(OBJS) convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
llama-bench: examples/llama-bench/llama-bench.cpp build-info.h ggml.o llama.o common.o $(OBJS) llama-bench: examples/llama-bench/llama-bench.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o train.o $(OBJS) llava: examples/llava/llava.cpp examples/llava/llava-utils.h examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -Wno-cast-qual
baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o common.o $(OBJS) beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
finetune: examples/finetune/finetune.cpp build-info.h ggml.o llama.o common.o train.o $(OBJS) finetune: examples/finetune/finetune.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
export-lora: examples/export-lora/export-lora.cpp build-info.h ggml.o llama.o common.o $(OBJS) export-lora: examples/export-lora/export-lora.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
speculative: examples/speculative/speculative.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) speculative: examples/speculative/speculative.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
parallel: examples/parallel/parallel.cpp build-info.h ggml.o llama.o common.o $(OBJS) parallel: examples/parallel/parallel.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
ifdef LLAMA_METAL ifdef LLAMA_METAL
@ -650,40 +686,40 @@ vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS)
q8dot: pocs/vdot/q8dot.cpp ggml.o $(OBJS) q8dot: pocs/vdot/q8dot.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
tests/test-llama-grammar: tests/test-llama-grammar.cpp build-info.h ggml.o common.o grammar-parser.o $(OBJS) tests/test-llama-grammar: tests/test-llama-grammar.cpp build-info.h ggml.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-grammar-parser: tests/test-grammar-parser.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) tests/test-grammar-parser: tests/test-grammar-parser.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-double-float: tests/test-double-float.cpp build-info.h ggml.o llama.o common.o $(OBJS) tests/test-double-float: tests/test-double-float.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-grad0: tests/test-grad0.cpp build-info.h ggml.o llama.o common.o $(OBJS) tests/test-grad0: tests/test-grad0.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-opt: tests/test-opt.cpp build-info.h ggml.o llama.o common.o $(OBJS) tests/test-opt: tests/test-opt.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-quantize-fns: tests/test-quantize-fns.cpp build-info.h ggml.o llama.o common.o $(OBJS) tests/test-quantize-fns: tests/test-quantize-fns.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-quantize-perf: tests/test-quantize-perf.cpp build-info.h ggml.o llama.o common.o $(OBJS) tests/test-quantize-perf: tests/test-quantize-perf.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-sampling: tests/test-sampling.cpp build-info.h ggml.o llama.o common.o $(OBJS) tests/test-sampling: tests/test-sampling.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-tokenizer-0-falcon: tests/test-tokenizer-0-falcon.cpp build-info.h ggml.o llama.o common.o $(OBJS) tests/test-tokenizer-0-falcon: tests/test-tokenizer-0-falcon.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-tokenizer-0-llama: tests/test-tokenizer-0-llama.cpp build-info.h ggml.o llama.o common.o $(OBJS) tests/test-tokenizer-0-llama: tests/test-tokenizer-0-llama.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-tokenizer-1-bpe: tests/test-tokenizer-1-bpe.cpp build-info.h ggml.o llama.o common.o $(OBJS) tests/test-tokenizer-1-bpe: tests/test-tokenizer-1-bpe.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-tokenizer-1-llama: tests/test-tokenizer-1-llama.cpp build-info.h ggml.o llama.o common.o $(OBJS) tests/test-tokenizer-1-llama: tests/test-tokenizer-1-llama.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-c.o: tests/test-c.c llama.h tests/test-c.o: tests/test-c.c llama.h

View File

@ -279,7 +279,7 @@ In order to build llama.cpp you have three different options.
On MacOS, Metal is enabled by default. Using Metal makes the computation run on the GPU. On MacOS, Metal is enabled by default. Using Metal makes the computation run on the GPU.
To disable the Metal build at compile time use the `LLAMA_NO_METAL=1` flag or the `LLAMA_METAL=OFF` cmake option. To disable the Metal build at compile time use the `LLAMA_NO_METAL=1` flag or the `LLAMA_METAL=OFF` cmake option.
When built with Metal support, you can explicitly disable GPU inference with the `--gpu-layers|-ngl 0` command-line When built with Metal support, you can explicitly disable GPU inference with the `--n-gpu-layers|-ngl 0` command-line
argument. argument.
### MPI Build ### MPI Build

View File

@ -128,17 +128,18 @@ pub fn build(b: *std.build.Builder) !void {
const llama = make.obj("llama", "llama.cpp"); const llama = make.obj("llama", "llama.cpp");
const common = make.obj("common", "common/common.cpp"); const common = make.obj("common", "common/common.cpp");
const console = make.obj("console", "common/console.cpp"); const console = make.obj("console", "common/console.cpp");
const sampling = make.obj("sampling", "common/sampling.cpp");
const grammar_parser = make.obj("grammar-parser", "common/grammar-parser.cpp"); const grammar_parser = make.obj("grammar-parser", "common/grammar-parser.cpp");
const train = make.obj("train", "common/train.cpp"); const train = make.obj("train", "common/train.cpp");
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, console, grammar_parser }); _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, sampling, console, grammar_parser });
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common }); _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common }); _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common }); _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, train }); _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, train });
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, train }); _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, train });
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, grammar_parser }); const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, sampling, grammar_parser });
if (server.target.isWindows()) { if (server.target.isWindows()) {
server.linkSystemLibrary("ws2_32"); server.linkSystemLibrary("ws2_32");
} }

View File

@ -496,11 +496,13 @@ test $ret -eq 0 && gg_run ctest_debug
test $ret -eq 0 && gg_run ctest_release test $ret -eq 0 && gg_run ctest_release
if [ -z ${GG_BUILD_LOW_PERF} ]; then if [ -z ${GG_BUILD_LOW_PERF} ]; then
if [ -z ${GG_BUILD_VRAM_GB} ] || [ ${GG_BUILD_VRAM_GB} -ge 8 ]; then
if [ -z ${GG_BUILD_CUDA} ]; then if [ -z ${GG_BUILD_CUDA} ]; then
test $ret -eq 0 && gg_run open_llama_3b_v2 test $ret -eq 0 && gg_run open_llama_3b_v2
else else
test $ret -eq 0 && gg_run open_llama_7b_v2 test $ret -eq 0 && gg_run open_llama_7b_v2
fi fi
fi
fi fi
exit $ret exit $ret

View File

@ -5,6 +5,8 @@ set(TARGET common)
add_library(${TARGET} OBJECT add_library(${TARGET} OBJECT
common.h common.h
common.cpp common.cpp
sampling.h
sampling.cpp
console.h console.h
console.cpp console.cpp
grammar-parser.h grammar-parser.h

View File

@ -107,6 +107,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::string arg; std::string arg;
gpt_params default_params; gpt_params default_params;
const std::string arg_prefix = "--"; const std::string arg_prefix = "--";
llama_sampling_params & sparams = params.sampling_params;
for (int i = 1; i < argc; i++) { for (int i = 1; i < argc; i++) {
arg = argv[i]; arg = argv[i];
@ -184,7 +185,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.top_k = std::stoi(argv[i]); sparams.top_k = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx-size") { } else if (arg == "-c" || arg == "--ctx-size") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -216,73 +217,73 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.top_p = std::stof(argv[i]); sparams.top_p = std::stof(argv[i]);
} else if (arg == "--temp") { } else if (arg == "--temp") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.temp = std::stof(argv[i]); sparams.temp = std::stof(argv[i]);
} else if (arg == "--tfs") { } else if (arg == "--tfs") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.tfs_z = std::stof(argv[i]); sparams.tfs_z = std::stof(argv[i]);
} else if (arg == "--typical") { } else if (arg == "--typical") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.typical_p = std::stof(argv[i]); sparams.typical_p = std::stof(argv[i]);
} else if (arg == "--repeat-last-n") { } else if (arg == "--repeat-last-n") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.repeat_last_n = std::stoi(argv[i]); sparams.repeat_last_n = std::stoi(argv[i]);
} else if (arg == "--repeat-penalty") { } else if (arg == "--repeat-penalty") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.repeat_penalty = std::stof(argv[i]); sparams.repeat_penalty = std::stof(argv[i]);
} else if (arg == "--frequency-penalty") { } else if (arg == "--frequency-penalty") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.frequency_penalty = std::stof(argv[i]); sparams.frequency_penalty = std::stof(argv[i]);
} else if (arg == "--presence-penalty") { } else if (arg == "--presence-penalty") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.presence_penalty = std::stof(argv[i]); sparams.presence_penalty = std::stof(argv[i]);
} else if (arg == "--mirostat") { } else if (arg == "--mirostat") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.mirostat = std::stoi(argv[i]); sparams.mirostat = std::stoi(argv[i]);
} else if (arg == "--mirostat-lr") { } else if (arg == "--mirostat-lr") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.mirostat_eta = std::stof(argv[i]); sparams.mirostat_eta = std::stof(argv[i]);
} else if (arg == "--mirostat-ent") { } else if (arg == "--mirostat-ent") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.mirostat_tau = std::stof(argv[i]); sparams.mirostat_tau = std::stof(argv[i]);
} else if (arg == "--cfg-negative-prompt") { } else if (arg == "--cfg-negative-prompt") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.cfg_negative_prompt = argv[i]; sparams.cfg_negative_prompt = argv[i];
} else if (arg == "--cfg-negative-prompt-file") { } else if (arg == "--cfg-negative-prompt-file") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -294,16 +295,16 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true; invalid_param = true;
break; break;
} }
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.cfg_negative_prompt)); std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(sparams.cfg_negative_prompt));
if (!params.cfg_negative_prompt.empty() && params.cfg_negative_prompt.back() == '\n') { if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') {
params.cfg_negative_prompt.pop_back(); sparams.cfg_negative_prompt.pop_back();
} }
} else if (arg == "--cfg-scale") { } else if (arg == "--cfg-scale") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.cfg_scale = std::stof(argv[i]); sparams.cfg_scale = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch-size") { } else if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -383,6 +384,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.lora_base = argv[i]; params.lora_base = argv[i];
} else if (arg == "--mmproj") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.mmproj = argv[i];
} else if (arg == "--image") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.image = argv[i];
} else if (arg == "-i" || arg == "--interactive") { } else if (arg == "-i" || arg == "--interactive") {
params.interactive = true; params.interactive = true;
} else if (arg == "--embedding") { } else if (arg == "--embedding") {
@ -512,7 +525,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "--ignore-eos") { } else if (arg == "--ignore-eos") {
params.ignore_eos = true; params.ignore_eos = true;
} else if (arg == "--no-penalize-nl") { } else if (arg == "--no-penalize-nl") {
params.penalize_nl = false; sparams.penalize_nl = false;
} else if (arg == "-l" || arg == "--logit-bias") { } else if (arg == "-l" || arg == "--logit-bias") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -524,7 +537,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::string value_str; std::string value_str;
try { try {
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
} else { } else {
throw std::exception(); throw std::exception();
} }
@ -627,6 +640,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} }
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
const llama_sampling_params & sparams = params.sampling_params;
printf("usage: %s [options]\n", argv[0]); printf("usage: %s [options]\n", argv[0]);
printf("\n"); printf("\n");
printf("options:\n"); printf("options:\n");
@ -659,19 +674,19 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n); printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n);
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty); printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty);
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty); printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty);
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty); printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty);
printf(" --mirostat N use Mirostat sampling.\n"); printf(" --mirostat N use Mirostat sampling.\n");
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat); printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
printf(" --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta); printf(" --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)sparams.mirostat_eta);
printf(" --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau); printf(" --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)sparams.mirostat_tau);
printf(" -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n"); printf(" -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n");
printf(" modifies the likelihood of token appearing in the completion,\n"); printf(" modifies the likelihood of token appearing in the completion,\n");
printf(" i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); printf(" i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
@ -682,7 +697,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" negative prompt to use for guidance. (default: empty)\n"); printf(" negative prompt to use for guidance. (default: empty)\n");
printf(" --cfg-negative-prompt-file FNAME\n"); printf(" --cfg-negative-prompt-file FNAME\n");
printf(" negative prompt file to use for guidance. (default: empty)\n"); printf(" negative prompt file to use for guidance. (default: empty)\n");
printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale);
printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n"); printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n");
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n"); printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n"); printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n");
@ -690,7 +705,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --no-penalize-nl do not penalize newline token\n"); printf(" --no-penalize-nl do not penalize newline token\n");
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
printf(" --temp N temperature (default: %.1f)\n", (double)params.temp); printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n"); printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
@ -700,6 +715,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n");
if (llama_mlock_supported()) { if (llama_mlock_supported()) {
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
} }
@ -840,7 +857,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
} }
if (params.ignore_eos) { if (params.ignore_eos) {
params.logit_bias[llama_token_eos(lctx)] = -INFINITY; params.sampling_params.logit_bias[llama_token_eos(lctx)] = -INFINITY;
} }
{ {
@ -932,127 +949,6 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
return result; return result;
} }
//
// Sampling utils
//
llama_token llama_sample_token(
struct llama_context * ctx,
struct llama_context * ctx_guidance,
struct llama_grammar * grammar,
const struct gpt_params & params,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
int idx) {
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
float * logits = llama_get_logits_ith(ctx, idx);
// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
candidates.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
if (ctx_guidance) {
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
}
// apply penalties
if (!last_tokens.empty()) {
const float nl_logit = logits[llama_token_nl(ctx)];
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx, &cur_p,
last_tokens.data() + last_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
last_tokens.data() + last_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
cur_p.data[idx].logit = nl_logit;
break;
}
}
}
}
if (grammar != NULL) {
llama_sample_grammar(ctx, &cur_p, grammar);
}
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &cur_p);
} else {
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
size_t min_keep = std::max(1, params.n_probs);
llama_sample_top_k (ctx, &cur_p, top_k, min_keep);
llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep);
llama_sample_typical (ctx, &cur_p, typical_p, min_keep);
llama_sample_top_p (ctx, &cur_p, top_p, min_keep);
llama_sample_temp(ctx, &cur_p, temp);
{
const int n_top = 10;
LOG("top %d candidates:\n", n_top);
for (int i = 0; i < n_top; i++) {
const llama_token id = cur_p.data[i].id;
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
}
}
id = llama_sample_token(ctx, &cur_p);
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
}
}
// printf("`%d`", candidates_p.size);
if (grammar != NULL) {
llama_grammar_accept_token(ctx, grammar, id);
}
return id;
}
// //
// YAML utils // YAML utils
// //
@ -1204,6 +1100,8 @@ std::string get_sortable_timestamp() {
void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx, void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) { const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
const llama_sampling_params & sparams = params.sampling_params;
fprintf(stream, "build_commit: %s\n", BUILD_COMMIT); fprintf(stream, "build_commit: %s\n", BUILD_COMMIT);
fprintf(stream, "build_number: %d\n", BUILD_NUMBER); fprintf(stream, "build_number: %d\n", BUILD_NUMBER);
fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false");
@ -1250,21 +1148,21 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str());
fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch);
dump_string_yaml_multiline(stream, "cfg_negative_prompt", params.cfg_negative_prompt.c_str()); dump_string_yaml_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str());
fprintf(stream, "cfg_scale: %f # default: 1.0\n", params.cfg_scale); fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale);
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty); fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty);
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str()); dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
const auto logit_bias_eos = params.logit_bias.find(llama_token_eos(lctx)); const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(lctx));
const bool ignore_eos = logit_bias_eos != params.logit_bias.end() && logit_bias_eos->second == -INFINITY; const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
dump_string_yaml_multiline(stream, "in_prefix", params.input_prefix.c_str()); dump_string_yaml_multiline(stream, "in_prefix", params.input_prefix.c_str());
@ -1277,7 +1175,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
fprintf(stream, "logit_bias:\n"); fprintf(stream, "logit_bias:\n");
for (std::pair<llama_token, float> lb : params.logit_bias) { for (std::pair<llama_token, float> lb : sparams.logit_bias) {
if (ignore_eos && lb.first == logit_bias_eos->first) { if (ignore_eos && lb.first == logit_bias_eos->first) {
continue; continue;
} }
@ -1301,30 +1199,30 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false"); fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", params.mirostat); fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", params.mirostat_tau); fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta); fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str()); fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str());
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers);
fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict);
fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", params.n_probs); fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs);
fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false");
fprintf(stream, "no_mul_mat_q: %s # default: false\n", !params.mul_mat_q ? "true" : "false"); fprintf(stream, "no_mul_mat_q: %s # default: false\n", !params.mul_mat_q ? "true" : "false");
fprintf(stream, "no_penalize_nl: %s # default: false\n", !params.penalize_nl ? "true" : "false"); fprintf(stream, "no_penalize_nl: %s # default: false\n", !sparams.penalize_nl ? "true" : "false");
fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false"); fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
fprintf(stream, "presence_penalty: %f # default: 0.0\n", params.presence_penalty); fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty);
dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str()); dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens); dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false"); fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", params.repeat_penalty); fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty);
fprintf(stream, "reverse_prompt:\n"); fprintf(stream, "reverse_prompt:\n");
for (std::string ap : params.antiprompt) { for (std::string ap : params.antiprompt) {
@ -1342,15 +1240,15 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed); fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed);
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "temp: %f # default: 0.8\n", params.temp); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES); const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES);
dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector); dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector);
fprintf(stream, "tfs: %f # default: 1.0\n", params.tfs_z); fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);
fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency());
fprintf(stream, "top_k: %d # default: 40\n", params.top_k); fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", params.top_p); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "typical_p: %f # default: 1.0\n", params.typical_p); fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
} }

View File

@ -4,6 +4,8 @@
#include "llama.h" #include "llama.h"
#include "sampling.h"
#define LOG_NO_FILE_LINE_FUNCTION #define LOG_NO_FILE_LINE_FUNCTION
#include "log.h" #include "log.h"
@ -49,31 +51,12 @@ struct gpt_params {
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t n_beams = 0; // if non-zero then use beam search of given width. int32_t n_beams = 0; // if non-zero then use beam search of given width.
float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_base = 0.0f; // RoPE base frequency
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
// sampling parameters // // sampling parameters
int32_t top_k = 40; // <= 0 to use vocab size struct llama_sampling_params sampling_params;
float top_p = 0.95f; // 1.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled
float repeat_penalty = 1.10f; // 1.0 = disabled
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float frequency_penalty = 0.00f; // 0.0 = disabled
float presence_penalty = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
// Classifier-Free Guidance
// https://arxiv.org/abs/2306.17806
std::string cfg_negative_prompt; // string to help guidance
float cfg_scale = 1.f; // How strong is guidance
std::string model = "models/7B/ggml-model-f16.gguf"; // model path std::string model = "models/7B/ggml-model-f16.gguf"; // model path
std::string model_draft = ""; // draft model for speculative decoding std::string model_draft = ""; // draft model for speculative decoding
@ -115,13 +98,16 @@ struct gpt_params {
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool ignore_eos = false; // ignore generated EOS tokens bool ignore_eos = false; // ignore generated EOS tokens
bool instruct = false; // instruction mode (used for Alpaca models) bool instruct = false; // instruction mode (used for Alpaca models)
bool penalize_nl = true; // consider newlines as a repeatable token
bool logits_all = false; // return logits for all tokens in the batch bool logits_all = false; // return logits for all tokens in the batch
bool use_mmap = true; // use mmap for faster loads bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory bool use_mlock = false; // use mlock to keep model in memory
bool numa = false; // attempt optimizations that help on some NUMA systems bool numa = false; // attempt optimizations that help on some NUMA systems
bool verbose_prompt = false; // print prompt tokens before generation bool verbose_prompt = false; // print prompt tokens before generation
bool infill = false; // use infill mode bool infill = false; // use infill mode
// multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector
std::string image = ""; // path to an image file
}; };
bool gpt_params_parse(int argc, char ** argv, gpt_params & params); bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
@ -180,36 +166,6 @@ std::string llama_detokenize_bpe(
llama_context * ctx, llama_context * ctx,
const std::vector<llama_token> & tokens); const std::vector<llama_token> & tokens);
//
// Sampling utils
//
// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
//
// required:
// - ctx: context to use for sampling
// - params: sampling parameters
//
// optional:
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
// - grammar: grammar to use for sampling, ignore if NULL
// - last_tokens: needed for repetition penalty, ignore if empty
// - idx: sample from llama_get_logits_ith(ctx, idx)
//
// returns:
// - token: sampled token
// - candidates: vector of candidate tokens
//
llama_token llama_sample_token(
struct llama_context * ctx,
struct llama_context * ctx_guidance,
struct llama_grammar * grammar,
const struct gpt_params & params,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
int idx = 0);
// //
// YAML utils // YAML utils
// //

166
common/sampling.cpp Normal file
View File

@ -0,0 +1,166 @@
#include "sampling.h"
llama_sampling_context::~llama_sampling_context() {
for (auto & it : sequence_contexts) {
if (it.second.grammar != NULL) {
llama_grammar_free(it.second.grammar);
it.second.grammar = NULL;
}
}
}
llama_sampling_context llama_sampling_context_init(
const struct gpt_params & params,
llama_grammar * grammar) {
llama_sampling_context result;
result.params = params.sampling_params;
result.grammar = grammar;
return result;
}
// Note: Creates the context if it doesn't exist, so this always return something.
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq) {
const auto it = ctx_sampling.sequence_contexts.find(seq);
if (it != ctx_sampling.sequence_contexts.end()) {
return it->second;
}
llama_sampler_sequence_context new_ctx = {
2.0f * ctx_sampling.params.mirostat_tau,
ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL,
};
return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second;
}
bool llama_sampling_context_reset(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq) {
const auto it = ctx_sampling.sequence_contexts.find(seq);
if (it == ctx_sampling.sequence_contexts.end()) return false;
if (it->second.grammar != NULL) {
llama_grammar_free(it->second.grammar);
it->second.grammar = NULL;
}
ctx_sampling.sequence_contexts.erase(it);
return true;
}
llama_token llama_sampling_sample(
struct llama_context * ctx,
struct llama_context * ctx_guidance,
struct llama_sampling_context & ctx_sampling,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
const int idx,
llama_seq_id seq) {
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const llama_sampling_params & params = ctx_sampling.params;
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
float * logits = llama_get_logits_ith(ctx, idx);
// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
candidates.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
if (ctx_guidance) {
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
}
// apply penalties
if (!last_tokens.empty()) {
const float nl_logit = logits[llama_token_nl(ctx)];
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx, &cur_p,
last_tokens.data() + last_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
last_tokens.data() + last_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
cur_p.data[idx].logit = nl_logit;
break;
}
}
}
}
llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq);
if (ctx_seq.grammar != NULL) {
llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar);
}
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &cur_p);
} else {
if (mirostat == 1) {
const int mirostat_m = 100;
llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu);
} else if (mirostat == 2) {
llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu);
} else {
// Temperature sampling
size_t min_keep = std::max(1, params.n_probs);
llama_sample_top_k (ctx, &cur_p, top_k, min_keep);
llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep);
llama_sample_typical (ctx, &cur_p, typical_p, min_keep);
llama_sample_top_p (ctx, &cur_p, top_p, min_keep);
llama_sample_temp(ctx, &cur_p, temp);
{
const int n_top = 10;
LOG("top %d candidates:\n", n_top);
for (int i = 0; i < n_top; i++) {
const llama_token id = cur_p.data[i].id;
(void)id; // To avoid a warning that id is unused when logging is disabled.
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
}
}
id = llama_sample_token(ctx, &cur_p);
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
}
}
if (ctx_seq.grammar != NULL) {
llama_grammar_accept_token(ctx, ctx_seq.grammar, id);
}
return id;
}

108
common/sampling.h Normal file
View File

@ -0,0 +1,108 @@
#pragma once
#include "llama.h"
#include <string>
#include <vector>
#include <unordered_map>
// sampling parameters
typedef struct llama_sampling_params {
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled
float repeat_penalty = 1.10f; // 1.0 = disabled
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float frequency_penalty = 0.00f; // 0.0 = disabled
float presence_penalty = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = true; // consider newlines as a repeatable token
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
// Classifier-Free Guidance
// https://arxiv.org/abs/2306.17806
std::string cfg_negative_prompt; // string to help guidance
float cfg_scale = 1.f; // How strong is guidance
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
} llama_sampling_params;
// per-sequence sampler context
typedef struct llama_sampler_sequence_context {
float mirostat_mu; // mirostat sampler state
llama_grammar * grammar;
} llama_sampler_sequence_context;
// general sampler context
typedef struct llama_sampling_context {
~llama_sampling_context();
// parameters that will be used for sampling and when creating
// new llama_sampler_sequence_context instances
llama_sampling_params params;
// map of sequence ids to sampler contexts
std::unordered_map<llama_seq_id, llama_sampler_sequence_context> sequence_contexts;
// when non-NULL, new instances of llama_sampler_sequence_context
// will get a copy of the grammar here
// note: only the pointer is stored here, it is not a copy of
// the grammar and shouldn't be freed
llama_grammar * grammar;
} llama_sampling_context;
#include "common.h"
// Create a new sampling context instance.
llama_sampling_context llama_sampling_context_init(
const struct gpt_params & params,
llama_grammar * grammar = NULL);
// Fetches the sampler context for the specified sequence id (defaults to 0).
// If the context for that sequence id doesn't already exist, it will be created with
// default values based on the parameters in the ctx_sampling argument.
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq = 0);
// Reset the sampler context for the supplied sequence id (defaults to 0).
// This is necessary to reuse a sequence id or free memory used by sequences
// that are no longer required.
bool llama_sampling_context_reset(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq = 0);
// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call
// llama_sampling_context_reset when a sequence ends
//
// required:
// - ctx: context to use for sampling
// - ctx_sampling: sampling-specific context
//
// optional:
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
// - last_tokens: needed for repetition penalty, ignore if empty
// - idx: sample from llama_get_logits_ith(ctx, idx)
// - seq: sequence id to associate sampler state with
//
// returns:
// - token: sampled token
// - candidates: vector of candidate tokens
//
llama_token llama_sampling_sample(
struct llama_context * ctx,
struct llama_context * ctx_guidance,
struct llama_sampling_context & ctx_sampling,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
const int idx = 0,
llama_seq_id seq = 0);

8396
common/stb_image.h Normal file

File diff suppressed because it is too large Load Diff

View File

@ -49,7 +49,7 @@ According to the BLIS documentation, we could set the following
environment variables to modify the behavior of openmp: environment variables to modify the behavior of openmp:
```bash ```bash
export GOMP_GPU_AFFINITY="0-19" export GOMP_CPU_AFFINITY="0-19"
export BLIS_NUM_THREADS=14 export BLIS_NUM_THREADS=14
``` ```

View File

@ -25,9 +25,11 @@ else()
add_subdirectory(convert-llama2c-to-ggml) add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(simple) add_subdirectory(simple)
add_subdirectory(batched) add_subdirectory(batched)
add_subdirectory(batched-bench)
add_subdirectory(speculative) add_subdirectory(speculative)
add_subdirectory(parallel) add_subdirectory(parallel)
add_subdirectory(embd-input) add_subdirectory(embd-input)
add_subdirectory(llava)
add_subdirectory(llama-bench) add_subdirectory(llama-bench)
add_subdirectory(beam-search) add_subdirectory(beam-search)
if (LLAMA_METAL) if (LLAMA_METAL)

View File

@ -0,0 +1,5 @@
set(TARGET batched-bench)
add_executable(${TARGET} batched-bench.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View File

@ -0,0 +1,51 @@
# llama.cpp/example/batched-bench
Benchmark the batched decoding performance of `llama.cpp`
## Usage
There are 2 modes of operation:
- `prompt not shared` - each batch has a separate prompt of size `PP` (i.e. `N_KV = B*(PP + TG)`)
- `prompt is shared` - there is a common prompt of size `PP` used by all batches (i.e. `N_KV = PP + B*TG`)
```bash
./batched-bench MODEL_PATH [N_KV_MAX] [IS_PP_SHARED] [NGL] [MMQ] <PP> <TG> <PL>
# LLaMA 7B, F16, N_KV_MAX = 16384 (8GB), prompt not shared
./batched-bench ./models/llama-7b/ggml-model-f16.gguf 16384 0 99
# LLaMA 7B, Q8_0, N_KV_MAX = 16384 (8GB), prompt is shared
./batched-bench ./models/llama-7b/ggml-model-q8_0.gguf 16384 1 99
# custom set of batches
./batched-bench ./models/llama-7b/ggml-model-q8_0.gguf 2048 0 999 0 128,256,512 128,256 1,2,4,8,16,32
```
## Sample results
- `PP` - prompt tokens per batch
- `TG` - generated tokens per batch
- `B` - number of batches
- `N_KV` - required KV cache size
- `T_PP` - prompt processing time (i.e. time to first token)
- `S_PP` - prompt processing speed (`(B*PP)/T_PP` or `PP/T_PP`)
- `T_TG` - time to generate all batches
- `S_TG` - text generation speed (`(B*TG)/T_TG`)
- `T` - total time
- `S` - total speed (i.e. all tokens / total time)
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
| 128 | 128 | 1 | 256 | 0.108 | 1186.64 | 3.079 | 41.57 | 3.187 | 80.32 |
| 128 | 128 | 2 | 512 | 0.198 | 1295.19 | 5.029 | 50.90 | 5.227 | 97.95 |
| 128 | 128 | 4 | 1024 | 0.373 | 1373.96 | 6.878 | 74.44 | 7.251 | 141.23 |
| 128 | 128 | 8 | 2048 | 0.751 | 1363.27 | 7.344 | 139.43 | 8.095 | 252.99 |
| 128 | 128 | 16 | 4096 | 1.570 | 1304.68 | 8.455 | 242.23 | 10.024 | 408.60 |
| 128 | 128 | 32 | 8192 | 3.408 | 1201.73 | 8.801 | 465.40 | 12.209 | 670.96 |
| 128 | 256 | 1 | 384 | 0.107 | 1196.70 | 6.329 | 40.45 | 6.436 | 59.67 |
| 128 | 256 | 2 | 768 | 0.194 | 1317.45 | 10.239 | 50.00 | 10.433 | 73.61 |
| 128 | 256 | 4 | 1536 | 0.366 | 1399.03 | 13.960 | 73.35 | 14.326 | 107.22 |
| 128 | 256 | 8 | 3072 | 0.751 | 1363.92 | 15.110 | 135.54 | 15.861 | 193.69 |
| 128 | 256 | 16 | 6144 | 1.569 | 1304.93 | 18.073 | 226.64 | 19.642 | 312.80 |
| 128 | 256 | 32 | 12288 | 3.409 | 1201.35 | 19.223 | 426.15 | 22.633 | 542.93 |

View File

@ -0,0 +1,251 @@
#include "common.h"
#include "llama.h"
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
// mutates the input string
static std::vector<int> parse_list(char * p) {
std::vector<int> ret;
char * q = p;
while (*p) {
if (*p == ',') {
*p = '\0';
ret.push_back(std::atoi(q));
q = p + 1;
}
++p;
}
ret.push_back(std::atoi(q));
return ret;
}
int main(int argc, char ** argv) {
gpt_params params;
if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [N_KV_MAX] [IS_PP_SHARED] [NGL] [MMQ] <PP> <TG> <PL>\n" , argv[0]);
printf(" <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n");
printf(" example: %s ggml-model-f16.gguf 2048 0 999 0 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]);
return 1 ;
}
int n_kv_max = 2048;
int is_pp_shared = 0;
int n_gpu_layers = 0;
int mmq = 0;
std::vector<int> n_pp = { 128, 256, 512, 1024, 2048, 3584, 7680, };
std::vector<int> n_tg = { 128, 256, };
std::vector<int> n_pl = { 1, 2, 4, 8, 16, 32, };
//std::vector<int> n_pl = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 32, };
if (argc >= 2) {
params.model = argv[1];
}
if (argc >= 3) {
n_kv_max = std::atoi(argv[2]);
}
if (argc >= 4) {
is_pp_shared = std::atoi(argv[3]);
}
if (argc >= 5) {
n_gpu_layers = std::atoi(argv[4]);
}
if (argc >= 6) {
mmq = std::atoi(argv[5]);
}
if (argc >= 7) {
n_pp = parse_list(argv[6]);
}
if (argc >= 8) {
n_tg = parse_list(argv[7]);
}
if (argc >= 9) {
n_pl = parse_list(argv[8]);
}
// init LLM
llama_backend_init(params.numa);
// initialize the model
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = n_gpu_layers;
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
llama_context_params ctx_params = llama_context_default_params();
ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_max;
ctx_params.n_batch = 512;
ctx_params.mul_mat_q = mmq;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
}
llama_batch batch = llama_batch_init(n_kv_max, 0);
// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
}
}
return true;
};
// warm up
{
batch.n_tokens = 16;
for (int i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = 0;
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
}
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
LOG_TEE("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------");
for ( int i_pp = 0; i_pp < (int) n_pp.size(); ++i_pp) {
for ( int i_tg = 0; i_tg < (int) n_tg.size(); ++i_tg) {
for (int i_pl = 0; i_pl < (int) n_pl.size(); ++i_pl) {
const int pp = n_pp[i_pp];
const int tg = n_tg[i_tg];
const int pl = n_pl[i_pl];
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
if (n_ctx_req > n_kv_max) {
continue;
}
batch.n_tokens = is_pp_shared ? pp : pl*pp;
for (int i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = 0;
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
}
batch.logits[batch.n_tokens - 1] = true;
const auto t_pp_start = ggml_time_us();
llama_kv_cache_tokens_rm(ctx, -1, -1);
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, 0, pp);
}
}
const auto t_pp_end = ggml_time_us();
const auto t_tg_start = ggml_time_us();
for (int i = 0; i < tg; ++i) {
batch.n_tokens = pl;
for (int j = 0; j < pl; ++j) {
batch.token[j] = 0;
batch.pos[j] = pp + i;
batch.seq_id[j] = j;
batch.logits[j] = true;
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
}
const auto t_tg_end = ggml_time_us();
const int32_t n_kv = n_ctx_req;
const float t_pp = (t_pp_end - t_pp_start) / 1000000.0f;
const float t_tg = (t_tg_end - t_tg_start) / 1000000.0f;
const float t = t_pp + t_tg;
const float speed_pp = is_pp_shared ? pp / t_pp : pl*pp / t_pp;
const float speed_tg = pl*tg / t_tg;
const float speed = n_kv / t;
LOG_TEE("|%6d | %6d | %4d | %6d | %8.3f | %8.2f | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed);
}
}
}
llama_print_timings(ctx);
llama_batch_free(batch);
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
fprintf(stderr, "\n\n");
return 0;
}

View File

@ -128,21 +128,22 @@ bool eval_string(struct MyModel * mymodel,const char* str){
llama_token sampling_id(struct MyModel* mymodel) { llama_token sampling_id(struct MyModel* mymodel) {
llama_context* ctx = mymodel->ctx; llama_context* ctx = mymodel->ctx;
gpt_params params = mymodel->params; gpt_params params = mymodel->params;
llama_sampling_params & sparams = params.sampling_params;
// int n_ctx = llama_n_ctx(ctx); // int n_ctx = llama_n_ctx(ctx);
// out of user input, sample next token // out of user input, sample next token
const float temp = params.temp; const float temp = sparams.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : params.top_k; const int32_t top_k = sparams.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : sparams.top_k;
const float top_p = params.top_p; const float top_p = sparams.top_p;
const float tfs_z = params.tfs_z; const float tfs_z = sparams.tfs_z;
const float typical_p = params.typical_p; const float typical_p = sparams.typical_p;
// const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; // const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
// const float repeat_penalty = params.repeat_penalty; // const float repeat_penalty = params.repeat_penalty;
// const float alpha_presence = params.presence_penalty; // const float alpha_presence = params.presence_penalty;
// const float alpha_frequency = params.frequency_penalty; // const float alpha_frequency = params.frequency_penalty;
const int mirostat = params.mirostat; const int mirostat = sparams.mirostat;
const float mirostat_tau = params.mirostat_tau; const float mirostat_tau = sparams.mirostat_tau;
const float mirostat_eta = params.mirostat_eta; const float mirostat_eta = sparams.mirostat_eta;
// const bool penalize_nl = params.penalize_nl; // const bool penalize_nl = params.penalize_nl;
llama_token id = 0; llama_token id = 0;
@ -151,7 +152,7 @@ llama_token sampling_id(struct MyModel* mymodel) {
auto n_vocab = llama_n_vocab(llama_get_model(ctx)); auto n_vocab = llama_n_vocab(llama_get_model(ctx));
// Apply params.logit_bias map // Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { for (auto it = sparams.logit_bias.begin(); it != sparams.logit_bias.end(); it++) {
logits[it->first] += it->second; logits[it->first] += it->second;
} }

View File

@ -104,6 +104,7 @@ static void sigint_handler(int signo) {
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
llama_sampling_params & sparams = params.sampling_params;
g_params = &params; g_params = &params;
if (!gpt_params_parse(argc, argv, params)) { if (!gpt_params_parse(argc, argv, params)) {
@ -206,7 +207,7 @@ int main(int argc, char ** argv) {
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__); LOG("%s: load the model and apply lora adapter, if any\n", __func__);
std::tie(model, ctx) = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (params.cfg_scale > 1.f) { if (sparams.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_context_params_from_gpt_params(params); struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
ctx_guidance = llama_new_context_with_model(model, lparams); ctx_guidance = llama_new_context_with_model(model, lparams);
} }
@ -269,9 +270,9 @@ int main(int argc, char ** argv) {
int guidance_offset = 0; int guidance_offset = 0;
int original_prompt_len = 0; int original_prompt_len = 0;
if (ctx_guidance) { if (ctx_guidance) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt)); LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos); guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp)); LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos); std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
@ -312,7 +313,7 @@ int main(int argc, char ** argv) {
if (ctx_guidance) { if (ctx_guidance) {
LOG_TEE("\n"); LOG_TEE("\n");
LOG_TEE("%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str());
LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
for (int i = 0; i < (int) guidance_inp.size(); i++) { for (int i = 0; i < (int) guidance_inp.size(); i++) {
LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str()); LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
@ -358,7 +359,7 @@ int main(int argc, char ** argv) {
} }
} }
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n"); LOG_TEE("\n\n");
@ -376,8 +377,8 @@ int main(int argc, char ** argv) {
LOG_TEE("\n"); LOG_TEE("\n");
{ {
auto it = params.logit_bias.find(llama_token_eos(ctx)); auto it = sparams.logit_bias.find(llama_token_eos(ctx));
if (it != params.logit_bias.end() && it->second == -INFINITY) { if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
} }
} }
@ -434,6 +435,7 @@ int main(int argc, char ** argv) {
const int n_vocab = llama_n_vocab(model); const int n_vocab = llama_n_vocab(model);
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -552,7 +554,7 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates); const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
last_tokens.erase(last_tokens.begin()); last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id); last_tokens.push_back(id);

View File

@ -0,0 +1,20 @@
set(TARGET clip)
add_library(${TARGET} clip.cpp clip.h)
install(TARGETS ${TARGET} LIBRARY)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if (NOT MSVC)
target_compile_options(${TARGET} PRIVATE -Wno-cast-qual) # stb_image.h
endif()
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()
set(TARGET llava)
add_executable(${TARGET} llava.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama clip ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()

57
examples/llava/README.md Normal file
View File

@ -0,0 +1,57 @@
# LLaVA
Currently this implementation supports [llava-v1.5](https://huggingface.co/liuhaotian/llava-v1.5-7b) variants.
The pre-converted [7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
and [13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
models are available.
After API is confirmed, more models will be supported / uploaded.
## Usage
Build with cmake or run `make llava` to build it.
After building, run: `./llava` to see the usage. For example:
```sh
./llava -m llava-v1.5-7b/ggml-model-q5_k.gguf --mmproj llava-v1.5-7b/mmproj-model-f16.gguf --image path/to/an/image.jpg
```
**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so.
## Model conversion
- Clone `llava-v15-7b`` and `clip-vit-large-patch14-336`` locally:
```sh
git clone https://huggingface.co/liuhaotian/llava-v1.5-7b
git clone https://huggingface.co/openai/clip-vit-large-patch14-336
```
2. Use `llava-surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents:
```sh
python ./examples/llava/llava-surgery.py -m ../llava-v1.5-7b
```
3. Use `convert-image-encoder-to-gguf.py` to convert the LLaVA image encoder to GGUF:
```sh
python ./examples/llava/convert-image-encoder-to-gguf -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b
```
4. Use `convert.py` to convert the LLaMA part of LLaVA to GGUF:
```sh
python ./convert.py ../llava-v1.5-7b
```
Now both the LLaMA part and the image encoder is in the `llava-v1.5-7b` directory.
## TODO
- [ ] Support server mode.
- [ ] Support non-CPU backend for the image encoding part.
- [ ] Support different sampling methods.
- [ ] Support more model variants.

1062
examples/llava/clip.cpp Normal file

File diff suppressed because it is too large Load Diff

73
examples/llava/clip.h Normal file
View File

@ -0,0 +1,73 @@
#ifndef CLIP_H
#define CLIP_H
#include "ggml.h"
struct clip_ctx;
#ifdef __cplusplus
extern "C" {
#endif
struct clip_vision_hparams {
int32_t image_size;
int32_t patch_size;
int32_t hidden_size;
int32_t n_intermediate;
int32_t projection_dim;
int32_t n_head;
int32_t n_layer;
float eps;
};
struct clip_ctx * clip_model_load(const char * fname, const int verbosity);
void clip_free(struct clip_ctx * ctx);
size_t clip_embd_nbytes(struct clip_ctx * ctx);
int clip_n_patches(struct clip_ctx * ctx);
int clip_n_mmproj_embd(struct clip_ctx * ctx);
// RGB uint8 image
struct clip_image_u8 {
int nx;
int ny;
uint8_t * data;
size_t size;
};
// RGB float32 image (NHWC)
// Memory layout: RGBRGBRGB...
struct clip_image_f32 {
int nx;
int ny;
float * data;
size_t size;
};
struct clip_image_u8_batch {
struct clip_image_u8 * data;
size_t size;
};
struct clip_image_f32_batch {
struct clip_image_f32 * data;
size_t size;
};
struct clip_image_u8 * make_clip_image_u8();
struct clip_image_f32 * make_clip_image_f32();
bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square);
bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec);
bool clip_image_batch_encode(const struct clip_ctx * ctx, const int n_threads, const struct clip_image_f32_batch * imgs,
float * vec);
bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype);
#ifdef __cplusplus
}
#endif
#endif // CLIP_H

View File

@ -0,0 +1,250 @@
import argparse
import os
import json
import torch
import numpy as np
from gguf import *
from transformers import CLIPModel, CLIPProcessor
TEXT = "clip.text"
VISION = "clip.vision"
def k(raw_key: str, arch: str) -> str:
return raw_key.format(arch=arch)
def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool:
if name in (
"logit_scale",
"text_model.embeddings.position_ids",
"vision_model.embeddings.position_ids",
):
return True
if has_llava and name in ["visual_projection.weight", "vision_model.post_layernorm.weight", "vision_model.post_layernorm.bias"]:
return True
if name.startswith("v") and not has_vision:
return True
if name.startswith("t") and not has_text:
return True
return False
def get_tensor_name(name: str) -> str:
if "projection" in name:
return name
if "mm_projector" in name:
return name.replace("model.mm_projector", "mm")
return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
ap = argparse.ArgumentParser(prog="convert_hf_to_gguf.py")
ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
ap.add_argument("--text-only", action="store_true", required=False,
help="Save a text-only model. It can't be used to encode images")
ap.add_argument("--vision-only", action="store_true", required=False,
help="Save a vision-only model. It can't be used to encode texts")
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values")
ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values")
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
args = ap.parse_args()
if args.text_only and args.vision_only:
print("--text-only and --image-only arguments cannot be specified at the same time.")
exit(1)
if args.use_f32:
print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.")
# output in the same directory as the model if output_dir is None
dir_model = args.model_dir
with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
vocab = json.load(f)
tokens = [key for key in vocab]
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
config = json.load(f)
v_hparams = config["vision_config"]
t_hparams = config["text_config"]
# possible data types
# ftype == 0 -> float32
# ftype == 1 -> float16
#
# map from ftype to string
ftype_str = ["f32", "f16"]
ftype = 1
if args.use_f32:
ftype = 0
model = CLIPModel.from_pretrained(dir_model)
processor = CLIPProcessor.from_pretrained(dir_model)
fname_middle = None
has_text_encoder = True
has_vision_encoder = True
has_llava_projector = False
if args.text_only:
fname_middle = "text-"
has_vision_encoder = False
elif args.vision_only:
fname_middle = "vision-"
has_text_encoder = False
elif args.llava_projector is not None:
fname_middle = "mmproj-"
has_text_encoder = False
has_llava_projector = True
else:
fname_middle = ""
output_dir = args.output_dir if args.output_dir is not None else dir_model
os.makedirs(output_dir, exist_ok=True)
output_prefix = os.path.basename(output_dir).replace("ggml_", "")
fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf")
fout = GGUFWriter(path=fname_out, arch="clip")
fout.add_bool("clip.has_text_encoder", has_text_encoder)
fout.add_bool("clip.has_vision_encoder", has_vision_encoder)
fout.add_bool("clip.has_llava_projector", has_llava_projector)
fout.add_file_type(ftype)
model_name = config["_name_or_path"] if "_name_or_path" in config else os.path.basename(dir_model)
fout.add_name(model_name)
if args.text_only:
fout.add_description("text-only CLIP model")
elif args.vision_only and not has_llava_projector:
fout.add_description("vision-only CLIP model")
elif has_llava_projector:
fout.add_description("image encoder for LLaVA")
else:
fout.add_description("two-tower CLIP model")
if has_text_encoder:
# text_model hparams
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"])
fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"]))
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"])
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"])
fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
fout.add_token_list(tokens)
if has_vision_encoder:
# vision_model hparams
fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"])
fout.add_uint32("clip.vision.projection_dim", v_hparams.get("projection_dim", config["projection_dim"]))
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
image_mean = processor.image_processor.image_mean if args.image_mean is None else args.image_mean
image_std = processor.image_processor.image_std if args.image_std is None else args.image_std
fout.add_array("clip.vision.image_mean", image_mean)
fout.add_array("clip.vision.image_std", image_std)
use_gelu = v_hparams["hidden_act"] == "gelu"
fout.add_bool("clip.use_gelu", use_gelu)
if has_llava_projector:
model.vision_model.encoder.layers.pop(-1)
projector = torch.load(args.llava_projector)
for name, data in projector.items():
name = get_tensor_name(name)
if data.ndim == 2:
data = data.squeeze().numpy().astype(np.float16)
else:
data = data.squeeze().numpy().astype(np.float32)
fout.add_tensor(name, data)
print("Projector tensors added\n")
state_dict = model.state_dict()
for name, data in state_dict.items():
if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector):
# we don't need this
print(f"skipping parameter: {name}")
continue
name = get_tensor_name(name)
data = data.squeeze().numpy()
n_dims = len(data.shape)
# ftype == 0 -> float32, ftype == 1 -> float16
ftype_cur = 0
if n_dims == 4:
print(f"tensor {name} is always saved in f16")
data = data.astype(np.float16)
ftype_cur = 1
elif ftype == 1:
if name[-7:] == ".weight" and n_dims == 2:
print(" Converting to float16")
data = data.astype(np.float16)
ftype_cur = 1
else:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
else:
if data.dtype != np.float32:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
fout.add_tensor(name, data)
fout.write_header_to_file()
fout.write_kv_data_to_file()
fout.write_tensors_to_file()
fout.close()
print("Done. Output file: " + fname_out)

View File

@ -0,0 +1,30 @@
import argparse
import glob
import os
import torch
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", help="Path to LLaVA v1.5 model")
args = ap.parse_args()
# find the model part that includes the the multimodal projector weights
path = sorted(glob.glob(f"{args.model}/pytorch_model*.bin"))[-1]
checkpoint = torch.load(path)
# get a list of mm tensor names
mm_tensors = [k for k, v in checkpoint.items() if k.startswith("model.mm_projector")]
# store these tensors in a new dictionary and torch.save them
projector = {name: checkpoint[name] for name in mm_tensors}
torch.save(projector, f"{args.model}/llava.projector")
# remove these tensors from the checkpoint and save it again
for name in mm_tensors:
del checkpoint[name]
torch.save(checkpoint, path)
print("Done!")
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")

View File

@ -0,0 +1,145 @@
#pragma once
// this one and clip lib will be eventually merged to a single lib, let's keep it this way for now
#include "common.h"
#include "llama.h"
#include <cstdio>
#include <cstdlib>
#include <vector>
inline bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
for (int i = 0; i < N; i += n_batch) {
int n_eval = N - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
*n_past += n_eval;
}
return true;
}
inline bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past) {
int N = (int) tokens.size();
for (int i = 0; i < N; i += n_batch) {
int n_eval = (int) tokens.size() - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
*n_past += n_eval;
}
return true;
}
inline bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) {
std::vector<llama_token> tokens;
tokens.push_back(id);
return eval_tokens(ctx_llama, tokens, 1, n_past);
}
inline bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past){
std::string str2 = str;
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, true);
eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
return true;
}
// TODO: use common/sampling.h
inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
// out of user input, sample next token
const float temp = params.sampling_params.temp;
const int32_t top_k = params.sampling_params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : params.sampling_params.top_k;
const float top_p = params.sampling_params.top_p;
const float tfs_z = params.sampling_params.tfs_z;
const float typical_p = params.sampling_params.typical_p;
// const int32_t repeat_last_n = params.sampling_params.repeat_last_n < 0 ? n_ctx : params.sampling_params.repeat_last_n;
// const float repeat_penalty = params.sampling_params.repeat_penalty;
// const float alpha_presence = params.sampling_params.presence_penalty;
// const float alpha_frequency = params.sampling_params.frequency_penalty;
const int mirostat = params.sampling_params.mirostat;
const float mirostat_tau = params.sampling_params.mirostat_tau;
const float mirostat_eta = params.sampling_params.mirostat_eta;
// const bool penalize_nl = params.sampling_params.penalize_nl;
llama_token id = 0;
{
auto logits = llama_get_logits(ctx_llama);
auto n_vocab = llama_n_vocab(llama_get_model(ctx_llama));
// Apply params.logit_bias map
for (auto it = params.sampling_params.logit_bias.begin(); it != params.sampling_params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// TODO: Apply penalties
// float nl_logit = logits[llama_token_nl(ctx)];
// auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
// llama_sample_repetition_penalty(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, repeat_penalty);
// llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, alpha_frequency, alpha_presence);
// if (!penalize_nl) {
// logits[llama_token_nl(ctx)] = nl_logit;
// }
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx_llama, &candidates_p);
} else {
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temp(ctx_llama, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temp(ctx_llama, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1);
llama_sample_tail_free(ctx_llama, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx_llama, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1);
llama_sample_temp(ctx_llama, &candidates_p, temp);
id = llama_sample_token(ctx_llama, &candidates_p);
}
}
}
return id;
}
inline const char * sample(struct llama_context * ctx_llama, gpt_params & params, int * n_past) {
int id = sample_id(ctx_llama, params);
static std::string ret;
if (id == llama_token_eos(ctx_llama)) {
ret = "</s>";
} else {
ret = llama_token_to_piece(ctx_llama, id);
}
eval_id(ctx_llama, id, n_past);
return ret.c_str();
}

156
examples/llava/llava.cpp Normal file
View File

@ -0,0 +1,156 @@
#include "clip.h"
#include "llava-utils.h"
#include "common.h"
#include "llama.h"
#include <cstdio>
#include <cstdlib>
#include <vector>
static void show_additional_info(int /*argc*/, char ** argv) {
printf("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
}
int main(int argc, char ** argv) {
ggml_time_init();
gpt_params params;
if (!gpt_params_parse(argc, argv, params)) {
show_additional_info(argc, argv);
return 1;
}
if (params.mmproj.empty() || params.image.empty()) {
gpt_print_usage(argc, argv, params);
show_additional_info(argc, argv);
return 1;
}
const char * clip_path = params.mmproj.c_str();
const char * img_path = params.image.c_str();
if (params.prompt.empty()) {
params.prompt = "describe the image in detail.";
}
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
// load and preprocess the image
clip_image_u8 img;
clip_image_f32 img_res;
if (!clip_image_load_from_file(img_path, &img)) {
fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path);
clip_free(ctx_clip);
return 1;
}
if (!clip_image_preprocess(ctx_clip, &img, &img_res, /*pad2square =*/ true)) {
fprintf(stderr, "%s: unable to preprocess %s\n", __func__, img_path);
clip_free(ctx_clip);
return 1;
}
int n_img_pos = clip_n_patches(ctx_clip);
int n_img_embd = clip_n_mmproj_embd(ctx_clip);
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
if (!image_embd) {
fprintf(stderr, "Unable to allocate memory for image embeddings\n");
return 1;
}
const int64_t t_img_enc_start_us = ggml_time_us();
if (!clip_image_encode(ctx_clip, params.n_threads, &img_res, image_embd)) {
fprintf(stderr, "Unable to encode image\n");
return 1;
}
const int64_t t_img_enc_end_us = ggml_time_us();
// we get the embeddings, free up the memory required for CLIP
clip_free(ctx_clip);
llama_backend_init(params.numa);
llama_model_params model_params = llama_model_default_params();
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = params.n_ctx < 2048 ? 2048 : params.n_ctx; // we need a longer context size to process image embeddings
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
if (ctx_llama == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
}
// make sure that the correct mmproj was used, i.e., compare apples to apples
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
if (n_img_embd != n_llama_embd) {
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_llama_embd);
llama_free(ctx_llama);
llama_free_model(model);
llama_backend_free();
free(image_embd);
return 1;
}
// process the prompt
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
int n_past = 0;
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
// GG: are we sure that the should be a trailing whitespace at the end of this string?
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params.n_batch, &n_past);
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past);
// generate the response
printf("\n");
for (int i = 0; i < max_tgt_len; i++) {
const char * tmp = sample(ctx_llama, params, &n_past);
if (strcmp(tmp, "</s>") == 0) break;
printf("%s", tmp);
fflush(stdout);
}
printf("\n");
{
const float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / n_img_pos);
}
llama_print_timings(ctx_llama);
llama_free(ctx_llama);
llama_free_model(model);
llama_backend_free();
free(image_embd);
return 0;
}

View File

@ -109,6 +109,7 @@ int main(int argc, char ** argv) {
if (!gpt_params_parse(argc, argv, params)) { if (!gpt_params_parse(argc, argv, params)) {
return 1; return 1;
} }
llama_sampling_params & sparams = params.sampling_params;
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("main", "log")); log_set_target(log_filename_generator("main", "log"));
@ -179,7 +180,7 @@ int main(int argc, char ** argv) {
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__); LOG("%s: load the model and apply lora adapter, if any\n", __func__);
std::tie(model, ctx) = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (params.cfg_scale > 1.f) { if (sparams.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_context_params_from_gpt_params(params); struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
ctx_guidance = llama_new_context_with_model(model, lparams); ctx_guidance = llama_new_context_with_model(model, lparams);
} }
@ -257,9 +258,9 @@ int main(int argc, char ** argv) {
int guidance_offset = 0; int guidance_offset = 0;
int original_prompt_len = 0; int original_prompt_len = 0;
if (ctx_guidance) { if (ctx_guidance) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt)); LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos); guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp)); LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos); std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
@ -296,6 +297,9 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n", LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n",
__func__, n_matching_session_tokens, embd_inp.size()); __func__, n_matching_session_tokens, embd_inp.size());
} }
// remove any "future" tokens that we might have inherited from the previous session
llama_kv_cache_tokens_rm(ctx, n_matching_session_tokens, -1);
} }
LOGLN( LOGLN(
@ -343,7 +347,7 @@ int main(int argc, char ** argv) {
if (ctx_guidance) { if (ctx_guidance) {
LOG_TEE("\n"); LOG_TEE("\n");
LOG_TEE("%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str());
LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
for (int i = 0; i < (int) guidance_inp.size(); i++) { for (int i = 0; i < (int) guidance_inp.size(); i++) {
LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str()); LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
@ -395,7 +399,7 @@ int main(int argc, char ** argv) {
} }
} }
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n"); LOG_TEE("\n\n");
@ -413,8 +417,8 @@ int main(int argc, char ** argv) {
LOG_TEE("\n"); LOG_TEE("\n");
{ {
auto it = params.logit_bias.find(llama_token_eos(ctx)); auto it = sparams.logit_bias.find(llama_token_eos(ctx));
if (it != params.logit_bias.end() && it->second == -INFINITY) { if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
} }
} }
@ -469,6 +473,7 @@ int main(int argc, char ** argv) {
const int n_vocab = llama_n_vocab(model); const int n_vocab = llama_n_vocab(model);
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -543,9 +548,6 @@ int main(int argc, char ** argv) {
if (i > 0) { if (i > 0) {
embd.erase(embd.begin(), embd.begin() + i); embd.erase(embd.begin(), embd.begin() + i);
} }
// remove any "future" tokens that we might have inherited from the session from the KV cache
llama_kv_cache_tokens_rm(ctx, n_past, -1);
} }
// evaluate tokens in batches // evaluate tokens in batches
@ -625,7 +627,7 @@ int main(int argc, char ** argv) {
LOG("saved session to %s\n", path_session.c_str()); LOG("saved session to %s\n", path_session.c_str());
} }
const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates); const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
last_tokens.erase(last_tokens.begin()); last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id); last_tokens.push_back(id);

View File

@ -125,6 +125,8 @@ int main(int argc, char ** argv) {
params.logits_all = true; params.logits_all = true;
std::tie(model, ctx) = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL);
// load the prompts from an external file if there are any // load the prompts from an external file if there are any
if (params.prompt.empty()) { if (params.prompt.empty()) {
printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n"); printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
@ -339,7 +341,7 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n", //printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i); const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id);
if (client.n_decoded == 1) { if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients // start measuring generation time after the first token to make sure all concurrent clients
@ -384,7 +386,7 @@ int main(int argc, char ** argv) {
n_total_prompt += client.n_prompt; n_total_prompt += client.n_prompt;
n_total_gen += client.n_decoded; n_total_gen += client.n_decoded;
llama_sampling_context_reset(ctx_sampling, client.seq_id);
client.seq_id = -1; client.seq_id = -1;
} }

View File

@ -8,9 +8,10 @@
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
llama_sampling_params & sparams = params.sampling_params;
params.seed = 42; params.seed = 42;
params.n_threads = 4; params.n_threads = 4;
params.repeat_last_n = 64; sparams.repeat_last_n = 64;
params.prompt = "The quick brown fox"; params.prompt = "The quick brown fox";
if (!gpt_params_parse(argc, argv, params)) { if (!gpt_params_parse(argc, argv, params)) {
@ -24,7 +25,7 @@ int main(int argc, char ** argv) {
} }
auto n_past = 0; auto n_past = 0;
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0); auto last_n_tokens_data = std::vector<llama_token>(sparams.repeat_last_n, 0);
// init // init
llama_model * model; llama_model * model;

File diff suppressed because it is too large Load Diff

View File

@ -136,6 +136,11 @@
display: block; display: block;
} }
fieldset label.slim {
margin: 0 0.5em;
display: inline;
}
header, footer { header, footer {
text-align: center; text-align: center;
} }
@ -145,6 +150,14 @@
color: #888; color: #888;
} }
.mode-chat textarea[name=prompt] {
height: 4.5em;
}
.mode-completion textarea[name=prompt] {
height: 10em;
}
@keyframes loading-bg-wipe { @keyframes loading-bg-wipe {
0% { 0% {
@ -187,7 +200,7 @@
template: "{{prompt}}\n\n{{history}}\n{{char}}:", template: "{{prompt}}\n\n{{history}}\n{{char}}:",
historyTemplate: "{{name}}: {{message}}", historyTemplate: "{{name}}: {{message}}",
transcript: [], transcript: [],
type: "chat", type: "chat", // "chat" | "completion"
char: "Llama", char: "Llama",
user: "User", user: "User",
}) })
@ -365,13 +378,44 @@
return String(str).replaceAll(/\{\{(.*?)\}\}/g, (_, key) => template(settings[key])); return String(str).replaceAll(/\{\{(.*?)\}\}/g, (_, key) => template(settings[key]));
} }
async function runLlama(prompt, llamaParams, char) {
const currentMessages = [];
const history = session.value.transcript;
if (controller.value) {
throw new Error("already running");
}
controller.value = new AbortController();
for await (const chunk of llama(prompt, llamaParams, {controller: controller.value})) {
const data = chunk.data;
if (data.stop) {
while (
currentMessages.length > 0 &&
currentMessages[currentMessages.length - 1].content.match(/\n$/) != null
) {
currentMessages.pop();
}
transcriptUpdate([...history, [char, currentMessages]])
console.log("Completion finished: '", currentMessages.map(msg => msg.content).join(''), "', summary: ", data);
} else {
currentMessages.push(data);
transcriptUpdate([...history, [char, currentMessages]])
}
if (data.timings) {
llamaStats.value = data.timings;
}
}
controller.value = null;
}
// send message to server // send message to server
const chat = async (msg) => { const chat = async (msg) => {
if (controller.value) { if (controller.value) {
console.log('already running...'); console.log('already running...');
return; return;
} }
controller.value = new AbortController();
transcriptUpdate([...session.value.transcript, ["{{user}}", msg]]) transcriptUpdate([...session.value.transcript, ["{{user}}", msg]])
@ -391,42 +435,25 @@
).join("\n"), ).join("\n"),
}); });
const currentMessages = []; await runLlama(prompt, {
const history = session.value.transcript
const llamaParams = {
...params.value, ...params.value,
stop: ["</s>", template("{{char}}:"), template("{{user}}:")], stop: ["</s>", template("{{char}}:"), template("{{user}}:")],
}, "{{char}}");
} }
for await (const chunk of llama(prompt, llamaParams, { controller: controller.value })) { const runCompletion = async () => {
const data = chunk.data; if (controller.value) {
console.log('already running...');
if (data.stop) { return;
while (
currentMessages.length > 0 &&
currentMessages[currentMessages.length - 1].content.match(/\n$/) != null
) {
currentMessages.pop();
} }
transcriptUpdate([...history, ["{{char}}", currentMessages]]) const {prompt} = session.value;
console.log("Completion finished: '", currentMessages.map(msg => msg.content).join(''), "', summary: ", data); transcriptUpdate([...session.value.transcript, ["", prompt]]);
} else { await runLlama(prompt, {
currentMessages.push(data); ...params.value,
transcriptUpdate([...history, ["{{char}}", currentMessages]]) stop: [],
}, "");
} }
if (data.timings) {
llamaStats.value = data.timings;
}
}
controller.value = null;
}
function MessageInput() {
const message = useSignal("")
const stop = (e) => { const stop = (e) => {
e.preventDefault(); e.preventDefault();
if (controller.value) { if (controller.value) {
@ -440,6 +467,9 @@
transcriptUpdate([]); transcriptUpdate([]);
} }
function MessageInput() {
const message = useSignal("")
const submit = (e) => { const submit = (e) => {
stop(e); stop(e);
chat(message.value); chat(message.value);
@ -474,6 +504,19 @@
` `
} }
function CompletionControls() {
const submit = (e) => {
stop(e);
runCompletion();
}
return html`
<div>
<button onclick=${submit} type="button" disabled=${generating.value}>Start</button>
<button onclick=${stop} disabled=${!generating.value}>Stop</button>
<button onclick=${reset}>Reset</button>
</div>`;
}
const ChatLog = (props) => { const ChatLog = (props) => {
const messages = session.value.transcript; const messages = session.value.transcript;
const container = useRef(null) const container = useRef(null)
@ -497,7 +540,11 @@
data; data;
message = html`<${Markdownish} text=${template(text)} />` message = html`<${Markdownish} text=${template(text)} />`
} }
if(user) {
return html`<p key=${index}><strong>${template(user)}:</strong> ${message}</p>` return html`<p key=${index}><strong>${template(user)}:</strong> ${message}</p>`
} else {
return html`<p key=${index}>${message}</p>`
}
}; };
return html` return html`
@ -574,18 +621,31 @@
userTemplateAutosave() userTemplateAutosave()
}, [session.value, params.value]) }, [session.value, params.value])
return html` const GrammarControl = () => (
<form> html`
<fieldset> <div>
<${UserTemplateResetButton}/> <label for="template">Grammar</label>
</fieldset> <textarea id="grammar" name="grammar" placeholder="Use gbnf or JSON Schema+convert" value="${params.value.grammar}" rows=4 oninput=${updateParams}/>
<input type="text" name="prop-order" placeholder="order: prop1,prop2,prop3" oninput=${updateGrammarJsonSchemaPropOrder} />
<button type="button" onclick=${convertJSONSchemaGrammar}>Convert JSON Schema</button>
</div>
`
);
const PromptControlFieldSet = () => (
html`
<fieldset> <fieldset>
<div> <div>
<label for="prompt">Prompt</label> <label htmlFor="prompt">Prompt</label>
<textarea type="text" name="prompt" value="${session.value.prompt}" rows=4 oninput=${updateSession}/> <textarea type="text" name="prompt" value="${session.value.prompt}" oninput=${updateSession}/>
</div> </div>
</fieldset> </fieldset>
`
);
const ChatConfigForm = () => (
html`
${PromptControlFieldSet()}
<fieldset class="two"> <fieldset class="two">
<div> <div>
@ -609,15 +669,30 @@
<label for="template">Chat history template</label> <label for="template">Chat history template</label>
<textarea id="template" name="historyTemplate" value="${session.value.historyTemplate}" rows=1 oninput=${updateSession}/> <textarea id="template" name="historyTemplate" value="${session.value.historyTemplate}" rows=1 oninput=${updateSession}/>
</div> </div>
${GrammarControl()}
</fieldset>
`
);
const CompletionConfigForm = () => (
html`
${PromptControlFieldSet()}
<fieldset>${GrammarControl()}</fieldset>
`
);
return html`
<form>
<fieldset class="two">
<${UserTemplateResetButton}/>
<div> <div>
<label for="template">Grammar</label> <label class="slim"><input type="radio" name="type" value="chat" checked=${session.value.type === "chat"} oninput=${updateSession} /> Chat</label>
<textarea id="grammar" name="grammar" placeholder="Use gbnf or JSON Schema+convert" value="${params.value.grammar}" rows=4 oninput=${updateParams}/> <label class="slim"><input type="radio" name="type" value="completion" checked=${session.value.type === "completion"} oninput=${updateSession} /> Completion</label>
<input type="text" name="prop-order" placeholder="order: prop1,prop2,prop3" oninput=${updateGrammarJsonSchemaPropOrder} />
<button type="button" onclick=${convertJSONSchemaGrammar}>Convert JSON Schema</button>
</div> </div>
</fieldset> </fieldset>
${session.value.type === 'chat' ? ChatConfigForm() : CompletionConfigForm()}
<fieldset class="two"> <fieldset class="two">
${IntField({label: "Predictions", max: 2048, min: -1, name: "n_predict", value: params.value.n_predict})} ${IntField({label: "Predictions", max: 2048, min: -1, name: "n_predict", value: params.value.n_predict})}
${FloatField({label: "Temperature", max: 1.5, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature})} ${FloatField({label: "Temperature", max: 1.5, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature})}
@ -851,7 +926,7 @@
function App(props) { function App(props) {
return html` return html`
<div> <div class="mode-${session.value.type}">
<header> <header>
<h1>llama.cpp</h1> <h1>llama.cpp</h1>
</header> </header>
@ -861,7 +936,7 @@
</main> </main>
<section id="write"> <section id="write">
<${MessageInput} /> <${session.value.type === 'chat' ? MessageInput : CompletionControls} />
</section> </section>
<footer> <footer>

View File

@ -380,6 +380,7 @@ struct llama_server_context
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
bool all_slots_are_idle = false; bool all_slots_are_idle = false;
gpt_params params; gpt_params params;
llama_sampling_context ctx_sampling;
int n_ctx; int n_ctx;
int n_vocab; int n_vocab;
bool clean_kv_cache = true; bool clean_kv_cache = true;
@ -402,11 +403,29 @@ struct llama_server_context
llama_free_model(model); llama_free_model(model);
model = nullptr; model = nullptr;
} }
for(auto &slot : slots) {
if(slot.grammar) {
llama_grammar_free(slot.grammar);
} }
void rewind()
{
params.antiprompt.clear();
params.grammar.clear();
num_prompt_tokens = 0;
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(n_ctx);
generated_token_probs.clear();
truncated = false;
stopped_eos = false;
stopped_word = false;
stopped_limit = false;
stopping_word = "";
multibyte_pending = 0;
n_remain = 0;
n_past = 0;
if (grammar != nullptr) {
llama_grammar_free(grammar);
grammar = nullptr;
} }
} }
@ -491,59 +510,28 @@ struct llama_server_context
return prompt_tokens; return prompt_tokens;
} }
void processPrompt() { bool loadGrammar()
//params.n_keep = std::min(n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal
// if (num_prompt_tokens >= (size_t)n_ctx)
// {
// const int n_left = (n_ctx - params.n_keep) / 2;
// std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
// const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
// new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
// std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
// LOG_VERBOSE("input truncated", {
// {"n_ctx", n_ctx},
// {"n_keep", params.n_keep},
// {"n_left", n_left},
// {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
// });
// truncated = true;
// prompt_tokens = new_tokens;
// }
// else
// {
// const size_t ps = num_prompt_tokens;
// std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
// std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
// }
// compare the evaluated prompt with the new prompt
}
llama_client_slot* getSlot(int id) {
for (llama_client_slot & slot : slots)
{ {
if ((id == -1 && slot.available()) || slot.id == id) if (!params.grammar.empty()) {
{ parsed_grammar = grammar_parser::parse(params.grammar.c_str());
return &slot; // will be empty (default) if there are parse errors
} if (parsed_grammar.rules.empty()) {
} LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
return nullptr;
}
bool launchSlot(llama_client_slot* &slot) {
if(!slot->loadGrammar()) {
return false; return false;
} }
all_slots_are_idle = false; grammar_parser::print_grammar(stderr, parsed_grammar);
slot->command = LOAD_PROMPT;
LOG_TEE("slot %i is processing\n", slot->id); {
auto it = params.logit_bias.find(llama_token_eos(ctx));
if (it != params.logit_bias.end() && it->second == -INFINITY) {
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
}
}
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
return true; return true;
} }
@ -604,15 +592,15 @@ struct llama_server_context
// std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); // std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
// } // }
// // compare the evaluated prompt with the new prompt // compare the evaluated prompt with the new prompt
// n_past = common_part(embd, prompt_tokens); n_past = common_part(embd, prompt_tokens);
// embd = prompt_tokens; embd = prompt_tokens;
// if (n_past == num_prompt_tokens) if (n_past == num_prompt_tokens)
// { {
// // we have to evaluate at least 1 token to generate logits. // we have to evaluate at least 1 token to generate logits.
// printf("we have to evaluate at least 1 token to generate logits\n"); printf("we have to evaluate at least 1 token to generate logits\n");
// n_past--; n_past--;
// } }
// LOG_VERBOSE("prompt ingested", { // LOG_VERBOSE("prompt ingested", {
// {"n_past", n_past}, // {"n_past", n_past},
@ -629,77 +617,168 @@ struct llama_server_context
{ {
llama_kv_cache_seq_rm(ctx, i, 0, -1); llama_kv_cache_seq_rm(ctx, i, 0, -1);
} }
clean_kv_cache = false; params.n_keep = std::min(n_ctx - 4, params.n_keep);
}
void updateSystemPrompt() { // if input prompt is too big, truncate like normal
tokens_system = ::llama_tokenize(ctx, system_prompt, true); if (num_prompt_tokens >= (size_t)n_ctx)
n_tokens_system = tokens_system.size();
batch.n_tokens = n_tokens_system;
cleanKVCache();
for (int32_t i = 0; i < batch.n_tokens; ++i)
{ {
batch.token[i] = tokens_system[i]; const int n_left = (n_ctx - params.n_keep) / 2;
batch.pos[i] = i; std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
batch.seq_id[i] = 0; const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
batch.logits[i] = false; new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
} std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
if (llama_decode(ctx, batch) != 0) LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
truncated = true;
prompt_tokens = new_tokens;
}
else
{ {
LOG_TEE("%s: llama_decode() failed\n", __func__); const size_t ps = num_prompt_tokens;
return; std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
} }
// assign the system KV cache to all parallel sequences // compare the evaluated prompt with the new prompt
for (int32_t i = 1; i < params.n_parallel; ++i) n_past = common_part(embd, prompt_tokens);
// since #3228 we now have to manually manage the KV cache
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
embd = prompt_tokens;
if (n_past == num_prompt_tokens)
{ {
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system); // we have to evaluate at least 1 token to generate logits.
n_past--;
} }
LOG_TEE("system prompt updated\n"); LOG_VERBOSE("prompt ingested", {
update_system_prompt = false; {"n_past", n_past},
{"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)},
{"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
});
has_next_token = true;
} }
void notifySystemPromptChanged() { void beginCompletion()
// release all slots
for (llama_client_slot &slot : slots)
{ {
slot.release(); // number of tokens to keep when resetting context
} n_remain = params.n_predict;
waitAllAreIdle(); llama_set_rng_seed(ctx, params.seed);
all_slots_are_idle = true;
// wait until system prompt load
update_system_prompt = true;
while(update_system_prompt) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
// system prompt loaded, continue
} }
void processSystemPromptData(json sys_props) { completion_token_output nextToken()
system_prompt = sys_props.value("system_prompt", ""); {
user_name = sys_props.value("anti_prompt", ""); completion_token_output result;
assistant_name = sys_props.value("assistant_name", ""); result.tok = -1;
notifySystemPromptChanged();
if (embd.size() >= (size_t)n_ctx)
{
// Shift context
const int n_left = n_past - params.n_keep - 1;
const int n_discard = n_left/2;
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++)
{
embd[i - n_discard] = embd[i];
}
embd.resize(embd.size() - n_discard);
n_past -= n_discard;
truncated = true;
LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
});
} }
void waitAllAreIdle() { bool tg = true;
bool wait = true; while (n_past < embd.size())
while(wait) {
wait = false;
for (auto &slot : slots)
{ {
if (!slot.available()) int n_eval = (int)embd.size() - n_past;
tg = n_eval == 1;
if (n_eval > params.n_batch)
{ {
wait = true; n_eval = params.n_batch;
break; }
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0)))
{
LOG_ERROR("failed to eval", {
{"n_eval", n_eval},
{"n_past", n_past},
{"embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
});
has_next_token = false;
return result;
}
n_past += n_eval;
}
if (params.n_predict == 0)
{
has_next_token = false;
result.tok = llama_token_eos(ctx);
return result;
}
{
// out of user input, sample next token
std::vector<llama_token_data> candidates;
candidates.reserve(llama_n_vocab(model));
result.tok = llama_sample_token(ctx, NULL, grammar, params, last_n_tokens, candidates);
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
const int32_t n_probs = params.n_probs;
if (params.temp <= 0 && n_probs > 0)
{
// For llama_sample_token_greedy we need to sort candidates
llama_sample_softmax(ctx, &candidates_p);
}
for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
{
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(result.tok);
if (tg) {
num_tokens_predicted++;
} }
} }
// add it to the context
embd.push_back(result.tok);
// decrement remaining sampling budget
--n_remain;
if (!embd.empty() && embd.back() == llama_token_eos(ctx))
{
// stopping_word = llama_token_to_piece(ctx, embd.back());
has_next_token = false;
stopped_eos = true;
LOG_VERBOSE("eos token found", {});
return result;
} }
has_next_token = params.n_predict == -1 || n_remain != 0;
return result;
} }
size_t findStoppingStrings(const size_t last_token_size, size_t findStoppingStrings(const size_t last_token_size,
@ -754,7 +833,7 @@ struct llama_server_context
params.n_predict) || params.n_predict) ||
stop_pos != std::string::npos)); stop_pos != std::string::npos));
if (slot.params.n_probs > 0) if (params.n_probs > 0)
{ {
slot.generated_token_probs.push_back(result); slot.generated_token_probs.push_back(result);
} }
@ -1016,6 +1095,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" -h, --help show this help message and exit\n"); printf(" -h, --help show this help message and exit\n");
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled"); printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
printf(" -tb N, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)\n");
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n"); printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n"); printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n");
@ -1166,6 +1246,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
params.n_threads = std::stoi(argv[i]); params.n_threads = std::stoi(argv[i]);
} }
else if (arg == "--threads-batch" || arg == "-tb")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.n_threads_batch = std::stoi(argv[i]);
}
else if (arg == "-b" || arg == "--batch-size") else if (arg == "-b" || arg == "--batch-size")
{ {
if (++i >= argc) if (++i >= argc)
@ -1343,35 +1432,35 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
static json format_generation_settings(llama_server_context &llama, llama_client_slot* &slot) static json format_generation_settings(llama_server_context &llama, llama_client_slot* &slot)
{ {
const auto eos_bias = slot->params.logit_bias.find(llama_token_eos(llama.ctx)); const auto eos_bias = llama.params.logit_bias.find(llama_token_eos(llama.ctx));
const bool ignore_eos = eos_bias != slot->params.logit_bias.end() && const bool ignore_eos = eos_bias != llama.params.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second); eos_bias->second < 0.0f && std::isinf(eos_bias->second);
return json{ return json{
{"n_ctx", llama.n_ctx}, {"n_ctx", llama.n_ctx},
{"model", llama.params.model_alias}, {"model", llama.params.model_alias},
{"seed", slot->params.seed}, {"seed", llama.params.seed},
{"temp", slot->params.temp}, {"temp", llama.params.temp},
{"top_k", slot->params.top_k}, {"top_k", llama.params.top_k},
{"top_p", slot->params.top_p}, {"top_p", llama.params.top_p},
{"tfs_z", slot->params.tfs_z}, {"tfs_z", llama.params.tfs_z},
{"typical_p", slot->params.typical_p}, {"typical_p", llama.params.typical_p},
{"repeat_last_n", slot->params.repeat_last_n}, {"repeat_last_n", llama.params.repeat_last_n},
{"repeat_penalty", slot->params.repeat_penalty}, {"repeat_penalty", llama.params.repeat_penalty},
{"presence_penalty",slot->params.presence_penalty}, {"presence_penalty", llama.params.presence_penalty},
{"frequency_penalty", slot->params.frequency_penalty}, {"frequency_penalty", llama.params.frequency_penalty},
{"mirostat", slot->params.mirostat}, {"mirostat", llama.params.mirostat},
{"mirostat_tau", slot->params.mirostat_tau}, {"mirostat_tau", llama.params.mirostat_tau},
{"mirostat_eta", slot->params.mirostat_eta}, {"mirostat_eta", llama.params.mirostat_eta},
{"penalize_nl", slot->params.penalize_nl}, {"penalize_nl", llama.params.penalize_nl},
{"stop", slot->params.antiprompt}, {"stop", llama.params.antiprompt},
{"n_predict", slot->params.n_predict}, {"n_predict", llama.params.n_predict},
// {"n_keep", slot.params.n_keep}, {"n_keep", llama.params.n_keep},
{"ignore_eos", ignore_eos}, {"ignore_eos", ignore_eos},
{"stream", slot->params.stream}, {"stream", llama.stream},
{"logit_bias", slot->params.logit_bias}, {"logit_bias", llama.params.logit_bias},
{"n_probs", slot->params.n_probs}, {"n_probs", llama.params.n_probs},
{"grammar", slot->params.grammar}, {"grammar", llama.params.grammar},
}; };
} }
@ -1419,7 +1508,7 @@ static json format_final_response(llama_server_context &llama, llama_client_slot
// {"timings", format_timings(llama)}, // {"timings", format_timings(llama)},
}; };
if (slot->params.n_probs > 0) if (llama.params.n_probs > 0)
{ {
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
} }
@ -1436,7 +1525,7 @@ static json format_partial_response(
{ "slot_id", slot->id } { "slot_id", slot->id }
}; };
if (slot->params.n_probs > 0) if (llama.params.n_probs > 0)
{ {
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
} }
@ -1467,27 +1556,27 @@ static T json_value(const json &body, const std::string &key, const T &default_v
static void parse_options_completion(const json &body, llama_client_slot* &slot, llama_server_context &llama) static void parse_options_completion(const json &body, llama_client_slot* &slot, llama_server_context &llama)
{ {
slot_params default_params; gpt_params default_params;
slot->params.stream = json_value(body, "stream", false); llama.stream = json_value(body, "stream", false);
slot->params.n_predict = json_value(body, "n_predict", default_params.n_predict); llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict);
slot->params.top_k = json_value(body, "top_k", default_params.top_k); llama.params.top_k = json_value(body, "top_k", default_params.top_k);
slot->params.top_p = json_value(body, "top_p", default_params.top_p); llama.params.top_p = json_value(body, "top_p", default_params.top_p);
slot->params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z); llama.params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z);
slot->params.typical_p = json_value(body, "typical_p", default_params.typical_p); llama.params.typical_p = json_value(body, "typical_p", default_params.typical_p);
slot->params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n); llama.params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n);
slot->params.temp = json_value(body, "temperature", default_params.temp); llama.params.temp = json_value(body, "temperature", default_params.temp);
slot->params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty); llama.params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty);
slot->params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty); llama.params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty);
slot->params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty); llama.params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty);
slot->params.mirostat = json_value(body, "mirostat", default_params.mirostat); llama.params.mirostat = json_value(body, "mirostat", default_params.mirostat);
slot->params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau); llama.params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau);
slot->params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta); llama.params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta);
slot->params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl); llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl);
//llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep); llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
slot->params.seed = json_value(body, "seed", default_params.seed); llama.params.seed = json_value(body, "seed", default_params.seed);
slot->params.grammar = json_value(body, "grammar", default_params.grammar); llama.params.grammar = json_value(body, "grammar", default_params.grammar);
slot->params.n_probs = json_value(body, "n_probs", default_params.n_probs); llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs);
if (body.count("prompt") != 0) if (body.count("prompt") != 0)
{ {
@ -1498,10 +1587,10 @@ static void parse_options_completion(const json &body, llama_client_slot* &slot,
slot->prompt = ""; slot->prompt = "";
} }
slot->params.logit_bias.clear(); llama.params.logit_bias.clear();
if (json_value(body, "ignore_eos", false)) if (json_value(body, "ignore_eos", false))
{ {
slot->params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY; llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
} }
const auto &logit_bias = body.find("logit_bias"); const auto &logit_bias = body.find("logit_bias");
@ -1517,11 +1606,11 @@ static void parse_options_completion(const json &body, llama_client_slot* &slot,
{ {
if (el[1].is_number()) if (el[1].is_number())
{ {
slot->params.logit_bias[tok] = el[1].get<float>(); llama.params.logit_bias[tok] = el[1].get<float>();
} }
else if (el[1].is_boolean() && !el[1].get<bool>()) else if (el[1].is_boolean() && !el[1].get<bool>())
{ {
slot->params.logit_bias[tok] = -INFINITY; llama.params.logit_bias[tok] = -INFINITY;
} }
} }
} }
@ -1541,6 +1630,8 @@ static void parse_options_completion(const json &body, llama_client_slot* &slot,
} }
} }
llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar);
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama, slot)); LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama, slot));
} }
@ -1774,11 +1865,11 @@ int main(int argc, char **argv)
// } // }
// } // }
// auto probs = llama.generated_token_probs; auto probs = llama.generated_token_probs;
// if (llama.params.n_probs > 0 && llama.stopped_word) { if (llama.params.n_probs > 0 && llama.stopped_word) {
// const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false); const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
// probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size()); probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
// } }
// const json data = format_final_response(llama, llama.generated_text, probs); // const json data = format_final_response(llama, llama.generated_text, probs);
@ -1796,32 +1887,70 @@ int main(int argc, char **argv)
// const completion_token_output token = slot->next(); // const completion_token_output token = slot->next();
// std::string token_str = llama_token_to_piece(llama.ctx, token.tok); // std::string token_str = llama_token_to_piece(llama.ctx, token.tok);
// std::vector<completion_token_output> probs_output = {}; size_t pos = std::min(sent_count, llama.generated_text.size());
// const json data = format_partial_response(llama, slot, token_str, probs_output); const std::string str_test = llama.generated_text.substr(pos);
// const std::string str = bool is_stop_full = false;
// "data: " + size_t stop_pos =
// data.dump(-1, ' ', false, json::error_handler_t::replace) + llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
// "\n\n"; if (stop_pos != std::string::npos) {
is_stop_full = true;
// LOG_VERBOSE("data stream", { llama.generated_text.erase(
// { "to_send", str } llama.generated_text.begin() + pos + stop_pos,
// }); llama.generated_text.end());
// if(!sink.write(str.c_str(), str.size())) { pos = std::min(sent_count, llama.generated_text.size());
// slot->release();
// return false;
// }
} else { } else {
std::this_thread::sleep_for(std::chrono::milliseconds(5)); is_stop_full = false;
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
STOP_PARTIAL);
}
if (
stop_pos == std::string::npos ||
// Send rest of the text if we are at the end of the generation
(!llama.has_next_token && !is_stop_full && stop_pos > 0)
) {
const std::string to_send = llama.generated_text.substr(pos, std::string::npos);
sent_count += to_send.size();
std::vector<completion_token_output> probs_output = {};
if (llama.params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
}
sent_token_probs_index = probs_stop_pos;
}
const json data = format_partial_response(llama, to_send, probs_output);
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if (!sink.write(str.data(), str.size())) {
LOG_VERBOSE("stream closed", {});
llama_print_timings(llama.ctx);
return false;
} }
} }
// const json data = format_final_response(
// llama, slot, if (!llama.has_next_token) {
// "", // Generation is done, send extra information.
// std::vector<completion_token_output>( const json data = format_final_response(
// slot->generated_token_probs.begin(), llama,
// slot->generated_token_probs.begin() + sent_token_probs_index) "",
// ); std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index)
);
// const std::string str = // const std::string str =
// "data: " + // "data: " +
@ -1907,15 +2036,15 @@ int main(int argc, char **argv)
// std::vector<completion_token_output> probs_output = {}; // std::vector<completion_token_output> probs_output = {};
// if (llama.params.n_probs > 0) { if (llama.params.n_probs > 0) {
// const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false); const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
// size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
// size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
// if (probs_pos < probs_stop_pos) { if (probs_pos < probs_stop_pos) {
// probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
// } }
// sent_token_probs_index = probs_stop_pos; sent_token_probs_index = probs_stop_pos;
// } }
// const json data = format_partial_response(llama, to_send, probs_output); // const json data = format_partial_response(llama, to_send, probs_output);

View File

@ -125,6 +125,8 @@ int main(int argc, char ** argv) {
grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
} }
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar_tgt);
const auto t_dec_start = ggml_time_us(); const auto t_dec_start = ggml_time_us();
while (true) { while (true) {
@ -134,7 +136,7 @@ int main(int argc, char ** argv) {
while (true) { while (true) {
// sample from the target model // sample from the target model
llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft);
// remember which tokens were sampled - used for repetition penalties during sampling // remember which tokens were sampled - used for repetition penalties during sampling
last_tokens.erase(last_tokens.begin()); last_tokens.erase(last_tokens.begin());
@ -211,7 +213,13 @@ int main(int argc, char ** argv) {
if (grammar_dft) { if (grammar_dft) {
llama_grammar_free(grammar_dft); llama_grammar_free(grammar_dft);
} }
grammar_dft = llama_grammar_copy(grammar_tgt); // Note: Hardcoded to sequence id 0, if this ever supports parallel generation
// that will need to change.
auto it = ctx_sampling.sequence_contexts.find(0);
GGML_ASSERT(it != ctx_sampling.sequence_contexts.end());
// This is necessary because each sequence id in sequence_contexts
// uses a copy of the original grammar.
grammar_dft = llama_grammar_copy(it->second.grammar);
LOG("copied target grammar to draft grammar\n"); LOG("copied target grammar to draft grammar\n");
} }

2
ggml.c
View File

@ -14428,7 +14428,7 @@ static void ggml_compute_forward_conv_2d_f16_f32(
int64_t t0 = ggml_perf_time_us(); int64_t t0 = ggml_perf_time_us();
UNUSED(t0); UNUSED(t0);
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS;
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;

93
prompts/mnemonics.txt Normal file
View File

@ -0,0 +1,93 @@
For each kanji character, write a Markdownformatted mnemonic that uses its keyword and the keyword of all its components.
Kanji: 欠 (lack of)
Components: 𠂊 (hook claw), 人 (person)
Mnemonic: This **person** is a pirate. He lost his hand to a crocodile many years ago. Nowadays, the ***lack of*** a hand does not bother him too much. In fact, the **hook claw** that replaces it is the mark of a true pirate, so he is quite proud of it!
Kanji: 類 (kind (of something))
Components: 米 (rice), 大 (large), 頁 (page)
Mnemonic: The waiter at a Chinese restaurant hands you a **large** menu. Each **page** has all ***kinds*** of **rice** on offer!
Kanji: 燃 (burn)
Components: 火 (fire), 然 (sort of thing)
Mnemonic: ***Burning*** things up with **fire** is just my **sort of thing**. (Spoken like a true pyromaniac.)
Kanji: 頂 (top of)
Components: 丁 (street), 頁 (page)
Mnemonic: To be at the ***top of*** your game, you need both practical knowledge (**street** smarts) and theoretical knowledge (having read many **pages**).
Kanji: 険 (risky and steep)
Components: 阝 (small village), 㑒 (consensus)
Mnemonic: Everyone agrees (there is **consensus**) that the path to the **small village** is ***risky and steep***.
Kanji: 困 (distressed)
Components: 囗 (closed box), 木 (tree)
Mnemonic: You would feel ***distressed*** too if you were a **tree** trapped in a **closed box**! I have no place to grow!
Kanji: 頭 (head)
Components: 豆 (bean), 頁 (page)
Mnemonic: What do you have in that ***head*** of yours? A **bean** for a brain? Go read more **pages** and become more knowledgeable about the world!
Kanji: 確 (certain)
Components: 石 (stone), 冖 (roof without a chimney), 隹 (old bird)
Mnemonic: An **old bird** has made a nest on your **roof**. What do you do? You call Misaka from a <cite>A ***Certain*** Scientific Railgun</cite> to get rid of it, of course! But she doesnt really want to vaporize the poor thing, so she just throws a **stone** to scare it away. (What was the point of calling her, then‽)
Kanji: 魚 (fish)
Components: 𠂊 (hook claw), 田 (rice field), 灬 (fire sparks)
Mnemonic: Catch ***fish*** with a **hook**, collect rice from the **rice field**, cook them with **fire**… And my meal is ready!
Kanji: 警 (to police (something))
Components: 敬 (respect), 言 (say)
Mnemonic: ***To police something*** is to make people **respect** what the law **says**.
Kanji: 筆 (writing brush)
Components: 竹 (bamboo), 聿 (brush)
Mnemonic: A traditional ***writing brush*** is a **brush** made of **bamboo**.
Kanji: 獄 (prison)
Components: 犭 (animal), 言 (say), 犬 (dog)
Mnemonic: In ***prison***, like in the **animal** kingdom, only the toughest survive. You have to watch what you **say**. Its a **dog**eatdog world.
Kanji: 新 (new)
Components: 立 (standing up), 木 (tree), 斤 (axe)
Mnemonic: In order for a ***new*** construction to be made, an empty lot is needed. If there are any **trees** **standing up**, they must be cut down with an **axe**.
Kanji: 怪 (suspicious)
Components: 忄 (weak heart), 圣 (sacred)
Mnemonic: That painting of the **Sacred** **Heart** of Jesus looks ***suspicious***. I think it might be a forgery.
Kanji: 温 (warm (to the touch))
Components: 氵 (water drops), 日 (sun), 皿 (dish)
Mnemonic: If you leave **water** on a **dish** in the **sun**, it will get ***warm***.
Kanji: 階 (floor (of a building))
Components: 阝 (small village), 皆 (all)
Mnemonic: It might be a **small village**, but, despite that, **all** of its buildings have many ***floors***. Its a village of skyscrapers!
Kanji: 多 (many)
Components: 夕 (evening (before sunset)), 夕 (evening (before sunset))
Mnemonic: Two **evenings** in a day would be one too ***many***.
Kanji: 別 (separate)
Components: 口 (mouth), 万 (ten thousand), 刂 (knife)
Mnemonic: Tom Six is at it again. For his next flick, he wants to stitch together **ten thousand** people, **mouth**toanus. One of the most graphic and disturbing scenes will feature one of the victims using a **knife** to ***separate*** perself.
Kanji: 並 (line up)
Components: 䒑 (antlers on a wall), 业 (runway)
Mnemonic: In order to land a plane you have to ***line up*** properly with the **runway**. The things that look like **antlers** at the end of the runway are the control towers; you should follow their instructions.
Kanji: 姿 (figure)
Components: 次 (next), 女 (woman)
Mnemonic: The **next** **woman** that I date will have a perfect **figure**. Because Im done with 3D women—it will *literally* be an anime figure!
Kanji: 実 (real)
Components: 宀 (roof with a chimney), 𡗗 (three people)
Mnemonic: Living under a **roof with a chimney** with **three people** (a wife and two children)—a happy family life—is not something I could have ever imagined. It does not feel ***real***.
Kanji: 謝 (apologize)
Components: 言 (say), 射 (shoot)
Mnemonic: **Shot** first, ***apologize*** (**say** you are sorry) later.
Kanji: 提 (propose)
Components: 扌 (left hand), 是 (go with)
Mnemonic: