diff --git a/CMakeLists.txt b/CMakeLists.txt index c6ed458b3..6c5a3e09e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -296,7 +296,6 @@ if (LLAMA_METAL) find_library(FOUNDATION_LIBRARY Foundation REQUIRED) find_library(METAL_FRAMEWORK Metal REQUIRED) find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) - find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED) set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h) @@ -313,7 +312,6 @@ if (LLAMA_METAL) ${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK} - ${METALPERFORMANCE_FRAMEWORK} ) endif() @@ -570,6 +568,16 @@ install( WORLD_READ WORLD_EXECUTE DESTINATION ${CMAKE_INSTALL_BINDIR}) +if (LLAMA_METAL) + install( + FILES ggml-metal.metal + PERMISSIONS + OWNER_READ + OWNER_WRITE + GROUP_READ + WORLD_READ + DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() # # programs, examples and tests diff --git a/Makefile b/Makefile index 4d2a8cfa9..0f768fd19 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple server embd-input-test gguf gptneox-main # Binaries only useful for tests -TEST_TARGETS = tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0 +TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0 default: $(BUILD_TARGETS) @@ -283,7 +283,7 @@ endif # LLAMA_CLBLAST ifdef LLAMA_METAL CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG CXXFLAGS += -DGGML_USE_METAL - LDFLAGS += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders + LDFLAGS += -framework Foundation -framework Metal -framework MetalKit OBJS += ggml-metal.o endif # LLAMA_METAL @@ -418,6 +418,9 @@ benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) +tests/test-llama-grammar: tests/test-llama-grammar.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) + tests/test-grammar-parser: tests/test-grammar-parser.cpp examples/grammar-parser.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) diff --git a/examples/common.cpp b/examples/common.cpp index 8beb63f36..ea6c9d499 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -262,6 +262,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.cfg_negative_prompt = argv[i]; + } else if (arg == "--cfg-negative-prompt-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.cfg_negative_prompt)); + if (params.cfg_negative_prompt.back() == '\n') { + params.cfg_negative_prompt.pop_back(); + } } else if (arg == "--cfg-scale") { if (++i >= argc) { invalid_param = true; @@ -553,8 +568,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n"); fprintf(stdout, " --grammar-file FNAME file to read grammar from\n"); - fprintf(stdout, " --cfg-negative-prompt PROMPT \n"); + fprintf(stdout, " --cfg-negative-prompt PROMPT\n"); fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n"); + fprintf(stdout, " --cfg-negative-prompt-file FNAME\n"); + fprintf(stdout, " negative prompt file to use for guidance. (default: empty)\n"); fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); fprintf(stdout, " --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale); fprintf(stdout, " --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 144372061..9337e2104 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -15,6 +15,7 @@ #include "index.html.hpp" #include "index.js.hpp" #include "completion.js.hpp" +#include "json-schema-to-grammar.mjs.hpp" #ifndef SERVER_VERBOSE #define SERVER_VERBOSE 1 @@ -1199,6 +1200,12 @@ int main(int argc, char **argv) res.set_content(reinterpret_cast(&completion_js), completion_js_len, "application/javascript"); return false; }); + // this is only called if no index.html is found in the public --path + svr.Get("/json-schema-to-grammar.mjs", [](const Request &, Response &res) + { + res.set_content(reinterpret_cast(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript"); + return false; }); + svr.Post("/completion", [&llama](const Request &req, Response &res) { auto lock = llama.lock(); diff --git a/flake.nix b/flake.nix index 4178e97ff..616b90252 100644 --- a/flake.nix +++ b/flake.nix @@ -14,8 +14,6 @@ with pkgs.darwin.apple_sdk_11_0.frameworks; [ Accelerate MetalKit - MetalPerformanceShaders - MetalPerformanceShadersGraph ] else if isAarch32 && isDarwin then with pkgs.darwin.apple_sdk.frameworks; [ diff --git a/scripts/get-wikitext-2.sh b/scripts/get-wikitext-2.sh new file mode 100644 index 000000000..98aec3e3e --- /dev/null +++ b/scripts/get-wikitext-2.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp new file mode 100644 index 000000000..f98c6531f --- /dev/null +++ b/tests/test-llama-grammar.cpp @@ -0,0 +1,403 @@ +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include "llama.cpp" +#include "examples/common.cpp" +#include "examples/grammar-parser.cpp" +#include + +int main() +{ + grammar_parser::parse_state parsed_grammar; + + std::vector> expected = { + {"expr", 2}, + {"expr_6", 6}, + {"expr_7", 7}, + {"ident", 8}, + {"ident_10", 10}, + {"num", 9}, + {"num_11", 11}, + {"root", 0}, + {"root_1", 1}, + {"root_5", 5}, + {"term", 4}, + {"ws", 3}, + {"ws_12", 12}, + }; + + std::vector> expected_rules = { + {{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}}, + { + {LLAMA_GRETYPE_RULE_REF, 2}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_RULE_REF, 4}, + {LLAMA_GRETYPE_CHAR, 10}, + {LLAMA_GRETYPE_END, 0}, + }, + {{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}}, + {{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}}, + { + {LLAMA_GRETYPE_RULE_REF, 8}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_RULE_REF, 9}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_CHAR, 40}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_RULE_REF, 2}, + {LLAMA_GRETYPE_CHAR, 41}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_END, 0}, + }, + {{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}}, + { + {LLAMA_GRETYPE_CHAR, 45}, + {LLAMA_GRETYPE_CHAR_ALT, 43}, + {LLAMA_GRETYPE_CHAR_ALT, 42}, + {LLAMA_GRETYPE_CHAR_ALT, 47}, + {LLAMA_GRETYPE_RULE_REF, 4}, + {LLAMA_GRETYPE_END, 0}, + }, + {{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}}, + { + {LLAMA_GRETYPE_CHAR, 97}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122}, + {LLAMA_GRETYPE_RULE_REF, 10}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_END, 0}, + }, + {{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}}, + { + {LLAMA_GRETYPE_CHAR, 97}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122}, + {LLAMA_GRETYPE_CHAR_ALT, 48}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, + {LLAMA_GRETYPE_CHAR_ALT, 95}, + {LLAMA_GRETYPE_RULE_REF, 10}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }, + { + {LLAMA_GRETYPE_CHAR, 48}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, + {LLAMA_GRETYPE_RULE_REF, 11}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_CHAR, 48}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, + {LLAMA_GRETYPE_END, 0}, + }, + { + {LLAMA_GRETYPE_CHAR, 32}, + {LLAMA_GRETYPE_CHAR_ALT, 9}, + {LLAMA_GRETYPE_CHAR_ALT, 10}, + {LLAMA_GRETYPE_RULE_REF, 12}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }, + }; + + for (auto pair : expected) + { + parsed_grammar.symbol_ids[pair.first] = pair.second; + } + + for (auto rule : expected_rules) + { + parsed_grammar.rules.push_back({}); + for (auto element : rule) + { + parsed_grammar.rules.back().push_back(element); + } + } + + llama_grammar *grammar = NULL; + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + + std::vector> expected_stacks = { + { + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_CHAR, 97}, + }, + { + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_CHAR, 40}, + }, + { + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_CHAR, 97}, + }, + { + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_CHAR, 40}, + }}; + + auto index = 0; + for (auto stack : grammar->stacks) + { + // compare stack to expected_stack + for (uint32_t i = 0; i < stack.size(); i++) + { + auto element = stack[i]; + auto expected_element = expected_stacks[index][i]; + + // pretty print error message before asserting + if (expected_element.type != element->type || expected_element.value != element->value) + { + fprintf(stderr, "index: %d\n", index); + fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value); + fprintf(stderr, "actual_element: %d, %d\n", element->type, element->value); + fprintf(stderr, "expected_element != actual_element\n"); + } + + assert(expected_element.type == element->type && expected_element.value == element->value); + } + index++; + } + + std::vector> next_stacks; + std::vector next_candidates; + next_candidates.resize(24); + + for (size_t i = 0; i < 24; ++i) + { + uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point + cp[0] = 37 + i; + cp[1] = 0; + next_candidates[i] = {i, cp}; + } + + std::vector>> expected_reject = { + { + {0, 37}, + {1, 38}, + {2, 39}, + {3, 40}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {11, 48}, + {12, 49}, + {13, 50}, + {14, 51}, + {15, 52}, + {16, 53}, + {17, 54}, + {18, 55}, + {19, 56}, + {20, 57}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + { + {0, 37}, + {1, 38}, + {2, 39}, + {3, 40}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + { + {0, 37}, + {1, 38}, + {2, 39}, + {3, 40}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + { + {0, 37}, + {1, 38}, + {2, 39}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {11, 48}, + {12, 49}, + {13, 50}, + {14, 51}, + {15, 52}, + {16, 53}, + {17, 54}, + {18, 55}, + {19, 56}, + {20, 57}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + { + {0, 37}, + {1, 38}, + {2, 39}, + {3, 40}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {11, 48}, + {12, 49}, + {13, 50}, + {14, 51}, + {15, 52}, + {16, 53}, + {17, 54}, + {18, 55}, + {19, 56}, + {20, 57}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + { + {0, 37}, + {1, 38}, + {2, 39}, + {3, 40}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + { + {0, 37}, + {1, 38}, + {2, 39}, + {3, 40}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + { + {0, 37}, + {1, 38}, + {2, 39}, + {4, 41}, + {5, 42}, + {6, 43}, + {7, 44}, + {8, 45}, + {9, 46}, + {10, 47}, + {11, 48}, + {12, 49}, + {13, 50}, + {14, 51}, + {15, 52}, + {16, 53}, + {17, 54}, + {18, 55}, + {19, 56}, + {20, 57}, + {21, 58}, + {22, 59}, + {23, 60}, + }, + }; + + std::vector rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[0], next_candidates); + + std::vector> all_rejects; + + for (std::size_t count = 0; count < grammar->stacks.size(); ++count) + { + rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[count], next_candidates); + all_rejects.push_back(rejects); + } + + index = 0; + for (auto rej : all_rejects) + { + for (uint32_t i = 0; i < rej.size(); i++) + { + auto element = rej[i]; + auto expected_element = expected_reject[index][i]; + assert(element.index == expected_element.first && *element.code_points == expected_element.second); + } + index++; + } + + for (auto &candidate : next_candidates) + { + delete[] candidate.code_points; + candidate.code_points = nullptr; + } + delete grammar; + return 0; +}