mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-15 07:19:53 +00:00
Merge branch 'master' into compilade/mamba2
This commit is contained in:
commit
7d16e1bc8c
@ -88,6 +88,10 @@ if (NOT DEFINED GGML_LLAMAFILE)
|
|||||||
set(GGML_LLAMAFILE_DEFAULT ON)
|
set(GGML_LLAMAFILE_DEFAULT ON)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (NOT DEFINED GGML_AMX)
|
||||||
|
set(GGML_AMX ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
if (NOT DEFINED GGML_CUDA_GRAPHS)
|
if (NOT DEFINED GGML_CUDA_GRAPHS)
|
||||||
set(GGML_CUDA_GRAPHS_DEFAULT ON)
|
set(GGML_CUDA_GRAPHS_DEFAULT ON)
|
||||||
endif()
|
endif()
|
||||||
|
24
Makefile
24
Makefile
@ -93,11 +93,6 @@ GGML_METAL := 1
|
|||||||
DEPRECATE_WARNING := 1
|
DEPRECATE_WARNING := 1
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ifdef LLAMA_OPENMP
|
|
||||||
GGML_OPENMP := 1
|
|
||||||
DEPRECATE_WARNING := 1
|
|
||||||
endif
|
|
||||||
|
|
||||||
ifdef LLAMA_RPC
|
ifdef LLAMA_RPC
|
||||||
GGML_RPC := 1
|
GGML_RPC := 1
|
||||||
DEPRECATE_WARNING := 1
|
DEPRECATE_WARNING := 1
|
||||||
@ -584,6 +579,11 @@ ifndef GGML_NO_LLAMAFILE
|
|||||||
OBJ_GGML += ggml/src/llamafile/sgemm.o
|
OBJ_GGML += ggml/src/llamafile/sgemm.o
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
ifndef GGML_NO_AMX
|
||||||
|
MK_CPPFLAGS += -DGGML_USE_AMX
|
||||||
|
OBJ_GGML += ggml/src/ggml-amx.o ggml/src/ggml-amx/mmq.o
|
||||||
|
endif
|
||||||
|
|
||||||
ifdef GGML_RPC
|
ifdef GGML_RPC
|
||||||
MK_CPPFLAGS += -DGGML_USE_RPC
|
MK_CPPFLAGS += -DGGML_USE_RPC
|
||||||
OBJ_GGML += ggml/src/ggml-rpc.o
|
OBJ_GGML += ggml/src/ggml-rpc.o
|
||||||
@ -1087,6 +1087,19 @@ ggml/src/llamafile/sgemm.o: \
|
|||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
endif # GGML_NO_LLAMAFILE
|
endif # GGML_NO_LLAMAFILE
|
||||||
|
|
||||||
|
ifndef GGML_NO_AMX
|
||||||
|
ggml/src/ggml-amx.o: \
|
||||||
|
ggml/src/ggml-amx.cpp \
|
||||||
|
ggml/include/ggml-amx.h
|
||||||
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
ggml/src/ggml-amx/mmq.o: \
|
||||||
|
ggml/src/ggml-amx/mmq.cpp \
|
||||||
|
ggml/src/ggml-amx/mmq.h \
|
||||||
|
ggml/include/ggml.h
|
||||||
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
endif
|
||||||
|
|
||||||
ifdef GGML_RPC
|
ifdef GGML_RPC
|
||||||
ggml/src/ggml-rpc.o: \
|
ggml/src/ggml-rpc.o: \
|
||||||
ggml/src/ggml-rpc.cpp \
|
ggml/src/ggml-rpc.cpp \
|
||||||
@ -1238,6 +1251,7 @@ clean:
|
|||||||
rm -vrf ggml/src/ggml-metal-embed.metal
|
rm -vrf ggml/src/ggml-metal-embed.metal
|
||||||
rm -vrf ggml/src/ggml-cuda/*.o
|
rm -vrf ggml/src/ggml-cuda/*.o
|
||||||
rm -vrf ggml/src/ggml-cuda/template-instances/*.o
|
rm -vrf ggml/src/ggml-cuda/template-instances/*.o
|
||||||
|
rm -vrf ggml/src/ggml-amx/*.o
|
||||||
rm -rvf $(BUILD_TARGETS)
|
rm -rvf $(BUILD_TARGETS)
|
||||||
rm -rvf $(TEST_TARGETS)
|
rm -rvf $(TEST_TARGETS)
|
||||||
rm -f vulkan-shaders-gen ggml/src/ggml-vulkan-shaders.hpp ggml/src/ggml-vulkan-shaders.cpp
|
rm -f vulkan-shaders-gen ggml/src/ggml-vulkan-shaders.hpp ggml/src/ggml-vulkan-shaders.cpp
|
||||||
|
@ -29,7 +29,7 @@ variety of hardware - locally and in the cloud.
|
|||||||
|
|
||||||
- Plain C/C++ implementation without any dependencies
|
- Plain C/C++ implementation without any dependencies
|
||||||
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
|
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
|
||||||
- AVX, AVX2 and AVX512 support for x86 architectures
|
- AVX, AVX2, AVX512 and AMX support for x86 architectures
|
||||||
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
|
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
|
||||||
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads MTT GPUs via MUSA)
|
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads MTT GPUs via MUSA)
|
||||||
- Vulkan and SYCL backend support
|
- Vulkan and SYCL backend support
|
||||||
@ -93,6 +93,7 @@ Typically finetunes of the base models below are supported as well.
|
|||||||
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
|
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
|
||||||
- [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat)
|
- [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat)
|
||||||
- [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a)
|
- [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a)
|
||||||
|
- [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM)
|
||||||
|
|
||||||
(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))
|
(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))
|
||||||
|
|
||||||
@ -122,6 +123,7 @@ Typically finetunes of the base models below are supported as well.
|
|||||||
- Rust (nicer API): [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp)
|
- Rust (nicer API): [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp)
|
||||||
- Rust (more direct bindings): [utilityai/llama-cpp-rs](https://github.com/utilityai/llama-cpp-rs)
|
- Rust (more direct bindings): [utilityai/llama-cpp-rs](https://github.com/utilityai/llama-cpp-rs)
|
||||||
- C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp)
|
- C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp)
|
||||||
|
- C#/VB.NET (more features - community license): [LM-Kit.NET](https://docs.lm-kit.com/lm-kit-net/index.html)
|
||||||
- Scala 3: [donderom/llm4s](https://github.com/donderom/llm4s)
|
- Scala 3: [donderom/llm4s](https://github.com/donderom/llm4s)
|
||||||
- Clojure: [phronmophobic/llama.clj](https://github.com/phronmophobic/llama.clj)
|
- Clojure: [phronmophobic/llama.clj](https://github.com/phronmophobic/llama.clj)
|
||||||
- React Native: [mybigday/llama.rn](https://github.com/mybigday/llama.rn)
|
- React Native: [mybigday/llama.rn](https://github.com/mybigday/llama.rn)
|
||||||
@ -130,6 +132,8 @@ Typically finetunes of the base models below are supported as well.
|
|||||||
- Flutter/Dart: [netdur/llama_cpp_dart](https://github.com/netdur/llama_cpp_dart)
|
- Flutter/Dart: [netdur/llama_cpp_dart](https://github.com/netdur/llama_cpp_dart)
|
||||||
- PHP (API bindings and features built on top of llama.cpp): [distantmagic/resonance](https://github.com/distantmagic/resonance) [(more info)](https://github.com/ggerganov/llama.cpp/pull/6326)
|
- PHP (API bindings and features built on top of llama.cpp): [distantmagic/resonance](https://github.com/distantmagic/resonance) [(more info)](https://github.com/ggerganov/llama.cpp/pull/6326)
|
||||||
- Guile Scheme: [guile_llama_cpp](https://savannah.nongnu.org/projects/guile-llama-cpp)
|
- Guile Scheme: [guile_llama_cpp](https://savannah.nongnu.org/projects/guile-llama-cpp)
|
||||||
|
- Swift [srgtuszy/llama-cpp-swift](https://github.com/srgtuszy/llama-cpp-swift)
|
||||||
|
- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama)
|
||||||
|
|
||||||
**UI:**
|
**UI:**
|
||||||
|
|
||||||
@ -170,6 +174,7 @@ Unless otherwise noted these projects are open-source with permissive licensing:
|
|||||||
- [LARS - The LLM & Advanced Referencing Solution](https://github.com/abgulati/LARS) (AGPL)
|
- [LARS - The LLM & Advanced Referencing Solution](https://github.com/abgulati/LARS) (AGPL)
|
||||||
- [LLMUnity](https://github.com/undreamai/LLMUnity) (MIT)
|
- [LLMUnity](https://github.com/undreamai/LLMUnity) (MIT)
|
||||||
- [Llama Assistant](https://github.com/vietanhdev/llama-assistant) (GPL)
|
- [Llama Assistant](https://github.com/vietanhdev/llama-assistant) (GPL)
|
||||||
|
- [PocketPal AI - An iOS and Android App](https://github.com/a-ghorbani/pocketpal-ai) (MIT)
|
||||||
|
|
||||||
*(to have a project listed here, it should clearly state that it depends on `llama.cpp`)*
|
*(to have a project listed here, it should clearly state that it depends on `llama.cpp`)*
|
||||||
|
|
||||||
@ -185,6 +190,7 @@ Unless otherwise noted these projects are open-source with permissive licensing:
|
|||||||
|
|
||||||
- [Paddler](https://github.com/distantmagic/paddler) - Stateful load balancer custom-tailored for llama.cpp
|
- [Paddler](https://github.com/distantmagic/paddler) - Stateful load balancer custom-tailored for llama.cpp
|
||||||
- [GPUStack](https://github.com/gpustack/gpustack) - Manage GPU clusters for running LLMs
|
- [GPUStack](https://github.com/gpustack/gpustack) - Manage GPU clusters for running LLMs
|
||||||
|
- [llama_cpp_canister](https://github.com/onicai/llama_cpp_canister) - llama.cpp as a smart contract on the Internet Computer, using WebAssembly
|
||||||
|
|
||||||
**Games:**
|
**Games:**
|
||||||
- [Lucy's Labyrinth](https://github.com/MorganRO8/Lucys_Labyrinth) - A simple maze game where agents controlled by an AI model will try to trick you.
|
- [Lucy's Labyrinth](https://github.com/MorganRO8/Lucys_Labyrinth) - A simple maze game where agents controlled by an AI model will try to trick you.
|
||||||
|
@ -53,7 +53,7 @@ if [ ! -z ${GG_BUILD_SYCL} ]; then
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON"
|
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -z ${GG_BUILD_VULKAN} ]; then
|
if [ ! -z ${GG_BUILD_VULKAN} ]; then
|
||||||
|
@ -128,13 +128,13 @@ static void common_params_handle_model_default(common_params & params) {
|
|||||||
}
|
}
|
||||||
params.hf_file = params.model;
|
params.hf_file = params.model;
|
||||||
} else if (params.model.empty()) {
|
} else if (params.model.empty()) {
|
||||||
params.model = fs_get_cache_file(string_split(params.hf_file, '/').back());
|
params.model = fs_get_cache_file(string_split<std::string>(params.hf_file, '/').back());
|
||||||
}
|
}
|
||||||
} else if (!params.model_url.empty()) {
|
} else if (!params.model_url.empty()) {
|
||||||
if (params.model.empty()) {
|
if (params.model.empty()) {
|
||||||
auto f = string_split(params.model_url, '#').front();
|
auto f = string_split<std::string>(params.model_url, '#').front();
|
||||||
f = string_split(f, '?').front();
|
f = string_split<std::string>(f, '?').front();
|
||||||
params.model = fs_get_cache_file(string_split(f, '/').back());
|
params.model = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||||
}
|
}
|
||||||
} else if (params.model.empty()) {
|
} else if (params.model.empty()) {
|
||||||
params.model = DEFAULT_MODEL_PATH;
|
params.model = DEFAULT_MODEL_PATH;
|
||||||
@ -251,6 +251,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
|||||||
for (auto & antiprompt : params.antiprompt) {
|
for (auto & antiprompt : params.antiprompt) {
|
||||||
string_process_escapes(antiprompt);
|
string_process_escapes(antiprompt);
|
||||||
}
|
}
|
||||||
|
for (auto & seq_breaker : params.sparams.dry_sequence_breakers) {
|
||||||
|
string_process_escapes(seq_breaker);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!params.kv_overrides.empty()) {
|
if (!params.kv_overrides.empty()) {
|
||||||
@ -879,7 +882,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
{"--samplers"}, "SAMPLERS",
|
{"--samplers"}, "SAMPLERS",
|
||||||
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
|
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
const auto sampler_names = string_split(value, ';');
|
const auto sampler_names = string_split<std::string>(value, ';');
|
||||||
params.sparams.samplers = common_sampler_types_from_names(sampler_names, true);
|
params.sparams.samplers = common_sampler_types_from_names(sampler_names, true);
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
@ -941,10 +944,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--tfs"}, "N",
|
{"--xtc-probability"}, "N",
|
||||||
string_format("tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)params.sparams.tfs_z),
|
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability),
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.sparams.tfs_z = std::stof(value);
|
params.sparams.xtc_probability = std::stof(value);
|
||||||
|
}
|
||||||
|
).set_sparam());
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--xtc-threshold"}, "N",
|
||||||
|
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sparams.xtc_threshold),
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
params.sparams.xtc_threshold = std::stof(value);
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
@ -983,6 +993,64 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
params.sparams.penalty_freq = std::stof(value);
|
params.sparams.penalty_freq = std::stof(value);
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--dry-multiplier"}, "N",
|
||||||
|
string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sparams.dry_multiplier),
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
params.sparams.dry_multiplier = std::stof(value);
|
||||||
|
}
|
||||||
|
).set_sparam());
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--dry-base"}, "N",
|
||||||
|
string_format("set DRY sampling base value (default: %.2f)", (double)params.sparams.dry_base),
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
float potential_base = std::stof(value);
|
||||||
|
if (potential_base >= 1.0f)
|
||||||
|
{
|
||||||
|
params.sparams.dry_base = potential_base;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
).set_sparam());
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--dry-allowed-length"}, "N",
|
||||||
|
string_format("set allowed length for DRY sampling (default: %d)", params.sparams.dry_allowed_length),
|
||||||
|
[](common_params & params, int value) {
|
||||||
|
params.sparams.dry_allowed_length = value;
|
||||||
|
}
|
||||||
|
).set_sparam());
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--dry-penalty-last-n"}, "N",
|
||||||
|
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sparams.dry_penalty_last_n),
|
||||||
|
[](common_params & params, int value) {
|
||||||
|
params.sparams.dry_penalty_last_n = value;
|
||||||
|
}
|
||||||
|
).set_sparam());
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--dry-sequence-breaker"}, "STRING",
|
||||||
|
string_format("add sequence breaker for DRY sampling, clearing out default breakers (%s) in the process; use \"none\" to not use any sequence breakers\n",
|
||||||
|
params.sparams.dry_sequence_breakers.empty() ? "none" :
|
||||||
|
std::accumulate(std::next(params.sparams.dry_sequence_breakers.begin()),
|
||||||
|
params.sparams.dry_sequence_breakers.end(),
|
||||||
|
std::string("'") + (params.sparams.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sparams.dry_sequence_breakers[0]) + "'",
|
||||||
|
[](const std::string& a, const std::string& b) {
|
||||||
|
std::string formatted_b = (b == "\n") ? "\\n" : b;
|
||||||
|
return a + ", '" + formatted_b + "'";
|
||||||
|
}).c_str()),
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
static bool defaults_cleared = false;
|
||||||
|
|
||||||
|
if (!defaults_cleared) {
|
||||||
|
params.sparams.dry_sequence_breakers.clear();
|
||||||
|
defaults_cleared = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (value == "none") {
|
||||||
|
params.sparams.dry_sequence_breakers.clear();
|
||||||
|
} else {
|
||||||
|
params.sparams.dry_sequence_breakers.emplace_back(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--dynatemp-range"}, "N",
|
{"--dynatemp-range"}, "N",
|
||||||
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range),
|
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range),
|
||||||
@ -999,7 +1067,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--mirostat"}, "N",
|
{"--mirostat"}, "N",
|
||||||
string_format("use Mirostat sampling.\nTop K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"
|
string_format("use Mirostat sampling.\nTop K, Nucleus and Locally Typical samplers are ignored if used.\n"
|
||||||
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sparams.mirostat),
|
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sparams.mirostat),
|
||||||
[](common_params & params, int value) {
|
[](common_params & params, int value) {
|
||||||
params.sparams.mirostat = value;
|
params.sparams.mirostat = value;
|
||||||
@ -1083,7 +1151,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
|
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--attention"}, "{causal,non,causal}",
|
{"--attention"}, "{causal,non-causal}",
|
||||||
"attention type for embeddings, use model default if unspecified",
|
"attention type for embeddings, use model default if unspecified",
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
|
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
|
||||||
@ -1681,7 +1749,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
).set_examples({LLAMA_EXAMPLE_BENCH}));
|
).set_examples({LLAMA_EXAMPLE_BENCH}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--embd-normalize"}, "N",
|
{"--embd-normalize"}, "N",
|
||||||
string_format("normalisation for embendings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize),
|
string_format("normalisation for embeddings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize),
|
||||||
[](common_params & params, int value) {
|
[](common_params & params, int value) {
|
||||||
params.embd_normalize = value;
|
params.embd_normalize = value;
|
||||||
}
|
}
|
||||||
@ -1695,7 +1763,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
|
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--embd-separator"}, "STRING",
|
{"--embd-separator"}, "STRING",
|
||||||
"separator of embendings (default \\n) for example \"<#sep#>\"",
|
"separator of embeddings (default \\n) for example \"<#sep#>\"",
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.embd_sep = value;
|
params.embd_sep = value;
|
||||||
}
|
}
|
||||||
@ -1788,6 +1856,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
params.n_threads_http = value;
|
params.n_threads_http = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP"));
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP"));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--cache-reuse"}, "N",
|
||||||
|
string_format("min chunk size to attempt reusing from the cache via KV shifting (default: %d)", params.n_cache_reuse),
|
||||||
|
[](common_params & params, int value) {
|
||||||
|
params.n_cache_reuse = value;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CACHE_REUSE"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--metrics"},
|
{"--metrics"},
|
||||||
string_format("enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled"),
|
string_format("enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled"),
|
||||||
|
@ -416,19 +416,6 @@ std::string string_format(const char * fmt, ...) {
|
|||||||
return std::string(buf.data(), size);
|
return std::string(buf.data(), size);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> string_split(std::string input, char separator) {
|
|
||||||
std::vector<std::string> parts;
|
|
||||||
size_t separator_pos = input.find(separator);
|
|
||||||
while (separator_pos != std::string::npos) {
|
|
||||||
std::string part = input.substr(0, separator_pos);
|
|
||||||
parts.emplace_back(part);
|
|
||||||
input = input.substr(separator_pos + 1);
|
|
||||||
separator_pos = input.find(separator);
|
|
||||||
}
|
|
||||||
parts.emplace_back(input);
|
|
||||||
return parts;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string string_strip(const std::string & str) {
|
std::string string_strip(const std::string & str) {
|
||||||
size_t start = 0;
|
size_t start = 0;
|
||||||
size_t end = str.size();
|
size_t end = str.size();
|
||||||
@ -955,7 +942,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (llama_model_has_encoder(model)) {
|
if (llama_model_has_encoder(model)) {
|
||||||
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
|
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
|
||||||
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
||||||
if (decoder_start_token_id == -1) {
|
if (decoder_start_token_id == -1) {
|
||||||
decoder_start_token_id = bos;
|
decoder_start_token_id = bos;
|
||||||
@ -964,7 +951,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
tmp.push_back(decoder_start_token_id);
|
tmp.push_back(decoder_start_token_id);
|
||||||
}
|
}
|
||||||
if (llama_model_has_decoder(model)) {
|
if (llama_model_has_decoder(model)) {
|
||||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
|
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
|
||||||
}
|
}
|
||||||
llama_kv_cache_clear(lctx);
|
llama_kv_cache_clear(lctx);
|
||||||
llama_synchronize(lctx);
|
llama_synchronize(lctx);
|
||||||
@ -1035,7 +1022,7 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
|
|||||||
return GGML_TYPE_Q5_1;
|
return GGML_TYPE_Q5_1;
|
||||||
}
|
}
|
||||||
|
|
||||||
throw std::runtime_error("Invalid cache type: " + s);
|
throw std::runtime_error("Unsupported cache type: " + s);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_context_params common_context_params_to_llama(const common_params & params) {
|
struct llama_context_params common_context_params_to_llama(const common_params & params) {
|
||||||
@ -2019,6 +2006,10 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
|
|||||||
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
|
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
|
||||||
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
|
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
|
||||||
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
||||||
|
fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length);
|
||||||
|
fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base);
|
||||||
|
fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier);
|
||||||
|
fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n);
|
||||||
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
||||||
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
||||||
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
|
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
|
||||||
@ -2099,11 +2090,12 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
|
|||||||
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
|
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
|
||||||
yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector);
|
yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector);
|
||||||
|
|
||||||
fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);
|
|
||||||
fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency());
|
fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency());
|
||||||
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
|
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
|
||||||
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
|
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
|
||||||
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
|
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
|
||||||
|
fprintf(stream, "xtc_probability: %f # default: 0.0\n", sparams.xtc_probability);
|
||||||
|
fprintf(stream, "xtc_threshold: %f # default: 0.1\n", sparams.xtc_threshold);
|
||||||
fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
|
fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
|
||||||
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
||||||
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
|
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
|
||||||
|
@ -84,12 +84,15 @@ enum llama_example {
|
|||||||
|
|
||||||
enum common_sampler_type {
|
enum common_sampler_type {
|
||||||
COMMON_SAMPLER_TYPE_NONE = 0,
|
COMMON_SAMPLER_TYPE_NONE = 0,
|
||||||
COMMON_SAMPLER_TYPE_TOP_K = 1,
|
COMMON_SAMPLER_TYPE_DRY = 1,
|
||||||
COMMON_SAMPLER_TYPE_TOP_P = 2,
|
COMMON_SAMPLER_TYPE_TOP_K = 2,
|
||||||
COMMON_SAMPLER_TYPE_MIN_P = 3,
|
COMMON_SAMPLER_TYPE_TOP_P = 3,
|
||||||
COMMON_SAMPLER_TYPE_TFS_Z = 4,
|
COMMON_SAMPLER_TYPE_MIN_P = 4,
|
||||||
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
|
//COMMON_SAMPLER_TYPE_TFS_Z = 5,
|
||||||
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
|
COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
|
||||||
|
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
|
||||||
|
COMMON_SAMPLER_TYPE_XTC = 8,
|
||||||
|
COMMON_SAMPLER_TYPE_INFILL = 9,
|
||||||
};
|
};
|
||||||
|
|
||||||
// dimensionality reduction methods, used by cvector-generator
|
// dimensionality reduction methods, used by cvector-generator
|
||||||
@ -108,7 +111,8 @@ struct common_sampler_params {
|
|||||||
int32_t top_k = 40; // <= 0 to use vocab size
|
int32_t top_k = 40; // <= 0 to use vocab size
|
||||||
float top_p = 0.95f; // 1.0 = disabled
|
float top_p = 0.95f; // 1.0 = disabled
|
||||||
float min_p = 0.05f; // 0.0 = disabled
|
float min_p = 0.05f; // 0.0 = disabled
|
||||||
float tfs_z = 1.00f; // 1.0 = disabled
|
float xtc_probability = 0.00f; // 0.0 = disabled
|
||||||
|
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
||||||
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||||
@ -117,6 +121,10 @@ struct common_sampler_params {
|
|||||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||||
float penalty_present = 0.00f; // 0.0 = disabled
|
float penalty_present = 0.00f; // 0.0 = disabled
|
||||||
|
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
||||||
|
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
||||||
|
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||||
|
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
float mirostat_tau = 5.00f; // target entropy
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
@ -124,13 +132,17 @@ struct common_sampler_params {
|
|||||||
bool ignore_eos = false;
|
bool ignore_eos = false;
|
||||||
bool no_perf = false; // disable performance metrics
|
bool no_perf = false; // disable performance metrics
|
||||||
|
|
||||||
|
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||||
|
|
||||||
|
|
||||||
std::vector<enum common_sampler_type> samplers = {
|
std::vector<enum common_sampler_type> samplers = {
|
||||||
|
COMMON_SAMPLER_TYPE_DRY,
|
||||||
COMMON_SAMPLER_TYPE_TOP_K,
|
COMMON_SAMPLER_TYPE_TOP_K,
|
||||||
COMMON_SAMPLER_TYPE_TFS_Z,
|
|
||||||
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
||||||
COMMON_SAMPLER_TYPE_TOP_P,
|
COMMON_SAMPLER_TYPE_TOP_P,
|
||||||
COMMON_SAMPLER_TYPE_MIN_P,
|
COMMON_SAMPLER_TYPE_MIN_P,
|
||||||
COMMON_SAMPLER_TYPE_TEMPERATURE
|
COMMON_SAMPLER_TYPE_XTC,
|
||||||
|
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||||
@ -268,16 +280,17 @@ struct common_params {
|
|||||||
|
|
||||||
// embedding
|
// embedding
|
||||||
bool embedding = false; // get only sentence embedding
|
bool embedding = false; // get only sentence embedding
|
||||||
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
||||||
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
||||||
std::string embd_sep = "\n"; // separator of embendings
|
std::string embd_sep = "\n"; // separator of embeddings
|
||||||
bool reranking = false; // enable reranking support on server
|
bool reranking = false; // enable reranking support on server
|
||||||
|
|
||||||
// server params
|
// server params
|
||||||
int32_t port = 8080; // server listens on this network port
|
int32_t port = 8080; // server listens on this network port
|
||||||
int32_t timeout_read = 600; // http read timeout in seconds
|
int32_t timeout_read = 600; // http read timeout in seconds
|
||||||
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
||||||
int n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||||
|
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||||
|
|
||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
std::string public_path = ""; // NOLINT
|
std::string public_path = ""; // NOLINT
|
||||||
@ -373,8 +386,6 @@ bool set_process_priority(enum ggml_sched_priority prio);
|
|||||||
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
||||||
std::string string_format(const char * fmt, ...);
|
std::string string_format(const char * fmt, ...);
|
||||||
|
|
||||||
std::vector<std::string> string_split(std::string input, char separator);
|
|
||||||
|
|
||||||
std::string string_strip(const std::string & str);
|
std::string string_strip(const std::string & str);
|
||||||
std::string string_get_sortable_timestamp();
|
std::string string_get_sortable_timestamp();
|
||||||
|
|
||||||
@ -382,6 +393,7 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
|||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
static std::vector<T> string_split(const std::string & str, char delim) {
|
static std::vector<T> string_split(const std::string & str, char delim) {
|
||||||
|
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
|
||||||
std::vector<T> values;
|
std::vector<T> values;
|
||||||
std::istringstream str_stream(str);
|
std::istringstream str_stream(str);
|
||||||
std::string token;
|
std::string token;
|
||||||
@ -394,6 +406,22 @@ static std::vector<T> string_split(const std::string & str, char delim) {
|
|||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
|
||||||
|
{
|
||||||
|
std::vector<std::string> parts;
|
||||||
|
size_t begin_pos = 0;
|
||||||
|
size_t separator_pos = input.find(separator);
|
||||||
|
while (separator_pos != std::string::npos) {
|
||||||
|
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
|
||||||
|
parts.emplace_back(part);
|
||||||
|
begin_pos = separator_pos + 1;
|
||||||
|
separator_pos = input.find(separator, begin_pos);
|
||||||
|
}
|
||||||
|
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
|
||||||
|
return parts;
|
||||||
|
}
|
||||||
|
|
||||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||||
void string_process_escapes(std::string & input);
|
void string_process_escapes(std::string & input);
|
||||||
|
|
||||||
|
@ -611,7 +611,7 @@ private:
|
|||||||
}
|
}
|
||||||
return join_seq();
|
return join_seq();
|
||||||
};
|
};
|
||||||
return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space");
|
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -130,10 +130,12 @@ std::string common_sampler_params::print() const {
|
|||||||
|
|
||||||
snprintf(result, sizeof(result),
|
snprintf(result, sizeof(result),
|
||||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||||
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
|
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
||||||
|
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
||||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||||
top_k, tfs_z, top_p, min_p, typ_p, temp,
|
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
||||||
|
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
||||||
mirostat, mirostat_eta, mirostat_tau);
|
mirostat, mirostat_eta, mirostat_tau);
|
||||||
|
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
@ -171,10 +173,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
params.penalize_nl,
|
params.penalize_nl,
|
||||||
params.ignore_eos));
|
params.ignore_eos));
|
||||||
|
|
||||||
if (params.temp > 0.0f) {
|
|
||||||
if (params.mirostat == 0) {
|
if (params.mirostat == 0) {
|
||||||
for (const auto & cnstr : params.samplers) {
|
for (const auto & cnstr : params.samplers) {
|
||||||
switch (cnstr) {
|
switch (cnstr) {
|
||||||
|
case COMMON_SAMPLER_TYPE_DRY:
|
||||||
|
{
|
||||||
|
std::vector<const char*> c_breakers;
|
||||||
|
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||||
|
for (const auto& str : params.dry_sequence_breakers) {
|
||||||
|
c_breakers.push_back(str.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||||
|
}
|
||||||
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||||
break;
|
break;
|
||||||
@ -184,8 +196,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TFS_Z:
|
case COMMON_SAMPLER_TYPE_XTC:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||||
@ -193,11 +205,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||||
break;
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_INFILL:
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unknown sampler type");
|
GGML_ASSERT(false && "unknown sampler type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||||
} else if (params.mirostat == 1) {
|
} else if (params.mirostat == 1) {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||||
@ -208,18 +222,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false && "unknown mirostat version");
|
GGML_ASSERT(false && "unknown mirostat version");
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
if (params.n_probs > 0) {
|
|
||||||
// some use cases require to sample greedily, but still obtain the probabilities of the top tokens
|
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/9605
|
|
||||||
//
|
|
||||||
// the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
|
|
||||||
// it is much faster, since we avoid sorting all tokens and should give a good approximation
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
|
||||||
}
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -366,36 +368,42 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
|
|||||||
|
|
||||||
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
||||||
switch (cnstr) {
|
switch (cnstr) {
|
||||||
|
case COMMON_SAMPLER_TYPE_DRY: return 'd';
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
|
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
|
||||||
case COMMON_SAMPLER_TYPE_TFS_Z: return 'f';
|
|
||||||
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
||||||
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
|
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
||||||
|
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||||
|
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
||||||
default : return '?';
|
default : return '?';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
||||||
switch (cnstr) {
|
switch (cnstr) {
|
||||||
|
case COMMON_SAMPLER_TYPE_DRY: return "dry";
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
|
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
|
||||||
case COMMON_SAMPLER_TYPE_TFS_Z: return "tfs_z";
|
|
||||||
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
||||||
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
|
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
||||||
|
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||||
|
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
||||||
default : return "";
|
default : return "";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
||||||
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
|
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
|
||||||
|
{ "dry", COMMON_SAMPLER_TYPE_DRY },
|
||||||
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
|
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
|
||||||
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
|
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
|
||||||
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
|
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||||
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
|
|
||||||
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||||
|
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||||
|
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
||||||
};
|
};
|
||||||
|
|
||||||
// since samplers names are written multiple ways
|
// since samplers names are written multiple ways
|
||||||
@ -409,8 +417,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||||||
{ "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
{ "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||||
{ "tfs-z", COMMON_SAMPLER_TYPE_TFS_Z },
|
|
||||||
{ "tfs", COMMON_SAMPLER_TYPE_TFS_Z },
|
|
||||||
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -436,12 +442,14 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||||||
|
|
||||||
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
|
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
|
||||||
std::unordered_map<char, common_sampler_type> sampler_name_map = {
|
std::unordered_map<char, common_sampler_type> sampler_name_map = {
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TFS_Z), COMMON_SAMPLER_TYPE_TFS_Z },
|
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<common_sampler_type> samplers;
|
std::vector<common_sampler_type> samplers;
|
||||||
|
@ -573,6 +573,9 @@ class Model:
|
|||||||
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
|
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
|
||||||
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5
|
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5
|
||||||
res = "bert-bge"
|
res = "bert-bge"
|
||||||
|
if chkhsh == "8e62295832751ca1e8f92f2226f403dea30dc5165e448b5bfa05af5340c64ec7":
|
||||||
|
# ref: https://huggingface.co/BAAI/bge-large-zh-v1.5
|
||||||
|
res = "bert-bge-large"
|
||||||
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
|
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
|
||||||
# ref: https://huggingface.co/mosaicml/mpt-7b
|
# ref: https://huggingface.co/mosaicml/mpt-7b
|
||||||
res = "mpt"
|
res = "mpt"
|
||||||
@ -2864,6 +2867,9 @@ class Rwkv6Model(Model):
|
|||||||
self.gguf_writer.add_token_list(tokens)
|
self.gguf_writer.add_token_list(tokens)
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||||
|
special_vocab.chat_template = "rwkv-world"
|
||||||
|
# hack: Add '\n\n' as the EOT token to make it chat normally
|
||||||
|
special_vocab._set_special_token("eot", 261)
|
||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
|
@ -72,6 +72,7 @@ models = [
|
|||||||
{"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", },
|
{"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", },
|
||||||
{"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", },
|
{"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", },
|
||||||
{"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", },
|
{"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", },
|
||||||
|
{"name": "bert-bge-large", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", },
|
||||||
{"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", },
|
{"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", },
|
||||||
{"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", },
|
{"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", },
|
||||||
{"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", },
|
{"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", },
|
||||||
|
@ -230,7 +230,7 @@ def get_base_tensor_name(lora_tensor_name: str) -> str:
|
|||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file")
|
description="Convert a Hugging Face PEFT LoRA adapter to a GGUF file")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--outfile", type=Path,
|
"--outfile", type=Path,
|
||||||
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
||||||
@ -257,11 +257,11 @@ def parse_args() -> argparse.Namespace:
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base", type=Path, required=True,
|
"--base", type=Path, required=True,
|
||||||
help="directory containing base model file",
|
help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"lora_path", type=Path,
|
"lora_path", type=Path,
|
||||||
help="directory containing LoRA adapter file",
|
help="directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
@ -348,6 +348,9 @@ if __name__ == '__main__':
|
|||||||
if ".base_layer.weight" in name:
|
if ".base_layer.weight" in name:
|
||||||
continue
|
continue
|
||||||
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
|
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
|
||||||
|
if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
|
||||||
|
logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
|
||||||
|
logger.error("Hint: if you are using TRL, make sure not to call setup_chat_format()")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if base_name in tensor_map:
|
if base_name in tensor_map:
|
||||||
|
@ -74,7 +74,6 @@ int main(int argc, char ** argv) {
|
|||||||
batch.n_seq_id + i,
|
batch.n_seq_id + i,
|
||||||
batch.seq_id + i,
|
batch.seq_id + i,
|
||||||
batch.logits + i,
|
batch.logits + i,
|
||||||
0, 0, 0, // unused
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
|
@ -339,7 +339,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
|
|||||||
|
|
||||||
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
|
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -131,7 +131,7 @@ static bool run(llama_context * ctx, const common_params & params) {
|
|||||||
|
|
||||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
|
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
|
||||||
|
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -496,6 +496,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
|||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
@ -508,9 +510,14 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
|||||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: use batch.logits to save computations instead of relying on logits_all == true
|
common_batch_clear(batch);
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
for (int i = 0; i < batch_size; i++) {
|
||||||
|
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_decode(ctx, batch)) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
|
llama_batch_free(batch);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -523,6 +530,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
|
@ -396,7 +396,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||||
|
|
||||||
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -540,7 +540,7 @@ class SchemaConverter:
|
|||||||
return self._add_rule(
|
return self._add_rule(
|
||||||
name,
|
name,
|
||||||
to_rule(transform()) if self._raw_pattern \
|
to_rule(transform()) if self._raw_pattern \
|
||||||
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
|
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
|
||||||
|
|
||||||
|
|
||||||
def _resolve_ref(self, ref):
|
def _resolve_ref(self, ref):
|
||||||
|
@ -21,12 +21,6 @@
|
|||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "ggml-cuda.h"
|
|
||||||
#include "ggml-sycl.h"
|
|
||||||
|
|
||||||
#ifdef GGML_USE_CANN
|
|
||||||
#include "ggml-cann.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#define WIN32_LEAN_AND_MEAN
|
#define WIN32_LEAN_AND_MEAN
|
||||||
@ -82,95 +76,27 @@ static T stdev(const std::vector<T> & v) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static std::string get_cpu_info() {
|
static std::string get_cpu_info() {
|
||||||
std::string id;
|
std::vector<std::string> cpu_list;
|
||||||
#ifdef __linux__
|
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
||||||
FILE * f = fopen("/proc/cpuinfo", "r");
|
auto * dev = ggml_backend_dev_get(i);
|
||||||
if (f) {
|
auto dev_type = ggml_backend_dev_type(dev);
|
||||||
char buf[1024];
|
if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU || dev_type == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
|
||||||
while (fgets(buf, sizeof(buf), f)) {
|
cpu_list.push_back(ggml_backend_dev_description(dev));
|
||||||
if (strncmp(buf, "model name", 10) == 0) {
|
|
||||||
char * p = strchr(buf, ':');
|
|
||||||
if (p) {
|
|
||||||
p++;
|
|
||||||
while (std::isspace(*p)) {
|
|
||||||
p++;
|
|
||||||
}
|
|
||||||
while (std::isspace(p[strlen(p) - 1])) {
|
|
||||||
p[strlen(p) - 1] = '\0';
|
|
||||||
}
|
|
||||||
id = p;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
return join(cpu_list, ", ");
|
||||||
fclose(f);
|
|
||||||
}
|
|
||||||
#elif defined(_WIN32)
|
|
||||||
HKEY hKey;
|
|
||||||
if (RegOpenKeyEx(HKEY_LOCAL_MACHINE,
|
|
||||||
TEXT("HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"),
|
|
||||||
0,
|
|
||||||
KEY_READ,
|
|
||||||
&hKey) != ERROR_SUCCESS) {
|
|
||||||
// fail to open registry key
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
char cpu_brand[256];
|
|
||||||
DWORD cpu_brand_size = sizeof(cpu_brand);
|
|
||||||
if (RegQueryValueExA(hKey,
|
|
||||||
TEXT("ProcessorNameString"),
|
|
||||||
NULL,
|
|
||||||
NULL,
|
|
||||||
(LPBYTE)cpu_brand,
|
|
||||||
&cpu_brand_size) == ERROR_SUCCESS) {
|
|
||||||
id.assign(cpu_brand, cpu_brand_size);
|
|
||||||
if (id.find('\0') != std::string::npos) {
|
|
||||||
id.resize(id.find('\0'));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
RegCloseKey(hKey);
|
|
||||||
#endif
|
|
||||||
// TODO: other platforms
|
|
||||||
return id;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string get_gpu_info() {
|
static std::string get_gpu_info() {
|
||||||
std::string id;
|
std::vector<std::string> gpu_list;
|
||||||
#ifdef GGML_USE_CUDA
|
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
||||||
int count = ggml_backend_cuda_get_device_count();
|
auto * dev = ggml_backend_dev_get(i);
|
||||||
for (int i = 0; i < count; i++) {
|
auto dev_type = ggml_backend_dev_type(dev);
|
||||||
char buf[128];
|
if (dev_type == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
||||||
ggml_backend_cuda_get_device_description(i, buf, sizeof(buf));
|
gpu_list.push_back(ggml_backend_dev_description(dev));
|
||||||
id += buf;
|
|
||||||
if (i < count - 1) {
|
|
||||||
id += "/";
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
return join(gpu_list, ", ");
|
||||||
#ifdef GGML_USE_SYCL
|
|
||||||
int count = ggml_backend_sycl_get_device_count();
|
|
||||||
for (int i = 0; i < count; i++) {
|
|
||||||
char buf[128];
|
|
||||||
ggml_sycl_get_device_description(i, buf, sizeof(buf));
|
|
||||||
id += buf;
|
|
||||||
if (i < count - 1) {
|
|
||||||
id += "/";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#ifdef GGML_USE_CANN
|
|
||||||
uint32_t count = ggml_backend_cann_get_device_count();
|
|
||||||
for (uint32_t i = 0; i < count; i++) {
|
|
||||||
char buf[128];
|
|
||||||
ggml_backend_cann_get_device_description(i, buf, sizeof(buf));
|
|
||||||
id += buf;
|
|
||||||
if (i < count - 1) {
|
|
||||||
id += "/";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
// TODO: other backends
|
|
||||||
return id;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// command line params
|
// command line params
|
||||||
@ -938,29 +864,15 @@ struct test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static std::string get_backend() {
|
static std::string get_backend() {
|
||||||
if (cuda) {
|
std::vector<std::string> backends;
|
||||||
return GGML_CUDA_NAME;
|
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
||||||
|
auto * reg = ggml_backend_reg_get(i);
|
||||||
|
std::string name = ggml_backend_reg_name(reg);
|
||||||
|
if (name != "CPU") {
|
||||||
|
backends.push_back(ggml_backend_reg_name(reg));
|
||||||
}
|
}
|
||||||
if (vulkan) {
|
|
||||||
return "Vulkan";
|
|
||||||
}
|
}
|
||||||
if (kompute) {
|
return backends.empty() ? "CPU" : join(backends, ",");
|
||||||
return "Kompute";
|
|
||||||
}
|
|
||||||
if (metal) {
|
|
||||||
return "Metal";
|
|
||||||
}
|
|
||||||
if (sycl) {
|
|
||||||
return GGML_SYCL_NAME;
|
|
||||||
}
|
|
||||||
if (gpu_blas) {
|
|
||||||
return "GPU BLAS";
|
|
||||||
}
|
|
||||||
if (blas) {
|
|
||||||
return "BLAS";
|
|
||||||
}
|
|
||||||
|
|
||||||
return "CPU";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static const std::vector<std::string> & get_fields() {
|
static const std::vector<std::string> & get_fields() {
|
||||||
@ -1428,7 +1340,7 @@ struct sql_printer : public printer {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
|
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
|
||||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||||
|
|
||||||
const llama_model * model = llama_get_model(ctx);
|
const llama_model * model = llama_get_model(ctx);
|
||||||
@ -1444,14 +1356,14 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
|
|||||||
for (int i = 1; i < n_tokens; i++) {
|
for (int i = 1; i < n_tokens; i++) {
|
||||||
tokens[i] = std::rand() % n_vocab;
|
tokens[i] = std::rand() % n_vocab;
|
||||||
}
|
}
|
||||||
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
|
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
|
||||||
n_processed += n_tokens;
|
n_processed += n_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
|
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
|
||||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||||
|
|
||||||
const llama_model * model = llama_get_model(ctx);
|
const llama_model * model = llama_get_model(ctx);
|
||||||
@ -1460,7 +1372,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads)
|
|||||||
llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;
|
llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;
|
||||||
|
|
||||||
for (int i = 0; i < n_gen; i++) {
|
for (int i = 0; i < n_gen; i++) {
|
||||||
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
|
llama_decode(ctx, llama_batch_get_one(&token, 1));
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
token = std::rand() % n_vocab;
|
token = std::rand() % n_vocab;
|
||||||
}
|
}
|
||||||
@ -1596,13 +1508,13 @@ int main(int argc, char ** argv) {
|
|||||||
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup prompt run\n", params_idx, params_count);
|
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup prompt run\n", params_idx, params_count);
|
||||||
}
|
}
|
||||||
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
|
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
|
||||||
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
|
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||||
}
|
}
|
||||||
if (t.n_gen > 0) {
|
if (t.n_gen > 0) {
|
||||||
if (params.progress) {
|
if (params.progress) {
|
||||||
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup generation run\n", params_idx, params_count);
|
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup generation run\n", params_idx, params_count);
|
||||||
}
|
}
|
||||||
test_gen(ctx, 1, 0, t.n_threads);
|
test_gen(ctx, 1, t.n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < params.reps; i++) {
|
for (int i = 0; i < params.reps; i++) {
|
||||||
@ -1614,13 +1526,13 @@ int main(int argc, char ** argv) {
|
|||||||
if (params.progress) {
|
if (params.progress) {
|
||||||
fprintf(stderr, "llama-bench: benchmark %d/%ld: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps);
|
fprintf(stderr, "llama-bench: benchmark %d/%ld: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps);
|
||||||
}
|
}
|
||||||
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
|
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||||
}
|
}
|
||||||
if (t.n_gen > 0) {
|
if (t.n_gen > 0) {
|
||||||
if (params.progress) {
|
if (params.progress) {
|
||||||
fprintf(stderr, "llama-bench: benchmark %d/%ld: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps);
|
fprintf(stderr, "llama-bench: benchmark %d/%ld: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps);
|
||||||
}
|
}
|
||||||
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
|
test_gen(ctx, t.n_gen, t.n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t t_ns = get_time_ns() - t_start;
|
uint64_t t_ns = get_time_ns() - t_start;
|
||||||
|
@ -283,9 +283,6 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
|
|||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (embd) {
|
if (embd) {
|
||||||
|
@ -46,7 +46,6 @@ actor LlamaContext {
|
|||||||
let sparams = llama_sampler_chain_default_params()
|
let sparams = llama_sampler_chain_default_params()
|
||||||
self.sampling = llama_sampler_chain_init(sparams)
|
self.sampling = llama_sampler_chain_init(sparams)
|
||||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
|
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
|
||||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
|
|
||||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
|
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
783
examples/llama.vim
Normal file
783
examples/llama.vim
Normal file
@ -0,0 +1,783 @@
|
|||||||
|
" LLM-based text completion using llama.cpp
|
||||||
|
"
|
||||||
|
" requires:
|
||||||
|
"
|
||||||
|
" - neovim or vim
|
||||||
|
" - curl
|
||||||
|
" - llama.cpp server instance
|
||||||
|
" - FIM-compatible model
|
||||||
|
"
|
||||||
|
" sample config:
|
||||||
|
"
|
||||||
|
" - Tab - accept the current suggestion
|
||||||
|
" - Shift+Tab - accept just the first line of the suggestion
|
||||||
|
" - Ctrl+F - toggle FIM completion manually
|
||||||
|
"
|
||||||
|
" make symlink or copy this file to ~/.config/nvim/autoload/llama.vim
|
||||||
|
"
|
||||||
|
" start the llama.cpp server with a FIM-compatible model. for example:
|
||||||
|
"
|
||||||
|
" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa -dt 0.1 --ubatch-size 512 --batch-size 1024 --cache-reuse 256
|
||||||
|
"
|
||||||
|
" --batch-size [512, model max context]
|
||||||
|
"
|
||||||
|
" adjust the batch size to control how much of the provided local context will be used during the inference
|
||||||
|
" lower values will use smaller part of the context around the cursor, which will result in faster processing
|
||||||
|
"
|
||||||
|
" --ubatch-size [64, 2048]
|
||||||
|
"
|
||||||
|
" chunks the batch into smaller chunks for faster processing
|
||||||
|
" depends on the specific hardware. use llama-bench to profile and determine the best size
|
||||||
|
"
|
||||||
|
" --cache-reuse (ge:llama_config.n_predict, 1024]
|
||||||
|
"
|
||||||
|
" this should be either 0 (disabled) or strictly larger than g:llama_config.n_predict
|
||||||
|
" using non-zero value enables context reuse on the server side which dramatically improves the performance at
|
||||||
|
" large contexts. a value of 256 should be good for all cases
|
||||||
|
"
|
||||||
|
" run this once to initialise llama.vim:
|
||||||
|
"
|
||||||
|
" :call llama#init()
|
||||||
|
"
|
||||||
|
" more info: https://github.com/ggerganov/llama.cpp/pull/9787
|
||||||
|
"
|
||||||
|
|
||||||
|
" colors (adjust to your liking)
|
||||||
|
highlight llama_hl_hint guifg=#ff772f ctermfg=202
|
||||||
|
highlight llama_hl_info guifg=#77ff2f ctermfg=119
|
||||||
|
|
||||||
|
" general parameters:
|
||||||
|
"
|
||||||
|
" endpoint: llama.cpp server endpoint
|
||||||
|
" n_prefix: number of lines before the cursor location to include in the local prefix
|
||||||
|
" n_suffix: number of lines after the cursor location to include in the local suffix
|
||||||
|
" n_predict: max number of tokens to predict
|
||||||
|
" t_max_prompt_ms: max alloted time for the prompt processing (TODO: not yet supported)
|
||||||
|
" t_max_predict_ms: max alloted time for the prediction
|
||||||
|
" show_info: show extra info about the inference (0 - disabled, 1 - statusline, 2 - inline)
|
||||||
|
" auto_fim: trigger FIM completion automatically on cursor movement
|
||||||
|
" max_line_suffix: do not auto-trigger FIM completion if there are more than this number of characters to the right of the cursor
|
||||||
|
"
|
||||||
|
" ring buffer of chunks, accumulated with time upon:
|
||||||
|
"
|
||||||
|
" - completion request
|
||||||
|
" - yank
|
||||||
|
" - entering a buffer
|
||||||
|
" - leaving a buffer
|
||||||
|
" - writing a file
|
||||||
|
"
|
||||||
|
" parameters for the ring-buffer with extra context:
|
||||||
|
"
|
||||||
|
" ring_n_chunks: max number of chunks to pass as extra context to the server (0 to disable)
|
||||||
|
" ring_chunk_size: max size of the chunks (in number of lines)
|
||||||
|
" note: adjust these numbers so that you don't overrun your context
|
||||||
|
" at ring_n_chunks = 64 and ring_chunk_size = 64 you need ~32k context
|
||||||
|
" ring_scope: the range around the cursor position (in number of lines) for gathering chunks after FIM
|
||||||
|
" ring_update_ms: how often to process queued chunks in normal mode
|
||||||
|
"
|
||||||
|
let s:default_config = {
|
||||||
|
\ 'endpoint': 'http://127.0.0.1:8012/infill',
|
||||||
|
\ 'n_prefix': 256,
|
||||||
|
\ 'n_suffix': 64,
|
||||||
|
\ 'n_predict': 128,
|
||||||
|
\ 't_max_prompt_ms': 500,
|
||||||
|
\ 't_max_predict_ms': 3000,
|
||||||
|
\ 'show_info': 2,
|
||||||
|
\ 'auto_fim': v:true,
|
||||||
|
\ 'max_line_suffix': 8,
|
||||||
|
\ 'ring_n_chunks': 64,
|
||||||
|
\ 'ring_chunk_size': 64,
|
||||||
|
\ 'ring_scope': 1024,
|
||||||
|
\ 'ring_update_ms': 1000,
|
||||||
|
\ }
|
||||||
|
|
||||||
|
let g:llama_config = get(g:, 'llama_config', s:default_config)
|
||||||
|
|
||||||
|
function! s:get_indent(str)
|
||||||
|
let l:count = 0
|
||||||
|
for i in range(len(a:str))
|
||||||
|
if a:str[i] == "\t"
|
||||||
|
let l:count += &tabstop - 1
|
||||||
|
else
|
||||||
|
break
|
||||||
|
endif
|
||||||
|
endfor
|
||||||
|
return l:count
|
||||||
|
endfunction
|
||||||
|
|
||||||
|
function! s:rand(i0, i1) abort
|
||||||
|
return a:i0 + rand() % (a:i1 - a:i0 + 1)
|
||||||
|
endfunction
|
||||||
|
|
||||||
|
function! llama#init()
|
||||||
|
if !executable('curl')
|
||||||
|
echohl WarningMsg
|
||||||
|
echo 'llama.vim requires the "curl" command to be available'
|
||||||
|
echohl None
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
|
||||||
|
let s:pos_x = 0 " cursor position upon start of completion
|
||||||
|
let s:pos_y = 0
|
||||||
|
|
||||||
|
let s:line_cur = ''
|
||||||
|
|
||||||
|
let s:line_cur_prefix = ''
|
||||||
|
let s:line_cur_suffix = ''
|
||||||
|
|
||||||
|
let s:ring_chunks = [] " current set of chunks used as extra context
|
||||||
|
let s:ring_queued = [] " chunks that are queued to be sent for processing
|
||||||
|
let s:ring_n_evict = 0
|
||||||
|
|
||||||
|
let s:hint_shown = v:false
|
||||||
|
let s:pos_y_pick = -9999 " last y where we picked a chunk
|
||||||
|
let s:pos_dx = 0
|
||||||
|
let s:content = []
|
||||||
|
let s:can_accept = v:false
|
||||||
|
|
||||||
|
let s:timer_fim = -1
|
||||||
|
let s:t_fim_start = reltime() " used to measure total FIM time
|
||||||
|
let s:t_last_move = reltime() " last time the cursor moved
|
||||||
|
|
||||||
|
let s:current_job = v:null
|
||||||
|
|
||||||
|
let s:ghost_text_nvim = exists('*nvim_buf_get_mark')
|
||||||
|
let s:ghost_text_vim = has('textprop')
|
||||||
|
|
||||||
|
if s:ghost_text_vim
|
||||||
|
let s:hlgroup_hint = 'llama_hl_hint'
|
||||||
|
let s:hlgroup_info = 'llama_hl_info'
|
||||||
|
|
||||||
|
if empty(prop_type_get(s:hlgroup_hint))
|
||||||
|
call prop_type_add(s:hlgroup_hint, {'highlight': s:hlgroup_hint})
|
||||||
|
endif
|
||||||
|
if empty(prop_type_get(s:hlgroup_info))
|
||||||
|
call prop_type_add(s:hlgroup_info, {'highlight': s:hlgroup_info})
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
|
augroup llama
|
||||||
|
autocmd!
|
||||||
|
autocmd InsertEnter * inoremap <expr> <silent> <C-F> llama#fim_inline(v:false)
|
||||||
|
autocmd InsertLeavePre * call llama#fim_cancel()
|
||||||
|
|
||||||
|
autocmd CursorMoved * call s:on_move()
|
||||||
|
autocmd CursorMovedI * call s:on_move()
|
||||||
|
autocmd CompleteChanged * call llama#fim_cancel()
|
||||||
|
|
||||||
|
if g:llama_config.auto_fim
|
||||||
|
autocmd CursorMovedI * call llama#fim(v:true)
|
||||||
|
endif
|
||||||
|
|
||||||
|
" gather chunks upon yanking
|
||||||
|
autocmd TextYankPost * if v:event.operator ==# 'y' | call s:pick_chunk(v:event.regcontents, v:false, v:true) | endif
|
||||||
|
|
||||||
|
" gather chunks upon entering/leaving a buffer
|
||||||
|
autocmd BufEnter * call timer_start(100, {-> s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)})
|
||||||
|
autocmd BufLeave * call s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)
|
||||||
|
|
||||||
|
" gather chunk upon saving the file
|
||||||
|
autocmd BufWritePost * call s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)
|
||||||
|
augroup END
|
||||||
|
|
||||||
|
silent! call llama#fim_cancel()
|
||||||
|
|
||||||
|
" init background update of the ring buffer
|
||||||
|
if g:llama_config.ring_n_chunks > 0
|
||||||
|
call s:ring_update()
|
||||||
|
endif
|
||||||
|
endfunction
|
||||||
|
|
||||||
|
" compute how similar two chunks of text are
|
||||||
|
" 0 - no similarity, 1 - high similarity
|
||||||
|
" TODO: figure out something better
|
||||||
|
function! s:chunk_sim(c0, c1)
|
||||||
|
let l:lines0 = len(a:c0)
|
||||||
|
let l:lines1 = len(a:c1)
|
||||||
|
|
||||||
|
let l:common = 0
|
||||||
|
|
||||||
|
for l:line0 in a:c0
|
||||||
|
for l:line1 in a:c1
|
||||||
|
if l:line0 == l:line1
|
||||||
|
let l:common += 1
|
||||||
|
break
|
||||||
|
endif
|
||||||
|
endfor
|
||||||
|
endfor
|
||||||
|
|
||||||
|
return 2.0 * l:common / (l:lines0 + l:lines1)
|
||||||
|
endfunction
|
||||||
|
|
||||||
|
" pick a random chunk of size g:llama_config.ring_chunk_size from the provided text and queue it for processing
|
||||||
|
"
|
||||||
|
" no_mod - do not pick chunks from buffers with pending changes
|
||||||
|
" do_evict - evict chunks that are very similar to the new one
|
||||||
|
"
|
||||||
|
function! s:pick_chunk(text, no_mod, do_evict)
|
||||||
|
" do not pick chunks from buffers with pending changes or buffers that are not files
|
||||||
|
if a:no_mod && (getbufvar(bufnr('%'), '&modified') || !buflisted(bufnr('%')) || !filereadable(expand('%')))
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
|
||||||
|
" if the extra context option is disabled - do nothing
|
||||||
|
if g:llama_config.ring_n_chunks <= 0
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
|
||||||
|
" don't pick very small chunks
|
||||||
|
if len(a:text) < 3
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
|
||||||
|
if len(a:text) + 1 < g:llama_config.ring_chunk_size
|
||||||
|
let l:chunk = a:text
|
||||||
|
else
|
||||||
|
let l:l0 = s:rand(0, max([0, len(a:text) - g:llama_config.ring_chunk_size/2]))
|
||||||
|
let l:l1 = min([l:l0 + g:llama_config.ring_chunk_size/2, len(a:text)])
|
||||||
|
|
||||||
|
let l:chunk = a:text[l:l0:l:l1]
|
||||||
|
endif
|
||||||
|
|
||||||
|
let l:chunk_str = join(l:chunk, "\n") . "\n"
|
||||||
|
|
||||||
|
" check if this chunk is already added
|
||||||
|
let l:exist = v:false
|
||||||
|
|
||||||
|
for i in range(len(s:ring_chunks))
|
||||||
|
if s:ring_chunks[i].data == l:chunk
|
||||||
|
let l:exist = v:true
|
||||||
|
break
|
||||||
|
endif
|
||||||
|
endfor
|
||||||
|
|
||||||
|
for i in range(len(s:ring_queued))
|
||||||
|
if s:ring_queued[i].data == l:chunk
|
||||||
|
let l:exist = v:true
|
||||||
|
break
|
||||||
|
endif
|
||||||
|
endfor
|
||||||
|
|
||||||
|
if l:exist
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
|
||||||
|
" evict queued chunks that are very similar to the new one
|
||||||
|
for i in range(len(s:ring_queued) - 1, 0, -1)
|
||||||
|
if s:chunk_sim(s:ring_queued[i].data, l:chunk) > 0.9
|
||||||
|
if a:do_evict
|
||||||
|
call remove(s:ring_queued, i)
|
||||||
|
let s:ring_n_evict += 1
|
||||||
|
else
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
endfor
|
||||||
|
|
||||||
|
" also from s:ring_chunks
|
||||||
|
for i in range(len(s:ring_chunks) - 1, 0, -1)
|
||||||
|
if s:chunk_sim(s:ring_chunks[i].data, l:chunk) > 0.9
|
||||||
|
if a:do_evict
|
||||||
|
call remove(s:ring_chunks, i)
|
||||||
|
let s:ring_n_evict += 1
|
||||||
|
else
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
endfor
|
||||||
|
|
||||||
|
" TODO: become parameter ?
|
||||||
|
if len(s:ring_queued) == 16
|
||||||
|
call remove(s:ring_queued, 0)
|
||||||
|
endif
|
||||||
|
|
||||||
|
call add(s:ring_queued, {'data': l:chunk, 'str': l:chunk_str, 'time': reltime(), 'filename': expand('%')})
|
||||||
|
|
||||||
|
"let &statusline = 'extra context: ' . len(s:ring_chunks) . ' / ' . len(s:ring_queued)
|
||||||
|
endfunction
|
||||||
|
|
||||||
|
" picks a queued chunk, sends it for processing and adds it to s:ring_chunks
|
||||||
|
" called every g:llama_config.ring_update_ms
|
||||||
|
function! s:ring_update()
|
||||||
|
call timer_start(g:llama_config.ring_update_ms, {-> s:ring_update()})
|
||||||
|
|
||||||
|
" update only if in normal mode or if the cursor hasn't moved for a while
|
||||||
|
if mode() !=# 'n' && reltimefloat(reltime(s:t_last_move)) < 3.0
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
|
||||||
|
if len(s:ring_queued) == 0
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
|
||||||
|
" move the first queued chunk to the ring buffer
|
||||||
|
if len(s:ring_chunks) == g:llama_config.ring_n_chunks
|
||||||
|
call remove(s:ring_chunks, 0)
|
||||||
|
endif
|
||||||
|
|
||||||
|
call add(s:ring_chunks, remove(s:ring_queued, 0))
|
||||||
|
|
||||||
|
"let &statusline = 'updated context: ' . len(s:ring_chunks) . ' / ' . len(s:ring_queued)
|
||||||
|
|
||||||
|
" send asynchronous job with the new extra context so that it is ready for the next FIM
|
||||||
|
let l:extra_context = []
|
||||||
|
for l:chunk in s:ring_chunks
|
||||||
|
call add(l:extra_context, {
|
||||||
|
\ 'text': l:chunk.str,
|
||||||
|
\ 'time': l:chunk.time,
|
||||||
|
\ 'filename': l:chunk.filename
|
||||||
|
\ })
|
||||||
|
endfor
|
||||||
|
|
||||||
|
" no samplers needed here
|
||||||
|
let l:request = json_encode({
|
||||||
|
\ 'input_prefix': "",
|
||||||
|
\ 'input_suffix': "",
|
||||||
|
\ 'input_extra': l:extra_context,
|
||||||
|
\ 'prompt': "",
|
||||||
|
\ 'n_predict': 1,
|
||||||
|
\ 'temperature': 0.0,
|
||||||
|
\ 'stream': v:false,
|
||||||
|
\ 'samplers': ["temperature"],
|
||||||
|
\ 'cache_prompt': v:true,
|
||||||
|
\ 't_max_prompt_ms': 1,
|
||||||
|
\ 't_max_predict_ms': 1
|
||||||
|
\ })
|
||||||
|
|
||||||
|
let l:curl_command = [
|
||||||
|
\ "curl",
|
||||||
|
\ "--silent",
|
||||||
|
\ "--no-buffer",
|
||||||
|
\ "--request", "POST",
|
||||||
|
\ "--url", g:llama_config.endpoint,
|
||||||
|
\ "--header", "Content-Type: application/json",
|
||||||
|
\ "--data", l:request
|
||||||
|
\ ]
|
||||||
|
|
||||||
|
" no callbacks because we don't need to process the response
|
||||||
|
if s:ghost_text_nvim
|
||||||
|
call jobstart(l:curl_command, {})
|
||||||
|
elseif s:ghost_text_vim
|
||||||
|
call job_start(l:curl_command, {})
|
||||||
|
endif
|
||||||
|
endfunction
|
||||||
|
|
||||||
|
" necessary for 'inoremap <expr>'
|
||||||
|
function! llama#fim_inline(is_auto) abort
|
||||||
|
call llama#fim(a:is_auto)
|
||||||
|
return ''
|
||||||
|
endfunction
|
||||||
|
|
||||||
|
" the main FIM call
|
||||||
|
" takes local context around the cursor and sends it together with the extra context to the server for completion
|
||||||
|
function! llama#fim(is_auto) abort
|
||||||
|
" we already have a suggestion for the current cursor position
|
||||||
|
if s:hint_shown && !a:is_auto
|
||||||
|
call llama#fim_cancel()
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
|
||||||
|
call llama#fim_cancel()
|
||||||
|
|
||||||
|
" avoid sending repeated requests too fast
|
||||||
|
if reltimefloat(reltime(s:t_fim_start)) < 0.6
|
||||||
|
if s:timer_fim != -1
|
||||||
|
call timer_stop(s:timer_fim)
|
||||||
|
let s:timer_fim = -1
|
||||||
|
endif
|
||||||
|
|
||||||
|
let s:t_fim_start = reltime()
|
||||||
|
let s:timer_fim = timer_start(600, {-> llama#fim(v:true)})
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
|
||||||
|
let s:t_fim_start = reltime()
|
||||||
|
|
||||||
|
let s:content = []
|
||||||
|
let s:can_accept = v:false
|
||||||
|
|
||||||
|
let s:pos_x = col('.') - 1
|
||||||
|
let s:pos_y = line('.')
|
||||||
|
let l:max_y = line('$')
|
||||||
|
|
||||||
|
let l:lines_prefix = getline(max([1, s:pos_y - g:llama_config.n_prefix]), s:pos_y - 1)
|
||||||
|
let l:lines_suffix = getline(s:pos_y + 1, min([l:max_y, s:pos_y + g:llama_config.n_suffix]))
|
||||||
|
|
||||||
|
let s:line_cur = getline('.')
|
||||||
|
|
||||||
|
let s:line_cur_prefix = strpart(s:line_cur, 0, s:pos_x)
|
||||||
|
let s:line_cur_suffix = strpart(s:line_cur, s:pos_x)
|
||||||
|
|
||||||
|
if a:is_auto && len(s:line_cur_suffix) > g:llama_config.max_line_suffix
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
|
||||||
|
let l:prefix = ""
|
||||||
|
\ . join(l:lines_prefix, "\n")
|
||||||
|
\ . "\n"
|
||||||
|
|
||||||
|
let l:prompt = ""
|
||||||
|
\ . s:line_cur_prefix
|
||||||
|
|
||||||
|
let l:suffix = ""
|
||||||
|
\ . s:line_cur_suffix
|
||||||
|
\ . "\n"
|
||||||
|
\ . join(l:lines_suffix, "\n")
|
||||||
|
\ . "\n"
|
||||||
|
|
||||||
|
" prepare the extra context data
|
||||||
|
let l:extra_context = []
|
||||||
|
for l:chunk in s:ring_chunks
|
||||||
|
call add(l:extra_context, {
|
||||||
|
\ 'text': l:chunk.str,
|
||||||
|
\ 'time': l:chunk.time,
|
||||||
|
\ 'filename': l:chunk.filename
|
||||||
|
\ })
|
||||||
|
endfor
|
||||||
|
|
||||||
|
" the indentation of the current line
|
||||||
|
let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*'))
|
||||||
|
|
||||||
|
let l:request = json_encode({
|
||||||
|
\ 'input_prefix': l:prefix,
|
||||||
|
\ 'input_suffix': l:suffix,
|
||||||
|
\ 'input_extra': l:extra_context,
|
||||||
|
\ 'prompt': l:prompt,
|
||||||
|
\ 'n_predict': g:llama_config.n_predict,
|
||||||
|
\ 'n_indent': l:indent,
|
||||||
|
\ 'top_k': 40,
|
||||||
|
\ 'top_p': 0.99,
|
||||||
|
\ 'stream': v:false,
|
||||||
|
\ 'samplers': ["top_k", "top_p", "infill"],
|
||||||
|
\ 'cache_prompt': v:true,
|
||||||
|
\ 't_max_prompt_ms': g:llama_config.t_max_prompt_ms,
|
||||||
|
\ 't_max_predict_ms': g:llama_config.t_max_predict_ms
|
||||||
|
\ })
|
||||||
|
|
||||||
|
let l:curl_command = [
|
||||||
|
\ "curl",
|
||||||
|
\ "--silent",
|
||||||
|
\ "--no-buffer",
|
||||||
|
\ "--request", "POST",
|
||||||
|
\ "--url", g:llama_config.endpoint,
|
||||||
|
\ "--header", "Content-Type: application/json",
|
||||||
|
\ "--data", l:request
|
||||||
|
\ ]
|
||||||
|
|
||||||
|
if s:current_job != v:null
|
||||||
|
if s:ghost_text_nvim
|
||||||
|
call jobstop(s:current_job)
|
||||||
|
elseif s:ghost_text_vim
|
||||||
|
call job_stop(s:current_job)
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
|
" send the request asynchronously
|
||||||
|
if s:ghost_text_nvim
|
||||||
|
let s:current_job = jobstart(l:curl_command, {
|
||||||
|
\ 'on_stdout': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]),
|
||||||
|
\ 'on_exit': function('s:fim_on_exit'),
|
||||||
|
\ 'stdout_buffered': v:true
|
||||||
|
\ })
|
||||||
|
elseif s:ghost_text_vim
|
||||||
|
let s:current_job = job_start(l:curl_command, {
|
||||||
|
\ 'out_cb': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]),
|
||||||
|
\ 'exit_cb': function('s:fim_on_exit')
|
||||||
|
\ })
|
||||||
|
endif
|
||||||
|
|
||||||
|
" TODO: per-file location
|
||||||
|
let l:delta_y = abs(s:pos_y - s:pos_y_pick)
|
||||||
|
|
||||||
|
" gather some extra context nearby and process it in the background
|
||||||
|
" only gather chunks if the cursor has moved a lot
|
||||||
|
" TODO: something more clever? reranking?
|
||||||
|
if a:is_auto && l:delta_y > 32
|
||||||
|
" expand the prefix even further
|
||||||
|
call s:pick_chunk(getline(max([1, s:pos_y - g:llama_config.ring_scope]), max([1, s:pos_y - g:llama_config.n_prefix])), v:false, v:false)
|
||||||
|
|
||||||
|
" pick a suffix chunk
|
||||||
|
call s:pick_chunk(getline(min([l:max_y, s:pos_y + g:llama_config.n_suffix]), min([l:max_y, s:pos_y + g:llama_config.n_suffix + g:llama_config.ring_chunk_size])), v:false, v:false)
|
||||||
|
|
||||||
|
let s:pos_y_pick = s:pos_y
|
||||||
|
endif
|
||||||
|
endfunction
|
||||||
|
|
||||||
|
" if first_line == v:true accept only the first line of the response
|
||||||
|
function! llama#fim_accept(first_line)
|
||||||
|
" insert the suggestion at the cursor location
|
||||||
|
if s:can_accept && len(s:content) > 0
|
||||||
|
call setline(s:pos_y, s:line_cur[:(s:pos_x - 1)] . s:content[0])
|
||||||
|
if len(s:content) > 1
|
||||||
|
if !a:first_line
|
||||||
|
call append(s:pos_y, s:content[1:-1])
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
|
" move the cursor to the end of the accepted text
|
||||||
|
if !a:first_line && len(s:content) > 1
|
||||||
|
call cursor(s:pos_y + len(s:content) - 1, s:pos_x + s:pos_dx + 1)
|
||||||
|
else
|
||||||
|
call cursor(s:pos_y, s:pos_x + len(s:content[0]))
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
|
call llama#fim_cancel()
|
||||||
|
endfunction
|
||||||
|
|
||||||
|
function! llama#fim_cancel()
|
||||||
|
let s:hint_shown = v:false
|
||||||
|
|
||||||
|
" clear the virtual text
|
||||||
|
let l:bufnr = bufnr('%')
|
||||||
|
|
||||||
|
if s:ghost_text_nvim
|
||||||
|
let l:id_vt_fim = nvim_create_namespace('vt_fim')
|
||||||
|
call nvim_buf_clear_namespace(l:bufnr, l:id_vt_fim, 0, -1)
|
||||||
|
elseif s:ghost_text_vim
|
||||||
|
call prop_remove({'type': s:hlgroup_hint, 'all': v:true})
|
||||||
|
call prop_remove({'type': s:hlgroup_info, 'all': v:true})
|
||||||
|
endif
|
||||||
|
|
||||||
|
" remove the mappings
|
||||||
|
silent! iunmap <buffer> <Tab>
|
||||||
|
silent! iunmap <buffer> <S-Tab>
|
||||||
|
silent! iunmap <buffer> <Esc>
|
||||||
|
endfunction
|
||||||
|
|
||||||
|
function! s:on_move()
|
||||||
|
let s:t_last_move = reltime()
|
||||||
|
|
||||||
|
call llama#fim_cancel()
|
||||||
|
endfunction
|
||||||
|
|
||||||
|
" callback that processes the FIM result from the server and displays the suggestion
|
||||||
|
function! s:fim_on_stdout(pos_x, pos_y, is_auto, job_id, data, event = v:null)
|
||||||
|
if s:ghost_text_nvim
|
||||||
|
let l:raw = join(a:data, "\n")
|
||||||
|
elseif s:ghost_text_vim
|
||||||
|
let l:raw = a:data
|
||||||
|
endif
|
||||||
|
|
||||||
|
if len(l:raw) == 0
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
|
||||||
|
if a:pos_x != col('.') - 1 || a:pos_y != line('.')
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
|
||||||
|
" show the suggestion only in insert mode
|
||||||
|
if mode() !=# 'i'
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
|
||||||
|
let s:pos_x = a:pos_x
|
||||||
|
let s:pos_y = a:pos_y
|
||||||
|
|
||||||
|
let s:can_accept = v:true
|
||||||
|
let l:has_info = v:false
|
||||||
|
|
||||||
|
if s:can_accept && v:shell_error
|
||||||
|
if !a:is_auto
|
||||||
|
call add(s:content, "<| curl error: is the server on? |>")
|
||||||
|
endif
|
||||||
|
let s:can_accept = v:false
|
||||||
|
endif
|
||||||
|
|
||||||
|
let l:n_prompt = 0
|
||||||
|
let l:t_prompt_ms = 1.0
|
||||||
|
let l:s_prompt = 0
|
||||||
|
|
||||||
|
let l:n_predict = 0
|
||||||
|
let l:t_predict_ms = 1.0
|
||||||
|
let l:s_predict = 0
|
||||||
|
|
||||||
|
" get the generated suggestion
|
||||||
|
if s:can_accept
|
||||||
|
let l:response = json_decode(l:raw)
|
||||||
|
|
||||||
|
for l:part in split(get(l:response, 'content', ''), "\n", 1)
|
||||||
|
call add(s:content, l:part)
|
||||||
|
endfor
|
||||||
|
|
||||||
|
" remove trailing new lines
|
||||||
|
while len(s:content) > 0 && s:content[-1] == ""
|
||||||
|
call remove(s:content, -1)
|
||||||
|
endwhile
|
||||||
|
|
||||||
|
let l:generation_settings = get(l:response, 'generation_settings', {})
|
||||||
|
let l:n_ctx = get(l:generation_settings, 'n_ctx', 0)
|
||||||
|
|
||||||
|
let l:n_cached = get(l:response, 'tokens_cached', 0)
|
||||||
|
let l:truncated = get(l:response, 'truncated', v:false)
|
||||||
|
|
||||||
|
" if response.timings is available
|
||||||
|
if len(get(l:response, 'timings', {})) > 0
|
||||||
|
let l:has_info = v:true
|
||||||
|
let l:timings = get(l:response, 'timings', {})
|
||||||
|
|
||||||
|
let l:n_prompt = get(l:timings, 'prompt_n', 0)
|
||||||
|
let l:t_prompt_ms = get(l:timings, 'prompt_ms', 1)
|
||||||
|
let l:s_prompt = get(l:timings, 'prompt_per_second', 0)
|
||||||
|
|
||||||
|
let l:n_predict = get(l:timings, 'predicted_n', 0)
|
||||||
|
let l:t_predict_ms = get(l:timings, 'predicted_ms', 1)
|
||||||
|
let l:s_predict = get(l:timings, 'predicted_per_second', 0)
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
|
if len(s:content) == 0
|
||||||
|
call add(s:content, "")
|
||||||
|
let s:can_accept = v:false
|
||||||
|
endif
|
||||||
|
|
||||||
|
if len(s:content) == 0
|
||||||
|
return
|
||||||
|
endif
|
||||||
|
|
||||||
|
" NOTE: the following is logic for discarding predictions that repeat existing text
|
||||||
|
" the code is quite ugly and there is very likely a simpler and more canonical way to implement this
|
||||||
|
"
|
||||||
|
" still, I wonder if there is some better way that avoids having to do these special hacks?
|
||||||
|
" on one hand, the LLM 'sees' the contents of the file before we start editing, so it is normal that it would
|
||||||
|
" start generating whatever we have given it via the extra context. but on the other hand, it's not very
|
||||||
|
" helpful to re-generate the same code that is already there
|
||||||
|
|
||||||
|
" truncate the suggestion if the first line is empty
|
||||||
|
if len(s:content) == 1 && s:content[0] == ""
|
||||||
|
let s:content = [""]
|
||||||
|
endif
|
||||||
|
|
||||||
|
" ... and the next lines are repeated
|
||||||
|
if len(s:content) > 1 && s:content[0] == "" && s:content[1:] == getline(s:pos_y + 1, s:pos_y + len(s:content) - 1)
|
||||||
|
let s:content = [""]
|
||||||
|
endif
|
||||||
|
|
||||||
|
" truncate the suggestion if it repeats the suffix
|
||||||
|
if len(s:content) == 1 && s:content[0] == s:line_cur_suffix
|
||||||
|
let s:content = [""]
|
||||||
|
endif
|
||||||
|
|
||||||
|
" find the first non-empty line (strip whitespace)
|
||||||
|
let l:cmp_y = s:pos_y + 1
|
||||||
|
while l:cmp_y < line('$') && getline(l:cmp_y) =~? '^\s*$'
|
||||||
|
let l:cmp_y += 1
|
||||||
|
endwhile
|
||||||
|
|
||||||
|
if (s:line_cur_prefix . s:content[0]) == getline(l:cmp_y)
|
||||||
|
" truncate the suggestion if it repeats the next line
|
||||||
|
if len(s:content) == 1
|
||||||
|
let s:content = [""]
|
||||||
|
endif
|
||||||
|
|
||||||
|
" ... or if the second line of the suggestion is the prefix of line l:cmp_y + 1
|
||||||
|
if len(s:content) == 2 && s:content[-1] == getline(l:cmp_y + 1)[:len(s:content[-1]) - 1]
|
||||||
|
let s:content = [""]
|
||||||
|
endif
|
||||||
|
|
||||||
|
" ... or if the middle chunk of lines of the suggestion is the same as [l:cmp_y + 1, l:cmp_y + len(s:content) - 1)
|
||||||
|
if len(s:content) > 2 && join(s:content[1:-1], "\n") == join(getline(l:cmp_y + 1, l:cmp_y + len(s:content) - 1), "\n")
|
||||||
|
let s:content = [""]
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
|
" keep only lines that have the same or larger whitespace prefix as s:line_cur_prefix
|
||||||
|
"let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*'))
|
||||||
|
"for i in range(1, len(s:content) - 1)
|
||||||
|
" if strlen(matchstr(s:content[i], '^\s*')) < l:indent
|
||||||
|
" let s:content = s:content[:i - 1]
|
||||||
|
" break
|
||||||
|
" endif
|
||||||
|
"endfor
|
||||||
|
|
||||||
|
let s:pos_dx = len(s:content[-1])
|
||||||
|
|
||||||
|
let s:content[-1] .= s:line_cur_suffix
|
||||||
|
|
||||||
|
call llama#fim_cancel()
|
||||||
|
|
||||||
|
" display virtual text with the suggestion
|
||||||
|
let l:bufnr = bufnr('%')
|
||||||
|
|
||||||
|
if s:ghost_text_nvim
|
||||||
|
let l:id_vt_fim = nvim_create_namespace('vt_fim')
|
||||||
|
endif
|
||||||
|
|
||||||
|
" construct the info message
|
||||||
|
if g:llama_config.show_info > 0 && l:has_info
|
||||||
|
let l:prefix = ' '
|
||||||
|
|
||||||
|
if l:truncated
|
||||||
|
let l:info = printf("%s | WARNING: the context is full: %d / %d, increase the server context size or reduce g:llama_config.ring_n_chunks",
|
||||||
|
\ g:llama_config.show_info == 2 ? l:prefix : 'llama.vim',
|
||||||
|
\ l:n_cached, l:n_ctx
|
||||||
|
\ )
|
||||||
|
else
|
||||||
|
let l:info = printf("%s | c: %d / %d, r: %d / %d, e: %d, q: %d / 16 | p: %d (%.2f ms, %.2f t/s) | g: %d (%.2f ms, %.2f t/s) | t: %.2f ms",
|
||||||
|
\ g:llama_config.show_info == 2 ? l:prefix : 'llama.vim',
|
||||||
|
\ l:n_cached, l:n_ctx, len(s:ring_chunks), g:llama_config.ring_n_chunks, s:ring_n_evict, len(s:ring_queued),
|
||||||
|
\ l:n_prompt, l:t_prompt_ms, l:s_prompt,
|
||||||
|
\ l:n_predict, l:t_predict_ms, l:s_predict,
|
||||||
|
\ 1000.0 * reltimefloat(reltime(s:t_fim_start))
|
||||||
|
\ )
|
||||||
|
endif
|
||||||
|
|
||||||
|
if g:llama_config.show_info == 1
|
||||||
|
" display the info in the statusline
|
||||||
|
let &statusline = l:info
|
||||||
|
let l:info = ''
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
|
" display the suggestion and append the info to the end of the first line
|
||||||
|
if s:ghost_text_nvim
|
||||||
|
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, s:pos_x - 1, {
|
||||||
|
\ 'virt_text': [[s:content[0], 'llama_hl_hint'], [l:info, 'llama_hl_info']],
|
||||||
|
\ 'virt_text_win_col': virtcol('.') - 1
|
||||||
|
\ })
|
||||||
|
|
||||||
|
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, 0, {
|
||||||
|
\ 'virt_lines': map(s:content[1:], {idx, val -> [[val, 'llama_hl_hint']]}),
|
||||||
|
\ 'virt_text_win_col': virtcol('.')
|
||||||
|
\ })
|
||||||
|
elseif s:ghost_text_vim
|
||||||
|
let l:new_suffix = s:content[0]
|
||||||
|
if !empty(l:new_suffix)
|
||||||
|
call prop_add(s:pos_y, s:pos_x + 1, {
|
||||||
|
\ 'type': s:hlgroup_hint,
|
||||||
|
\ 'text': l:new_suffix
|
||||||
|
\ })
|
||||||
|
endif
|
||||||
|
for line in s:content[1:]
|
||||||
|
call prop_add(s:pos_y, 0, {
|
||||||
|
\ 'type': s:hlgroup_hint,
|
||||||
|
\ 'text': line,
|
||||||
|
\ 'text_padding_left': s:get_indent(line),
|
||||||
|
\ 'text_align': 'below'
|
||||||
|
\ })
|
||||||
|
endfor
|
||||||
|
if !empty(l:info)
|
||||||
|
call prop_add(s:pos_y, 0, {
|
||||||
|
\ 'type': s:hlgroup_info,
|
||||||
|
\ 'text': l:info,
|
||||||
|
\ 'text_padding_left': col('$'),
|
||||||
|
\ 'text_wrap': 'truncate'
|
||||||
|
\ })
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
|
" setup accept shortcuts
|
||||||
|
inoremap <buffer> <Tab> <C-O>:call llama#fim_accept(v:false)<CR>
|
||||||
|
inoremap <buffer> <S-Tab> <C-O>:call llama#fim_accept(v:true)<CR>
|
||||||
|
|
||||||
|
let s:hint_shown = v:true
|
||||||
|
endfunction
|
||||||
|
|
||||||
|
function! s:fim_on_exit(job_id, exit_code, event = v:null)
|
||||||
|
if a:exit_code != 0
|
||||||
|
echom "Job failed with exit code: " . a:exit_code
|
||||||
|
endif
|
||||||
|
|
||||||
|
let s:current_job = v:null
|
||||||
|
endfunction
|
@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
|
|||||||
if (n_eval > n_batch) {
|
if (n_eval > n_batch) {
|
||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
|
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
|
||||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -401,6 +401,39 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llava_embd_batch {
|
||||||
|
std::vector<llama_pos> pos;
|
||||||
|
std::vector<int32_t> n_seq_id;
|
||||||
|
std::vector<llama_seq_id> seq_id_0;
|
||||||
|
std::vector<llama_seq_id *> seq_ids;
|
||||||
|
std::vector<int8_t> logits;
|
||||||
|
llama_batch batch;
|
||||||
|
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
||||||
|
pos .resize(n_tokens);
|
||||||
|
n_seq_id.resize(n_tokens);
|
||||||
|
seq_ids .resize(n_tokens + 1);
|
||||||
|
logits .resize(n_tokens);
|
||||||
|
seq_id_0.resize(1);
|
||||||
|
seq_id_0[0] = seq_id;
|
||||||
|
seq_ids [n_tokens] = nullptr;
|
||||||
|
batch = {
|
||||||
|
/*n_tokens =*/ n_tokens,
|
||||||
|
/*tokens =*/ nullptr,
|
||||||
|
/*embd =*/ embd,
|
||||||
|
/*pos =*/ pos.data(),
|
||||||
|
/*n_seq_id =*/ n_seq_id.data(),
|
||||||
|
/*seq_id =*/ seq_ids.data(),
|
||||||
|
/*logits =*/ logits.data(),
|
||||||
|
};
|
||||||
|
for (int i = 0; i < n_tokens; i++) {
|
||||||
|
batch.pos [i] = pos_0 + i;
|
||||||
|
batch.n_seq_id[i] = 1;
|
||||||
|
batch.seq_id [i] = seq_id_0.data();
|
||||||
|
batch.logits [i] = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
|
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
|
||||||
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
|
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||||
|
|
||||||
@ -409,8 +442,9 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
|
|||||||
if (n_eval > n_batch) {
|
if (n_eval > n_batch) {
|
||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
float * embd = image_embed->embed+i*n_embd;
|
||||||
if (llama_decode(ctx_llama, batch)) {
|
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
|
||||||
|
if (llama_decode(ctx_llama, llava_batch.batch)) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -432,7 +466,7 @@ struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * c
|
|||||||
bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos);
|
bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos);
|
||||||
if (!image_embed_result) {
|
if (!image_embed_result) {
|
||||||
clip_image_u8_free(img);
|
clip_image_u8_free(img);
|
||||||
LOG_ERR("%s: coulnd't embed the image\n", __func__);
|
LOG_ERR("%s: couldn't embed the image\n", __func__);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,7 +97,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
|
|||||||
if (n_eval > n_batch) {
|
if (n_eval > n_batch) {
|
||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
|
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
|
||||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -89,8 +89,8 @@ int main(int argc, char ** argv) {
|
|||||||
const auto t_enc_start = ggml_time_us();
|
const auto t_enc_start = ggml_time_us();
|
||||||
|
|
||||||
// eval the prompt
|
// eval the prompt
|
||||||
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
|
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
|
||||||
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
|
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
|
||||||
|
|
||||||
for (int s = 1; s < W + G + 1; ++s) {
|
for (int s = 1; s < W + G + 1; ++s) {
|
||||||
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
|
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
|
||||||
|
@ -89,8 +89,8 @@ int main(int argc, char ** argv){
|
|||||||
|
|
||||||
const auto t_enc_start = ggml_time_us();
|
const auto t_enc_start = ggml_time_us();
|
||||||
|
|
||||||
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
|
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
|
||||||
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
|
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
|
||||||
|
|
||||||
const auto t_enc_end = ggml_time_us();
|
const auto t_enc_end = ggml_time_us();
|
||||||
|
|
||||||
|
@ -187,6 +187,30 @@ Use the `--no-penalize-nl` option to disable newline penalization when applying
|
|||||||
|
|
||||||
Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl`
|
Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl`
|
||||||
|
|
||||||
|
### DRY Repetition Penalty
|
||||||
|
|
||||||
|
DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)).
|
||||||
|
|
||||||
|
- `--dry-multiplier N`: Set the DRY sampling multiplier (default: 0.0, 0.0 = disabled).
|
||||||
|
- `--dry-base N`: Set the DRY sampling base value (default: 1.75).
|
||||||
|
- `--dry-allowed-length N`: Set the allowed length for DRY sampling (default: 2).
|
||||||
|
- `--dry-penalty-last-n N`: Set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size).
|
||||||
|
- `--dry-sequence-breaker STRING`: Add a sequence breaker for DRY sampling. Can be used more than once to add multiple sequence breakers. Using this clears out the default breakers, which consist of: `['\n', ':', '"', '*']`. If the string `"none"` is supplied, no sequence breakers are used.
|
||||||
|
|
||||||
|
The `dry-multiplier` option controls the strength of the DRY sampling effect. A value of 0.0 disables DRY sampling, while higher values increase its influence. A typical recommended value is 0.8.
|
||||||
|
|
||||||
|
The `dry-base` option sets the base value for the exponential penalty calculation in DRY sampling. Higher values lead to more aggressive penalization of repetitions.
|
||||||
|
|
||||||
|
The `dry-allowed-length` option sets the maximum length of repeated sequences that will not be penalized. Repetitions shorter than or equal to this length are not penalized, allowing for natural repetitions of short phrases or common words.
|
||||||
|
|
||||||
|
The `dry-penalty-last-n` option controls how many recent tokens to consider when applying the DRY penalty. A value of -1 considers the entire context. Use a positive value to limit the consideration to a specific number of recent tokens.
|
||||||
|
|
||||||
|
The `dry-sequence-breaker` option adds a single sequence breaker and can be used more than once to specify multiple sequence breakers. Sequence breakers interrupt sequence matching and break the input into parts where matching can be applied.
|
||||||
|
|
||||||
|
DRY sampling provides more nuanced control over text generation, particularly for reducing long-range repetitions and maintaining global coherence.
|
||||||
|
|
||||||
|
Example usage: `--dry-multiplier 0.8 --dry-base 1.75 --dry-allowed-length 2 --dry-penalty-last-n -1 --dry-sequence-breaker "—" --dry-sequence-breaker "##"`
|
||||||
|
|
||||||
### Top-K Sampling
|
### Top-K Sampling
|
||||||
|
|
||||||
- `--top-k N`: Limit the next token selection to the K most probable tokens (default: 40).
|
- `--top-k N`: Limit the next token selection to the K most probable tokens (default: 40).
|
||||||
@ -211,14 +235,6 @@ The Min-P sampling method was designed as an alternative to Top-P, and aims to e
|
|||||||
|
|
||||||
Example usage: `--min-p 0.05`
|
Example usage: `--min-p 0.05`
|
||||||
|
|
||||||
### Tail-Free Sampling (TFS)
|
|
||||||
|
|
||||||
- `--tfs N`: Enable tail free sampling with parameter z (default: 1.0, 1.0 = disabled).
|
|
||||||
|
|
||||||
Tail-free sampling (TFS) is a text generation technique that aims to reduce the impact of less likely tokens, which may be less relevant, less coherent, or nonsensical, on the output. Similar to Top-P it tries to determine the bulk of the most likely tokens dynamically. But TFS filters out logits based on the second derivative of their probabilities. Adding tokens is stopped after the sum of the second derivatives reaches the parameter z. In short: TFS looks at how quickly the probabilities of the tokens decrease and cuts off the tail of unlikely tokens using the parameter z. Typical values for z are in the range of 0.9 to 0.95. A value of 1.0 would include all tokens and thus disables the effect of TFS.
|
|
||||||
|
|
||||||
Example usage: `--tfs 0.95`
|
|
||||||
|
|
||||||
### Locally Typical Sampling
|
### Locally Typical Sampling
|
||||||
|
|
||||||
- `--typical N`: Enable locally typical sampling with parameter p (default: 1.0, 1.0 = disabled).
|
- `--typical N`: Enable locally typical sampling with parameter p (default: 1.0, 1.0 = disabled).
|
||||||
@ -241,6 +257,19 @@ The `--mirostat-ent` option sets the Mirostat target entropy (tau), which repres
|
|||||||
|
|
||||||
Example usage: `--mirostat 2 --mirostat-lr 0.05 --mirostat-ent 3.0`
|
Example usage: `--mirostat 2 --mirostat-lr 0.05 --mirostat-ent 3.0`
|
||||||
|
|
||||||
|
### XTC Sampling
|
||||||
|
|
||||||
|
- `--xtc-probability N`: Sets the chance for token removal (checked once on sampler start) (default: 0.0).
|
||||||
|
- `--xtc-threshold N`: Sets a minimum probability threshold for tokens to be removed (default: 0.1).
|
||||||
|
|
||||||
|
Exclude Top Choices (XTC) is a unique sampler that is designed to remove top tokens from consideration and avoid more obvious and repetitive outputs. With a chance of `xtc-probability` it searches for tokens with probabilities of `xtc-threshold` and above, then removes all such tokens except the least probable one.
|
||||||
|
|
||||||
|
By removing top tokens XTC can improve the variety of answers, break writing clichés and inhibit repition, since clichés and repeated phrases are usually more likely to appear. By keeping the last token above the threshold, XTC ensures that the answer is still coherent. XTC is meant to be used for creative tasks, but feel free to experiment with different settings for different models.
|
||||||
|
|
||||||
|
Being experimental and unique, XTC is disabled by default. The recommended combination of samplers is Min-P followed by XTC on its default settings: `--sampling-seq mx --min-p 0.02 --xtc-probability 0.5`.
|
||||||
|
|
||||||
|
Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1`
|
||||||
|
|
||||||
### Logit Bias
|
### Logit Bias
|
||||||
|
|
||||||
- `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion.
|
- `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion.
|
||||||
@ -284,10 +313,6 @@ These options help improve the performance and memory usage of the LLaMA models.
|
|||||||
|
|
||||||
These flags attempt optimizations that help on some systems with non-uniform memory access. This currently consists of one of the above strategies, and disabling prefetch and readahead for mmap. The latter causes mapped pages to be faulted in on first access instead of all at once, and in combination with pinning threads to NUMA nodes, more of the pages end up on the NUMA node where they are used. Note that if the model is already in the system page cache, for example because of a previous run without this option, this will have little effect unless you drop the page cache first. This can be done by rebooting the system or on Linux by writing '3' to '/proc/sys/vm/drop_caches' as root.
|
These flags attempt optimizations that help on some systems with non-uniform memory access. This currently consists of one of the above strategies, and disabling prefetch and readahead for mmap. The latter causes mapped pages to be faulted in on first access instead of all at once, and in combination with pinning threads to NUMA nodes, more of the pages end up on the NUMA node where they are used. Note that if the model is already in the system page cache, for example because of a previous run without this option, this will have little effect unless you drop the page cache first. This can be done by rebooting the system or on Linux by writing '3' to '/proc/sys/vm/drop_caches' as root.
|
||||||
|
|
||||||
### Memory Float 32
|
|
||||||
|
|
||||||
- `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. This doubles the context memory requirement and cached prompt file size but does not appear to increase generation quality in a measurable way. Not recommended.
|
|
||||||
|
|
||||||
### Batch Size
|
### Batch Size
|
||||||
|
|
||||||
- `-b N, --batch-size N`: Set the batch size for prompt processing (default: `2048`). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations.
|
- `-b N, --batch-size N`: Set the batch size for prompt processing (default: `2048`). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations.
|
||||||
@ -308,6 +333,15 @@ These options help improve the performance and memory usage of the LLaMA models.
|
|||||||
|
|
||||||
For information about 4-bit quantization, which can significantly improve performance and reduce memory usage, please refer to llama.cpp's primary [README](../../README.md#prepare-and-quantize).
|
For information about 4-bit quantization, which can significantly improve performance and reduce memory usage, please refer to llama.cpp's primary [README](../../README.md#prepare-and-quantize).
|
||||||
|
|
||||||
|
## LoRA (Low-Rank Adaptation) adapters
|
||||||
|
|
||||||
|
- `--lora FNAME`: Optional path to a LoRA adapter to use with scaling of 1.0. Can be mixed with `--lora-scaled` and can be repeated to use multiple adapters.
|
||||||
|
- `--lora-scaled FNAME`: Optional path to a LoRA adapter with user-defined scaling. Can be mixed with `--lora` and can repeated to use multiple adapters.
|
||||||
|
|
||||||
|
You can add LoRA adapters using `--lora` or `--lora-scaled`. For example: `--lora my_adapter_1.gguf --lora my_adapter_2.gguf ...` or `--lora-scaled lora_task_A.gguf 0.5 --lora-scaled lora_task_B.gguf 0.5`.
|
||||||
|
|
||||||
|
LoRA adapters should be in GGUF format. To convert from Hugging Face format use the `convert-lora-to-gguf.py` script. LoRA adapters are loaded separately and applied during inference - they are not merged with the main model. This means that mmap model loading is fully supported when using LoRA adapters. The old `--lora-base` flag has been removed now that merging is no longer performed.
|
||||||
|
|
||||||
## Additional Options
|
## Additional Options
|
||||||
|
|
||||||
These options provide extra functionality and customization when running the LLaMA models:
|
These options provide extra functionality and customization when running the LLaMA models:
|
||||||
@ -316,6 +350,4 @@ These options provide extra functionality and customization when running the LLa
|
|||||||
- `--verbose-prompt`: Print the prompt before generating text.
|
- `--verbose-prompt`: Print the prompt before generating text.
|
||||||
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used.
|
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used.
|
||||||
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance.
|
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance.
|
||||||
- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
|
|
||||||
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.
|
|
||||||
- `-hfr URL --hf-repo URL`: The url to the Hugging Face model repository. Used in conjunction with `--hf-file` or `-hff`. The model is downloaded and stored in the file provided by `-m` or `--model`. If `-m` is not provided, the model is auto-stored in the path specified by the `LLAMA_CACHE` environment variable or in an OS-specific local cache.
|
- `-hfr URL --hf-repo URL`: The url to the Hugging Face model repository. Used in conjunction with `--hf-file` or `-hff`. The model is downloaded and stored in the file provided by `-m` or `--model`. If `-m` is not provided, the model is auto-stored in the path specified by the `LLAMA_CACHE` environment variable or in an OS-specific local cache.
|
||||||
|
@ -528,7 +528,7 @@ int main(int argc, char ** argv) {
|
|||||||
int enc_input_size = embd_inp.size();
|
int enc_input_size = embd_inp.size();
|
||||||
llama_token * enc_input_buf = embd_inp.data();
|
llama_token * enc_input_buf = embd_inp.data();
|
||||||
|
|
||||||
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
|
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -569,7 +569,8 @@ int main(int argc, char ** argv) {
|
|||||||
if (!params.ctx_shift){
|
if (!params.ctx_shift){
|
||||||
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
|
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
|
||||||
break;
|
break;
|
||||||
} else {
|
}
|
||||||
|
|
||||||
if (params.n_predict == -2) {
|
if (params.n_predict == -2) {
|
||||||
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
||||||
break;
|
break;
|
||||||
@ -593,7 +594,6 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_DBG("clear session path\n");
|
LOG_DBG("clear session path\n");
|
||||||
path_session.clear();
|
path_session.clear();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// context extension via Self-Extend
|
// context extension via Self-Extend
|
||||||
while (n_past >= ga_i + ga_w) {
|
while (n_past >= ga_i + ga_w) {
|
||||||
@ -648,7 +648,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||||
|
|
||||||
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -308,7 +308,6 @@ int main(int argc, char ** argv) {
|
|||||||
batch.n_seq_id + i,
|
batch.n_seq_id + i,
|
||||||
batch.seq_id + i,
|
batch.seq_id + i,
|
||||||
batch.logits + i,
|
batch.logits + i,
|
||||||
0, 0, 0, // unused
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
|
@ -408,14 +408,21 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
|||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
|
|
||||||
|
common_batch_clear(batch);
|
||||||
|
for (int i = 0; i < batch_size; i++) {
|
||||||
|
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
|
||||||
|
}
|
||||||
|
|
||||||
//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
||||||
// TODO: use llama_batch.logits instead of relying on logits_all == true
|
if (llama_decode(ctx, batch)) {
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
|
||||||
//LOG_ERR("%s : failed to eval\n", __func__);
|
//LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
|
llama_batch_free(batch);
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -435,6 +442,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
@ -704,7 +713,6 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
|||||||
batch.n_seq_id + i,
|
batch.n_seq_id + i,
|
||||||
batch.seq_id + i,
|
batch.seq_id + i,
|
||||||
batch.logits + i,
|
batch.logits + i,
|
||||||
0, 0, 0, // unused
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
@ -1791,6 +1799,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
@ -1803,9 +1813,14 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: use llama_batch.logits instead of relying on logits_all == true
|
common_batch_clear(batch);
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
for (int i = 0; i < batch_size; i++) {
|
||||||
|
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_decode(ctx, batch)) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
|
llama_batch_free(batch);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1818,6 +1833,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
|
@ -42,15 +42,21 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||||
|
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_softmax());
|
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed));
|
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed));
|
||||||
|
|
||||||
// tokenize prompt
|
// tokenize prompt
|
||||||
auto tokens = common_tokenize(ctx, params.prompt, true);
|
auto tokens = common_tokenize(ctx, params.prompt, true);
|
||||||
|
|
||||||
|
// prepare the batch
|
||||||
|
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
||||||
|
for (size_t i = 0; i < tokens.size(); i++) {
|
||||||
|
common_batch_add(batch, tokens[i], i, {0}, false);
|
||||||
|
}
|
||||||
|
batch.logits[batch.n_tokens - 1] = true; // generate next token
|
||||||
|
|
||||||
// evaluate prompt
|
// evaluate prompt
|
||||||
llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), n_past, 0));
|
llama_decode(ctx, batch);
|
||||||
n_past += tokens.size();
|
n_past += batch.n_tokens;
|
||||||
|
|
||||||
// save state (rng, logits, embedding and kv_cache) to file
|
// save state (rng, logits, embedding and kv_cache) to file
|
||||||
{
|
{
|
||||||
@ -77,8 +83,12 @@ int main(int argc, char ** argv) {
|
|||||||
printf("%s", next_token_str.c_str());
|
printf("%s", next_token_str.c_str());
|
||||||
result0 += next_token_str;
|
result0 += next_token_str;
|
||||||
|
|
||||||
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) {
|
common_batch_clear(batch);
|
||||||
|
common_batch_add(batch, next_token, n_past, {0}, true);
|
||||||
|
|
||||||
|
if (llama_decode(ctx, batch)) {
|
||||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||||
|
llama_batch_free(batch);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
return 1;
|
return 1;
|
||||||
@ -96,7 +106,6 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
|
llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
|
||||||
|
|
||||||
llama_sampler_chain_add(smpl2, llama_sampler_init_softmax());
|
|
||||||
llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed));
|
llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed));
|
||||||
|
|
||||||
printf("\nsecond run: %s", params.prompt.c_str());
|
printf("\nsecond run: %s", params.prompt.c_str());
|
||||||
@ -133,8 +142,12 @@ int main(int argc, char ** argv) {
|
|||||||
printf("%s", next_token_str.c_str());
|
printf("%s", next_token_str.c_str());
|
||||||
result1 += next_token_str;
|
result1 += next_token_str;
|
||||||
|
|
||||||
if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1, n_past, 0))) {
|
common_batch_clear(batch);
|
||||||
|
common_batch_add(batch, next_token, n_past, {0}, true);
|
||||||
|
|
||||||
|
if (llama_decode(ctx2, batch)) {
|
||||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||||
|
llama_batch_free(batch);
|
||||||
llama_free(ctx2);
|
llama_free(ctx2);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
return 1;
|
return 1;
|
||||||
@ -156,7 +169,6 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
|
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
|
||||||
|
|
||||||
llama_sampler_chain_add(smpl3, llama_sampler_init_softmax());
|
|
||||||
llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed));
|
llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed));
|
||||||
|
|
||||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||||
@ -221,8 +233,12 @@ int main(int argc, char ** argv) {
|
|||||||
printf("%s", next_token_str.c_str());
|
printf("%s", next_token_str.c_str());
|
||||||
result2 += next_token_str;
|
result2 += next_token_str;
|
||||||
|
|
||||||
if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) {
|
common_batch_clear(batch);
|
||||||
|
common_batch_add(batch, next_token, n_past, {1}, true);
|
||||||
|
|
||||||
|
if (llama_decode(ctx3, batch)) {
|
||||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||||
|
llama_batch_free(batch);
|
||||||
llama_free(ctx3);
|
llama_free(ctx3);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
return 1;
|
return 1;
|
||||||
@ -236,6 +252,7 @@ int main(int argc, char ** argv) {
|
|||||||
llama_sampler_free(smpl2);
|
llama_sampler_free(smpl2);
|
||||||
llama_sampler_free(smpl3);
|
llama_sampler_free(smpl3);
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
llama_free(ctx3);
|
llama_free(ctx3);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ The project is under active development, and we are [looking for feedback and co
|
|||||||
|
|
||||||
| Argument | Explanation |
|
| Argument | Explanation |
|
||||||
| -------- | ----------- |
|
| -------- | ----------- |
|
||||||
| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'<br/>(default: top_k;tfs_z;typ_p;top_p;min_p;temperature) |
|
| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'<br/>(default: top_k;typ_p;top_p;min_p;temperature) |
|
||||||
| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
|
| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
|
||||||
| `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: kfypmt) |
|
| `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: kfypmt) |
|
||||||
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
|
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
|
||||||
@ -108,15 +108,19 @@ The project is under active development, and we are [looking for feedback and co
|
|||||||
| `--top-k N` | top-k sampling (default: 40, 0 = disabled) |
|
| `--top-k N` | top-k sampling (default: 40, 0 = disabled) |
|
||||||
| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) |
|
| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) |
|
||||||
| `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) |
|
| `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) |
|
||||||
| `--tfs N` | tail free sampling, parameter z (default: 1.0, 1.0 = disabled) |
|
|
||||||
| `--typical N` | locally typical sampling, parameter p (default: 1.0, 1.0 = disabled) |
|
| `--typical N` | locally typical sampling, parameter p (default: 1.0, 1.0 = disabled) |
|
||||||
| `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) |
|
| `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) |
|
||||||
| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) |
|
| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) |
|
||||||
| `--presence-penalty N` | repeat alpha presence penalty (default: 0.0, 0.0 = disabled) |
|
| `--presence-penalty N` | repeat alpha presence penalty (default: 0.0, 0.0 = disabled) |
|
||||||
| `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.0, 0.0 = disabled) |
|
| `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.0, 0.0 = disabled) |
|
||||||
|
| `--dry-multiplier N` | DRY sampling multiplier (default: 0.0, 0.0 = disabled) |
|
||||||
|
| `--dry-base N` | DRY sampling base value (default: 1.75) |
|
||||||
|
| `--dry-allowed-length N` | allowed length for DRY sampling (default: 2) |
|
||||||
|
| `--dry-penalty-last-n N` | DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size) |
|
||||||
|
| `--dry-sequence-breaker STRING` | add sequence breaker for DRY sampling, clearing out default breakers (`['\n', ':', '"', '*']`) in the process; use `"none"` to not use any sequence breakers
|
||||||
| `--dynatemp-range N` | dynamic temperature range (default: 0.0, 0.0 = disabled) |
|
| `--dynatemp-range N` | dynamic temperature range (default: 0.0, 0.0 = disabled) |
|
||||||
| `--dynatemp-exp N` | dynamic temperature exponent (default: 1.0) |
|
| `--dynatemp-exp N` | dynamic temperature exponent (default: 1.0) |
|
||||||
| `--mirostat N` | use Mirostat sampling.<br/>Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.<br/>(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) |
|
| `--mirostat N` | use Mirostat sampling.<br/>Top K, Nucleus and Locally Typical samplers are ignored if used.<br/>(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) |
|
||||||
| `--mirostat-lr N` | Mirostat learning rate, parameter eta (default: 0.1) |
|
| `--mirostat-lr N` | Mirostat learning rate, parameter eta (default: 0.1) |
|
||||||
| `--mirostat-ent N` | Mirostat target entropy, parameter tau (default: 5.0) |
|
| `--mirostat-ent N` | Mirostat target entropy, parameter tau (default: 5.0) |
|
||||||
| `-l, --logit-bias TOKEN_ID(+/-)BIAS` | modifies the likelihood of token appearing in the completion,<br/>i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',<br/>or `--logit-bias 15043-1` to decrease likelihood of token ' Hello' |
|
| `-l, --logit-bias TOKEN_ID(+/-)BIAS` | modifies the likelihood of token appearing in the completion,<br/>i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',<br/>or `--logit-bias 15043-1` to decrease likelihood of token ' Hello' |
|
||||||
@ -147,6 +151,7 @@ The project is under active development, and we are [looking for feedback and co
|
|||||||
| `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate<br/>(env: LLAMA_ARG_SSL_CERT_FILE) |
|
| `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate<br/>(env: LLAMA_ARG_SSL_CERT_FILE) |
|
||||||
| `-to, --timeout N` | server read/write timeout in seconds (default: 600)<br/>(env: LLAMA_ARG_TIMEOUT) |
|
| `-to, --timeout N` | server read/write timeout in seconds (default: 600)<br/>(env: LLAMA_ARG_TIMEOUT) |
|
||||||
| `--threads-http N` | number of threads used to process HTTP requests (default: -1)<br/>(env: LLAMA_ARG_THREADS_HTTP) |
|
| `--threads-http N` | number of threads used to process HTTP requests (default: -1)<br/>(env: LLAMA_ARG_THREADS_HTTP) |
|
||||||
|
| `--cache-reuse N` | min chunk size to attempt reusing from the cache via KV shifting (default: 0)<br/>(env: LLAMA_ARG_CACHE_REUSE) |
|
||||||
| `--metrics` | enable prometheus compatible metrics endpoint (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_METRICS) |
|
| `--metrics` | enable prometheus compatible metrics endpoint (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_METRICS) |
|
||||||
| `--slots` | enable slots monitoring endpoint (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_SLOTS) |
|
| `--slots` | enable slots monitoring endpoint (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_SLOTS) |
|
||||||
| `--props` | enable changing global properties via POST /props (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_PROPS) |
|
| `--props` | enable changing global properties via POST /props (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_PROPS) |
|
||||||
@ -318,6 +323,18 @@ node index.js
|
|||||||
- The prompt is a string or an array with the first element given as a string
|
- The prompt is a string or an array with the first element given as a string
|
||||||
- The model's `tokenizer.ggml.add_bos_token` metadata is `true`
|
- The model's `tokenizer.ggml.add_bos_token` metadata is `true`
|
||||||
|
|
||||||
|
These input shapes and data type are allowed for `prompt`:
|
||||||
|
|
||||||
|
- Single string: `"string"`
|
||||||
|
- Single sequence of tokens: `[12, 34, 56]`
|
||||||
|
- Mixed tokens and strings: `[12, 34, "string", 56, 78]`
|
||||||
|
|
||||||
|
Multiple prompts are also supported. In this case, the completion result will be an array.
|
||||||
|
|
||||||
|
- Only strings: `["string1", "string2"]`
|
||||||
|
- Strings and sequences of tokens: `["string1", [12, 34, 56]]`
|
||||||
|
- Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string"]`
|
||||||
|
|
||||||
`temperature`: Adjust the randomness of the generated text. Default: `0.8`
|
`temperature`: Adjust the randomness of the generated text. Default: `0.8`
|
||||||
|
|
||||||
`dynatemp_range`: Dynamic temperature range. The final temperature will be in the range of `[temperature - dynatemp_range; temperature + dynatemp_range]` Default: `0.0`, which is disabled.
|
`dynatemp_range`: Dynamic temperature range. The final temperature will be in the range of `[temperature - dynatemp_range; temperature + dynatemp_range]` Default: `0.0`, which is disabled.
|
||||||
@ -332,6 +349,8 @@ node index.js
|
|||||||
|
|
||||||
`n_predict`: Set the maximum number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. Default: `-1`, where `-1` is infinity.
|
`n_predict`: Set the maximum number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. Default: `-1`, where `-1` is infinity.
|
||||||
|
|
||||||
|
`n_indent`: Specify the minimum line indentation for the generated text in number of whitespace characters. Useful for code completion tasks. Default: `0`
|
||||||
|
|
||||||
`n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. The number excludes the BOS token.
|
`n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. The number excludes the BOS token.
|
||||||
By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to retain all tokens from the prompt.
|
By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to retain all tokens from the prompt.
|
||||||
|
|
||||||
@ -340,8 +359,6 @@ node index.js
|
|||||||
`stop`: Specify a JSON array of stopping strings.
|
`stop`: Specify a JSON array of stopping strings.
|
||||||
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration. Default: `[]`
|
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration. Default: `[]`
|
||||||
|
|
||||||
`tfs_z`: Enable tail free sampling with parameter z. Default: `1.0`, which is disabled.
|
|
||||||
|
|
||||||
`typical_p`: Enable locally typical sampling with parameter p. Default: `1.0`, which is disabled.
|
`typical_p`: Enable locally typical sampling with parameter p. Default: `1.0`, which is disabled.
|
||||||
|
|
||||||
`repeat_penalty`: Control the repetition of token sequences in the generated text. Default: `1.1`
|
`repeat_penalty`: Control the repetition of token sequences in the generated text. Default: `1.1`
|
||||||
@ -354,6 +371,16 @@ node index.js
|
|||||||
|
|
||||||
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
|
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
|
||||||
|
|
||||||
|
`dry_multiplier`: Set the DRY (Don't Repeat Yourself) repetition penalty multiplier. Default: `0.0`, which is disabled.
|
||||||
|
|
||||||
|
`dry_base`: Set the DRY repetition penalty base value. Default: `1.75`
|
||||||
|
|
||||||
|
`dry_allowed_length`: Tokens that extend repetition beyond this receive exponentially increasing penalty: multiplier * base ^ (length of repeating sequence before token - allowed length). Default: `2`
|
||||||
|
|
||||||
|
`dry_penalty_last_n`: How many tokens to scan for repetitions. Default: `-1`, where `0` is disabled and `-1` is context size.
|
||||||
|
|
||||||
|
`dry_sequence_breakers`: Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n', ':', '"', '*']`
|
||||||
|
|
||||||
`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
|
`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
|
||||||
|
|
||||||
`mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`
|
`mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`
|
||||||
@ -382,7 +409,7 @@ node index.js
|
|||||||
|
|
||||||
`cache_prompt`: Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests. Because (depending on the backend) the logits are **not** guaranteed to be bit-for-bit identical for different batch sizes (prompt processing vs. token generation) enabling this option can cause nondeterministic results. Default: `false`
|
`cache_prompt`: Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests. Because (depending on the backend) the logits are **not** guaranteed to be bit-for-bit identical for different batch sizes (prompt processing vs. token generation) enabling this option can cause nondeterministic results. Default: `false`
|
||||||
|
|
||||||
`samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"]` - these are all the available values.
|
`samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["top_k", "typical_p", "top_p", "min_p", "temperature"]` - these are all the available values.
|
||||||
|
|
||||||
**Response format**
|
**Response format**
|
||||||
|
|
||||||
@ -523,8 +550,31 @@ Takes a prefix and a suffix and returns the predicted completion as stream.
|
|||||||
|
|
||||||
- `input_prefix`: Set the prefix of the code to infill.
|
- `input_prefix`: Set the prefix of the code to infill.
|
||||||
- `input_suffix`: Set the suffix of the code to infill.
|
- `input_suffix`: Set the suffix of the code to infill.
|
||||||
|
- `input_extra`: Additional context inserted before the FIM prefix.
|
||||||
|
- `prompt`: Added after the `FIM_MID` token
|
||||||
|
|
||||||
It also accepts all the options of `/completion`.
|
`input_extra` is array of `{"filename": string, "text": string}` objects.
|
||||||
|
|
||||||
|
The endpoint also accepts all the options of `/completion`.
|
||||||
|
|
||||||
|
If the model has `FIM_REPO` and `FIM_FILE_SEP` tokens, the [repo-level pattern](https://arxiv.org/pdf/2409.12186) is used:
|
||||||
|
|
||||||
|
```txt
|
||||||
|
<FIM_REP>myproject
|
||||||
|
<FIM_SEP>{chunk 0 filename}
|
||||||
|
{chunk 0 text}
|
||||||
|
<FIM_SEP>{chunk 1 filename}
|
||||||
|
{chunk 1 text}
|
||||||
|
...
|
||||||
|
<FIM_SEP>filename
|
||||||
|
<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
|
||||||
|
```
|
||||||
|
|
||||||
|
If the tokens are missing, then the extra context is simply prefixed at the start:
|
||||||
|
|
||||||
|
```txt
|
||||||
|
[input_extra]<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
|
||||||
|
```
|
||||||
|
|
||||||
### **GET** `/props`: Get server global properties.
|
### **GET** `/props`: Get server global properties.
|
||||||
|
|
||||||
@ -685,7 +735,6 @@ Example:
|
|||||||
"repeat_penalty": 1.100000023841858,
|
"repeat_penalty": 1.100000023841858,
|
||||||
"samplers": [
|
"samplers": [
|
||||||
"top_k",
|
"top_k",
|
||||||
"tfs_z",
|
|
||||||
"typical_p",
|
"typical_p",
|
||||||
"top_p",
|
"top_p",
|
||||||
"min_p",
|
"min_p",
|
||||||
@ -699,7 +748,6 @@ Example:
|
|||||||
"stream": false,
|
"stream": false,
|
||||||
"task_id": 0,
|
"task_id": 0,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"tfs_z": 1.0,
|
|
||||||
"top_k": 40,
|
"top_k": 40,
|
||||||
"top_p": 0.949999988079071,
|
"top_p": 0.949999988079071,
|
||||||
"typical_p": 1.0
|
"typical_p": 1.0
|
||||||
|
@ -40,10 +40,15 @@
|
|||||||
repeat_last_n: 0, // 0 = disable penalty, -1 = context size
|
repeat_last_n: 0, // 0 = disable penalty, -1 = context size
|
||||||
repeat_penalty: 1.0, // 1.0 = disabled
|
repeat_penalty: 1.0, // 1.0 = disabled
|
||||||
penalize_nl: false, // true only useful for infinite completion
|
penalize_nl: false, // true only useful for infinite completion
|
||||||
|
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
|
||||||
|
dry_base: 1.75, // 0.0 = disabled
|
||||||
|
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
|
||||||
|
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||||
top_k: 0, // <= 0 to use vocab size
|
top_k: 0, // <= 0 to use vocab size
|
||||||
top_p: 1.0, // 1.0 = disabled
|
top_p: 1.0, // 1.0 = disabled
|
||||||
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
|
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
|
||||||
tfs_z: 1.0, // 1.0 = disabled
|
xtc_probability: 0.0, // 0 = disabled;
|
||||||
|
xtc_threshold: 0.1, // > 0.5 disables XTC;
|
||||||
typical_p: 1.0, // 1.0 = disabled
|
typical_p: 1.0, // 1.0 = disabled
|
||||||
presence_penalty: 0.0, // 0.0 = disabled
|
presence_penalty: 0.0, // 0.0 = disabled
|
||||||
frequency_penalty: 0.0, // 0.0 = disabled
|
frequency_penalty: 0.0, // 0.0 = disabled
|
||||||
@ -831,11 +836,16 @@ return html`
|
|||||||
<fieldset class="params">
|
<fieldset class="params">
|
||||||
${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })}
|
${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })}
|
||||||
${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
||||||
${FloatField({ label: "Top-P", title: "Limits the selection of the next token to a subset of tokens whose combined probability reaches a threshold value P = top-P. If set to 1, it means the entire vocabulary size is considered.", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
|
|
||||||
${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
||||||
${FloatField({ label: "TFS-Z", title: "Activates tail-free sampling, a method used to limit the prediction of tokens that are too frequent. The parameter z controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
|
|
||||||
${FloatField({ label: "Frequency Penalty", title: "A penalty that is applied based on the frequency with which certain tokens occur in the training data set. A higher value results in rare tokens being favoured.", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
${FloatField({ label: "Frequency Penalty", title: "A penalty that is applied based on the frequency with which certain tokens occur in the training data set. A higher value results in rare tokens being favoured.", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
||||||
|
${FloatField({ label: "Top-P", title: "Limits the selection of the next token to a subset of tokens whose combined probability reaches a threshold value P = top-P. If set to 1, it means the entire vocabulary size is considered.", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
|
||||||
${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
||||||
|
${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
|
||||||
|
${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
|
||||||
|
${FloatField({ label: "DRY Penalty Multiplier", title: "Set the DRY repetition penalty multiplier. Default is 0.0, which disables DRY.", max: 5.0, min: 0.0, name: "dry_multiplier", step: 0.01, value: params.value.dry_multiplier })}
|
||||||
|
${FloatField({ label: "DRY Base", title: "Set the DRY repetition penalty base value. Default is 1.75", max: 3.0, min: 1.0, name: "dry_base", step: 0.01, value: params.value.dry_base })}
|
||||||
|
${IntField({ label: "DRY Allowed Length", title: "Tokens that extend repetition beyond this receive exponentially increasing penalty. Default is 2", max: 10, min: 1, step: 1, name: "dry_allowed_length", value: params.value.dry_allowed_length })}
|
||||||
|
${IntField({ label: "DRY Penalty Last N", title: "How many tokens to scan for repetitions. Default is -1, where 0 is disabled and -1 is context size", max: 2048, min: -1, step: 16, name: "dry_penalty_last_n", value: params.value.dry_penalty_last_n })}
|
||||||
${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
|
${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
|
||||||
</fieldset>
|
</fieldset>
|
||||||
|
|
||||||
@ -1132,12 +1142,15 @@ document.addEventListener('DOMContentLoaded', (event) => {
|
|||||||
const snapSettings = {
|
const snapSettings = {
|
||||||
temperature: { snapValue: 1.0, snapRangeMultiplier: 6 },
|
temperature: { snapValue: 1.0, snapRangeMultiplier: 6 },
|
||||||
min_p: { snapValue: 0.05, snapRangeMultiplier: 2 },
|
min_p: { snapValue: 0.05, snapRangeMultiplier: 2 },
|
||||||
|
xtc_probability: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
||||||
|
xtc_threshold: { snapValue: 0.5, snapRangeMultiplier: 4 },
|
||||||
top_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
top_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
||||||
tfs_z: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
|
||||||
typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
||||||
repeat_penalty: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
repeat_penalty: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
||||||
presence_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
presence_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
||||||
frequency_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
frequency_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
||||||
|
dry_multiplier: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
||||||
|
dry_base: { snapValue: 1.75, snapRangeMultiplier: 4 },
|
||||||
};
|
};
|
||||||
// add an event listener for each slider
|
// add an event listener for each slider
|
||||||
Object.keys(snapSettings).forEach(sliderName => {
|
Object.keys(snapSettings).forEach(sliderName => {
|
||||||
|
@ -304,10 +304,15 @@
|
|||||||
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
|
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
|
||||||
repeat_penalty: 1.18, // 1.0 = disabled
|
repeat_penalty: 1.18, // 1.0 = disabled
|
||||||
penalize_nl: false,
|
penalize_nl: false,
|
||||||
|
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
|
||||||
|
dry_base: 1.75, // 0.0 = disabled
|
||||||
|
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
|
||||||
|
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||||
top_k: 40, // <= 0 to use vocab size
|
top_k: 40, // <= 0 to use vocab size
|
||||||
top_p: 0.95, // 1.0 = disabled
|
top_p: 0.95, // 1.0 = disabled
|
||||||
min_p: 0.05, // 0 = disabled
|
min_p: 0.05, // 0 = disabled
|
||||||
tfs_z: 1.0, // 1.0 = disabled
|
xtc_probability: 0.0, // 0 = disabled;
|
||||||
|
xtc_threshold: 0.1, // > 0.5 disables XTC;
|
||||||
typical_p: 1.0, // 1.0 = disabled
|
typical_p: 1.0, // 1.0 = disabled
|
||||||
presence_penalty: 0.0, // 0.0 = disabled
|
presence_penalty: 0.0, // 0.0 = disabled
|
||||||
frequency_penalty: 0.0, // 0.0 = disabled
|
frequency_penalty: 0.0, // 0.0 = disabled
|
||||||
@ -1009,10 +1014,15 @@
|
|||||||
<details>
|
<details>
|
||||||
<summary>More options</summary>
|
<summary>More options</summary>
|
||||||
<fieldset class="two">
|
<fieldset class="two">
|
||||||
${FloatField({ label: "TFS-Z", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
|
|
||||||
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
||||||
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
||||||
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
||||||
|
${FloatField({ label: "DRY Penalty Multiplier", max: 5.0, min: 0.0, name: "dry_multiplier", step: 0.01, value: params.value.dry_multiplier })}
|
||||||
|
${FloatField({ label: "DRY Base", max: 3.0, min: 1.0, name: "dry_base", step: 0.01, value: params.value.dry_base })}
|
||||||
|
${IntField({ label: "DRY Allowed Length", max: 10, min: 2, step: 1, name: "dry_allowed_length", value: params.value.dry_allowed_length })}
|
||||||
|
${IntField({ label: "DRY Penalty Last N", max: 2048, min: -1, step: 16, name: "dry_penalty_last_n", value: params.value.dry_penalty_last_n })}
|
||||||
|
${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
|
||||||
|
${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
|
||||||
</fieldset>
|
</fieldset>
|
||||||
<hr />
|
<hr />
|
||||||
<fieldset class="three">
|
<fieldset class="three">
|
||||||
|
File diff suppressed because one or more lines are too long
@ -529,7 +529,7 @@ export class SchemaConverter {
|
|||||||
return joinSeq();
|
return joinSeq();
|
||||||
};
|
};
|
||||||
|
|
||||||
return this._addRule(name, "\"\\\"\" " + toRule(transform()) + " \"\\\"\" space")
|
return this._addRule(name, "\"\\\"\" (" + toRule(transform()) + ") \"\\\"\" space")
|
||||||
}
|
}
|
||||||
|
|
||||||
_notStrings(strings) {
|
_notStrings(strings) {
|
||||||
|
0
examples/server/public/style.css
Executable file → Normal file
0
examples/server/public/style.css
Executable file → Normal file
@ -43,21 +43,6 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
enum stop_type {
|
enum stop_type {
|
||||||
@ -68,6 +53,7 @@ enum stop_type {
|
|||||||
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
||||||
enum slot_state {
|
enum slot_state {
|
||||||
SLOT_STATE_IDLE,
|
SLOT_STATE_IDLE,
|
||||||
|
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
|
||||||
SLOT_STATE_PROCESSING_PROMPT,
|
SLOT_STATE_PROCESSING_PROMPT,
|
||||||
SLOT_STATE_DONE_PROMPT,
|
SLOT_STATE_DONE_PROMPT,
|
||||||
SLOT_STATE_GENERATING,
|
SLOT_STATE_GENERATING,
|
||||||
@ -79,7 +65,7 @@ enum server_state {
|
|||||||
};
|
};
|
||||||
|
|
||||||
enum server_task_type {
|
enum server_task_type {
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_INFERENCE,
|
||||||
SERVER_TASK_TYPE_CANCEL,
|
SERVER_TASK_TYPE_CANCEL,
|
||||||
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
||||||
SERVER_TASK_TYPE_METRICS,
|
SERVER_TASK_TYPE_METRICS,
|
||||||
@ -89,21 +75,22 @@ enum server_task_type {
|
|||||||
SERVER_TASK_TYPE_SET_LORA,
|
SERVER_TASK_TYPE_SET_LORA,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum server_task_cmpl_type {
|
enum server_task_inf_type {
|
||||||
SERVER_TASK_CMPL_TYPE_NORMAL,
|
SERVER_TASK_INF_TYPE_COMPLETION,
|
||||||
SERVER_TASK_CMPL_TYPE_EMBEDDING,
|
SERVER_TASK_INF_TYPE_EMBEDDING,
|
||||||
SERVER_TASK_CMPL_TYPE_RERANK,
|
SERVER_TASK_INF_TYPE_RERANK,
|
||||||
SERVER_TASK_CMPL_TYPE_INFILL,
|
SERVER_TASK_INF_TYPE_INFILL,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task {
|
struct server_task {
|
||||||
int id = -1; // to be filled by server_queue
|
int id = -1; // to be filled by server_queue
|
||||||
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
|
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
|
||||||
|
|
||||||
|
llama_tokens prompt_tokens;
|
||||||
server_task_type type;
|
server_task_type type;
|
||||||
json data;
|
json data;
|
||||||
|
|
||||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||||
|
|
||||||
// utility function
|
// utility function
|
||||||
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
||||||
@ -131,14 +118,12 @@ struct slot_params {
|
|||||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
|
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
|
||||||
|
|
||||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||||
|
|
||||||
std::vector<std::string> antiprompt;
|
std::vector<std::string> antiprompt;
|
||||||
|
|
||||||
json input_prefix;
|
|
||||||
json input_suffix;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_slot {
|
struct server_slot {
|
||||||
@ -163,19 +148,20 @@ struct server_slot {
|
|||||||
int32_t i_batch = -1;
|
int32_t i_batch = -1;
|
||||||
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
||||||
|
|
||||||
|
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
|
||||||
int32_t n_prompt_tokens = 0;
|
int32_t n_prompt_tokens = 0;
|
||||||
int32_t n_prompt_tokens_processed = 0;
|
int32_t n_prompt_tokens_processed = 0;
|
||||||
|
|
||||||
json prompt; // can be either a string, array of strings or array of token ids
|
// input prompt tokens
|
||||||
|
llama_tokens prompt_tokens;
|
||||||
|
|
||||||
// when a task is submitted, we first tokenize the prompt and store it here
|
size_t last_nl_pos = 0;
|
||||||
std::vector<llama_token> prompt_tokens;
|
|
||||||
|
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
std::vector<llama_token> cache_tokens;
|
llama_tokens cache_tokens;
|
||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
|
||||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||||
|
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
bool has_new_line = false;
|
bool has_new_line = false;
|
||||||
@ -213,6 +199,7 @@ struct server_slot {
|
|||||||
SLT_DBG(*this, "%s", "\n");
|
SLT_DBG(*this, "%s", "\n");
|
||||||
|
|
||||||
n_prompt_tokens = 0;
|
n_prompt_tokens = 0;
|
||||||
|
last_nl_pos = 0;
|
||||||
generated_text = "";
|
generated_text = "";
|
||||||
has_new_line = false;
|
has_new_line = false;
|
||||||
truncated = false;
|
truncated = false;
|
||||||
@ -223,7 +210,7 @@ struct server_slot {
|
|||||||
n_past = 0;
|
n_past = 0;
|
||||||
n_sent_text = 0;
|
n_sent_text = 0;
|
||||||
n_sent_token_probs = 0;
|
n_sent_token_probs = 0;
|
||||||
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||||
|
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
}
|
}
|
||||||
@ -728,42 +715,6 @@ struct server_context {
|
|||||||
metrics.init();
|
metrics.init();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
|
|
||||||
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
|
||||||
// or the first element of the json_prompt array is a string.
|
|
||||||
std::vector<llama_token> prompt_tokens;
|
|
||||||
|
|
||||||
if (json_prompt.is_array()) {
|
|
||||||
bool first = true;
|
|
||||||
for (const auto & p : json_prompt) {
|
|
||||||
if (p.is_string()) {
|
|
||||||
auto s = p.template get<std::string>();
|
|
||||||
|
|
||||||
std::vector<llama_token> p;
|
|
||||||
if (first) {
|
|
||||||
p = common_tokenize(ctx, s, add_special, parse_special);
|
|
||||||
first = false;
|
|
||||||
} else {
|
|
||||||
p = common_tokenize(ctx, s, false, parse_special);
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
|
||||||
} else {
|
|
||||||
if (first) {
|
|
||||||
first = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt_tokens.push_back(p.template get<llama_token>());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto s = json_prompt.template get<std::string>();
|
|
||||||
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
|
|
||||||
}
|
|
||||||
|
|
||||||
return prompt_tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
server_slot * get_slot_by_id(int id) {
|
server_slot * get_slot_by_id(int id) {
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
if (slot.id == id) {
|
if (slot.id == id) {
|
||||||
@ -774,12 +725,12 @@ struct server_context {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
server_slot * get_available_slot(const std::string & prompt) {
|
server_slot * get_available_slot(const server_task & task) {
|
||||||
server_slot * ret = nullptr;
|
server_slot * ret = nullptr;
|
||||||
|
|
||||||
// find the slot that has at least n% prompt similarity
|
// find the slot that has at least n% prompt similarity
|
||||||
if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
|
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
|
||||||
int max_lcp_len = 0;
|
int max_lcs_len = 0;
|
||||||
float similarity = 0;
|
float similarity = 0;
|
||||||
|
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
@ -788,32 +739,26 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// skip the slot if it does not contains prompt
|
// skip the slot if it does not contains cached tokens
|
||||||
if (!slot.prompt.is_string()) {
|
if (slot.cache_tokens.empty()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// current slot's prompt
|
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
|
||||||
std::string slot_prompt = slot.prompt.get<std::string>();
|
int lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);
|
||||||
|
|
||||||
// length of the current slot's prompt
|
// fraction of the common subsequence length compared to the current slot's prompt length
|
||||||
int slot_prompt_len = slot_prompt.size();
|
similarity = static_cast<float>(lcs_len) / static_cast<int>(slot.cache_tokens.size());
|
||||||
|
|
||||||
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
|
||||||
int lcp_len = common_part(slot_prompt, prompt);
|
|
||||||
|
|
||||||
// fraction of the common substring length compared to the current slot's prompt length
|
|
||||||
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
|
|
||||||
|
|
||||||
// select the current slot if the criteria match
|
// select the current slot if the criteria match
|
||||||
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
|
if (lcs_len > max_lcs_len && similarity > slot_prompt_similarity) {
|
||||||
max_lcp_len = lcp_len;
|
max_lcs_len = lcs_len;
|
||||||
ret = &slot;
|
ret = &slot;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ret != nullptr) {
|
if (ret != nullptr) {
|
||||||
SLT_DBG(*ret, "selected slot by lcp similarity, max_lcp_len = %d, similarity = %f\n", max_lcp_len, similarity);
|
SLT_DBG(*ret, "selected slot by lcs similarity, max_lcs_len = %d, similarity = %f\n", max_lcs_len, similarity);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -858,10 +803,12 @@ struct server_context {
|
|||||||
slot.params.stream = json_value(data, "stream", false);
|
slot.params.stream = json_value(data, "stream", false);
|
||||||
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||||
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
||||||
|
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
|
||||||
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||||
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
|
||||||
|
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
|
||||||
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
||||||
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||||
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
||||||
@ -870,11 +817,15 @@ struct server_context {
|
|||||||
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
||||||
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
||||||
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
||||||
|
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
|
||||||
|
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
|
||||||
|
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
|
||||||
|
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
|
||||||
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||||
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||||
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
|
||||||
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||||
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||||
@ -882,6 +833,25 @@ struct server_context {
|
|||||||
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
|
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
|
||||||
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
|
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
|
||||||
|
|
||||||
|
if (slot.sparams.dry_base < 1.0f)
|
||||||
|
{
|
||||||
|
slot.sparams.dry_base = default_sparams.dry_base;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sequence breakers for DRY
|
||||||
|
{
|
||||||
|
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
|
||||||
|
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
||||||
|
|
||||||
|
if (data.contains("dry_sequence_breakers")) {
|
||||||
|
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
||||||
|
if (slot.sparams.dry_sequence_breakers.empty()) {
|
||||||
|
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// process "json_schema" and "grammar"
|
// process "json_schema" and "grammar"
|
||||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||||
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
||||||
@ -905,39 +875,6 @@ struct server_context {
|
|||||||
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
||||||
}
|
}
|
||||||
|
|
||||||
// infill
|
|
||||||
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
|
|
||||||
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
|
|
||||||
|
|
||||||
// get prompt
|
|
||||||
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
|
|
||||||
const auto & prompt = data.find("prompt");
|
|
||||||
if (prompt == data.end()) {
|
|
||||||
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((prompt->is_string()) ||
|
|
||||||
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
|
|
||||||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
|
||||||
slot.prompt = *prompt;
|
|
||||||
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
|
||||||
slot.prompt = prompt->at(0);
|
|
||||||
} else if (prompt->is_array() && prompt->size() > 1) {
|
|
||||||
// array of strings
|
|
||||||
for (const auto & el : *prompt) {
|
|
||||||
if (!el.is_string()) {
|
|
||||||
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
slot.prompt = *prompt;
|
|
||||||
} else {
|
|
||||||
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
slot.sparams.logit_bias.clear();
|
slot.sparams.logit_bias.clear();
|
||||||
|
|
||||||
@ -1017,8 +954,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
slot.state = SLOT_STATE_STARTED;
|
||||||
slot.prompt_tokens.clear();
|
|
||||||
|
|
||||||
SLT_INF(slot, "%s", "processing task\n");
|
SLT_INF(slot, "%s", "processing task\n");
|
||||||
|
|
||||||
@ -1068,22 +1004,21 @@ struct server_context {
|
|||||||
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||||
|
|
||||||
const std::string str_test = slot.generated_text.substr(pos);
|
const std::string str_test = slot.generated_text.substr(pos);
|
||||||
bool is_stop_full = false;
|
bool send_text = true;
|
||||||
|
|
||||||
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL);
|
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL);
|
||||||
if (stop_pos != std::string::npos) {
|
if (stop_pos != std::string::npos) {
|
||||||
is_stop_full = true;
|
|
||||||
slot.generated_text.erase(
|
slot.generated_text.erase(
|
||||||
slot.generated_text.begin() + pos + stop_pos,
|
slot.generated_text.begin() + pos + stop_pos,
|
||||||
slot.generated_text.end());
|
slot.generated_text.end());
|
||||||
pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||||
} else {
|
} else if (slot.has_next_token) {
|
||||||
is_stop_full = false;
|
|
||||||
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL);
|
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL);
|
||||||
|
send_text = stop_pos == std::string::npos;
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if there is any token to predict
|
// check if there is any token to predict
|
||||||
if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) {
|
if (send_text) {
|
||||||
// no send the stop word in the response
|
// no send the stop word in the response
|
||||||
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
||||||
slot.n_sent_text += result.text_to_send.size();
|
slot.n_sent_text += result.text_to_send.size();
|
||||||
@ -1108,15 +1043,50 @@ struct server_context {
|
|||||||
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
|
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (slot.has_new_line) {
|
||||||
// if we have already seen a new line, we stop after a certain time limit
|
// if we have already seen a new line, we stop after a certain time limit
|
||||||
if (slot.has_new_line && slot.params.t_max_predict_ms > 0 &&
|
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
|
||||||
(ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
|
|
||||||
slot.stopped_limit = true;
|
slot.stopped_limit = true;
|
||||||
slot.has_next_token = false;
|
slot.has_next_token = false;
|
||||||
|
|
||||||
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
|
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
|
||||||
|
if (slot.params.n_indent > 0) {
|
||||||
|
// check the current indentation
|
||||||
|
// TODO: improve by not doing it more than once for each new line
|
||||||
|
if (slot.last_nl_pos > 0) {
|
||||||
|
size_t pos = slot.last_nl_pos;
|
||||||
|
|
||||||
|
int n_indent = 0;
|
||||||
|
while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
|
||||||
|
n_indent++;
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
|
||||||
|
slot.stopped_limit = true;
|
||||||
|
slot.has_next_token = false;
|
||||||
|
|
||||||
|
// cut the last line
|
||||||
|
slot.generated_text.erase(pos, std::string::npos);
|
||||||
|
|
||||||
|
SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the next new line
|
||||||
|
{
|
||||||
|
const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
|
||||||
|
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
slot.last_nl_pos = pos + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// check if there is a new line in the generated text
|
// check if there is a new line in the generated text
|
||||||
if (result.text_to_send.find('\n') != std::string::npos) {
|
if (result.text_to_send.find('\n') != std::string::npos) {
|
||||||
slot.has_new_line = true;
|
slot.has_new_line = true;
|
||||||
@ -1176,12 +1146,18 @@ struct server_context {
|
|||||||
{"top_k", slot.sparams.top_k},
|
{"top_k", slot.sparams.top_k},
|
||||||
{"top_p", slot.sparams.top_p},
|
{"top_p", slot.sparams.top_p},
|
||||||
{"min_p", slot.sparams.min_p},
|
{"min_p", slot.sparams.min_p},
|
||||||
{"tfs_z", slot.sparams.tfs_z},
|
{"xtc_probability", slot.sparams.xtc_probability},
|
||||||
|
{"xtc_threshold", slot.sparams.xtc_threshold},
|
||||||
{"typical_p", slot.sparams.typ_p},
|
{"typical_p", slot.sparams.typ_p},
|
||||||
{"repeat_last_n", slot.sparams.penalty_last_n},
|
{"repeat_last_n", slot.sparams.penalty_last_n},
|
||||||
{"repeat_penalty", slot.sparams.penalty_repeat},
|
{"repeat_penalty", slot.sparams.penalty_repeat},
|
||||||
{"presence_penalty", slot.sparams.penalty_present},
|
{"presence_penalty", slot.sparams.penalty_present},
|
||||||
{"frequency_penalty", slot.sparams.penalty_freq},
|
{"frequency_penalty", slot.sparams.penalty_freq},
|
||||||
|
{"dry_multiplier", slot.sparams.dry_multiplier},
|
||||||
|
{"dry_base", slot.sparams.dry_base},
|
||||||
|
{"dry_allowed_length", slot.sparams.dry_allowed_length},
|
||||||
|
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
|
||||||
|
{"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
|
||||||
{"mirostat", slot.sparams.mirostat},
|
{"mirostat", slot.sparams.mirostat},
|
||||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||||
@ -1234,7 +1210,7 @@ struct server_context {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if (slot.sparams.n_probs > 0) {
|
if (slot.sparams.n_probs > 0) {
|
||||||
const std::vector<llama_token> to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
||||||
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
||||||
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
||||||
|
|
||||||
@ -1270,7 +1246,7 @@ struct server_context {
|
|||||||
{"tokens_predicted", slot.n_decoded},
|
{"tokens_predicted", slot.n_decoded},
|
||||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||||
{"generation_settings", get_formated_generation(slot)},
|
{"generation_settings", get_formated_generation(slot)},
|
||||||
{"prompt", slot.prompt},
|
{"prompt", common_detokenize(ctx, slot.prompt_tokens)},
|
||||||
{"has_new_line", slot.has_new_line},
|
{"has_new_line", slot.has_new_line},
|
||||||
{"truncated", slot.truncated},
|
{"truncated", slot.truncated},
|
||||||
{"stopped_eos", slot.stopped_eos},
|
{"stopped_eos", slot.stopped_eos},
|
||||||
@ -1285,7 +1261,7 @@ struct server_context {
|
|||||||
if (slot.sparams.n_probs > 0) {
|
if (slot.sparams.n_probs > 0) {
|
||||||
std::vector<completion_token_output> probs;
|
std::vector<completion_token_output> probs;
|
||||||
if (!slot.params.stream && slot.stopped_word) {
|
if (!slot.params.stream && slot.stopped_word) {
|
||||||
const std::vector<llama_token> stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
||||||
|
|
||||||
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
||||||
probs = std::vector<completion_token_output>(
|
probs = std::vector<completion_token_output>(
|
||||||
@ -1394,19 +1370,17 @@ struct server_context {
|
|||||||
// Functions to create new task(s) and receive result(s)
|
// Functions to create new task(s) and receive result(s)
|
||||||
//
|
//
|
||||||
|
|
||||||
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
|
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
|
||||||
|
std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
|
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
|
||||||
|
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
|
||||||
server_task task;
|
server_task task;
|
||||||
task.id = queue_tasks.get_new_id();
|
task.id = queue_tasks.get_new_id();
|
||||||
task.cmpl_type = cmpl_type;
|
task.inf_type = inf_type;
|
||||||
task.type = SERVER_TASK_TYPE_COMPLETION;
|
task.type = SERVER_TASK_TYPE_INFERENCE;
|
||||||
if (replace_prompt) {
|
|
||||||
task.data = task_data;
|
task.data = task_data;
|
||||||
task.data["prompt"] = std::move(prompt);
|
task.prompt_tokens = std::move(prompt_tokens);
|
||||||
} else {
|
|
||||||
task.data = std::move(task_data);
|
|
||||||
}
|
|
||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1415,42 +1389,50 @@ struct server_context {
|
|||||||
throw std::runtime_error(error_msg);
|
throw std::runtime_error(error_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
json prompt = data.at("prompt");
|
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
|
||||||
|
bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
|
||||||
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
|
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
|
||||||
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
|
switch (inf_type) {
|
||||||
data["index"] = 0;
|
case SERVER_TASK_INF_TYPE_RERANK:
|
||||||
create_task(data, false, nullptr);
|
{
|
||||||
} else if (prompt.is_array()) {
|
|
||||||
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
|
||||||
std::vector<json> prompts = prompt;
|
|
||||||
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
|
||||||
// prompts[0] is the question
|
// prompts[0] is the question
|
||||||
// the rest are the answers/documents
|
// the rest are the answers/documents
|
||||||
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
|
GGML_ASSERT(tokenized_prompts.size() > 1);
|
||||||
for (size_t i = 1; i < prompts.size(); i++) {
|
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
|
||||||
json qd;
|
for (size_t i = 1; i < tokenized_prompts.size(); i++) {
|
||||||
qd.push_back(prompts[0]);
|
|
||||||
qd.push_back(prompts[i]);
|
|
||||||
data["index"] = i - 1;
|
data["index"] = i - 1;
|
||||||
create_task(data, true, qd);
|
auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
|
||||||
|
create_task(data, tokens);
|
||||||
}
|
}
|
||||||
} else {
|
} break;
|
||||||
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
|
case SERVER_TASK_INF_TYPE_INFILL:
|
||||||
for (size_t i = 0; i < prompts.size(); i++) {
|
{
|
||||||
const auto & e = prompts[i];
|
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||||
if (e.is_string() || json_is_array_of_numbers(e)) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
data["index"] = i;
|
data["index"] = i;
|
||||||
create_task(data, true, e);
|
auto tokens = format_infill(
|
||||||
} else {
|
ctx,
|
||||||
throw std::runtime_error(error_msg);
|
data.at("input_prefix"),
|
||||||
|
data.at("input_suffix"),
|
||||||
|
data.at("input_extra"),
|
||||||
|
params.n_batch,
|
||||||
|
params.n_predict,
|
||||||
|
slots[0].n_ctx, // TODO: there should be a better way
|
||||||
|
params.spm_infill,
|
||||||
|
tokenized_prompts[i]
|
||||||
|
);
|
||||||
|
create_task(data, tokens);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||||
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
|
data["index"] = i;
|
||||||
|
create_task(data, tokenized_prompts[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// invalid case
|
|
||||||
throw std::runtime_error(error_msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
return tasks;
|
return tasks;
|
||||||
}
|
}
|
||||||
@ -1471,7 +1453,7 @@ struct server_context {
|
|||||||
queue_tasks.post(cancel_tasks, true);
|
queue_tasks.post(cancel_tasks, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// receive the results from task(s) created by create_tasks_cmpl
|
// receive the results from task(s) created by create_tasks_inference
|
||||||
void receive_cmpl_results(
|
void receive_cmpl_results(
|
||||||
const std::unordered_set<int> & id_tasks,
|
const std::unordered_set<int> & id_tasks,
|
||||||
const std::function<void(std::vector<server_task_result>&)> & result_handler,
|
const std::function<void(std::vector<server_task_result>&)> & result_handler,
|
||||||
@ -1495,7 +1477,7 @@ struct server_context {
|
|||||||
result_handler(results);
|
result_handler(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
// receive the results from task(s) created by create_tasks_cmpl, in stream mode
|
// receive the results from task(s) created by create_tasks_inference, in stream mode
|
||||||
void receive_cmpl_results_stream(
|
void receive_cmpl_results_stream(
|
||||||
const std::unordered_set<int> & id_tasks, const
|
const std::unordered_set<int> & id_tasks, const
|
||||||
std::function<bool(server_task_result&)> & result_handler, const
|
std::function<bool(server_task_result&)> & result_handler, const
|
||||||
@ -1528,22 +1510,11 @@ struct server_context {
|
|||||||
|
|
||||||
void process_single_task(const server_task & task) {
|
void process_single_task(const server_task & task) {
|
||||||
switch (task.type) {
|
switch (task.type) {
|
||||||
case SERVER_TASK_TYPE_COMPLETION:
|
case SERVER_TASK_TYPE_INFERENCE:
|
||||||
{
|
{
|
||||||
const int id_slot = json_value(task.data, "id_slot", -1);
|
const int id_slot = json_value(task.data, "id_slot", -1);
|
||||||
|
|
||||||
server_slot * slot;
|
server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
|
||||||
|
|
||||||
if (id_slot != -1) {
|
|
||||||
slot = get_slot_by_id(id_slot);
|
|
||||||
} else {
|
|
||||||
std::string prompt;
|
|
||||||
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
|
|
||||||
prompt = json_value(task.data, "prompt", std::string());
|
|
||||||
}
|
|
||||||
|
|
||||||
slot = get_available_slot(prompt);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (slot == nullptr) {
|
if (slot == nullptr) {
|
||||||
// if no slot is available, we defer this task for processing later
|
// if no slot is available, we defer this task for processing later
|
||||||
@ -1561,8 +1532,9 @@ struct server_context {
|
|||||||
slot->reset();
|
slot->reset();
|
||||||
|
|
||||||
slot->id_task = task.id;
|
slot->id_task = task.id;
|
||||||
slot->cmpl_type = task.cmpl_type;
|
slot->inf_type = task.inf_type;
|
||||||
slot->index = json_value(task.data, "index", 0);
|
slot->index = json_value(task.data, "index", 0);
|
||||||
|
slot->prompt_tokens = std::move(task.prompt_tokens);
|
||||||
|
|
||||||
if (!launch_slot_with_task(*slot, task)) {
|
if (!launch_slot_with_task(*slot, task)) {
|
||||||
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
||||||
@ -1595,7 +1567,7 @@ struct server_context {
|
|||||||
slot_data["id"] = slot.id;
|
slot_data["id"] = slot.id;
|
||||||
slot_data["id_task"] = slot.id_task;
|
slot_data["id_task"] = slot.id_task;
|
||||||
slot_data["state"] = slot.state;
|
slot_data["state"] = slot.state;
|
||||||
slot_data["prompt"] = slot.prompt;
|
slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
|
||||||
slot_data["next_token"] = {
|
slot_data["next_token"] = {
|
||||||
{"has_next_token", slot.has_next_token},
|
{"has_next_token", slot.has_next_token},
|
||||||
{"has_new_line", slot.has_new_line},
|
{"has_new_line", slot.has_new_line},
|
||||||
@ -1722,9 +1694,6 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
slot->cache_tokens.resize(token_count);
|
slot->cache_tokens.resize(token_count);
|
||||||
|
|
||||||
// TODO: maybe detokenize the slot->cache_tokens instead?
|
|
||||||
slot->prompt = string_format("[restored %d tokens from file]", (int) token_count);
|
|
||||||
|
|
||||||
const int64_t t_end = ggml_time_us();
|
const int64_t t_end = ggml_time_us();
|
||||||
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
||||||
|
|
||||||
@ -1891,80 +1860,19 @@ struct server_context {
|
|||||||
if (params.cont_batching || batch.n_tokens == 0) {
|
if (params.cont_batching || batch.n_tokens == 0) {
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
// this slot still has a prompt to be processed
|
// this slot still has a prompt to be processed
|
||||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
|
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||||
auto & prompt_tokens = slot.prompt_tokens;
|
auto & prompt_tokens = slot.prompt_tokens;
|
||||||
|
|
||||||
// we haven't tokenized the prompt yet - do it now:
|
// TODO: maybe move branch to outside of this loop in the future
|
||||||
if (prompt_tokens.empty()) {
|
if (slot.state == SLOT_STATE_STARTED) {
|
||||||
SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size());
|
|
||||||
|
|
||||||
slot.t_start_process_prompt = ggml_time_us();
|
slot.t_start_process_prompt = ggml_time_us();
|
||||||
slot.t_start_generation = 0;
|
slot.t_start_generation = 0;
|
||||||
|
|
||||||
switch (slot.cmpl_type) {
|
|
||||||
case SERVER_TASK_CMPL_TYPE_NORMAL:
|
|
||||||
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
|
|
||||||
{
|
|
||||||
prompt_tokens = tokenize(slot.prompt, llama_add_bos_token(model), true);
|
|
||||||
} break;
|
|
||||||
case SERVER_TASK_CMPL_TYPE_RERANK:
|
|
||||||
{
|
|
||||||
// require slot.prompt to be array of 2 strings
|
|
||||||
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
|
|
||||||
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
|
|
||||||
slot.release();
|
|
||||||
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// prompt: [BOS]query[EOS][SEP]doc[EOS]
|
|
||||||
prompt_tokens.clear();
|
|
||||||
prompt_tokens.push_back(llama_token_bos(model));
|
|
||||||
{
|
|
||||||
const auto part = tokenize(slot.prompt[0], false, false);
|
|
||||||
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
|
||||||
}
|
|
||||||
prompt_tokens.push_back(llama_token_eos(model));
|
|
||||||
prompt_tokens.push_back(llama_token_sep(model));
|
|
||||||
{
|
|
||||||
const auto part = tokenize(slot.prompt[1], false, false);
|
|
||||||
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
|
||||||
}
|
|
||||||
prompt_tokens.push_back(llama_token_eos(model));
|
|
||||||
} break;
|
|
||||||
case SERVER_TASK_CMPL_TYPE_INFILL:
|
|
||||||
{
|
|
||||||
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
|
|
||||||
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
|
|
||||||
|
|
||||||
// for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
|
||||||
const int n_suffix_take = std::min<int>(suffix_tokens.size(), n_batch/4);
|
|
||||||
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
|
|
||||||
|
|
||||||
prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
|
|
||||||
suffix_tokens.resize(n_suffix_take);
|
|
||||||
|
|
||||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
|
|
||||||
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
|
|
||||||
|
|
||||||
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
|
|
||||||
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
|
|
||||||
|
|
||||||
if (llama_add_bos_token(model)) {
|
|
||||||
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
|
||||||
}
|
|
||||||
|
|
||||||
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
|
||||||
embd_inp.push_back(llama_token_fim_mid(model));
|
|
||||||
|
|
||||||
prompt_tokens = std::move(embd_inp);
|
|
||||||
} break;
|
|
||||||
}
|
|
||||||
|
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
slot.n_prompt_tokens = prompt_tokens.size();
|
slot.n_prompt_tokens = prompt_tokens.size();
|
||||||
|
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||||
|
|
||||||
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
||||||
|
|
||||||
// print prompt tokens (for debugging)
|
// print prompt tokens (for debugging)
|
||||||
if (1) {
|
if (1) {
|
||||||
@ -1989,13 +1897,18 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||||
// this prompt is too large to process - discard it
|
|
||||||
if (slot.n_prompt_tokens > n_ubatch) {
|
if (slot.n_prompt_tokens > n_ubatch) {
|
||||||
slot.release();
|
slot.release();
|
||||||
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (slot.n_prompt_tokens > slot.n_ctx) {
|
||||||
|
slot.release();
|
||||||
|
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!params.ctx_shift) {
|
if (!params.ctx_shift) {
|
||||||
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
||||||
@ -2012,14 +1925,14 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||||
|
|
||||||
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
// if input prompt is too big, truncate it
|
||||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||||
|
|
||||||
const int n_block_size = n_left / 2;
|
const int n_block_size = n_left / 2;
|
||||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||||
|
|
||||||
std::vector<llama_token> new_tokens(
|
llama_tokens new_tokens(
|
||||||
prompt_tokens.begin(),
|
prompt_tokens.begin(),
|
||||||
prompt_tokens.begin() + slot.params.n_keep);
|
prompt_tokens.begin() + slot.params.n_keep);
|
||||||
|
|
||||||
@ -2038,15 +1951,52 @@ struct server_context {
|
|||||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
common_sampler_reset(slot.smpl);
|
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
// reuse any previously computed tokens that are common with the new prompt
|
// reuse any previously computed tokens that are common with the new prompt
|
||||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
|
||||||
|
|
||||||
// push the prompt into the sampling context (do not apply grammar)
|
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||||
for (int i = 0; i < slot.n_past; ++i) {
|
if (params.n_cache_reuse > 0) {
|
||||||
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
|
size_t head_c = slot.n_past; // cache
|
||||||
|
size_t head_p = slot.n_past; // current prompt
|
||||||
|
|
||||||
|
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
|
||||||
|
|
||||||
|
while (head_c < slot.cache_tokens.size() &&
|
||||||
|
head_p < prompt_tokens.size()) {
|
||||||
|
|
||||||
|
size_t n_match = 0;
|
||||||
|
while (head_c + n_match < slot.cache_tokens.size() &&
|
||||||
|
head_p + n_match < prompt_tokens.size() &&
|
||||||
|
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
|
||||||
|
|
||||||
|
n_match++;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_match >= (size_t) params.n_cache_reuse) {
|
||||||
|
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
||||||
|
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
||||||
|
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||||
|
//}
|
||||||
|
|
||||||
|
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
||||||
|
|
||||||
|
llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
|
||||||
|
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < n_match; i++) {
|
||||||
|
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
||||||
|
slot.n_past++;
|
||||||
|
}
|
||||||
|
|
||||||
|
head_c += n_match;
|
||||||
|
head_p += n_match;
|
||||||
|
} else {
|
||||||
|
head_c += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2062,7 +2012,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||||
// cannot fit the prompt in the current batch - will try next iter
|
// cannot fit the prompt in the current batch - will try next iter
|
||||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||||
continue;
|
continue;
|
||||||
@ -2071,8 +2021,8 @@ struct server_context {
|
|||||||
|
|
||||||
// check that we are in the right batch_type, if not defer the slot
|
// check that we are in the right batch_type, if not defer the slot
|
||||||
const bool slot_type =
|
const bool slot_type =
|
||||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
|
||||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
|
||||||
|
|
||||||
if (batch_type == -1) {
|
if (batch_type == -1) {
|
||||||
batch_type = slot_type;
|
batch_type = slot_type;
|
||||||
@ -2087,8 +2037,6 @@ struct server_context {
|
|||||||
|
|
||||||
// there is no common part left
|
// there is no common part left
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
|
|
||||||
common_sampler_reset(slot.smpl);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
||||||
@ -2116,6 +2064,13 @@ struct server_context {
|
|||||||
|
|
||||||
GGML_ASSERT(batch.n_tokens > 0);
|
GGML_ASSERT(batch.n_tokens > 0);
|
||||||
|
|
||||||
|
common_sampler_reset(slot.smpl);
|
||||||
|
|
||||||
|
// Process all prompt tokens through sampler system
|
||||||
|
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
||||||
|
common_sampler_accept(slot.smpl, prompt_tokens[i], false);
|
||||||
|
}
|
||||||
|
|
||||||
// extract the logits only for the last token
|
// extract the logits only for the last token
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
@ -2154,7 +2109,6 @@ struct server_context {
|
|||||||
batch.n_seq_id + i,
|
batch.n_seq_id + i,
|
||||||
batch.seq_id + i,
|
batch.seq_id + i,
|
||||||
batch.logits + i,
|
batch.logits + i,
|
||||||
0, 0, 0, // unused
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
@ -2186,7 +2140,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
|
||||||
// prompt evaluated for embedding
|
// prompt evaluated for embedding
|
||||||
send_embedding(slot, batch_view);
|
send_embedding(slot, batch_view);
|
||||||
slot.release();
|
slot.release();
|
||||||
@ -2194,7 +2148,7 @@ struct server_context {
|
|||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||||
send_rerank(slot, batch_view);
|
send_rerank(slot, batch_view);
|
||||||
slot.release();
|
slot.release();
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
@ -2441,7 +2395,7 @@ int main(int argc, char ** argv) {
|
|||||||
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
|
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
|
||||||
server_state current_state = state.load();
|
server_state current_state = state.load();
|
||||||
if (current_state == SERVER_STATE_LOADING_MODEL) {
|
if (current_state == SERVER_STATE_LOADING_MODEL) {
|
||||||
auto tmp = string_split(req.path, '.');
|
auto tmp = string_split<std::string>(req.path, '.');
|
||||||
if (req.path == "/" || tmp.back() == "html") {
|
if (req.path == "/" || tmp.back() == "html") {
|
||||||
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
||||||
res.status = 503;
|
res.status = 503;
|
||||||
@ -2748,13 +2702,13 @@ int main(int argc, char ** argv) {
|
|||||||
res_ok(res, {{ "success", true }});
|
res_ok(res, {{ "success", true }});
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
|
||||||
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
@ -2800,10 +2754,11 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
|
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
// check model compatibility
|
||||||
std::string err;
|
std::string err;
|
||||||
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
||||||
err += "prefix token is missing. ";
|
err += "prefix token is missing. ";
|
||||||
@ -2814,14 +2769,42 @@ int main(int argc, char ** argv) {
|
|||||||
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
||||||
err += "middle token is missing. ";
|
err += "middle token is missing. ";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!err.empty()) {
|
if (!err.empty()) {
|
||||||
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
|
|
||||||
|
// validate input
|
||||||
|
if (!data.contains("input_prefix")) {
|
||||||
|
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!data.contains("input_suffix")) {
|
||||||
|
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
||||||
|
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
json input_extra = json_value(data, "input_extra", json::array());
|
||||||
|
for (const auto & chunk : input_extra) {
|
||||||
|
// { "text": string, "filename": string }
|
||||||
|
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
|
||||||
|
res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// filename is optional
|
||||||
|
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
|
||||||
|
res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
||||||
|
|
||||||
|
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: maybe merge this function with "handle_completions_generic"
|
// TODO: maybe merge this function with "handle_completions_generic"
|
||||||
@ -2833,7 +2816,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||||
|
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
@ -2906,7 +2889,7 @@ int main(int argc, char ** argv) {
|
|||||||
const bool add_special = json_value(body, "add_special", false);
|
const bool add_special = json_value(body, "add_special", false);
|
||||||
const bool with_pieces = json_value(body, "with_pieces", false);
|
const bool with_pieces = json_value(body, "with_pieces", false);
|
||||||
|
|
||||||
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special, true);
|
llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
|
||||||
|
|
||||||
if (with_pieces) {
|
if (with_pieces) {
|
||||||
for (const auto& token : tokens) {
|
for (const auto& token : tokens) {
|
||||||
@ -2943,7 +2926,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
std::string content;
|
std::string content;
|
||||||
if (body.count("tokens") != 0) {
|
if (body.count("tokens") != 0) {
|
||||||
const std::vector<llama_token> tokens = body.at("tokens");
|
const llama_tokens tokens = body.at("tokens");
|
||||||
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
|
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2977,7 +2960,7 @@ int main(int argc, char ** argv) {
|
|||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
bool error = false;
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
@ -3054,7 +3037,7 @@ int main(int argc, char ** argv) {
|
|||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
bool error = false;
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
@ -3257,6 +3240,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||||
|
|
||||||
ctx_server.queue_tasks.on_update_slots(std::bind(
|
ctx_server.queue_tasks.on_update_slots(std::bind(
|
||||||
&server_context::update_slots, &ctx_server));
|
&server_context::update_slots, &ctx_server));
|
||||||
|
|
||||||
@ -3264,7 +3248,7 @@ int main(int argc, char ** argv) {
|
|||||||
ctx_server.queue_tasks.terminate();
|
ctx_server.queue_tasks.terminate();
|
||||||
};
|
};
|
||||||
|
|
||||||
LOG_INF("%s: server is listening on %s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
||||||
|
|
||||||
ctx_server.queue_tasks.start_loop();
|
ctx_server.queue_tasks.start_loop();
|
||||||
|
|
||||||
|
36
examples/server/tests/features/infill.feature
Normal file
36
examples/server/tests/features/infill.feature
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
@llama.cpp
|
||||||
|
@infill
|
||||||
|
Feature: llama.cpp server
|
||||||
|
|
||||||
|
# The current model is made by adding FIM tokens to the existing stories260K
|
||||||
|
# We may want to use a better model in the future, maybe something like SmolLM 360M
|
||||||
|
|
||||||
|
Background: Server startup
|
||||||
|
Given a server listening on localhost:8080
|
||||||
|
And a model file tinyllamas/stories260K-infill.gguf from HF repo ggml-org/models
|
||||||
|
And a model file test-model-infill.gguf
|
||||||
|
And a model alias tinyllama-infill
|
||||||
|
And 42 as server seed
|
||||||
|
And 1024 as batch size
|
||||||
|
And 1024 as ubatch size
|
||||||
|
And 2048 KV cache size
|
||||||
|
And 64 max tokens to predict
|
||||||
|
And 0.0 temperature
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
|
||||||
|
Scenario: Infill without input_extra
|
||||||
|
Given a prompt "Complete this"
|
||||||
|
And an infill input extra none none
|
||||||
|
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
|
||||||
|
And an infill input suffix "}\n"
|
||||||
|
And an infill request with no api error
|
||||||
|
Then 64 tokens are predicted matching One|day|she|saw|big|scary|bird
|
||||||
|
|
||||||
|
Scenario: Infill with input_extra
|
||||||
|
Given a prompt "Complete this"
|
||||||
|
And an infill input extra "llama.h" "LLAMA_API int32_t llama_n_threads();\n"
|
||||||
|
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
|
||||||
|
And an infill input suffix "}\n"
|
||||||
|
And an infill request with no api error
|
||||||
|
Then 64 tokens are predicted matching cuts|Jimmy|mom|came|into|the|room"
|
@ -80,6 +80,11 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
|||||||
context.lora_file = None
|
context.lora_file = None
|
||||||
context.disable_ctx_shift = False
|
context.disable_ctx_shift = False
|
||||||
|
|
||||||
|
# infill
|
||||||
|
context.infill_input_extra = None
|
||||||
|
context.infill_input_suffix = ''
|
||||||
|
context.infill_input_prefix = ''
|
||||||
|
|
||||||
context.tasks_result = []
|
context.tasks_result = []
|
||||||
context.concurrent_tasks = []
|
context.concurrent_tasks = []
|
||||||
context.prompts = []
|
context.prompts = []
|
||||||
@ -291,6 +296,28 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
|
|||||||
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
|
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
|
||||||
|
|
||||||
|
|
||||||
|
@step('an infill request with {api_error} api error')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||||
|
if api_error != 'no':
|
||||||
|
raise ValueError(f'api_error={api_error} is not yet implemented')
|
||||||
|
payload = {
|
||||||
|
"prompt": context.prompts[0],
|
||||||
|
"input_suffix": context.infill_input_suffix,
|
||||||
|
"input_prefix": context.infill_input_prefix,
|
||||||
|
"n_predict": context.n_predict,
|
||||||
|
"seed": context.seed,
|
||||||
|
"temperature": context.temperature,
|
||||||
|
}
|
||||||
|
if context.infill_input_extra is not None:
|
||||||
|
payload['input_extra'] = context.infill_input_extra
|
||||||
|
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
|
||||||
|
async with session.post(f'{context.base_url}/infill',
|
||||||
|
json=payload) as response:
|
||||||
|
assert response.status == 200
|
||||||
|
context.tasks_result = [await response.json()]
|
||||||
|
|
||||||
|
|
||||||
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
||||||
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
||||||
context.completion = context.tasks_result.pop()
|
context.completion = context.tasks_result.pop()
|
||||||
@ -539,6 +566,25 @@ def step_a_prompt_prompt(context, prompt):
|
|||||||
context.n_prompts = len(context.prompts)
|
context.n_prompts = len(context.prompts)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: allow this to be repeated
|
||||||
|
@step('an infill input extra {filename} {text}')
|
||||||
|
def step_infill_input_extra(context, filename, text):
|
||||||
|
if filename == 'none':
|
||||||
|
context.infill_input_extra = None
|
||||||
|
else:
|
||||||
|
context.infill_input_extra = [{'filename': filename, 'text': text}]
|
||||||
|
|
||||||
|
|
||||||
|
@step('an infill input suffix {text}')
|
||||||
|
def step_infill_input_suffix(context, text):
|
||||||
|
context.infill_input_suffix = text
|
||||||
|
|
||||||
|
|
||||||
|
@step('an infill input prefix {text}')
|
||||||
|
def step_infill_input_prefix(context, text):
|
||||||
|
context.infill_input_prefix = text
|
||||||
|
|
||||||
|
|
||||||
@step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
|
@step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
|
||||||
def step_many_prompts(context, num_prompts, prompt, seed):
|
def step_many_prompts(context, num_prompts, prompt, seed):
|
||||||
if context.seed is None:
|
if context.seed is None:
|
||||||
|
@ -226,7 +226,6 @@
|
|||||||
top_k: 40, // <= 0 to use vocab size
|
top_k: 40, // <= 0 to use vocab size
|
||||||
top_p: 0.95, // 1.0 = disabled
|
top_p: 0.95, // 1.0 = disabled
|
||||||
min_p: 0.05, // 0 = disabled
|
min_p: 0.05, // 0 = disabled
|
||||||
tfs_z: 1.0, // 1.0 = disabled
|
|
||||||
typical_p: 1.0, // 1.0 = disabled
|
typical_p: 1.0, // 1.0 = disabled
|
||||||
presence_penalty: 0.0, // 0.0 = disabled
|
presence_penalty: 0.0, // 0.0 = disabled
|
||||||
frequency_penalty: 0.0, // 0.0 = disabled
|
frequency_penalty: 0.0, // 0.0 = disabled
|
||||||
@ -788,7 +787,6 @@
|
|||||||
<details>
|
<details>
|
||||||
<summary>More options</summary>
|
<summary>More options</summary>
|
||||||
<fieldset class="two">
|
<fieldset class="two">
|
||||||
${FloatField({ label: "TFS-Z", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
|
|
||||||
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
||||||
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
||||||
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
||||||
|
@ -229,7 +229,6 @@
|
|||||||
top_k: 40, // <= 0 to use vocab size
|
top_k: 40, // <= 0 to use vocab size
|
||||||
top_p: 0.95, // 1.0 = disabled
|
top_p: 0.95, // 1.0 = disabled
|
||||||
min_p: 0.05, // 0 = disabled
|
min_p: 0.05, // 0 = disabled
|
||||||
tfs_z: 1.0, // 1.0 = disabled
|
|
||||||
typical_p: 1.0, // 1.0 = disabled
|
typical_p: 1.0, // 1.0 = disabled
|
||||||
presence_penalty: 0.0, // 0.0 = disabled
|
presence_penalty: 0.0, // 0.0 = disabled
|
||||||
frequency_penalty: 0.0, // 0.0 = disabled
|
frequency_penalty: 0.0, // 0.0 = disabled
|
||||||
@ -791,7 +790,6 @@
|
|||||||
<details>
|
<details>
|
||||||
<summary>More options</summary>
|
<summary>More options</summary>
|
||||||
<fieldset class="two">
|
<fieldset class="two">
|
||||||
${FloatField({ label: "TFS-Z", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
|
|
||||||
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
||||||
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
||||||
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
||||||
|
@ -24,6 +24,22 @@
|
|||||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
using llama_tokens = std::vector<llama_token>;
|
||||||
|
|
||||||
|
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
|
||||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||||
enum error_type {
|
enum error_type {
|
||||||
@ -52,9 +68,237 @@ static T json_value(const json & body, const std::string & key, const T & defaul
|
|||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// chat template utils
|
// tokenizer and input processing utils
|
||||||
//
|
//
|
||||||
|
|
||||||
|
static bool json_is_array_of_numbers(const json & data) {
|
||||||
|
if (data.is_array()) {
|
||||||
|
for (const auto & e : data) {
|
||||||
|
if (!e.is_number_integer()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// is array having BOTH numbers & strings?
|
||||||
|
static bool json_is_array_of_mixed_numbers_strings(const json & data) {
|
||||||
|
bool seen_string = false;
|
||||||
|
bool seen_number = false;
|
||||||
|
if (data.is_array()) {
|
||||||
|
for (const auto & e : data) {
|
||||||
|
seen_string |= e.is_string();
|
||||||
|
seen_number |= e.is_number_integer();
|
||||||
|
if (seen_number && seen_string) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* this handles 2 cases:
|
||||||
|
* - only string, example: "string"
|
||||||
|
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
|
||||||
|
*/
|
||||||
|
static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
|
||||||
|
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
||||||
|
// or the first element of the json_prompt array is a string.
|
||||||
|
llama_tokens prompt_tokens;
|
||||||
|
|
||||||
|
if (json_prompt.is_array()) {
|
||||||
|
bool first = true;
|
||||||
|
for (const auto & p : json_prompt) {
|
||||||
|
if (p.is_string()) {
|
||||||
|
auto s = p.template get<std::string>();
|
||||||
|
|
||||||
|
llama_tokens p;
|
||||||
|
if (first) {
|
||||||
|
p = common_tokenize(ctx, s, add_special, parse_special);
|
||||||
|
first = false;
|
||||||
|
} else {
|
||||||
|
p = common_tokenize(ctx, s, false, parse_special);
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
||||||
|
} else {
|
||||||
|
if (first) {
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt_tokens.push_back(p.template get<llama_token>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto s = json_prompt.template get<std::string>();
|
||||||
|
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
|
||||||
|
}
|
||||||
|
|
||||||
|
return prompt_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* break the input "prompt" object into multiple prompt if needed, then tokenize them
|
||||||
|
* this supports these cases:
|
||||||
|
* - "prompt": "string"
|
||||||
|
* - "prompt": [12, 34, 56]
|
||||||
|
* - "prompt": [12, 34, "string", 56, 78]
|
||||||
|
* and multiple prompts (multi-tasks):
|
||||||
|
* - "prompt": ["string1", "string2"]
|
||||||
|
* - "prompt": ["string1", [12, 34, 56]]
|
||||||
|
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
|
||||||
|
*/
|
||||||
|
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
|
||||||
|
std::vector<llama_tokens> result;
|
||||||
|
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
|
||||||
|
// string or mixed
|
||||||
|
result.push_back(tokenize_mixed(ctx, json_prompt, add_special, parse_special));
|
||||||
|
} else if (json_is_array_of_numbers(json_prompt)) {
|
||||||
|
// array of tokens
|
||||||
|
result.push_back(json_prompt.get<llama_tokens>());
|
||||||
|
} else if (json_prompt.is_array()) {
|
||||||
|
// array of prompts
|
||||||
|
result.reserve(json_prompt.size());
|
||||||
|
for (const auto & p : json_prompt) {
|
||||||
|
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
|
||||||
|
result.push_back(tokenize_mixed(ctx, p, add_special, parse_special));
|
||||||
|
} else if (json_is_array_of_numbers(p)) {
|
||||||
|
// array of tokens
|
||||||
|
result.push_back(p.get<llama_tokens>());
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// template utils
|
||||||
|
//
|
||||||
|
|
||||||
|
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
|
||||||
|
static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
|
||||||
|
llama_tokens result;
|
||||||
|
result.reserve(doc.size() + query.size() + 4);
|
||||||
|
result.push_back(llama_token_bos(model));
|
||||||
|
result.insert(result.end(), query.begin(), query.end());
|
||||||
|
result.push_back(llama_token_eos(model));
|
||||||
|
result.push_back(llama_token_sep(model));
|
||||||
|
result.insert(result.end(), doc.begin(), doc.end());
|
||||||
|
result.push_back(llama_token_eos(model));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// format infill task
|
||||||
|
static llama_tokens format_infill(
|
||||||
|
const llama_context * ctx,
|
||||||
|
const json & input_prefix,
|
||||||
|
const json & input_suffix,
|
||||||
|
const json & input_extra,
|
||||||
|
const int n_batch,
|
||||||
|
const int n_predict,
|
||||||
|
const int n_ctx,
|
||||||
|
const bool spm_infill,
|
||||||
|
const llama_tokens & tokens_prompt
|
||||||
|
) {
|
||||||
|
// TODO: optimize this block by reducing memory allocations and movement
|
||||||
|
|
||||||
|
// use FIM repo-level pattern:
|
||||||
|
// ref: https://arxiv.org/pdf/2409.12186
|
||||||
|
//
|
||||||
|
// [FIM_REP]myproject
|
||||||
|
// [FIM_SEP]filename0
|
||||||
|
// extra chunk 0
|
||||||
|
// [FIM_SEP]filename1
|
||||||
|
// extra chunk 1
|
||||||
|
// ...
|
||||||
|
// [FIM_SEP]filename
|
||||||
|
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
|
||||||
|
//
|
||||||
|
llama_tokens extra_tokens;
|
||||||
|
extra_tokens.reserve(n_ctx);
|
||||||
|
|
||||||
|
auto model = llama_get_model(ctx);
|
||||||
|
auto tokens_prefix = tokenize_mixed(ctx, input_prefix, false, false);
|
||||||
|
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
|
||||||
|
|
||||||
|
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
|
||||||
|
// TODO: make project name an input
|
||||||
|
static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false);
|
||||||
|
|
||||||
|
extra_tokens.push_back(llama_token_fim_rep(model));
|
||||||
|
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
|
||||||
|
}
|
||||||
|
for (const auto & chunk : input_extra) {
|
||||||
|
// { "text": string, "filename": string }
|
||||||
|
const std::string text = json_value(chunk, "text", std::string());
|
||||||
|
const std::string filename = json_value(chunk, "filename", std::string("tmp"));
|
||||||
|
|
||||||
|
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||||
|
const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false);
|
||||||
|
|
||||||
|
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
|
||||||
|
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||||
|
} else {
|
||||||
|
// chunk separator in binary form to avoid confusing the AI
|
||||||
|
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
|
||||||
|
static const auto k_chunk_prefix_tokens = common_tokenize(ctx, k_chunk_prefix_str, false, false);
|
||||||
|
|
||||||
|
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto chunk_tokens = common_tokenize(ctx, text, false, false);
|
||||||
|
extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||||
|
// TODO: current filename
|
||||||
|
static const auto k_fim_file = common_tokenize(ctx, "filename\n", false, false);
|
||||||
|
|
||||||
|
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
|
||||||
|
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
||||||
|
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4));
|
||||||
|
const int n_suffix_take = std::min<int>(tokens_suffix.size(), std::max<int>(0, (n_batch/4) - (2 + tokens_prompt.size())));
|
||||||
|
|
||||||
|
SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take));
|
||||||
|
|
||||||
|
// fill the rest of the context with extra chunks
|
||||||
|
const int n_extra_take = std::min<int>(std::max<int>(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size());
|
||||||
|
|
||||||
|
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
|
||||||
|
tokens_suffix.resize(n_suffix_take);
|
||||||
|
|
||||||
|
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
|
||||||
|
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
|
||||||
|
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
|
||||||
|
|
||||||
|
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
|
||||||
|
auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
|
||||||
|
|
||||||
|
if (llama_add_bos_token(model)) {
|
||||||
|
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
||||||
|
}
|
||||||
|
|
||||||
|
SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
|
||||||
|
|
||||||
|
// put the extra context before the FIM prefix
|
||||||
|
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
|
||||||
|
|
||||||
|
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
||||||
|
embd_inp.push_back(llama_token_fim_mid(model));
|
||||||
|
|
||||||
|
return embd_inp;
|
||||||
|
}
|
||||||
|
|
||||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
||||||
std::vector<common_chat_msg> chat;
|
std::vector<common_chat_msg> chat;
|
||||||
@ -195,18 +439,60 @@ static std::string gen_chatcmplid() {
|
|||||||
// other common utils
|
// other common utils
|
||||||
//
|
//
|
||||||
|
|
||||||
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
|
static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) {
|
||||||
size_t i;
|
size_t i;
|
||||||
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
||||||
|
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t common_part(const std::string & a, const std::string & b) {
|
static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) {
|
||||||
size_t i;
|
// check for empty sequences
|
||||||
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
if (a.empty() || b.empty()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
return i;
|
// get the lengths of the input sequences
|
||||||
|
int a_len = a.size();
|
||||||
|
int b_len = b.size();
|
||||||
|
|
||||||
|
// initialize the maximum length of the longest common subsequence (LCS)
|
||||||
|
int max_length = 0;
|
||||||
|
|
||||||
|
// use two rows instead of a 2D matrix to optimize space
|
||||||
|
std::vector<int> prev_row(b_len + 1, 0);
|
||||||
|
std::vector<int> curr_row(b_len + 1, 0);
|
||||||
|
|
||||||
|
// iterate through the elements of a
|
||||||
|
for (int i = 1; i <= a_len; i++) {
|
||||||
|
// iterate through the elements of b
|
||||||
|
for (int j = 1; j <= b_len; j++) {
|
||||||
|
// if elements at the current positions match
|
||||||
|
if (a[i - 1] == b[j - 1]) {
|
||||||
|
// if it's the first element of either sequences, set LCS length to 1
|
||||||
|
if (i == 1 || j == 1) {
|
||||||
|
curr_row[j] = 1;
|
||||||
|
} else {
|
||||||
|
// increment LCS length by 1 compared to the previous element
|
||||||
|
curr_row[j] = prev_row[j - 1] + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// update max_length if necessary
|
||||||
|
if (curr_row[j] > max_length) {
|
||||||
|
max_length = curr_row[j];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// reset LCS length if elements don't match
|
||||||
|
curr_row[j] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// update the previous row for the next iteration
|
||||||
|
prev_row = curr_row;
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the maximum length of the LCS
|
||||||
|
return max_length;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ends_with(const std::string & str, const std::string & suffix) {
|
static bool ends_with(const std::string & str, const std::string & suffix) {
|
||||||
@ -229,18 +515,6 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
|
|||||||
return std::string::npos;
|
return std::string::npos;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool json_is_array_of_numbers(const json & data) {
|
|
||||||
if (data.is_array()) {
|
|
||||||
for (const auto & e : data) {
|
|
||||||
if (!e.is_number()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: reuse llama_detokenize
|
// TODO: reuse llama_detokenize
|
||||||
template <class Iter>
|
template <class Iter>
|
||||||
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
||||||
@ -360,9 +634,9 @@ static json oaicompat_completion_params_parse(
|
|||||||
|
|
||||||
// Handle "logprobs" field
|
// Handle "logprobs" field
|
||||||
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
||||||
if (body.contains("logprobs")) {
|
if (json_value(body, "logprobs", false)) {
|
||||||
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
|
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
|
||||||
} else if (body.contains("top_logprobs")) {
|
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
|
||||||
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -375,7 +649,7 @@ static json oaicompat_completion_params_parse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Copy remaining properties to llama_params
|
// Copy remaining properties to llama_params
|
||||||
// This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint.
|
// This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint.
|
||||||
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
|
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
|
||||||
for (const auto & item : body.items()) {
|
for (const auto & item : body.items()) {
|
||||||
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
|
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
|
||||||
|
@ -138,7 +138,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// prepare a batch for the prompt
|
// prepare a batch for the prompt
|
||||||
|
|
||||||
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size(), 0, 0);
|
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
||||||
|
|
||||||
// main loop
|
// main loop
|
||||||
|
|
||||||
@ -175,7 +175,7 @@ int main(int argc, char ** argv) {
|
|||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
// prepare the next batch with the sampled token
|
// prepare the next batch with the sampled token
|
||||||
batch = llama_batch_get_one(&new_token_id, 1, n_pos, 0);
|
batch = llama_batch_get_one(&new_token_id, 1);
|
||||||
|
|
||||||
n_decode += 1;
|
n_decode += 1;
|
||||||
}
|
}
|
||||||
|
@ -39,6 +39,11 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.n_predict < -1) {
|
||||||
|
LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
common_init();
|
common_init();
|
||||||
|
|
||||||
if (params.model_draft.empty()) {
|
if (params.model_draft.empty()) {
|
||||||
@ -155,9 +160,9 @@ int main(int argc, char ** argv) {
|
|||||||
const auto t_enc_start = ggml_time_us();
|
const auto t_enc_start = ggml_time_us();
|
||||||
|
|
||||||
// eval the prompt with both models
|
// eval the prompt with both models
|
||||||
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
|
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1));
|
||||||
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
|
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1));
|
||||||
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0));
|
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input));
|
||||||
|
|
||||||
const auto t_enc_end = ggml_time_us();
|
const auto t_enc_end = ggml_time_us();
|
||||||
|
|
||||||
@ -180,8 +185,6 @@ int main(int argc, char ** argv) {
|
|||||||
// target model sampling context (reuse the llama_context's sampling instance)
|
// target model sampling context (reuse the llama_context's sampling instance)
|
||||||
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
|
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
|
||||||
|
|
||||||
struct llama_sampler * softmax = llama_sampler_init_softmax();
|
|
||||||
|
|
||||||
// draft sequence data
|
// draft sequence data
|
||||||
std::vector<seq_draft> drafts(n_seq_dft);
|
std::vector<seq_draft> drafts(n_seq_dft);
|
||||||
|
|
||||||
@ -190,8 +193,8 @@ int main(int argc, char ** argv) {
|
|||||||
drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
|
drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
||||||
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
|
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
|
||||||
|
|
||||||
const auto t_dec_start = ggml_time_us();
|
const auto t_dec_start = ggml_time_us();
|
||||||
|
|
||||||
@ -441,7 +444,7 @@ int main(int argc, char ** argv) {
|
|||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_predict > params.n_predict || has_eos) {
|
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -624,7 +627,6 @@ int main(int argc, char ** argv) {
|
|||||||
common_sampler_free(drafts[s].smpl);
|
common_sampler_free(drafts[s].smpl);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampler_free(softmax);
|
|
||||||
llama_batch_free(batch_dft);
|
llama_batch_free(batch_dft);
|
||||||
|
|
||||||
llama_free(ctx_tgt);
|
llama_free(ctx_tgt);
|
||||||
|
@ -20,11 +20,11 @@
|
|||||||
},
|
},
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1728018373,
|
"lastModified": 1729665710,
|
||||||
"narHash": "sha256-NOiTvBbRLIOe5F6RbHaAh6++BNjsb149fGZd1T4+KBg=",
|
"narHash": "sha256-AlcmCXJZPIlO5dmFzV3V2XF6x/OpNWUV8Y/FMPGd8Z4=",
|
||||||
"owner": "NixOS",
|
"owner": "NixOS",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "bc947f541ae55e999ffdb4013441347d83b00feb",
|
"rev": "2768c7d042a37de65bb1b5b3268fc987e534c49d",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
@ -99,6 +99,9 @@ option(GGML_AVX512 "ggml: enable AVX512" OFF)
|
|||||||
option(GGML_AVX512_VBMI "ggml: enable AVX512-VBMI" OFF)
|
option(GGML_AVX512_VBMI "ggml: enable AVX512-VBMI" OFF)
|
||||||
option(GGML_AVX512_VNNI "ggml: enable AVX512-VNNI" OFF)
|
option(GGML_AVX512_VNNI "ggml: enable AVX512-VNNI" OFF)
|
||||||
option(GGML_AVX512_BF16 "ggml: enable AVX512-BF16" OFF)
|
option(GGML_AVX512_BF16 "ggml: enable AVX512-BF16" OFF)
|
||||||
|
option(GGML_AMX_TILE "ggml: enable AMX-TILE" OFF)
|
||||||
|
option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF)
|
||||||
|
option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF)
|
||||||
option(GGML_FMA "ggml: enable FMA" ${INS_ENB})
|
option(GGML_FMA "ggml: enable FMA" ${INS_ENB})
|
||||||
if (NOT MSVC)
|
if (NOT MSVC)
|
||||||
option(GGML_F16C "ggml: enable F16C" ${INS_ENB}) # in MSVC F16C is implied with AVX2/AVX512
|
option(GGML_F16C "ggml: enable F16C" ${INS_ENB}) # in MSVC F16C is implied with AVX2/AVX512
|
||||||
@ -158,6 +161,7 @@ set (GGML_METAL_MACOSX_VERSION_MIN "" CACHE STRING
|
|||||||
set (GGML_METAL_STD "" CACHE STRING "ggml: metal standard version (-std flag)")
|
set (GGML_METAL_STD "" CACHE STRING "ggml: metal standard version (-std flag)")
|
||||||
option(GGML_OPENMP "ggml: use OpenMP" ON)
|
option(GGML_OPENMP "ggml: use OpenMP" ON)
|
||||||
option(GGML_RPC "ggml: use RPC" OFF)
|
option(GGML_RPC "ggml: use RPC" OFF)
|
||||||
|
option(GGML_AMX "ggml: use AMX" OFF)
|
||||||
option(GGML_SYCL "ggml: use SYCL" OFF)
|
option(GGML_SYCL "ggml: use SYCL" OFF)
|
||||||
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
|
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
|
||||||
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
|
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
|
||||||
|
25
ggml/include/ggml-amx.h
Normal file
25
ggml/include/ggml-amx.h
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// buffer_type API
|
||||||
|
GGML_API ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);
|
||||||
|
|
||||||
|
GGML_API bool ggml_backend_is_amx(ggml_backend_t backend);
|
||||||
|
|
||||||
|
// backend API
|
||||||
|
GGML_API ggml_backend_t ggml_backend_amx_init(void);
|
||||||
|
|
||||||
|
GGML_API void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads);
|
||||||
|
|
||||||
|
GGML_API ggml_backend_reg_t ggml_backend_amx_reg(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
@ -114,11 +114,12 @@ extern "C" {
|
|||||||
//
|
//
|
||||||
|
|
||||||
enum ggml_backend_dev_type {
|
enum ggml_backend_dev_type {
|
||||||
|
// CPU device using system memory
|
||||||
GGML_BACKEND_DEVICE_TYPE_CPU,
|
GGML_BACKEND_DEVICE_TYPE_CPU,
|
||||||
|
// GPU device using dedicated memory
|
||||||
GGML_BACKEND_DEVICE_TYPE_GPU,
|
GGML_BACKEND_DEVICE_TYPE_GPU,
|
||||||
// devices with full capabilities (excludes backends such as BLAS that only support matrix multiplication)
|
// accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
|
||||||
GGML_BACKEND_DEVICE_TYPE_CPU_FULL,
|
GGML_BACKEND_DEVICE_TYPE_ACCEL
|
||||||
GGML_BACKEND_DEVICE_TYPE_GPU_FULL
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// functionality supported by the device
|
// functionality supported by the device
|
||||||
@ -167,10 +168,14 @@ extern "C" {
|
|||||||
GGML_API ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index);
|
GGML_API ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index);
|
||||||
GGML_API void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name);
|
GGML_API void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name);
|
||||||
|
|
||||||
|
// Common functions that may be obtained using ggml_backend_reg_get_proc_address
|
||||||
|
|
||||||
// Functions that may be obtained using ggml_backend_reg_get_proc_address
|
// Split buffer type for tensor parallelism
|
||||||
typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(const float *);
|
typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split);
|
||||||
typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t, int);
|
// Set the number of threads for the backend
|
||||||
|
typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads);
|
||||||
|
// Get additional buffer types provided by the device (returns a NULL-terminated array)
|
||||||
|
typedef ggml_backend_buffer_type_t * (*ggml_backend_dev_get_extra_bufts_t)(ggml_backend_dev_t device);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Backend registry
|
// Backend registry
|
||||||
@ -192,7 +197,7 @@ extern "C" {
|
|||||||
GGML_API ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params);
|
GGML_API ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params);
|
||||||
// = ggml_backend_dev_init(ggml_backend_dev_by_type(type), params)
|
// = ggml_backend_dev_init(ggml_backend_dev_by_type(type), params)
|
||||||
GGML_API ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params);
|
GGML_API ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params);
|
||||||
// = ggml_backend_dev_init(ggml_backend_dev_by_type(GPU_FULL) OR ggml_backend_dev_by_type(CPU_FULL), NULL)
|
// = ggml_backend_dev_init(ggml_backend_dev_by_type(GPU) OR ggml_backend_dev_by_type(CPU), NULL)
|
||||||
GGML_API ggml_backend_t ggml_backend_init_best(void);
|
GGML_API ggml_backend_t ggml_backend_init_best(void);
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -34,6 +34,8 @@ extern "C" {
|
|||||||
*/
|
*/
|
||||||
#define GGML_CANN_MAX_DEVICES 16
|
#define GGML_CANN_MAX_DEVICES 16
|
||||||
|
|
||||||
|
GGML_API ggml_backend_reg_t ggml_backend_cann_reg(void);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Initializes the CANN backend for a specified device.
|
* @brief Initializes the CANN backend for a specified device.
|
||||||
*
|
*
|
||||||
|
@ -28,7 +28,7 @@ GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
|
|||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
|
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
|
||||||
|
|
||||||
// split tensor buffer that splits matrices by rows across multiple devices
|
// split tensor buffer that splits matrices by rows across multiple devices
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
|
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);
|
||||||
|
|
||||||
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
|
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
|
||||||
|
@ -11,6 +11,8 @@
|
|||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define GGML_KOMPUTE_MAX_DEVICES 16
|
||||||
|
|
||||||
struct ggml_vk_device {
|
struct ggml_vk_device {
|
||||||
int index;
|
int index;
|
||||||
int type; // same as VkPhysicalDeviceType
|
int type; // same as VkPhysicalDeviceType
|
||||||
@ -41,6 +43,8 @@ GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
|
|||||||
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
|
GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
|
||||||
|
|
||||||
|
GGML_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -19,6 +19,8 @@ extern "C" {
|
|||||||
// backend API
|
// backend API
|
||||||
GGML_API ggml_backend_t ggml_backend_sycl_init(int device);
|
GGML_API ggml_backend_t ggml_backend_sycl_init(int device);
|
||||||
|
|
||||||
|
GGML_API bool ggml_backend_is_sycl(ggml_backend_t backend);
|
||||||
|
|
||||||
// devide buffer
|
// devide buffer
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
|
GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
|
||||||
|
|
||||||
@ -29,14 +31,19 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const fl
|
|||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
|
GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
|
||||||
|
|
||||||
GGML_API void ggml_backend_sycl_print_sycl_devices(void);
|
GGML_API void ggml_backend_sycl_print_sycl_devices(void);
|
||||||
GGML_API void ggml_sycl_get_gpu_list(int *id_list, int max_len);
|
GGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len);
|
||||||
GGML_API void ggml_sycl_get_device_description(int device, char *description, size_t description_size);
|
GGML_API void ggml_backend_sycl_get_device_description(int device,
|
||||||
|
char *description,
|
||||||
|
size_t description_size);
|
||||||
GGML_API int ggml_backend_sycl_get_device_count();
|
GGML_API int ggml_backend_sycl_get_device_count();
|
||||||
GGML_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
|
GGML_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
|
||||||
|
|
||||||
// SYCL doesn't support registering host memory, keep here for reference
|
// SYCL doesn't support registering host memory, keep here for reference
|
||||||
// GGML_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
|
// GGML_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
|
||||||
// GGML_API void ggml_backend_sycl_unregister_host_buffer(void * buffer);
|
// GGML_API void ggml_backend_sycl_unregister_host_buffer(void * buffer);
|
||||||
|
|
||||||
|
GGML_API ggml_backend_reg_t ggml_backend_sycl_reg(void);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -24,6 +24,8 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
|
|||||||
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
|
GGML_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
|
||||||
|
|
||||||
|
GGML_API ggml_backend_reg_t ggml_backend_vk_reg(void);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -217,7 +217,6 @@
|
|||||||
|
|
||||||
#define GGML_MAX_DIMS 4
|
#define GGML_MAX_DIMS 4
|
||||||
#define GGML_MAX_PARAMS 2048
|
#define GGML_MAX_PARAMS 2048
|
||||||
#define GGML_MAX_CONTEXTS 64
|
|
||||||
#define GGML_MAX_SRC 10
|
#define GGML_MAX_SRC 10
|
||||||
#define GGML_MAX_N_THREADS 512
|
#define GGML_MAX_N_THREADS 512
|
||||||
#define GGML_MAX_OP_PARAMS 64
|
#define GGML_MAX_OP_PARAMS 64
|
||||||
@ -656,13 +655,6 @@ extern "C" {
|
|||||||
void * abort_callback_data;
|
void * abort_callback_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
// scratch buffer
|
|
||||||
struct ggml_scratch {
|
|
||||||
size_t offs;
|
|
||||||
size_t size;
|
|
||||||
void * data;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ggml_init_params {
|
struct ggml_init_params {
|
||||||
// memory pool
|
// memory pool
|
||||||
size_t mem_size; // bytes
|
size_t mem_size; // bytes
|
||||||
@ -760,12 +752,12 @@ extern "C" {
|
|||||||
|
|
||||||
// main
|
// main
|
||||||
|
|
||||||
GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
|
GGML_API struct ggml_context * ggml_init (struct ggml_init_params params);
|
||||||
GGML_API void ggml_free(struct ggml_context * ctx);
|
GGML_API void ggml_reset(struct ggml_context * ctx);
|
||||||
|
GGML_API void ggml_free (struct ggml_context * ctx);
|
||||||
|
|
||||||
GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
|
GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
|
||||||
|
|
||||||
GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
|
|
||||||
GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx);
|
GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx);
|
||||||
GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
|
GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
|
||||||
|
|
||||||
@ -2490,6 +2482,7 @@ extern "C" {
|
|||||||
GGML_API int ggml_cpu_has_avx512_vbmi(void);
|
GGML_API int ggml_cpu_has_avx512_vbmi(void);
|
||||||
GGML_API int ggml_cpu_has_avx512_vnni(void);
|
GGML_API int ggml_cpu_has_avx512_vnni(void);
|
||||||
GGML_API int ggml_cpu_has_avx512_bf16(void);
|
GGML_API int ggml_cpu_has_avx512_bf16(void);
|
||||||
|
GGML_API int ggml_cpu_has_amx_int8 (void);
|
||||||
GGML_API int ggml_cpu_has_fma (void);
|
GGML_API int ggml_cpu_has_fma (void);
|
||||||
GGML_API int ggml_cpu_has_neon (void);
|
GGML_API int ggml_cpu_has_neon (void);
|
||||||
GGML_API int ggml_cpu_has_sve (void);
|
GGML_API int ggml_cpu_has_sve (void);
|
||||||
|
@ -267,6 +267,26 @@ if (GGML_LLAMAFILE)
|
|||||||
set(GGML_SOURCES_LLAMAFILE llamafile/sgemm.cpp)
|
set(GGML_SOURCES_LLAMAFILE llamafile/sgemm.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (GGML_AMX)
|
||||||
|
if (CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 11.0)
|
||||||
|
else()
|
||||||
|
set(GGML_AMX OFF)
|
||||||
|
message(WARNING "AMX requires gcc version > 11.0. Turning off GGML_AMX.")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_AMX)
|
||||||
|
message(STATUS "Using AMX")
|
||||||
|
|
||||||
|
list(APPEND GGML_CDEF_PUBLIC GGML_USE_AMX)
|
||||||
|
|
||||||
|
file(GLOB GGML_HEADERS_AMX "ggml-amx/*.h")
|
||||||
|
list(APPEND GGML_HEADERS_AMX "../include/ggml-amx.h")
|
||||||
|
|
||||||
|
file(GLOB GGML_SOURCES_AMX "ggml-amx/*.cpp")
|
||||||
|
list(APPEND GGML_SOURCES_AMX "ggml-amx.cpp")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
if (GGML_CUDA)
|
if (GGML_CUDA)
|
||||||
cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
|
cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
|
||||||
|
|
||||||
@ -780,6 +800,7 @@ if (GGML_KOMPUTE)
|
|||||||
kompute-shaders/op_mul_mat_q8_0.comp
|
kompute-shaders/op_mul_mat_q8_0.comp
|
||||||
kompute-shaders/op_mul_mat_q4_0.comp
|
kompute-shaders/op_mul_mat_q4_0.comp
|
||||||
kompute-shaders/op_mul_mat_q4_1.comp
|
kompute-shaders/op_mul_mat_q4_1.comp
|
||||||
|
kompute-shaders/op_mul_mat_q4_k.comp
|
||||||
kompute-shaders/op_mul_mat_q6_k.comp
|
kompute-shaders/op_mul_mat_q6_k.comp
|
||||||
kompute-shaders/op_getrows_f32.comp
|
kompute-shaders/op_getrows_f32.comp
|
||||||
kompute-shaders/op_getrows_f16.comp
|
kompute-shaders/op_getrows_f16.comp
|
||||||
@ -813,6 +834,7 @@ if (GGML_KOMPUTE)
|
|||||||
shaderop_mul_mat_q8_0.h
|
shaderop_mul_mat_q8_0.h
|
||||||
shaderop_mul_mat_q4_0.h
|
shaderop_mul_mat_q4_0.h
|
||||||
shaderop_mul_mat_q4_1.h
|
shaderop_mul_mat_q4_1.h
|
||||||
|
shaderop_mul_mat_q4_k.h
|
||||||
shaderop_mul_mat_q6_k.h
|
shaderop_mul_mat_q6_k.h
|
||||||
shaderop_getrows_f32.h
|
shaderop_getrows_f32.h
|
||||||
shaderop_getrows_f16.h
|
shaderop_getrows_f16.h
|
||||||
@ -1180,6 +1202,18 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
|
|||||||
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
|
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
|
||||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
|
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
|
||||||
endif()
|
endif()
|
||||||
|
if (GGML_AMX_TILE)
|
||||||
|
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_TILE__>)
|
||||||
|
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_TILE__>)
|
||||||
|
endif()
|
||||||
|
if (GGML_AMX_INT8)
|
||||||
|
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_INT8__>)
|
||||||
|
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_INT8__>)
|
||||||
|
endif()
|
||||||
|
if (GGML_AMX_BF16)
|
||||||
|
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_BF16__>)
|
||||||
|
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_BF16__>)
|
||||||
|
endif()
|
||||||
elseif (GGML_AVX2)
|
elseif (GGML_AVX2)
|
||||||
list(APPEND ARCH_FLAGS /arch:AVX2)
|
list(APPEND ARCH_FLAGS /arch:AVX2)
|
||||||
elseif (GGML_AVX)
|
elseif (GGML_AVX)
|
||||||
@ -1215,6 +1249,15 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
|
|||||||
if (GGML_AVX512_BF16)
|
if (GGML_AVX512_BF16)
|
||||||
list(APPEND ARCH_FLAGS -mavx512bf16)
|
list(APPEND ARCH_FLAGS -mavx512bf16)
|
||||||
endif()
|
endif()
|
||||||
|
if (GGML_AMX_TILE)
|
||||||
|
list(APPEND ARCH_FLAGS -mamx-tile)
|
||||||
|
endif()
|
||||||
|
if (GGML_AMX_INT8)
|
||||||
|
list(APPEND ARCH_FLAGS -mamx-int8)
|
||||||
|
endif()
|
||||||
|
if (GGML_AMX_BF16)
|
||||||
|
list(APPEND ARCH_FLAGS -mamx-bf16)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
||||||
message(STATUS "PowerPC detected")
|
message(STATUS "PowerPC detected")
|
||||||
@ -1340,6 +1383,7 @@ add_library(ggml
|
|||||||
${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM}
|
${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM}
|
||||||
${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS}
|
${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS}
|
||||||
${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
|
${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
|
||||||
|
${GGML_SOURCES_AMX} ${GGML_HEADERS_AMX}
|
||||||
${GGML_SOURCES_CANN} ${GGML_HEADERS_CANN}
|
${GGML_SOURCES_CANN} ${GGML_HEADERS_CANN}
|
||||||
ggml-aarch64.c ggml-aarch64.h
|
ggml-aarch64.c ggml-aarch64.h
|
||||||
)
|
)
|
||||||
@ -1358,7 +1402,7 @@ list(APPEND GGML_EXTRA_LIBS_PRIVATE Threads::Threads)
|
|||||||
|
|
||||||
find_library(MATH_LIBRARY m)
|
find_library(MATH_LIBRARY m)
|
||||||
if (MATH_LIBRARY)
|
if (MATH_LIBRARY)
|
||||||
if (NOT WIN32 OR NOT GGML_SYCL)
|
if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT})
|
||||||
list(APPEND GGML_EXTRA_LIBS_PRIVATE m)
|
list(APPEND GGML_EXTRA_LIBS_PRIVATE m)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
@ -991,6 +991,73 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
|
#elif defined(__riscv_v_intrinsic)
|
||||||
|
if (__riscv_vlenb() >= QK4_0) {
|
||||||
|
const size_t vl = QK4_0;
|
||||||
|
|
||||||
|
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
||||||
|
|
||||||
|
vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||||
|
for (int l = 0; l < nb; l++) {
|
||||||
|
const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0];
|
||||||
|
const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8];
|
||||||
|
const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16];
|
||||||
|
const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24];
|
||||||
|
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
||||||
|
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4));
|
||||||
|
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4));
|
||||||
|
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4));
|
||||||
|
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4));
|
||||||
|
|
||||||
|
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
|
||||||
|
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
||||||
|
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
|
||||||
|
const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
|
||||||
|
const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
|
||||||
|
const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
|
||||||
|
const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
|
||||||
|
|
||||||
|
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||||
|
|
||||||
|
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m));
|
||||||
|
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||||
|
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||||
|
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||||
|
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||||
|
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||||
|
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||||
|
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||||
|
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||||
|
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||||
|
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||||
|
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||||
|
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||||
|
|
||||||
|
// vector version needs Zvfhmin extension
|
||||||
|
const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d);
|
||||||
|
const float b_scales[8] = {
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[0]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[1]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[2]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[3]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[4]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[5]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[6]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[7])
|
||||||
|
};
|
||||||
|
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
|
||||||
|
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4);
|
||||||
|
sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4);
|
||||||
|
}
|
||||||
|
__riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
||||||
{
|
{
|
||||||
float sumf[8];
|
float sumf[8];
|
||||||
@ -3171,6 +3238,207 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#elif defined(__riscv_v_intrinsic)
|
||||||
|
if (__riscv_vlenb() >= QK4_0) {
|
||||||
|
const size_t vl = QK4_0;
|
||||||
|
|
||||||
|
for (int y = 0; y < nr / 4; y++) {
|
||||||
|
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
||||||
|
vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||||
|
vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||||
|
vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||||
|
vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||||
|
for (int l = 0; l < nb; l++) {
|
||||||
|
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
|
||||||
|
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
||||||
|
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
|
||||||
|
const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
|
||||||
|
const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
|
||||||
|
const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
|
||||||
|
const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
|
||||||
|
|
||||||
|
// vector version needs Zvfhmin extension
|
||||||
|
const float a_scales[4] = {
|
||||||
|
GGML_FP16_TO_FP32(a_ptr[l].d[0]),
|
||||||
|
GGML_FP16_TO_FP32(a_ptr[l].d[1]),
|
||||||
|
GGML_FP16_TO_FP32(a_ptr[l].d[2]),
|
||||||
|
GGML_FP16_TO_FP32(a_ptr[l].d[3])
|
||||||
|
};
|
||||||
|
const float b_scales[8] = {
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[0]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[1]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[2]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[3]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[4]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[5]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[6]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[7])
|
||||||
|
};
|
||||||
|
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
|
||||||
|
|
||||||
|
const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0];
|
||||||
|
const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32];
|
||||||
|
const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64];
|
||||||
|
const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96];
|
||||||
|
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
||||||
|
vint16m4_t sumi_l0;
|
||||||
|
{
|
||||||
|
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4));
|
||||||
|
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4));
|
||||||
|
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4));
|
||||||
|
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4));
|
||||||
|
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||||
|
|
||||||
|
sumi_l0 = sumi_hi_m;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0));
|
||||||
|
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||||
|
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||||
|
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||||
|
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||||
|
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||||
|
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||||
|
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||||
|
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||||
|
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||||
|
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||||
|
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||||
|
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||||
|
|
||||||
|
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4);
|
||||||
|
sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8];
|
||||||
|
const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40];
|
||||||
|
const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72];
|
||||||
|
const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104];
|
||||||
|
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
||||||
|
vint16m4_t sumi_l1;
|
||||||
|
{
|
||||||
|
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4));
|
||||||
|
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4));
|
||||||
|
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4));
|
||||||
|
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4));
|
||||||
|
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||||
|
|
||||||
|
sumi_l1 = sumi_hi_m;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1));
|
||||||
|
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||||
|
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||||
|
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||||
|
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||||
|
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||||
|
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||||
|
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||||
|
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||||
|
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||||
|
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||||
|
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||||
|
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||||
|
|
||||||
|
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4);
|
||||||
|
sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16];
|
||||||
|
const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48];
|
||||||
|
const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80];
|
||||||
|
const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112];
|
||||||
|
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
||||||
|
vint16m4_t sumi_l2;
|
||||||
|
{
|
||||||
|
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4));
|
||||||
|
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4));
|
||||||
|
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4));
|
||||||
|
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4));
|
||||||
|
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||||
|
|
||||||
|
sumi_l2 = sumi_hi_m;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2));
|
||||||
|
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||||
|
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||||
|
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||||
|
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||||
|
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||||
|
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||||
|
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||||
|
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||||
|
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||||
|
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||||
|
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||||
|
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||||
|
|
||||||
|
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4);
|
||||||
|
sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24];
|
||||||
|
const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56];
|
||||||
|
const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88];
|
||||||
|
const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120];
|
||||||
|
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
||||||
|
vint16m4_t sumi_l3;
|
||||||
|
{
|
||||||
|
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4));
|
||||||
|
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4));
|
||||||
|
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4));
|
||||||
|
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4));
|
||||||
|
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
||||||
|
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
||||||
|
|
||||||
|
sumi_l3 = sumi_hi_m;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3));
|
||||||
|
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||||
|
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||||
|
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||||
|
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||||
|
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||||
|
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||||
|
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||||
|
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||||
|
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||||
|
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||||
|
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||||
|
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||||
|
|
||||||
|
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4);
|
||||||
|
sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4);
|
||||||
|
__riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4);
|
||||||
|
__riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4);
|
||||||
|
__riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
||||||
|
@ -348,7 +348,6 @@ struct tensor_alloc {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct leaf_alloc {
|
struct leaf_alloc {
|
||||||
int buffer_id;
|
|
||||||
struct tensor_alloc leaf;
|
struct tensor_alloc leaf;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -740,7 +739,6 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
|||||||
for (int i = 0; i < graph->n_leafs; i++) {
|
for (int i = 0; i < graph->n_leafs; i++) {
|
||||||
struct ggml_tensor * leaf = graph->leafs[i];
|
struct ggml_tensor * leaf = graph->leafs[i];
|
||||||
struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf);
|
struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf);
|
||||||
galloc->leaf_allocs[i].buffer_id = hn->buffer_id;
|
|
||||||
if (leaf->view_src || leaf->data) {
|
if (leaf->view_src || leaf->data) {
|
||||||
galloc->leaf_allocs[i].leaf.buffer_id = -1;
|
galloc->leaf_allocs[i].leaf.buffer_id = -1;
|
||||||
galloc->leaf_allocs[i].leaf.offset = SIZE_MAX;
|
galloc->leaf_allocs[i].leaf.offset = SIZE_MAX;
|
||||||
|
436
ggml/src/ggml-amx.cpp
Normal file
436
ggml/src/ggml-amx.cpp
Normal file
@ -0,0 +1,436 @@
|
|||||||
|
#include "ggml-amx.h"
|
||||||
|
#include "ggml-amx/common.h"
|
||||||
|
#include "ggml-amx/mmq.h"
|
||||||
|
#include "ggml-backend-impl.h"
|
||||||
|
#include "ggml-impl.h"
|
||||||
|
|
||||||
|
#if defined(__gnu_linux__)
|
||||||
|
#include <sys/syscall.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#if defined(__AMX_INT8__)
|
||||||
|
|
||||||
|
// AMX buffer interface
|
||||||
|
static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
|
free(buffer->context);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
|
return (void *)(buffer->context);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||||
|
memset((char *)tensor->data + offset, value, size);
|
||||||
|
|
||||||
|
GGML_UNUSED(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||||
|
if (qtype_has_amx_kernels(tensor->type)) {
|
||||||
|
ggml_backend_amx_convert_weight(tensor, data, offset, size);
|
||||||
|
} else {
|
||||||
|
memcpy((char *)tensor->data + offset, data, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_UNUSED(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||||
|
GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
|
||||||
|
memcpy(data, (const char *)tensor->data + offset, size);
|
||||||
|
|
||||||
|
GGML_UNUSED(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
||||||
|
if (ggml_backend_buffer_is_host(src->buffer)) {
|
||||||
|
if (qtype_has_amx_kernels(src->type)) {
|
||||||
|
ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_backend_amx_get_alloc_size(dst));
|
||||||
|
} else {
|
||||||
|
memcpy(dst->data, src->data, ggml_nbytes(src));
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
|
||||||
|
GGML_UNUSED(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||||
|
memset(buffer->context, value, buffer->size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
|
||||||
|
/* .free_buffer = */ ggml_backend_amx_buffer_free_buffer,
|
||||||
|
/* .get_base = */ ggml_backend_amx_buffer_get_base,
|
||||||
|
/* .init_tensor = */ NULL, // no initialization required
|
||||||
|
/* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
|
||||||
|
/* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
|
||||||
|
/* .get_tensor = */ ggml_backend_amx_buffer_get_tensor,
|
||||||
|
/* .cpy_tensor = */ ggml_backend_amx_buffer_cpy_tensor,
|
||||||
|
/* .clear = */ ggml_backend_amx_buffer_clear,
|
||||||
|
/* .reset = */ NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
||||||
|
return "AMX";
|
||||||
|
|
||||||
|
GGML_UNUSED(buft);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||||
|
void * data = aligned_alloc(TENSOR_ALIGNMENT, size);
|
||||||
|
if (data == NULL) {
|
||||||
|
fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||||
|
return TENSOR_ALIGNMENT;
|
||||||
|
|
||||||
|
GGML_UNUSED(buft);
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
|
||||||
|
return ggml_backend_amx_get_alloc_size(tensor);
|
||||||
|
|
||||||
|
GGML_UNUSED(buft);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
||||||
|
return false;
|
||||||
|
|
||||||
|
GGML_UNUSED(buft);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
|
||||||
|
static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
|
||||||
|
/* .iface = */ {
|
||||||
|
/* .get_name = */ ggml_backend_amx_buffer_type_get_name,
|
||||||
|
/* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer,
|
||||||
|
/* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment,
|
||||||
|
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
||||||
|
/* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size,
|
||||||
|
/* .is_host = */ ggml_backend_amx_buffer_type_is_host,
|
||||||
|
},
|
||||||
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
|
||||||
|
/* .context = */ NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
return &ggml_backend_buffer_type_amx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// backend interface
|
||||||
|
|
||||||
|
static const char * ggml_backend_amx_name(ggml_backend_t backend) {
|
||||||
|
return "AMX";
|
||||||
|
|
||||||
|
GGML_UNUSED(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_amx_free(ggml_backend_t backend) {
|
||||||
|
ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
|
||||||
|
delete ctx;
|
||||||
|
delete backend;
|
||||||
|
}
|
||||||
|
|
||||||
|
static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||||
|
ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
|
||||||
|
|
||||||
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
struct ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
|
||||||
|
switch (node->op) {
|
||||||
|
case GGML_OP_MUL_MAT:
|
||||||
|
ggml_backend_amx_mul_mat(ctx, node);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case GGML_OP_NONE:
|
||||||
|
case GGML_OP_RESHAPE:
|
||||||
|
case GGML_OP_VIEW:
|
||||||
|
case GGML_OP_PERMUTE:
|
||||||
|
case GGML_OP_TRANSPOSE:
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node));
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return GGML_STATUS_SUCCESS;
|
||||||
|
|
||||||
|
GGML_UNUSED(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct ggml_backend_i ggml_backend_amx_i = {
|
||||||
|
/* .get_name = */ ggml_backend_amx_name,
|
||||||
|
/* .free = */ ggml_backend_amx_free,
|
||||||
|
/* .set_tensor_async = */ NULL,
|
||||||
|
/* .get_tensor_async = */ NULL,
|
||||||
|
/* .cpy_tensor_async = */ NULL,
|
||||||
|
/* .synchronize = */ NULL,
|
||||||
|
/* .graph_plan_create = */ NULL,
|
||||||
|
/* .graph_plan_free = */ NULL,
|
||||||
|
/* .graph_plan_update = */ NULL,
|
||||||
|
/* .graph_plan_compute = */ NULL,
|
||||||
|
/* .graph_compute = */ ggml_backend_amx_graph_compute,
|
||||||
|
/* .event_record = */ NULL,
|
||||||
|
/* .event_wait = */ NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
static ggml_guid_t ggml_backend_amx_guid() {
|
||||||
|
static ggml_guid guid = { 0x13, 0xb8, 0xa4, 0xc4, 0xba, 0xfe, 0x51, 0x67, 0x87, 0x44, 0x55, 0x15, 0xb2, 0x35, 0x62, 0x3e };
|
||||||
|
return &guid;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define ARCH_GET_XCOMP_PERM 0x1022
|
||||||
|
#define ARCH_REQ_XCOMP_PERM 0x1023
|
||||||
|
#define XFEATURE_XTILECFG 17
|
||||||
|
#define XFEATURE_XTILEDATA 18
|
||||||
|
|
||||||
|
static bool ggml_amx_init() {
|
||||||
|
#if defined(__gnu_linux__)
|
||||||
|
if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
|
||||||
|
fprintf(stderr, "AMX is not ready to be used!\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
#elif defined(_WIN32)
|
||||||
|
return true;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_t ggml_backend_amx_init() {
|
||||||
|
|
||||||
|
// invoke a Linux system call to request access to AMX features
|
||||||
|
ggml_amx_init();
|
||||||
|
|
||||||
|
// backend context
|
||||||
|
ggml_backend_amx_context * ctx = new ggml_backend_amx_context;
|
||||||
|
|
||||||
|
// ggml amx backend
|
||||||
|
ggml_backend_t backend = new ggml_backend {
|
||||||
|
/* .guid = */ ggml_backend_amx_guid(),
|
||||||
|
/* .interface = */ ggml_backend_amx_i,
|
||||||
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
|
||||||
|
/* .context = */ ctx,
|
||||||
|
};
|
||||||
|
|
||||||
|
return backend;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ggml_backend_is_amx(ggml_backend_t backend) {
|
||||||
|
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_amx_guid());
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
|
||||||
|
GGML_ASSERT(ggml_backend_is_amx(backend_amx));
|
||||||
|
|
||||||
|
ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend_amx->context;
|
||||||
|
ctx->n_threads = n_threads;
|
||||||
|
}
|
||||||
|
|
||||||
|
// device interface
|
||||||
|
|
||||||
|
static const char * ggml_backend_amx_device_get_name(ggml_backend_dev_t dev) {
|
||||||
|
return "AMX";
|
||||||
|
|
||||||
|
GGML_UNUSED(dev);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * ggml_backend_amx_device_get_description(ggml_backend_dev_t dev) {
|
||||||
|
return "Intel Advanced Matrix Extensions";
|
||||||
|
|
||||||
|
GGML_UNUSED(dev);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_amx_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||||
|
// TODO
|
||||||
|
*free = 0;
|
||||||
|
*total = 0;
|
||||||
|
|
||||||
|
GGML_UNUSED(dev);
|
||||||
|
}
|
||||||
|
|
||||||
|
static enum ggml_backend_dev_type ggml_backend_amx_device_get_type(ggml_backend_dev_t dev) {
|
||||||
|
return GGML_BACKEND_DEVICE_TYPE_ACCEL;
|
||||||
|
|
||||||
|
GGML_UNUSED(dev);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_amx_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
||||||
|
props->name = ggml_backend_amx_device_get_name(dev);
|
||||||
|
props->description = ggml_backend_amx_device_get_description(dev);
|
||||||
|
props->type = ggml_backend_amx_device_get_type(dev);
|
||||||
|
ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||||
|
|
||||||
|
// `buffer_from_host_ptr` is intended to be used in mmap, when memory layout unchanged
|
||||||
|
props->caps = {
|
||||||
|
/* .async = */ false,
|
||||||
|
/* .host_buffer = */ false,
|
||||||
|
/* .buffer_from_host_ptr = */ false,
|
||||||
|
/* .events = */ false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_t ggml_backend_amx_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||||
|
return ggml_backend_amx_init();
|
||||||
|
|
||||||
|
GGML_UNUSED(dev);
|
||||||
|
GGML_UNUSED(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type_t ggml_backend_amx_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||||
|
return ggml_backend_amx_buffer_type();
|
||||||
|
|
||||||
|
GGML_UNUSED(dev);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
||||||
|
|
||||||
|
// handle only 2d gemm for now
|
||||||
|
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
|
||||||
|
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
switch (op->op) {
|
||||||
|
case GGML_OP_NONE:
|
||||||
|
case GGML_OP_RESHAPE:
|
||||||
|
case GGML_OP_VIEW:
|
||||||
|
case GGML_OP_PERMUTE:
|
||||||
|
case GGML_OP_TRANSPOSE:
|
||||||
|
return true;
|
||||||
|
|
||||||
|
case GGML_OP_MUL_MAT: {
|
||||||
|
const struct ggml_tensor * src0 = op->src[0];
|
||||||
|
const struct ggml_tensor * src1 = op->src[1];
|
||||||
|
|
||||||
|
const enum ggml_type type = src0->type;
|
||||||
|
const int64_t ne0 = op->ne[0];
|
||||||
|
|
||||||
|
bool is_training = src0->grad || src1->grad;
|
||||||
|
|
||||||
|
// amx kernels enables for Q4_0, Q4_1, Q8_0, F16
|
||||||
|
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
|
||||||
|
bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
|
||||||
|
|
||||||
|
bool can_use_amx =
|
||||||
|
is_contiguous_2d(src0) && // src0 must be contiguous
|
||||||
|
is_contiguous_2d(src1) && // src1 must be contiguous
|
||||||
|
!is_training && // inference only
|
||||||
|
src1->type == GGML_TYPE_F32 && // src1 must be float32
|
||||||
|
has_amx_kernels && // with amx kernel impls
|
||||||
|
ne0 % (TILE_N * 2) == 0; // out_features is 32x
|
||||||
|
|
||||||
|
return can_use_amx;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_UNUSED(dev);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_amx_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||||
|
return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;
|
||||||
|
|
||||||
|
GGML_UNUSED(dev);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const struct ggml_backend_device_i ggml_backend_amx_device_i = {
|
||||||
|
/* .get_name = */ ggml_backend_amx_device_get_name,
|
||||||
|
/* .get_description = */ ggml_backend_amx_device_get_description,
|
||||||
|
/* .get_memory = */ ggml_backend_amx_device_get_memory,
|
||||||
|
/* .get_type = */ ggml_backend_amx_device_get_type,
|
||||||
|
/* .get_props = */ ggml_backend_amx_device_get_props,
|
||||||
|
/* .init_backend = */ ggml_backend_amx_device_init,
|
||||||
|
/* .get_buffer_type = */ ggml_backend_amx_device_get_buffer_type,
|
||||||
|
/* .get_host_buffer_type = */ NULL,
|
||||||
|
/* .buffer_from_host_ptr = */ NULL,
|
||||||
|
/* .supports_op = */ ggml_backend_amx_device_supports_op,
|
||||||
|
/* .supports_buft = */ ggml_backend_amx_device_supports_buft,
|
||||||
|
/* .offload_op = */ NULL,
|
||||||
|
/* .event_new = */ NULL,
|
||||||
|
/* .event_free = */ NULL,
|
||||||
|
/* .event_synchronize = */ NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
// backend reg interface
|
||||||
|
|
||||||
|
static const char * ggml_backend_amx_reg_get_name(ggml_backend_reg_t reg) {
|
||||||
|
return "AMX";
|
||||||
|
|
||||||
|
GGML_UNUSED(reg);
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t ggml_backend_amx_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||||
|
return 1;
|
||||||
|
|
||||||
|
GGML_UNUSED(reg);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_dev_t ggml_backend_amx_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
||||||
|
GGML_ASSERT(index == 0);
|
||||||
|
|
||||||
|
static ggml_backend_device ggml_backend_amx_device = {
|
||||||
|
/* .iface = */ ggml_backend_amx_device_i,
|
||||||
|
/* .reg = */ reg,
|
||||||
|
/* .context = */ nullptr,
|
||||||
|
};
|
||||||
|
|
||||||
|
return &ggml_backend_amx_device;
|
||||||
|
|
||||||
|
GGML_UNUSED(reg);
|
||||||
|
GGML_UNUSED(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void * ggml_backend_amx_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
||||||
|
if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
|
||||||
|
return (void *)ggml_backend_amx_set_n_threads;
|
||||||
|
}
|
||||||
|
return NULL;
|
||||||
|
|
||||||
|
GGML_UNUSED(reg);
|
||||||
|
GGML_UNUSED(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const struct ggml_backend_reg_i ggml_backend_amx_reg_i = {
|
||||||
|
/* .get_name = */ ggml_backend_amx_reg_get_name,
|
||||||
|
/* .get_device_count = */ ggml_backend_amx_reg_get_device_count,
|
||||||
|
/* .get_device = */ ggml_backend_amx_reg_get_device,
|
||||||
|
/* .get_proc_address = */ ggml_backend_amx_get_proc_address,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_backend_reg_t ggml_backend_amx_reg(void) {
|
||||||
|
static struct ggml_backend_reg ggml_backend_amx_reg = {
|
||||||
|
/* .iface = */ ggml_backend_amx_reg_i,
|
||||||
|
/* .context = */ NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
return &ggml_backend_amx_reg;
|
||||||
|
}
|
||||||
|
|
||||||
|
#else // if defined(__AMX_INT8__)
|
||||||
|
|
||||||
|
ggml_backend_t ggml_backend_amx_init(void) {
|
||||||
|
fprintf(stderr, "GGML is not compiled with AMX support!\n");
|
||||||
|
return ggml_backend_t{};
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
|
||||||
|
fprintf(stderr, "GGML is not compiled with AMX support!\n");
|
||||||
|
|
||||||
|
GGML_UNUSED(backend_amx);
|
||||||
|
GGML_UNUSED(n_threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
93
ggml/src/ggml-amx/common.h
Normal file
93
ggml/src/ggml-amx/common.h
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-cpu-impl.h" // <immintrin.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#if defined(_OPENMP)
|
||||||
|
#include <omp.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define TILE_M 16
|
||||||
|
#define TILE_N 16
|
||||||
|
#define TILE_K 32
|
||||||
|
#define VNNI_BLK 4
|
||||||
|
|
||||||
|
#define AMX_BLK_SIZE 32
|
||||||
|
|
||||||
|
#define TMM0 0
|
||||||
|
#define TMM1 1
|
||||||
|
#define TMM2 2
|
||||||
|
#define TMM3 3
|
||||||
|
#define TMM4 4
|
||||||
|
#define TMM5 5
|
||||||
|
#define TMM6 6
|
||||||
|
#define TMM7 7
|
||||||
|
|
||||||
|
// parallel routines
|
||||||
|
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
|
||||||
|
inline T div_up(T x, T y) { return (x + y - 1) / y; }
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
|
||||||
|
#if 0
|
||||||
|
// onednn partition pattern
|
||||||
|
T& n_my = n_end;
|
||||||
|
if (nth <= 1 || n == 0) {
|
||||||
|
n_start = 0;
|
||||||
|
n_my = n;
|
||||||
|
} else {
|
||||||
|
T n1 = div_up(n, nth);
|
||||||
|
T n2 = n1 - 1;
|
||||||
|
T T1 = n - n2 * nth;
|
||||||
|
n_my = ith < T1 ? n1 : n2;
|
||||||
|
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
|
||||||
|
}
|
||||||
|
n_end += n_start;
|
||||||
|
#else
|
||||||
|
// pytorch aten partition pattern
|
||||||
|
T n_my = div_up(n, nth);
|
||||||
|
n_start = ith * n_my;
|
||||||
|
n_end = std::min(n_start + n_my, n);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename func_t>
|
||||||
|
inline void parallel_for(int nth, int n, const func_t& f) {
|
||||||
|
#if defined(_OPENMP)
|
||||||
|
#pragma omp parallel num_threads(nth)
|
||||||
|
{
|
||||||
|
//int nth = omp_get_num_threads();
|
||||||
|
int ith = omp_get_thread_num();
|
||||||
|
int tbegin, tend;
|
||||||
|
balance211(n, nth, ith, tbegin, tend);
|
||||||
|
f(tbegin, tend);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
f(0, n);
|
||||||
|
|
||||||
|
GGML_UNUSED(nth);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// quantized types that have AMX support
|
||||||
|
inline bool qtype_has_amx_kernels(const enum ggml_type type) {
|
||||||
|
// TODO: fix padding for vnni format
|
||||||
|
return (type == GGML_TYPE_Q4_0) ||
|
||||||
|
(type == GGML_TYPE_Q4_1);
|
||||||
|
//(type == GGML_TYPE_Q8_0) ||
|
||||||
|
//(type == GGML_TYPE_Q4_K) ||
|
||||||
|
//(type == GGML_TYPE_Q5_K) ||
|
||||||
|
//(type == GGML_TYPE_Q6_K) ||
|
||||||
|
//(type == GGML_TYPE_IQ4_XS);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggml backend context
|
||||||
|
struct ggml_backend_amx_context {
|
||||||
|
int n_threads = GGML_DEFAULT_N_THREADS;
|
||||||
|
std::unique_ptr<char[]> work_data;
|
||||||
|
size_t work_size = 0;
|
||||||
|
};
|
2509
ggml/src/ggml-amx/mmq.cpp
Normal file
2509
ggml/src/ggml-amx/mmq.cpp
Normal file
File diff suppressed because it is too large
Load Diff
17
ggml/src/ggml-amx/mmq.h
Normal file
17
ggml/src/ggml-amx/mmq.h
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
#pragma once
|
||||||
|
#include "common.h"
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||||
|
|
||||||
|
void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
@ -22,7 +22,7 @@ extern "C" {
|
|||||||
size_t (*get_max_size) (ggml_backend_buffer_type_t buft);
|
size_t (*get_max_size) (ggml_backend_buffer_type_t buft);
|
||||||
// (optional) data size needed to allocate the tensor, including padding (defaults to ggml_nbytes)
|
// (optional) data size needed to allocate the tensor, including padding (defaults to ggml_nbytes)
|
||||||
size_t (*get_alloc_size)(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
|
size_t (*get_alloc_size)(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
|
||||||
// (optional) check if tensor data is in host memory (defaults to false)
|
// (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false)
|
||||||
bool (*is_host) (ggml_backend_buffer_type_t buft);
|
bool (*is_host) (ggml_backend_buffer_type_t buft);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -37,7 +37,6 @@ extern "C" {
|
|||||||
//
|
//
|
||||||
|
|
||||||
struct ggml_backend_buffer_i {
|
struct ggml_backend_buffer_i {
|
||||||
const char * (*get_name) (ggml_backend_buffer_t buffer);
|
|
||||||
// (optional) free the buffer
|
// (optional) free the buffer
|
||||||
void (*free_buffer) (ggml_backend_buffer_t buffer);
|
void (*free_buffer) (ggml_backend_buffer_t buffer);
|
||||||
// base address of the buffer
|
// base address of the buffer
|
||||||
@ -88,19 +87,16 @@ extern "C" {
|
|||||||
|
|
||||||
void (*free)(ggml_backend_t backend);
|
void (*free)(ggml_backend_t backend);
|
||||||
|
|
||||||
// Will be moved to the device interface
|
|
||||||
// buffer allocation
|
|
||||||
ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
|
|
||||||
|
|
||||||
// (optional) asynchronous tensor data access
|
// (optional) asynchronous tensor data access
|
||||||
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||||
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||||
bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
|
bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||||
|
|
||||||
// (optional) complete all pending operations
|
// (optional) complete all pending operations (required if the backend supports async operations)
|
||||||
void (*synchronize)(ggml_backend_t backend);
|
void (*synchronize)(ggml_backend_t backend);
|
||||||
|
|
||||||
// (optional) compute graph with a plan (not used currently)
|
// (optional) graph plans (not used currently)
|
||||||
|
// compute graph with a plan
|
||||||
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
|
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
|
||||||
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||||
// update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology
|
// update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology
|
||||||
@ -111,13 +107,6 @@ extern "C" {
|
|||||||
// compute graph (always async if supported by the backend)
|
// compute graph (always async if supported by the backend)
|
||||||
enum ggml_status (*graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
enum ggml_status (*graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||||
|
|
||||||
// IMPORTANT: these functions have been moved to the device interface and will be removed from the backend interface
|
|
||||||
// new backends should implement the device interface instead
|
|
||||||
// These functions are being moved to the device interface
|
|
||||||
bool (*supports_op) (ggml_backend_t backend, const struct ggml_tensor * op);
|
|
||||||
bool (*supports_buft)(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
|
|
||||||
bool (*offload_op) (ggml_backend_t backend, const struct ggml_tensor * op);
|
|
||||||
|
|
||||||
// (optional) event synchronization
|
// (optional) event synchronization
|
||||||
// record an event on this stream
|
// record an event on this stream
|
||||||
void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event);
|
void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event);
|
||||||
|
@ -34,6 +34,11 @@ const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||||
|
if (size == 0) {
|
||||||
|
// return a dummy buffer for zero-sized allocations
|
||||||
|
return ggml_backend_buffer_init(buft, {}, NULL, 0);
|
||||||
|
}
|
||||||
|
|
||||||
return buft->iface.alloc_buffer(buft, size);
|
return buft->iface.alloc_buffer(buft, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,7 +94,7 @@ ggml_backend_buffer_t ggml_backend_buffer_init(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const char * ggml_backend_buffer_name(ggml_backend_buffer_t buffer) {
|
const char * ggml_backend_buffer_name(ggml_backend_buffer_t buffer) {
|
||||||
return buffer->iface.get_name(buffer);
|
return ggml_backend_buft_name(ggml_backend_buffer_get_type(buffer));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
|
void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
|
||||||
@ -108,6 +113,11 @@ size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
|
void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
|
// get_base is optional if the buffer is zero-sized
|
||||||
|
if (buffer->size == 0) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
void * base = buffer->iface.get_base(buffer);
|
void * base = buffer->iface.get_base(buffer);
|
||||||
|
|
||||||
GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
|
GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
|
||||||
@ -122,6 +132,15 @@ void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||||
|
// clear is optional if the buffer is zero-sized
|
||||||
|
if (buffer->size == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer->iface.clear(buffer, value);
|
||||||
|
}
|
||||||
|
|
||||||
size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) {
|
size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) {
|
||||||
return ggml_backend_buft_get_alignment(ggml_backend_buffer_get_type(buffer));
|
return ggml_backend_buft_get_alignment(ggml_backend_buffer_get_type(buffer));
|
||||||
}
|
}
|
||||||
@ -134,10 +153,6 @@ size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct g
|
|||||||
return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor);
|
return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
|
||||||
buffer->iface.clear(buffer, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) {
|
bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) {
|
||||||
return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
|
return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
|
||||||
}
|
}
|
||||||
@ -198,7 +213,7 @@ void ggml_backend_free(ggml_backend_t backend) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
|
ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
|
||||||
return backend->iface.get_default_buffer_type(backend);
|
return ggml_backend_dev_buffer_type(backend->device);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {
|
ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {
|
||||||
@ -238,43 +253,42 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten
|
|||||||
void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||||
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
||||||
|
|
||||||
|
if (size == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(buf != NULL && "tensor buffer not set");
|
GGML_ASSERT(buf != NULL && "tensor buffer not set");
|
||||||
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
||||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
||||||
|
|
||||||
if (!size) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
buf->iface.set_tensor(buf, tensor, data, offset, size);
|
buf->iface.set_tensor(buf, tensor, data, offset, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||||
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
||||||
|
|
||||||
|
if (size == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(buf != NULL && "tensor buffer not set");
|
GGML_ASSERT(buf != NULL && "tensor buffer not set");
|
||||||
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
||||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
|
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
|
||||||
|
|
||||||
if (!size) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
buf->iface.get_tensor(buf, tensor, data, offset, size);
|
buf->iface.get_tensor(buf, tensor, data, offset, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_API void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
GGML_API void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||||
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
||||||
|
|
||||||
GGML_ASSERT(buf != NULL && "tensor buffer not set");
|
if (size == 0) {
|
||||||
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
|
||||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
|
||||||
|
|
||||||
if (!size) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(buf->iface.memset_tensor != NULL && "memset not supported by backend buffer");
|
GGML_ASSERT(buf != NULL && "tensor buffer not set");
|
||||||
|
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
||||||
|
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
||||||
|
GGML_ASSERT(buf->iface.memset_tensor != NULL && "memset not implemented by backend buffer");
|
||||||
|
|
||||||
buf->iface.memset_tensor(buf, tensor, value, offset, size);
|
buf->iface.memset_tensor(buf, tensor, value, offset, size);
|
||||||
}
|
}
|
||||||
@ -316,33 +330,15 @@ enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
||||||
// helper to ease transition to device interface
|
|
||||||
if (backend->device) {
|
|
||||||
return ggml_backend_dev_supports_op(backend->device, op);
|
return ggml_backend_dev_supports_op(backend->device, op);
|
||||||
}
|
|
||||||
|
|
||||||
return backend->iface.supports_op(backend, op);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
||||||
// helper to ease transition to device interface
|
|
||||||
if (backend->device) {
|
|
||||||
return ggml_backend_dev_supports_buft(backend->device, buft);
|
return ggml_backend_dev_supports_buft(backend->device, buft);
|
||||||
}
|
|
||||||
|
|
||||||
return backend->iface.supports_buft(backend, buft);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
||||||
// helper to ease transition to device interface
|
|
||||||
if (backend->device) {
|
|
||||||
return ggml_backend_dev_offload_op(backend->device, op);
|
return ggml_backend_dev_offload_op(backend->device, op);
|
||||||
}
|
|
||||||
|
|
||||||
if (backend->iface.offload_op != NULL) {
|
|
||||||
return backend->iface.offload_op(backend, op);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
|
ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
|
||||||
@ -538,6 +534,14 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
|
|||||||
#include "ggml-metal.h"
|
#include "ggml-metal.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_SYCL
|
||||||
|
#include "ggml-sycl.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_VULKAN
|
||||||
|
#include "ggml-vulkan.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_BLAS
|
#ifdef GGML_USE_BLAS
|
||||||
#include "ggml-blas.h"
|
#include "ggml-blas.h"
|
||||||
#endif
|
#endif
|
||||||
@ -546,6 +550,22 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
|
|||||||
#include "ggml-rpc.h"
|
#include "ggml-rpc.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifndef __AMX_INT8__
|
||||||
|
#undef GGML_USE_AMX
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_AMX
|
||||||
|
# include "ggml-amx.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_CANN
|
||||||
|
#include "ggml-cann.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_KOMPUTE
|
||||||
|
#include "ggml-kompute.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
struct ggml_backend_registry {
|
struct ggml_backend_registry {
|
||||||
std::vector<ggml_backend_reg_t> backends;
|
std::vector<ggml_backend_reg_t> backends;
|
||||||
std::vector<ggml_backend_dev_t> devices;
|
std::vector<ggml_backend_dev_t> devices;
|
||||||
@ -557,14 +577,27 @@ struct ggml_backend_registry {
|
|||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
register_backend(ggml_backend_metal_reg());
|
register_backend(ggml_backend_metal_reg());
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef GGML_USE_SYCL
|
||||||
|
register_backend(ggml_backend_sycl_reg());
|
||||||
|
#endif
|
||||||
|
#ifdef GGML_USE_VULKAN
|
||||||
|
register_backend(ggml_backend_vk_reg());
|
||||||
|
#endif
|
||||||
|
#ifdef GGML_USE_CANN
|
||||||
|
register_backend(ggml_backend_cann_reg());
|
||||||
|
#endif
|
||||||
#ifdef GGML_USE_BLAS
|
#ifdef GGML_USE_BLAS
|
||||||
register_backend(ggml_backend_blas_reg());
|
register_backend(ggml_backend_blas_reg());
|
||||||
#endif
|
#endif
|
||||||
#ifdef GGML_USE_RPC
|
#ifdef GGML_USE_RPC
|
||||||
register_backend(ggml_backend_rpc_reg());
|
register_backend(ggml_backend_rpc_reg());
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef GGML_USE_AMX
|
||||||
// TODO: sycl, vulkan, kompute, cann
|
register_backend(ggml_backend_amx_reg());
|
||||||
|
#endif
|
||||||
|
#ifdef GGML_USE_KOMPUTE
|
||||||
|
register_backend(ggml_backend_kompute_reg());
|
||||||
|
#endif
|
||||||
|
|
||||||
register_backend(ggml_backend_cpu_reg());
|
register_backend(ggml_backend_cpu_reg());
|
||||||
}
|
}
|
||||||
@ -670,9 +703,9 @@ ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_t ggml_backend_init_best(void) {
|
ggml_backend_t ggml_backend_init_best(void) {
|
||||||
ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU_FULL);
|
ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
|
||||||
if (!dev) {
|
if (!dev) {
|
||||||
dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU_FULL);
|
dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||||
}
|
}
|
||||||
if (!dev) {
|
if (!dev) {
|
||||||
return NULL;
|
return NULL;
|
||||||
@ -680,15 +713,7 @@ ggml_backend_t ggml_backend_init_best(void) {
|
|||||||
return ggml_backend_dev_init(dev, NULL);
|
return ggml_backend_dev_init(dev, NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
// backend CPU
|
// CPU backend - buffer
|
||||||
|
|
||||||
static const size_t TENSOR_ALIGNMENT = 32; // required for mmap as gguf only guarantees 32-byte alignment
|
|
||||||
|
|
||||||
static const char * ggml_backend_cpu_buffer_get_name(ggml_backend_buffer_t buffer) {
|
|
||||||
return "CPU";
|
|
||||||
|
|
||||||
GGML_UNUSED(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
uintptr_t data = (uintptr_t)buffer->context;
|
uintptr_t data = (uintptr_t)buffer->context;
|
||||||
@ -702,7 +727,7 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
free(buffer->context);
|
ggml_aligned_free(buffer->context, buffer->size);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||||
@ -738,7 +763,6 @@ static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
|
|||||||
}
|
}
|
||||||
|
|
||||||
static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = {
|
static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = {
|
||||||
/* .get_name = */ ggml_backend_cpu_buffer_get_name,
|
|
||||||
/* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
|
/* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
|
||||||
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
|
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
|
||||||
/* .init_tensor = */ NULL, // no initialization required
|
/* .init_tensor = */ NULL, // no initialization required
|
||||||
@ -751,7 +775,6 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = {
|
static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = {
|
||||||
/* .get_name = */ ggml_backend_cpu_buffer_get_name,
|
|
||||||
/* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
|
/* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
|
||||||
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
|
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
|
||||||
/* .init_tensor = */ NULL, // no initialization required
|
/* .init_tensor = */ NULL, // no initialization required
|
||||||
@ -763,6 +786,8 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = {
|
|||||||
/* .reset = */ NULL,
|
/* .reset = */ NULL,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// CPU backend - buffer type
|
||||||
|
|
||||||
static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
||||||
return "CPU";
|
return "CPU";
|
||||||
|
|
||||||
@ -770,8 +795,8 @@ static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_ty
|
|||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||||
size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
|
void * data = ggml_aligned_malloc(size);
|
||||||
void * data = malloc(size); // TODO: use GGML_ALIGNED_MALLOC (move to ggml-impl.h)
|
|
||||||
if (data == NULL) {
|
if (data == NULL) {
|
||||||
GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size);
|
GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size);
|
||||||
return NULL;
|
return NULL;
|
||||||
@ -809,6 +834,29 @@ ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
|
|||||||
return &ggml_backend_cpu_buffer_type;
|
return &ggml_backend_cpu_buffer_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static const char * ggml_backend_cpu_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
|
||||||
|
return "CPU_Mapped";
|
||||||
|
|
||||||
|
GGML_UNUSED(buft);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type_t ggml_backend_cpu_buffer_from_ptr_type(void) {
|
||||||
|
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = {
|
||||||
|
/* .iface = */ {
|
||||||
|
/* .get_name = */ ggml_backend_cpu_buffer_from_ptr_type_get_name,
|
||||||
|
/* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer,
|
||||||
|
/* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
|
||||||
|
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
||||||
|
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
||||||
|
/* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
|
||||||
|
},
|
||||||
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
||||||
|
/* .context = */ NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
return &ggml_backend_cpu_buffer_type;
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef GGML_USE_CPU_HBM
|
#ifdef GGML_USE_CPU_HBM
|
||||||
|
|
||||||
// buffer type HBM
|
// buffer type HBM
|
||||||
@ -821,18 +869,11 @@ static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffe
|
|||||||
GGML_UNUSED(buft);
|
GGML_UNUSED(buft);
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char * ggml_backend_cpu_hbm_buffer_get_name(ggml_backend_buffer_t buf) {
|
|
||||||
return "CPU_HBM";
|
|
||||||
|
|
||||||
GGML_UNUSED(buf);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
hbw_free(buffer->context);
|
hbw_free(buffer->context);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||||
//void * ptr = hbw_malloc(size);
|
|
||||||
void * ptr;
|
void * ptr;
|
||||||
int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
|
int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
|
||||||
if (result != 0) {
|
if (result != 0) {
|
||||||
@ -842,7 +883,6 @@ static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_
|
|||||||
|
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
||||||
buffer->buft = buft;
|
buffer->buft = buft;
|
||||||
buffer->iface.get_name = ggml_backend_cpu_hbm_buffer_get_name;
|
|
||||||
buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
|
buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
|
||||||
|
|
||||||
return buffer;
|
return buffer;
|
||||||
@ -865,6 +905,21 @@ ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backend_dev_t device) {
|
||||||
|
static ggml_backend_buffer_type_t bufts[] = {
|
||||||
|
#ifdef GGML_USE_CPU_HBM
|
||||||
|
ggml_backend_cpu_hbm_buffer_type(),
|
||||||
|
#endif
|
||||||
|
NULL
|
||||||
|
};
|
||||||
|
|
||||||
|
return bufts;
|
||||||
|
|
||||||
|
GGML_UNUSED(device);
|
||||||
|
}
|
||||||
|
|
||||||
|
// CPU backend - backend (stream)
|
||||||
|
|
||||||
struct ggml_backend_cpu_context {
|
struct ggml_backend_cpu_context {
|
||||||
int n_threads;
|
int n_threads;
|
||||||
ggml_threadpool_t threadpool;
|
ggml_threadpool_t threadpool;
|
||||||
@ -889,12 +944,6 @@ static void ggml_backend_cpu_free(ggml_backend_t backend) {
|
|||||||
delete backend;
|
delete backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_type_t ggml_backend_cpu_get_default_buffer_type(ggml_backend_t backend) {
|
|
||||||
return ggml_backend_cpu_buffer_type();
|
|
||||||
|
|
||||||
GGML_UNUSED(backend);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_backend_plan_cpu {
|
struct ggml_backend_plan_cpu {
|
||||||
struct ggml_cplan cplan;
|
struct ggml_cplan cplan;
|
||||||
struct ggml_cgraph cgraph;
|
struct ggml_cgraph cgraph;
|
||||||
@ -964,7 +1013,6 @@ static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, s
|
|||||||
static const struct ggml_backend_i ggml_backend_cpu_i = {
|
static const struct ggml_backend_i ggml_backend_cpu_i = {
|
||||||
/* .get_name = */ ggml_backend_cpu_get_name,
|
/* .get_name = */ ggml_backend_cpu_get_name,
|
||||||
/* .free = */ ggml_backend_cpu_free,
|
/* .free = */ ggml_backend_cpu_free,
|
||||||
/* .get_default_buffer_type = */ ggml_backend_cpu_get_default_buffer_type,
|
|
||||||
/* .set_tensor_async = */ NULL,
|
/* .set_tensor_async = */ NULL,
|
||||||
/* .get_tensor_async = */ NULL,
|
/* .get_tensor_async = */ NULL,
|
||||||
/* .cpy_tensor_async = */ NULL,
|
/* .cpy_tensor_async = */ NULL,
|
||||||
@ -974,9 +1022,6 @@ static const struct ggml_backend_i ggml_backend_cpu_i = {
|
|||||||
/* .graph_plan_update = */ NULL,
|
/* .graph_plan_update = */ NULL,
|
||||||
/* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute,
|
/* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute,
|
||||||
/* .graph_compute = */ ggml_backend_cpu_graph_compute,
|
/* .graph_compute = */ ggml_backend_cpu_graph_compute,
|
||||||
/* .supports_op = */ NULL,
|
|
||||||
/* .supports_buft = */ NULL,
|
|
||||||
/* .offload_op = */ NULL,
|
|
||||||
/* .event_record = */ NULL,
|
/* .event_record = */ NULL,
|
||||||
/* .event_wait = */ NULL,
|
/* .event_wait = */ NULL,
|
||||||
};
|
};
|
||||||
@ -1047,10 +1092,10 @@ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_
|
|||||||
|
|
||||||
ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
|
ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
|
||||||
GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned");
|
GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned");
|
||||||
return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), ggml_backend_cpu_buffer_from_ptr_i, ptr, size);
|
return ggml_backend_buffer_init(ggml_backend_cpu_buffer_from_ptr_type(), ggml_backend_cpu_buffer_from_ptr_i, ptr, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////
|
// CPU backend - device
|
||||||
|
|
||||||
struct ggml_backend_cpu_device_context {
|
struct ggml_backend_cpu_device_context {
|
||||||
std::string description = "CPU";
|
std::string description = "CPU";
|
||||||
@ -1137,7 +1182,7 @@ static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t *
|
|||||||
}
|
}
|
||||||
|
|
||||||
static enum ggml_backend_dev_type ggml_backend_cpu_device_get_type(ggml_backend_dev_t dev) {
|
static enum ggml_backend_dev_type ggml_backend_cpu_device_get_type(ggml_backend_dev_t dev) {
|
||||||
return GGML_BACKEND_DEVICE_TYPE_CPU_FULL;
|
return GGML_BACKEND_DEVICE_TYPE_CPU;
|
||||||
|
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
}
|
}
|
||||||
@ -1155,7 +1200,7 @@ static void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggm
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_t ggml_backend_cpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
static ggml_backend_t ggml_backend_cpu_device_init_backend(ggml_backend_dev_t dev, const char * params) {
|
||||||
return ggml_backend_cpu_init();
|
return ggml_backend_cpu_init();
|
||||||
|
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
@ -1168,7 +1213,7 @@ static ggml_backend_buffer_type_t ggml_backend_cpu_device_get_buffer_type(ggml_b
|
|||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_t ggml_backend_cpu_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
static ggml_backend_buffer_t ggml_backend_cpu_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
||||||
return ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
return ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
||||||
|
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
@ -1210,10 +1255,10 @@ static const struct ggml_backend_device_i ggml_backend_cpu_device_i = {
|
|||||||
/* .get_memory = */ ggml_backend_cpu_device_get_memory,
|
/* .get_memory = */ ggml_backend_cpu_device_get_memory,
|
||||||
/* .get_type = */ ggml_backend_cpu_device_get_type,
|
/* .get_type = */ ggml_backend_cpu_device_get_type,
|
||||||
/* .get_props = */ ggml_backend_cpu_device_get_props,
|
/* .get_props = */ ggml_backend_cpu_device_get_props,
|
||||||
/* .init_backend = */ ggml_backend_cpu_device_init,
|
/* .init_backend = */ ggml_backend_cpu_device_init_backend,
|
||||||
/* .get_buffer_type = */ ggml_backend_cpu_device_get_buffer_type,
|
/* .get_buffer_type = */ ggml_backend_cpu_device_get_buffer_type,
|
||||||
/* .get_host_buffer_type = */ NULL,
|
/* .get_host_buffer_type = */ NULL,
|
||||||
/* .buffer_from_host_ptr = */ ggml_backend_cpu_device_buffer_from_ptr,
|
/* .buffer_from_host_ptr = */ ggml_backend_cpu_device_buffer_from_host_ptr,
|
||||||
/* .supports_op = */ ggml_backend_cpu_device_supports_op,
|
/* .supports_op = */ ggml_backend_cpu_device_supports_op,
|
||||||
/* .supports_buft = */ ggml_backend_cpu_device_supports_buft,
|
/* .supports_buft = */ ggml_backend_cpu_device_supports_buft,
|
||||||
/* .offload_op = */ NULL,
|
/* .offload_op = */ NULL,
|
||||||
@ -1222,7 +1267,7 @@ static const struct ggml_backend_device_i ggml_backend_cpu_device_i = {
|
|||||||
/* .event_synchronize = */ NULL,
|
/* .event_synchronize = */ NULL,
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////
|
// CPU backend - backend (reg)
|
||||||
|
|
||||||
static const char * ggml_backend_cpu_reg_get_name(ggml_backend_reg_t reg) {
|
static const char * ggml_backend_cpu_reg_get_name(ggml_backend_reg_t reg) {
|
||||||
return "CPU";
|
return "CPU";
|
||||||
@ -1253,6 +1298,10 @@ static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const ch
|
|||||||
if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
|
if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
|
||||||
return (void *)ggml_backend_cpu_set_n_threads;
|
return (void *)ggml_backend_cpu_set_n_threads;
|
||||||
}
|
}
|
||||||
|
if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) {
|
||||||
|
return (void *)ggml_backend_cpu_get_extra_bufts;
|
||||||
|
}
|
||||||
|
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
||||||
GGML_UNUSED(reg);
|
GGML_UNUSED(reg);
|
||||||
@ -1281,12 +1330,6 @@ struct ggml_backend_multi_buffer_context {
|
|||||||
size_t n_buffers;
|
size_t n_buffers;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * ggml_backend_multi_buffer_get_name(ggml_backend_buffer_t buffer) {
|
|
||||||
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
|
|
||||||
|
|
||||||
return ctx->buffers[0]->iface.get_name(ctx->buffers[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
|
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
|
||||||
for (size_t i = 0; i < ctx->n_buffers; i++) {
|
for (size_t i = 0; i < ctx->n_buffers; i++) {
|
||||||
@ -1305,7 +1348,6 @@ static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_
|
|||||||
}
|
}
|
||||||
|
|
||||||
static const struct ggml_backend_buffer_i ggml_backend_multi_buffer_i = {
|
static const struct ggml_backend_buffer_i ggml_backend_multi_buffer_i = {
|
||||||
/* .get_name = */ ggml_backend_multi_buffer_get_name,
|
|
||||||
/* .free_buffer = */ ggml_backend_multi_buffer_free_buffer,
|
/* .free_buffer = */ ggml_backend_multi_buffer_free_buffer,
|
||||||
/* .get_base = */ NULL,
|
/* .get_base = */ NULL,
|
||||||
/* .init_tensor = */ NULL,
|
/* .init_tensor = */ NULL,
|
||||||
@ -1334,7 +1376,7 @@ ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) {
|
bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) {
|
||||||
return buffer->iface.get_name == ggml_backend_multi_buffer_get_name;
|
return buffer->iface.free_buffer == ggml_backend_multi_buffer_free_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
|
void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
|
||||||
@ -1426,7 +1468,7 @@ struct ggml_backend_sched {
|
|||||||
char * context_buffer;
|
char * context_buffer;
|
||||||
size_t context_buffer_size;
|
size_t context_buffer_size;
|
||||||
|
|
||||||
bool debug;
|
int debug;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
|
#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
|
||||||
@ -1514,7 +1556,9 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
|
|||||||
if (src == NULL) {
|
if (src == NULL) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
|
// skip ROPE since the rope freqs tensor is too small to choose a backend based on it
|
||||||
|
// not an ideal solution
|
||||||
|
if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
|
||||||
int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
|
int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
|
||||||
// check if a backend with higher prio wants to offload the op
|
// check if a backend with higher prio wants to offload the op
|
||||||
if (src_backend_id == sched->n_backends - 1) {
|
if (src_backend_id == sched->n_backends - 1) {
|
||||||
@ -1561,6 +1605,7 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
|
|||||||
if (ggml_is_view_op(node->op)) {
|
if (ggml_is_view_op(node->op)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (sched->debug > 1) {
|
||||||
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
|
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
|
||||||
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
|
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
|
||||||
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
|
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
|
||||||
@ -1575,6 +1620,7 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
|
|||||||
}
|
}
|
||||||
GGML_LOG_DEBUG("\n");
|
GGML_LOG_DEBUG("\n");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_sched_buffer_supported(ggml_backend_sched_t sched, struct ggml_tensor * t, int backend_id) {
|
static bool ggml_backend_sched_buffer_supported(ggml_backend_sched_t sched, struct ggml_tensor * t, int backend_id) {
|
||||||
@ -1865,11 +1911,11 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
|||||||
if (src == NULL) {
|
if (src == NULL) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// check if a weight is on a different backend
|
// check if a weight is on a different and incompatible backend
|
||||||
// by starting a new split, the memory of the previously offloaded weights can be reused
|
// by starting a new split, the memory of the previously offloaded weights can be reused
|
||||||
if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
|
if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
|
||||||
int src_backend_id = tensor_backend_id(src);
|
int src_backend_id = tensor_backend_id(src);
|
||||||
if (src_backend_id != cur_backend_id) {
|
if (src_backend_id != cur_backend_id && !ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) {
|
||||||
need_new_split = true;
|
need_new_split = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -1881,7 +1927,6 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
|||||||
int src_backend_id = sched->hv_tensor_backend_ids[id];
|
int src_backend_id = sched->hv_tensor_backend_ids[id];
|
||||||
bool supported = ggml_backend_sched_buffer_supported(sched, src, cur_backend_id);
|
bool supported = ggml_backend_sched_buffer_supported(sched, src, cur_backend_id);
|
||||||
if (src_backend_id != cur_backend_id && tensor_id_copy(id, cur_backend_id, 0) == NULL && !supported) {
|
if (src_backend_id != cur_backend_id && tensor_id_copy(id, cur_backend_id, 0) == NULL && !supported) {
|
||||||
//printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name);
|
|
||||||
need_new_split = true;
|
need_new_split = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -2206,7 +2251,8 @@ ggml_backend_sched_t ggml_backend_sched_new(
|
|||||||
|
|
||||||
struct ggml_backend_sched * sched = (ggml_backend_sched *) calloc(1, sizeof(struct ggml_backend_sched));
|
struct ggml_backend_sched * sched = (ggml_backend_sched *) calloc(1, sizeof(struct ggml_backend_sched));
|
||||||
|
|
||||||
sched->debug = getenv("GGML_SCHED_DEBUG") != NULL;
|
const char * GGML_SCHED_DEBUG = getenv("GGML_SCHED_DEBUG");
|
||||||
|
sched->debug = GGML_SCHED_DEBUG ? atoi(GGML_SCHED_DEBUG) : 0;
|
||||||
sched->n_backends = n_backends;
|
sched->n_backends = n_backends;
|
||||||
sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
|
sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
|
||||||
|
|
||||||
@ -2234,6 +2280,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
|
|||||||
sched->backends[b] = backends[b];
|
sched->backends[b] = backends[b];
|
||||||
sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]);
|
sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]);
|
||||||
GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b]));
|
GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b]));
|
||||||
|
|
||||||
if (sched->n_copies > 1) {
|
if (sched->n_copies > 1) {
|
||||||
for (int c = 0; c < sched->n_copies; c++) {
|
for (int c = 0; c < sched->n_copies; c++) {
|
||||||
sched->events[b][c] = ggml_backend_event_new(backends[b]->device);
|
sched->events[b][c] = ggml_backend_event_new(backends[b]->device);
|
||||||
|
@ -224,12 +224,6 @@ static void ggml_backend_blas_free(ggml_backend_t backend) {
|
|||||||
delete backend;
|
delete backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_type_t ggml_backend_blas_get_default_buffer_type(ggml_backend_t backend) {
|
|
||||||
return ggml_backend_cpu_buffer_type();
|
|
||||||
|
|
||||||
GGML_UNUSED(backend);
|
|
||||||
}
|
|
||||||
|
|
||||||
static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||||
ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
|
ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
|
||||||
|
|
||||||
@ -265,7 +259,6 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
|
|||||||
static struct ggml_backend_i blas_backend_i = {
|
static struct ggml_backend_i blas_backend_i = {
|
||||||
/* .get_name = */ ggml_backend_blas_get_name,
|
/* .get_name = */ ggml_backend_blas_get_name,
|
||||||
/* .free = */ ggml_backend_blas_free,
|
/* .free = */ ggml_backend_blas_free,
|
||||||
/* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type,
|
|
||||||
/* .set_tensor_async = */ NULL,
|
/* .set_tensor_async = */ NULL,
|
||||||
/* .get_tensor_async = */ NULL,
|
/* .get_tensor_async = */ NULL,
|
||||||
/* .cpy_tensor_async = */ NULL,
|
/* .cpy_tensor_async = */ NULL,
|
||||||
@ -275,9 +268,6 @@ static struct ggml_backend_i blas_backend_i = {
|
|||||||
/* .graph_plan_update = */ NULL,
|
/* .graph_plan_update = */ NULL,
|
||||||
/* .graph_plan_compute = */ NULL,
|
/* .graph_plan_compute = */ NULL,
|
||||||
/* .graph_compute = */ ggml_backend_blas_graph_compute,
|
/* .graph_compute = */ ggml_backend_blas_graph_compute,
|
||||||
/* .supports_op = */ NULL,
|
|
||||||
/* .supports_buft = */ NULL,
|
|
||||||
/* .offload_op = */ NULL,
|
|
||||||
/* .event_record = */ NULL,
|
/* .event_record = */ NULL,
|
||||||
/* .event_wait = */ NULL,
|
/* .event_wait = */ NULL,
|
||||||
};
|
};
|
||||||
@ -356,7 +346,7 @@ static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t *
|
|||||||
}
|
}
|
||||||
|
|
||||||
static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
|
static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
|
||||||
return GGML_BACKEND_DEVICE_TYPE_CPU;
|
return GGML_BACKEND_DEVICE_TYPE_ACCEL;
|
||||||
|
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
}
|
}
|
||||||
@ -374,7 +364,7 @@ static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct gg
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_t ggml_backend_blas_device_init(ggml_backend_dev_t dev, const char * params) {
|
static ggml_backend_t ggml_backend_blas_device_init_backend(ggml_backend_dev_t dev, const char * params) {
|
||||||
return ggml_backend_blas_init();
|
return ggml_backend_blas_init();
|
||||||
|
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
@ -387,7 +377,7 @@ static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_
|
|||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
||||||
return ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
return ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
||||||
|
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
@ -456,10 +446,10 @@ static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
|
|||||||
/* .get_memory = */ ggml_backend_blas_device_get_memory,
|
/* .get_memory = */ ggml_backend_blas_device_get_memory,
|
||||||
/* .get_type = */ ggml_backend_blas_device_get_type,
|
/* .get_type = */ ggml_backend_blas_device_get_type,
|
||||||
/* .get_props = */ ggml_backend_blas_device_get_props,
|
/* .get_props = */ ggml_backend_blas_device_get_props,
|
||||||
/* .init_backend = */ ggml_backend_blas_device_init,
|
/* .init_backend = */ ggml_backend_blas_device_init_backend,
|
||||||
/* .get_buffer_type = */ ggml_backend_blas_device_get_buffer_type,
|
/* .get_buffer_type = */ ggml_backend_blas_device_get_buffer_type,
|
||||||
/* .get_host_buffer_type = */ NULL,
|
/* .get_host_buffer_type = */ NULL,
|
||||||
/* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_ptr,
|
/* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_host_ptr,
|
||||||
/* .supports_op = */ ggml_backend_blas_device_supports_op,
|
/* .supports_op = */ ggml_backend_blas_device_supports_op,
|
||||||
/* .supports_buft = */ ggml_backend_blas_device_supports_buft,
|
/* .supports_buft = */ ggml_backend_blas_device_supports_buft,
|
||||||
/* .offload_op = */ NULL,
|
/* .offload_op = */ NULL,
|
||||||
|
@ -39,6 +39,8 @@
|
|||||||
|
|
||||||
#include "ggml-common.h"
|
#include "ggml-common.h"
|
||||||
|
|
||||||
|
#define GGML_CANN_NAME "CANN"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Handles CANN errors by printing an error message and aborting.
|
* @brief Handles CANN errors by printing an error message and aborting.
|
||||||
*
|
*
|
||||||
@ -487,23 +489,6 @@ struct ggml_backend_cann_buffer_context {
|
|||||||
~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
|
~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Retrieve the name associated with a CANN buffer.
|
|
||||||
*
|
|
||||||
* This function returns the name of a CANN buffer, which is stored in the
|
|
||||||
* context of the buffer.
|
|
||||||
*
|
|
||||||
* @param buffer The CANN buffer whose name is to be retrieved.
|
|
||||||
* @return A pointer to a C-string containing the name of the buffer.
|
|
||||||
*/
|
|
||||||
|
|
||||||
static const char* ggml_backend_cann_buffer_get_name(
|
|
||||||
ggml_backend_buffer_t buffer) {
|
|
||||||
return "CANN";
|
|
||||||
|
|
||||||
GGML_UNUSED(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Check if a buffer is a CANN buffer.
|
* @brief Check if a buffer is a CANN buffer.
|
||||||
*
|
*
|
||||||
@ -513,9 +498,10 @@ static const char* ggml_backend_cann_buffer_get_name(
|
|||||||
* @param buffer The buffer to check.
|
* @param buffer The buffer to check.
|
||||||
* @return true if the buffer is a CANN buffer, false otherwise.
|
* @return true if the buffer is a CANN buffer, false otherwise.
|
||||||
*/
|
*/
|
||||||
|
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
|
||||||
static bool ggml_backend_buffer_is_cann(
|
static bool ggml_backend_buffer_is_cann(
|
||||||
ggml_backend_buffer_t buffer) {
|
ggml_backend_buffer_t buffer) {
|
||||||
return buffer->iface.get_name == ggml_backend_cann_buffer_get_name;
|
return ggml_backend_buft_is_cann(buffer->buft);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -851,13 +837,6 @@ static void ggml_backend_cann_buffer_set_tensor(
|
|||||||
void *transform_buffer = malloc(size);
|
void *transform_buffer = malloc(size);
|
||||||
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
||||||
|
|
||||||
#ifndef NDEBUG
|
|
||||||
void *check_buffer = malloc(size);
|
|
||||||
ggml_backend_cann_transform_back(tensor, transform_buffer,
|
|
||||||
check_buffer);
|
|
||||||
GGML_ASSERT(memcmp(data, check_buffer, size) == 0);
|
|
||||||
free(check_buffer);
|
|
||||||
#endif
|
|
||||||
ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
|
ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
|
||||||
transform_buffer, size,
|
transform_buffer, size,
|
||||||
ACL_MEMCPY_HOST_TO_DEVICE));
|
ACL_MEMCPY_HOST_TO_DEVICE));
|
||||||
@ -969,8 +948,7 @@ static void ggml_backend_cann_buffer_clear(
|
|||||||
* This structure defines function pointers to operations that can be performed
|
* This structure defines function pointers to operations that can be performed
|
||||||
* on a CANN buffer within the backend.
|
* on a CANN buffer within the backend.
|
||||||
*/
|
*/
|
||||||
static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
|
static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
|
||||||
/* .get_name = */ ggml_backend_cann_buffer_get_name,
|
|
||||||
/* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
|
/* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
|
||||||
/* .get_base = */ ggml_backend_cann_buffer_get_base,
|
/* .get_base = */ ggml_backend_cann_buffer_get_base,
|
||||||
/* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
|
/* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
|
||||||
@ -1004,9 +982,10 @@ struct ggml_backend_cann_buffer_type_context {
|
|||||||
*/
|
*/
|
||||||
static const char* ggml_backend_cann_buffer_type_name(
|
static const char* ggml_backend_cann_buffer_type_name(
|
||||||
ggml_backend_buffer_type_t buft) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
return "CANN";
|
ggml_backend_cann_buffer_type_context* buft_ctx =
|
||||||
|
(ggml_backend_cann_buffer_type_context*)buft->context;
|
||||||
|
|
||||||
GGML_UNUSED(buft);
|
return buft_ctx->name.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -1105,19 +1084,25 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(
|
|||||||
GGML_UNUSED(buft);
|
GGML_UNUSED(buft);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
||||||
|
return false;
|
||||||
|
|
||||||
|
GGML_UNUSED(buft);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Interface for managing CANN buffer types in the GGML backend.
|
* @brief Interface for managing CANN buffer types in the GGML backend.
|
||||||
*
|
*
|
||||||
* Provides function pointers for allocating, querying properties, and managing
|
* Provides function pointers for allocating, querying properties, and managing
|
||||||
* memory for CANN buffer types in the GGML backend.
|
* memory for CANN buffer types in the GGML backend.
|
||||||
*/
|
*/
|
||||||
static ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
|
static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
|
||||||
/* .get_name = */ ggml_backend_cann_buffer_type_name,
|
/* .get_name = */ ggml_backend_cann_buffer_type_name,
|
||||||
/* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
|
/* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
|
||||||
/* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
|
/* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
|
||||||
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
||||||
/* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
|
/* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
|
||||||
/* .is_host = */ NULL,
|
/* .is_host = */ ggml_backend_cann_buffer_type_is_host,
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -1148,6 +1133,7 @@ ggml_backend_cann_buffer_type(int32_t device) {
|
|||||||
for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
|
for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
|
||||||
ggml_backend_cann_buffer_types[i] = {
|
ggml_backend_cann_buffer_types[i] = {
|
||||||
/* .iface = */ ggml_backend_cann_buffer_type_interface,
|
/* .iface = */ ggml_backend_cann_buffer_type_interface,
|
||||||
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
|
||||||
/* .context = */
|
/* .context = */
|
||||||
new ggml_backend_cann_buffer_type_context{
|
new ggml_backend_cann_buffer_type_context{
|
||||||
i, "CANN" + std::to_string(i)},
|
i, "CANN" + std::to_string(i)},
|
||||||
@ -1263,7 +1249,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
|
|||||||
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
|
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
|
||||||
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
||||||
},
|
},
|
||||||
/* .device = */ nullptr,
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
|
||||||
/* .context = */ nullptr,
|
/* .context = */ nullptr,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1463,24 +1449,6 @@ static void ggml_backend_cann_free(ggml_backend_t backend) {
|
|||||||
delete backend;
|
delete backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Retrieves the default buffer type associated with the CANN backend.
|
|
||||||
*
|
|
||||||
* This function returns the buffer type specific to the device associated
|
|
||||||
* with the CANN backend. It is used to allocate buffers for computations
|
|
||||||
* performed by the backend.
|
|
||||||
*
|
|
||||||
* @param backend Pointer to the CANN backend structure.
|
|
||||||
* @return Pointer to the buffer type structure for the CANN backend.
|
|
||||||
*/
|
|
||||||
static ggml_backend_buffer_type_t
|
|
||||||
ggml_backend_cann_get_default_buffer_type(ggml_backend_t backend) {
|
|
||||||
ggml_backend_cann_context* cann_ctx =
|
|
||||||
(ggml_backend_cann_context*)backend->context;
|
|
||||||
|
|
||||||
return ggml_backend_cann_buffer_type(cann_ctx->device);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Sets tensor data asynchronously in the CANN backend.
|
* @brief Sets tensor data asynchronously in the CANN backend.
|
||||||
*
|
*
|
||||||
@ -1510,13 +1478,6 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
|
|||||||
void *transform_buffer = malloc(size);
|
void *transform_buffer = malloc(size);
|
||||||
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
||||||
|
|
||||||
#ifndef NDEBUG
|
|
||||||
void *check_buffer = malloc(size);
|
|
||||||
ggml_backend_cann_transform_back(tensor, transform_buffer,
|
|
||||||
check_buffer);
|
|
||||||
GGML_ASSERT(memcmp(data, check_buffer, size));
|
|
||||||
free(check_buffer);
|
|
||||||
#endif
|
|
||||||
ACL_CHECK(aclrtMemcpyAsync(
|
ACL_CHECK(aclrtMemcpyAsync(
|
||||||
(char *)tensor->data + offset, size, transform_buffer, size,
|
(char *)tensor->data + offset, size, transform_buffer, size,
|
||||||
ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
|
ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
|
||||||
@ -1691,7 +1652,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
|
|||||||
* @return bool Returns true if the operation is supported by the backend,
|
* @return bool Returns true if the operation is supported by the backend,
|
||||||
* otherwise false.
|
* otherwise false.
|
||||||
*/
|
*/
|
||||||
static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
|
static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||||
const ggml_tensor* op) {
|
const ggml_tensor* op) {
|
||||||
switch (op->op) {
|
switch (op->op) {
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
@ -1782,7 +1743,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_UNUSED(backend);
|
GGML_UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -1800,31 +1761,6 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
|
|||||||
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
|
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Checks if the CANN backend supports a specific backend buffer type.
|
|
||||||
*
|
|
||||||
* This function determines whether the CANN backend supports the given backend
|
|
||||||
* buffer type by comparing the device context of the backend and buffer type.
|
|
||||||
* It returns true if the devices are same between the backend context and
|
|
||||||
* buffer type context.
|
|
||||||
*
|
|
||||||
* @param backend Pointer to the CANN backend.
|
|
||||||
* @param buft Pointer to the backend buffer type to check.
|
|
||||||
* @return bool Returns true if the CANN backend supports the buffer type,
|
|
||||||
* otherwise false.
|
|
||||||
*/
|
|
||||||
static bool ggml_backend_cann_supports_buft(
|
|
||||||
ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
|
||||||
if (ggml_backend_buft_is_cann(buft)) {
|
|
||||||
ggml_backend_cann_context * cann_ctx =
|
|
||||||
(ggml_backend_cann_context *)backend->context;
|
|
||||||
ggml_backend_cann_buffer_type_context * buft_ctx =
|
|
||||||
(ggml_backend_cann_buffer_type_context *)buft->context;
|
|
||||||
return buft_ctx->device == cann_ctx->device;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Determines if a tensor operation should be offloaded to the CANN
|
* @brief Determines if a tensor operation should be offloaded to the CANN
|
||||||
* backend.
|
* backend.
|
||||||
@ -1839,54 +1775,14 @@ static bool ggml_backend_cann_supports_buft(
|
|||||||
* @return bool Returns true if the operation should be offloaded, otherwise
|
* @return bool Returns true if the operation should be offloaded, otherwise
|
||||||
* false.
|
* false.
|
||||||
*/
|
*/
|
||||||
static bool ggml_backend_cann_offload_op(ggml_backend_t backend,
|
static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
|
||||||
const ggml_tensor* op) {
|
const ggml_tensor* op) {
|
||||||
const int min_batch_size = 32;
|
const int min_batch_size = 32;
|
||||||
GGML_UNUSED(backend);
|
GGML_UNUSED(dev);
|
||||||
|
|
||||||
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
|
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Creates a new event for the CANN backend.
|
|
||||||
*
|
|
||||||
* This function initializes a new event for the CANN backend by setting the
|
|
||||||
* device and creating an ACL runtime event. The created event is then wrapped
|
|
||||||
* in a ggml_backend_event structure and returned.
|
|
||||||
*
|
|
||||||
* @param backend Pointer to the CANN backend.
|
|
||||||
* @return ggml_backend_event_t Returns a pointer to the new event structure.
|
|
||||||
*/
|
|
||||||
static ggml_backend_event_t ggml_backend_cann_event_new(
|
|
||||||
ggml_backend_t backend) {
|
|
||||||
ggml_backend_cann_context* cann_ctx =
|
|
||||||
(ggml_backend_cann_context*)backend->context;
|
|
||||||
|
|
||||||
ggml_cann_set_device(cann_ctx->device);
|
|
||||||
|
|
||||||
aclrtEvent event;
|
|
||||||
ACL_CHECK(aclrtCreateEvent(&event));
|
|
||||||
|
|
||||||
return new ggml_backend_event{
|
|
||||||
/* .backend = */ backend,
|
|
||||||
/* .context = */ event,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Frees a CANN backend event.
|
|
||||||
*
|
|
||||||
* This function destroys the ACL runtime event associated with the given CANN
|
|
||||||
* backend event and then deletes the event structure itself.
|
|
||||||
*
|
|
||||||
* @param event Pointer to the event structure to be freed.
|
|
||||||
*/
|
|
||||||
static void ggml_backend_cann_event_free(ggml_backend_event_t event) {
|
|
||||||
ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
|
|
||||||
|
|
||||||
delete event;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Records an event on the CANN backend stream.
|
* @brief Records an event on the CANN backend stream.
|
||||||
*
|
*
|
||||||
@ -1895,10 +1791,9 @@ static void ggml_backend_cann_event_free(ggml_backend_event_t event) {
|
|||||||
*
|
*
|
||||||
* @param event Pointer to the event structure to be recorded.
|
* @param event Pointer to the event structure to be recorded.
|
||||||
*/
|
*/
|
||||||
static void ggml_backend_cann_event_record(ggml_backend_event_t event) {
|
static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
|
||||||
ggml_backend_cann_context* cann_ctx =
|
ggml_backend_cann_context* cann_ctx =
|
||||||
(ggml_backend_cann_context*)event->backend->context;
|
(ggml_backend_cann_context*)backend->context;
|
||||||
|
|
||||||
ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
|
ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1916,8 +1811,7 @@ static void ggml_backend_cann_event_wait(ggml_backend_t backend,
|
|||||||
ggml_backend_event_t event) {
|
ggml_backend_event_t event) {
|
||||||
ggml_backend_cann_context* cann_ctx =
|
ggml_backend_cann_context* cann_ctx =
|
||||||
(ggml_backend_cann_context*)backend->context;
|
(ggml_backend_cann_context*)backend->context;
|
||||||
|
if (ggml_backend_is_cann(backend)) {
|
||||||
if (ggml_backend_is_cann(event->backend)) {
|
|
||||||
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
|
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
|
||||||
(aclrtEvent)event->context));
|
(aclrtEvent)event->context));
|
||||||
} else {
|
} else {
|
||||||
@ -1925,17 +1819,6 @@ static void ggml_backend_cann_event_wait(ggml_backend_t backend,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Synchronizes the given event on the CANN backend.
|
|
||||||
*
|
|
||||||
* This function waits for the specified event to complete on the ACL runtime.
|
|
||||||
*
|
|
||||||
* @param event Pointer to the event structure to be synchronized.
|
|
||||||
*/
|
|
||||||
static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) {
|
|
||||||
ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Structure defining the interface for the CANN backend.
|
* @brief Structure defining the interface for the CANN backend.
|
||||||
*
|
*
|
||||||
@ -1943,10 +1826,9 @@ static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) {
|
|||||||
* supported by the CANN backend, including name retrieval, memory
|
* supported by the CANN backend, including name retrieval, memory
|
||||||
* management, tensor operations, synchronization, and event handling.
|
* management, tensor operations, synchronization, and event handling.
|
||||||
*/
|
*/
|
||||||
static ggml_backend_i ggml_backend_cann_interface = {
|
static const ggml_backend_i ggml_backend_cann_interface = {
|
||||||
/* .get_name = */ ggml_backend_cann_name,
|
/* .get_name = */ ggml_backend_cann_name,
|
||||||
/* .free = */ ggml_backend_cann_free,
|
/* .free = */ ggml_backend_cann_free,
|
||||||
/* .get_default_buffer_type = */ ggml_backend_cann_get_default_buffer_type,
|
|
||||||
/* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
|
/* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
|
||||||
/* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
|
/* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
|
||||||
/* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
|
/* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
|
||||||
@ -1956,9 +1838,6 @@ static ggml_backend_i ggml_backend_cann_interface = {
|
|||||||
/* .graph_plan_update = */ NULL,
|
/* .graph_plan_update = */ NULL,
|
||||||
/* .graph_plan_compute = */ NULL,
|
/* .graph_plan_compute = */ NULL,
|
||||||
/* .graph_compute = */ ggml_backend_cann_graph_compute,
|
/* .graph_compute = */ ggml_backend_cann_graph_compute,
|
||||||
/* .supports_op = */ ggml_backend_cann_supports_op,
|
|
||||||
/* .supports_buft = */ ggml_backend_cann_supports_buft,
|
|
||||||
/* .offload_op = */ ggml_backend_cann_offload_op,
|
|
||||||
/* .event_record = */ ggml_backend_cann_event_record,
|
/* .event_record = */ ggml_backend_cann_event_record,
|
||||||
/* .event_wait = */ ggml_backend_cann_event_wait,
|
/* .event_wait = */ ggml_backend_cann_event_wait,
|
||||||
};
|
};
|
||||||
@ -1977,6 +1856,234 @@ static ggml_guid_t ggml_backend_cann_guid() {
|
|||||||
return &guid;
|
return &guid;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// backend device
|
||||||
|
struct ggml_backend_cann_device_context {
|
||||||
|
int device;
|
||||||
|
std::string name;
|
||||||
|
std::string description;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
|
||||||
|
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||||
|
return ctx->name.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
|
||||||
|
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||||
|
return ctx->description.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||||
|
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||||
|
ggml_backend_cann_get_device_memory(ctx->device, free, total);
|
||||||
|
}
|
||||||
|
|
||||||
|
static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
|
||||||
|
GGML_UNUSED(dev);
|
||||||
|
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||||
|
props->name = ggml_backend_cann_device_get_name(dev);
|
||||||
|
props->description = ggml_backend_cann_device_get_description(dev);
|
||||||
|
props->type = ggml_backend_cann_device_get_type(dev);
|
||||||
|
ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||||
|
|
||||||
|
bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
|
||||||
|
|
||||||
|
props->caps = {
|
||||||
|
/* .async = */ false,
|
||||||
|
/* .host_buffer = */ host_buffer,
|
||||||
|
/* .buffer_from_host_ptr = */ false,
|
||||||
|
/* .events = */ true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||||
|
GGML_UNUSED(params);
|
||||||
|
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||||
|
return ggml_backend_cann_init(ctx->device);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Checks if the CANN backend supports a specific backend buffer type.
|
||||||
|
*
|
||||||
|
* This function determines whether the CANN backend supports the given backend
|
||||||
|
* buffer type by comparing the device context of the backend and buffer type.
|
||||||
|
* It returns true if the devices are same between the backend context and
|
||||||
|
* buffer type context.
|
||||||
|
*
|
||||||
|
* @param backend Pointer to the CANN backend.
|
||||||
|
* @param buft Pointer to the backend buffer type to check.
|
||||||
|
* @return bool Returns true if the CANN backend supports the buffer type,
|
||||||
|
* otherwise false.
|
||||||
|
*/
|
||||||
|
static bool ggml_backend_cann_supports_buft(
|
||||||
|
ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||||
|
if (ggml_backend_buft_is_cann(buft)) {
|
||||||
|
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||||
|
ggml_backend_cann_buffer_type_context * buft_ctx =
|
||||||
|
(ggml_backend_cann_buffer_type_context *)buft->context;
|
||||||
|
return buft_ctx->device == dev_ctx->device;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||||
|
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||||
|
return ggml_backend_cann_buffer_type(ctx->device);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
|
||||||
|
GGML_UNUSED(dev);
|
||||||
|
return ggml_backend_cann_host_buffer_type();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Creates a new event for the CANN backend device.
|
||||||
|
*
|
||||||
|
* This function initializes a new event for the CANN backend by setting the
|
||||||
|
* device and creating an ACL runtime event. The created event is then wrapped
|
||||||
|
* in a ggml_backend_event structure and returned.
|
||||||
|
*
|
||||||
|
* @param backend Pointer to the CANN backend.
|
||||||
|
* @return ggml_backend_event_t Returns a pointer to the new event structure.
|
||||||
|
*/
|
||||||
|
static ggml_backend_event_t ggml_backend_cann_device_event_new(
|
||||||
|
ggml_backend_dev_t dev) {
|
||||||
|
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||||
|
|
||||||
|
ggml_cann_set_device(dev_ctx->device);
|
||||||
|
|
||||||
|
aclrtEvent event;
|
||||||
|
ACL_CHECK(aclrtCreateEvent(&event));
|
||||||
|
|
||||||
|
return new ggml_backend_event{
|
||||||
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
|
||||||
|
/* .context = */ event,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Frees a CANN backend event.
|
||||||
|
*
|
||||||
|
* This function destroys the ACL runtime event associated with the given CANN
|
||||||
|
* backend event and then deletes the event structure itself.
|
||||||
|
*
|
||||||
|
* @param event Pointer to the event structure to be freed.
|
||||||
|
*/
|
||||||
|
static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
||||||
|
ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
|
||||||
|
|
||||||
|
delete event;
|
||||||
|
GGML_UNUSED(dev);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Synchronizes the given event on the CANN backend.
|
||||||
|
*
|
||||||
|
* This function waits for the specified event to complete on the ACL runtime.
|
||||||
|
*
|
||||||
|
* @param event Pointer to the event structure to be synchronized.
|
||||||
|
*/
|
||||||
|
static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
||||||
|
ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
|
||||||
|
|
||||||
|
GGML_UNUSED(dev);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const ggml_backend_device_i ggml_backend_cann_device_interface = {
|
||||||
|
/* .get_name = */ ggml_backend_cann_device_get_name,
|
||||||
|
/* .get_description = */ ggml_backend_cann_device_get_description,
|
||||||
|
/* .get_memory = */ ggml_backend_cann_device_get_memory,
|
||||||
|
/* .get_type = */ ggml_backend_cann_device_get_type,
|
||||||
|
/* .get_props = */ ggml_backend_cann_device_get_props,
|
||||||
|
/* .init_backend = */ ggml_backend_cann_device_init, // called for every card
|
||||||
|
/* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type,
|
||||||
|
/* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type,
|
||||||
|
/* .buffer_from_host_ptr = */ NULL, // not supported for CANN
|
||||||
|
/* .supports_op = */ ggml_backend_cann_supports_op,
|
||||||
|
/* .supports_buft = */ ggml_backend_cann_supports_buft,
|
||||||
|
/* .offload_op = */ ggml_backend_cann_offload_op,
|
||||||
|
/* .event_new = */ ggml_backend_cann_device_event_new,
|
||||||
|
/* .event_free = */ ggml_backend_cann_device_event_free,
|
||||||
|
/* .event_synchronize = */ ggml_backend_cann_device_event_synchronize,
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// backend reg
|
||||||
|
struct ggml_backend_cann_reg_context {
|
||||||
|
std::vector<ggml_backend_dev_t> devices;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
|
||||||
|
GGML_UNUSED(reg);
|
||||||
|
return GGML_CANN_NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||||
|
ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
|
||||||
|
return ctx->devices.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
||||||
|
ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
|
||||||
|
GGML_ASSERT(index < ctx->devices.size());
|
||||||
|
return ctx->devices[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
||||||
|
GGML_UNUSED(reg);
|
||||||
|
GGML_UNUSED(name);
|
||||||
|
// reserved for future use
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
|
||||||
|
/* .get_name = */ ggml_backend_cann_reg_get_name,
|
||||||
|
/* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
|
||||||
|
/* .get_device_get = */ ggml_backend_cann_reg_get_device,
|
||||||
|
/* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
|
||||||
|
};
|
||||||
|
|
||||||
|
// backend registry, called only once for cann backend
|
||||||
|
ggml_backend_reg_t ggml_backend_cann_reg() {
|
||||||
|
static ggml_backend_reg reg;
|
||||||
|
static bool initialized = false;
|
||||||
|
|
||||||
|
{
|
||||||
|
static std::mutex mutex;
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
if (!initialized) {
|
||||||
|
aclInit(nullptr);
|
||||||
|
ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
|
||||||
|
|
||||||
|
for (int i = 0; i < ggml_cann_info().device_count; i++) {
|
||||||
|
ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
|
||||||
|
dev_ctx->description = aclrtGetSocName();
|
||||||
|
dev_ctx->device = i;
|
||||||
|
dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
|
||||||
|
ggml_cann_set_device(i);
|
||||||
|
ggml_backend_dev_t dev = new ggml_backend_device {
|
||||||
|
/* .interface = */ ggml_backend_cann_device_interface,
|
||||||
|
/* .reg = */ ®,
|
||||||
|
/* .context = */ dev_ctx
|
||||||
|
};
|
||||||
|
ctx->devices.push_back(dev);
|
||||||
|
}
|
||||||
|
|
||||||
|
reg = ggml_backend_reg {
|
||||||
|
/* .interface = */ ggml_backend_cann_reg_interface,
|
||||||
|
/* .context = */ ctx
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
initialized = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ®
|
||||||
|
}
|
||||||
|
|
||||||
ggml_backend_t ggml_backend_cann_init(int32_t device) {
|
ggml_backend_t ggml_backend_cann_init(int32_t device) {
|
||||||
aclInit(nullptr);
|
aclInit(nullptr);
|
||||||
if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
|
if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
|
||||||
@ -1993,7 +2100,7 @@ ggml_backend_t ggml_backend_cann_init(int32_t device) {
|
|||||||
ggml_backend_t cann_backend =
|
ggml_backend_t cann_backend =
|
||||||
new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
|
new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
|
||||||
/* .interface = */ ggml_backend_cann_interface,
|
/* .interface = */ ggml_backend_cann_interface,
|
||||||
/* .device = */ nullptr,
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
|
||||||
/* .context = */ ctx};
|
/* .context = */ ctx};
|
||||||
|
|
||||||
return cann_backend;
|
return cann_backend;
|
||||||
|
@ -421,20 +421,15 @@ struct ggml_backend_cuda_buffer_context {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * ggml_backend_cuda_buffer_get_name(ggml_backend_buffer_t buffer) {
|
|
||||||
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
|
||||||
return ctx->name.c_str();
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
|
|
||||||
return buffer->iface.get_name == ggml_backend_cuda_buffer_get_name;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
||||||
delete ctx;
|
delete ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
|
||||||
|
return buffer->iface.free_buffer == ggml_backend_cuda_buffer_free_buffer;
|
||||||
|
}
|
||||||
|
|
||||||
static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
||||||
return ctx->dev_ptr;
|
return ctx->dev_ptr;
|
||||||
@ -515,7 +510,6 @@ static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
|
|||||||
}
|
}
|
||||||
|
|
||||||
static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
|
static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
|
||||||
/* .get_name = */ ggml_backend_cuda_buffer_get_name,
|
|
||||||
/* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
|
/* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
|
||||||
/* .get_base = */ ggml_backend_cuda_buffer_get_base,
|
/* .get_base = */ ggml_backend_cuda_buffer_get_base,
|
||||||
/* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
|
/* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
|
||||||
@ -548,8 +542,6 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
|
|||||||
|
|
||||||
ggml_cuda_set_device(buft_ctx->device);
|
ggml_cuda_set_device(buft_ctx->device);
|
||||||
|
|
||||||
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
|
|
||||||
|
|
||||||
void * dev_ptr;
|
void * dev_ptr;
|
||||||
cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
|
cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
|
||||||
if (err != cudaSuccess) {
|
if (err != cudaSuccess) {
|
||||||
@ -657,7 +649,9 @@ static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_spl
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_backend_cuda_split_buffer_type_context {
|
struct ggml_backend_cuda_split_buffer_type_context {
|
||||||
|
int main_device;
|
||||||
std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
|
std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
|
||||||
|
std::string name;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_backend_cuda_split_buffer_context {
|
struct ggml_backend_cuda_split_buffer_context {
|
||||||
@ -680,16 +674,6 @@ struct ggml_backend_cuda_split_buffer_context {
|
|||||||
std::vector<ggml_tensor_extra_gpu *> tensor_extras;
|
std::vector<ggml_tensor_extra_gpu *> tensor_extras;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * ggml_backend_cuda_split_buffer_get_name(ggml_backend_buffer_t buffer) {
|
|
||||||
return GGML_CUDA_NAME "_Split";
|
|
||||||
|
|
||||||
GGML_UNUSED(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool ggml_backend_buffer_is_cuda_split(ggml_backend_buffer_t buffer) {
|
|
||||||
return buffer->iface.get_name == ggml_backend_cuda_split_buffer_get_name;
|
|
||||||
GGML_UNUSED(ggml_backend_buffer_is_cuda_split); // only used in debug builds currently, avoid unused function warning in release builds
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
|
ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
|
||||||
@ -833,7 +817,6 @@ static void ggml_backend_cuda_split_buffer_clear(ggml_backend_buffer_t buffer, u
|
|||||||
}
|
}
|
||||||
|
|
||||||
static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
|
static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
|
||||||
/* .get_name = */ ggml_backend_cuda_split_buffer_get_name,
|
|
||||||
/* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer,
|
/* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer,
|
||||||
/* .get_base = */ ggml_backend_cuda_split_buffer_get_base,
|
/* .get_base = */ ggml_backend_cuda_split_buffer_get_base,
|
||||||
/* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor,
|
/* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor,
|
||||||
@ -848,9 +831,9 @@ static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
|
|||||||
// cuda split buffer type
|
// cuda split buffer type
|
||||||
|
|
||||||
static const char * ggml_backend_cuda_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
static const char * ggml_backend_cuda_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
||||||
return GGML_CUDA_NAME "_Split";
|
ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
|
||||||
|
|
||||||
GGML_UNUSED(buft);
|
return ctx->name.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) {
|
static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) {
|
||||||
@ -915,11 +898,11 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte
|
|||||||
/* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host,
|
/* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host,
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split) {
|
ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) {
|
||||||
static std::mutex mutex;
|
static std::mutex mutex;
|
||||||
std::lock_guard<std::mutex> lock(mutex);
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
|
||||||
static std::map<std::array<float, GGML_CUDA_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
|
static std::map<std::pair<int, std::array<float, GGML_CUDA_MAX_DEVICES>>, struct ggml_backend_buffer_type> buft_map;
|
||||||
|
|
||||||
std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split_arr = {};
|
std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split_arr = {};
|
||||||
|
|
||||||
@ -937,18 +920,23 @@ ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * ten
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto it = buft_map.find(tensor_split_arr);
|
auto it = buft_map.find({main_device, tensor_split_arr});
|
||||||
if (it != buft_map.end()) {
|
if (it != buft_map.end()) {
|
||||||
return &it->second;
|
return &it->second;
|
||||||
}
|
}
|
||||||
|
auto * ctx = new ggml_backend_cuda_split_buffer_type_context{
|
||||||
|
main_device,
|
||||||
|
tensor_split_arr,
|
||||||
|
GGML_CUDA_NAME + std::to_string(main_device) + "_Split",
|
||||||
|
};
|
||||||
|
|
||||||
struct ggml_backend_buffer_type buft {
|
struct ggml_backend_buffer_type buft {
|
||||||
/* .iface = */ ggml_backend_cuda_split_buffer_type_interface,
|
/* .iface = */ ggml_backend_cuda_split_buffer_type_interface,
|
||||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), 0),
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), main_device),
|
||||||
/* .context = */ new ggml_backend_cuda_split_buffer_type_context{tensor_split_arr},
|
/* .context = */ ctx,
|
||||||
};
|
};
|
||||||
|
|
||||||
auto result = buft_map.emplace(tensor_split_arr, buft);
|
auto result = buft_map.emplace(std::make_pair(main_device, tensor_split_arr), buft);
|
||||||
return &result.first->second;
|
return &result.first->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -960,12 +948,6 @@ static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_
|
|||||||
GGML_UNUSED(buft);
|
GGML_UNUSED(buft);
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char * ggml_backend_cuda_host_buffer_name(ggml_backend_buffer_t buffer) {
|
|
||||||
return GGML_CUDA_NAME "_Host";
|
|
||||||
|
|
||||||
GGML_UNUSED(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
CUDA_CHECK(cudaFreeHost(buffer->context));
|
CUDA_CHECK(cudaFreeHost(buffer->context));
|
||||||
}
|
}
|
||||||
@ -998,7 +980,6 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm
|
|||||||
|
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
||||||
buffer->buft = buft;
|
buffer->buft = buft;
|
||||||
buffer->iface.get_name = ggml_backend_cuda_host_buffer_name;
|
|
||||||
buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
|
buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
|
||||||
|
|
||||||
return buffer;
|
return buffer;
|
||||||
@ -1151,7 +1132,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
|
|||||||
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
|
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
|
||||||
char * src_ptr = (char *) src->data;
|
const char * src_ptr = (const char *) src->data;
|
||||||
char * dst_ptr = (char *) dst;
|
char * dst_ptr = (char *) dst;
|
||||||
|
|
||||||
const int64_t ne0 = src->ne[0];
|
const int64_t ne0 = src->ne[0];
|
||||||
@ -1162,7 +1143,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
|
|||||||
const enum ggml_type type = src->type;
|
const enum ggml_type type = src->type;
|
||||||
const int64_t ts = ggml_type_size(type);
|
const int64_t ts = ggml_type_size(type);
|
||||||
const int64_t bs = ggml_blck_size(type);
|
const int64_t bs = ggml_blck_size(type);
|
||||||
int64_t i1_diff = i1_high - i1_low;
|
const int64_t i1_diff = i1_high - i1_low;
|
||||||
|
|
||||||
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
|
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
|
||||||
if (nb0 == ts && nb1 == ts*ne0/bs) {
|
if (nb0 == ts && nb1 == ts*ne0/bs) {
|
||||||
@ -1400,7 +1381,7 @@ static void ggml_cuda_op_mul_mat(
|
|||||||
|
|
||||||
const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
|
const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
|
||||||
|
|
||||||
const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
|
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
|
||||||
GGML_ASSERT(!(split && ne02 > 1));
|
GGML_ASSERT(!(split && ne02 > 1));
|
||||||
GGML_ASSERT(!(split && ne03 > 1));
|
GGML_ASSERT(!(split && ne03 > 1));
|
||||||
GGML_ASSERT(!(split && ne02 < ne12));
|
GGML_ASSERT(!(split && ne02 < ne12));
|
||||||
@ -1479,14 +1460,24 @@ static void ggml_cuda_op_mul_mat(
|
|||||||
if (src0_is_contiguous) {
|
if (src0_is_contiguous) {
|
||||||
dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
|
dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
|
||||||
} else {
|
} else {
|
||||||
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0));
|
// If src0 is not contiguous it will be copied to a temporary buffer.
|
||||||
|
// This buffer needs to be cleared entirely because multiple regions will function as padding.
|
||||||
|
const size_t nbytes_data = ggml_nbytes(src0);
|
||||||
|
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
||||||
|
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
|
||||||
|
// TODO: remove this for MUSA once the Guilty Lockup issue is resolved
|
||||||
|
#ifndef GGML_USE_MUSA
|
||||||
|
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
|
||||||
|
#else // GGML_USE_MUSA
|
||||||
|
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
|
||||||
|
#endif // !GGML_USE_MUSA
|
||||||
}
|
}
|
||||||
|
|
||||||
// If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared:
|
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
|
||||||
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
|
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
|
||||||
const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
|
const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
|
||||||
const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
||||||
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
|
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (src1_on_device && src1_is_contiguous) {
|
if (src1_on_device && src1_is_contiguous) {
|
||||||
@ -1880,7 +1871,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
|
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
|
||||||
|
|
||||||
bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
|
bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
|
||||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||||
@ -2007,7 +1998,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||||||
|
|
||||||
GGML_TENSOR_BINARY_OP_LOCALS
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
|
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
|
||||||
|
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
@ -2140,7 +2131,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||||||
|
|
||||||
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
|
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
|
||||||
// why is this here instead of mul_mat?
|
// why is this here instead of mul_mat?
|
||||||
if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
|
if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) {
|
||||||
ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
|
ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2361,12 +2352,6 @@ static void ggml_backend_cuda_free(ggml_backend_t backend) {
|
|||||||
delete backend;
|
delete backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_type_t ggml_backend_cuda_get_default_buffer_type(ggml_backend_t backend) {
|
|
||||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
|
||||||
|
|
||||||
return ggml_backend_cuda_buffer_type(cuda_ctx->device);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||||
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
||||||
@ -2572,7 +2557,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->src[0] && node->src[0]->buffer && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) {
|
if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
|
||||||
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
|
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
|
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
|
||||||
@ -2659,7 +2644,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|||||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||||
if (node->src[j] != nullptr) {
|
if (node->src[j] != nullptr) {
|
||||||
assert(node->src[j]->buffer);
|
assert(node->src[j]->buffer);
|
||||||
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
|
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
|
||||||
|
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -2752,7 +2738,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|||||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||||
if (stat == cudaErrorGraphExecUpdateFailure) {
|
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_LOG_ERROR("%s: CUDA graph update failed\n", __func__);
|
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
||||||
#endif
|
#endif
|
||||||
// The pre-existing graph exec cannot be updated due to violated constraints
|
// The pre-existing graph exec cannot be updated due to violated constraints
|
||||||
// so instead clear error and re-instantiate
|
// so instead clear error and re-instantiate
|
||||||
@ -2801,7 +2787,6 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
|
|||||||
static const ggml_backend_i ggml_backend_cuda_interface = {
|
static const ggml_backend_i ggml_backend_cuda_interface = {
|
||||||
/* .get_name = */ ggml_backend_cuda_get_name,
|
/* .get_name = */ ggml_backend_cuda_get_name,
|
||||||
/* .free = */ ggml_backend_cuda_free,
|
/* .free = */ ggml_backend_cuda_free,
|
||||||
/* .get_default_buffer_type = */ ggml_backend_cuda_get_default_buffer_type,
|
|
||||||
/* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
|
/* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
|
||||||
/* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
|
/* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
|
||||||
/* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async,
|
/* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async,
|
||||||
@ -2811,9 +2796,6 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
|
|||||||
/* .graph_plan_update = */ NULL,
|
/* .graph_plan_update = */ NULL,
|
||||||
/* .graph_plan_compute = */ NULL,
|
/* .graph_plan_compute = */ NULL,
|
||||||
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
|
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
|
||||||
/* .supports_op = */ NULL, // moved to device
|
|
||||||
/* .supports_buft = */ NULL, // moved to device
|
|
||||||
/* .offload_op = */ NULL, // moved to device
|
|
||||||
/* .event_record = */ ggml_backend_cuda_event_record,
|
/* .event_record = */ ggml_backend_cuda_event_record,
|
||||||
/* .event_wait = */ ggml_backend_cuda_event_wait,
|
/* .event_wait = */ ggml_backend_cuda_event_wait,
|
||||||
};
|
};
|
||||||
@ -2903,7 +2885,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
|
|||||||
|
|
||||||
static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
|
static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
|
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||||
@ -2927,7 +2909,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_t ggml_backend_cuda_device_init(ggml_backend_dev_t dev, const char * params) {
|
static ggml_backend_t ggml_backend_cuda_device_init_backend(ggml_backend_dev_t dev, const char * params) {
|
||||||
GGML_UNUSED(params);
|
GGML_UNUSED(params);
|
||||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||||
return ggml_backend_cuda_init(ctx->device);
|
return ggml_backend_cuda_init(ctx->device);
|
||||||
@ -2943,18 +2925,29 @@ static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_host_buffer_type(
|
|||||||
return ggml_backend_cuda_host_buffer_type();
|
return ggml_backend_cuda_host_buffer_type();
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_t ggml_backend_cuda_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
|
||||||
GGML_UNUSED(dev);
|
|
||||||
GGML_UNUSED(ptr);
|
|
||||||
GGML_UNUSED(size);
|
|
||||||
GGML_UNUSED(max_tensor_size);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: move these functions here
|
// TODO: move these functions here
|
||||||
static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||||
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
|
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
|
||||||
|
|
||||||
|
// split buffers can only be used with GGML_OP_MUL_MAT
|
||||||
|
if (op->op != GGML_OP_MUL_MAT) {
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||||
|
if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda_split(op->src[i]->buffer->buft)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if all the sources are allocated on this device
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||||
|
if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda(op->src[i]->buffer->buft)) {
|
||||||
|
ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)op->src[i]->buffer->buft->context;
|
||||||
|
if (buft_ctx->device != dev_ctx->device) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch (op->op) {
|
switch (op->op) {
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(op)) {
|
switch (ggml_get_unary_op(op)) {
|
||||||
@ -3114,18 +3107,20 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_NORM:
|
||||||
|
case GGML_OP_RMS_NORM:
|
||||||
|
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
||||||
|
break;
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
case GGML_OP_VIEW:
|
case GGML_OP_VIEW:
|
||||||
case GGML_OP_PERMUTE:
|
case GGML_OP_PERMUTE:
|
||||||
case GGML_OP_TRANSPOSE:
|
case GGML_OP_TRANSPOSE:
|
||||||
case GGML_OP_NORM:
|
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_RMS_NORM:
|
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
case GGML_OP_SQRT:
|
case GGML_OP_SQRT:
|
||||||
@ -3141,7 +3136,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
return op->src[0]->type == GGML_TYPE_F16;
|
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
@ -3181,24 +3175,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||||
if (ggml_backend_buft_is_cuda_split(buft)) {
|
return (ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev;
|
||||||
return true;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (ggml_backend_buft_is_cuda(buft)) {
|
static int64_t get_op_batch_size(const ggml_tensor * op) {
|
||||||
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *)dev->context;
|
switch (op->op) {
|
||||||
ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
|
case GGML_OP_GET_ROWS:
|
||||||
return buft_ctx->device == dev_ctx->device;
|
return 0;
|
||||||
|
case GGML_OP_MUL_MAT:
|
||||||
|
return op->ne[1];
|
||||||
|
case GGML_OP_MUL_MAT_ID:
|
||||||
|
case GGML_OP_ROPE:
|
||||||
|
return op->ne[2];
|
||||||
|
default:
|
||||||
|
return ggml_nrows(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||||
const int min_batch_size = 32;
|
const int min_batch_size = 32;
|
||||||
|
|
||||||
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
|
return get_op_batch_size(op) >= min_batch_size;
|
||||||
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
|
|
||||||
|
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
}
|
}
|
||||||
@ -3239,10 +3236,10 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
|||||||
/* .get_memory = */ ggml_backend_cuda_device_get_memory,
|
/* .get_memory = */ ggml_backend_cuda_device_get_memory,
|
||||||
/* .get_type = */ ggml_backend_cuda_device_get_type,
|
/* .get_type = */ ggml_backend_cuda_device_get_type,
|
||||||
/* .get_props = */ ggml_backend_cuda_device_get_props,
|
/* .get_props = */ ggml_backend_cuda_device_get_props,
|
||||||
/* .init_backend = */ ggml_backend_cuda_device_init,
|
/* .init_backend = */ ggml_backend_cuda_device_init_backend,
|
||||||
/* .get_buffer_type = */ ggml_backend_cuda_device_get_buffer_type,
|
/* .get_buffer_type = */ ggml_backend_cuda_device_get_buffer_type,
|
||||||
/* .get_host_buffer_type = */ ggml_backend_cuda_device_get_host_buffer_type,
|
/* .get_host_buffer_type = */ ggml_backend_cuda_device_get_host_buffer_type,
|
||||||
/* .buffer_from_host_ptr = */ ggml_backend_cuda_device_buffer_from_host_ptr,
|
/* .buffer_from_host_ptr = */ NULL,
|
||||||
/* .supports_op = */ ggml_backend_cuda_device_supports_op,
|
/* .supports_op = */ ggml_backend_cuda_device_supports_op,
|
||||||
/* .supports_buft = */ ggml_backend_cuda_device_supports_buft,
|
/* .supports_buft = */ ggml_backend_cuda_device_supports_buft,
|
||||||
/* .offload_op = */ ggml_backend_cuda_device_offload_op,
|
/* .offload_op = */ ggml_backend_cuda_device_offload_op,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
|
||||||
#define CUDA_CPY_BLOCK_SIZE 32
|
#define CUDA_CPY_BLOCK_SIZE 64
|
||||||
|
|
||||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
||||||
|
|
||||||
|
@ -416,10 +416,11 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
|
|||||||
|
|
||||||
static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||||
const half * x = (const half *) vx;
|
const half * x = (const half *) vx;
|
||||||
|
// load 2 halfs into register in a single instruction
|
||||||
|
const half2 x_reg = *((half2 *) &(x[ib + iqs]));
|
||||||
// automatic half -> float type cast if dfloat == float
|
// automatic half -> float type cast if dfloat == float
|
||||||
v.x = x[ib + iqs + 0];
|
v.x = __low2float(x_reg);
|
||||||
v.y = x[ib + iqs + 1];
|
v.y = __high2float(x_reg);
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
|
static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
|
||||||
@ -476,13 +477,28 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
|
|||||||
// matrix multiplication
|
// matrix multiplication
|
||||||
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
||||||
#ifdef GGML_CUDA_F16
|
#ifdef GGML_CUDA_F16
|
||||||
|
if ( y_offset == 1 ) {
|
||||||
|
// load 2 dfloats into register in a single instruction
|
||||||
|
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
|
||||||
|
tmp += __hmul2(v, y_reg);
|
||||||
|
}
|
||||||
|
else {
|
||||||
tmp += __hmul2(v, {
|
tmp += __hmul2(v, {
|
||||||
y[iybs + iqs + j/qr + 0],
|
y[iybs + iqs + j/qr + 0],
|
||||||
y[iybs + iqs + j/qr + y_offset]
|
y[iybs + iqs + j/qr + y_offset]
|
||||||
});
|
});
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
|
if ( y_offset == 1 ) {
|
||||||
|
// load 2 dfloats into register in a single instruction
|
||||||
|
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
|
||||||
|
tmp += v.x * y_reg.x;
|
||||||
|
tmp += v.y * y_reg.y;
|
||||||
|
}
|
||||||
|
else {
|
||||||
tmp += v.x * y[iybs + iqs + j/qr + 0];
|
tmp += v.x * y[iybs + iqs + j/qr + 0];
|
||||||
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
|
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
|
||||||
|
}
|
||||||
#endif // GGML_CUDA_F16
|
#endif // GGML_CUDA_F16
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -92,8 +92,8 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
const int64_t OW = dst->ne[1];
|
const int64_t OW = dst->ne[1];
|
||||||
|
|
||||||
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||||
const int64_t batch = src1->ne[3];
|
const int64_t batch = src1->ne[is_2D ? 3 : 2];
|
||||||
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
|
const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
||||||
|
|
||||||
if(dst->type == GGML_TYPE_F16) {
|
if(dst->type == GGML_TYPE_F16) {
|
||||||
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
|
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
|
||||||
|
@ -8,8 +8,6 @@ void ggml_cuda_op_mul_mat_q(
|
|||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
|
||||||
const int64_t nb01 = src0->nb[1];
|
|
||||||
|
|
||||||
const int64_t ne10 = src1->ne[0];
|
const int64_t ne10 = src1->ne[0];
|
||||||
const int64_t ne11 = src1->ne[1];
|
const int64_t ne11 = src1->ne[1];
|
||||||
GGML_ASSERT(ne10 % QK8_1 == 0);
|
GGML_ASSERT(ne10 % QK8_1 == 0);
|
||||||
@ -17,7 +15,7 @@ void ggml_cuda_op_mul_mat_q(
|
|||||||
const int64_t ne0 = dst->ne[0];
|
const int64_t ne0 = dst->ne[0];
|
||||||
|
|
||||||
const int64_t row_diff = row_high - row_low;
|
const int64_t row_diff = row_high - row_low;
|
||||||
const int64_t stride00 = nb01 / ggml_type_size(src0->type);
|
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
|
||||||
|
|
||||||
int id = ggml_cuda_get_device();
|
int id = ggml_cuda_get_device();
|
||||||
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
||||||
|
@ -19,6 +19,9 @@ extern "C" {
|
|||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
|
|
||||||
|
// required for mmap as gguf only guarantees 32-byte alignment
|
||||||
|
#define TENSOR_ALIGNMENT 32
|
||||||
|
|
||||||
// static_assert should be a #define, but if it's not,
|
// static_assert should be a #define, but if it's not,
|
||||||
// fall back to the _Static_assert C11 keyword.
|
// fall back to the _Static_assert C11 keyword.
|
||||||
// if C99 - static_assert is noop
|
// if C99 - static_assert is noop
|
||||||
@ -196,6 +199,11 @@ struct ggml_cgraph {
|
|||||||
|
|
||||||
struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1);
|
struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1);
|
||||||
|
|
||||||
|
// Memory allocation
|
||||||
|
|
||||||
|
void * ggml_aligned_malloc(size_t size);
|
||||||
|
void ggml_aligned_free(void * ptr, size_t size);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -20,6 +20,7 @@
|
|||||||
#include "shaderop_mul_mat_q8_0.h"
|
#include "shaderop_mul_mat_q8_0.h"
|
||||||
#include "shaderop_mul_mat_q4_0.h"
|
#include "shaderop_mul_mat_q4_0.h"
|
||||||
#include "shaderop_mul_mat_q4_1.h"
|
#include "shaderop_mul_mat_q4_1.h"
|
||||||
|
#include "shaderop_mul_mat_q4_k.h"
|
||||||
#include "shaderop_mul_mat_q6_k.h"
|
#include "shaderop_mul_mat_q6_k.h"
|
||||||
#include "shaderop_mul_mat_mat_f32.h"
|
#include "shaderop_mul_mat_mat_f32.h"
|
||||||
#include "shaderop_getrows_f32.h"
|
#include "shaderop_getrows_f32.h"
|
||||||
@ -42,6 +43,7 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@ -273,18 +275,9 @@ static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t mem
|
|||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
// public API returns a C-style array
|
static std::vector<ggml_vk_device>& ggml_vk_available_devices() {
|
||||||
ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count) {
|
static std::vector<ggml_vk_device> devices = ggml_vk_available_devices_internal(0);
|
||||||
auto devices = ggml_vk_available_devices_internal(memoryRequired);
|
return devices;
|
||||||
*count = devices.size();
|
|
||||||
if (devices.empty()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t nbytes = sizeof (ggml_vk_device) * (devices.size());
|
|
||||||
auto * arr = static_cast<ggml_vk_device *>(malloc(nbytes));
|
|
||||||
memcpy(arr, devices.data(), nbytes);
|
|
||||||
return arr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_filterByVendor(std::vector<ggml_vk_device>& devices, const std::string& targetVendor) {
|
static void ggml_vk_filterByVendor(std::vector<ggml_vk_device>& devices, const std::string& targetVendor) {
|
||||||
@ -341,7 +334,7 @@ ggml_vk_device ggml_vk_current_device() {
|
|||||||
if (!komputeManager()->hasDevice())
|
if (!komputeManager()->hasDevice())
|
||||||
return ggml_vk_device();
|
return ggml_vk_device();
|
||||||
|
|
||||||
auto devices = ggml_vk_available_devices_internal(0);
|
auto devices = ggml_vk_available_devices();
|
||||||
ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
|
ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
|
||||||
GGML_ASSERT(!devices.empty());
|
GGML_ASSERT(!devices.empty());
|
||||||
return devices.front();
|
return devices.front();
|
||||||
@ -1075,6 +1068,40 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
|
|||||||
ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
|
ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_mul_mat_q4_k(
|
||||||
|
kp::Sequence& seq,
|
||||||
|
const std::shared_ptr<kp::Tensor>& inA,
|
||||||
|
const std::shared_ptr<kp::Tensor>& inB,
|
||||||
|
const std::shared_ptr<kp::Tensor>& out,
|
||||||
|
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
||||||
|
int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
|
||||||
|
int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
|
||||||
|
int32_t ne1, int32_t r2, int32_t r3
|
||||||
|
) {
|
||||||
|
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
|
||||||
|
kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
|
||||||
|
|
||||||
|
struct PushConstants {
|
||||||
|
uint32_t inAOff, inBOff, outOff;
|
||||||
|
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
|
||||||
|
} pushConsts {
|
||||||
|
0, 0, 0,
|
||||||
|
ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
|
||||||
|
};
|
||||||
|
|
||||||
|
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
||||||
|
if (!komputeManager()->hasAlgorithm(__func__)) {
|
||||||
|
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
|
||||||
|
} else {
|
||||||
|
s_algo = komputeManager()->getAlgorithm(__func__);
|
||||||
|
s_algo->setTensors({inA, inB, out});
|
||||||
|
s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
|
||||||
|
s_algo->setPushConstants<PushConstants>({pushConsts});
|
||||||
|
s_algo->updateDescriptors(s_kompute_context->pool.get());
|
||||||
|
}
|
||||||
|
seq.record<kp::OpAlgoDispatch>(s_algo);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_mul_mat_q6_k(
|
static void ggml_vk_mul_mat_q6_k(
|
||||||
kp::Sequence& seq,
|
kp::Sequence& seq,
|
||||||
const std::shared_ptr<kp::Tensor>& inA,
|
const std::shared_ptr<kp::Tensor>& inA,
|
||||||
@ -1323,17 +1350,7 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
|
|||||||
ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
|
ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
|
static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
||||||
switch (op->type) {
|
|
||||||
case GGML_TYPE_F16:
|
|
||||||
case GGML_TYPE_F32:
|
|
||||||
case GGML_TYPE_Q4_0:
|
|
||||||
case GGML_TYPE_Q4_1:
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (op->op) {
|
switch (op->op) {
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(op)) {
|
switch (ggml_get_unary_op(op)) {
|
||||||
@ -1402,6 +1419,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
|
|||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
|
case GGML_TYPE_Q4_K:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
;
|
;
|
||||||
@ -1410,6 +1428,8 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
|
|||||||
;
|
;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
GGML_UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
|
static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
|
||||||
@ -1458,11 +1478,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|||||||
|
|
||||||
any_commands_recorded = true;
|
any_commands_recorded = true;
|
||||||
|
|
||||||
if (!ggml_vk_supports_op(dst)) {
|
|
||||||
fprintf(stderr, "%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
|
||||||
GGML_ABORT("unsupported op");
|
|
||||||
}
|
|
||||||
|
|
||||||
const int32_t ne00 = src0 ? src0->ne[0] : 0;
|
const int32_t ne00 = src0 ? src0->ne[0] : 0;
|
||||||
const int32_t ne01 = src0 ? src0->ne[1] : 0;
|
const int32_t ne01 = src0 ? src0->ne[1] : 0;
|
||||||
const int32_t ne02 = src0 ? src0->ne[2] : 0;
|
const int32_t ne02 = src0 ? src0->ne[2] : 0;
|
||||||
@ -1656,6 +1671,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|||||||
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
|
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
|
case GGML_TYPE_Q4_K:
|
||||||
|
ggml_vk_mul_mat_q4_k(
|
||||||
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
||||||
|
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
|
||||||
|
);
|
||||||
|
break;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
ggml_vk_mul_mat_q6_k(
|
ggml_vk_mul_mat_q6_k(
|
||||||
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
||||||
@ -1820,11 +1841,6 @@ static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char * ggml_backend_kompute_buffer_get_name(ggml_backend_buffer_t buffer) {
|
|
||||||
auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buffer->buft->context);
|
|
||||||
return ctx->name.c_str();
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
auto * memory = (ggml_vk_memory *)buffer->context;
|
auto * memory = (ggml_vk_memory *)buffer->context;
|
||||||
if (ggml_vk_has_device()) {
|
if (ggml_vk_has_device()) {
|
||||||
@ -1868,7 +1884,6 @@ static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint
|
|||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
|
static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
|
||||||
/* .get_name = */ ggml_backend_kompute_buffer_get_name,
|
|
||||||
/* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
|
/* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
|
||||||
/* .get_base = */ ggml_backend_kompute_buffer_get_base,
|
/* .get_base = */ ggml_backend_kompute_buffer_get_base,
|
||||||
/* .init_tensor = */ NULL,
|
/* .init_tensor = */ NULL,
|
||||||
@ -1913,25 +1928,31 @@ static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
|
ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
|
||||||
static std::vector<ggml_backend_buffer_type> bufts = []() {
|
static std::mutex mutex;
|
||||||
std::vector<ggml_backend_buffer_type> vec;
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
auto devices = ggml_vk_available_devices_internal(0);
|
|
||||||
vec.reserve(devices.size());
|
|
||||||
|
|
||||||
for (const auto & dev : devices) {
|
auto devices = ggml_vk_available_devices();
|
||||||
vec.push_back({
|
int32_t device_count = (int32_t) devices.size();
|
||||||
|
GGML_ASSERT(device < device_count);
|
||||||
|
GGML_ASSERT(devices.size() <= GGML_KOMPUTE_MAX_DEVICES);
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type
|
||||||
|
ggml_backend_kompute_buffer_types[GGML_KOMPUTE_MAX_DEVICES];
|
||||||
|
|
||||||
|
static bool ggml_backend_kompute_buffer_type_initialized = false;
|
||||||
|
|
||||||
|
if (!ggml_backend_kompute_buffer_type_initialized) {
|
||||||
|
for (int32_t i = 0; i < device_count; i++) {
|
||||||
|
ggml_backend_kompute_buffer_types[i] = {
|
||||||
/* .iface = */ ggml_backend_kompute_buffer_type_interface,
|
/* .iface = */ ggml_backend_kompute_buffer_type_interface,
|
||||||
/* .device = */ nullptr,
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), i),
|
||||||
/* .context = */ new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment, dev.maxAlloc)
|
/* .context = */ new ggml_backend_kompute_buffer_type_context{ i, devices[i].bufferAlignment, devices[i].maxAlloc },
|
||||||
});
|
};
|
||||||
|
}
|
||||||
|
ggml_backend_kompute_buffer_type_initialized = true;
|
||||||
}
|
}
|
||||||
return vec;
|
|
||||||
}();
|
|
||||||
|
|
||||||
auto it = std::find_if(bufts.begin(), bufts.end(), [device](const ggml_backend_buffer_type & t) {
|
return &ggml_backend_kompute_buffer_types[device];
|
||||||
return device == static_cast<ggml_backend_kompute_buffer_type_context *>(t.context)->device;
|
|
||||||
});
|
|
||||||
return it < bufts.end() ? &*it : nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// backend
|
// backend
|
||||||
@ -1953,31 +1974,15 @@ static void ggml_backend_kompute_free(ggml_backend_t backend) {
|
|||||||
delete backend;
|
delete backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(ggml_backend_t backend) {
|
|
||||||
auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
|
|
||||||
return ggml_backend_kompute_buffer_type(ctx->device);
|
|
||||||
}
|
|
||||||
|
|
||||||
static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||||
auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
|
auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
|
||||||
ggml_vk_graph_compute(ctx, cgraph);
|
ggml_vk_graph_compute(ctx, cgraph);
|
||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
|
||||||
GGML_UNUSED(backend);
|
|
||||||
return ggml_vk_supports_op(op);
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool ggml_backend_kompute_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
|
||||||
GGML_UNUSED(backend);
|
|
||||||
return buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name;
|
|
||||||
}
|
|
||||||
|
|
||||||
static struct ggml_backend_i kompute_backend_i = {
|
static struct ggml_backend_i kompute_backend_i = {
|
||||||
/* .get_name = */ ggml_backend_kompute_name,
|
/* .get_name = */ ggml_backend_kompute_name,
|
||||||
/* .free = */ ggml_backend_kompute_free,
|
/* .free = */ ggml_backend_kompute_free,
|
||||||
/* .get_default_buffer_type = */ ggml_backend_kompute_get_default_buffer_type,
|
|
||||||
/* .set_tensor_async = */ NULL,
|
/* .set_tensor_async = */ NULL,
|
||||||
/* .get_tensor_async = */ NULL,
|
/* .get_tensor_async = */ NULL,
|
||||||
/* .cpy_tensor_async = */ NULL,
|
/* .cpy_tensor_async = */ NULL,
|
||||||
@ -1987,9 +1992,6 @@ static struct ggml_backend_i kompute_backend_i = {
|
|||||||
/* .graph_plan_update = */ NULL,
|
/* .graph_plan_update = */ NULL,
|
||||||
/* .graph_plan_compute = */ NULL,
|
/* .graph_plan_compute = */ NULL,
|
||||||
/* .graph_compute = */ ggml_backend_kompute_graph_compute,
|
/* .graph_compute = */ ggml_backend_kompute_graph_compute,
|
||||||
/* .supports_op = */ ggml_backend_kompute_supports_op,
|
|
||||||
/* .supports_buft = */ ggml_backend_kompute_supports_buft,
|
|
||||||
/* .offload_op = */ NULL,
|
|
||||||
/* .event_record = */ NULL,
|
/* .event_record = */ NULL,
|
||||||
/* .event_wait = */ NULL,
|
/* .event_wait = */ NULL,
|
||||||
};
|
};
|
||||||
@ -2006,7 +2008,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
|
|||||||
ggml_backend_t kompute_backend = new ggml_backend {
|
ggml_backend_t kompute_backend = new ggml_backend {
|
||||||
/* .guid = */ ggml_backend_kompute_guid(),
|
/* .guid = */ ggml_backend_kompute_guid(),
|
||||||
/* .interface = */ kompute_backend_i,
|
/* .interface = */ kompute_backend_i,
|
||||||
/* .device = */ nullptr,
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), device),
|
||||||
/* .context = */ s_kompute_context,
|
/* .context = */ s_kompute_context,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -2016,3 +2018,167 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
|
|||||||
bool ggml_backend_is_kompute(ggml_backend_t backend) {
|
bool ggml_backend_is_kompute(ggml_backend_t backend) {
|
||||||
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
|
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static size_t ggml_backend_kompute_get_device_count() {
|
||||||
|
auto devices = ggml_vk_available_devices();
|
||||||
|
return devices.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_kompute_get_device_description(int device, char * description, size_t description_size) {
|
||||||
|
auto devices = ggml_vk_available_devices();
|
||||||
|
GGML_ASSERT((size_t) device < devices.size());
|
||||||
|
snprintf(description, description_size, "%s", devices[device].name);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_kompute_get_device_memory(int device, size_t * free, size_t * total) {
|
||||||
|
auto devices = ggml_vk_available_devices();
|
||||||
|
GGML_ASSERT((size_t) device < devices.size());
|
||||||
|
*total = devices[device].heapSize;
|
||||||
|
*free = devices[device].heapSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////
|
||||||
|
|
||||||
|
struct ggml_backend_kompute_device_context {
|
||||||
|
int device;
|
||||||
|
std::string name;
|
||||||
|
std::string description;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const char * ggml_backend_kompute_device_get_name(ggml_backend_dev_t dev) {
|
||||||
|
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
||||||
|
return ctx->name.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * ggml_backend_kompute_device_get_description(ggml_backend_dev_t dev) {
|
||||||
|
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
||||||
|
return ctx->description.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_kompute_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||||
|
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
||||||
|
ggml_backend_kompute_get_device_memory(ctx->device, free, total);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||||
|
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
||||||
|
return ggml_backend_kompute_buffer_type(ctx->device);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_kompute_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||||
|
if (buft->iface.get_name != ggml_backend_kompute_buffer_type_get_name) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
||||||
|
ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context;
|
||||||
|
|
||||||
|
return buft_ctx->device == ctx->device;
|
||||||
|
}
|
||||||
|
|
||||||
|
static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type(ggml_backend_dev_t dev) {
|
||||||
|
GGML_UNUSED(dev);
|
||||||
|
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_kompute_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
||||||
|
props->name = ggml_backend_kompute_device_get_name(dev);
|
||||||
|
props->description = ggml_backend_kompute_device_get_description(dev);
|
||||||
|
props->type = ggml_backend_kompute_device_get_type(dev);
|
||||||
|
ggml_backend_kompute_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||||
|
props->caps = {
|
||||||
|
/* async = */ false,
|
||||||
|
/* host_buffer = */ false,
|
||||||
|
/* .buffer_from_host_ptr = */ false,
|
||||||
|
/* events = */ false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_t ggml_backend_kompute_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||||
|
GGML_UNUSED(params);
|
||||||
|
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
||||||
|
return ggml_backend_kompute_init(ctx->device);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_kompute_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||||
|
const int min_batch_size = 32;
|
||||||
|
|
||||||
|
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
|
||||||
|
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
|
||||||
|
|
||||||
|
GGML_UNUSED(dev);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const struct ggml_backend_device_i ggml_backend_kompute_device_i = {
|
||||||
|
/* .get_name = */ ggml_backend_kompute_device_get_name,
|
||||||
|
/* .get_description = */ ggml_backend_kompute_device_get_description,
|
||||||
|
/* .get_memory = */ ggml_backend_kompute_device_get_memory,
|
||||||
|
/* .get_type = */ ggml_backend_kompute_device_get_type,
|
||||||
|
/* .get_props = */ ggml_backend_kompute_device_get_props,
|
||||||
|
/* .init_backend = */ ggml_backend_kompute_device_init,
|
||||||
|
/* .get_buffer_type = */ ggml_backend_kompute_device_get_buffer_type,
|
||||||
|
/* .get_host_buffer_type = */ NULL,
|
||||||
|
/* .buffer_from_host_ptr = */ NULL,
|
||||||
|
/* .supports_op = */ ggml_backend_kompute_device_supports_op,
|
||||||
|
/* .supports_buft = */ ggml_backend_kompute_device_supports_buft,
|
||||||
|
/* .offload_op = */ ggml_backend_kompute_device_offload_op,
|
||||||
|
/* .event_new = */ NULL,
|
||||||
|
/* .event_free = */ NULL,
|
||||||
|
/* .event_synchronize = */ NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
static const char * ggml_backend_kompute_reg_get_name(ggml_backend_reg_t reg) {
|
||||||
|
GGML_UNUSED(reg);
|
||||||
|
return "Kompute";
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t ggml_backend_kompute_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||||
|
GGML_UNUSED(reg);
|
||||||
|
return ggml_backend_kompute_get_device_count();
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_dev_t ggml_backend_kompute_reg_get_device(ggml_backend_reg_t reg, size_t device) {
|
||||||
|
static std::vector<ggml_backend_dev_t> devices;
|
||||||
|
|
||||||
|
static bool initialized = false;
|
||||||
|
|
||||||
|
{
|
||||||
|
static std::mutex mutex;
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
if (!initialized) {
|
||||||
|
for (size_t i = 0; i < ggml_backend_kompute_get_device_count(); i++) {
|
||||||
|
ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context;
|
||||||
|
char desc[256];
|
||||||
|
ggml_backend_kompute_get_device_description(i, desc, sizeof(desc));
|
||||||
|
ctx->device = i;
|
||||||
|
ctx->name = "Kompute" + std::to_string(i);
|
||||||
|
ctx->description = desc;
|
||||||
|
devices.push_back(new ggml_backend_device {
|
||||||
|
/* .iface = */ ggml_backend_kompute_device_i,
|
||||||
|
/* .reg = */ reg,
|
||||||
|
/* .context = */ ctx,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
initialized = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(device < devices.size());
|
||||||
|
return devices[device];
|
||||||
|
}
|
||||||
|
|
||||||
|
static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
|
||||||
|
/* .get_name = */ ggml_backend_kompute_reg_get_name,
|
||||||
|
/* .get_device_count = */ ggml_backend_kompute_reg_get_device_count,
|
||||||
|
/* .get_device = */ ggml_backend_kompute_reg_get_device,
|
||||||
|
/* .get_proc_address = */ NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_backend_reg_t ggml_backend_kompute_reg() {
|
||||||
|
static ggml_backend_reg reg = {
|
||||||
|
/* .iface = */ ggml_backend_kompute_reg_i,
|
||||||
|
/* .context = */ nullptr,
|
||||||
|
};
|
||||||
|
|
||||||
|
return ®
|
||||||
|
}
|
||||||
|
@ -242,6 +242,8 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
|
||||||
|
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
||||||
@ -273,6 +275,8 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_SIN,
|
GGML_METAL_KERNEL_TYPE_SIN,
|
||||||
GGML_METAL_KERNEL_TYPE_COS,
|
GGML_METAL_KERNEL_TYPE_COS,
|
||||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||||
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||||
|
|
||||||
GGML_METAL_KERNEL_TYPE_COUNT
|
GGML_METAL_KERNEL_TYPE_COUNT
|
||||||
};
|
};
|
||||||
@ -687,6 +691,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
||||||
@ -718,6 +724,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
[metal_library release];
|
[metal_library release];
|
||||||
@ -846,8 +854,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
return op->src[0]->type == GGML_TYPE_F16;
|
return op->src[0]->type == GGML_TYPE_F16;
|
||||||
case GGML_OP_POOL_1D:
|
case GGML_OP_POOL_1D:
|
||||||
case GGML_OP_POOL_2D:
|
|
||||||
return false;
|
return false;
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
@ -1009,19 +1017,21 @@ static void ggml_metal_encode_node(
|
|||||||
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
||||||
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
||||||
|
|
||||||
//GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
#if 0
|
||||||
//if (src0) {
|
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
||||||
// GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
|
if (src0) {
|
||||||
// ggml_is_contiguous(src0), src0->name);
|
GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
|
||||||
//}
|
ggml_is_contiguous(src0), src0->name);
|
||||||
//if (src1) {
|
}
|
||||||
// GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
|
if (src1) {
|
||||||
// ggml_is_contiguous(src1), src1->name);
|
GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
||||||
//}
|
ggml_is_contiguous(src1), src1->name);
|
||||||
//if (dst) {
|
}
|
||||||
// GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
|
if (dst) {
|
||||||
// dst->name);
|
GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
||||||
//}
|
dst->name);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
id<MTLDevice> device = ctx_dev->mtl_device;
|
id<MTLDevice> device = ctx_dev->mtl_device;
|
||||||
|
|
||||||
@ -1841,14 +1851,16 @@ static void ggml_metal_encode_node(
|
|||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
||||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
||||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
|
||||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
||||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
|
||||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
|
||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
|
||||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||||
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
|
||||||
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
|
||||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
} else {
|
} else {
|
||||||
@ -2017,16 +2029,18 @@ static void ggml_metal_encode_node(
|
|||||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
||||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
||||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
|
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
||||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
||||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
|
||||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
|
||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
|
||||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
||||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
||||||
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
|
||||||
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
||||||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
||||||
@ -2079,6 +2093,9 @@ static void ggml_metal_encode_node(
|
|||||||
|
|
||||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_ASSERT(ne03 == 1);
|
||||||
|
GGML_ASSERT(ne13 == 1);
|
||||||
|
|
||||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||||
// to the matrix-vector kernel
|
// to the matrix-vector kernel
|
||||||
// ne20 = n_used_experts
|
// ne20 = n_used_experts
|
||||||
@ -2584,6 +2601,8 @@ static void ggml_metal_encode_node(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||||
@ -2613,30 +2632,54 @@ static void ggml_metal_encode_node(
|
|||||||
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
||||||
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
||||||
|
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
|
||||||
|
|
||||||
|
const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
|
||||||
|
|
||||||
switch (dst->type) {
|
switch (dst->type) {
|
||||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
|
case GGML_TYPE_F32: {
|
||||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
pipeline = (is_gt_mttpt ?
|
||||||
|
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
|
||||||
|
:
|
||||||
|
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_F16: {
|
||||||
|
pipeline = (is_gt_mttpt ?
|
||||||
|
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
|
||||||
|
:
|
||||||
|
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
|
||||||
|
} break;
|
||||||
default: GGML_ABORT("fatal error");
|
default: GGML_ABORT("fatal error");
|
||||||
};
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
[encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
|
||||||
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
|
||||||
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
|
||||||
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
|
||||||
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
|
||||||
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
|
||||||
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
|
||||||
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
|
||||||
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
|
||||||
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
|
||||||
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
|
||||||
|
|
||||||
|
if (is_gt_mttpt) {
|
||||||
|
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
|
||||||
|
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
|
||||||
|
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
|
||||||
|
|
||||||
|
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
|
||||||
|
|
||||||
|
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
||||||
|
} else {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
{
|
{
|
||||||
@ -3040,6 +3083,64 @@ static void ggml_metal_encode_node(
|
|||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
|
||||||
|
|
||||||
|
const int32_t * opts = dst->op_params;
|
||||||
|
enum ggml_op_pool op = opts[0];
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
switch (src0t) {
|
||||||
|
case GGML_TYPE_F32: {
|
||||||
|
switch(op) {
|
||||||
|
case GGML_OP_POOL_AVG:
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
|
||||||
|
case GGML_OP_POOL_MAX:
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
|
||||||
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
const int32_t k0 = opts[1];
|
||||||
|
const int32_t k1 = opts[2];
|
||||||
|
const int32_t s0 = opts[3];
|
||||||
|
const int32_t s1 = opts[4];
|
||||||
|
const int32_t p0 = opts[5];
|
||||||
|
const int32_t p1 = opts[6];
|
||||||
|
|
||||||
|
const int64_t IH = src0->ne[1];
|
||||||
|
const int64_t IW = src0->ne[0];
|
||||||
|
|
||||||
|
const int64_t N = dst->ne[3];
|
||||||
|
const int64_t OC = dst->ne[2];
|
||||||
|
const int64_t OH = dst->ne[1];
|
||||||
|
const int64_t OW = dst->ne[0];
|
||||||
|
|
||||||
|
const int64_t parallel_elements = N * OC * OH * OW;
|
||||||
|
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
|
||||||
|
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
|
||||||
|
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
|
||||||
|
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
|
||||||
|
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
|
||||||
|
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
|
||||||
|
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
|
||||||
|
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
|
||||||
|
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
|
||||||
|
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
|
||||||
|
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
|
||||||
|
[encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||||
@ -3185,12 +3286,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
|
|
||||||
// backend interface
|
// backend interface
|
||||||
|
|
||||||
static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
|
|
||||||
return "Metal";
|
|
||||||
|
|
||||||
UNUSED(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
||||||
|
|
||||||
@ -3245,7 +3340,6 @@ static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_
|
|||||||
}
|
}
|
||||||
|
|
||||||
static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
|
static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
|
||||||
/* .get_name = */ ggml_backend_metal_buffer_get_name,
|
|
||||||
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
||||||
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
||||||
/* .init_tensor = */ NULL,
|
/* .init_tensor = */ NULL,
|
||||||
@ -3370,6 +3464,29 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
|||||||
return &ggml_backend_buffer_type_metal;
|
return &ggml_backend_buffer_type_metal;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
|
||||||
|
return "Metal_Mapped";
|
||||||
|
|
||||||
|
UNUSED(buft);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) {
|
||||||
|
static struct ggml_backend_buffer_type ggml_backend_buffer_from_ptr_type_metal = {
|
||||||
|
/* .iface = */ {
|
||||||
|
/* .get_name = */ ggml_backend_metal_buffer_from_ptr_type_get_name,
|
||||||
|
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
|
||||||
|
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
||||||
|
/* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
|
||||||
|
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
||||||
|
/* .is_host = */ ggml_backend_metal_buffer_type_is_host,
|
||||||
|
},
|
||||||
|
/* .device = */ &g_ggml_backend_metal_device,
|
||||||
|
/* .context = */ NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
return &ggml_backend_buffer_from_ptr_type_metal;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr
|
// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr
|
||||||
ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
|
ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
|
||||||
struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
|
struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
|
||||||
@ -3446,7 +3563,7 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// backend
|
// backend
|
||||||
@ -3467,12 +3584,6 @@ static void ggml_backend_metal_free(ggml_backend_t backend) {
|
|||||||
free(backend);
|
free(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
|
|
||||||
return ggml_backend_metal_buffer_type();
|
|
||||||
|
|
||||||
UNUSED(backend);
|
|
||||||
}
|
|
||||||
|
|
||||||
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||||
return ggml_metal_graph_compute(backend, cgraph);
|
return ggml_metal_graph_compute(backend, cgraph);
|
||||||
}
|
}
|
||||||
@ -3539,7 +3650,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|||||||
static struct ggml_backend_i ggml_backend_metal_i = {
|
static struct ggml_backend_i ggml_backend_metal_i = {
|
||||||
/* .get_name = */ ggml_backend_metal_name,
|
/* .get_name = */ ggml_backend_metal_name,
|
||||||
/* .free = */ ggml_backend_metal_free,
|
/* .free = */ ggml_backend_metal_free,
|
||||||
/* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
|
|
||||||
/* .set_tensor_async = */ NULL,
|
/* .set_tensor_async = */ NULL,
|
||||||
/* .get_tensor_async = */ NULL,
|
/* .get_tensor_async = */ NULL,
|
||||||
/* .cpy_tensor_async = */ NULL,
|
/* .cpy_tensor_async = */ NULL,
|
||||||
@ -3549,9 +3659,6 @@ static struct ggml_backend_i ggml_backend_metal_i = {
|
|||||||
/* .graph_plan_update = */ NULL,
|
/* .graph_plan_update = */ NULL,
|
||||||
/* .graph_plan_compute = */ NULL,
|
/* .graph_plan_compute = */ NULL,
|
||||||
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
||||||
/* .supports_op = */ NULL,
|
|
||||||
/* .supports_buft = */ NULL,
|
|
||||||
/* .offload_op = */ NULL,
|
|
||||||
/* .event_record = */ NULL,
|
/* .event_record = */ NULL,
|
||||||
/* .event_wait = */ NULL,
|
/* .event_wait = */ NULL,
|
||||||
};
|
};
|
||||||
@ -3646,7 +3753,7 @@ static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t
|
|||||||
}
|
}
|
||||||
|
|
||||||
static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
|
static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
|
||||||
return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
|
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||||
|
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -57,8 +57,9 @@ struct socket_t {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// ggml_tensor is serialized into rpc_tensor
|
// all RPC structures must be packed
|
||||||
#pragma pack(push, 1)
|
#pragma pack(push, 1)
|
||||||
|
// ggml_tensor is serialized into rpc_tensor
|
||||||
struct rpc_tensor {
|
struct rpc_tensor {
|
||||||
uint64_t id;
|
uint64_t id;
|
||||||
uint32_t type;
|
uint32_t type;
|
||||||
@ -76,7 +77,6 @@ struct rpc_tensor {
|
|||||||
|
|
||||||
char padding[4];
|
char padding[4];
|
||||||
};
|
};
|
||||||
#pragma pack(pop)
|
|
||||||
|
|
||||||
static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
|
static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
|
||||||
|
|
||||||
@ -96,6 +96,65 @@ enum rpc_cmd {
|
|||||||
RPC_CMD_COUNT,
|
RPC_CMD_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct rpc_msg_alloc_buffer_req {
|
||||||
|
uint64_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rpc_msg_alloc_buffer_rsp {
|
||||||
|
uint64_t remote_ptr;
|
||||||
|
uint64_t remote_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rpc_msg_get_alignment_rsp {
|
||||||
|
uint64_t alignment;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rpc_msg_get_max_size_rsp {
|
||||||
|
uint64_t max_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rpc_msg_buffer_get_base_req {
|
||||||
|
uint64_t remote_ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rpc_msg_buffer_get_base_rsp {
|
||||||
|
uint64_t base_ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rpc_msg_free_buffer_req {
|
||||||
|
uint64_t remote_ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rpc_msg_buffer_clear_req {
|
||||||
|
uint64_t remote_ptr;
|
||||||
|
uint8_t value;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rpc_msg_get_tensor_req {
|
||||||
|
rpc_tensor tensor;
|
||||||
|
uint64_t offset;
|
||||||
|
uint64_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rpc_msg_copy_tensor_req {
|
||||||
|
rpc_tensor src;
|
||||||
|
rpc_tensor dst;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rpc_msg_copy_tensor_rsp {
|
||||||
|
uint8_t result;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rpc_msg_graph_compute_rsp {
|
||||||
|
uint8_t result;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rpc_msg_get_device_memory_rsp {
|
||||||
|
uint64_t free_mem;
|
||||||
|
uint64_t total_mem;
|
||||||
|
};
|
||||||
|
#pragma pack(pop)
|
||||||
|
|
||||||
// RPC data structures
|
// RPC data structures
|
||||||
|
|
||||||
static ggml_guid_t ggml_backend_rpc_guid() {
|
static ggml_guid_t ggml_backend_rpc_guid() {
|
||||||
@ -119,7 +178,6 @@ struct ggml_backend_rpc_buffer_context {
|
|||||||
std::shared_ptr<socket_t> sock;
|
std::shared_ptr<socket_t> sock;
|
||||||
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
|
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
|
||||||
uint64_t remote_ptr;
|
uint64_t remote_ptr;
|
||||||
std::string name;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// RPC helper functions
|
// RPC helper functions
|
||||||
@ -240,6 +298,38 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
|
||||||
|
if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return send_data(sockfd, msg, msg_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
|
||||||
|
uint64_t size;
|
||||||
|
if (!recv_data(sockfd, &size, sizeof(size))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (size != msg_size) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return recv_data(sockfd, msg, msg_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
|
||||||
|
uint64_t size;
|
||||||
|
if (!recv_data(sockfd, &size, sizeof(size))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
input.resize(size);
|
||||||
|
} catch (const std::bad_alloc & e) {
|
||||||
|
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return recv_data(sockfd, input.data(), size);
|
||||||
|
}
|
||||||
|
|
||||||
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
|
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
|
||||||
size_t pos = endpoint.find(':');
|
size_t pos = endpoint.find(':');
|
||||||
if (pos == std::string::npos) {
|
if (pos == std::string::npos) {
|
||||||
@ -252,28 +342,27 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
|
|||||||
|
|
||||||
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
|
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
|
||||||
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
|
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
|
||||||
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
|
||||||
uint8_t cmd_byte = cmd;
|
uint8_t cmd_byte = cmd;
|
||||||
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
|
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
uint64_t input_size = input.size();
|
|
||||||
if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
|
if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!send_data(sock->fd, input.data(), input.size())) {
|
if (!send_data(sock->fd, input, input_size)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
uint64_t output_size;
|
// TODO: currently the output_size is always known, do we need support for commands with variable output size?
|
||||||
if (!recv_data(sock->fd, &output_size, sizeof(output_size))) {
|
// even if we do, we can skip sending output_size from the server for commands with known output size
|
||||||
|
uint64_t out_size;
|
||||||
|
if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (output_size == 0) {
|
if (out_size != output_size) {
|
||||||
output.clear();
|
return false;
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
output.resize(output_size);
|
if (!recv_data(sock->fd, output, output_size)) {
|
||||||
if (!recv_data(sock->fd, output.data(), output_size)) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
@ -319,21 +408,11 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|||||||
return sock;
|
return sock;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
|
|
||||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
||||||
return ctx->name.c_str();
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||||
// input serialization format: | remote_ptr (8 bytes) |
|
rpc_msg_free_buffer_req request = {ctx->remote_ptr};
|
||||||
std::vector<uint8_t> input(sizeof(uint64_t), 0);
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
|
||||||
uint64_t remote_ptr = ctx->remote_ptr;
|
|
||||||
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
|
||||||
std::vector<uint8_t> output;
|
|
||||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
|
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.empty());
|
|
||||||
delete ctx;
|
delete ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -342,20 +421,13 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|||||||
if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
|
if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
|
||||||
return ctx->base_cache[buffer];
|
return ctx->base_cache[buffer];
|
||||||
}
|
}
|
||||||
// input serialization format: | remote_ptr (8 bytes) |
|
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
|
||||||
std::vector<uint8_t> input(sizeof(uint64_t), 0);
|
rpc_msg_buffer_get_base_rsp response;
|
||||||
uint64_t remote_ptr = ctx->remote_ptr;
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
|
||||||
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
|
||||||
std::vector<uint8_t> output;
|
|
||||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
|
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
||||||
// output serialization format: | base_ptr (8 bytes) |
|
ctx->base_cache[buffer] = base_ptr;
|
||||||
uint64_t base_ptr;
|
return base_ptr;
|
||||||
memcpy(&base_ptr, output.data(), sizeof(base_ptr));
|
|
||||||
void * base = reinterpret_cast<void *>(base_ptr);
|
|
||||||
ctx->base_cache[buffer] = base;
|
|
||||||
return base;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
||||||
@ -405,26 +477,18 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
|
|||||||
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
||||||
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
||||||
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
||||||
std::vector<uint8_t> output;
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
|
||||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
|
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||||
// input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
|
rpc_msg_get_tensor_req request;
|
||||||
int input_size = sizeof(rpc_tensor) + 2*sizeof(uint64_t);
|
request.tensor = serialize_tensor(tensor);
|
||||||
std::vector<uint8_t> input(input_size, 0);
|
request.offset = offset;
|
||||||
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
request.size = size;
|
||||||
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
|
||||||
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
|
||||||
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
|
|
||||||
std::vector<uint8_t> output;
|
|
||||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
|
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == size);
|
|
||||||
// output serialization format: | data (size bytes) |
|
|
||||||
memcpy(data, output.data(), size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||||
@ -437,35 +501,23 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||||
// input serialization format: | rpc_tensor src | rpc_tensor dst |
|
rpc_msg_copy_tensor_req request;
|
||||||
int input_size = 2*sizeof(rpc_tensor);
|
request.src = serialize_tensor(src);
|
||||||
std::vector<uint8_t> input(input_size, 0);
|
request.dst = serialize_tensor(dst);
|
||||||
rpc_tensor rpc_src = serialize_tensor(src);
|
rpc_msg_copy_tensor_rsp response;
|
||||||
rpc_tensor rpc_dst = serialize_tensor(dst);
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
||||||
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
|
|
||||||
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
|
|
||||||
std::vector<uint8_t> output;
|
|
||||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
|
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
// output serialization format: | result (1 byte) |
|
return response.result;
|
||||||
GGML_ASSERT(output.size() == 1);
|
|
||||||
return output[0];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||||
// serialization format: | bufptr (8 bytes) | value (1 byte) |
|
rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
|
||||||
int input_size = sizeof(uint64_t) + sizeof(uint8_t);
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
|
||||||
std::vector<uint8_t> input(input_size, 0);
|
|
||||||
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
|
|
||||||
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
|
|
||||||
std::vector<uint8_t> output;
|
|
||||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
|
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
|
static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
|
||||||
/* .get_name = */ ggml_backend_rpc_buffer_get_name,
|
|
||||||
/* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
|
/* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
|
||||||
/* .get_base = */ ggml_backend_rpc_buffer_get_base,
|
/* .get_base = */ ggml_backend_rpc_buffer_get_base,
|
||||||
/* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
|
/* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
|
||||||
@ -484,25 +536,16 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
|
|||||||
|
|
||||||
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||||
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
||||||
// input serialization format: | size (8 bytes) |
|
rpc_msg_alloc_buffer_req request = {size};
|
||||||
int input_size = sizeof(uint64_t);
|
rpc_msg_alloc_buffer_rsp response;
|
||||||
std::vector<uint8_t> input(input_size, 0);
|
|
||||||
memcpy(input.data(), &size, sizeof(size));
|
|
||||||
std::vector<uint8_t> output;
|
|
||||||
auto sock = get_socket(buft_ctx->endpoint);
|
auto sock = get_socket(buft_ctx->endpoint);
|
||||||
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
|
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
if (response.remote_ptr != 0) {
|
||||||
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
|
||||||
uint64_t remote_ptr;
|
|
||||||
memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
|
|
||||||
size_t remote_size;
|
|
||||||
memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
|
|
||||||
if (remote_ptr != 0) {
|
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
||||||
ggml_backend_rpc_buffer_interface,
|
ggml_backend_rpc_buffer_interface,
|
||||||
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
|
new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
|
||||||
remote_size);
|
response.remote_size);
|
||||||
return buffer;
|
return buffer;
|
||||||
} else {
|
} else {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -510,16 +553,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
|
|||||||
}
|
}
|
||||||
|
|
||||||
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
|
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
|
||||||
// input serialization format: | 0 bytes |
|
rpc_msg_get_alignment_rsp response;
|
||||||
std::vector<uint8_t> input;
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
|
||||||
std::vector<uint8_t> output;
|
|
||||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
|
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
return response.alignment;
|
||||||
// output serialization format: | alignment (8 bytes) |
|
|
||||||
uint64_t alignment;
|
|
||||||
memcpy(&alignment, output.data(), sizeof(alignment));
|
|
||||||
return alignment;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||||
@ -528,16 +565,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
|
|||||||
}
|
}
|
||||||
|
|
||||||
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
|
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
|
||||||
// input serialization format: | 0 bytes |
|
rpc_msg_get_max_size_rsp response;
|
||||||
std::vector<uint8_t> input;
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
|
||||||
std::vector<uint8_t> output;
|
|
||||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
|
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
return response.max_size;
|
||||||
// output serialization format: | max_size (8 bytes) |
|
|
||||||
uint64_t max_size;
|
|
||||||
memcpy(&max_size, output.data(), sizeof(max_size));
|
|
||||||
return max_size;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||||
@ -571,11 +602,6 @@ static void ggml_backend_rpc_free(ggml_backend_t backend) {
|
|||||||
delete backend;
|
delete backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
|
|
||||||
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
|
||||||
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
||||||
UNUSED(backend);
|
UNUSED(backend);
|
||||||
// this is no-op because we don't have any async operations
|
// this is no-op because we don't have any async operations
|
||||||
@ -622,18 +648,16 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
|
|||||||
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
||||||
std::vector<uint8_t> input;
|
std::vector<uint8_t> input;
|
||||||
serialize_graph(cgraph, input);
|
serialize_graph(cgraph, input);
|
||||||
std::vector<uint8_t> output;
|
rpc_msg_graph_compute_rsp response;
|
||||||
auto sock = get_socket(rpc_ctx->endpoint);
|
auto sock = get_socket(rpc_ctx->endpoint);
|
||||||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
|
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == 1);
|
return (enum ggml_status)response.result;
|
||||||
return (enum ggml_status)output[0];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_i ggml_backend_rpc_interface = {
|
static ggml_backend_i ggml_backend_rpc_interface = {
|
||||||
/* .get_name = */ ggml_backend_rpc_name,
|
/* .get_name = */ ggml_backend_rpc_name,
|
||||||
/* .free = */ ggml_backend_rpc_free,
|
/* .free = */ ggml_backend_rpc_free,
|
||||||
/* .get_default_buffer_type = */ ggml_backend_rpc_get_default_buffer_type,
|
|
||||||
/* .set_tensor_async = */ NULL,
|
/* .set_tensor_async = */ NULL,
|
||||||
/* .get_tensor_async = */ NULL,
|
/* .get_tensor_async = */ NULL,
|
||||||
/* .cpy_tensor_async = */ NULL,
|
/* .cpy_tensor_async = */ NULL,
|
||||||
@ -643,9 +667,6 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
|||||||
/* .graph_plan_update = */ NULL,
|
/* .graph_plan_update = */ NULL,
|
||||||
/* .graph_plan_compute = */ NULL,
|
/* .graph_plan_compute = */ NULL,
|
||||||
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
|
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
|
||||||
/* .supports_op = */ NULL,
|
|
||||||
/* .supports_buft = */ NULL,
|
|
||||||
/* .offload_op = */ NULL,
|
|
||||||
/* .event_record = */ NULL,
|
/* .event_record = */ NULL,
|
||||||
/* .event_wait = */ NULL,
|
/* .event_wait = */ NULL,
|
||||||
};
|
};
|
||||||
@ -702,19 +723,11 @@ GGML_API bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
|
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
|
||||||
// input serialization format: | 0 bytes |
|
rpc_msg_get_device_memory_rsp response;
|
||||||
std::vector<uint8_t> input;
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
|
||||||
std::vector<uint8_t> output;
|
|
||||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
|
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
*free = response.free_mem;
|
||||||
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
*total = response.total_mem;
|
||||||
uint64_t free_mem;
|
|
||||||
memcpy(&free_mem, output.data(), sizeof(free_mem));
|
|
||||||
uint64_t total_mem;
|
|
||||||
memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
|
|
||||||
*free = free_mem;
|
|
||||||
*total = total_mem;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
||||||
@ -734,16 +747,16 @@ public:
|
|||||||
rpc_server(ggml_backend_t backend) : backend(backend) {}
|
rpc_server(ggml_backend_t backend) : backend(backend) {}
|
||||||
~rpc_server();
|
~rpc_server();
|
||||||
|
|
||||||
bool alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
||||||
void get_alignment(std::vector<uint8_t> & output);
|
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
||||||
void get_max_size(std::vector<uint8_t> & output);
|
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
||||||
bool buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
|
||||||
bool free_buffer(const std::vector<uint8_t> & input);
|
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
||||||
bool buffer_clear(const std::vector<uint8_t> & input);
|
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
||||||
bool set_tensor(const std::vector<uint8_t> & input);
|
bool set_tensor(const std::vector<uint8_t> & input);
|
||||||
bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
||||||
bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
||||||
bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
||||||
@ -757,80 +770,50 @@ private:
|
|||||||
std::unordered_set<ggml_backend_buffer_t> buffers;
|
std::unordered_set<ggml_backend_buffer_t> buffers;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool rpc_server::alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
|
||||||
// input serialization format: | size (8 bytes) |
|
|
||||||
if (input.size() != sizeof(uint64_t)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
uint64_t size;
|
|
||||||
memcpy(&size, input.data(), sizeof(size));
|
|
||||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
|
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
|
||||||
uint64_t remote_ptr = 0;
|
response.remote_ptr = 0;
|
||||||
uint64_t remote_size = 0;
|
response.remote_size = 0;
|
||||||
if (buffer != nullptr) {
|
if (buffer != nullptr) {
|
||||||
remote_ptr = reinterpret_cast<uint64_t>(buffer);
|
response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
|
||||||
remote_size = buffer->size;
|
response.remote_size = buffer->size;
|
||||||
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
|
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
|
||||||
buffers.insert(buffer);
|
buffers.insert(buffer);
|
||||||
} else {
|
} else {
|
||||||
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size);
|
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
|
||||||
}
|
}
|
||||||
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
|
||||||
output.resize(2*sizeof(uint64_t), 0);
|
|
||||||
memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
|
|
||||||
memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void rpc_server::get_alignment(std::vector<uint8_t> & output) {
|
void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
|
||||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
||||||
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
||||||
GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
|
GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
|
||||||
// output serialization format: | alignment (8 bytes) |
|
response.alignment = alignment;
|
||||||
output.resize(sizeof(uint64_t), 0);
|
|
||||||
memcpy(output.data(), &alignment, sizeof(alignment));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void rpc_server::get_max_size(std::vector<uint8_t> & output) {
|
void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
|
||||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
||||||
size_t max_size = ggml_backend_buft_get_max_size(buft);
|
size_t max_size = ggml_backend_buft_get_max_size(buft);
|
||||||
GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
|
GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
|
||||||
// output serialization format: | max_size (8 bytes) |
|
response.max_size = max_size;
|
||||||
output.resize(sizeof(uint64_t), 0);
|
|
||||||
memcpy(output.data(), &max_size, sizeof(max_size));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool rpc_server::buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
|
||||||
// input serialization format: | remote_ptr (8 bytes) |
|
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
|
||||||
if (input.size() != sizeof(uint64_t)) {
|
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
uint64_t remote_ptr;
|
|
||||||
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
|
|
||||||
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
|
|
||||||
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
|
|
||||||
if (buffers.find(buffer) == buffers.end()) {
|
if (buffers.find(buffer) == buffers.end()) {
|
||||||
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
void * base = ggml_backend_buffer_get_base(buffer);
|
void * base = ggml_backend_buffer_get_base(buffer);
|
||||||
// output serialization format: | base_ptr (8 bytes) |
|
response.base_ptr = reinterpret_cast<uint64_t>(base);
|
||||||
uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
|
|
||||||
output.resize(sizeof(uint64_t), 0);
|
|
||||||
memcpy(output.data(), &base_ptr, sizeof(base_ptr));
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
|
bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
|
||||||
// input serialization format: | remote_ptr (8 bytes) |
|
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
|
||||||
if (input.size() != sizeof(uint64_t)) {
|
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
uint64_t remote_ptr;
|
|
||||||
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
|
|
||||||
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
|
|
||||||
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
|
|
||||||
if (buffers.find(buffer) == buffers.end()) {
|
if (buffers.find(buffer) == buffers.end()) {
|
||||||
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
@ -840,22 +823,14 @@ bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool rpc_server::buffer_clear(const std::vector<uint8_t> & input) {
|
bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
|
||||||
// input serialization format: | remote_ptr (8 bytes) | value (1 byte) |
|
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
|
||||||
if (input.size() != sizeof(uint64_t) + sizeof(uint8_t)) {
|
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
uint64_t remote_ptr;
|
|
||||||
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
|
|
||||||
uint8_t value;
|
|
||||||
memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
|
|
||||||
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
|
|
||||||
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
|
|
||||||
if (buffers.find(buffer) == buffers.end()) {
|
if (buffers.find(buffer) == buffers.end()) {
|
||||||
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ggml_backend_buffer_clear(buffer, value);
|
ggml_backend_buffer_clear(buffer, request.value);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -930,74 +905,55 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool rpc_server::get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
|
||||||
// serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
|
|
||||||
if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
|
|
||||||
uint64_t offset;
|
|
||||||
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
|
|
||||||
uint64_t size;
|
|
||||||
memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size));
|
|
||||||
|
|
||||||
struct ggml_init_params params {
|
struct ggml_init_params params {
|
||||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
/*.mem_size =*/ ggml_tensor_overhead(),
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ true,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
struct ggml_context * ctx = ggml_init(params);
|
struct ggml_context * ctx = ggml_init(params);
|
||||||
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
|
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
||||||
if (tensor == nullptr) {
|
if (tensor == nullptr) {
|
||||||
GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
|
GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
|
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
|
||||||
|
|
||||||
// sanitize tensor->data
|
// sanitize tensor->data
|
||||||
{
|
{
|
||||||
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
||||||
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
||||||
|
|
||||||
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
|
if (request.tensor.data + request.offset < p0 ||
|
||||||
|
request.tensor.data + request.offset >= p1 ||
|
||||||
|
request.size > (p1 - request.tensor.data - request.offset)) {
|
||||||
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
|
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// output serialization format: | data (size bytes) |
|
response.resize(request.size, 0);
|
||||||
output.resize(size, 0);
|
ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
|
||||||
ggml_backend_tensor_get(tensor, output.data(), offset, size);
|
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool rpc_server::copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
|
||||||
// serialization format: | rpc_tensor src | rpc_tensor dst |
|
|
||||||
if (input.size() != 2*sizeof(rpc_tensor)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
|
|
||||||
const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
|
|
||||||
|
|
||||||
struct ggml_init_params params {
|
struct ggml_init_params params {
|
||||||
/*.mem_size =*/ 2*ggml_tensor_overhead(),
|
/*.mem_size =*/ 2*ggml_tensor_overhead(),
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ true,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
struct ggml_context * ctx = ggml_init(params);
|
struct ggml_context * ctx = ggml_init(params);
|
||||||
ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
|
ggml_tensor * src = deserialize_tensor(ctx, &request.src);
|
||||||
ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst);
|
ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
|
||||||
if (src == nullptr || dst == nullptr) {
|
if (src == nullptr || dst == nullptr) {
|
||||||
GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
|
GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
|
GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
|
||||||
bool result = ggml_backend_buffer_copy_tensor(src, dst);
|
response.result = ggml_backend_buffer_copy_tensor(src, dst);
|
||||||
// output serialization format: | result (1 byte) |
|
|
||||||
output.resize(1, 0);
|
|
||||||
output[0] = result;
|
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -1026,7 +982,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
|
||||||
// serialization format:
|
// serialization format:
|
||||||
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
||||||
if (input.size() < sizeof(uint32_t)) {
|
if (input.size() < sizeof(uint32_t)) {
|
||||||
@ -1066,9 +1022,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
|
|||||||
graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
|
graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
|
||||||
}
|
}
|
||||||
ggml_status status = ggml_backend_graph_compute(backend, graph);
|
ggml_status status = ggml_backend_graph_compute(backend, graph);
|
||||||
// output serialization format: | status (1 byte) |
|
response.result = status;
|
||||||
output.resize(1, 0);
|
|
||||||
output[0] = status;
|
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -1091,85 +1045,153 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|||||||
fprintf(stderr, "Unknown command: %d\n", cmd);
|
fprintf(stderr, "Unknown command: %d\n", cmd);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
std::vector<uint8_t> input;
|
|
||||||
std::vector<uint8_t> output;
|
|
||||||
uint64_t input_size;
|
|
||||||
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
input.resize(input_size);
|
|
||||||
} catch (const std::bad_alloc & e) {
|
|
||||||
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (!recv_data(sockfd, input.data(), input_size)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
bool ok = true;
|
|
||||||
switch (cmd) {
|
switch (cmd) {
|
||||||
case RPC_CMD_ALLOC_BUFFER: {
|
case RPC_CMD_ALLOC_BUFFER: {
|
||||||
ok = server.alloc_buffer(input, output);
|
rpc_msg_alloc_buffer_req request;
|
||||||
|
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rpc_msg_alloc_buffer_rsp response;
|
||||||
|
server.alloc_buffer(request, response);
|
||||||
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case RPC_CMD_GET_ALIGNMENT: {
|
case RPC_CMD_GET_ALIGNMENT: {
|
||||||
server.get_alignment(output);
|
if (!recv_msg(sockfd, nullptr, 0)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rpc_msg_get_alignment_rsp response;
|
||||||
|
server.get_alignment(response);
|
||||||
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case RPC_CMD_GET_MAX_SIZE: {
|
case RPC_CMD_GET_MAX_SIZE: {
|
||||||
server.get_max_size(output);
|
if (!recv_msg(sockfd, nullptr, 0)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rpc_msg_get_max_size_rsp response;
|
||||||
|
server.get_max_size(response);
|
||||||
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case RPC_CMD_BUFFER_GET_BASE: {
|
case RPC_CMD_BUFFER_GET_BASE: {
|
||||||
ok = server.buffer_get_base(input, output);
|
rpc_msg_buffer_get_base_req request;
|
||||||
|
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rpc_msg_buffer_get_base_rsp response;
|
||||||
|
if (!server.buffer_get_base(request, response)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case RPC_CMD_FREE_BUFFER: {
|
case RPC_CMD_FREE_BUFFER: {
|
||||||
ok = server.free_buffer(input);
|
rpc_msg_free_buffer_req request;
|
||||||
|
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!server.free_buffer(request)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!send_msg(sockfd, nullptr, 0)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case RPC_CMD_BUFFER_CLEAR: {
|
case RPC_CMD_BUFFER_CLEAR: {
|
||||||
ok = server.buffer_clear(input);
|
rpc_msg_buffer_clear_req request;
|
||||||
|
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!server.buffer_clear(request)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!send_msg(sockfd, nullptr, 0)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case RPC_CMD_SET_TENSOR: {
|
case RPC_CMD_SET_TENSOR: {
|
||||||
ok = server.set_tensor(input);
|
std::vector<uint8_t> input;
|
||||||
|
if (!recv_msg(sockfd, input)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!server.set_tensor(input)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!send_msg(sockfd, nullptr, 0)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case RPC_CMD_GET_TENSOR: {
|
case RPC_CMD_GET_TENSOR: {
|
||||||
ok = server.get_tensor(input, output);
|
rpc_msg_get_tensor_req request;
|
||||||
|
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::vector<uint8_t> response;
|
||||||
|
if (!server.get_tensor(request, response)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!send_msg(sockfd, response.data(), response.size())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case RPC_CMD_COPY_TENSOR: {
|
case RPC_CMD_COPY_TENSOR: {
|
||||||
ok = server.copy_tensor(input, output);
|
rpc_msg_copy_tensor_req request;
|
||||||
|
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rpc_msg_copy_tensor_rsp response;
|
||||||
|
if (!server.copy_tensor(request, response)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case RPC_CMD_GRAPH_COMPUTE: {
|
case RPC_CMD_GRAPH_COMPUTE: {
|
||||||
ok = server.graph_compute(input, output);
|
std::vector<uint8_t> input;
|
||||||
|
if (!recv_msg(sockfd, input)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rpc_msg_graph_compute_rsp response;
|
||||||
|
if (!server.graph_compute(input, response)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case RPC_CMD_GET_DEVICE_MEMORY: {
|
case RPC_CMD_GET_DEVICE_MEMORY: {
|
||||||
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
if (!recv_msg(sockfd, nullptr, 0)) {
|
||||||
output.resize(2*sizeof(uint64_t), 0);
|
return;
|
||||||
memcpy(output.data(), &free_mem, sizeof(free_mem));
|
}
|
||||||
memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
|
rpc_msg_get_device_memory_rsp response;
|
||||||
|
response.free_mem = free_mem;
|
||||||
|
response.total_mem = total_mem;
|
||||||
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
fprintf(stderr, "Unknown command: %d\n", cmd);
|
fprintf(stderr, "Unknown command: %d\n", cmd);
|
||||||
ok = false;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!ok) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
uint64_t output_size = output.size();
|
|
||||||
if (!send_data(sockfd, &output_size, sizeof(output_size))) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (!send_data(sockfd, output.data(), output_size)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1240,7 +1262,7 @@ static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t *
|
|||||||
|
|
||||||
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
|
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
|
||||||
// TODO: obtain value from the server
|
// TODO: obtain value from the server
|
||||||
return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
|
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||||
|
|
||||||
UNUSED(dev);
|
UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
|||||||
#include "mmvq.hpp"
|
#include "mmvq.hpp"
|
||||||
#include "vecdotq.hpp"
|
#include "vecdotq.hpp"
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
|
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
|
||||||
static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
|
static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
|
||||||
@ -13,7 +13,8 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int blocks_per_row = ncols / qk;
|
const int blocks_per_row = ncols / qk;
|
||||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||||
|
assert(blocks_per_warp>0);
|
||||||
|
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
float tmp = 0.0f;
|
float tmp = 0.0f;
|
||||||
@ -37,7 +38,7 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
|
|||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||||
tmp +=
|
tmp +=
|
||||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||||
}
|
}
|
||||||
@ -61,7 +62,8 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int blocks_per_row = ncols / qk;
|
const int blocks_per_row = ncols / qk;
|
||||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||||
|
assert(blocks_per_warp>0);
|
||||||
|
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
float tmp = 0.0f;
|
float tmp = 0.0f;
|
||||||
@ -85,7 +87,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
|
|||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||||
tmp +=
|
tmp +=
|
||||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||||
}
|
}
|
||||||
@ -109,8 +111,8 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int blocks_per_row = ncols / qk;
|
const int blocks_per_row = ncols / qk;
|
||||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||||
|
assert(blocks_per_warp>0);
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
float tmp = 0.0f;
|
float tmp = 0.0f;
|
||||||
|
|
||||||
@ -133,7 +135,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
|
|||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||||
tmp +=
|
tmp +=
|
||||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||||
}
|
}
|
||||||
@ -157,8 +159,8 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int blocks_per_row = ncols / qk;
|
const int blocks_per_row = ncols / qk;
|
||||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||||
|
assert(blocks_per_warp>0);
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
float tmp = 0.0f;
|
float tmp = 0.0f;
|
||||||
|
|
||||||
@ -181,7 +183,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
|
|||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||||
tmp +=
|
tmp +=
|
||||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||||
}
|
}
|
||||||
@ -205,8 +207,8 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int blocks_per_row = ncols / qk;
|
const int blocks_per_row = ncols / qk;
|
||||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||||
|
assert(blocks_per_warp>0);
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
float tmp = 0.0f;
|
float tmp = 0.0f;
|
||||||
|
|
||||||
@ -229,7 +231,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
|
|||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||||
tmp +=
|
tmp +=
|
||||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||||
}
|
}
|
||||||
@ -253,8 +255,8 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int blocks_per_row = ncols / qk;
|
const int blocks_per_row = ncols / qk;
|
||||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||||
|
assert(blocks_per_warp>0);
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
float tmp = 0.0f;
|
float tmp = 0.0f;
|
||||||
|
|
||||||
@ -277,7 +279,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
|
|||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||||
tmp +=
|
tmp +=
|
||||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||||
}
|
}
|
||||||
@ -301,8 +303,8 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int blocks_per_row = ncols / qk;
|
const int blocks_per_row = ncols / qk;
|
||||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||||
|
assert(blocks_per_warp>0);
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
float tmp = 0.0f;
|
float tmp = 0.0f;
|
||||||
|
|
||||||
@ -325,7 +327,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
|
|||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||||
tmp +=
|
tmp +=
|
||||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||||
}
|
}
|
||||||
@ -349,8 +351,8 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int blocks_per_row = ncols / qk;
|
const int blocks_per_row = ncols / qk;
|
||||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||||
|
assert(blocks_per_warp>0);
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
float tmp = 0.0f;
|
float tmp = 0.0f;
|
||||||
|
|
||||||
@ -373,7 +375,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
|
|||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||||
tmp +=
|
tmp +=
|
||||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||||
}
|
}
|
||||||
@ -397,8 +399,8 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int blocks_per_row = ncols / qk;
|
const int blocks_per_row = ncols / qk;
|
||||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||||
|
assert(blocks_per_warp>0);
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
float tmp = 0.0f;
|
float tmp = 0.0f;
|
||||||
|
|
||||||
@ -421,7 +423,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
|
|||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||||
tmp +=
|
tmp +=
|
||||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||||
}
|
}
|
||||||
@ -446,8 +448,8 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int blocks_per_row = ncols / qk;
|
const int blocks_per_row = ncols / qk;
|
||||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||||
|
assert(blocks_per_warp>0);
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
float tmp = 0.0f;
|
float tmp = 0.0f;
|
||||||
|
|
||||||
@ -470,7 +472,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
|
|||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||||
tmp +=
|
tmp +=
|
||||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||||
}
|
}
|
||||||
@ -487,7 +489,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK4_0 == 0);
|
GGML_ASSERT(ncols % QK4_0 == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -495,7 +497,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
|
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
|
||||||
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
@ -511,7 +513,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK4_1 == 0);
|
GGML_ASSERT(ncols % QK4_1 == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -519,7 +521,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
|
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
|
||||||
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
@ -535,7 +537,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK5_0 == 0);
|
GGML_ASSERT(ncols % QK5_0 == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -543,7 +545,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
|
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
|
||||||
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
@ -559,7 +561,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK5_1 == 0);
|
GGML_ASSERT(ncols % QK5_1 == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -567,7 +569,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
|
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
|
||||||
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
@ -583,7 +585,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK8_0 == 0);
|
GGML_ASSERT(ncols % QK8_0 == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -591,7 +593,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
|
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
|
||||||
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
@ -607,7 +609,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -615,7 +617,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
|
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
|
||||||
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
@ -631,7 +633,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -639,7 +641,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
|
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
|
||||||
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
@ -655,7 +657,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -663,7 +665,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
|
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
|
||||||
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
@ -679,7 +681,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -687,7 +689,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
|
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
|
||||||
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
@ -703,7 +705,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -711,7 +713,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
|
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
|
||||||
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
@ -728,13 +730,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
|
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
@ -749,7 +751,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -759,7 +761,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
|
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
@ -774,7 +776,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -784,7 +786,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
|
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
@ -799,7 +801,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -809,7 +811,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
|
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
@ -824,7 +826,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -833,7 +835,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
|
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
@ -848,7 +850,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
@ -858,7 +860,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
|
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
@ -873,13 +875,13 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
|
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
@ -894,14 +896,14 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK4_NL == 0);
|
GGML_ASSERT(ncols % QK4_NL == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
|
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
@ -916,14 +918,14 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||||
{
|
{
|
||||||
|
|
||||||
stream->submit([&](sycl::handler &cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
|
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -213,6 +213,7 @@ struct vk_device_struct {
|
|||||||
vk_pipeline pipeline_sum_rows_f32;
|
vk_pipeline pipeline_sum_rows_f32;
|
||||||
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
||||||
vk_pipeline pipeline_timestep_embedding_f32;
|
vk_pipeline pipeline_timestep_embedding_f32;
|
||||||
|
vk_pipeline pipeline_pool2d_f32;
|
||||||
|
|
||||||
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
|
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
|
||||||
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
|
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
|
||||||
@ -403,6 +404,17 @@ struct vk_op_timestep_embedding_push_constants {
|
|||||||
uint32_t max_period;
|
uint32_t max_period;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct vk_op_pool2d_push_constants {
|
||||||
|
uint32_t IW; uint32_t IH;
|
||||||
|
uint32_t OW; uint32_t OH;
|
||||||
|
uint32_t OC;
|
||||||
|
uint32_t pelements;
|
||||||
|
uint32_t op;
|
||||||
|
int32_t k0; int32_t k1;
|
||||||
|
int32_t s0; int32_t s1;
|
||||||
|
int32_t p0; int32_t p1;
|
||||||
|
};
|
||||||
|
|
||||||
// Allow pre-recording command buffers
|
// Allow pre-recording command buffers
|
||||||
struct vk_staging_memcpy {
|
struct vk_staging_memcpy {
|
||||||
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
||||||
@ -1803,6 +1815,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
for (auto &c : compiles) {
|
for (auto &c : compiles) {
|
||||||
c.wait();
|
c.wait();
|
||||||
}
|
}
|
||||||
@ -1941,7 +1955,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||||||
if (device->fp16) {
|
if (device->fp16) {
|
||||||
device_extensions.push_back("VK_KHR_shader_float16_int8");
|
device_extensions.push_back("VK_KHR_shader_float16_int8");
|
||||||
}
|
}
|
||||||
device->name = device->properties.deviceName.data();
|
device->name = GGML_VK_NAME + std::to_string(idx);
|
||||||
|
|
||||||
device_create_info = {
|
device_create_info = {
|
||||||
vk::DeviceCreateFlags(),
|
vk::DeviceCreateFlags(),
|
||||||
@ -1968,7 +1982,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||||||
|
|
||||||
device->buffer_type = {
|
device->buffer_type = {
|
||||||
/* .iface = */ ggml_backend_vk_buffer_type_interface,
|
/* .iface = */ ggml_backend_vk_buffer_type_interface,
|
||||||
/* .device = */ nullptr,
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx),
|
||||||
/* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device },
|
/* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device },
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -4234,6 +4248,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||||||
return ctx->device->pipeline_timestep_embedding_f32;
|
return ctx->device->pipeline_timestep_embedding_f32;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_pool2d_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_leaky_relu_f32;
|
return ctx->device->pipeline_leaky_relu_f32;
|
||||||
@ -4464,6 +4483,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||||||
uint32_t half_ceil = (dim + 1) / 2;
|
uint32_t half_ceil = (dim + 1) / 2;
|
||||||
elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
|
elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
|
{
|
||||||
|
const uint32_t N = dst->ne[3];
|
||||||
|
const uint32_t OC = dst->ne[2];
|
||||||
|
const uint32_t OH = dst->ne[1];
|
||||||
|
const uint32_t OW = dst->ne[0];
|
||||||
|
elements = { N * OC * OH * OW, 1, 1};
|
||||||
|
} break;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
@ -4914,6 +4941,34 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
|
|||||||
}, dryrun);
|
}, dryrun);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
|
||||||
|
const int32_t k1 = dst->op_params[1];
|
||||||
|
const int32_t k0 = dst->op_params[2];
|
||||||
|
const int32_t s1 = dst->op_params[3];
|
||||||
|
const int32_t s0 = dst->op_params[4];
|
||||||
|
const int32_t p1 = dst->op_params[5];
|
||||||
|
const int32_t p0 = dst->op_params[6];
|
||||||
|
|
||||||
|
const uint32_t IH = src0->ne[1];
|
||||||
|
const uint32_t IW = src0->ne[0];
|
||||||
|
|
||||||
|
const uint32_t N = dst->ne[3];
|
||||||
|
|
||||||
|
const uint32_t OC = dst->ne[2];
|
||||||
|
const uint32_t OH = dst->ne[1];
|
||||||
|
const uint32_t OW = dst->ne[0];
|
||||||
|
|
||||||
|
const uint32_t parallel_elements = N * OC * OH * OW;
|
||||||
|
|
||||||
|
ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, {
|
||||||
|
IW, IH, OW, OH, OC,
|
||||||
|
parallel_elements,
|
||||||
|
op,
|
||||||
|
k0, k1, s0, s1, p0, p1,
|
||||||
|
}, dryrun);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||||
const float * op_params = (const float *)dst->op_params;
|
const float * op_params = (const float *)dst->op_params;
|
||||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
|
||||||
@ -5792,6 +5847,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
@ -5927,6 +5983,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
|
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
|
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
|
ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
|
||||||
@ -6018,6 +6078,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
buf = tensor->buffer;
|
buf = tensor->buffer;
|
||||||
@ -6186,13 +6247,8 @@ static void ggml_vk_get_device_description(int device, char * description, size_
|
|||||||
|
|
||||||
// device backend
|
// device backend
|
||||||
|
|
||||||
static const char * ggml_backend_vk_buffer_get_name(ggml_backend_buffer_t buffer) {
|
|
||||||
ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
|
||||||
return ctx->name.c_str();
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) {
|
static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) {
|
||||||
return buffer->iface.get_name == ggml_backend_vk_buffer_get_name;
|
return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
@ -6256,7 +6312,6 @@ static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t v
|
|||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
|
static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
|
||||||
/* .get_name = */ ggml_backend_vk_buffer_get_name,
|
|
||||||
/* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
|
/* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
|
||||||
/* .get_base = */ ggml_backend_vk_buffer_get_base,
|
/* .get_base = */ ggml_backend_vk_buffer_get_base,
|
||||||
/* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
|
/* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
|
||||||
@ -6352,7 +6407,6 @@ static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_
|
|||||||
|
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
||||||
buffer->buft = buft;
|
buffer->buft = buft;
|
||||||
buffer->iface.get_name = ggml_backend_vk_host_buffer_name;
|
|
||||||
buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer;
|
buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer;
|
||||||
|
|
||||||
return buffer;
|
return buffer;
|
||||||
@ -6378,7 +6432,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
|
|||||||
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
|
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
|
||||||
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
||||||
},
|
},
|
||||||
/* .device = */ nullptr,
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0),
|
||||||
/* .context = */ nullptr,
|
/* .context = */ nullptr,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -6581,9 +6635,132 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||||||
UNUSED(backend);
|
UNUSED(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
// TODO: enable async and synchronize
|
||||||
// ggml_backend_vk_context * ctx = (ggml_backend_vk_context *) backend->context;
|
static ggml_backend_i ggml_backend_vk_interface = {
|
||||||
|
/* .get_name = */ ggml_backend_vk_name,
|
||||||
|
/* .free = */ ggml_backend_vk_free,
|
||||||
|
/* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async,
|
||||||
|
/* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async,
|
||||||
|
/* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
|
||||||
|
/* .synchronize = */ NULL, // ggml_backend_vk_synchronize,
|
||||||
|
/* .graph_plan_create = */ NULL,
|
||||||
|
/* .graph_plan_free = */ NULL,
|
||||||
|
/* .graph_plan_update = */ NULL,
|
||||||
|
/* .graph_plan_compute = */ NULL,
|
||||||
|
/* .graph_compute = */ ggml_backend_vk_graph_compute,
|
||||||
|
/* .event_record = */ NULL,
|
||||||
|
/* .event_wait = */ NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
static ggml_guid_t ggml_backend_vk_guid() {
|
||||||
|
static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
|
||||||
|
return &guid;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
|
||||||
|
VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
|
||||||
|
|
||||||
|
ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
|
||||||
|
ggml_vk_init(ctx, dev_num);
|
||||||
|
|
||||||
|
ggml_backend_t vk_backend = new ggml_backend {
|
||||||
|
/* .guid = */ ggml_backend_vk_guid(),
|
||||||
|
/* .interface = */ ggml_backend_vk_interface,
|
||||||
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
|
||||||
|
/* .context = */ ctx,
|
||||||
|
};
|
||||||
|
|
||||||
|
return vk_backend;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ggml_backend_is_vk(ggml_backend_t backend) {
|
||||||
|
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
|
||||||
|
}
|
||||||
|
|
||||||
|
int ggml_backend_vk_get_device_count() {
|
||||||
|
return ggml_vk_get_device_count();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
|
||||||
|
GGML_ASSERT(device < (int) vk_instance.device_indices.size());
|
||||||
|
int dev_idx = vk_instance.device_indices[device];
|
||||||
|
ggml_vk_get_device_description(dev_idx, description, description_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
|
||||||
|
GGML_ASSERT(device < (int) vk_instance.device_indices.size());
|
||||||
|
|
||||||
|
vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
|
||||||
|
|
||||||
|
vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
|
||||||
|
|
||||||
|
for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
|
||||||
|
if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
|
||||||
|
*total = heap.size;
|
||||||
|
*free = heap.size;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////
|
||||||
|
|
||||||
|
struct ggml_backend_vk_device_context {
|
||||||
|
size_t device;
|
||||||
|
std::string name;
|
||||||
|
std::string description;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
|
||||||
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
|
return ctx->name.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) {
|
||||||
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
|
return ctx->description.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
|
||||||
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
|
||||||
|
ggml_backend_vk_get_device_memory(ctx->device, free, total);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||||
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
|
return ggml_backend_vk_buffer_type(ctx->device);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) {
|
||||||
|
UNUSED(dev);
|
||||||
|
return ggml_backend_vk_host_buffer_type();
|
||||||
|
}
|
||||||
|
|
||||||
|
static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
|
||||||
|
UNUSED(dev);
|
||||||
|
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
||||||
|
props->name = ggml_backend_vk_device_get_name(dev);
|
||||||
|
props->description = ggml_backend_vk_device_get_description(dev);
|
||||||
|
props->type = ggml_backend_vk_device_get_type(dev);
|
||||||
|
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||||
|
props->caps = {
|
||||||
|
/* .async = */ false,
|
||||||
|
/* .host_buffer = */ true,
|
||||||
|
/* .buffer_from_host_ptr = */ false,
|
||||||
|
/* .events = */ false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||||
|
UNUSED(params);
|
||||||
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
|
return ggml_backend_vk_init(ctx->device);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||||
switch (op->op) {
|
switch (op->op) {
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(op)) {
|
switch (ggml_get_unary_op(op)) {
|
||||||
@ -6695,103 +6872,108 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
|
|||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
UNUSED(backend);
|
UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_vk_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
|
static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||||
|
if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
|
ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
|
||||||
|
|
||||||
|
return buft_ctx->device->idx == ctx->device;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||||
const int min_batch_size = 32;
|
const int min_batch_size = 32;
|
||||||
|
|
||||||
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
|
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
|
||||||
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
|
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
|
||||||
|
|
||||||
UNUSED(backend);
|
UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_vk_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
|
||||||
if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) {
|
/* .get_name = */ ggml_backend_vk_device_get_name,
|
||||||
return false;
|
/* .get_description = */ ggml_backend_vk_device_get_description,
|
||||||
}
|
/* .get_memory = */ ggml_backend_vk_device_get_memory,
|
||||||
|
/* .get_type = */ ggml_backend_vk_device_get_type,
|
||||||
ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
|
/* .get_props = */ ggml_backend_vk_device_get_props,
|
||||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
/* .init_backend = */ ggml_backend_vk_device_init,
|
||||||
|
/* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
|
||||||
return buft_ctx->device == ctx->device;
|
/* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
|
||||||
}
|
/* .buffer_from_host_ptr = */ NULL,
|
||||||
|
/* .supports_op = */ ggml_backend_vk_device_supports_op,
|
||||||
// TODO: enable async and synchronize
|
/* .supports_buft = */ ggml_backend_vk_device_supports_buft,
|
||||||
static ggml_backend_i ggml_backend_vk_interface = {
|
/* .offload_op = */ ggml_backend_vk_device_offload_op,
|
||||||
/* .get_name = */ ggml_backend_vk_name,
|
/* .event_new = */ NULL,
|
||||||
/* .free = */ ggml_backend_vk_free,
|
/* .event_free = */ NULL,
|
||||||
/* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
|
/* .event_synchronize = */ NULL,
|
||||||
/* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async,
|
|
||||||
/* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async,
|
|
||||||
/* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
|
|
||||||
/* .synchronize = */ NULL, // ggml_backend_vk_synchronize,
|
|
||||||
/* .graph_plan_create = */ NULL,
|
|
||||||
/* .graph_plan_free = */ NULL,
|
|
||||||
/* .graph_plan_update = */ NULL,
|
|
||||||
/* .graph_plan_compute = */ NULL,
|
|
||||||
/* .graph_compute = */ ggml_backend_vk_graph_compute,
|
|
||||||
/* .supports_op = */ ggml_backend_vk_supports_op,
|
|
||||||
/* .supports_buft = */ ggml_backend_vk_supports_buft,
|
|
||||||
/* .offload_op = */ ggml_backend_vk_offload_op,
|
|
||||||
/* .event_record = */ NULL,
|
|
||||||
/* .event_wait = */ NULL,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static ggml_guid_t ggml_backend_vk_guid() {
|
static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
|
||||||
static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
|
UNUSED(reg);
|
||||||
return &guid;
|
return GGML_VK_NAME;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
|
static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||||
VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
|
UNUSED(reg);
|
||||||
|
return ggml_backend_vk_get_device_count();
|
||||||
|
}
|
||||||
|
|
||||||
ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
|
static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) {
|
||||||
ggml_vk_init(ctx, dev_num);
|
static std::vector<ggml_backend_dev_t> devices;
|
||||||
|
|
||||||
ggml_backend_t vk_backend = new ggml_backend {
|
static bool initialized = false;
|
||||||
/* .guid = */ ggml_backend_vk_guid(),
|
|
||||||
/* .interface = */ ggml_backend_vk_interface,
|
{
|
||||||
/* .device = */ nullptr,
|
static std::mutex mutex;
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
if (!initialized) {
|
||||||
|
for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
|
||||||
|
ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
|
||||||
|
char desc[256];
|
||||||
|
ggml_backend_vk_get_device_description(i, desc, sizeof(desc));
|
||||||
|
ctx->device = i;
|
||||||
|
ctx->name = GGML_VK_NAME + std::to_string(i);
|
||||||
|
ctx->description = desc;
|
||||||
|
devices.push_back(new ggml_backend_device {
|
||||||
|
/* .iface = */ ggml_backend_vk_device_i,
|
||||||
|
/* .reg = */ reg,
|
||||||
/* .context = */ ctx,
|
/* .context = */ ctx,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
initialized = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(device < devices.size());
|
||||||
|
return devices[device];
|
||||||
|
}
|
||||||
|
|
||||||
|
static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
|
||||||
|
/* .get_name = */ ggml_backend_vk_reg_get_name,
|
||||||
|
/* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
|
||||||
|
/* .get_device = */ ggml_backend_vk_reg_get_device,
|
||||||
|
/* .get_proc_address = */ NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_backend_reg_t ggml_backend_vk_reg() {
|
||||||
|
static ggml_backend_reg reg = {
|
||||||
|
/* .iface = */ ggml_backend_vk_reg_i,
|
||||||
|
/* .context = */ nullptr,
|
||||||
};
|
};
|
||||||
|
|
||||||
return vk_backend;
|
return ®
|
||||||
}
|
|
||||||
|
|
||||||
bool ggml_backend_is_vk(ggml_backend_t backend) {
|
|
||||||
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
|
|
||||||
}
|
|
||||||
|
|
||||||
int ggml_backend_vk_get_device_count() {
|
|
||||||
return ggml_vk_get_device_count();
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
|
|
||||||
ggml_vk_get_device_description(device, description, description_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
|
|
||||||
GGML_ASSERT(device < (int) vk_instance.device_indices.size());
|
|
||||||
|
|
||||||
vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
|
|
||||||
|
|
||||||
vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
|
|
||||||
|
|
||||||
for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
|
|
||||||
if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
|
|
||||||
*total = heap.size;
|
|
||||||
*free = heap.size;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extension availability
|
// Extension availability
|
||||||
@ -7204,6 +7386,16 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||||||
const int32_t dim = tensor->op_params[0];
|
const int32_t dim = tensor->op_params[0];
|
||||||
const int32_t max_period = tensor->op_params[1];
|
const int32_t max_period = tensor->op_params[1];
|
||||||
tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
|
tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
|
||||||
|
} else if (tensor->op == GGML_OP_POOL_2D) {
|
||||||
|
enum ggml_op_pool op = static_cast<ggml_op_pool>(dst->op_params[0]);
|
||||||
|
const int32_t k0 = tensor->op_params[1];
|
||||||
|
const int32_t k1 = tensor->op_params[2];
|
||||||
|
const int32_t s0 = tensor->op_params[3];
|
||||||
|
const int32_t s1 = tensor->op_params[4];
|
||||||
|
const int32_t p0 = tensor->op_params[5];
|
||||||
|
const int32_t p1 = tensor->op_params[6];
|
||||||
|
|
||||||
|
tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1);
|
||||||
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
|
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
|
||||||
const float * op_params = (const float *)tensor->op_params;
|
const float * op_params = (const float *)tensor->op_params;
|
||||||
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
|
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
|
||||||
|
331
ggml/src/ggml.c
331
ggml/src/ggml.c
@ -35,10 +35,6 @@
|
|||||||
#include <omp.h>
|
#include <omp.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
|
||||||
#include <unistd.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
|
#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
#undef GGML_USE_LLAMAFILE
|
#undef GGML_USE_LLAMAFILE
|
||||||
#endif
|
#endif
|
||||||
@ -189,6 +185,8 @@ typedef pthread_t ggml_thread_t;
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__APPLE__)
|
#if defined(__APPLE__)
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <mach/mach.h>
|
||||||
#include <TargetConditionals.h>
|
#include <TargetConditionals.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -308,6 +306,7 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define GGML_DEBUG 0
|
#define GGML_DEBUG 0
|
||||||
|
|
||||||
#define GGML_GELU_FP16
|
#define GGML_GELU_FP16
|
||||||
#define GGML_GELU_QUICK_FP16
|
#define GGML_GELU_QUICK_FP16
|
||||||
|
|
||||||
@ -326,8 +325,9 @@ struct ggml_logger_state {
|
|||||||
static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
|
static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
|
||||||
|
|
||||||
static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
|
static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
|
||||||
if (format == NULL)
|
if (format == NULL) {
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
va_list args_copy;
|
va_list args_copy;
|
||||||
va_copy(args_copy, args);
|
va_copy(args_copy, args);
|
||||||
char buffer[128];
|
char buffer[128];
|
||||||
@ -386,22 +386,40 @@ void ggml_log_callback_default(enum ggml_log_level level, const char * text, voi
|
|||||||
//#define GGML_SOFT_MAX_ACCELERATE
|
//#define GGML_SOFT_MAX_ACCELERATE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
void * ggml_aligned_malloc(size_t size) {
|
||||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||||
#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
|
return _aligned_malloc(size, TENSOR_ALIGNMENT);
|
||||||
#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
|
|
||||||
#else
|
#else
|
||||||
inline static void * ggml_aligned_malloc(size_t size) {
|
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
|
GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
void * aligned_memory = NULL;
|
void * aligned_memory = NULL;
|
||||||
#ifdef GGML_USE_CPU_HBM
|
#ifdef GGML_USE_CPU_HBM
|
||||||
int result = hbw_posix_memalign(&aligned_memory, 16, size);
|
int result = hbw_posix_memalign(&aligned_memory, TENSOR_ALIGNMENT, size);
|
||||||
|
#elif TARGET_OS_OSX
|
||||||
|
kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE);
|
||||||
|
int result = EFAULT;
|
||||||
|
switch (alloc_status) {
|
||||||
|
case KERN_SUCCESS:
|
||||||
|
result = 0;
|
||||||
|
break;
|
||||||
|
case KERN_INVALID_ADDRESS:
|
||||||
|
result = EINVAL;
|
||||||
|
break;
|
||||||
|
case KERN_NO_SPACE:
|
||||||
|
result = ENOMEM;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
result = EFAULT;
|
||||||
|
break;
|
||||||
|
}
|
||||||
#elif GGML_USE_METAL
|
#elif GGML_USE_METAL
|
||||||
int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size);
|
const long page_size = sysconf(_SC_PAGESIZE);
|
||||||
|
int result = posix_memalign(&aligned_memory, MAX(TENSOR_ALIGNMENT, page_size), size);
|
||||||
#else
|
#else
|
||||||
int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
|
int result = posix_memalign(&aligned_memory, TENSOR_ALIGNMENT, size);
|
||||||
#endif
|
#endif
|
||||||
if (result != 0) {
|
if (result != 0) {
|
||||||
// Handle allocation failure
|
// Handle allocation failure
|
||||||
@ -419,14 +437,26 @@ inline static void * ggml_aligned_malloc(size_t size) {
|
|||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
return aligned_memory;
|
return aligned_memory;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
|
|
||||||
#ifdef GGML_USE_CPU_HBM
|
void ggml_aligned_free(void * ptr, size_t size) {
|
||||||
#define GGML_ALIGNED_FREE(ptr) if(NULL != ptr) hbw_free(ptr)
|
GGML_UNUSED(size);
|
||||||
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||||
|
_aligned_free(ptr);
|
||||||
|
#elif GGML_USE_CPU_HBM
|
||||||
|
if (ptr != NULL) {
|
||||||
|
hbw_free(ptr);
|
||||||
|
}
|
||||||
|
#elif TARGET_OS_OSX
|
||||||
|
if (ptr != NULL) {
|
||||||
|
vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ptr, size);
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
#define GGML_ALIGNED_FREE(ptr) free(ptr)
|
free(ptr);
|
||||||
#endif
|
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
inline static void * ggml_malloc(size_t size) {
|
inline static void * ggml_malloc(size_t size) {
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
@ -1985,18 +2015,14 @@ static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
|
|||||||
|
|
||||||
struct ggml_context {
|
struct ggml_context {
|
||||||
size_t mem_size;
|
size_t mem_size;
|
||||||
void* mem_buffer;
|
void * mem_buffer;
|
||||||
bool mem_buffer_owned;
|
bool mem_buffer_owned;
|
||||||
bool no_alloc;
|
bool no_alloc;
|
||||||
bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
|
|
||||||
|
|
||||||
int n_objects;
|
int n_objects;
|
||||||
|
|
||||||
struct ggml_object * objects_begin;
|
struct ggml_object * objects_begin;
|
||||||
struct ggml_object * objects_end;
|
struct ggml_object * objects_end;
|
||||||
|
|
||||||
struct ggml_scratch scratch;
|
|
||||||
struct ggml_scratch scratch_save;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_context_container {
|
struct ggml_context_container {
|
||||||
@ -3234,7 +3260,6 @@ struct ggml_numa_nodes {
|
|||||||
//
|
//
|
||||||
|
|
||||||
struct ggml_state {
|
struct ggml_state {
|
||||||
struct ggml_context_container contexts[GGML_MAX_CONTEXTS];
|
|
||||||
struct ggml_numa_nodes numa;
|
struct ggml_numa_nodes numa;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -3435,7 +3460,7 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
|
|||||||
|
|
||||||
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
||||||
size_t nbytes;
|
size_t nbytes;
|
||||||
size_t blck_size = ggml_blck_size(tensor->type);
|
const size_t blck_size = ggml_blck_size(tensor->type);
|
||||||
if (blck_size == 1) {
|
if (blck_size == 1) {
|
||||||
nbytes = ggml_type_size(tensor->type);
|
nbytes = ggml_type_size(tensor->type);
|
||||||
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||||
@ -3816,17 +3841,12 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||||||
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
|
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
|
||||||
|
|
||||||
g_state = (struct ggml_state) {
|
g_state = (struct ggml_state) {
|
||||||
/*.contexts =*/ { { 0 } },
|
|
||||||
/*.numa =*/ {
|
/*.numa =*/ {
|
||||||
.n_nodes = 0,
|
.n_nodes = 0,
|
||||||
.total_cpus = 0,
|
.total_cpus = 0,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
|
|
||||||
g_state.contexts[i].used = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
||||||
|
|
||||||
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
||||||
@ -3839,26 +3859,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||||||
is_first_call = false;
|
is_first_call = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// find non-used context in g_state
|
|
||||||
struct ggml_context * ctx = NULL;
|
|
||||||
|
|
||||||
for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
|
|
||||||
if (!g_state.contexts[i].used) {
|
|
||||||
g_state.contexts[i].used = true;
|
|
||||||
ctx = &g_state.contexts[i].context;
|
|
||||||
|
|
||||||
GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ctx == NULL) {
|
|
||||||
GGML_PRINT_DEBUG("%s: no unused context found\n", __func__);
|
|
||||||
|
|
||||||
ggml_critical_section_end();
|
ggml_critical_section_end();
|
||||||
|
|
||||||
return NULL;
|
struct ggml_context * ctx = GGML_MALLOC(sizeof(struct ggml_context));
|
||||||
}
|
|
||||||
|
|
||||||
// allow to call ggml_init with 0 size
|
// allow to call ggml_init with 0 size
|
||||||
if (params.mem_size == 0) {
|
if (params.mem_size == 0) {
|
||||||
@ -3869,15 +3872,12 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||||||
|
|
||||||
*ctx = (struct ggml_context) {
|
*ctx = (struct ggml_context) {
|
||||||
/*.mem_size =*/ mem_size,
|
/*.mem_size =*/ mem_size,
|
||||||
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
|
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : ggml_aligned_malloc(mem_size),
|
||||||
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
|
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
|
||||||
/*.no_alloc =*/ params.no_alloc,
|
/*.no_alloc =*/ params.no_alloc,
|
||||||
/*.no_alloc_save =*/ params.no_alloc,
|
|
||||||
/*.n_objects =*/ 0,
|
/*.n_objects =*/ 0,
|
||||||
/*.objects_begin =*/ NULL,
|
/*.objects_begin =*/ NULL,
|
||||||
/*.objects_end =*/ NULL,
|
/*.objects_end =*/ NULL,
|
||||||
/*.scratch =*/ { 0, 0, NULL, },
|
|
||||||
/*.scratch_save =*/ { 0, 0, NULL, },
|
|
||||||
};
|
};
|
||||||
|
|
||||||
GGML_ASSERT(ctx->mem_buffer != NULL);
|
GGML_ASSERT(ctx->mem_buffer != NULL);
|
||||||
@ -3886,56 +3886,35 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||||||
|
|
||||||
GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
|
GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
|
||||||
|
|
||||||
ggml_critical_section_end();
|
|
||||||
|
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_reset(struct ggml_context * ctx) {
|
||||||
|
if (ctx == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->n_objects = 0;
|
||||||
|
ctx->objects_begin = NULL;
|
||||||
|
ctx->objects_end = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_free(struct ggml_context * ctx) {
|
void ggml_free(struct ggml_context * ctx) {
|
||||||
if (ctx == NULL) {
|
if (ctx == NULL) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// make this function thread safe
|
|
||||||
ggml_critical_section_start();
|
|
||||||
|
|
||||||
bool found = false;
|
|
||||||
|
|
||||||
for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
|
|
||||||
if (&g_state.contexts[i].context == ctx) {
|
|
||||||
g_state.contexts[i].used = false;
|
|
||||||
|
|
||||||
GGML_PRINT_DEBUG("%s: context %d has been freed. memory used = %zu\n",
|
|
||||||
__func__, i, ggml_used_mem(ctx));
|
|
||||||
|
|
||||||
if (ctx->mem_buffer_owned) {
|
if (ctx->mem_buffer_owned) {
|
||||||
GGML_ALIGNED_FREE(ctx->mem_buffer);
|
ggml_aligned_free(ctx->mem_buffer, ctx->mem_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
found = true;
|
GGML_FREE(ctx);
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!found) {
|
|
||||||
GGML_PRINT_DEBUG("%s: context not found\n", __func__);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_critical_section_end();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t ggml_used_mem(const struct ggml_context * ctx) {
|
size_t ggml_used_mem(const struct ggml_context * ctx) {
|
||||||
return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size;
|
return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) {
|
|
||||||
const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0;
|
|
||||||
|
|
||||||
ctx->scratch = scratch;
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ggml_get_no_alloc(struct ggml_context * ctx) {
|
bool ggml_get_no_alloc(struct ggml_context * ctx) {
|
||||||
return ctx->no_alloc;
|
return ctx->no_alloc;
|
||||||
}
|
}
|
||||||
@ -3963,27 +3942,6 @@ size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
|
|||||||
return max_size;
|
return max_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
// IMPORTANT:
|
|
||||||
// when creating "opt" tensors, always save and load the scratch buffer
|
|
||||||
// this is an error prone process, but it is necessary to support inplace
|
|
||||||
// operators when using scratch buffers
|
|
||||||
// TODO: implement a better way
|
|
||||||
static void ggml_scratch_save(struct ggml_context * ctx) {
|
|
||||||
// this is needed to allow opt tensors to store their data
|
|
||||||
// TODO: again, need to find a better way
|
|
||||||
ctx->no_alloc_save = ctx->no_alloc;
|
|
||||||
ctx->no_alloc = false;
|
|
||||||
|
|
||||||
ctx->scratch_save = ctx->scratch;
|
|
||||||
ctx->scratch.data = NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_scratch_load(struct ggml_context * ctx) {
|
|
||||||
ctx->no_alloc = ctx->no_alloc_save;
|
|
||||||
|
|
||||||
ctx->scratch = ctx->scratch_save;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) {
|
static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) {
|
||||||
@ -4003,7 +3961,9 @@ static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml
|
|||||||
if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
|
if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
|
||||||
GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
|
GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
|
||||||
__func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
|
__func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
|
||||||
assert(false);
|
#ifndef NDEBUG
|
||||||
|
GGML_ABORT("not enough space in the context's memory pool");
|
||||||
|
#endif
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4062,29 +4022,13 @@ static struct ggml_tensor * ggml_new_tensor_impl(
|
|||||||
size_t obj_alloc_size = 0;
|
size_t obj_alloc_size = 0;
|
||||||
|
|
||||||
if (view_src == NULL && !ctx->no_alloc) {
|
if (view_src == NULL && !ctx->no_alloc) {
|
||||||
if (ctx->scratch.data != NULL) {
|
|
||||||
// allocate tensor data in the scratch buffer
|
|
||||||
if (ctx->scratch.offs + data_size > ctx->scratch.size) {
|
|
||||||
GGML_LOG_WARN("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
|
|
||||||
__func__, ctx->scratch.offs + data_size, ctx->scratch.size);
|
|
||||||
assert(false);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
data = (char * const) ctx->scratch.data + ctx->scratch.offs;
|
|
||||||
|
|
||||||
ctx->scratch.offs += data_size;
|
|
||||||
} else {
|
|
||||||
// allocate tensor data in the context's memory pool
|
// allocate tensor data in the context's memory pool
|
||||||
obj_alloc_size = data_size;
|
obj_alloc_size = data_size;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size);
|
struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size);
|
||||||
GGML_ASSERT(obj_new);
|
GGML_ASSERT(obj_new);
|
||||||
|
|
||||||
// TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here
|
|
||||||
|
|
||||||
struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
|
struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
|
||||||
|
|
||||||
#ifdef __clang__
|
#ifdef __clang__
|
||||||
@ -4180,24 +4124,16 @@ struct ggml_tensor * ggml_new_tensor_4d(
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
|
struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
|
||||||
ggml_scratch_save(ctx);
|
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
|
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
|
||||||
|
|
||||||
ggml_scratch_load(ctx);
|
|
||||||
|
|
||||||
ggml_set_i32(result, value);
|
ggml_set_i32(result, value);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
|
struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
|
||||||
ggml_scratch_save(ctx);
|
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
|
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
|
||||||
|
|
||||||
ggml_scratch_load(ctx);
|
|
||||||
|
|
||||||
ggml_set_f32(result, value);
|
ggml_set_f32(result, value);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
@ -7245,6 +7181,7 @@ struct ggml_tensor * ggml_ssm_conv(
|
|||||||
const int64_t n_s = sx->ne[2];
|
const int64_t n_s = sx->ne[2];
|
||||||
|
|
||||||
// TODO: maybe support other strides than 1?
|
// TODO: maybe support other strides than 1?
|
||||||
|
// FIXME: this is always true?
|
||||||
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
|
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
|
||||||
GGML_ASSERT(sx->ne[1] == d_inner);
|
GGML_ASSERT(sx->ne[1] == d_inner);
|
||||||
GGML_ASSERT(n_t >= 0);
|
GGML_ASSERT(n_t >= 0);
|
||||||
@ -15713,6 +15650,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
||||||
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
||||||
|
|
||||||
|
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
|
||||||
|
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
|
||||||
|
|
||||||
// loop over n_batch and n_head
|
// loop over n_batch and n_head
|
||||||
for (int ir = ir0; ir < ir1; ++ir) {
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
// q indices
|
// q indices
|
||||||
@ -19706,9 +19646,10 @@ static void ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask
|
|||||||
void ggml_threadpool_free(struct ggml_threadpool* threadpool) {
|
void ggml_threadpool_free(struct ggml_threadpool* threadpool) {
|
||||||
if (!threadpool) return;
|
if (!threadpool) return;
|
||||||
|
|
||||||
|
const int n_threads = threadpool->n_threads_max;
|
||||||
|
|
||||||
#ifndef GGML_USE_OPENMP
|
#ifndef GGML_USE_OPENMP
|
||||||
struct ggml_compute_state* workers = threadpool->workers;
|
struct ggml_compute_state* workers = threadpool->workers;
|
||||||
const int n_threads = threadpool->n_threads_max;
|
|
||||||
|
|
||||||
ggml_mutex_lock(&threadpool->mutex);
|
ggml_mutex_lock(&threadpool->mutex);
|
||||||
|
|
||||||
@ -19728,8 +19669,9 @@ void ggml_threadpool_free(struct ggml_threadpool* threadpool) {
|
|||||||
ggml_cond_destroy(&threadpool->cond);
|
ggml_cond_destroy(&threadpool->cond);
|
||||||
#endif // GGML_USE_OPENMP
|
#endif // GGML_USE_OPENMP
|
||||||
|
|
||||||
GGML_ALIGNED_FREE(threadpool->workers);
|
const size_t workers_size = sizeof(struct ggml_compute_state) * n_threads;
|
||||||
GGML_ALIGNED_FREE(threadpool);
|
ggml_aligned_free(threadpool->workers, workers_size);
|
||||||
|
ggml_aligned_free(threadpool, sizeof(struct ggml_threadpool));
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef GGML_USE_OPENMP
|
#ifndef GGML_USE_OPENMP
|
||||||
@ -20161,7 +20103,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl(
|
|||||||
struct ggml_cplan * cplan) {
|
struct ggml_cplan * cplan) {
|
||||||
|
|
||||||
struct ggml_threadpool * threadpool =
|
struct ggml_threadpool * threadpool =
|
||||||
GGML_ALIGNED_MALLOC(sizeof(struct ggml_threadpool));
|
ggml_aligned_malloc(sizeof(struct ggml_threadpool));
|
||||||
{
|
{
|
||||||
threadpool->cgraph = cgraph;
|
threadpool->cgraph = cgraph;
|
||||||
threadpool->cplan = cplan;
|
threadpool->cplan = cplan;
|
||||||
@ -20182,7 +20124,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl(
|
|||||||
|
|
||||||
// Allocate and init workers state
|
// Allocate and init workers state
|
||||||
const size_t workers_size = sizeof(struct ggml_compute_state) * tpp->n_threads;
|
const size_t workers_size = sizeof(struct ggml_compute_state) * tpp->n_threads;
|
||||||
struct ggml_compute_state * workers = GGML_ALIGNED_MALLOC(workers_size);
|
struct ggml_compute_state * workers = ggml_aligned_malloc(workers_size);
|
||||||
|
|
||||||
memset(workers, 0, workers_size);
|
memset(workers, 0, workers_size);
|
||||||
for (int j = 0; j < tpp->n_threads; j++) {
|
for (int j = 0; j < tpp->n_threads; j++) {
|
||||||
@ -20357,7 +20299,6 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
|
|||||||
uint64_t size_eval = 0;
|
uint64_t size_eval = 0;
|
||||||
|
|
||||||
// compute size of intermediate results
|
// compute size of intermediate results
|
||||||
// TODO: does not take into account scratch buffers !!!!
|
|
||||||
for (int i = 0; i < cgraph->n_nodes; ++i) {
|
for (int i = 0; i < cgraph->n_nodes; ++i) {
|
||||||
size_eval += ggml_nbytes_pad(cgraph->nodes[i]);
|
size_eval += ggml_nbytes_pad(cgraph->nodes[i]);
|
||||||
}
|
}
|
||||||
@ -22168,18 +22109,46 @@ static size_t gguf_type_size(enum gguf_type type) {
|
|||||||
return GGUF_TYPE_SIZE[type];
|
return GGUF_TYPE_SIZE[type];
|
||||||
}
|
}
|
||||||
|
|
||||||
static void gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
|
static bool gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
|
||||||
GGML_ASSERT(info->n_dims <= GGML_MAX_DIMS);
|
if (info->n_dims > GGML_MAX_DIMS) {
|
||||||
GGML_ASSERT(0 <= info->type && info->type < GGML_TYPE_COUNT);
|
fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (info->type < 0 || info->type >= GGML_TYPE_COUNT) {
|
||||||
|
fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (strlen(info->name.data) >= GGML_MAX_NAME) {
|
||||||
|
fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < info->n_dims; ++i) {
|
for (uint32_t i = 0; i < info->n_dims; ++i) {
|
||||||
GGML_ASSERT(info->ne[i] > 0);
|
if (info->ne[i] <= 0) {
|
||||||
|
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// prevent overflow for total number of elements
|
// prevent overflow for total number of elements
|
||||||
GGML_ASSERT(INT64_MAX/info->ne[1] > info->ne[0]);
|
if (INT64_MAX/info->ne[1] <= info->ne[0]) {
|
||||||
GGML_ASSERT(INT64_MAX/info->ne[2] > info->ne[0]*info->ne[1]);
|
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]);
|
||||||
GGML_ASSERT(INT64_MAX/info->ne[3] > info->ne[0]*info->ne[1]*info->ne[2]);
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) {
|
||||||
|
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) {
|
||||||
|
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
|
static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
|
||||||
@ -22202,7 +22171,11 @@ static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
p->data = GGML_CALLOC(p->n + 1, 1);
|
p->data = calloc(p->n + 1, 1);
|
||||||
|
if (!p->data) {
|
||||||
|
fprintf(stderr, "%s: failed to allocate memory for string of length %" PRIu64 "\n", __func__, p->n);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
ok = ok && gguf_fread_el(file, p->data, p->n, offset);
|
ok = ok && gguf_fread_el(file, p->data, p->n, offset);
|
||||||
|
|
||||||
@ -22236,7 +22209,11 @@ static void gguf_free_kv(struct gguf_kv * kv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct gguf_context * gguf_init_empty(void) {
|
struct gguf_context * gguf_init_empty(void) {
|
||||||
struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
|
struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
|
||||||
|
if (!ctx) {
|
||||||
|
fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
|
memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
|
||||||
ctx->header.version = GGUF_VERSION;
|
ctx->header.version = GGUF_VERSION;
|
||||||
@ -22282,7 +22259,12 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||||||
|
|
||||||
bool ok = true;
|
bool ok = true;
|
||||||
|
|
||||||
struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
|
struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
|
||||||
|
if (!ctx) {
|
||||||
|
fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
|
||||||
|
fclose(file);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
// read the header
|
// read the header
|
||||||
{
|
{
|
||||||
@ -22321,9 +22303,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||||||
{
|
{
|
||||||
const uint64_t n_kv = ctx->header.n_kv;
|
const uint64_t n_kv = ctx->header.n_kv;
|
||||||
|
|
||||||
// header.n_kv will hold the actual value of pairs that were successfully read in the loop below
|
ctx->kv = calloc(n_kv, sizeof(struct gguf_kv));
|
||||||
ctx->header.n_kv = 0;
|
if (!ctx->kv) {
|
||||||
ctx->kv = GGML_CALLOC(n_kv, sizeof(struct gguf_kv));
|
fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__);
|
||||||
|
fclose(file);
|
||||||
|
gguf_free(ctx);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
for (uint64_t i = 0; i < n_kv; ++i) {
|
for (uint64_t i = 0; i < n_kv; ++i) {
|
||||||
struct gguf_kv * kv = &ctx->kv[i];
|
struct gguf_kv * kv = &ctx->kv[i];
|
||||||
@ -22374,7 +22360,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
|
kv->value.arr.data = calloc(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
|
||||||
|
if (!kv->value.arr.data) {
|
||||||
|
fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
|
||||||
|
fclose(file);
|
||||||
|
gguf_free(ctx);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
|
ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
|
||||||
} break;
|
} break;
|
||||||
@ -22388,24 +22380,36 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, sizeof(struct gguf_str));
|
kv->value.arr.data = calloc(kv->value.arr.n, sizeof(struct gguf_str));
|
||||||
|
if (!kv->value.arr.data) {
|
||||||
|
fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
|
||||||
|
fclose(file);
|
||||||
|
gguf_free(ctx);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
|
for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
|
||||||
ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
|
ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGUF_TYPE_ARRAY:
|
case GGUF_TYPE_ARRAY:
|
||||||
default: GGML_ABORT("invalid type");
|
default:
|
||||||
|
{
|
||||||
|
fprintf(stderr, "%s: invalid array type %d\n", __func__, kv->value.arr.type);
|
||||||
|
ok = false;
|
||||||
|
} break;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
default: GGML_ABORT("invalid type");
|
default:
|
||||||
|
{
|
||||||
|
fprintf(stderr, "%s: invalid type %d\n", __func__, kv->type);
|
||||||
|
ok = false;
|
||||||
|
} break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->header.n_kv++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
@ -22418,7 +22422,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||||||
|
|
||||||
// read the tensor infos
|
// read the tensor infos
|
||||||
if (ctx->header.n_tensors > 0) {
|
if (ctx->header.n_tensors > 0) {
|
||||||
ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
|
ctx->infos = calloc(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
|
||||||
|
if (!ctx->infos) {
|
||||||
|
fprintf(stderr, "%s: failed to allocate memory for tensor infos\n", __func__);
|
||||||
|
fclose(file);
|
||||||
|
gguf_free(ctx);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
|
for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
|
||||||
struct gguf_tensor_info * info = &ctx->infos[i];
|
struct gguf_tensor_info * info = &ctx->infos[i];
|
||||||
@ -22439,8 +22449,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||||||
ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
|
ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
|
||||||
ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
|
ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
|
||||||
|
|
||||||
// TODO: return an error instead of crashing with GGML_ASSERT
|
ok = ok && gguf_tensor_info_sanitize(info);
|
||||||
gguf_tensor_info_sanitize(info);
|
|
||||||
|
|
||||||
// make sure there is no duplicated tensor names
|
// make sure there is no duplicated tensor names
|
||||||
for (uint64_t j = 0; j < i && ok; ++j) {
|
for (uint64_t j = 0; j < i && ok; ++j) {
|
||||||
@ -23320,6 +23329,14 @@ int ggml_cpu_has_avx512_bf16(void) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ggml_cpu_has_amx_int8(void) {
|
||||||
|
#if defined(__AMX_INT8__)
|
||||||
|
return 1;
|
||||||
|
#else
|
||||||
|
return 0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
int ggml_cpu_has_fma(void) {
|
int ggml_cpu_has_fma(void) {
|
||||||
#if defined(__FMA__)
|
#if defined(__FMA__)
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
#define TWOPI_F 6.283185307179586f
|
#define TWOPI_F 6.283185307179586f
|
||||||
|
|
||||||
#define QK_K 256
|
#define QK_K 256
|
||||||
|
#define K_SCALE_SIZE 12
|
||||||
|
|
||||||
#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
|
#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
|
||||||
#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
|
#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
|
||||||
@ -64,6 +65,14 @@ mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
|
|||||||
return reg;
|
return reg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define sizeof_block_q4_k 144
|
||||||
|
struct block_q4_k {
|
||||||
|
float16_t d;
|
||||||
|
float16_t dmin;
|
||||||
|
uint8_t scales[K_SCALE_SIZE];
|
||||||
|
uint8_t qs[QK_K/2];
|
||||||
|
};
|
||||||
|
|
||||||
#define sizeof_block_q6_k 210
|
#define sizeof_block_q6_k 210
|
||||||
struct block_q6_k {
|
struct block_q6_k {
|
||||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||||
|
133
ggml/src/kompute-shaders/op_mul_mat_q4_k.comp
Normal file
133
ggml/src/kompute-shaders/op_mul_mat_q4_k.comp
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "common.comp"
|
||||||
|
|
||||||
|
#define N_DST 4
|
||||||
|
#define SIZE_OF_BLOCK sizeof_block_q4_k
|
||||||
|
|
||||||
|
layout(local_size_x = 4) in;
|
||||||
|
layout(local_size_y = 8) in;
|
||||||
|
layout(local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; };
|
||||||
|
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
|
||||||
|
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter {
|
||||||
|
uint inAOff;
|
||||||
|
uint inBOff;
|
||||||
|
uint outOff;
|
||||||
|
int ne00;
|
||||||
|
int ne10;
|
||||||
|
int ne0;
|
||||||
|
int ne1;
|
||||||
|
int ne01;
|
||||||
|
int ne02;
|
||||||
|
int ne12;
|
||||||
|
int r2;
|
||||||
|
int r3;
|
||||||
|
} pcs;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint16_t kmask1 = uint16_t(0x3f3f);
|
||||||
|
const uint16_t kmask2 = uint16_t(0x0f0f);
|
||||||
|
const uint16_t kmask3 = uint16_t(0xc0c0);
|
||||||
|
|
||||||
|
const uint ix = gl_SubgroupInvocationID/8; // 0...3
|
||||||
|
const uint it = gl_SubgroupInvocationID%8; // 0...7
|
||||||
|
const uint iq = it/4; // 0 or 1
|
||||||
|
const uint ir = it%4; // 0...3
|
||||||
|
|
||||||
|
const uint nb = pcs.ne00/QK_K;
|
||||||
|
|
||||||
|
const uint r0 = gl_WorkGroupID.x;
|
||||||
|
const uint r1 = gl_WorkGroupID.y;
|
||||||
|
const uint im = gl_WorkGroupID.z;
|
||||||
|
|
||||||
|
const uint first_row = r0 * N_DST;
|
||||||
|
const uint ib_row = first_row * nb;
|
||||||
|
|
||||||
|
const uint i12 = im%pcs.ne12;
|
||||||
|
const uint i13 = im/pcs.ne12;
|
||||||
|
|
||||||
|
const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
|
||||||
|
|
||||||
|
const uint xblk = ib_row + offset0 + pcs.inAOff;
|
||||||
|
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;
|
||||||
|
|
||||||
|
float yl[16];
|
||||||
|
float yh[16];
|
||||||
|
float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f};
|
||||||
|
float all_sum = 0.f;
|
||||||
|
|
||||||
|
uint y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
||||||
|
|
||||||
|
for (uint ib = ix; ib < nb; ib += 4) {
|
||||||
|
const uint blk_idx = ib + xblk;
|
||||||
|
|
||||||
|
float sumy[4] = {0.f, 0.f, 0.f, 0.f};
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0];
|
||||||
|
yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8];
|
||||||
|
yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0];
|
||||||
|
yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int row = 0; row < N_DST; row++) {
|
||||||
|
uint row_idx = row * nb;
|
||||||
|
|
||||||
|
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
|
||||||
|
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
|
||||||
|
uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4);
|
||||||
|
uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6);
|
||||||
|
uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8);
|
||||||
|
|
||||||
|
uint16_t sc16[4];
|
||||||
|
sc16[0] = sc_0 & kmask1;
|
||||||
|
sc16[1] = sc_2 & kmask1;
|
||||||
|
sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2);
|
||||||
|
sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2);
|
||||||
|
|
||||||
|
float acc1[4] = {0.f, 0.f, 0.f, 0.f};
|
||||||
|
float acc2[4] = {0.f, 0.f, 0.f, 0.f};
|
||||||
|
for (int i = 0; i < 8; i += 2) {
|
||||||
|
uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i);
|
||||||
|
uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i);
|
||||||
|
acc1[0] += yl[i+0] * (q1 & 0x000F);
|
||||||
|
acc1[1] += yl[i+1] * (q1 & 0x0F00);
|
||||||
|
acc1[2] += yl[i+8] * (q1 & 0x00F0);
|
||||||
|
acc1[3] += yl[i+9] * (q1 & 0xF000);
|
||||||
|
acc2[0] += yh[i+0] * (q2 & 0x000F);
|
||||||
|
acc2[1] += yh[i+1] * (q2 & 0x0F00);
|
||||||
|
acc2[2] += yh[i+8] * (q2 & 0x00F0);
|
||||||
|
acc2[3] += yh[i+9] * (q2 & 0xF000);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF);
|
||||||
|
uint8_t sc8_1 = uint8_t(sc16[0] >> 8 );
|
||||||
|
uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF);
|
||||||
|
uint8_t sc8_3 = uint8_t(sc16[1] >> 8 );
|
||||||
|
uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF);
|
||||||
|
uint8_t sc8_5 = uint8_t(sc16[2] >> 8 );
|
||||||
|
uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF);
|
||||||
|
uint8_t sc8_7 = uint8_t(sc16[3] >> 8 );
|
||||||
|
|
||||||
|
float dall = float(inA[blk_idx + row_idx].d);
|
||||||
|
float dmin = float(inA[blk_idx + row_idx].dmin);
|
||||||
|
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 +
|
||||||
|
(acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f +
|
||||||
|
(acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 +
|
||||||
|
(acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) -
|
||||||
|
dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7);
|
||||||
|
}
|
||||||
|
|
||||||
|
y4 += 4 * QK_K;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
|
all_sum = subgroupAdd(sumf[row]);
|
||||||
|
if (subgroupElect()) {
|
||||||
|
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -942,6 +942,36 @@ class tinyBLAS_Q0_AVX {
|
|||||||
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
|
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline __m256i load(const block_q5_0 *b) {
|
||||||
|
return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __m128i load0(const block_q5_0* b) {
|
||||||
|
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
||||||
|
uint32_t x32;
|
||||||
|
memcpy(&x32, b->qh, sizeof(uint32_t));
|
||||||
|
__m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
|
||||||
|
__m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
|
||||||
|
_mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
|
||||||
|
_mm_shuffle_epi8(_mm_set1_epi32(x32),
|
||||||
|
_mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
|
||||||
|
bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
|
||||||
|
return _mm_or_si128(qxl, bytesl);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __m128i load1(const block_q5_0* b) {
|
||||||
|
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
||||||
|
uint32_t x32;
|
||||||
|
memcpy(&x32, b->qh, sizeof(uint32_t));
|
||||||
|
__m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
|
||||||
|
__m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
|
||||||
|
_mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
|
||||||
|
_mm_shuffle_epi8(_mm_set1_epi32(x32),
|
||||||
|
_mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
|
||||||
|
bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
|
||||||
|
return _mm_or_si128(qxh, bytesh);
|
||||||
|
}
|
||||||
|
|
||||||
inline __m256i load(const block_iq4_nl *b) {
|
inline __m256i load(const block_iq4_nl *b) {
|
||||||
return MM256_SET_M128I(load1(b), load0(b));
|
return MM256_SET_M128I(load1(b), load0(b));
|
||||||
}
|
}
|
||||||
@ -973,6 +1003,17 @@ class tinyBLAS_Q0_AVX {
|
|||||||
_mm_srli_epi16(x, 4), 1));
|
_mm_srli_epi16(x, 4), 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline __m256i bittobyte(const uint8_t *p) {
|
||||||
|
uint32_t x32;
|
||||||
|
memcpy(&x32, p, sizeof(uint32_t));
|
||||||
|
__m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
|
||||||
|
_mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
|
||||||
|
_mm256_shuffle_epi8(_mm256_set1_epi32(x32),
|
||||||
|
_mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
|
||||||
|
0x0101010101010101, 0x0000000000000000))));
|
||||||
|
return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
|
||||||
|
}
|
||||||
|
|
||||||
const TA *const A;
|
const TA *const A;
|
||||||
const TB *const B;
|
const TB *const B;
|
||||||
TC *const C;
|
TC *const C;
|
||||||
@ -1182,6 +1223,22 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case GGML_TYPE_Q5_0: {
|
||||||
|
if (Btype != GGML_TYPE_Q8_0)
|
||||||
|
return false;
|
||||||
|
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
||||||
|
tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
|
||||||
|
k, (const block_q5_0 *)A, lda,
|
||||||
|
(const block_q8_0 *)B, ldb,
|
||||||
|
(float *)C, ldc,
|
||||||
|
ith, nth};
|
||||||
|
tb.matmul(m, n);
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
case GGML_TYPE_IQ4_NL: {
|
case GGML_TYPE_IQ4_NL: {
|
||||||
if (Btype != GGML_TYPE_Q8_0)
|
if (Btype != GGML_TYPE_Q8_0)
|
||||||
return false;
|
return false;
|
||||||
|
74
ggml/src/vulkan-shaders/pool2d.comp
Normal file
74
ggml/src/vulkan-shaders/pool2d.comp
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "types.comp"
|
||||||
|
|
||||||
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
|
|
||||||
|
layout(push_constant) uniform parameter {
|
||||||
|
uint IW; uint IH;
|
||||||
|
uint OW; uint OH;
|
||||||
|
uint OC;
|
||||||
|
uint pelements;
|
||||||
|
uint op;
|
||||||
|
int k0; int k1;
|
||||||
|
int s0; int s1;
|
||||||
|
int p0; int p1;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
#define BLOCK_SIZE 512
|
||||||
|
#define FLT_MAX 3.402823466e+38F
|
||||||
|
#define OP_POOL_MAX 0u
|
||||||
|
#define OP_POOL_AVG 1u
|
||||||
|
|
||||||
|
layout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
|
layout(binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint idx = gl_GlobalInvocationID.x;
|
||||||
|
if (idx >= p.pelements) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint O_HW = p.OW * p.OH;
|
||||||
|
|
||||||
|
const uint nc = idx / O_HW;
|
||||||
|
const uint cur_oh = (idx % O_HW) / p.OW;
|
||||||
|
const uint cur_ow = (idx % O_HW) % p.OW;
|
||||||
|
|
||||||
|
const int start_h = int(cur_oh) * p.s0 - p.p0;
|
||||||
|
const uint bh = max(start_h, 0);
|
||||||
|
const uint eh = min(start_h + p.k0, p.IH);
|
||||||
|
|
||||||
|
const int start_w = int(cur_ow) * p.s1 - p.p1;
|
||||||
|
const uint bw = max(start_w, 0);
|
||||||
|
const uint ew = min(start_w + p.k1, p.IW);
|
||||||
|
|
||||||
|
const float scale = 1.0 / float(p.k0 * p.k1);
|
||||||
|
float res;
|
||||||
|
|
||||||
|
if (p.op == OP_POOL_AVG) {
|
||||||
|
res = 0.0;
|
||||||
|
} else if (p.op == OP_POOL_MAX) {
|
||||||
|
res = -FLT_MAX;
|
||||||
|
} else {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (uint i = bh; i < eh; i++) {
|
||||||
|
#pragma unroll
|
||||||
|
for (uint j = bw; j < ew; j++) {
|
||||||
|
const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]);
|
||||||
|
|
||||||
|
if (p.op == OP_POOL_AVG) {
|
||||||
|
res += cur * scale;
|
||||||
|
} else if (p.op == OP_POOL_MAX) {
|
||||||
|
res = max(res, cur);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res;
|
||||||
|
}
|
@ -493,6 +493,10 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
|
|||||||
tasks.push_back(std::async(std::launch::async, [=] {
|
tasks.push_back(std::async(std::launch::async, [=] {
|
||||||
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
tasks.push_back(std::async(std::launch::async, [=] {
|
||||||
|
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
void write_output_files() {
|
void write_output_files() {
|
||||||
|
@ -205,7 +205,7 @@ extern "C" {
|
|||||||
enum llama_split_mode {
|
enum llama_split_mode {
|
||||||
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
|
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
|
||||||
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
|
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
|
||||||
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
|
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
|
||||||
@ -217,6 +217,7 @@ extern "C" {
|
|||||||
|
|
||||||
typedef struct llama_token_data_array {
|
typedef struct llama_token_data_array {
|
||||||
// TODO: consider SoA
|
// TODO: consider SoA
|
||||||
|
// NOTE: this pointer can be modified by the samplers
|
||||||
llama_token_data * data;
|
llama_token_data * data;
|
||||||
size_t size;
|
size_t size;
|
||||||
int64_t selected; // this is the index in the data array (i.e. not the token id)
|
int64_t selected; // this is the index in the data array (i.e. not the token id)
|
||||||
@ -232,8 +233,11 @@ extern "C" {
|
|||||||
// - token : the token ids of the input (used when embd is NULL)
|
// - token : the token ids of the input (used when embd is NULL)
|
||||||
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||||
// - pos : the positions of the respective token in the sequence
|
// - pos : the positions of the respective token in the sequence
|
||||||
|
// (if set to NULL, the token position will be tracked automatically by llama_decode)
|
||||||
// - seq_id : the sequence to which the respective token belongs
|
// - seq_id : the sequence to which the respective token belongs
|
||||||
|
// (if set to NULL, the sequence ID will be assumed to be 0)
|
||||||
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||||
|
// (if set to NULL, only the logits for last token will be returned)
|
||||||
//
|
//
|
||||||
typedef struct llama_batch {
|
typedef struct llama_batch {
|
||||||
int32_t n_tokens;
|
int32_t n_tokens;
|
||||||
@ -244,15 +248,6 @@ extern "C" {
|
|||||||
int32_t * n_seq_id;
|
int32_t * n_seq_id;
|
||||||
llama_seq_id ** seq_id;
|
llama_seq_id ** seq_id;
|
||||||
int8_t * logits; // TODO: rename this to "output"
|
int8_t * logits; // TODO: rename this to "output"
|
||||||
|
|
||||||
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
|
||||||
// for future-proof code, use the above fields instead and ignore everything below
|
|
||||||
//
|
|
||||||
// pos[i] = all_pos_0 + i*all_pos_1
|
|
||||||
//
|
|
||||||
llama_pos all_pos_0; // used if pos == NULL
|
|
||||||
llama_pos all_pos_1; // used if pos == NULL
|
|
||||||
llama_seq_id all_seq_id; // used if seq_id == NULL
|
|
||||||
} llama_batch;
|
} llama_batch;
|
||||||
|
|
||||||
enum llama_model_kv_override_type {
|
enum llama_model_kv_override_type {
|
||||||
@ -279,10 +274,7 @@ extern "C" {
|
|||||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||||
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
||||||
|
|
||||||
// main_gpu interpretation depends on split_mode:
|
// the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE
|
||||||
// LLAMA_SPLIT_MODE_NONE: the GPU that is used for the entire model
|
|
||||||
// LLAMA_SPLIT_MODE_ROW: the GPU that is used for small tensors and intermediate results
|
|
||||||
// LLAMA_SPLIT_MODE_LAYER: ignored
|
|
||||||
int32_t main_gpu;
|
int32_t main_gpu;
|
||||||
|
|
||||||
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
|
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
|
||||||
@ -776,15 +768,15 @@ extern "C" {
|
|||||||
// Decoding
|
// Decoding
|
||||||
//
|
//
|
||||||
|
|
||||||
// Return batch for single sequence of tokens starting at pos_0
|
// Return batch for single sequence of tokens
|
||||||
|
// The sequence ID will be fixed to 0
|
||||||
|
// The position of the tokens will be tracked automatically by llama_decode
|
||||||
//
|
//
|
||||||
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
||||||
//
|
//
|
||||||
LLAMA_API struct llama_batch llama_batch_get_one(
|
LLAMA_API struct llama_batch llama_batch_get_one(
|
||||||
llama_token * tokens,
|
llama_token * tokens,
|
||||||
int32_t n_tokens,
|
int32_t n_tokens);
|
||||||
llama_pos pos_0,
|
|
||||||
llama_seq_id seq_id);
|
|
||||||
|
|
||||||
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
||||||
// Each token can be assigned up to n_seq_max sequence ids
|
// Each token can be assigned up to n_seq_max sequence ids
|
||||||
@ -1075,12 +1067,13 @@ extern "C" {
|
|||||||
|
|
||||||
// available samplers:
|
// available samplers:
|
||||||
|
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void);
|
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
||||||
|
|
||||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||||
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
|
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
|
||||||
|
"will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
|
||||||
|
|
||||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
||||||
@ -1091,16 +1084,18 @@ extern "C" {
|
|||||||
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
|
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
|
||||||
|
|
||||||
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, size_t min_keep);
|
|
||||||
|
|
||||||
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
|
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
|
||||||
|
|
||||||
|
/// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
|
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
|
||||||
|
|
||||||
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
|
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
|
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
|
||||||
|
|
||||||
|
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
||||||
|
|
||||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||||
@ -1140,11 +1135,43 @@ extern "C" {
|
|||||||
bool penalize_nl, // consider newlines as a repeatable token
|
bool penalize_nl, // consider newlines as a repeatable token
|
||||||
bool ignore_eos); // ignore the end-of-sequence token
|
bool ignore_eos); // ignore the end-of-sequence token
|
||||||
|
|
||||||
|
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
|
||||||
|
const struct llama_model * model,
|
||||||
|
float dry_multiplier,
|
||||||
|
float dry_base,
|
||||||
|
int32_t dry_allowed_length,
|
||||||
|
int32_t dry_penalty_last_n,
|
||||||
|
const char ** seq_breakers,
|
||||||
|
size_t num_breakers);
|
||||||
|
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
||||||
int32_t n_vocab,
|
int32_t n_vocab,
|
||||||
int32_t n_logit_bias,
|
int32_t n_logit_bias,
|
||||||
const llama_logit_bias * logit_bias);
|
const llama_logit_bias * logit_bias);
|
||||||
|
|
||||||
|
// this sampler is meant to be used for fill-in-the-middle infilling
|
||||||
|
// it's supposed to be used after top_k + top_p sampling
|
||||||
|
//
|
||||||
|
// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
|
||||||
|
// 2. combine probs of tokens that have the same prefix
|
||||||
|
//
|
||||||
|
// example:
|
||||||
|
//
|
||||||
|
// - before:
|
||||||
|
// "hel": 0.5
|
||||||
|
// "hell": 0.2
|
||||||
|
// "hello": 0.1
|
||||||
|
// "dummy": 0.1
|
||||||
|
//
|
||||||
|
// - after:
|
||||||
|
// "hel": 0.8
|
||||||
|
// "dummy": 0.1
|
||||||
|
//
|
||||||
|
// 3. discard non-EOG tokens with low prob
|
||||||
|
// 4. if no tokens are left -> pick EOT
|
||||||
|
//
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
|
||||||
|
|
||||||
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
||||||
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
||||||
|
@ -20,7 +20,7 @@ logger = logging.getLogger("compare-llama-bench")
|
|||||||
# Properties by which to differentiate results per commit:
|
# Properties by which to differentiate results per commit:
|
||||||
KEY_PROPERTIES = [
|
KEY_PROPERTIES = [
|
||||||
"cpu_info", "gpu_info", "n_gpu_layers", "cuda", "vulkan", "kompute", "metal", "sycl", "rpc", "gpu_blas",
|
"cpu_info", "gpu_info", "n_gpu_layers", "cuda", "vulkan", "kompute", "metal", "sycl", "rpc", "gpu_blas",
|
||||||
"blas", "model_filename", "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "embeddings", "n_threads",
|
"blas", "model_filename", "model_type", "n_batch", "n_ubatch", "embeddings", "n_threads",
|
||||||
"type_k", "type_v", "use_mmap", "no_kv_offload", "split_mode", "main_gpu", "tensor_split", "flash_attn", "n_prompt", "n_gen"
|
"type_k", "type_v", "use_mmap", "no_kv_offload", "split_mode", "main_gpu", "tensor_split", "flash_attn", "n_prompt", "n_gen"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -15,22 +15,22 @@ CLI_ARGS_LLAMA_CLI_PERPLEXITY = [
|
|||||||
"export", "file", "frequency-penalty", "grammar", "grammar-file", "hellaswag",
|
"export", "file", "frequency-penalty", "grammar", "grammar-file", "hellaswag",
|
||||||
"hellaswag-tasks", "ignore-eos", "in-prefix", "in-prefix-bos", "in-suffix",
|
"hellaswag-tasks", "ignore-eos", "in-prefix", "in-prefix-bos", "in-suffix",
|
||||||
"interactive", "interactive-first", "keep", "logdir", "logit-bias", "lora", "lora-base",
|
"interactive", "interactive-first", "keep", "logdir", "logit-bias", "lora", "lora-base",
|
||||||
"low-vram", "main-gpu", "memory-f32", "mirostat", "mirostat-ent", "mirostat-lr", "mlock",
|
"low-vram", "main-gpu", "mirostat", "mirostat-ent", "mirostat-lr", "mlock",
|
||||||
"model", "multiline-input", "n-gpu-layers", "n-predict", "no-mmap", "no-mul-mat-q",
|
"model", "multiline-input", "n-gpu-layers", "n-predict", "no-mmap", "no-mul-mat-q",
|
||||||
"np-penalize-nl", "numa", "ppl-output-type", "ppl-stride", "presence-penalty", "prompt",
|
"np-penalize-nl", "numa", "ppl-output-type", "ppl-stride", "presence-penalty", "prompt",
|
||||||
"prompt-cache", "prompt-cache-all", "prompt-cache-ro", "repeat-last-n",
|
"prompt-cache", "prompt-cache-all", "prompt-cache-ro", "repeat-last-n",
|
||||||
"repeat-penalty", "reverse-prompt", "rope-freq-base", "rope-freq-scale", "rope-scale", "seed",
|
"repeat-penalty", "reverse-prompt", "rope-freq-base", "rope-freq-scale", "rope-scale", "seed",
|
||||||
"simple-io", "tensor-split", "threads", "temp", "tfs", "top-k", "top-p", "typical",
|
"simple-io", "tensor-split", "threads", "temp", "top-k", "top-p", "typical",
|
||||||
"verbose-prompt"
|
"verbose-prompt"
|
||||||
]
|
]
|
||||||
|
|
||||||
CLI_ARGS_LLAMA_BENCH = [
|
CLI_ARGS_LLAMA_BENCH = [
|
||||||
"batch-size", "memory-f32", "low-vram", "model", "mul-mat-q", "n-gen", "n-gpu-layers",
|
"batch-size", "low-vram", "model", "mul-mat-q", "n-gen", "n-gpu-layers",
|
||||||
"n-prompt", "output", "repetitions", "tensor-split", "threads", "verbose"
|
"n-prompt", "output", "repetitions", "tensor-split", "threads", "verbose"
|
||||||
]
|
]
|
||||||
|
|
||||||
CLI_ARGS_LLAMA_SERVER = [
|
CLI_ARGS_LLAMA_SERVER = [
|
||||||
"alias", "batch-size", "ctx-size", "embedding", "host", "memory-f32", "lora", "lora-base",
|
"alias", "batch-size", "ctx-size", "embedding", "host", "lora", "lora-base",
|
||||||
"low-vram", "main-gpu", "mlock", "model", "n-gpu-layers", "n-probs", "no-mmap", "no-mul-mat-q",
|
"low-vram", "main-gpu", "mlock", "model", "n-gpu-layers", "n-probs", "no-mmap", "no-mul-mat-q",
|
||||||
"numa", "path", "port", "rope-freq-base", "timeout", "rope-freq-scale", "tensor-split",
|
"numa", "path", "port", "rope-freq-base", "timeout", "rope-freq-scale", "tensor-split",
|
||||||
"threads", "verbose"
|
"threads", "verbose"
|
||||||
|
@ -76,6 +76,7 @@ while read c; do
|
|||||||
src/ggml*.m \
|
src/ggml*.m \
|
||||||
src/ggml*.metal \
|
src/ggml*.metal \
|
||||||
src/ggml*.cu \
|
src/ggml*.cu \
|
||||||
|
src/ggml-amx/* \
|
||||||
src/ggml-cann/* \
|
src/ggml-cann/* \
|
||||||
src/ggml-cuda/* \
|
src/ggml-cuda/* \
|
||||||
src/ggml-sycl/* \
|
src/ggml-sycl/* \
|
||||||
@ -121,6 +122,8 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
|||||||
# src/ggml-aarch64.c -> ggml/src/ggml-aarch64.c
|
# src/ggml-aarch64.c -> ggml/src/ggml-aarch64.c
|
||||||
# src/ggml-aarch64.h -> ggml/src/ggml-aarch64.h
|
# src/ggml-aarch64.h -> ggml/src/ggml-aarch64.h
|
||||||
# src/ggml-alloc.c -> ggml/src/ggml-alloc.c
|
# src/ggml-alloc.c -> ggml/src/ggml-alloc.c
|
||||||
|
# src/ggml-amx/* -> ggml/src/ggml-amx/
|
||||||
|
# src/ggml-amx.cpp -> ggml/src/ggml-amx.cpp
|
||||||
# src/ggml-backend-impl.h -> ggml/src/ggml-backend-impl.h
|
# src/ggml-backend-impl.h -> ggml/src/ggml-backend-impl.h
|
||||||
# src/ggml-backend.cpp -> ggml/src/ggml-backend.cpp
|
# src/ggml-backend.cpp -> ggml/src/ggml-backend.cpp
|
||||||
# src/ggml-cann/* -> ggml/src/ggml-cann/
|
# src/ggml-cann/* -> ggml/src/ggml-cann/
|
||||||
@ -141,6 +144,7 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
|||||||
#
|
#
|
||||||
# include/ggml.h -> ggml/include/ggml.h
|
# include/ggml.h -> ggml/include/ggml.h
|
||||||
# include/ggml-alloc.h -> ggml/include/ggml-alloc.h
|
# include/ggml-alloc.h -> ggml/include/ggml-alloc.h
|
||||||
|
# include/ggml-amx.h -> ggml/include/ggml-amx.h
|
||||||
# include/ggml-backend.h -> ggml/include/ggml-backend.h
|
# include/ggml-backend.h -> ggml/include/ggml-backend.h
|
||||||
# include/ggml-blas.h -> ggml/include/ggml-blas.h
|
# include/ggml-blas.h -> ggml/include/ggml-blas.h
|
||||||
# include/ggml-cann.h -> ggml/include/ggml-cann.h
|
# include/ggml-cann.h -> ggml/include/ggml-cann.h
|
||||||
@ -168,6 +172,8 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
|||||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-aarch64\.c/\1ggml\/src\/ggml-aarch64.c/g' \
|
-e 's/([[:space:]]|[ab]\/)src\/ggml-aarch64\.c/\1ggml\/src\/ggml-aarch64.c/g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-aarch64\.h/\1ggml\/src\/ggml-aarch64.h/g' \
|
-e 's/([[:space:]]|[ab]\/)src\/ggml-aarch64\.h/\1ggml\/src\/ggml-aarch64.h/g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-alloc\.c/\1ggml\/src\/ggml-alloc.c/g' \
|
-e 's/([[:space:]]|[ab]\/)src\/ggml-alloc\.c/\1ggml\/src\/ggml-alloc.c/g' \
|
||||||
|
-e 's/([[:space:]]|[ab]\/)src\/ggml-amx\//\1ggml\/src\/ggml-amx\//g' \
|
||||||
|
-e 's/([[:space:]]|[ab]\/)src\/ggml-amx\.cpp/\1ggml\/src\/ggml-amx.cpp/g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-backend-impl\.h/\1ggml\/src\/ggml-backend-impl.h/g' \
|
-e 's/([[:space:]]|[ab]\/)src\/ggml-backend-impl\.h/\1ggml\/src\/ggml-backend-impl.h/g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-backend\.cpp/\1ggml\/src\/ggml-backend.cpp/g' \
|
-e 's/([[:space:]]|[ab]\/)src\/ggml-backend\.cpp/\1ggml\/src\/ggml-backend.cpp/g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-cann\//\1ggml\/src\/ggml-cann\//g' \
|
-e 's/([[:space:]]|[ab]\/)src\/ggml-cann\//\1ggml\/src\/ggml-cann\//g' \
|
||||||
@ -187,6 +193,7 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
|||||||
-e 's/([[:space:]]|[ab]\/)src\/vulkan-shaders\//\1ggml\/src\/vulkan-shaders\//g' \
|
-e 's/([[:space:]]|[ab]\/)src\/vulkan-shaders\//\1ggml\/src\/vulkan-shaders\//g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)include\/ggml\.h/\1ggml\/include\/ggml.h/g' \
|
-e 's/([[:space:]]|[ab]\/)include\/ggml\.h/\1ggml\/include\/ggml.h/g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-alloc\.h/\1ggml\/include\/ggml-alloc.h/g' \
|
-e 's/([[:space:]]|[ab]\/)include\/ggml-alloc\.h/\1ggml\/include\/ggml-alloc.h/g' \
|
||||||
|
-e 's/([[:space:]]|[ab]\/)include\/ggml-amx\.h/\1ggml\/include\/ggml-amx.h/g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-backend\.h/\1ggml\/include\/ggml-backend.h/g' \
|
-e 's/([[:space:]]|[ab]\/)include\/ggml-backend\.h/\1ggml\/include\/ggml-backend.h/g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-blas\.h/\1ggml\/include\/ggml-blas.h/g' \
|
-e 's/([[:space:]]|[ab]\/)include\/ggml-blas\.h/\1ggml\/include\/ggml-blas.h/g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-cann\.h/\1ggml\/include\/ggml-cann.h/g' \
|
-e 's/([[:space:]]|[ab]\/)include\/ggml-cann\.h/\1ggml\/include\/ggml-cann.h/g' \
|
||||||
|
@ -1 +1 @@
|
|||||||
564f42082f858f9674b2a2e06e9e779d9ed2c754
|
bb78a40dc60e04c626bac2b65840b509988e990d
|
||||||
|
@ -8,6 +8,8 @@ cp -rpv ../ggml/src/ggml.c ./ggml/src/ggml.c
|
|||||||
cp -rpv ../ggml/src/ggml-aarch64.c ./ggml/src/ggml-aarch64.c
|
cp -rpv ../ggml/src/ggml-aarch64.c ./ggml/src/ggml-aarch64.c
|
||||||
cp -rpv ../ggml/src/ggml-aarch64.h ./ggml/src/ggml-aarch64.h
|
cp -rpv ../ggml/src/ggml-aarch64.h ./ggml/src/ggml-aarch64.h
|
||||||
cp -rpv ../ggml/src/ggml-alloc.c ./ggml/src/ggml-alloc.c
|
cp -rpv ../ggml/src/ggml-alloc.c ./ggml/src/ggml-alloc.c
|
||||||
|
cp -rpv ../ggml/src/ggml-amx/* ./ggml/src/ggml-amx/
|
||||||
|
cp -rpv ../ggml/src/ggml-amx.cpp ./ggml/src/ggml-amx.cpp
|
||||||
cp -rpv ../ggml/src/ggml-backend-impl.h ./ggml/src/ggml-backend-impl.h
|
cp -rpv ../ggml/src/ggml-backend-impl.h ./ggml/src/ggml-backend-impl.h
|
||||||
cp -rpv ../ggml/src/ggml-backend.cpp ./ggml/src/ggml-backend.cpp
|
cp -rpv ../ggml/src/ggml-backend.cpp ./ggml/src/ggml-backend.cpp
|
||||||
cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/
|
cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/
|
||||||
@ -29,6 +31,7 @@ cp -rpv ../ggml/src/vulkan-shaders/* ./ggml/src/vulkan-shaders/
|
|||||||
|
|
||||||
cp -rpv ../ggml/include/ggml.h ./ggml/include/ggml.h
|
cp -rpv ../ggml/include/ggml.h ./ggml/include/ggml.h
|
||||||
cp -rpv ../ggml/include/ggml-alloc.h ./ggml/include/ggml-alloc.h
|
cp -rpv ../ggml/include/ggml-alloc.h ./ggml/include/ggml-alloc.h
|
||||||
|
cp -rpv ../ggml/include/ggml-amx.h ./ggml/include/ggml-amx.h
|
||||||
cp -rpv ../ggml/include/ggml-backend.h ./ggml/include/ggml-backend.h
|
cp -rpv ../ggml/include/ggml-backend.h ./ggml/include/ggml-backend.h
|
||||||
cp -rpv ../ggml/include/ggml-blas.h ./ggml/include/ggml-blas.h
|
cp -rpv ../ggml/include/ggml-blas.h ./ggml/include/ggml-blas.h
|
||||||
cp -rpv ../ggml/include/ggml-cann.h ./ggml/include/ggml-cann.h
|
cp -rpv ../ggml/include/ggml-cann.h ./ggml/include/ggml-cann.h
|
||||||
|
@ -63,6 +63,30 @@ static void llama_log_softmax(float * array, size_t size) {
|
|||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
|
||||||
|
if (temp <= 0.0f) {
|
||||||
|
// find the token with the highest logit and set the rest to -inf
|
||||||
|
size_t max_i = 0;
|
||||||
|
float max_l = cur_p->data[0].logit;
|
||||||
|
|
||||||
|
for (size_t i = 1; i < cur_p->size; ++i) {
|
||||||
|
if (cur_p->data[i ].logit > max_l) {
|
||||||
|
cur_p->data[max_i].logit = -INFINITY;
|
||||||
|
max_i = i;
|
||||||
|
max_l = cur_p->data[i].logit;
|
||||||
|
} else {
|
||||||
|
cur_p->data[i].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
cur_p->data[i].logit /= temp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
||||||
GGML_ASSERT(cur_p->size > 0);
|
GGML_ASSERT(cur_p->size > 0);
|
||||||
|
|
||||||
@ -89,7 +113,7 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
|
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
|
||||||
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
// TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
|
||||||
// if (k >= (int32_t)cur_p->size) {
|
// if (k >= (int32_t)cur_p->size) {
|
||||||
// return;
|
// return;
|
||||||
// }
|
// }
|
||||||
@ -427,6 +451,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
|
|||||||
|
|
||||||
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
||||||
|
|
||||||
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
|
||||||
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -706,101 +733,6 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// tail-free
|
|
||||||
|
|
||||||
struct llama_sampler_tail_free {
|
|
||||||
const float z;
|
|
||||||
const size_t min_keep;
|
|
||||||
};
|
|
||||||
|
|
||||||
static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
|
|
||||||
return "tail-free";
|
|
||||||
}
|
|
||||||
|
|
||||||
static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
||||||
const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
|
|
||||||
|
|
||||||
if (ctx->z >= 1.0f || cur_p->size <= 2) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_sampler_softmax_impl(cur_p);
|
|
||||||
|
|
||||||
// Compute the first and second derivatives
|
|
||||||
std::vector<float> first_derivatives(cur_p->size - 1);
|
|
||||||
std::vector<float> second_derivatives(cur_p->size - 2);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < first_derivatives.size(); ++i) {
|
|
||||||
first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
|
||||||
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate absolute value of second derivatives
|
|
||||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
|
||||||
second_derivatives[i] = std::abs(second_derivatives[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize the second derivatives
|
|
||||||
{
|
|
||||||
const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
|
|
||||||
|
|
||||||
if (second_derivatives_sum > 1e-6f) {
|
|
||||||
for (float & value : second_derivatives) {
|
|
||||||
value /= second_derivatives_sum;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (float & value : second_derivatives) {
|
|
||||||
value = 1.0f / second_derivatives.size();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
float cum_sum = 0.0f;
|
|
||||||
size_t last_idx = cur_p->size;
|
|
||||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
|
||||||
cum_sum += second_derivatives[i];
|
|
||||||
|
|
||||||
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
|
|
||||||
if (cum_sum > ctx->z && i >= ctx->min_keep) {
|
|
||||||
last_idx = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resize the output vector to keep only the tokens above the tail location
|
|
||||||
cur_p->size = last_idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
|
|
||||||
const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
|
|
||||||
return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
|
|
||||||
delete (llama_sampler_tail_free *) smpl->ctx;
|
|
||||||
}
|
|
||||||
|
|
||||||
static struct llama_sampler_i llama_sampler_tail_free_i = {
|
|
||||||
/* .name = */ llama_sampler_tail_free_name,
|
|
||||||
/* .accept = */ nullptr,
|
|
||||||
/* .apply = */ llama_sampler_tail_free_apply,
|
|
||||||
/* .reset = */ nullptr,
|
|
||||||
/* .clone = */ llama_sampler_tail_free_clone,
|
|
||||||
/* .free = */ llama_sampler_tail_free_free,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
|
|
||||||
return new llama_sampler {
|
|
||||||
/* .iface = */ &llama_sampler_tail_free_i,
|
|
||||||
/* .ctx = */ new llama_sampler_tail_free {
|
|
||||||
/* .z = */ z,
|
|
||||||
/*. min_keep = */ min_keep,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// typical
|
// typical
|
||||||
|
|
||||||
struct llama_sampler_typical {
|
struct llama_sampler_typical {
|
||||||
@ -912,9 +844,8 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
|
|||||||
|
|
||||||
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
||||||
cur_p->data[i].logit /= ctx->temp;
|
llama_sampler_temp_impl(cur_p, ctx->temp);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
|
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
|
||||||
@ -961,6 +892,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|||||||
if (ctx->delta > 0) {
|
if (ctx->delta > 0) {
|
||||||
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
|
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
|
||||||
const float max_temp = ctx->temp + ctx->delta;
|
const float max_temp = ctx->temp + ctx->delta;
|
||||||
|
|
||||||
float exponent_val = ctx->exponent;
|
float exponent_val = ctx->exponent;
|
||||||
|
|
||||||
// no need to do anything if there is only one (or zero) candidates
|
// no need to do anything if there is only one (or zero) candidates
|
||||||
@ -998,9 +930,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Apply the dynamically calculated temperature scaling
|
// Apply the dynamically calculated temperature scaling
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
llama_sampler_temp_impl(cur_p, dyn_temp);
|
||||||
cur_p->data[i].logit /= dyn_temp;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
||||||
const double max_l_double = cur_p->data[0].logit;
|
const double max_l_double = cur_p->data[0].logit;
|
||||||
@ -1024,9 +954,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
llama_sampler_temp_impl(cur_p, ctx->temp);
|
||||||
cur_p->data[i].logit /= ctx->temp;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1059,6 +987,101 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// xtc
|
||||||
|
|
||||||
|
struct llama_sampler_xtc {
|
||||||
|
const float probability;
|
||||||
|
const float threshold;
|
||||||
|
const size_t min_keep;
|
||||||
|
|
||||||
|
const uint32_t seed;
|
||||||
|
uint32_t seed_cur;
|
||||||
|
|
||||||
|
std::mt19937 rng;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
|
||||||
|
return "xtc";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
||||||
|
|
||||||
|
if (ctx->probability <= 0.0f
|
||||||
|
|| ctx->threshold > 0.5f
|
||||||
|
|| cur_p->size < 2) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
|
||||||
|
float chance = distribution(ctx->rng);
|
||||||
|
if (chance > ctx->probability) return;
|
||||||
|
|
||||||
|
// in case it's not sorted/recalculated yet
|
||||||
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
|
||||||
|
int pos_last = 0;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
if (cur_p->data[i].p >= ctx->threshold) {
|
||||||
|
pos_last = i;
|
||||||
|
} else break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
|
||||||
|
cur_p->data += pos_last;
|
||||||
|
cur_p->size -= pos_last;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
||||||
|
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
|
||||||
|
auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
|
||||||
|
|
||||||
|
// copy the state
|
||||||
|
{
|
||||||
|
auto * result_ctx = (llama_sampler_xtc *) result->ctx;
|
||||||
|
|
||||||
|
result_ctx->rng = ctx->rng;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
|
||||||
|
delete (llama_sampler_xtc *) smpl->ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
|
||||||
|
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
||||||
|
ctx->seed_cur = get_rng_seed(ctx->seed);
|
||||||
|
ctx->rng.seed(ctx->seed_cur);
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler_i llama_sampler_xtc_i = {
|
||||||
|
/* .name = */ llama_sampler_xtc_name,
|
||||||
|
/* .accept = */ nullptr,
|
||||||
|
/* .apply = */ llama_sample_xtc_apply,
|
||||||
|
/* .reset = */ llama_sampler_xtc_reset,
|
||||||
|
/* .clone = */ llama_sampler_xtc_clone,
|
||||||
|
/* .free = */ llama_sampler_xtc_free,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
||||||
|
auto seed_cur = get_rng_seed(seed);
|
||||||
|
return new llama_sampler {
|
||||||
|
/* .iface = */ &llama_sampler_xtc_i,
|
||||||
|
/* .ctx = */ new llama_sampler_xtc {
|
||||||
|
/* .probability = */ p,
|
||||||
|
/* .threshold = */ t,
|
||||||
|
/* .min_keep = */ min_keep,
|
||||||
|
/* .seed = */ seed,
|
||||||
|
/* .seed_cur = */ seed_cur,
|
||||||
|
/* .rng = */ std::mt19937(seed_cur),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// mirostat
|
// mirostat
|
||||||
|
|
||||||
struct llama_sampler_mirostat {
|
struct llama_sampler_mirostat {
|
||||||
@ -1565,6 +1588,397 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DRY
|
||||||
|
|
||||||
|
struct llama_sampler_dry {
|
||||||
|
int32_t total_context_size;
|
||||||
|
|
||||||
|
const float dry_multiplier;
|
||||||
|
const float dry_base;
|
||||||
|
const int32_t dry_allowed_length;
|
||||||
|
const int32_t dry_penalty_last_n;
|
||||||
|
|
||||||
|
std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
|
||||||
|
std::vector<int> dry_repeat_count;
|
||||||
|
std::unordered_map<llama_token, int> dry_max_token_repeat;
|
||||||
|
ring_buffer<llama_token> last_tokens;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
||||||
|
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
|
||||||
|
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
|
||||||
|
std::string word = llama_detokenize(vocab, {token_id}, true);
|
||||||
|
if (word.find(str) != std::string::npos) {
|
||||||
|
token_sequences.emplace(token_id, std::vector<llama_token>());
|
||||||
|
} else {
|
||||||
|
size_t word_len = word.size(), str_len = str.size();
|
||||||
|
size_t pos = -1;
|
||||||
|
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
|
||||||
|
bool match = true;
|
||||||
|
size_t i;
|
||||||
|
for (i = 1; i < str_len && i + pos < word_len; ++i) {
|
||||||
|
if (word[pos + i] != str[i]) {
|
||||||
|
match = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (match) {
|
||||||
|
std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
|
||||||
|
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
|
||||||
|
tokenization.resize(max_tail_len);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure we don't already have a duplicate matching tokenization
|
||||||
|
auto its = token_sequences.equal_range(token_id);
|
||||||
|
bool found = false;
|
||||||
|
for (auto it = its.first; it != its.second; ++it) {
|
||||||
|
if (tokenization == it->second) {
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found) {
|
||||||
|
token_sequences.emplace(token_id, tokenization);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
|
||||||
|
return "dry";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
|
||||||
|
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||||
|
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->last_tokens.push_back(token);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
||||||
|
static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||||
|
|
||||||
|
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0);
|
||||||
|
int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size);
|
||||||
|
|
||||||
|
if (last_n_repeat <= ctx->dry_allowed_length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->dry_repeat_count.assign(last_n_repeat, 0);
|
||||||
|
ctx->dry_max_token_repeat.clear();
|
||||||
|
|
||||||
|
// Step 1: Look for restart sequences to limit the maximum repetition length.
|
||||||
|
// Work backwards through the context looking for any token that begins a restart sequence.
|
||||||
|
//
|
||||||
|
// The collection `restart_sequences` is a mapping from a "head" token to all "tail"
|
||||||
|
// sequences that together comprise a restart sequence. This allows us to quickly check
|
||||||
|
// whether each token is the head of a complete sequence. Most restart sequences are actually
|
||||||
|
// a single token, and for these the "tail" is an empty vector.
|
||||||
|
//
|
||||||
|
// If the token is a "head", test all restart sequences that begin with this token
|
||||||
|
// (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
|
||||||
|
// 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
|
||||||
|
// longest matching sequence (if any) is used to limit the maximum repetition length.
|
||||||
|
//
|
||||||
|
// Note that in the case case of a short sequence contained in a longer one, this might fail to
|
||||||
|
// find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
|
||||||
|
// restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
|
||||||
|
// 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
|
||||||
|
//
|
||||||
|
// This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
|
||||||
|
// have already clamped the maximum tail sequence length when generating `restart_sequences`.
|
||||||
|
// With clamping, this scan is O(N) in the context length.
|
||||||
|
|
||||||
|
int rep_limit = last_n_repeat;
|
||||||
|
for (int i = 0; i < last_n_repeat; ++i) {
|
||||||
|
llama_token token = ctx->last_tokens.rat(i);
|
||||||
|
auto its = ctx->dry_processed_breakers.equal_range(token);
|
||||||
|
if (its.first == ctx->dry_processed_breakers.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int longest_match = -1;
|
||||||
|
for (auto it = its.first; it != its.second; ++it) {
|
||||||
|
// Note that (*it) does not contain the head character, so seq_len will be
|
||||||
|
// the restart sequence length minus 1.
|
||||||
|
// In the common case of a single-token restart sequence, (*it) will be empty
|
||||||
|
// and we will trivially match.
|
||||||
|
int seq_len = (int)it->second.size();
|
||||||
|
if (seq_len > longest_match && seq_len <= (int)i) {
|
||||||
|
bool match = true;
|
||||||
|
for (int offset = 0; offset < seq_len; ++offset) {
|
||||||
|
// The -1 when indexing `last_tokens` is because we already matched the head.
|
||||||
|
if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
|
||||||
|
match = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (match) {
|
||||||
|
longest_match = seq_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (longest_match >= 0) {
|
||||||
|
// We found a restart sequence starting `i` tokens from the end and continuing for
|
||||||
|
// `longest_match` tokens.
|
||||||
|
rep_limit = i - longest_match;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (rep_limit < ctx->dry_allowed_length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
|
||||||
|
// the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
|
||||||
|
// elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
|
||||||
|
//
|
||||||
|
// This algorithm is not currently documented on Wikipedia, but there is a clear description here:
|
||||||
|
// https://ivanyu.me/blog/2014/10/15/z-algorithm/
|
||||||
|
//
|
||||||
|
// The code below is adapted from the public domain implementation by the same author here:
|
||||||
|
// https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// Last N tokens: a b c c b c y a b c
|
||||||
|
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
|
||||||
|
// ^
|
||||||
|
// This `3` means that the last three tokens of the context (a b c) also appear here.
|
||||||
|
//
|
||||||
|
// This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
|
||||||
|
// for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
|
||||||
|
// repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
|
||||||
|
// ensure that the inner while loops only examine each token in the context once as the outer
|
||||||
|
// for loop iterates over the context.
|
||||||
|
|
||||||
|
{
|
||||||
|
const int last = last_n_repeat - 1;
|
||||||
|
int rt = 0, lt = 0;
|
||||||
|
|
||||||
|
for (int k = 1; k < last_n_repeat; ++k) {
|
||||||
|
if (k > rt) {
|
||||||
|
// If k is outside the current Z-box, do naive computation.
|
||||||
|
int n = 0;
|
||||||
|
while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
|
||||||
|
++n;
|
||||||
|
}
|
||||||
|
ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
|
||||||
|
if (n > 0) {
|
||||||
|
lt = k;
|
||||||
|
rt = k+n-1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If k is inside the current Z-box, consider two cases.
|
||||||
|
|
||||||
|
int p = k - lt; // Pair index.
|
||||||
|
int right_part_len = rt - k + 1;
|
||||||
|
|
||||||
|
if (ctx->dry_repeat_count[last - p] < right_part_len) {
|
||||||
|
int n = std::min(ctx->dry_repeat_count[last - p], rep_limit);
|
||||||
|
ctx->dry_repeat_count[last - k] = n;
|
||||||
|
} else {
|
||||||
|
int i = rt + 1;
|
||||||
|
while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int n = std::min(i - k, rep_limit);
|
||||||
|
ctx->dry_repeat_count[last - k] = n;
|
||||||
|
lt = k;
|
||||||
|
rt = i - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
|
||||||
|
// that would be generated by emitting each new token that would extend a sequence.
|
||||||
|
//
|
||||||
|
// Following the same example as above:
|
||||||
|
// Last N tokens: a b c c b c y a b c
|
||||||
|
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
|
||||||
|
//
|
||||||
|
// For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
|
||||||
|
// c: 3 -> 4 (from `a b c` to `a b c c`)
|
||||||
|
// b: 1 -> 2 (from `c` to `c b`)
|
||||||
|
// y: 2 -> 3 (from `b c` to `b c y`)
|
||||||
|
|
||||||
|
for (int i = 0; i < last_n_repeat - 1; ++i) {
|
||||||
|
int repeat_len = ctx->dry_repeat_count[i];
|
||||||
|
if (repeat_len >= ctx->dry_allowed_length) {
|
||||||
|
// This token ends a repeat, so the next token would continue one.
|
||||||
|
// By convention, the value of `repeat_len` only includes the tokens currently
|
||||||
|
// in the context, not the new token that would be added.
|
||||||
|
llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
|
||||||
|
// Track the maximum sequence ending in this token.
|
||||||
|
const auto& it = ctx->dry_max_token_repeat.find(token);
|
||||||
|
if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
|
||||||
|
ctx->dry_max_token_repeat[token] = repeat_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
|
||||||
|
|
||||||
|
// Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
|
||||||
|
// Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
|
||||||
|
const float FLOAT_MAX_LOG = 88.7228391f;
|
||||||
|
int max_exponent = 0;
|
||||||
|
if (ctx->dry_base > 1.000001f) {
|
||||||
|
max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id);
|
||||||
|
if (af_kvp != ctx->dry_max_token_repeat.end()) {
|
||||||
|
// Check all sequence breakers starting with this token
|
||||||
|
auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id);
|
||||||
|
bool is_single_token_breaker = false;
|
||||||
|
|
||||||
|
for (auto it = range.first; it != range.second; ++it) {
|
||||||
|
if (it->second.empty()) {
|
||||||
|
is_single_token_breaker = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply penalty only if it's not a single-token sequence breaker
|
||||||
|
if (!is_single_token_breaker) {
|
||||||
|
int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
|
||||||
|
if (max_exponent > 0 && repeat_exp > max_exponent) {
|
||||||
|
repeat_exp = max_exponent;
|
||||||
|
}
|
||||||
|
float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp);
|
||||||
|
cur_p->data[i].logit -= penalty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cur_p->sorted = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
|
||||||
|
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||||
|
ctx->last_tokens.clear();
|
||||||
|
ctx->dry_repeat_count.clear();
|
||||||
|
ctx->dry_max_token_repeat.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
|
||||||
|
const auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||||
|
|
||||||
|
// nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
|
||||||
|
auto * result = llama_sampler_init_dry(nullptr, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
|
||||||
|
// Copy the state, including the processed breakers
|
||||||
|
{
|
||||||
|
auto * result_ctx = (llama_sampler_dry *) result->ctx;
|
||||||
|
result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
|
||||||
|
result_ctx->dry_repeat_count = ctx->dry_repeat_count;
|
||||||
|
result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
|
||||||
|
result_ctx->last_tokens = ctx->last_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_dry_free(struct llama_sampler * smpl) {
|
||||||
|
delete (llama_sampler_dry *) smpl->ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler_i llama_sampler_dry_i = {
|
||||||
|
/* .name = */ llama_sampler_dry_name,
|
||||||
|
/* .accept = */ llama_sampler_dry_accept,
|
||||||
|
/* .apply = */ llama_sampler_dry_apply,
|
||||||
|
/* .reset = */ llama_sampler_dry_reset,
|
||||||
|
/* .clone = */ llama_sampler_dry_clone,
|
||||||
|
/* .free = */ llama_sampler_dry_free,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
||||||
|
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
|
||||||
|
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
|
||||||
|
const int MAX_CHAR_LEN = 40;
|
||||||
|
const int MAX_SEQ_LEN = 20;
|
||||||
|
|
||||||
|
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
||||||
|
|
||||||
|
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
||||||
|
// Process sequence breakers
|
||||||
|
for (size_t i = 0; i < num_breakers; ++i) {
|
||||||
|
if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
|
||||||
|
LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string sequence_break(seq_breakers[i]);
|
||||||
|
if (sequence_break.empty()) {
|
||||||
|
LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sequence_break.size() > MAX_CHAR_LEN) {
|
||||||
|
LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
|
||||||
|
sequence_break.resize(MAX_CHAR_LEN);
|
||||||
|
}
|
||||||
|
|
||||||
|
get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new llama_sampler {
|
||||||
|
/* .iface = */ &llama_sampler_dry_i,
|
||||||
|
/* .ctx = */ new llama_sampler_dry {
|
||||||
|
/* .total_context_size = */ context_size,
|
||||||
|
/* .dry_multiplier = */ dry_multiplier,
|
||||||
|
/* .dry_base = */ dry_base,
|
||||||
|
/* .dry_allowed_length = */ dry_allowed_length,
|
||||||
|
/* .dry_penalty_last_n = */ dry_penalty_last_n,
|
||||||
|
/* .dry_processed_breakers = */ std::move(processed_breakers),
|
||||||
|
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
|
||||||
|
/* .dry_max_token_repeat = */ {},
|
||||||
|
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapper for test-sampling.cpp
|
||||||
|
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
|
||||||
|
llama_vocab dummy_vocab;
|
||||||
|
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
|
||||||
|
auto * ctx = (llama_sampler_dry *) result->ctx;
|
||||||
|
|
||||||
|
// Process the token-based sequence breakers
|
||||||
|
ctx->dry_processed_breakers.clear();
|
||||||
|
if (seq_breakers.empty()) {
|
||||||
|
LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
|
||||||
|
} else {
|
||||||
|
for (const auto& breaker : seq_breakers) {
|
||||||
|
if (breaker.empty()) {
|
||||||
|
LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
llama_token head_token = breaker[0];
|
||||||
|
std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
|
||||||
|
ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx->dry_processed_breakers.empty()) {
|
||||||
|
LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// logit-bias
|
// logit-bias
|
||||||
|
|
||||||
struct llama_sampler_logit_bias {
|
struct llama_sampler_logit_bias {
|
||||||
@ -1644,6 +2058,229 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// infill
|
||||||
|
|
||||||
|
//#define GGML_DEBUG_SAMPLER_INFILL
|
||||||
|
|
||||||
|
struct llama_sampler_infill {
|
||||||
|
const struct llama_vocab * vocab;
|
||||||
|
|
||||||
|
std::vector<char> buf0;
|
||||||
|
std::vector<char> buf1;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
|
||||||
|
return "infill";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
auto * ctx = (llama_sampler_infill *) smpl->ctx;
|
||||||
|
|
||||||
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
|
||||||
|
#if defined(GGML_DEBUG_SAMPLER_INFILL)
|
||||||
|
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
|
||||||
|
#else
|
||||||
|
#define LOG_DBG_CUR(...)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
||||||
|
}
|
||||||
|
|
||||||
|
float p_txt_sum = 0.0f;
|
||||||
|
float p_eog_sum = 0.0f;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
||||||
|
p_eog_sum += cur_p->data[i].p;
|
||||||
|
} else {
|
||||||
|
p_txt_sum += cur_p->data[i].p;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
|
||||||
|
|
||||||
|
LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
|
||||||
|
|
||||||
|
if (3*p_eog_sum*cur_p->size > p_txt_sum) {
|
||||||
|
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
|
||||||
|
|
||||||
|
// keep just the EOG tokens
|
||||||
|
const auto size_org = cur_p->size;
|
||||||
|
|
||||||
|
cur_p->size = 0;
|
||||||
|
|
||||||
|
float p_sum = 0.0f;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < size_org; ++i) {
|
||||||
|
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
||||||
|
p_sum += cur_p->data[i].p;
|
||||||
|
|
||||||
|
cur_p->data[cur_p->size++] = cur_p->data[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize probs
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
cur_p->data[i].p /= p_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t n_combined = 0; GGML_UNUSED(n_combined);
|
||||||
|
|
||||||
|
// combine tokens with common prefix
|
||||||
|
for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
|
||||||
|
for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
|
||||||
|
if (cur_p->data[i0].logit == -INFINITY) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
||||||
|
if (len0 < 0) {
|
||||||
|
ctx->buf0.resize(len0);
|
||||||
|
len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
||||||
|
assert(len0 > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
||||||
|
if (len1 < 0) {
|
||||||
|
ctx->buf1.resize(len1);
|
||||||
|
len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
||||||
|
assert(len1 > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// token i0 is a prefix of token i1
|
||||||
|
if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
|
||||||
|
int dst = i0;
|
||||||
|
int src = i1;
|
||||||
|
|
||||||
|
// merge into the token with higher probability
|
||||||
|
if (cur_p->data[i1].p > cur_p->data[i0].p) {
|
||||||
|
std::swap(dst, src);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur_p->data[dst].p += cur_p->data[src].p;
|
||||||
|
cur_p->data[src].logit = -INFINITY;
|
||||||
|
cur_p->data[src].p = 0.0f;
|
||||||
|
|
||||||
|
n_combined++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t n_non_eog = 0;
|
||||||
|
|
||||||
|
size_t size_org = cur_p->size;
|
||||||
|
|
||||||
|
float p_sum = 0.0f;
|
||||||
|
float thold = 0.2f;
|
||||||
|
|
||||||
|
cur_p->size = 0;
|
||||||
|
|
||||||
|
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < size_org; ++i) {
|
||||||
|
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
||||||
|
|
||||||
|
if (cur_p->data[i].p < thold && !is_eog) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_eog) {
|
||||||
|
++n_non_eog;
|
||||||
|
}
|
||||||
|
|
||||||
|
p_sum += cur_p->data[i].p;
|
||||||
|
|
||||||
|
// keep this token
|
||||||
|
cur_p->data[cur_p->size++] = cur_p->data[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
|
||||||
|
|
||||||
|
// if no non-EOG tokens are left -> reduce cur_p to single EOT token
|
||||||
|
if (n_non_eog == 0) {
|
||||||
|
cur_p->size = 1;
|
||||||
|
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
|
||||||
|
cur_p->data[0].logit = 1.0f;
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize probs
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
cur_p->data[i].p /= p_sum;
|
||||||
|
|
||||||
|
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_org = cur_p->size;
|
||||||
|
p_sum = 0.0f;
|
||||||
|
thold = 1.0/(n_non_eog + 1);
|
||||||
|
|
||||||
|
cur_p->size = 0;
|
||||||
|
|
||||||
|
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < size_org; ++i) {
|
||||||
|
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
||||||
|
|
||||||
|
if (cur_p->data[i].p < thold && !is_eog) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
p_sum += cur_p->data[i].p;
|
||||||
|
|
||||||
|
cur_p->data[cur_p->size++] = cur_p->data[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize probs
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
cur_p->data[i].p /= p_sum;
|
||||||
|
|
||||||
|
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef LOG_DBG_CUR
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
|
||||||
|
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
|
||||||
|
return llama_sampler_init_infill_impl(*ctx->vocab);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_infill_free(struct llama_sampler * smpl) {
|
||||||
|
delete (llama_sampler_infill *) smpl->ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler_i llama_sampler_infill_i = {
|
||||||
|
/* .name = */ llama_sampler_infill_name,
|
||||||
|
/* .accept = */ nullptr,
|
||||||
|
/* .apply = */ llama_sampler_infill_apply,
|
||||||
|
/* .reset = */ nullptr,
|
||||||
|
/* .clone = */ llama_sampler_infill_clone,
|
||||||
|
/* .free = */ llama_sampler_infill_free,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_infill_impl(
|
||||||
|
const struct llama_vocab & vocab) {
|
||||||
|
return new llama_sampler {
|
||||||
|
/* .iface = */ &llama_sampler_infill_i,
|
||||||
|
/* .ctx = */ new llama_sampler_infill {
|
||||||
|
/* .vocab = */ &vocab,
|
||||||
|
/* .buf0 = */ std::vector<char>(512),
|
||||||
|
/* .buf1 = */ std::vector<char>(512),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// utils
|
// utils
|
||||||
|
|
||||||
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
||||||
|
@ -4,8 +4,6 @@
|
|||||||
|
|
||||||
#include "llama-grammar.h"
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
struct llama_vocab;
|
struct llama_vocab;
|
||||||
struct llama_grammar;
|
struct llama_grammar;
|
||||||
|
|
||||||
@ -27,3 +25,24 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
|
|||||||
const struct llama_vocab & vocab,
|
const struct llama_vocab & vocab,
|
||||||
const char * grammar_str,
|
const char * grammar_str,
|
||||||
const char * grammar_root);
|
const char * grammar_root);
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_infill_impl(
|
||||||
|
const struct llama_vocab & vocab);
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_dry_impl(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
int32_t context_size,
|
||||||
|
float dry_multiplier,
|
||||||
|
float dry_base,
|
||||||
|
int32_t dry_allowed_length,
|
||||||
|
int32_t dry_penalty_last_n,
|
||||||
|
const char ** seq_breakers,
|
||||||
|
size_t num_breakers);
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_dry_testing(
|
||||||
|
int32_t context_size,
|
||||||
|
float dry_multiplier,
|
||||||
|
float dry_base,
|
||||||
|
int32_t dry_allowed_length,
|
||||||
|
int32_t dry_penalty_last_n,
|
||||||
|
const std::vector<std::vector<llama_token>>& seq_breakers);
|
||||||
|
@ -221,7 +221,7 @@ struct llm_tokenizer_spm_session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// seed the work queue with all possible 2-character tokens.
|
// seed the work queue with all possible 2-character tokens.
|
||||||
for (size_t i = 1; i < symbols.size(); ++i) {
|
for (int i = 1; i < (int) symbols.size(); ++i) {
|
||||||
try_add_bigram(i - 1, i);
|
try_add_bigram(i - 1, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -563,7 +563,7 @@ struct llm_tokenizer_bpe_session {
|
|||||||
index++;
|
index++;
|
||||||
symbols.emplace_back(sym);
|
symbols.emplace_back(sym);
|
||||||
}
|
}
|
||||||
for (size_t i = 1; i < symbols.size(); ++i) {
|
for (int i = 1; i < (int) symbols.size(); ++i) {
|
||||||
add_new_bigram(i - 1, i);
|
add_new_bigram(i - 1, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1966,3 +1966,19 @@ int32_t llama_detokenize_impl(
|
|||||||
|
|
||||||
return total <= text_len_max ? total : -total;
|
return total <= text_len_max ? total : -total;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector<llama_token> & tokens, bool special) {
|
||||||
|
std::string text;
|
||||||
|
text.resize(std::max(text.capacity(), tokens.size()));
|
||||||
|
int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
||||||
|
if (n_chars < 0) {
|
||||||
|
text.resize(-n_chars);
|
||||||
|
n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
||||||
|
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
|
||||||
|
}
|
||||||
|
|
||||||
|
text.resize(n_chars);
|
||||||
|
|
||||||
|
// NOTE: the original tokenizer decodes bytes after collecting the pieces.
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
@ -149,6 +149,12 @@ int32_t llama_token_to_piece_impl(
|
|||||||
int32_t lstrip,
|
int32_t lstrip,
|
||||||
bool special);
|
bool special);
|
||||||
|
|
||||||
|
// check if token0 is contained as a prefix in token1
|
||||||
|
bool llama_token_is_prefix_impl(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
llama_token token0,
|
||||||
|
llama_token token1);
|
||||||
|
|
||||||
int32_t llama_detokenize_impl(
|
int32_t llama_detokenize_impl(
|
||||||
const struct llama_vocab & vocab,
|
const struct llama_vocab & vocab,
|
||||||
const llama_token * tokens,
|
const llama_token * tokens,
|
||||||
@ -157,3 +163,8 @@ int32_t llama_detokenize_impl(
|
|||||||
int32_t text_len_max,
|
int32_t text_len_max,
|
||||||
bool remove_special,
|
bool remove_special,
|
||||||
bool unparse_special);
|
bool unparse_special);
|
||||||
|
|
||||||
|
std::string llama_detokenize(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
const std::vector<llama_token> & tokens,
|
||||||
|
bool special);
|
||||||
|
3673
src/llama.cpp
3673
src/llama.cpp
File diff suppressed because it is too large
Load Diff
@ -1683,9 +1683,10 @@ struct test_mul_mat : public test_case {
|
|||||||
const int64_t k;
|
const int64_t k;
|
||||||
const std::array<int64_t, 2> bs; // dims 3 and 4
|
const std::array<int64_t, 2> bs; // dims 3 and 4
|
||||||
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
|
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
|
||||||
|
const std::array<int64_t, 4> per; // permutation of dimensions
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR7(type_a, type_b, m, n, k, bs, nr);
|
return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per);
|
||||||
}
|
}
|
||||||
|
|
||||||
double max_nmse_err() override {
|
double max_nmse_err() override {
|
||||||
@ -1700,18 +1701,45 @@ struct test_mul_mat : public test_case {
|
|||||||
test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
|
test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
|
||||||
int64_t m = 32, int64_t n = 32, int64_t k = 32,
|
int64_t m = 32, int64_t n = 32, int64_t k = 32,
|
||||||
std::array<int64_t, 2> bs = {10, 10},
|
std::array<int64_t, 2> bs = {10, 10},
|
||||||
std::array<int64_t, 2> nr = {2, 2})
|
std::array<int64_t, 2> nr = {2, 2},
|
||||||
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {}
|
std::array<int64_t, 4> per = {0, 1, 2, 3})
|
||||||
|
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
|
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
|
||||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0] , bs[1]);
|
ggml_tensor * a;
|
||||||
ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
|
ggml_tensor * b;
|
||||||
|
|
||||||
|
const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
|
||||||
|
if (npermuted > 0) {
|
||||||
|
GGML_ASSERT(npermuted == 2);
|
||||||
|
GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
|
||||||
|
GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);
|
||||||
|
|
||||||
|
// Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k.
|
||||||
|
const int64_t ne_a[4] = {k, m, bs[0], bs[1]};
|
||||||
|
const int64_t ne_b[4] = {k, n, bs[0]*nr[0], bs[1]*nr[1]};
|
||||||
|
|
||||||
|
a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
|
||||||
|
b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
|
||||||
ggml_set_param(ctx, a);
|
ggml_set_param(ctx, a);
|
||||||
ggml_set_param(ctx, b);
|
ggml_set_param(ctx, b);
|
||||||
ggml_set_name(a, "a");
|
ggml_set_name(a, "a");
|
||||||
ggml_set_name(b, "b");
|
ggml_set_name(b, "b");
|
||||||
|
|
||||||
|
a = ggml_permute(ctx, a, per[0], per[1], per[2], per[3]);
|
||||||
|
b = ggml_permute(ctx, b, per[0], per[1], per[2], per[3]);
|
||||||
|
ggml_set_name(a, "a_permuted");
|
||||||
|
ggml_set_name(b, "b_permuted");
|
||||||
|
} else {
|
||||||
|
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
|
||||||
|
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
|
||||||
|
ggml_set_param(ctx, a);
|
||||||
|
ggml_set_param(ctx, b);
|
||||||
|
ggml_set_name(a, "a");
|
||||||
|
ggml_set_name(b, "b");
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * out = ggml_mul_mat(ctx, a, b);
|
ggml_tensor * out = ggml_mul_mat(ctx, a, b);
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
@ -3339,13 +3367,49 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
|
// im2col 1D
|
||||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
|
|
||||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
|
|
||||||
// test cases for 1D im2col
|
|
||||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
|
||||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
|
||||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
|
||||||
|
for (int s0 : {1, 3}) {
|
||||||
|
for (int p0 : {0, 3}) {
|
||||||
|
for (int d0 : {1, 3}) {
|
||||||
|
test_cases.emplace_back(new test_im2col(
|
||||||
|
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1},
|
||||||
|
s0, 0, p0, 0, d0, 0, false));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// im2col 2D
|
||||||
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
|
||||||
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
|
||||||
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
|
||||||
|
for (int s0 : {1, 3}) {
|
||||||
|
for (int s1 : {1, 3}) {
|
||||||
|
for (int p0 : {0, 3}) {
|
||||||
|
for (int p1 : {0, 3}) {
|
||||||
|
for (int d0 : {1, 3}) {
|
||||||
|
for (int d1 : {1, 3}) {
|
||||||
|
test_cases.emplace_back(new test_im2col(
|
||||||
|
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2},
|
||||||
|
s0, s1, p0, p1, d0, d1, true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extra tests for im2col 2D
|
||||||
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
|
||||||
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
|
||||||
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));
|
||||||
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true));
|
||||||
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2048}, {3, 3, 1, 2048}, 1, 1, 1, 1, 1, 1, true));
|
||||||
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
|
||||||
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
|
||||||
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
|
||||||
|
|
||||||
// sycl backend will limit task global_range < MAX_INT
|
// sycl backend will limit task global_range < MAX_INT
|
||||||
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
|
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
|
||||||
@ -3474,6 +3538,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||||||
#if 1
|
#if 1
|
||||||
for (ggml_type type_a : base_types) {
|
for (ggml_type type_a : base_types) {
|
||||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
|
// test cases without permutation
|
||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
|
||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
|
||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
|
||||||
@ -3489,6 +3554,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
|
||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
|
||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
|
||||||
|
|
||||||
|
// test cases with permutation
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (ggml_type type_a : other_types) {
|
for (ggml_type type_a : other_types) {
|
||||||
|
@ -65,6 +65,8 @@ int main(void) {
|
|||||||
u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
|
u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
|
||||||
// DeepSeek-V2
|
// DeepSeek-V2
|
||||||
"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
|
"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
|
||||||
|
// ibm-granite/granite-3.0-8b-instruct
|
||||||
|
"{%- if tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n {%- for tool in tools %}\n {{- tool | tojson(indent=4) }}\n {%- if not loop.last %}\n {{- '\n\n' }}\n {%- endif %}\n {%- endfor %}\n {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'user' %}\n {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'tool_response' %}\n {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}",
|
||||||
};
|
};
|
||||||
std::vector<std::string> expected_output = {
|
std::vector<std::string> expected_output = {
|
||||||
// teknium/OpenHermes-2.5-Mistral-7B
|
// teknium/OpenHermes-2.5-Mistral-7B
|
||||||
@ -109,6 +111,8 @@ int main(void) {
|
|||||||
u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
|
u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
|
||||||
// DeepSeek-V2
|
// DeepSeek-V2
|
||||||
u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:",
|
u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:",
|
||||||
|
// ibm-granite/granite-3.0-8b-instruct
|
||||||
|
"<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n",
|
||||||
};
|
};
|
||||||
std::vector<char> formatted_chat(1024);
|
std::vector<char> formatted_chat(1024);
|
||||||
int32_t res;
|
int32_t res;
|
||||||
|
@ -696,7 +696,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
"pattern": "^abc?d*efg+(hij)?kl$"
|
"pattern": "^abc?d*efg+(hij)?kl$"
|
||||||
})""",
|
})""",
|
||||||
R"""(
|
R"""(
|
||||||
root ::= "\"" "ab" "c"? "d"* "ef" "g"+ ("hij")? "kl" "\"" space
|
root ::= "\"" ("ab" "c"? "d"* "ef" "g"+ ("hij")? "kl") "\"" space
|
||||||
space ::= | " " | "\n" [ \t]{0,20}
|
space ::= | " " | "\n" [ \t]{0,20}
|
||||||
)"""
|
)"""
|
||||||
});
|
});
|
||||||
@ -709,7 +709,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
"pattern": "^\\[\\]\\{\\}\\(\\)\\|\\+\\*\\?$"
|
"pattern": "^\\[\\]\\{\\}\\(\\)\\|\\+\\*\\?$"
|
||||||
})""",
|
})""",
|
||||||
R"""(
|
R"""(
|
||||||
root ::= "\"" "[]{}()|+*?" "\"" space
|
root ::= "\"" ("[]{}()|+*?") "\"" space
|
||||||
space ::= | " " | "\n" [ \t]{0,20}
|
space ::= | " " | "\n" [ \t]{0,20}
|
||||||
)"""
|
)"""
|
||||||
});
|
});
|
||||||
@ -722,7 +722,20 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
"pattern": "^\"$"
|
"pattern": "^\"$"
|
||||||
})""",
|
})""",
|
||||||
R"""(
|
R"""(
|
||||||
root ::= "\"" "\"" "\"" space
|
root ::= "\"" ("\"") "\"" space
|
||||||
|
space ::= | " " | "\n" [ \t]{0,20}
|
||||||
|
)"""
|
||||||
|
});
|
||||||
|
|
||||||
|
test({
|
||||||
|
SUCCESS,
|
||||||
|
"regexp with top-level alternation",
|
||||||
|
R"""({
|
||||||
|
"type": "string",
|
||||||
|
"pattern": "^A|B|C|D$"
|
||||||
|
})""",
|
||||||
|
R"""(
|
||||||
|
root ::= "\"" ("A" | "B" | "C" | "D") "\"" space
|
||||||
space ::= | " " | "\n" [ \t]{0,20}
|
space ::= | " " | "\n" [ \t]{0,20}
|
||||||
)"""
|
)"""
|
||||||
});
|
});
|
||||||
@ -736,7 +749,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
})""",
|
})""",
|
||||||
R"""(
|
R"""(
|
||||||
dot ::= [^\x0A\x0D]
|
dot ::= [^\x0A\x0D]
|
||||||
root ::= "\"" ("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot "\"" space
|
root ::= "\"" (("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot) "\"" space
|
||||||
root-1 ::= [0-9]
|
root-1 ::= [0-9]
|
||||||
space ::= | " " | "\n" [ \t]{0,20}
|
space ::= | " " | "\n" [ \t]{0,20}
|
||||||
)"""
|
)"""
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user