Merge branch 'master' into compilade/bitnet-ternary

This commit is contained in:
Francis Couture-Harpin 2024-08-22 16:42:24 -04:00
commit cb6d9962c4
77 changed files with 4681 additions and 2212 deletions

View File

@ -0,0 +1,44 @@
ARG ASCEND_VERSION=8.0.rc2.alpha003-910b-openeuler22.03-py3.8
FROM cosdt/cann:$ASCEND_VERSION AS build
WORKDIR /app
COPY . .
RUN yum install -y gcc g++ cmake make
ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest
ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH
ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH}
ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH}
ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH}
ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME}
ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp
ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit
ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME}
# find libascend_hal.so, because the drive hasn`t been mounted.
ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH
RUN echo "Building with static libs" && \
source /usr/local/Ascend/ascend-toolkit/set_env.sh --force && \
cmake -B build -DGGML_CANN=ON -DBUILD_SHARED_LIBS=OFF && \
cmake --build build --config Release --target llama-cli
# TODO: use image with NNRT
FROM cosdt/cann:$ASCEND_VERSION AS runtime
COPY --from=build /app/build/bin/llama-cli /llama-cli
ENV LC_ALL=C.utf8
ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest
ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH
ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH}
ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH}
ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH}
ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME}
ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp
ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit
ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME}
ENTRYPOINT ["/llama-cli" ]

View File

@ -1,3 +1,6 @@
# TODO: there have been some issues with the workflow, so disabling for now
# https://github.com/ggerganov/llama.cpp/issues/7893
#
# Benchmark # Benchmark
name: Benchmark name: Benchmark

3
.gitignore vendored
View File

@ -129,3 +129,6 @@ poetry.toml
# Scripts # Scripts
!/scripts/install-oneapi.bat !/scripts/install-oneapi.bat
# Test models for lora adapters
/lora-tests

View File

@ -28,6 +28,7 @@
{ "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } },
{ "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, { "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } },
{ "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, { "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } },
{ "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } },
{ {
"name": "arm64-windows-msvc", "hidden": true, "name": "arm64-windows-msvc", "hidden": true,
@ -60,6 +61,8 @@
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] },
{ "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] }, { "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] },
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] } { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] },
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] },
{ "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] }
] ]
} }

View File

@ -763,6 +763,10 @@ ifdef GGML_VULKAN_MEMORY_DEBUG
MK_CPPFLAGS += -DGGML_VULKAN_MEMORY_DEBUG MK_CPPFLAGS += -DGGML_VULKAN_MEMORY_DEBUG
endif endif
ifdef GGML_VULKAN_PERF
MK_CPPFLAGS += -DGGML_VULKAN_PERF
endif
ifdef GGML_VULKAN_VALIDATE ifdef GGML_VULKAN_VALIDATE
MK_CPPFLAGS += -DGGML_VULKAN_VALIDATE MK_CPPFLAGS += -DGGML_VULKAN_VALIDATE
endif endif

View File

@ -105,6 +105,8 @@ Typically finetunes of the base models below are supported as well.
- [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca) - [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca)
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b) - [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966) - [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md)) (instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))
@ -424,6 +426,7 @@ Please refer to [Build llama.cpp locally](./docs/build.md)
| [CUDA](./docs/build.md#cuda) | Nvidia GPU | | [CUDA](./docs/build.md#cuda) | Nvidia GPU |
| [hipBLAS](./docs/build.md#hipblas) | AMD GPU | | [hipBLAS](./docs/build.md#hipblas) | AMD GPU |
| [Vulkan](./docs/build.md#vulkan) | GPU | | [Vulkan](./docs/build.md#vulkan) | GPU |
| [CANN](./docs/build.md#cann) | Ascend NPU |
## Tools ## Tools

View File

@ -77,6 +77,41 @@
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
//
// Environment variable utils
//
template<typename T>
static typename std::enable_if<std::is_same<T, std::string>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
target = value ? std::string(value) : target;
}
template<typename T>
static typename std::enable_if<!std::is_same<T, bool>::value && std::is_integral<T>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
target = value ? std::stoi(value) : target;
}
template<typename T>
static typename std::enable_if<std::is_floating_point<T>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
target = value ? std::stof(value) : target;
}
template<typename T>
static typename std::enable_if<std::is_same<T, bool>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
if (value) {
std::string val(value);
target = val == "1" || val == "true";
}
}
// //
// CPU utils // CPU utils
// //
@ -110,8 +145,34 @@ int32_t cpu_get_num_physical_cores() {
if (result == 0) { if (result == 0) {
return num_physical_cores; return num_physical_cores;
} }
#elif defined(_WIN32) #elif defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
//TODO: Implement // TODO: windows + arm64 + mingw64
unsigned int n_threads_win = std::thread::hardware_concurrency();
unsigned int default_threads = n_threads_win > 0 ? (n_threads_win <= 4 ? n_threads_win : n_threads_win / 2) : 4;
DWORD buffer_size = 0;
if (!GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &buffer_size)) {
if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) {
return default_threads;
}
}
std::vector<char> buffer(buffer_size);
if (!GetLogicalProcessorInformationEx(RelationProcessorCore, reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data()), &buffer_size)) {
return default_threads;
}
int32_t num_physical_cores = 0;
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data());
while (buffer_size > 0) {
if (info->Relationship == RelationProcessorCore) {
num_physical_cores += info->Processor.GroupCount;
}
buffer_size -= info->Size;
info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(reinterpret_cast<char*>(info) + info->Size);
}
return num_physical_cores > 0 ? num_physical_cores : default_threads;
#endif #endif
unsigned int n_threads = std::thread::hardware_concurrency(); unsigned int n_threads = std::thread::hardware_concurrency();
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
@ -194,12 +255,6 @@ int32_t cpu_get_num_math() {
// CLI argument parsing // CLI argument parsing
// //
void gpt_params_handle_hf_token(gpt_params & params) {
if (params.hf_token.empty() && std::getenv("HF_TOKEN")) {
params.hf_token = std::getenv("HF_TOKEN");
}
}
void gpt_params_handle_model_default(gpt_params & params) { void gpt_params_handle_model_default(gpt_params & params) {
if (!params.hf_repo.empty()) { if (!params.hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model // short-hand to avoid specifying --hf-file -> default it to --model
@ -247,7 +302,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
gpt_params_handle_model_default(params); gpt_params_handle_model_default(params);
gpt_params_handle_hf_token(params); if (params.hf_token.empty()) {
get_env("HF_TOKEN", params.hf_token);
}
if (params.escape) { if (params.escape) {
string_process_escapes(params.prompt); string_process_escapes(params.prompt);
@ -267,6 +324,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
return true; return true;
} }
void gpt_params_parse_from_env(gpt_params & params) {
// we only care about server-related params for now
get_env("LLAMA_ARG_MODEL", params.model);
get_env("LLAMA_ARG_THREADS", params.n_threads);
get_env("LLAMA_ARG_CTX_SIZE", params.n_ctx);
get_env("LLAMA_ARG_N_PARALLEL", params.n_parallel);
get_env("LLAMA_ARG_BATCH", params.n_batch);
get_env("LLAMA_ARG_UBATCH", params.n_ubatch);
get_env("LLAMA_ARG_N_GPU_LAYERS", params.n_gpu_layers);
get_env("LLAMA_ARG_THREADS_HTTP", params.n_threads_http);
get_env("LLAMA_ARG_CHAT_TEMPLATE", params.chat_template);
get_env("LLAMA_ARG_N_PREDICT", params.n_predict);
get_env("LLAMA_ARG_ENDPOINT_METRICS", params.endpoint_metrics);
get_env("LLAMA_ARG_ENDPOINT_SLOTS", params.endpoint_slots);
get_env("LLAMA_ARG_EMBEDDINGS", params.embedding);
get_env("LLAMA_ARG_FLASH_ATTN", params.flash_attn);
get_env("LLAMA_ARG_DEFRAG_THOLD", params.defrag_thold);
}
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
const auto params_org = params; // the example can modify the default params const auto params_org = params; // the example can modify the default params
@ -1727,7 +1803,13 @@ std::string gpt_params_get_system_info(const gpt_params & params) {
if (params.n_threads_batch != -1) { if (params.n_threads_batch != -1) {
os << " (n_threads_batch = " << params.n_threads_batch << ")"; os << " (n_threads_batch = " << params.n_threads_batch << ")";
} }
#if defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
// TODO: windows + arm64 + mingw64
DWORD logicalProcessorCount = GetActiveProcessorCount(ALL_PROCESSOR_GROUPS);
os << " / " << logicalProcessorCount << " | " << llama_print_system_info();
#else
os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info(); os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
#endif
return os.str(); return os.str();
} }
@ -2702,12 +2784,6 @@ std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token>
return text; return text;
} }
bool llama_should_add_bos_token(const llama_model * model) {
const int add_bos = llama_add_bos_token(model);
return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
}
// //
// Chat template utils // Chat template utils
// //

View File

@ -267,7 +267,7 @@ struct gpt_params {
std::string lora_outfile = "ggml-lora-merged-f16.gguf"; std::string lora_outfile = "ggml-lora-merged-f16.gguf";
}; };
void gpt_params_handle_hf_token(gpt_params & params); void gpt_params_parse_from_env(gpt_params & params);
void gpt_params_handle_model_default(gpt_params & params); void gpt_params_handle_model_default(gpt_params & params);
bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params); bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params);
@ -380,10 +380,6 @@ std::string llama_detokenize(
const std::vector<llama_token> & tokens, const std::vector<llama_token> & tokens,
bool special = true); bool special = true);
// Uses the value from the model metadata if possible, otherwise
// defaults to true when model type is SPM, otherwise false.
bool llama_should_add_bos_token(const llama_model * model);
// //
// Chat template utils // Chat template utils
// //

View File

@ -295,6 +295,7 @@ class Model:
gguf.MODEL_TENSOR.FFN_GATE_INP, gguf.MODEL_TENSOR.FFN_GATE_INP,
gguf.MODEL_TENSOR.POS_EMBD, gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES, gguf.MODEL_TENSOR.TOKEN_TYPES,
gguf.MODEL_TENSOR.SSM_CONV1D,
) )
) )
or not name.endswith(".weight") or not name.endswith(".weight")
@ -608,6 +609,15 @@ class Model:
if chkhsh == "855059429035d75a914d1eda9f10a876752e281a054a7a3d421ef0533e5b6249": if chkhsh == "855059429035d75a914d1eda9f10a876752e281a054a7a3d421ef0533e5b6249":
# ref: https://huggingface.co/HuggingFaceTB/SmolLM-135M # ref: https://huggingface.co/HuggingFaceTB/SmolLM-135M
res = "smollm" res = "smollm"
if chkhsh == "3c30d3ad1d6b64202cd222813e7736c2db6e1bd6d67197090fc1211fbc612ae7":
# ref: https://huggingface.co/bigscience/bloom
res = "bloom"
if chkhsh == "bc01ce58980e1db43859146dc51b1758b3b88729b217a74792e9f8d43e479d21":
# ref: https://huggingface.co/TurkuNLP/gpt3-finnish-small
res = "gpt3-finnish"
if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae":
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct
res = "exaone"
if res is None: if res is None:
logger.warning("\n") logger.warning("\n")
@ -911,7 +921,7 @@ class GPTNeoXModel(Model):
return tensors return tensors
@Model.register("BloomForCausalLM") @Model.register("BloomForCausalLM", "BloomModel")
class BloomModel(Model): class BloomModel(Model):
model_arch = gguf.MODEL_ARCH.BLOOM model_arch = gguf.MODEL_ARCH.BLOOM
@ -2719,7 +2729,7 @@ class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2 model_arch = gguf.MODEL_ARCH.STARCODER2
@Model.register("MambaForCausalLM", "MambaLMHeadModel") @Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model): class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA model_arch = gguf.MODEL_ARCH.MAMBA
@ -2750,7 +2760,10 @@ class MambaModel(Model):
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16) dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
use_dt_b_c_norm = False
# For falconmamba we do apply RMS norm on B / DT and C layers
if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",):
use_dt_b_c_norm = True
# Fail early for models which don't have a block expansion factor of 2 # Fail early for models which don't have a block expansion factor of 2
assert d_inner == 2 * d_model assert d_inner == 2 * d_model
@ -2758,12 +2771,13 @@ class MambaModel(Model):
self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.hparams["n_layer"]) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv) self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner) self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state) self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank) self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_ssm_dt_b_c_rms(use_dt_b_c_norm) # For classic Mamba we don't apply rms norm on B / DT layers
self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_file_type(self.ftype)
_tok_embd = None _tok_embd = None
@ -2790,23 +2804,6 @@ class MambaModel(Model):
return [(new_name, data_torch)] return [(new_name, data_torch)]
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
if bid is not None and new_name in (
self.format_tensor_name(
n, bid, ".weight" if name.endswith(".weight") else ""
)
for n in [
gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.SSM_X,
gguf.MODEL_TENSOR.SSM_DT,
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]
):
return gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)
@Model.register("CohereForCausalLM") @Model.register("CohereForCausalLM")
class CommandR2Model(Model): class CommandR2Model(Model):
@ -3751,8 +3748,120 @@ class ChatGLMModel(Model):
name = name.removeprefix("transformer.") name = name.removeprefix("transformer.")
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
###### CONVERSION LOGIC ######
@Model.register("NemotronForCausalLM")
class NemotronModel(Model):
model_arch = gguf.MODEL_ARCH.NEMOTRON
def set_vocab(self):
self._set_vocab_sentencepiece()
self.gguf_writer.add_pad_token_id(0)
self.gguf_writer.add_unk_token_id(1)
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
f_norm_eps = self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon", "norm_eps"])
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
# * Partial RoPE
rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"])
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
# * RopeScaling for Nemotron
if "rope_scaling" not in self.hparams or self.hparams["rope_scaling"] is None:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
else:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side
# model.layers.{l}.input_layernorm.weight
# model.layers.{l}.post_attention_layernorm.weight
# model.norm.weight
if name.endswith("norm.weight"):
data_torch = data_torch + 1
return [(self.map_tensor_name(name), data_torch)]
@Model.register("ExaoneForCausalLM")
class ExaoneModel(Model):
model_arch = gguf.MODEL_ARCH.EXAONE
def set_gguf_parameters(self):
hparams = self.hparams
assert (hparams["activation_function"] == "silu")
max_position_embeddings = hparams["max_position_embeddings"]
embed_dim = hparams["hidden_size"]
num_heads = hparams["num_attention_heads"]
num_kv_heads = hparams.get("num_key_value_heads", num_heads)
layer_norm_eps = hparams["layer_norm_epsilon"]
intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
num_layers = hparams["num_layers"]
# ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
# attention_dropout_rate = hparams["attention_dropout"]
# ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
# embed_dropout_rate = hparams["embed_dropout"]
self.gguf_writer.add_embedding_length(embed_dim)
self.gguf_writer.add_head_count(num_heads)
self.gguf_writer.add_head_count_kv(num_kv_heads)
self.gguf_writer.add_context_length(max_position_embeddings)
self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_block_count(num_layers)
self.gguf_writer.add_file_type(self.ftype)
if (rope_theta := self.hparams.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta)
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True)
rotary_factor = rotary_factor if rotary_factor is not None else 1.0
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
if hparams.get("rope_scaling") is not None and "factor" in hparams["rope_scaling"]:
if hparams["rope_scaling"].get("type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
def prepare_tensors(self):
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
if rope_scaling.get("rope_type", '').lower() == "llama3":
base = self.hparams.get("rope_theta", 10000.0)
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
factor = rope_scaling.get("factor", 8.0)
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
assert low_freq_wavelen != high_freq_wavelen
rope_factors = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
rope_factors.append(1)
elif wavelen > low_freq_wavelen:
rope_factors.append(factor)
else:
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
super().prepare_tensors()
###### CONVERSION LOGIC ######
# tree of lazy tensors # tree of lazy tensors
class LazyTorchTensor(gguf.LazyBase): class LazyTorchTensor(gguf.LazyBase):

View File

@ -94,6 +94,9 @@ models = [
{"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", }, {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", },
{"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", }, {"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", },
{"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", }, {"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", },
{'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", },
{'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", },
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
] ]

View File

@ -116,7 +116,7 @@ class Tensor:
assert quant is not None, 'Unknown tensor type' assert quant is not None, 'Unknown tensor type'
(blksize, tysize) = quant (blksize, tysize) = quant
offset += 12 offset += 12
self.dtype= dtype self.dtype= gguf.GGMLQuantizationType(dtype)
self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)]) self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)])
offset += 4 * n_dims offset += 4 * n_dims
self.name = bytes(data[offset:offset + name_len]) self.name = bytes(data[offset:offset + name_len])

259
docs/backend/CANN.md Normal file
View File

@ -0,0 +1,259 @@
# llama.cpp for CANN
- [Background](#background)
- [News](#news)
- [OS](#os)
- [Hardware](#hardware)
- [Model Supports](#model-supports)
- [DataType Supports](#datatype-supports)
- [Docker](#docker)
- [Linux](#linux)
- [TODO](#todo)
## Background
**Ascend NPU** is a range of AI processors using Neural Processing Unit. It will efficiently handle matrix-matrix multiplication, dot-product and scalars.
**CANN** (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for AI scenarios, providing support for multiple AI frameworks on the top and serving AI processors and programming at the bottom. It plays a crucial role in bridging the gap between upper and lower layers, and is a key platform for improving the computing efficiency of Ascend AI processors. Meanwhile, it offers a highly efficient and easy-to-use programming interface for diverse application scenarios, allowing users to rapidly build AI applications and services based on the Ascend platform.
**Llama.cpp + CANN**
The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the ability of AscendC and ACLNN which are intergrated to CANN Toolkit and kernels to using Ascend NPU directly.
## News
- 2024.8
- Support `Q4_0` and `Q8_0` data type for Ascend NPU.
- 2024.7
- Create CANN backend for Ascend NPU.
## OS
| OS | Status | Verified |
|:-------:|:-------:|:----------------------------------------------:|
| Linux | Support | Ubuntu 22.04, OpenEuler22.03 |
## Hardware
### Ascend NPU
**Verified devices**
| Ascend NPU | Status |
|:-----------------------------:|:-------:|
| Atlas 300T A2 | Support |
*Notes:*
- If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag.
- If you run successfully with your Ascend NPU device, please help update the upper table.
## Model Supports
| Model Name | FP16 | Q8_0 | Q4_0 |
|:----------------------------|:-----:|:----:|:----:|
| AquilaChat2-7B | √ | √ | √ |
| Baichuan-7b | √ | √ | √ |
| Baichuan2-7B-Chat | √ | √ | √ |
| bitnet_b1_58-large | √ | √ | √ |
| bloom-560m | √ | x | √ |
| bloomz-alpaca-560m | √ | x | √ |
| c4ai-command-r-35B-v01 | x | x | x |
| chatglm3-6B | x | x | x |
| chinese-alpaca-2-1.3b | √ | √ | √ |
| CodeShell-7B | √ | √ | √ |
| deepseek-ai_deepseek-coder-1.3B-base | x | x | x |
| deepseek-ai_DeepSeek-V2-Lite | x | x | x |
| deepseek-coder-6.7B-instruct | x | x | x |
| DeepSeek-V2-Lite-64x1.5B | x | x | x |
| falcon-7b-instruct | √ | √ | √ |
| flan-t5-large | √ | √ | √ |
| gemma-2-9b-it | √ | √ | √ |
| glm-4-9B | x | x | x |
| gpt2 | √ | √ | √ |
| Gpt2-163M | √ | √ | √ |
| granite-3B-code-instruct | √ | √ | √ |
| GritLM-7B | √ | √ | √ |
| internlm2_5-7b-chat | √ | √ | √ |
| koala-7B-HF | √ | √ | √ |
| Llama-2-7b-chat-hf | √ | √ | √ |
| Llama-3-Smaug-8B | √ | √ | √ |
| Llama2-Chinese-7b-Chat | √ | √ | √ |
| Llama3-8B | √ | √ | √ |
| Llama3-8b-chinese | √ | √ | √ |
| mamba-130m-hf | √ | √ | √ |
| Mistral-7B-Instruct-v0.2 | √ | √ | √ |
| Mixtral-8x7B-Instruct-v0.1 | x | √ | √ |
| mpt-7B | √ | √ | √ |
| OLMo-1B-hf | √ | √ | √ |
| OpenELM-3B-Instruct | √ | √ | √ |
| Orion-14b-base | √ | √ | √ |
| phi1 | x | x | x |
| phi2 | x | x | x |
| Phi-3-mini-4k-instruct | √ | √ | √ |
| plamo-13b | √ | √ | √ |
| pythia-70M | x | x | x |
| Qwen-7B | √ | √ | √ |
| Qwen2-1.5B-Instruct | √ | x | √ |
| Refact-1_6B-fim | √ | √ | √ |
| SmolLM-135M | √ | √ | √ |
| stablelm-zephyr | x | x | x |
| stablelm-2-zephyr-1_6b | x | x | x |
| starcoderbase-1b | √ | √ | √ |
| starcoder2-3b | √ | √ | √ |
| vigogne-7b-chat | √ | √ | √ |
| xverse-7b-chat | √ | √ | √ |
| Yi-6b-Chat | √ | √ | √ |
## DataType Supports
| DataType | Status |
|:----------------------:|:-------:|
| FP16 | Support |
| Q8_0 | Support |
| Q4_0 | Support |
## Docker
### Build Images
You can get a image with llama.cpp in one command.
```sh
docker build -t llama-cpp-cann -f .devops/llama-cli-cann.Dockerfile .
```
### Run container
```sh
# Find all cards.
npu-smi info
# Select the cards that you want to use, make sure these cards are not used by someone.
# Following using cards of device0.
docker run --name llamacpp --device /dev/davinci0 --device /dev/davinci_manager --device /dev/devmm_svm --device /dev/hisi_hdc -v /usr/local/dcmi:/usr/local/dcmi -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info -v /PATH_TO_YOUR_MODELS/:/app/models -it llama-cpp-cann -m /app/models/MODEL_PATH -ngl 32 -p "Building a website can be done in 10 simple steps:"
```
*Notes:*
- You may need to install Ascend Driver and firmware on the **host** machine *(Please refer to the [Linux configuration](#linux) for details)*.
## Linux
### I. Setup Environment
1. **Install Ascend Driver and firmware**
```sh
# create driver running user.
sudo groupadd -g HwHiAiUser
sudo useradd -g HwHiAiUser -d /home/HwHiAiUser -m HwHiAiUser -s /bin/bash
sudo usermod -aG HwHiAiUser $USER
# download driver from https://www.hiascend.com/hardware/firmware-drivers/community according to your system
# and install driver.
sudo sh Ascend-hdk-910b-npu-driver_x.x.x_linux-{arch}.run --full --install-for-all
```
Once installed, run `npu-smi info` to check whether driver is installed successfully.
```sh
+-------------------------------------------------------------------------------------------+
| npu-smi 24.1.rc2 Version: 24.1.rc2 |
+----------------------+---------------+----------------------------------------------------+
| NPU Name | Health | Power(W) Temp(C) Hugepages-Usage(page)|
| Chip | Bus-Id | AICore(%) Memory-Usage(MB) HBM-Usage(MB) |
+======================+===============+====================================================+
| 2 xxx | OK | 64.4 51 15 / 15 |
| 0 | 0000:01:00.0 | 0 1873 / 15077 0 / 32768 |
+======================+===============+====================================================+
| 5 xxx | OK | 64.0 52 15 / 15 |
| 0 | 0000:81:00.0 | 0 1874 / 15077 0 / 32768 |
+======================+===============+====================================================+
| No running processes found in NPU 2 |
+======================+===============+====================================================+
| No running processes found in NPU 5 |
+======================+===============+====================================================+
```
2. **Install Ascend Firmware**
```sh
# download driver from https://www.hiascend.com/hardware/firmware-drivers/community according to your system
# and install driver.
sudo sh Ascend-hdk-910b-npu-firmware_x.x.x.x.X.run --full
```
If the following messaage appers, firmware is installed successfully.
```sh
Firmware package installed successfully!
```
3. **Install CANN toolkit and kernels**
CANN toolkit and kernels can be obtained from the official [CANN Toolkit](https://www.hiascend.com/zh/developer/download/community/result?module=cann) page.
Please download the corresponding version that satified your system. The minimum version required is 8.0.RC2.alpha002 and here is the install command.
```sh
pip3 install attrs numpy decorator sympy cffi pyyaml pathlib2 psutil protobuf scipy requests absl-py wheel typing_extensions
sh Ascend-cann-toolkit_8.0.RC2.alpha002_linux-aarch64.run --install
sh Ascend-cann-kernels-910b_8.0.RC2.alpha002_linux.run --install
```
Set Ascend Variables:
```sh
echo "source ~/Ascend/ascend-toolkit/set_env.sh" >> ~/.bashrc
source ~/.bashrc
```
Upon a successful installation, CANN is enabled for the available ascend devices.
### II. Build llama.cpp
```sh
cmake -B build -DGGML_CANN=on -DCMAKE_BUILD_TYPE=release
cmake --build build --config release
```
### III. Run the inference
1. **Retrieve and prepare model**
You can refer to the general [*Prepare and Quantize*](../../README.md#prepare-and-quantize) guide for model prepration.
**Notes**:
- CANN backend only supports FP16/Q4_0/Q8_0 models currently.
2. **Launch inference**
There are two device selection modes:
- Single device: Use one device target specified by the user.
- Multiple devices: Automatically choose the devices with the same backend.
| Device selection | Parameter |
|:----------------:|:--------------------------------------:|
| Single device | --split-mode none --main-gpu DEVICE_ID |
| Multiple devices | --split-mode layer (default) |
Examples:
- Use device 0:
```sh
./build/bin/llama-cli -m path_to_model -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm none -mg 0
```
- Use multiple devices:
```sh
./build/bin/llama-cli -m path_to_model -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm layer
```
### **GitHub contribution**:
Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay.
## TODO
- Support more models and data types.

View File

@ -20,7 +20,7 @@
**oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include: **oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include:
- **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers. - **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers.
- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL - Math Kernel Library)*. - **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL and oneDNN)*.
- **oneAPI LevelZero**: A high performance low level interface for fine-grained control over intel iGPUs and dGPUs. - **oneAPI LevelZero**: A high performance low level interface for fine-grained control over intel iGPUs and dGPUs.
- **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets. - **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets.
@ -28,10 +28,6 @@
The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it could support other vendor GPUs: Nvidia GPU (*AMD GPU coming*). The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it could support other vendor GPUs: Nvidia GPU (*AMD GPU coming*).
When targeting **Intel CPU**, it is recommended to use llama.cpp for [Intel oneMKL](README.md#intel-onemkl) backend.
It has the similar design of other llama.cpp BLAS-based paths such as *OpenBLAS, cuBLAS, etc..*. In beginning work, the oneAPI's [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) open-source migration tool (Commercial release [Intel® DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) was used for this purpose.
## Recommended Release ## Recommended Release
The SYCL backend would be broken by some PRs due to no online CI. The SYCL backend would be broken by some PRs due to no online CI.
@ -45,6 +41,10 @@ The following release is verified with good quality:
## News ## News
- 2024.8
- Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs.
- 2024.5 - 2024.5
- Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770. - Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770.
- Arch Linux is verified successfully. - Arch Linux is verified successfully.
@ -196,7 +196,7 @@ Please follow the instructions for downloading and installing the Toolkit for Li
Following guidelines/code snippets assume the default installation values. Otherwise, please make sure the necessary changes are reflected where applicable. Following guidelines/code snippets assume the default installation values. Otherwise, please make sure the necessary changes are reflected where applicable.
Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI MKL for intel GPUs. Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI oneDNN for Intel GPUs.
- **Adding support to Nvidia GPUs** - **Adding support to Nvidia GPUs**
@ -255,8 +255,6 @@ or
# Export relevant ENV variables # Export relevant ENV variables
source /opt/intel/oneapi/setvars.sh source /opt/intel/oneapi/setvars.sh
# Build LLAMA with MKL BLAS acceleration for intel GPU
# Option 1: Use FP32 (recommended for better performance in most cases) # Option 1: Use FP32 (recommended for better performance in most cases)
cmake -B build -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx cmake -B build -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx

View File

@ -352,6 +352,31 @@ cmake --build build --config Release
# ggml_vulkan: Using Intel(R) Graphics (ADL GT2) | uma: 1 | fp16: 1 | warp size: 32 # ggml_vulkan: Using Intel(R) Graphics (ADL GT2) | uma: 1 | fp16: 1 | warp size: 32
``` ```
### CANN
This provides NPU acceleration using the AI cores of your Ascend NPU. And [CANN](https://www.hiascend.com/en/software/cann) is a hierarchical APIs to help you to quickly build AI applications and service based on Ascend NPU.
For more information about Ascend NPU in [Ascend Community](https://www.hiascend.com/en/).
Make sure to have the CANN toolkit installed. You can download it from here: [CANN Toolkit](https://www.hiascend.com/developer/download/community/result?module=cann)
Go to `llama.cpp` directory and build using CMake.
```bash
cmake -B build -DGGML_CANN=on -DCMAKE_BUILD_TYPE=release
cmake --build build --config release
```
You can test with:
`./build/llama-cli -m PATH_TO_MODEL -p "Building a website can be done in 10 steps:" -ngl 32`
If the fllowing info is output on screen, you are using `llama.cpp by CANN backend`:
```bash
llm_load_tensors: CANN buffer size = 13313.00 MiB
llama_new_context_with_model: CANN compute buffer size = 1260.81 MiB
```
For detailed info, such as model/device supports, CANN install, please refer to [llama.cpp for CANN](./backend/CANN.md).
### Android ### Android
To read documentation for how to build on Android, [click here](./android.md) To read documentation for how to build on Android, [click here](./android.md)

View File

@ -271,7 +271,7 @@ struct tokenized_prompt {
size_t max_seq_len; size_t max_seq_len;
tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) { tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) {
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
tokens_pos = ::llama_tokenize(ctx, pos, add_bos, true); tokens_pos = ::llama_tokenize(ctx, pos, add_bos, true);
tokens_neg = ::llama_tokenize(ctx, neg, add_bos, true); tokens_neg = ::llama_tokenize(ctx, neg, add_bos, true);
max_seq_len = std::max(tokens_pos.size(), tokens_neg.size()); max_seq_len = std::max(tokens_pos.size(), tokens_neg.size());

View File

@ -127,7 +127,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
} }
static bool run(llama_context * ctx, const gpt_params & params) { static bool run(llama_context * ctx, const gpt_params & params) {
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos); std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);

View File

@ -433,8 +433,8 @@ static void process_logits(
} }
static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
auto tim1 = std::chrono::high_resolution_clock::now(); auto tim1 = std::chrono::high_resolution_clock::now();

View File

@ -203,8 +203,8 @@ int main(int argc, char ** argv) {
LOG_TEE("\n"); LOG_TEE("\n");
LOG_TEE("%s\n", gpt_params_get_system_info(params).c_str()); LOG_TEE("%s\n", gpt_params_get_system_info(params).c_str());
} }
const bool add_bos = llama_should_add_bos_token(model); const bool add_bos = llama_add_bos_token(model);
GGML_ASSERT(llama_add_eos_token(model) != 1); GGML_ASSERT(!llama_add_eos_token(model));
LOG("add_bos: %d\n", add_bos); LOG("add_bos: %d\n", add_bos);
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;

View File

@ -16,8 +16,8 @@ Convert PyTorch model to gguf files (You can also download the converted [gguf](
```bash ```bash
python ./examples/minicpmv/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5 python ./examples/minicpmv/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5
python ./examples/minicpmv/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 python ./examples/minicpmv/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 2
python ./convert-hf-to-gguf.py ../MiniCPM-Llama3-V-2_5/model python ./convert_hf_to_gguf.py ../MiniCPM-Llama3-V-2_5/model
# quantize int4 version # quantize int4 version
./llama-quantize ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf Q4_K_M ./llama-quantize ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf Q4_K_M

View File

@ -0,0 +1,107 @@
## MiniCPM-V 2.6
### Prepare models and code
Download [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) PyTorch model from huggingface to "MiniCPM-V-2_6" folder.
Clone llama.cpp:
```bash
git clone git@github.com:OpenBMB/llama.cpp.git
cd llama.cpp
git checkout minicpmv-main
```
### Usage of MiniCPM-V 2.6
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) by us)
```bash
python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-V-2_6
python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-2_6 --minicpmv-projector ../MiniCPM-V-2_6/minicpmv.projector --output-dir ../MiniCPM-V-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 3
python ./convert_hf_to_gguf.py ../MiniCPM-V-2_6/model
# quantize int4 version
./llama-quantize ../MiniCPM-V-2_6/model/ggml-model-f16.gguf ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M
```
Build for Linux or Mac
```bash
make
make llama-minicpmv-cli
```
Inference on Linux or Mac
```
# run f16 version
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
# run quantized int4 version
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
# or run in interactive mode
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i
```
### Video
Install FFmpeg
```
brew install ffmpeg
brew install pkg-config
```
### Android
#### Build on Android device using Termux
We found that build on Android device would bring better runtime performance, so we recommend to build on device.
[Termux](https://github.com/termux/termux-app#installation) is a terminal app on Android device (no root required).
Install tools in Termux:
```
apt update && apt upgrade -y
apt install git make cmake
```
It's recommended to move your model inside the `~/` directory for best performance:
```
cd storage/downloads
mv model.gguf ~/
```
#### Building the Project using Android NDK
Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake.
Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux:
```bash
mkdir build-android
cd build-android
export NDK=/your_ndk_path
cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod ..
make
```
Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice).
Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission:
(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`)
```
$cp -r /sdcard/llama.cpp/bin /data/data/com.termux/files/home/
$cd /data/data/com.termux/files/home/bin
$chmod +x ./*
```
Download models and push them to `/sdcard/llama.cpp/`, then move it to `/data/data/com.termux/files/home/model/`
```
$mv /sdcard/llama.cpp/ggml-model-Q4_K_M.gguf /data/data/com.termux/files/home/model/
$mv /sdcard/llama.cpp/mmproj-model-f16.gguf /data/data/com.termux/files/home/model/
```
Now, you can start chatting:
```
$cd /data/data/com.termux/files/home/bin
$./llama-minicpmv-cli -m ../model/ggml-model-Q4_K_M.gguf --mmproj ../model/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
```

View File

@ -20,6 +20,10 @@
#include "ggml-cann.h" #include "ggml-cann.h"
#endif #endif
#ifdef GGML_USE_VULKAN
#include "ggml-vulkan.h"
#endif
#define STB_IMAGE_IMPLEMENTATION #define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h" #include "stb_image.h"
@ -81,6 +85,7 @@ static std::string format(const char * fmt, ...) {
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder" #define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" #define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" #define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
#define KEY_USE_GELU "clip.use_gelu" #define KEY_USE_GELU "clip.use_gelu"
#define KEY_N_EMBD "clip.%s.embedding_length" #define KEY_N_EMBD "clip.%s.embedding_length"
#define KEY_N_FF "clip.%s.feed_forward_length" #define KEY_N_FF "clip.%s.feed_forward_length"
@ -526,6 +531,7 @@ struct clip_ctx {
bool has_vision_encoder = false; bool has_vision_encoder = false;
bool has_llava_projector = false; bool has_llava_projector = false;
bool has_minicpmv_projector = false; bool has_minicpmv_projector = false;
int minicpmv_version = 2;
struct clip_vision_model vision_model; struct clip_vision_model vision_model;
projector_type proj_type = PROJECTOR_TYPE_MLP; projector_type proj_type = PROJECTOR_TYPE_MLP;
@ -641,7 +647,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
if (ctx->has_minicpmv_projector) { if (ctx->has_minicpmv_projector) {
int pos_w = image_size_width/patch_size; int pos_w = image_size_width/patch_size;
int pos_h = image_size_height/patch_size; int pos_h = image_size_height/patch_size;
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1); if (ctx->minicpmv_version == 2) {
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
}
else if (ctx->minicpmv_version == 3) {
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
}
ggml_set_name(pos_embed, "pos_embed"); ggml_set_name(pos_embed, "pos_embed");
ggml_set_input(pos_embed); ggml_set_input(pos_embed);
} }
@ -768,8 +779,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_gelu(ctx0, embeddings); embeddings = ggml_gelu(ctx0, embeddings);
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
}
} else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
@ -949,10 +960,20 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
} }
{ // attention { // attention
const int hidden_size = 4096; int hidden_size = 4096;
const int d_head = 128; const int d_head = 128;
const int n_head = hidden_size/d_head; int n_head = hidden_size/d_head;
const int num_query = 96; int num_query = 96;
if (ctx->minicpmv_version == 2) {
hidden_size = 4096;
n_head = hidden_size/d_head;
num_query = 96;
}
else if (ctx->minicpmv_version == 3) {
hidden_size = 3584;
n_head = hidden_size/d_head;
num_query = 64;
}
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
@ -1091,7 +1112,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
} }
} }
clip_ctx * new_clip = new clip_ctx; clip_ctx * new_clip = new clip_ctx{};
// update projector type // update projector type
{ {
@ -1125,6 +1146,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
LOG_TEE("%s: CLIP using CANN backend\n", __func__); LOG_TEE("%s: CLIP using CANN backend\n", __func__);
#endif #endif
#ifdef GGML_USE_VULKAN
new_clip->backend = ggml_backend_vk_init(0);
LOG_TEE("%s: CLIP using Vulkan backend\n", __func__);
#endif
if (!new_clip->backend) { if (!new_clip->backend) {
new_clip->backend = ggml_backend_cpu_init(); new_clip->backend = ggml_backend_cpu_init();
@ -1149,6 +1174,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx); new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx);
} }
idx = gguf_find_key(ctx, KEY_MINICPMV_VERSION);
if (idx != -1) {
new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
}
// GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search // GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search
GGML_ASSERT(new_clip->has_vision_encoder); GGML_ASSERT(new_clip->has_vision_encoder);
@ -1910,10 +1940,12 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
// res_imgs memory is being allocated here, previous allocations will be freed if found // res_imgs memory is being allocated here, previous allocations will be freed if found
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) { bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) {
if (clip_is_minicpmv(ctx)) {
std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img); if(clip_is_minicpmv(ctx)){
int max_slice_nums = 9;
std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img, max_slice_nums);
res_imgs->size = 0; res_imgs->size = 0;
for (size_t i = 0; i < imgs.size(); ++i) { for (size_t i = 0; i < imgs.size(); ++i){
res_imgs->size += imgs[i].size(); res_imgs->size += imgs[i].size();
} }
res_imgs->data = new clip_image_f32[res_imgs->size]; res_imgs->data = new clip_image_f32[res_imgs->size];
@ -2146,7 +2178,12 @@ int clip_n_patches(const struct clip_ctx * ctx) {
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) { if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
n_patches /= 4; n_patches /= 4;
} else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
n_patches = 96; if (ctx->minicpmv_version == 2) {
n_patches = 96;
}
else if (ctx->minicpmv_version == 3) {
n_patches = 64;
}
} }
return n_patches; return n_patches;
@ -2282,6 +2319,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
const int patch_size = hparams.patch_size; const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
if(ctx->load_image_size==nullptr){
ctx->load_image_size= clip_image_size_init();
}
const int pos_w = ctx->load_image_size->width/patch_size;
const int pos_h = ctx->load_image_size->height/patch_size;
{ {
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
@ -2316,8 +2358,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
int* positions_data = (int*)malloc(ggml_nbytes(positions)); int* positions_data = (int*)malloc(ggml_nbytes(positions));
for (int i = 0; i < num_positions; i++) { int bucket_coords_h[70];
positions_data[i] = std::floor(70.0*i/num_positions); int bucket_coords_w[70];
for (int i = 0; i < pos_h; i++){
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
}
for (int i = 0; i < pos_w; i++){
bucket_coords_w[i] = std::floor(70.0*i/pos_w);
}
for (int i = 0, id = 0; i < pos_h; i++){
for (int j = 0; j < pos_w; j++){
positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
}
} }
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data); free(positions_data);
@ -2328,12 +2380,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// -> https://huggingface.co/Qwen/Qwen-VL/tree/main // -> https://huggingface.co/Qwen/Qwen-VL/tree/main
// -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23 // -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed"); struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed");
if(ctx->load_image_size==nullptr){
ctx->load_image_size= clip_image_size_init();
}
int pos_w = ctx->load_image_size->width/patch_size;
int pos_h = ctx->load_image_size->height/patch_size;
int embed_dim = 4096; int embed_dim = 4096;
if (ctx->minicpmv_version == 2) {
embed_dim = 4096;
}
else if (ctx->minicpmv_version == 3) {
embed_dim = 3584;
}
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
@ -2346,7 +2399,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed)); ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed));
free(pos_embed_data); free(pos_embed_data);
} }
} else { }
else{
{ {
if (ctx->has_class_embedding) { if (ctx->has_class_embedding) {
struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
@ -2548,13 +2602,21 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->vision_model.mm_3_b->ne[0]; return ctx->vision_model.mm_3_b->ne[0];
} }
if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
return 4096; if (ctx->minicpmv_version == 2) {
return 4096;
}
else if (ctx->minicpmv_version == 3) {
return 3584;
}
} }
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
} }
bool clip_is_minicpmv(const struct clip_ctx * ctx) { int clip_is_minicpmv(const struct clip_ctx * ctx) {
return ctx->has_minicpmv_projector; if (ctx->has_minicpmv_projector) {
return ctx->minicpmv_version;
}
return 0;
} }

View File

@ -85,7 +85,7 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
CLIP_API bool clip_is_minicpmv(const struct clip_ctx * ctx); CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -256,7 +256,14 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
load_image_size->width = img_res_v.data[i].nx; load_image_size->width = img_res_v.data[i].nx;
load_image_size->height = img_res_v.data[i].ny; load_image_size->height = img_res_v.data[i].ny;
clip_add_load_image_size(ctx_clip, load_image_size); clip_add_load_image_size(ctx_clip, load_image_size);
const bool encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); bool encoded = false;
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
if (has_minicpmv_projector == 2) {
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
}
else if (has_minicpmv_projector == 3) {
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
}
if (!encoded) { if (!encoded) {
LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
return false; return false;

View File

@ -134,7 +134,13 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
std::string system_prompt; std::string system_prompt;
int idx = 0; int idx = 0;
int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip); int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip);
system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"; int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
if (has_minicpmv_projector == 2) {
system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n";
}
else if (has_minicpmv_projector == 3) {
system_prompt = "<|im_start|>user\n";
}
LOG_TEE("%s: image token past: %d\n", __func__, n_past); LOG_TEE("%s: image token past: %d\n", __func__, n_past);
eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false); eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false);
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
@ -210,10 +216,24 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri
static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
std::string user_prompt = prompt; std::string user_prompt = prompt;
if (!is_first) user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt; int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
if (!is_first) {
if (has_minicpmv_projector == 2) {
user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt;
}
else if (has_minicpmv_projector == 3) {
user_prompt = "<|im_start|>user\n" + prompt;
}
}
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false); if (has_minicpmv_projector == 2) {
eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false);
}
else if (has_minicpmv_projector == 3) {
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
}
// generate the response // generate the response
LOG_TEE("\n"); LOG_TEE("\n");

View File

@ -1,9 +1,416 @@
import argparse # coding=utf-8
# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Siglip model. """
# Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
import os import os
import math
import warnings
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn.init import _calculate_fan_in_and_fan_out
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import (
logging,
)
from transformers.utils import logging
logger = logging.get_logger(__name__)
class SiglipVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
Example:
```python
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipVisionConfig()
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "siglip_vision_model"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=224,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/siglip-base-patch16-224",
# See all SigLIP models at https://huggingface.co/models?filter=siglip
]
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
if tensor.dtype in [torch.float16, torch.bfloat16]:
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
og_dtype = tensor.dtype
tensor = tensor.to(torch.float32)
tensor.erfinv_()
tensor = tensor.to(og_dtype)
else:
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
if tensor.dtype == torch.float16:
# The `clamp_` op is not (yet?) defined in float16+cpu
tensor = tensor.to(torch.float32)
tensor.clamp_(min=a, max=b)
tensor = tensor.to(torch.float16)
else:
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
):
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \\leq \text{mean} \\leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsquently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
denom = fan_in
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
class SiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
class SiglipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.self_attn = (
SiglipAttention(config)
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
class SiglipPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SiglipVisionConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, SiglipVisionEmbeddings):
width = self.config.hidden_size
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
elif isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, SiglipAttention):
nn.init.normal_(module.q_proj.weight)
nn.init.normal_(module.k_proj.weight)
nn.init.normal_(module.v_proj.weight)
nn.init.normal_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, SiglipMLP):
nn.init.normal_(module.fc1.weight)
nn.init.normal_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
SIGLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`SiglipVisionConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
SIGLIP_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
class SiglipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`SiglipEncoderLayer`].
Args:
config: SiglipConfig
"""
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
class SiglipVisionTransformer(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
main_input_name = "pixel_values"
_supports_flash_attn_2 = True
def __init__(self, config: SiglipVisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.embeddings.patch_embedding
import argparse
import json import json
import re import re
import torch
import numpy as np import numpy as np
from gguf import * from gguf import *
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer, Idefics2VisionConfig from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer, Idefics2VisionConfig
@ -94,6 +501,7 @@ default_image_mean = [0.48145466, 0.4578275, 0.40821073]
default_image_std = [0.26862954, 0.26130258, 0.27577711] default_image_std = [0.26862954, 0.26130258, 0.27577711]
ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2)
# with proper # with proper
args = ap.parse_args() args = ap.parse_args()
@ -135,6 +543,15 @@ if args.use_f32:
# model = CLIPModel.from_pretrained(dir_model) # model = CLIPModel.from_pretrained(dir_model)
# processor = CLIPProcessor.from_pretrained(dir_model) # processor = CLIPProcessor.from_pretrained(dir_model)
minicpmv_version = args.minicpmv_version
emb_dim = 4096
if minicpmv_version == 1:
emb_dim = 2304
elif minicpmv_version == 2:
emb_dim = 4096
elif minicpmv_version == 3:
emb_dim = 3584
default_vision_config = { default_vision_config = {
"hidden_size": 1152, "hidden_size": 1152,
"image_size": 980, "image_size": 980,
@ -144,8 +561,12 @@ default_vision_config = {
"num_hidden_layers": 27, "num_hidden_layers": 27,
"patch_size": 14, "patch_size": 14,
} }
vision_config = Idefics2VisionConfig(**default_vision_config) vision_config = Idefics2VisionConfig(**default_vision_config)
model = Idefics2VisionTransformer(vision_config) model = Idefics2VisionTransformer(vision_config)
if minicpmv_version == 3:
vision_config = SiglipVisionConfig(**default_vision_config)
model = SiglipVisionTransformer(vision_config)
processor = None processor = None
# if model.attn_pool is not None: # if model.attn_pool is not None:
@ -158,6 +579,7 @@ fname_middle = None
has_text_encoder = True has_text_encoder = True
has_vision_encoder = True has_vision_encoder = True
has_minicpmv_projector = False has_minicpmv_projector = False
if args.text_only: if args.text_only:
fname_middle = "text-" fname_middle = "text-"
has_vision_encoder = False has_vision_encoder = False
@ -165,6 +587,7 @@ elif args.minicpmv_projector is not None:
fname_middle = "mmproj-" fname_middle = "mmproj-"
has_text_encoder = False has_text_encoder = False
has_minicpmv_projector = True has_minicpmv_projector = True
minicpmv_version = 3
elif args.vision_only: elif args.vision_only:
fname_middle = "vision-" fname_middle = "vision-"
has_text_encoder = False has_text_encoder = False
@ -189,6 +612,7 @@ elif has_minicpmv_projector:
fout.add_description("image encoder for MiniCPM-V") fout.add_description("image encoder for MiniCPM-V")
# add projector type # add projector type
fout.add_string("clip.projector_type", "resampler") fout.add_string("clip.projector_type", "resampler")
fout.add_int32("clip.minicpmv_version", minicpmv_version)
else: else:
fout.add_description("two-tower CLIP model") fout.add_description("two-tower CLIP model")
@ -274,11 +698,11 @@ def _replace_name_resampler(s, v):
if re.match("resampler.pos_embed", s): if re.match("resampler.pos_embed", s):
return { return {
s: v, s: v,
re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))), re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))),
} }
if re.match("resampler.proj", s): if re.match("resampler.proj", s):
return { return {
re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))), re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))),
re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(), re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(),
} }
if re.match("resampler.attn.in_proj_.*", s): if re.match("resampler.attn.in_proj_.*", s):

View File

@ -4,7 +4,7 @@ import torch
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
ap = argparse.ArgumentParser() ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", help="Path to MiniCPM-V-2.5 model") ap.add_argument("-m", "--model", help="Path to MiniCPM-V model")
args = ap.parse_args() args = ap.parse_args()
# find the model part that includes the the multimodal projector weights # find the model part that includes the the multimodal projector weights
@ -29,7 +29,6 @@ if len(clip_tensors) > 0:
f.write("{}\n") f.write("{}\n")
config = model.llm.config config = model.llm.config
config._name_or_path = "openbmb/MiniCPM-Llama3-V-2.5"
config.auto_map = { config.auto_map = {
"AutoConfig": "configuration_minicpm.MiniCPMConfig", "AutoConfig": "configuration_minicpm.MiniCPMConfig",
"AutoModel": "modeling_minicpm.MiniCPMModel", "AutoModel": "modeling_minicpm.MiniCPMModel",
@ -40,7 +39,6 @@ config.auto_map = {
model.llm.save_pretrained(f"{args.model}/model") model.llm.save_pretrained(f"{args.model}/model")
tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
tok.save_pretrained(f"{args.model}/model") tok.save_pretrained(f"{args.model}/model")
# os.system(f"cp {args.model}/modeling_minicpm.py {args.model}/MiniCPM_l3/modeling_minicpm.py")
print("Done!") print("Done!")
print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")

View File

@ -267,9 +267,9 @@ int main(int argc, char ** argv) {
} }
} }
const bool add_bos = llama_should_add_bos_token(model); const bool add_bos = llama_add_bos_token(model);
if (!llama_model_has_encoder(model)) { if (!llama_model_has_encoder(model)) {
GGML_ASSERT(llama_add_eos_token(model) != 1); GGML_ASSERT(!llama_add_eos_token(model));
} }
LOG("add_bos: %d\n", add_bos); LOG("add_bos: %d\n", add_bos);

View File

@ -340,8 +340,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval // BOS tokens will be added for each chunk before eval
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
fprintf(stderr, "%s: tokenizing the input ..\n", __func__); fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
@ -480,8 +480,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval // BOS tokens will be added for each chunk before eval
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
std::ofstream logits_stream; std::ofstream logits_stream;
if (!params.logits_file.empty()) { if (!params.logits_file.empty()) {
@ -1733,8 +1733,8 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int num_batches = (n_ctx + n_batch - 1)/n_batch; const int num_batches = (n_ctx + n_batch - 1)/n_batch;
const int nv = 2*((n_vocab + 1)/2) + 4; const int nv = 2*((n_vocab + 1)/2) + 4;
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);

View File

@ -34,7 +34,7 @@ Run the quantized model:
```bash ```bash
# start inference on a gguf model # start inference on a gguf model
./llama-cli -m ./models/mymodel/ggml-model-Q4_K_M.gguf -n 128 ./llama-cli -m ./models/mymodel/ggml-model-Q4_K_M.gguf -cnv -p "You are a helpful assistant"
``` ```
When running the larger models, make sure you have enough disk space to store all the intermediate files. When running the larger models, make sure you have enough disk space to store all the intermediate files.

View File

@ -253,6 +253,8 @@ int main(int argc, char ** argv) {
chunks[i].tokens.clear(); chunks[i].tokens.clear();
} }
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
// start loop, receive query and return top k similar chunks based on cosine similarity // start loop, receive query and return top k similar chunks based on cosine similarity
std::string query; std::string query;
while (true) { while (true) {
@ -260,7 +262,6 @@ int main(int argc, char ** argv) {
std::getline(std::cin, query); std::getline(std::cin, query);
std::vector<int32_t> query_tokens = llama_tokenize(ctx, query, true); std::vector<int32_t> query_tokens = llama_tokenize(ctx, query, true);
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
batch_add_seq(query_batch, query_tokens, 0); batch_add_seq(query_batch, query_tokens, 0);
std::vector<float> query_emb(n_embd, 0); std::vector<float> query_emb(n_embd, 0);
@ -293,6 +294,7 @@ int main(int argc, char ** argv) {
} }
// clean up // clean up
llama_batch_free(query_batch);
llama_print_timings(ctx); llama_print_timings(ctx);
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);

View File

@ -247,6 +247,25 @@ logging:
--log-append Don't truncate the old log file. --log-append Don't truncate the old log file.
``` ```
Available environment variables (if specified, these variables will override parameters specified in arguments):
- `LLAMA_CACHE` (cache directory, used by `--hf-repo`)
- `HF_TOKEN` (Hugging Face access token, used when accessing a gated model with `--hf-repo`)
- `LLAMA_ARG_MODEL`
- `LLAMA_ARG_THREADS`
- `LLAMA_ARG_CTX_SIZE`
- `LLAMA_ARG_N_PARALLEL`
- `LLAMA_ARG_BATCH`
- `LLAMA_ARG_UBATCH`
- `LLAMA_ARG_N_GPU_LAYERS`
- `LLAMA_ARG_THREADS_HTTP`
- `LLAMA_ARG_CHAT_TEMPLATE`
- `LLAMA_ARG_N_PREDICT`
- `LLAMA_ARG_ENDPOINT_METRICS`
- `LLAMA_ARG_ENDPOINT_SLOTS`
- `LLAMA_ARG_EMBEDDINGS`
- `LLAMA_ARG_FLASH_ATTN`
- `LLAMA_ARG_DEFRAG_THOLD`
## Build ## Build
@ -368,15 +387,16 @@ node index.js
## API Endpoints ## API Endpoints
### GET `/health`: Returns the current state of the server ### GET `/health`: Returns heath check result
- 503 -> `{"status": "loading model"}` if the model is still being loaded. **Response format**
- 500 -> `{"status": "error"}` if the model failed to load.
- 200 -> `{"status": "ok", "slots_idle": 1, "slots_processing": 2 }` if the model is successfully loaded and the server is ready for further requests mentioned below.
- 200 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if no slots are currently available.
- 503 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if the query parameter `fail_on_no_slot` is provided and no slots are currently available.
If the query parameter `include_slots` is passed, `slots` field will contain internal slots data except if `--slots-endpoint-disable` is set. - HTTP status code 503
- Body: `{"error": {"code": 503, "message": "Loading model", "type": "unavailable_error"}}`
- Explanation: the model is still being loaded.
- HTTP status code 200
- Body: `{"status": "ok" }`
- Explanation: the model is successfully loaded and the server is ready.
### POST `/completion`: Given a `prompt`, it returns the predicted completion. ### POST `/completion`: Given a `prompt`, it returns the predicted completion.
@ -639,10 +659,16 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte
}' }'
``` ```
### GET `/slots`: Returns the current slots processing state. Can be disabled with `--slots-endpoint-disable`. ### GET `/slots`: Returns the current slots processing state
This endpoint can be disabled with `--no-slots`
If query param `?fail_on_no_slot=1` is set, this endpoint will respond with status code 503 if there is no available slots.
**Response format** **Response format**
Example:
```json ```json
[ [
{ {
@ -702,7 +728,13 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte
] ]
``` ```
### GET `/metrics`: Prometheus compatible metrics exporter endpoint if `--metrics` is enabled: Possible values for `slot[i].state` are:
- `0`: SLOT_STATE_IDLE
- `1`: SLOT_STATE_PROCESSING
### GET `/metrics`: Prometheus compatible metrics exporter
This endpoint is only accessible if `--metrics` is set.
Available metrics: Available metrics:
- `llamacpp:prompt_tokens_total`: Number of prompt tokens processed. - `llamacpp:prompt_tokens_total`: Number of prompt tokens processed.
@ -767,6 +799,10 @@ Available metrics:
### GET `/lora-adapters`: Get list of all LoRA adapters ### GET `/lora-adapters`: Get list of all LoRA adapters
This endpoint returns the loaded LoRA adapters. You can add adapters using `--lora` when starting the server, for example: `--lora my_adapter_1.gguf --lora my_adapter_2.gguf ...`
By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply`
If an adapter is disabled, the scale will be set to 0. If an adapter is disabled, the scale will be set to 0.
**Response format** **Response format**

View File

@ -15,6 +15,8 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT: // Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
// mime type for sending response
#define MIMETYPE_JSON "application/json; charset=utf-8"
// auto generated files (update with ./deps.sh) // auto generated files (update with ./deps.sh)
#include "colorthemes.css.hpp" #include "colorthemes.css.hpp"
@ -67,7 +69,6 @@ enum slot_command {
enum server_state { enum server_state {
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
SERVER_STATE_READY, // Server is ready and model is loaded SERVER_STATE_READY, // Server is ready and model is loaded
SERVER_STATE_ERROR // An error occurred, load_model failed
}; };
enum server_task_type { enum server_task_type {
@ -693,8 +694,8 @@ struct server_context {
n_ctx = llama_n_ctx(ctx); n_ctx = llama_n_ctx(ctx);
add_bos_token = llama_should_add_bos_token(model); add_bos_token = llama_add_bos_token(model);
has_eos_token = llama_add_eos_token(model) != 1; has_eos_token = !llama_add_eos_token(model);
return true; return true;
} }
@ -754,13 +755,13 @@ struct server_context {
default_generation_settings_for_props = get_formated_generation(slots.front()); default_generation_settings_for_props = get_formated_generation(slots.front());
default_generation_settings_for_props["seed"] = -1; default_generation_settings_for_props["seed"] = -1;
// the update_slots() logic will always submit a maximum of n_batch tokens // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
{ {
const int32_t n_batch = llama_n_batch(ctx); const int32_t n_batch = llama_n_batch(ctx);
// only a single seq_id per token is needed // only a single seq_id per token is needed
batch = llama_batch_init(n_batch, 0, 1); batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
} }
metrics.init(); metrics.init();
@ -1137,28 +1138,19 @@ struct server_context {
if (!system_prompt.empty()) { if (!system_prompt.empty()) {
system_tokens = ::llama_tokenize(ctx, system_prompt, true); system_tokens = ::llama_tokenize(ctx, system_prompt, true);
llama_batch_clear(batch);
for (int i = 0; i < (int)system_tokens.size(); ++i) {
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
}
const int32_t n_batch = llama_n_batch(ctx); const int32_t n_batch = llama_n_batch(ctx);
const int32_t n_tokens_prompt = system_tokens.size();
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i); const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
if (llama_decode(ctx, batch_view) != 0) { llama_batch_clear(batch);
for (int32_t j = 0; j < n_tokens; ++j) {
llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
}
if (llama_decode(ctx, batch) != 0) {
LOG_ERROR("llama_decode() failed", {}); LOG_ERROR("llama_decode() failed", {});
return; return;
} }
@ -1331,7 +1323,7 @@ struct server_context {
return json { return json {
{"n_ctx", slot.n_ctx}, {"n_ctx", slot.n_ctx},
{"n_predict", slot.n_predict}, {"n_predict", slot.n_predict}, // Server configured n_predict
{"model", params.model_alias}, {"model", params.model_alias},
{"seed", slot.sparams.seed}, {"seed", slot.sparams.seed},
{"temperature", slot.sparams.temp}, {"temperature", slot.sparams.temp},
@ -1353,7 +1345,7 @@ struct server_context {
{"mirostat_eta", slot.sparams.mirostat_eta}, {"mirostat_eta", slot.sparams.mirostat_eta},
{"penalize_nl", slot.sparams.penalize_nl}, {"penalize_nl", slot.sparams.penalize_nl},
{"stop", slot.params.antiprompt}, {"stop", slot.params.antiprompt},
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict {"max_tokens", slot.params.n_predict}, // User configured n_predict
{"n_keep", slot.params.n_keep}, {"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard}, {"n_discard", slot.params.n_discard},
{"ignore_eos", ignore_eos}, {"ignore_eos", ignore_eos},
@ -1861,6 +1853,8 @@ struct server_context {
llama_lora_adapters_apply(ctx, lora_adapters); llama_lora_adapters_apply(ctx, lora_adapters);
server_task_result result; server_task_result result;
result.id = task.id; result.id = task.id;
result.stop = true;
result.error = false;
result.data = json{{ "success", true }}; result.data = json{{ "success", true }};
queue_results.send(result); queue_results.send(result);
} break; } break;
@ -2045,7 +2039,7 @@ struct server_context {
slot.t_start_generation = 0; slot.t_start_generation = 0;
if (slot.infill) { if (slot.infill) {
const bool add_bos = llama_should_add_bos_token(model); const bool add_bos = llama_add_bos_token(model);
bool suff_rm_leading_spc = true; bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1); params.input_suffix.erase(0, 1);
@ -2513,6 +2507,9 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
// parse arguments from environment variables
gpt_params_parse_from_env(params);
// TODO: not great to use extern vars // TODO: not great to use extern vars
server_log_json = params.log_json; server_log_json = params.log_json;
server_verbose = params.verbosity > 0; server_verbose = params.verbosity > 0;
@ -2563,19 +2560,19 @@ int main(int argc, char ** argv) {
svr->set_default_headers({{"Server", "llama.cpp"}}); svr->set_default_headers({{"Server", "llama.cpp"}});
// CORS preflight // CORS preflight
svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) { svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); // Access-Control-Allow-Origin is already set by middleware
res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "POST"); res.set_header("Access-Control-Allow-Methods", "POST");
res.set_header("Access-Control-Allow-Headers", "*"); res.set_header("Access-Control-Allow-Headers", "*");
return res.set_content("", "application/json; charset=utf-8"); return res.set_content("", "text/html"); // blank response, no data
}); });
svr->set_logger(log_server_request); svr->set_logger(log_server_request);
auto res_error = [](httplib::Response & res, json error_data) { auto res_error = [](httplib::Response & res, json error_data) {
json final_response {{"error", error_data}}; json final_response {{"error", error_data}};
res.set_content(final_response.dump(), "application/json; charset=utf-8"); res.set_content(final_response.dump(), MIMETYPE_JSON);
res.status = json_value(error_data, "code", 500); res.status = json_value(error_data, "code", 500);
}; };
@ -2605,11 +2602,6 @@ int main(int argc, char ** argv) {
svr->set_read_timeout (params.timeout_read); svr->set_read_timeout (params.timeout_read);
svr->set_write_timeout(params.timeout_write); svr->set_write_timeout(params.timeout_write);
if (!svr->bind_to_port(params.hostname, params.port)) {
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port);
return 1;
}
std::unordered_map<std::string, std::string> log_data; std::unordered_map<std::string, std::string> log_data;
log_data["hostname"] = params.hostname; log_data["hostname"] = params.hostname;
@ -2625,35 +2617,6 @@ int main(int argc, char ** argv) {
// Necessary similarity of prompt for slot selection // Necessary similarity of prompt for slot selection
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
// load the model
if (!ctx_server.load_model(params)) {
state.store(SERVER_STATE_ERROR);
return 1;
} else {
ctx_server.init();
state.store(SERVER_STATE_READY);
}
LOG_INFO("model loaded", {});
const auto model_meta = ctx_server.model_meta();
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
if (params.chat_template.empty()) {
if (!ctx_server.validate_model_chat_template()) {
LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
params.chat_template = "chatml";
}
}
// print sample chat example to make it clear which template is used
{
LOG_INFO("chat template", {
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
{"built_in", params.chat_template.empty()},
});
}
// //
// Middlewares // Middlewares
// //
@ -2697,8 +2660,6 @@ int main(int argc, char ** argv) {
} }
// API key is invalid or not provided // API key is invalid or not provided
// TODO: make another middleware for CORS related logic
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
LOG_WARNING("Unauthorized: Invalid API Key", {}); LOG_WARNING("Unauthorized: Invalid API Key", {});
@ -2706,8 +2667,21 @@ int main(int argc, char ** argv) {
return false; return false;
}; };
auto middleware_server_state = [&res_error, &state](const httplib::Request &, httplib::Response & res) {
server_state current_state = state.load();
if (current_state == SERVER_STATE_LOADING_MODEL) {
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
return false;
}
return true;
};
// register server middlewares // register server middlewares
svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) { svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!middleware_server_state(req, res)) {
return httplib::Server::HandlerResponse::Handled;
}
if (!middleware_validate_api_key(req, res)) { if (!middleware_validate_api_key(req, res)) {
return httplib::Server::HandlerResponse::Handled; return httplib::Server::HandlerResponse::Handled;
} }
@ -2718,62 +2692,15 @@ int main(int argc, char ** argv) {
// Route handlers (or controllers) // Route handlers (or controllers)
// //
const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) { const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
server_state current_state = state.load(); // error and loading states are handled by middleware
switch (current_state) { json health = {{"status", "ok"}};
case SERVER_STATE_READY: res.set_content(health.dump(), "application/json");
{
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
task.type = SERVER_TASK_TYPE_METRICS;
task.id_target = -1;
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
// get the result
server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
const int n_idle_slots = result.data.at("idle");
const int n_processing_slots = result.data.at("processing");
json health = {
{"status", "ok"},
{"slots_idle", n_idle_slots},
{"slots_processing", n_processing_slots}
};
res.status = 200; // HTTP OK
if (params.endpoint_slots && req.has_param("include_slots")) {
health["slots"] = result.data.at("slots");
}
if (n_idle_slots == 0) {
health["status"] = "no slot available";
if (req.has_param("fail_on_no_slot")) {
res.status = 503; // HTTP Service Unavailable
}
}
res.set_content(health.dump(), "application/json");
break;
}
case SERVER_STATE_LOADING_MODEL:
{
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
} break;
case SERVER_STATE_ERROR:
{
res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
} break;
}
}; };
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
if (!params.endpoint_slots) { if (!params.endpoint_slots) {
res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
@ -2791,13 +2718,22 @@ int main(int argc, char ** argv) {
server_task_result result = ctx_server.queue_results.recv(task.id); server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id);
res.set_content(result.data.at("slots").dump(), "application/json"); // optionally return "fail_on_no_slot" error
const int n_idle_slots = result.data.at("idle");
if (req.has_param("fail_on_no_slot")) {
if (n_idle_slots == 0) {
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
return;
}
}
res.set_content(result.data.at("slots").dump(), MIMETYPE_JSON);
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
}; };
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
if (!params.endpoint_metrics) { if (!params.endpoint_metrics) {
res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
@ -2922,7 +2858,7 @@ int main(int argc, char ** argv) {
if (result.error) { if (result.error) {
res_error(res, result.data); res_error(res, result.data);
} else { } else {
res.set_content(result.data.dump(), "application/json"); res.set_content(result.data.dump(), MIMETYPE_JSON);
} }
}; };
@ -2952,7 +2888,7 @@ int main(int argc, char ** argv) {
if (result.error) { if (result.error) {
res_error(res, result.data); res_error(res, result.data);
} else { } else {
res.set_content(result.data.dump(), "application/json"); res.set_content(result.data.dump(), MIMETYPE_JSON);
} }
}; };
@ -2972,13 +2908,11 @@ int main(int argc, char ** argv) {
if (result.error) { if (result.error) {
res_error(res, result.data); res_error(res, result.data);
} else { } else {
res.set_content(result.data.dump(), "application/json"); res.set_content(result.data.dump(), MIMETYPE_JSON);
} }
}; };
const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
std::string id_slot_str = req.path_params.at("id_slot"); std::string id_slot_str = req.path_params.at("id_slot");
int id_slot; int id_slot;
@ -3002,7 +2936,7 @@ int main(int argc, char ** argv) {
} }
}; };
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
std::string template_key = "tokenizer.chat_template", curr_tmpl; std::string template_key = "tokenizer.chat_template", curr_tmpl;
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
if (tlen > 0) { if (tlen > 0) {
@ -3011,7 +2945,6 @@ int main(int argc, char ** argv) {
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen); curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
} }
} }
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = { json data = {
{ "system_prompt", ctx_server.system_prompt.c_str() }, { "system_prompt", ctx_server.system_prompt.c_str() },
{ "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "default_generation_settings", ctx_server.default_generation_settings_for_props },
@ -3019,7 +2952,7 @@ int main(int argc, char ** argv) {
{ "chat_template", curr_tmpl.c_str() } { "chat_template", curr_tmpl.c_str() }
}; };
res.set_content(data.dump(), "application/json; charset=utf-8"); res.set_content(data.dump(), MIMETYPE_JSON);
}; };
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
@ -3028,8 +2961,6 @@ int main(int argc, char ** argv) {
return; return;
} }
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = json::parse(req.body); json data = json::parse(req.body);
const int id_task = ctx_server.queue_tasks.get_new_id(); const int id_task = ctx_server.queue_tasks.get_new_id();
@ -3040,7 +2971,7 @@ int main(int argc, char ** argv) {
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
server_task_result result = ctx_server.queue_results.recv(id_task); server_task_result result = ctx_server.queue_results.recv(id_task);
if (!result.error && result.stop) { if (!result.error && result.stop) {
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
} else { } else {
res_error(res, result.data); res_error(res, result.data);
} }
@ -3103,9 +3034,7 @@ int main(int argc, char ** argv) {
} }
}; };
const auto handle_models = [&params, &model_meta](const httplib::Request & req, httplib::Response & res) { const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json models = { json models = {
{"object", "list"}, {"object", "list"},
{"data", { {"data", {
@ -3114,12 +3043,12 @@ int main(int argc, char ** argv) {
{"object", "model"}, {"object", "model"},
{"created", std::time(0)}, {"created", std::time(0)},
{"owned_by", "llamacpp"}, {"owned_by", "llamacpp"},
{"meta", model_meta} {"meta", ctx_server.model_meta()}
}, },
}} }}
}; };
res.set_content(models.dump(), "application/json; charset=utf-8"); res.set_content(models.dump(), MIMETYPE_JSON);
}; };
const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) {
@ -3127,8 +3056,6 @@ int main(int argc, char ** argv) {
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
const int id_task = ctx_server.queue_tasks.get_new_id(); const int id_task = ctx_server.queue_tasks.get_new_id();
@ -3143,7 +3070,7 @@ int main(int argc, char ** argv) {
if (!result.error && result.stop) { if (!result.error && result.stop) {
json result_oai = format_final_response_oaicompat(data, result.data, completion_id); json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
} else { } else {
res_error(res, result.data); res_error(res, result.data);
} }
@ -3205,8 +3132,6 @@ int main(int argc, char ** argv) {
return; return;
} }
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = json::parse(req.body); json data = json::parse(req.body);
const int id_task = ctx_server.queue_tasks.get_new_id(); const int id_task = ctx_server.queue_tasks.get_new_id();
@ -3217,7 +3142,7 @@ int main(int argc, char ** argv) {
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
server_task_result result = ctx_server.queue_results.recv(id_task); server_task_result result = ctx_server.queue_results.recv(id_task);
if (!result.error && result.stop) { if (!result.error && result.stop) {
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
} else { } else {
res_error(res, result.data); res_error(res, result.data);
} }
@ -3265,7 +3190,6 @@ int main(int argc, char ** argv) {
}; };
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body); const json body = json::parse(req.body);
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
@ -3274,11 +3198,10 @@ int main(int argc, char ** argv) {
tokens = ctx_server.tokenize(body.at("content"), add_special); tokens = ctx_server.tokenize(body.at("content"), add_special);
} }
const json data = format_tokenizer_response(tokens); const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json; charset=utf-8"); return res.set_content(data.dump(), MIMETYPE_JSON);
}; };
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body); const json body = json::parse(req.body);
std::string content; std::string content;
@ -3288,12 +3211,10 @@ int main(int argc, char ** argv) {
} }
const json data = format_detokenized_response(content); const json data = format_detokenized_response(content);
return res.set_content(data.dump(), "application/json; charset=utf-8"); return res.set_content(data.dump(), MIMETYPE_JSON);
}; };
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body); const json body = json::parse(req.body);
bool is_openai = false; bool is_openai = false;
@ -3339,11 +3260,10 @@ int main(int argc, char ** argv) {
json root = is_openai json root = is_openai
? format_embeddings_response_oaicompat(body, responses) ? format_embeddings_response_oaicompat(body, responses)
: responses[0]; : responses[0];
return res.set_content(root.dump(), "application/json; charset=utf-8"); return res.set_content(root.dump(), MIMETYPE_JSON);
}; };
const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) { const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json result = json::array(); json result = json::array();
for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) { for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
auto & la = ctx_server.lora_adapters[i]; auto & la = ctx_server.lora_adapters[i];
@ -3353,13 +3273,11 @@ int main(int argc, char ** argv) {
{"scale", la.scale}, {"scale", la.scale},
}); });
} }
res.set_content(result.dump(), "application/json"); res.set_content(result.dump(), MIMETYPE_JSON);
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
}; };
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const std::vector<json> body = json::parse(req.body); const std::vector<json> body = json::parse(req.body);
int max_idx = ctx_server.lora_adapters.size(); int max_idx = ctx_server.lora_adapters.size();
@ -3387,7 +3305,7 @@ int main(int argc, char ** argv) {
server_task_result result = ctx_server.queue_results.recv(id_task); server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task);
res.set_content(result.data.dump(), "application/json"); res.set_content(result.data.dump(), MIMETYPE_JSON);
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
}; };
@ -3463,35 +3381,75 @@ int main(int argc, char ** argv) {
log_data["n_threads_http"] = std::to_string(params.n_threads_http); log_data["n_threads_http"] = std::to_string(params.n_threads_http);
svr->new_task_queue = [&params] { return new httplib::ThreadPool(params.n_threads_http); }; svr->new_task_queue = [&params] { return new httplib::ThreadPool(params.n_threads_http); };
LOG_INFO("HTTP server listening", log_data); // clean up function, to be called before exit
auto clean_up = [&svr]() {
svr->stop();
llama_backend_free();
};
// run the HTTP server in a thread - see comment below // bind HTTP listen port, run the HTTP server in a thread
std::thread t([&]() { if (!svr->bind_to_port(params.hostname, params.port)) {
if (!svr->listen_after_bind()) { LOG_ERROR("couldn't bind HTTP server socket", {
state.store(SERVER_STATE_ERROR); {"hostname", params.hostname},
return 1; {"port", params.port},
});
clean_up();
LOG_ERROR("exiting due to HTTP server error", {});
return 1;
}
std::thread t([&]() { svr->listen_after_bind(); });
svr->wait_until_ready();
LOG_INFO("HTTP server is listening", log_data);
// load the model
LOG_INFO("loading model", log_data);
if (!ctx_server.load_model(params)) {
clean_up();
t.join();
LOG_ERROR("exiting due to model loading error", {});
return 1;
} else {
ctx_server.init();
state.store(SERVER_STATE_READY);
LOG_INFO("model loaded", {});
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
if (params.chat_template.empty()) {
if (!ctx_server.validate_model_chat_template()) {
LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
params.chat_template = "chatml";
}
} }
return 0; // print sample chat example to make it clear which template is used
}); {
LOG_INFO("chat template", {
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
{"built_in", params.chat_template.empty()},
});
}
ctx_server.queue_tasks.on_new_task(std::bind( ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1)); &server_context::process_single_task, &ctx_server, std::placeholders::_1));
ctx_server.queue_tasks.on_finish_multitask(std::bind( ctx_server.queue_tasks.on_finish_multitask(std::bind(
&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
ctx_server.queue_tasks.on_update_slots(std::bind( ctx_server.queue_tasks.on_update_slots(std::bind(
&server_context::update_slots, &ctx_server)); &server_context::update_slots, &ctx_server));
ctx_server.queue_results.on_multitask_update(std::bind( ctx_server.queue_results.on_multitask_update(std::bind(
&server_queue::update_multitask, &server_queue::update_multitask,
&ctx_server.queue_tasks, &ctx_server.queue_tasks,
std::placeholders::_1, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_2,
std::placeholders::_3 std::placeholders::_3
)); ));
shutdown_handler = [&](int) { shutdown_handler = [&](int) {
ctx_server.queue_tasks.terminate(); ctx_server.queue_tasks.terminate();
}; };
ctx_server.queue_tasks.start_loop();
}
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action; struct sigaction sigint_action;
@ -3507,12 +3465,8 @@ int main(int argc, char ** argv) {
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true); SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif #endif
ctx_server.queue_tasks.start_loop(); clean_up();
svr->stop();
t.join(); t.join();
llama_backend_free();
return 0; return 0;
} }

View File

@ -205,27 +205,20 @@ def step_start_server(context):
async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str): async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str):
match expecting_status: match expecting_status:
case 'healthy': case 'healthy':
await wait_for_health_status(context, context.base_url, 200, 'ok', await wait_for_slots_status(context, context.base_url, 200,
timeout=30) timeout=30)
case 'ready' | 'idle': case 'ready' | 'idle':
await wait_for_health_status(context, context.base_url, 200, 'ok', await wait_for_slots_status(context, context.base_url, 200,
timeout=30, timeout=30,
params={'fail_on_no_slot': 0, 'include_slots': 0}, params={'fail_on_no_slot': 1},
slots_idle=context.n_slots, slots_idle=context.n_slots,
slots_processing=0, slots_processing=0)
expected_slots=[{'id': slot_id, 'state': 0}
for slot_id in
range(context.n_slots if context.n_slots else 1)])
case 'busy': case 'busy':
await wait_for_health_status(context, context.base_url, 503, await wait_for_slots_status(context, context.base_url, 503,
'no slot available', params={'fail_on_no_slot': 1},
params={'fail_on_no_slot': 0, 'include_slots': 0}, slots_idle=0,
slots_idle=0, slots_processing=context.n_slots)
slots_processing=context.n_slots,
expected_slots=[{'id': slot_id, 'state': 1}
for slot_id in
range(context.n_slots if context.n_slots else 1)])
case _: case _:
assert False, "unknown status" assert False, "unknown status"
@ -1187,17 +1180,15 @@ async def gather_tasks_results(context):
return n_completions return n_completions
async def wait_for_health_status(context, async def wait_for_slots_status(context,
base_url, base_url,
expected_http_status_code, expected_http_status_code,
expected_health_status, timeout=3,
timeout=3, params=None,
params=None, slots_idle=None,
slots_idle=None, slots_processing=None):
slots_processing=None,
expected_slots=None):
if context.debug: if context.debug:
print(f"Starting checking for health for expected_health_status={expected_health_status}") print(f"Starting checking for health for expected_http_status_code={expected_http_status_code}")
interval = 0.5 interval = 0.5
counter = 0 counter = 0
if 'GITHUB_ACTIONS' in os.environ: if 'GITHUB_ACTIONS' in os.environ:
@ -1205,26 +1196,19 @@ async def wait_for_health_status(context,
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
while True: while True:
async with await session.get(f'{base_url}/health', params=params) as health_response: async with await session.get(f'{base_url}/slots', params=params) as slots_response:
status_code = health_response.status status_code = slots_response.status
health = await health_response.json() slots = await slots_response.json()
if context.debug: if context.debug:
print(f"HEALTH - response for expected health status='{expected_health_status}' on " print(f"slots responses {slots}\n")
f"'{base_url}/health'?{params} is {health}\n") if status_code == 503 and status_code == expected_http_status_code:
if (status_code == expected_http_status_code
and health['status'] == expected_health_status
and (slots_idle is None or health['slots_idle'] == slots_idle)
and (slots_processing is None or health['slots_processing'] == slots_processing)):
if expected_slots is not None:
assert_slots_status(health['slots'], expected_slots)
return
if (status_code == expected_http_status_code
and health['status'] == expected_health_status
and (slots_idle is None or health['slots_idle'] == slots_idle)
and (slots_processing is None or health['slots_processing'] == slots_processing)):
if expected_slots is not None:
assert_slots_status(health['slots'], expected_slots)
return return
if status_code == 200 and status_code == expected_http_status_code:
n_slots_idle = sum(1 if slot["state"] == 0 else 0 for slot in slots)
n_slots_processing = sum(1 if slot["state"] != 0 else 0 for slot in slots)
if ((slots_idle is None or slots_idle == n_slots_idle)
and (slots_processing is None or slots_processing == n_slots_processing)):
return
await asyncio.sleep(interval) await asyncio.sleep(interval)
counter += interval counter += interval
@ -1238,7 +1222,7 @@ async def wait_for_health_status(context,
if n_completions > 0: if n_completions > 0:
return return
assert False, f'{expected_health_status} timeout exceeded {counter}s>={timeout}' assert False, f'slots check timeout exceeded {counter}s>={timeout}'
def assert_embeddings(embeddings): def assert_embeddings(embeddings):

View File

@ -362,7 +362,7 @@ int main(int raw_argc, char ** raw_argv) {
prompt = stdin_buffer.str(); prompt = stdin_buffer.str();
} }
const bool model_wants_add_bos = llama_should_add_bos_token(model); const bool model_wants_add_bos = llama_add_bos_token(model);
const bool add_bos = model_wants_add_bos && !no_bos; const bool add_bos = model_wants_add_bos && !no_bos;
const bool parse_special = !no_parse_special; const bool parse_special = !no_parse_special;

View File

@ -20,11 +20,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1723175592, "lastModified": 1723637854,
"narHash": "sha256-M0xJ3FbDUc4fRZ84dPGx5VvgFsOzds77KiBMW/mMTnI=", "narHash": "sha256-med8+5DSWa2UnOqtdICndjDAEjxr5D7zaIiK4pn0Q7c=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "5e0ca22929f3342b19569b21b2f3462f053e497b", "rev": "c3aa7b8938b17aebd2deecf7be0636000d62a2b9",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -129,13 +129,13 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM"
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF) option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF)
option(GGML_CURL "ggml: use libcurl to download model from an URL" OFF)
option(GGML_HIPBLAS "ggml: use hipBLAS" OFF) option(GGML_HIPBLAS "ggml: use hipBLAS" OFF)
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
option(GGML_VULKAN "ggml: use Vulkan" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF)
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF) option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)
option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug output" OFF) option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug output" OFF)
option(GGML_VULKAN_PERF "ggml: enable Vulkan perf output" OFF)
option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF) option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF)
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
option(GGML_KOMPUTE "ggml: use Kompute" OFF) option(GGML_KOMPUTE "ggml: use Kompute" OFF)

View File

@ -1779,10 +1779,8 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_ssm_conv( GGML_API struct ggml_tensor * ggml_ssm_conv(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * s, struct ggml_tensor * sx,
struct ggml_tensor * x, struct ggml_tensor * c);
struct ggml_tensor * c,
struct ggml_tensor * sq);
GGML_API struct ggml_tensor * ggml_ssm_scan( GGML_API struct ggml_tensor * ggml_ssm_scan(
struct ggml_context * ctx, struct ggml_context * ctx,
@ -1791,8 +1789,7 @@ extern "C" {
struct ggml_tensor * dt, struct ggml_tensor * dt,
struct ggml_tensor * A, struct ggml_tensor * A,
struct ggml_tensor * B, struct ggml_tensor * B,
struct ggml_tensor * C, struct ggml_tensor * C);
struct ggml_tensor * sq);
// partition into non-overlapping windows with padding if needed // partition into non-overlapping windows with padding if needed
// example: // example:

View File

@ -549,6 +549,13 @@ if (GGML_SYCL)
file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp") file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp")
list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp") list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp")
find_package(DNNL)
message("-- DNNL found:" ${DNNL_FOUND})
if (GGML_SYCL_TARGET STREQUAL "INTEL")
add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
else()
add_compile_definitions(GGML_SYCL_DNNL=0)
endif()
if (WIN32) if (WIN32)
find_package(IntelSYCL REQUIRED) find_package(IntelSYCL REQUIRED)
find_package(MKL REQUIRED) find_package(MKL REQUIRED)
@ -561,6 +568,9 @@ if (GGML_SYCL)
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl) set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl)
endif() endif()
endif() endif()
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
list(APPEND GGML_EXTRA_LIBS DNNL::dnnl)
endif()
endif() endif()
if (GGML_RPC) if (GGML_RPC)
@ -602,6 +612,10 @@ if (GGML_VULKAN)
add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG) add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)
endif() endif()
if (GGML_VULKAN_PERF)
add_compile_definitions(GGML_VULKAN_PERF)
endif()
if (GGML_VULKAN_VALIDATE) if (GGML_VULKAN_VALIDATE)
add_compile_definitions(GGML_VULKAN_VALIDATE) add_compile_definitions(GGML_VULKAN_VALIDATE)
endif() endif()

View File

@ -1018,10 +1018,6 @@ static bool ggml_is_view_op(enum ggml_op op) {
#define GGML_SCHED_MAX_BACKENDS 16 #define GGML_SCHED_MAX_BACKENDS 16
#endif #endif
#ifndef GGML_SCHED_MAX_SPLITS
#define GGML_SCHED_MAX_SPLITS 2048
#endif
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS #ifndef GGML_SCHED_MAX_SPLIT_INPUTS
#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC #define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC
#endif #endif
@ -1125,7 +1121,8 @@ static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, co
} }
#if 0 #if 0
static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only #define GGML_SCHED_MAX_SPLITS_DEBUG 4096
static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only
#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__) #define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
#define GET_CAUSE(node) causes[hash_id(node)] #define GET_CAUSE(node) causes[hash_id(node)]
#else #else
@ -1549,7 +1546,6 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split)); sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split));
GGML_ASSERT(sched->splits != NULL); GGML_ASSERT(sched->splits != NULL);
} }
GGML_ASSERT(i_split < GGML_SCHED_MAX_SPLITS);
split = &sched->splits[i_split]; split = &sched->splits[i_split];
split->backend_id = node_backend_id; split->backend_id = node_backend_id;
split->i_start = i; split->i_start = i;
@ -1865,13 +1861,14 @@ ggml_backend_sched_t ggml_backend_sched_new(
sched->hv_tensor_backend_ids = malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); sched->hv_tensor_backend_ids = malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
sched->hv_tensor_copies = malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *)); sched->hv_tensor_copies = malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
const size_t nodes_size = graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2; const size_t ggml_sched_max_splits = graph_size; // at most there is one split for each node in the graph
const size_t nodes_size = graph_size + ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2;
sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0])); sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0]));
sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0])); sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0]));
sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0]));
sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0]));
sched->context_buffer_size = GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false); sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false);
sched->context_buffer = malloc(sched->context_buffer_size); sched->context_buffer = malloc(sched->context_buffer_size);
const int initial_splits_capacity = 16; const int initial_splits_capacity = 16;

View File

@ -82,17 +82,18 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
// RPC commands // RPC commands
enum rpc_cmd { enum rpc_cmd {
ALLOC_BUFFER = 0, RPC_CMD_ALLOC_BUFFER = 0,
GET_ALIGNMENT, RPC_CMD_GET_ALIGNMENT,
GET_MAX_SIZE, RPC_CMD_GET_MAX_SIZE,
BUFFER_GET_BASE, RPC_CMD_BUFFER_GET_BASE,
FREE_BUFFER, RPC_CMD_FREE_BUFFER,
BUFFER_CLEAR, RPC_CMD_BUFFER_CLEAR,
SET_TENSOR, RPC_CMD_SET_TENSOR,
GET_TENSOR, RPC_CMD_GET_TENSOR,
COPY_TENSOR, RPC_CMD_COPY_TENSOR,
GRAPH_COMPUTE, RPC_CMD_GRAPH_COMPUTE,
GET_DEVICE_MEMORY, RPC_CMD_GET_DEVICE_MEMORY,
RPC_CMD_COUNT,
}; };
// RPC data structures // RPC data structures
@ -330,7 +331,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
uint64_t remote_ptr = ctx->remote_ptr; uint64_t remote_ptr = ctx->remote_ptr;
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
std::vector<uint8_t> output; std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output); bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.empty()); GGML_ASSERT(output.empty());
delete ctx; delete ctx;
@ -346,7 +347,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
uint64_t remote_ptr = ctx->remote_ptr; uint64_t remote_ptr = ctx->remote_ptr;
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
std::vector<uint8_t> output; std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output); bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t)); GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | base_ptr (8 bytes) | // output serialization format: | base_ptr (8 bytes) |
@ -405,7 +406,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
std::vector<uint8_t> output; std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output); bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
} }
@ -419,7 +420,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size)); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
std::vector<uint8_t> output; std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output); bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == size); GGML_ASSERT(output.size() == size);
// output serialization format: | data (size bytes) | // output serialization format: | data (size bytes) |
@ -444,7 +445,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
memcpy(input.data(), &rpc_src, sizeof(rpc_src)); memcpy(input.data(), &rpc_src, sizeof(rpc_src));
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst)); memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
std::vector<uint8_t> output; std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output); bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
// output serialization format: | result (1 byte) | // output serialization format: | result (1 byte) |
GGML_ASSERT(output.size() == 1); GGML_ASSERT(output.size() == 1);
@ -459,7 +460,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr)); memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value)); memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
std::vector<uint8_t> output; std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output); bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
} }
@ -488,7 +489,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
memcpy(input.data(), &size, sizeof(size)); memcpy(input.data(), &size, sizeof(size));
std::vector<uint8_t> output; std::vector<uint8_t> output;
auto sock = get_socket(buft_ctx->endpoint); auto sock = get_socket(buft_ctx->endpoint);
bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output); bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@ -511,7 +512,7 @@ static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
// input serialization format: | 0 bytes | // input serialization format: | 0 bytes |
std::vector<uint8_t> input; std::vector<uint8_t> input;
std::vector<uint8_t> output; std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output); bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t)); GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | alignment (8 bytes) | // output serialization format: | alignment (8 bytes) |
@ -529,7 +530,7 @@ static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
// input serialization format: | 0 bytes | // input serialization format: | 0 bytes |
std::vector<uint8_t> input; std::vector<uint8_t> input;
std::vector<uint8_t> output; std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output); bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t)); GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | max_size (8 bytes) | // output serialization format: | max_size (8 bytes) |
@ -622,7 +623,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
serialize_graph(cgraph, input); serialize_graph(cgraph, input);
std::vector<uint8_t> output; std::vector<uint8_t> output;
auto sock = get_socket(rpc_ctx->endpoint); auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output); bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == 1); GGML_ASSERT(output.size() == 1);
return (enum ggml_status)output[0]; return (enum ggml_status)output[0];
@ -636,7 +637,7 @@ GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const
} }
GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) { if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
return false; return false;
} }
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
@ -678,6 +679,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
} }
auto sock = get_socket(endpoint); auto sock = get_socket(endpoint);
if (sock == nullptr) { if (sock == nullptr) {
fprintf(stderr, "Failed to connect to %s\n", endpoint);
return nullptr; return nullptr;
} }
size_t alignment = get_alignment(sock); size_t alignment = get_alignment(sock);
@ -719,7 +721,7 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
// input serialization format: | 0 bytes | // input serialization format: | 0 bytes |
std::vector<uint8_t> input; std::vector<uint8_t> input;
std::vector<uint8_t> output; std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output); bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | free (8 bytes) | total (8 bytes) | // output serialization format: | free (8 bytes) | total (8 bytes) |
@ -1098,59 +1100,69 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
if (!recv_data(sockfd, &cmd, 1)) { if (!recv_data(sockfd, &cmd, 1)) {
break; break;
} }
if (cmd >= RPC_CMD_COUNT) {
// fail fast if the command is invalid
fprintf(stderr, "Unknown command: %d\n", cmd);
break;
}
std::vector<uint8_t> input; std::vector<uint8_t> input;
std::vector<uint8_t> output; std::vector<uint8_t> output;
uint64_t input_size; uint64_t input_size;
if (!recv_data(sockfd, &input_size, sizeof(input_size))) { if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
break; break;
} }
input.resize(input_size); try {
input.resize(input_size);
} catch (const std::bad_alloc & e) {
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
break;
}
if (!recv_data(sockfd, input.data(), input_size)) { if (!recv_data(sockfd, input.data(), input_size)) {
break; break;
} }
bool ok = true; bool ok = true;
switch (cmd) { switch (cmd) {
case ALLOC_BUFFER: { case RPC_CMD_ALLOC_BUFFER: {
ok = server.alloc_buffer(input, output); ok = server.alloc_buffer(input, output);
break; break;
} }
case GET_ALIGNMENT: { case RPC_CMD_GET_ALIGNMENT: {
server.get_alignment(output); server.get_alignment(output);
break; break;
} }
case GET_MAX_SIZE: { case RPC_CMD_GET_MAX_SIZE: {
server.get_max_size(output); server.get_max_size(output);
break; break;
} }
case BUFFER_GET_BASE: { case RPC_CMD_BUFFER_GET_BASE: {
ok = server.buffer_get_base(input, output); ok = server.buffer_get_base(input, output);
break; break;
} }
case FREE_BUFFER: { case RPC_CMD_FREE_BUFFER: {
ok = server.free_buffer(input); ok = server.free_buffer(input);
break; break;
} }
case BUFFER_CLEAR: { case RPC_CMD_BUFFER_CLEAR: {
ok = server.buffer_clear(input); ok = server.buffer_clear(input);
break; break;
} }
case SET_TENSOR: { case RPC_CMD_SET_TENSOR: {
ok = server.set_tensor(input); ok = server.set_tensor(input);
break; break;
} }
case GET_TENSOR: { case RPC_CMD_GET_TENSOR: {
ok = server.get_tensor(input, output); ok = server.get_tensor(input, output);
break; break;
} }
case COPY_TENSOR: { case RPC_CMD_COPY_TENSOR: {
ok = server.copy_tensor(input, output); ok = server.copy_tensor(input, output);
break; break;
} }
case GRAPH_COMPUTE: { case RPC_CMD_GRAPH_COMPUTE: {
ok = server.graph_compute(input, output); ok = server.graph_compute(input, output);
break; break;
} }
case GET_DEVICE_MEMORY: { case RPC_CMD_GET_DEVICE_MEMORY: {
// output serialization format: | free (8 bytes) | total (8 bytes) | // output serialization format: | free (8 bytes) | total (8 bytes) |
output.resize(2*sizeof(uint64_t), 0); output.resize(2*sizeof(uint64_t), 0);
memcpy(output.data(), &free_mem, sizeof(free_mem)); memcpy(output.data(), &free_mem, sizeof(free_mem));
@ -1203,8 +1215,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
return; return;
} }
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
fflush(stdout);
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem); rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
printf("Client connection closed\n"); printf("Client connection closed\n");
fflush(stdout);
} }
#ifdef _WIN32 #ifdef _WIN32
WSACleanup(); WSACleanup();

View File

@ -38,6 +38,7 @@
#include "ggml-sycl/backend.hpp" #include "ggml-sycl/backend.hpp"
#include "ggml-sycl/presets.hpp" #include "ggml-sycl/presets.hpp"
#include "ggml-sycl/gemm.hpp"
bool ggml_sycl_loaded(void); bool ggml_sycl_loaded(void);
void ggml_sycl_free_data(struct ggml_tensor * tensor); void ggml_sycl_free_data(struct ggml_tensor * tensor);
@ -893,43 +894,6 @@ static void clamp_f32(const float * x, float * dst, const float min, const float
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
} }
template <typename T>
static void im2col_kernel(const float *x, T *dst, int offset_delta,
int IW, int IH, int OW, int KW, int KH,
int pelements, int CHW, int s0, int s1, int p0,
int p1, int d0, int d1,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
if (i >= pelements) {
return;
}
const int ksize = OW * (KH > 1 ? KW : 1);
const int kx = i / ksize;
const int kd = kx * ksize;
const int ky = (i - kd) / OW;
const int ix = i % OW;
const int64_t iiw = ix * s0 + kx * d0 - p0;
const int64_t iih = item_ct1.get_group(1) * s1 + ky * d1 - p1;
const int64_t offset_dst =
(item_ct1.get_group(1) * OW + ix) * CHW +
(item_ct1.get_group(0) * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] =
sycl::vec<float, 1>(0.0f)
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
} else {
const int64_t offset_src = item_ct1.get_group(0) * offset_delta;
dst[offset_dst] =
sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
}
}
template <typename Ti, typename To> template <typename Ti, typename To>
static void pool2d_nchw_kernel( static void pool2d_nchw_kernel(
const int ih, const int iw, const int oh, const int ow, const int ih, const int iw, const int oh, const int ow,
@ -1742,32 +1706,6 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
}); });
} }
template <typename T>
static void im2col_sycl(const float *x, T *dst, int IW, int IH,
int OW, int OH, int KW, int KH, int IC,
int offset_delta, int s0, int s1, int p0,
int p1, int d0, int d1,
queue_ptr stream) {
const int parallel_elements = OW * KW * KH;
const int num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
sycl::range<3> block_nums(IC, OH, num_blocks);
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums *
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH,
parallel_elements, (IC * KH * KW), s0, s1, p0,
p1, d0, d1, item_ct1);
});
}
}
static bool g_sycl_loaded = false; static bool g_sycl_loaded = false;
bool ggml_sycl_loaded(void) { bool ggml_sycl_loaded(void) {
@ -2545,6 +2483,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
const sycl::half alpha_f16 = 1.0f; const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f; const sycl::half beta_f16 = 0.0f;
#if !GGML_SYCL_DNNL
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
*stream, oneapi::mkl::transpose::trans, *stream, oneapi::mkl::transpose::trans,
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
@ -2554,6 +2493,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
dpct::library_data_t::real_half))); dpct::library_data_t::real_half)));
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
#else
auto dnnl_stream = ctx.stream_dnnl(stream);
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
#endif
} }
else { else {
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
@ -2576,13 +2522,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
#if !GGML_SYCL_DNNL
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
*stream, oneapi::mkl::transpose::trans, *stream, oneapi::mkl::transpose::trans,
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
dst_dd_i, ldc))); dst_dd_i, ldc)));
#else
auto dnnl_stream = ctx.stream_dnnl(stream);
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
#endif
} }
(void) dst; (void) dst;
(void) src1_ddq_i; (void) src1_ddq_i;
@ -2636,47 +2587,6 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
(void) src1_dd; (void) src1_dd;
} }
inline void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
const int64_t IC = src1->ne[is_2D ? 2 : 1];
const int64_t IH = is_2D ? src1->ne[1] : 1;
const int64_t IW = src1->ne[0];
const int64_t KH = is_2D ? src0->ne[1] : 1;
const int64_t KW = src0->ne[0];
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
if (dst->type == GGML_TYPE_F16) {
im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
} else {
im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
}
(void) src0;
(void) src0_dd;
}
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst, const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd, const float *src0_dd, const float *src1_dd,
@ -3581,7 +3491,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE
&& (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda || src1->ne[1] > MMVQ_MIN_BATCH_SIZE);
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type) bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;

View File

@ -25,5 +25,6 @@
#include "norm.hpp" #include "norm.hpp"
#include "softmax.hpp" #include "softmax.hpp"
#include "tsembd.hpp" #include "tsembd.hpp"
#include "im2col.hpp"
#endif // GGML_SYCL_BACKEND_HPP #endif // GGML_SYCL_BACKEND_HPP

View File

@ -51,3 +51,14 @@ void ggml_sycl_host_free(void* ptr) try {
<< ", line:" << __LINE__ << std::endl; << ", line:" << __LINE__ << std::endl;
std::exit(1); std::exit(1);
} }
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {
const int64_t max_range = std::numeric_limits<int>::max();
int64_t sycl_down_blk_size = block_size;
int64_t global_range = accumulate_block_num * sycl_down_blk_size;
while(global_range > max_range) {
sycl_down_blk_size /= 2;
global_range = accumulate_block_num * sycl_down_blk_size;
}
return sycl_down_blk_size;
}

View File

@ -19,6 +19,10 @@
#include "dpct/helper.hpp" #include "dpct/helper.hpp"
#include "ggml-sycl.h" #include "ggml-sycl.h"
#include "presets.hpp" #include "presets.hpp"
#if GGML_SYCL_DNNL
#include "dnnl.hpp"
#include "dnnl_sycl.hpp"
#endif
#define GGML_COMMON_DECL_SYCL #define GGML_COMMON_DECL_SYCL
#define GGML_COMMON_IMPL_SYCL #define GGML_COMMON_IMPL_SYCL
@ -130,6 +134,7 @@ typedef sycl::float2 dfloat2;
#endif // GGML_SYCL_F16 #endif // GGML_SYCL_F16
#define MMVQ_MAX_BATCH_SIZE 8 #define MMVQ_MAX_BATCH_SIZE 8
#define MMVQ_MIN_BATCH_SIZE 4
static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
@ -276,6 +281,52 @@ struct ggml_backend_sycl_context {
return stream(device, 0); return stream(device, 0);
} }
#if GGML_SYCL_DNNL
dnnl::engine make_engine(sycl::queue* q) {
// Get the device associated with the queue
sycl::device dev = q->get_device();
// Get the context associated with the queue
sycl::context ctx = q->get_context();
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
return eng;
}
std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
dnnl::stream stream_dnnl(int device, int _stream) {
auto q = stream(device, _stream);
return stream_dnnl(q);
}
dnnl::engine engine_dnnl(sycl::queue* qptr) {
auto it = engine_map.find(qptr);
if (it == engine_map.end()) {
auto eng = make_engine(qptr);
engine_map[qptr] = eng;
return eng;
}
else
{
return it->second;
}
}
dnnl::stream stream_dnnl(sycl::queue* qptr) {
auto it = stream_map.find(qptr);
if (it == stream_map.end()) {
auto eng = engine_dnnl(qptr);
auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
stream_map[qptr] = stream;
return stream;
}
else
{
return it->second;
}
}
dnnl::stream stream_dnnl() {
return stream_dnnl(device, 0);
}
#endif
// pool // pool
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES]; std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
@ -352,4 +403,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
return acc.template get_multi_ptr<sycl::access::decorated::no>().get(); return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
} }
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
#endif // GGML_SYCL_COMMON_HPP #endif // GGML_SYCL_COMMON_HPP

View File

@ -3,19 +3,19 @@
#include "presets.hpp" #include "presets.hpp"
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2)); item_ct1.get_local_id(2));
if (i >= k) { if (i >= k) {
return; return;
} }
const int ib = i/qk; // block index const int64_t ib = i/qk; // block index
const int iqs = (i%qk)/qr; // quant index const int64_t iqs = (i%qk)/qr; // quant index
const int iybs = i - i%qk; // y block start index const int64_t iybs = i - i%qk; // y block start index
const int y_offset = qr == 1 ? 1 : qk/2; const int64_t y_offset = qr == 1 ? 1 : qk/2;
// dequantize // dequantize
dfloat2 v; dfloat2 v;
@ -27,9 +27,9 @@ static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_sycl(const void *__restrict__ vx, static void dequantize_block_sycl(const void *__restrict__ vx,
dst_t *__restrict__ y, const int k, dst_t *__restrict__ y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE); const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
@ -45,9 +45,9 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
} }
template <typename dst_t> template <typename dst_t>
static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb = k / QK_K; const int64_t nb = k / QK_K;
#if QK_K == 256 #if QK_K == 256
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
@ -77,9 +77,9 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
} }
template <typename dst_t> template <typename dst_t>
static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb = k / QK_K; const int64_t nb = k / QK_K;
#if QK_K == 256 #if QK_K == 256
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
@ -108,10 +108,10 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
} }
template <typename dst_t> template <typename dst_t>
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb32 = k / 32; const int64_t nb32 = k / 32;
const int nb = (k + 255) / 256; const int64_t nb = (k + 255) / 256;
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
@ -126,10 +126,10 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
} }
template <typename dst_t> template <typename dst_t>
static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb32 = k / 32; const int64_t nb32 = k / 32;
const int nb = (k + 255) / 256; const int64_t nb = (k + 255) / 256;
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
@ -145,9 +145,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
template <typename dst_t> template <typename dst_t>
static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb = k / QK_K; const int64_t nb = k / QK_K;
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
@ -165,9 +165,9 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
} }
template <typename dst_t> template <typename dst_t>
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb = k / QK_K; const int64_t nb = k / QK_K;
#if QK_K == 256 #if QK_K == 256
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
@ -197,9 +197,9 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
} }
template <typename dst_t> template <typename dst_t>
static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb = k / QK_K; const int64_t nb = k / QK_K;
#if QK_K == 256 #if QK_K == 256
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
@ -229,9 +229,9 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
} }
template <typename dst_t> template <typename dst_t>
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb = k / QK_K; const int64_t nb = k / QK_K;
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
@ -250,9 +250,9 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
} }
template <typename dst_t> template <typename dst_t>
static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb = k / QK_K; const int64_t nb = k / QK_K;
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
@ -271,9 +271,9 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
} }
template <typename dst_t> template <typename dst_t>
static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb = k / QK_K; const int64_t nb = k / QK_K;
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
@ -292,9 +292,9 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
} }
template <typename dst_t> template <typename dst_t>
static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb = k / QK_K; const int64_t nb = k / QK_K;
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
@ -313,9 +313,9 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
} }
template <typename dst_t> template <typename dst_t>
static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb = k / QK_K; const int64_t nb = k / QK_K;
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
@ -333,9 +333,9 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
template <typename dst_t> template <typename dst_t>
static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb = k / QK_K; const int64_t nb = k / QK_K;
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
@ -354,9 +354,9 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
} }
template <typename dst_t> template <typename dst_t>
static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb = k / QK_K; const int64_t nb = k / QK_K;
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
@ -374,9 +374,9 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
} }
template <typename dst_t> template <typename dst_t>
static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb = (k + QK_K - 1) / QK_K; const int64_t nb = (k + QK_K - 1) / QK_K;
#if QK_K == 64 #if QK_K == 64
dequantize_row_iq4_nl_sycl(vx, y, k, stream); dequantize_row_iq4_nl_sycl(vx, y, k, stream);
#else #else
@ -398,9 +398,9 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
} }
template <typename dst_t> template <typename dst_t>
static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb = (k + QK_K - 1) / QK_K; const int64_t nb = (k + QK_K - 1) / QK_K;
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
@ -418,34 +418,34 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
} }
template <typename src_t, typename dst_t> template <typename src_t, typename dst_t>
static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int64_t work_group_size = item_ct1.get_local_range(2);
item_ct1.get_local_id(2); const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
if (i >= k) {
return;
}
// make each work-item deal with more elements since sycl global range can not exceed max int
const src_t * x = (src_t *) vx; const src_t * x = (src_t *) vx;
for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
y[i] = x[i]; y[i] = x[i];
}
} }
template <typename src_t, typename dst_t> template <typename src_t, typename dst_t>
static void convert_unary_sycl(const void *__restrict__ vx, static void convert_unary_sycl(const void *__restrict__ vx,
dst_t *__restrict__ y, const int k, dst_t *__restrict__ y, const int64_t k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE; const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
// decrease global range when it exceeds the max int
int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
sycl::range<3> block_nums(1, 1, num_blocks);
sycl::range<3> local_range(1, 1, local_size);
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>( sycl::nd_range<3>(block_nums * local_range, local_range),
sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
convert_unary<src_t>(vx, y, k, item_ct1); convert_unary<src_t>(vx, y, k, item_ct1);
}); });

View File

@ -17,7 +17,7 @@
template <typename T> template <typename T>
using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y, using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
int k, dpct::queue_ptr stream); int64_t k, dpct::queue_ptr stream);
typedef to_t_sycl_t<float> to_fp32_sycl_t; typedef to_t_sycl_t<float> to_fp32_sycl_t;
typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t; typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;

View File

@ -15,9 +15,9 @@
#include "common.hpp" #include "common.hpp"
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib, static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) { const int iqs, dfloat2 &v) {
const block_q4_0 * x = (const block_q4_0 *) vx; const block_q4_0 * x = (const block_q4_0 *) vx;
@ -40,7 +40,7 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
#endif // GGML_SYCL_F16 #endif // GGML_SYCL_F16
} }
static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib, static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) { const int iqs, dfloat2 &v) {
const block_q4_1 * x = (const block_q4_1 *) vx; const block_q4_1 * x = (const block_q4_1 *) vx;
@ -64,7 +64,7 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
#endif // GGML_SYCL_F16 #endif // GGML_SYCL_F16
} }
static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib, static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) { const int iqs, dfloat2 &v) {
const block_q5_0 * x = (const block_q5_0 *) vx; const block_q5_0 * x = (const block_q5_0 *) vx;
@ -91,7 +91,7 @@ static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
#endif // GGML_SYCL_F16 #endif // GGML_SYCL_F16
} }
static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib, static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) { const int iqs, dfloat2 &v) {
const block_q5_1 * x = (const block_q5_1 *) vx; const block_q5_1 * x = (const block_q5_1 *) vx;
@ -118,7 +118,7 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
#endif // GGML_SYCL_F16 #endif // GGML_SYCL_F16
} }
static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib, static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) { const int iqs, dfloat2 &v) {
const block_q8_0 * x = (const block_q8_0 *) vx; const block_q8_0 * x = (const block_q8_0 *) vx;
@ -138,16 +138,16 @@ static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
} }
template<typename dst_t> template<typename dst_t>
static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32, static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
// assume 32 threads // assume 32 threads
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
const int il = tid/8; const int64_t il = tid/8;
const int ir = tid%8; const int64_t ir = tid%8;
const int ib = 8*i + ir; const int64_t ib = 8*i + ir;
if (ib >= nb32) { if (ib >= nb32) {
return; return;
} }
@ -168,16 +168,16 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri
} }
template<typename dst_t> template<typename dst_t>
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32, static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
// assume 32 threads // assume 32 threads
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
const int il = tid/8; const int64_t il = tid/8;
const int ir = tid%8; const int64_t ir = tid%8;
const int ib = 8*i + ir; const int64_t ib = 8*i + ir;
if (ib >= nb32) { if (ib >= nb32) {
return; return;
} }
@ -203,14 +203,14 @@ template<typename dst_t>
static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy, static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
const block_q2_K * x = (const block_q2_K *) vx; const block_q2_K * x = (const block_q2_K *) vx;
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256 #if QK_K == 256
const int n = tid/32; const int64_t n = tid/32;
const int l = tid - 32*n; const int64_t l = tid - 32*n;
const int is = 8*n + l/16; const int64_t is = 8*n + l/16;
const uint8_t q = x[i].qs[32*n + l]; const uint8_t q = x[i].qs[32*n + l];
dst_t * y = yy + i*QK_K + 128*n; dst_t * y = yy + i*QK_K + 128*n;
@ -222,8 +222,8 @@ static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restri
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
#else #else
const int is = tid/16; // 0 or 1 const int64_t is = tid/16; // 0 or 1
const int il = tid%16; // 0...15 const int64_t il = tid%16; // 0...15
const uint8_t q = x[i].qs[il] >> (2*is); const uint8_t q = x[i].qs[il] >> (2*is);
dst_t * y = yy + i*QK_K + 16*is + il; dst_t * y = yy + i*QK_K + 16*is + il;
@ -239,19 +239,19 @@ template<typename dst_t>
static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy, static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
const block_q3_K * x = (const block_q3_K *) vx; const block_q3_K * x = (const block_q3_K *) vx;
#if QK_K == 256 #if QK_K == 256
const int r = item_ct1.get_local_id(2) / 4; const int64_t r = item_ct1.get_local_id(2) / 4;
const int tid = r/2; const int64_t tid = r/2;
const int is0 = r%2; const int64_t is0 = r%2;
const int l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4); const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
const int n = tid / 4; const int64_t n = tid / 4;
const int j = tid - 4*n; const int64_t j = tid - 4*n;
uint8_t m = 1 << (4*n + j); uint8_t m = 1 << (4*n + j);
int is = 8*n + 2*j + is0; int64_t is = 8*n + 2*j + is0;
int shift = 2*j; int shift = 2*j;
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
@ -267,11 +267,11 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
#else #else
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
const int is = tid/16; // 0 or 1 const int64_t is = tid/16; // 0 or 1
const int il = tid%16; // 0...15 const int64_t il = tid%16; // 0...15
const int im = il/8; // 0...1 const int64_t im = il/8; // 0...1
const int in = il%8; // 0...7 const int64_t in = il%8; // 0...7
dst_t * y = yy + i*QK_K + 16*is + il; dst_t * y = yy + i*QK_K + 16*is + il;
@ -307,15 +307,15 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) { uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
const block_q4_K * x = (const block_q4_K *) vx; const block_q4_K * x = (const block_q4_K *) vx;
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
#if QK_K == 256 #if QK_K == 256
// assume 32 threads // assume 32 threads
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
const int il = tid/8; const int64_t il = tid/8;
const int ir = tid%8; const int64_t ir = tid%8;
const int is = 2*il; const int64_t is = 2*il;
const int n = 4; const int64_t n = 4;
dst_t * y = yy + i*QK_K + 64*il + n*ir; dst_t * y = yy + i*QK_K + 64*il + n*ir;
@ -341,7 +341,7 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
y[l +32] = d2 * (q_vec[l] >> 4) - m2; y[l +32] = d2 * (q_vec[l] >> 4) - m2;
} }
#else #else
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
const uint8_t * q = x[i].qs; const uint8_t * q = x[i].qs;
dst_t * y = yy + i*QK_K; dst_t * y = yy + i*QK_K;
const float d = (float)x[i].dm[0]; const float d = (float)x[i].dm[0];
@ -356,14 +356,14 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const block_q5_K * x = (const block_q5_K *) vx; const block_q5_K * x = (const block_q5_K *) vx;
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
#if QK_K == 256 #if QK_K == 256
// assume 64 threads - this is very slightly better than the one below // assume 64 threads - this is very slightly better than the one below
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
const int il = tid/16; // il is in 0...3 const int64_t il = tid/16; // il is in 0...3
const int ir = tid%16; // ir is in 0...15 const int64_t ir = tid%16; // ir is in 0...15
const int is = 2*il; // is is in 0...6 const int64_t is = 2*il; // is is in 0...6
dst_t * y = yy + i*QK_K + 64*il + 2*ir; dst_t * y = yy + i*QK_K + 64*il + 2*ir;
@ -386,11 +386,11 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
#else #else
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
const uint8_t q = x[i].qs[tid]; const uint8_t q = x[i].qs[tid];
const int im = tid/8; // 0...3 const int64_t im = tid/8; // 0...3
const int in = tid%8; // 0...7 const int64_t in = tid%8; // 0...7
const int is = tid/16; // 0 or 1 const int64_t is = tid/16; // 0 or 1
const uint8_t h = x[i].qh[in] >> im; const uint8_t h = x[i].qh[in] >> im;
const float d = x[i].d; const float d = x[i].d;
dst_t * y = yy + i*QK_K + tid; dst_t * y = yy + i*QK_K + tid;
@ -404,14 +404,14 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const block_q6_K * x = (const block_q6_K *) vx; const block_q6_K * x = (const block_q6_K *) vx;
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
#if QK_K == 256 #if QK_K == 256
// assume 64 threads - this is very slightly better than the one below // assume 64 threads - this is very slightly better than the one below
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
const int ip = tid/32; // ip is 0 or 1 const int64_t ip = tid/32; // ip is 0 or 1
const int il = tid - 32*ip; // 0...32 const int64_t il = tid - 32*ip; // 0...32
const int is = 8*ip + il/16; const int64_t is = 8*ip + il/16;
dst_t * y = yy + i*QK_K + 128*ip + il; dst_t * y = yy + i*QK_K + 128*ip + il;
@ -428,9 +428,9 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
#else #else
// assume 32 threads // assume 32 threads
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
const int ip = tid/16; // 0 or 1 const int64_t ip = tid/16; // 0 or 1
const int il = tid - 16*ip; // 0...15 const int64_t il = tid - 16*ip; // 0...15
dst_t * y = yy + i*QK_K + 16*ip + il; dst_t * y = yy + i*QK_K + 16*ip + il;
@ -452,13 +452,13 @@ static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __res
const uint8_t *ksigns_iq2xs_ptr, const uint8_t *ksigns_iq2xs_ptr,
const uint8_t *kmask_iq2xs_ptr) { const uint8_t *kmask_iq2xs_ptr) {
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
const block_iq2_xxs * x = (const block_iq2_xxs *) vx; const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256 #if QK_K == 256
const int il = tid/8; // 0...3 const int64_t il = tid/8; // 0...3
const int ib = tid%8; // 0...7 const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il; dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint16_t * q2 = x[i].qs + 4*ib; const uint16_t * q2 = x[i].qs + 4*ib;
const uint8_t * aux8 = (const uint8_t *)q2; const uint8_t * aux8 = (const uint8_t *)q2;
@ -480,13 +480,13 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
const uint8_t *ksigns_iq2xs, const uint8_t *ksigns_iq2xs,
const uint8_t *kmask_iq2xs) { const uint8_t *kmask_iq2xs) {
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
const block_iq2_xs * x = (const block_iq2_xs *) vx; const block_iq2_xs * x = (const block_iq2_xs *) vx;
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256 #if QK_K == 256
const int il = tid/8; // 0...3 const int64_t il = tid/8; // 0...3
const int ib = tid%8; // 0...7 const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il; dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint16_t * q2 = x[i].qs + 4*ib; const uint16_t * q2 = x[i].qs + 4*ib;
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511)); const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
@ -504,13 +504,13 @@ __dpct_inline__ static void
dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy, dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
const block_iq2_s * x = (const block_iq2_s *) vx; const block_iq2_s * x = (const block_iq2_s *) vx;
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256 #if QK_K == 256
const int il = tid/8; // 0...3 const int64_t il = tid/8; // 0...3
const int ib = tid%8; // 0...7 const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il; dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300))); const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
@ -532,13 +532,13 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
const uint8_t *ksigns_iq2xs, const uint8_t *ksigns_iq2xs,
const uint8_t *kmask_iq2xs) { const uint8_t *kmask_iq2xs) {
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
const block_iq3_xxs * x = (const block_iq3_xxs *) vx; const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256 #if QK_K == 256
const int il = tid/8; // 0...3 const int64_t il = tid/8; // 0...3
const int ib = tid%8; // 0...7 const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il; dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * q3 = x[i].qs + 8*ib; const uint8_t * q3 = x[i].qs + 8*ib;
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib; const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
@ -563,13 +563,13 @@ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
const sycl::nd_item<3> &item_ct1, const sycl::nd_item<3> &item_ct1,
const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) { const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
const block_iq3_s * x = (const block_iq3_s *) vx; const block_iq3_s * x = (const block_iq3_s *) vx;
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256 #if QK_K == 256
const int il = tid/8; // 0...3 const int64_t il = tid/8; // 0...3
const int ib = tid%8; // 0...7 const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il; dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * qs = x[i].qs + 8*ib; const uint8_t * qs = x[i].qs + 8*ib;
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
@ -593,13 +593,13 @@ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
const sycl::nd_item<3> &item_ct1, const sycl::nd_item<3> &item_ct1,
const uint32_t *iq1s_grid_gpu) { const uint32_t *iq1s_grid_gpu) {
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
const block_iq1_s * x = (const block_iq1_s *) vx; const block_iq1_s * x = (const block_iq1_s *) vx;
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256 #if QK_K == 256
const int il = tid/8; // 0...3 const int64_t il = tid/8; // 0...3
const int ib = tid%8; // 0...7 const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il; dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA; const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1); const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
@ -623,13 +623,13 @@ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
const sycl::nd_item<3> &item_ct1, const sycl::nd_item<3> &item_ct1,
const uint32_t *iq1s_grid_gpu) { const uint32_t *iq1s_grid_gpu) {
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
const block_iq1_m * x = (const block_iq1_m *) vx; const block_iq1_m * x = (const block_iq1_m *) vx;
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256 #if QK_K == 256
const int il = tid/8; // 0...3 const int64_t il = tid/8; // 0...3
const int ib = tid%8; // 0...7 const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il; dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint16_t * sc = (const uint16_t *)x[i].scales; const uint16_t * sc = (const uint16_t *)x[i].scales;
iq1m_scale_t scale; iq1m_scale_t scale;
@ -656,12 +656,12 @@ __dpct_inline__ static void
dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy, dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL); const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
const int il = tid/8; // 0...3 const int64_t il = tid/8; // 0...3
const int ib = tid%8; // 0...7 const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 4*il; dst_t * y = yy + i*QK_K + 32*ib + 4*il;
const uint8_t * q4 = x[ib].qs + 4*il; const uint8_t * q4 = x[ib].qs + 4*il;
const float d = (float)x[ib].d; const float d = (float)x[ib].d;
@ -678,12 +678,12 @@ template <typename dst_t>
__dpct_inline__ static void __dpct_inline__ static void
dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy, dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2); const int64_t i = item_ct1.get_group(2);
const block_iq4_xs * x = (const block_iq4_xs *)vx; const block_iq4_xs * x = (const block_iq4_xs *)vx;
const int tid = item_ct1.get_local_id(2); const int64_t tid = item_ct1.get_local_id(2);
const int il = tid/8; // 0...3 const int64_t il = tid/8; // 0...3
const int ib = tid%8; // 0...7 const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 4*il; dst_t * y = yy + i*QK_K + 32*ib + 4*il;
const uint8_t * q4 = x[i].qs + 16*ib + 4*il; const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32); const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);

View File

@ -4,7 +4,7 @@
#include "presets.hpp" #include "presets.hpp"
static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const sycl::half *x = (const sycl::half *)vx; const sycl::half *x = (const sycl::half *)vx;
// automatic half -> float type cast if dfloat == float // automatic half -> float type cast if dfloat == float
@ -12,7 +12,7 @@ static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 &
v.y() = x[ib + iqs + 1]; v.y() = x[ib + iqs + 1];
} }
static void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){ static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const float * x = (const float *) vx; const float * x = (const float *) vx;
// automatic half -> float type cast if dfloat == float // automatic half -> float type cast if dfloat == float

101
ggml/src/ggml-sycl/gemm.hpp Normal file
View File

@ -0,0 +1,101 @@
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#ifndef GGML_SYCL_GEMM_HPP
#define GGML_SYCL_GEMM_HPP
#include <fstream>
#include <iostream>
#include "ggml-sycl.h"
#if GGML_SYCL_DNNL
#include "dnnl.hpp"
#include "dnnl_sycl.hpp"
class DnnlGemmWrapper {
public:
using dt = dnnl::memory::data_type;
using tag = dnnl::memory::format_tag;
template<typename T>
static constexpr dt to_dt() {
if constexpr (std::is_same_v<T, float>) return dt::f32;
else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;
else static_assert(0);
}
static inline void row_gemm(sycl::queue& q, bool a_trans,
bool b_trans, int m, int n, int k,
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
{
// Get the device associated with the queue
sycl::device dev = q.get_device();
// Get the context associated with the queue
sycl::context ctx = q.get_context();
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
dnnl::memory::dims a_dims = { m, k };
dnnl::memory::dims b_dims = { k, n };
dnnl::memory::dims c_dims = { m, n };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
// Create the primitive.
auto matmul_prim = dnnl::matmul(matmul_pd);
// Primitive arguments.
std::unordered_map<int, dnnl::memory> matmul_args;
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
matmul_args.insert({ DNNL_ARG_DST, c_mem });
matmul_prim.execute(stream, matmul_args);
}
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
bool b_trans, int m, int n, int k,
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
{
auto const eng = stream.get_engine();
dnnl::memory::dims a_dims = { m, k };
dnnl::memory::dims b_dims = { k, n };
dnnl::memory::dims c_dims = { m, n };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
// Create the primitive.
auto matmul_prim = dnnl::matmul(matmul_pd);
// Primitive arguments.
std::unordered_map<int, dnnl::memory> matmul_args;
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
matmul_args.insert({ DNNL_ARG_DST, c_mem });
matmul_prim.execute(stream, matmul_args);
}
};
#endif
#endif // GGML_SYCL_GEMM_HPP

View File

@ -0,0 +1,125 @@
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#include "im2col.hpp"
template <typename T>
static void im2col_kernel(
const float *x, T *dst, int64_t batch_offset, int64_t offset_delta,
int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1,
const sycl::nd_item<3> &item_ct1) {
const int64_t work_group_size = item_ct1.get_local_range(2);
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
// make each work-item deal with more elements since sycl global range can not exceed max int
for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) {
const int64_t ksize = OW * (KH > 1 ? KW : 1);
const int64_t kx = i / ksize;
const int64_t kd = kx * ksize;
const int64_t ky = (i - kd) / OW;
const int64_t ix = i % OW;
const int64_t oh = item_ct1.get_group(1);
const int64_t batch = item_ct1.get_group(0) / IC;
const int64_t ic = item_ct1.get_group(0) % IC;
const int64_t iiw = ix * s0 + kx * d0 - p0;
const int64_t iih = oh * s1 + ky * d1 - p1;
const int64_t offset_dst =
((batch * OH + oh) * OW + ix) * CHW +
(ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] =
sycl::vec<float, 1>(0.0f)
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
} else {
const int64_t offset_src = ic * offset_delta + batch * batch_offset;
dst[offset_dst] =
sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
}
}
}
template <typename T>
static void im2col_sycl(
const float *x, T *dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,
int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta,
int s0, int s1, int p0, int p1, int d0, int d1,
queue_ptr stream) {
const int64_t parallel_elements = OW * KW * KH;
const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
// decrease global range when it exceeds the max int
int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE);
sycl::range<3> block_nums(batch * IC, OH, num_blocks);
sycl::range<3> local_range(1, 1, local_size);
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * local_range, local_range),
[=](sycl::nd_item<3> item_ct1) {
im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH,
parallel_elements, (IC * KH * KW), s0, s1, p0,
p1, d0, d1, item_ct1);
});
}
}
void ggml_sycl_op_im2col(
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
const int64_t IC = src1->ne[is_2D ? 2 : 1];
const int64_t IH = is_2D ? src1->ne[1] : 1;
const int64_t IW = src1->ne[0];
const int64_t KH = is_2D ? src0->ne[1] : 1;
const int64_t KW = src0->ne[0];
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t batch = src1->ne[3];
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
if (dst->type == GGML_TYPE_F16) {
im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
} else {
im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
}
(void) src0;
(void) src0_dd;
}

View File

@ -0,0 +1,23 @@
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#ifndef GGML_SYCL_IM2COL_HPP
#define GGML_SYCL_IM2COL_HPP
#include "common.hpp"
void ggml_sycl_op_im2col(
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream);
#endif // GGML_SYCL_IM2COL_HPP

File diff suppressed because it is too large Load Diff

View File

@ -7253,43 +7253,34 @@ struct ggml_tensor * ggml_flash_attn_back(
struct ggml_tensor * ggml_ssm_conv( struct ggml_tensor * ggml_ssm_conv(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * s, struct ggml_tensor * sx,
struct ggml_tensor * x, struct ggml_tensor * c) {
struct ggml_tensor * c, GGML_ASSERT(ggml_is_3d(sx));
struct ggml_tensor * sq) {
GGML_ASSERT(ggml_is_3d(s));
GGML_ASSERT(ggml_is_matrix(x));
GGML_ASSERT(ggml_is_matrix(c)); GGML_ASSERT(ggml_is_matrix(c));
GGML_ASSERT(ggml_is_matrix(sq));
GGML_ASSERT(sq->type == GGML_TYPE_I32);
const int64_t d_conv = c->ne[0]; const int64_t d_conv = c->ne[0];
const int64_t d_inner = c->ne[1]; const int64_t d_inner = c->ne[1];
const int64_t n_tokens = x->ne[1]; const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence
const int64_t n_kv = s->ne[2]; const int64_t n_s = sx->ne[2];
GGML_ASSERT( s->ne[0] == d_conv - 1); // TODO: maybe support other strides than 1?
GGML_ASSERT( s->ne[1] == d_inner); GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
GGML_ASSERT( x->ne[0] == d_inner); GGML_ASSERT(sx->ne[1] == d_inner);
GGML_ASSERT(sq->ne[0] == n_kv); GGML_ASSERT(n_t >= 0);
GGML_ASSERT(sq->ne[1] == n_tokens);
bool is_node = false; bool is_node = false;
if (s->grad || x->grad || c->grad || sq->grad) { if (sx->grad || c->grad) {
GGML_ABORT("fatal error"); // TODO: implement GGML_ABORT("fatal error"); // TODO: implement
is_node = true; is_node = true;
} }
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv} struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
result->op = GGML_OP_SSM_CONV; result->op = GGML_OP_SSM_CONV;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = s; result->src[0] = sx;
result->src[1] = x; result->src[1] = c;
result->src[2] = c;
result->src[3] = sq;
return result; return result;
} }
@ -7303,39 +7294,42 @@ struct ggml_tensor * ggml_ssm_scan(
struct ggml_tensor * dt, struct ggml_tensor * dt,
struct ggml_tensor * A, struct ggml_tensor * A,
struct ggml_tensor * B, struct ggml_tensor * B,
struct ggml_tensor * C, struct ggml_tensor * C) {
struct ggml_tensor * sq) {
GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(s));
GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(x));
GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A)); GGML_ASSERT(ggml_is_contiguous(A));
GGML_ASSERT(sq->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_matrix(A));
GGML_ASSERT(ggml_is_3d(B));
GGML_ASSERT(ggml_is_3d(s));
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
GGML_ASSERT(ggml_are_same_shape(x, dt)); GGML_ASSERT(ggml_are_same_shape(x, dt));
GGML_ASSERT(ggml_are_same_shape(B, C));
{ {
const int64_t d_state = s->ne[0]; const int64_t d_state = s->ne[0];
const int64_t d_inner = s->ne[1]; const int64_t d_inner = s->ne[1];
const int64_t n_tokens = x->ne[1]; const int64_t n_seq_tokens = x->ne[1];
const int64_t n_seqs = x->ne[2];
GGML_ASSERT(s->ne[2] == n_seqs);
GGML_ASSERT(x->ne[0] == d_inner); GGML_ASSERT(x->ne[0] == d_inner);
GGML_ASSERT(A->ne[0] == d_state); GGML_ASSERT(A->ne[0] == d_state);
GGML_ASSERT(A->ne[1] == d_inner); GGML_ASSERT(A->ne[1] == d_inner);
GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[0] == d_state);
GGML_ASSERT(B->ne[1] == n_tokens); GGML_ASSERT(B->ne[1] == n_seq_tokens);
GGML_ASSERT(C->ne[0] == d_state); GGML_ASSERT(B->ne[2] == n_seqs);
GGML_ASSERT(C->ne[1] == n_tokens);
} }
bool is_node = false; bool is_node = false;
if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) { if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) {
GGML_ABORT("fatal error"); // TODO: implement GGML_ABORT("fatal error"); // TODO: implement
is_node = true; is_node = true;
} }
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv} // concatenated y + ssm_states
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
result->op = GGML_OP_SSM_SCAN; result->op = GGML_OP_SSM_SCAN;
@ -7346,7 +7340,6 @@ struct ggml_tensor * ggml_ssm_scan(
result->src[3] = A; result->src[3] = A;
result->src[4] = B; result->src[4] = B;
result->src[5] = C; result->src[5] = C;
result->src[6] = sq;
return result; return result;
} }
@ -11028,11 +11021,6 @@ static void ggml_compute_forward_concat_f32(
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS
// TODO: support for transposed / permuted tensors
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
const int32_t dim = ggml_get_op_params_i32(dst, 0); const int32_t dim = ggml_get_op_params_i32(dst, 0);
GGML_ASSERT(dim >= 0 && dim < 4); GGML_ASSERT(dim >= 0 && dim < 4);
@ -15819,27 +15807,22 @@ static void ggml_compute_forward_flash_attn_back(
static void ggml_compute_forward_ssm_conv_f32( static void ggml_compute_forward_ssm_conv_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // conv_state const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // x const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
const struct ggml_tensor * src3 = dst->src[3]; // state_seq
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
const int nc = src2->ne[0]; // d_conv const int nc = src1->ne[0]; // d_conv
const int nr = src0->ne[1]; // d_inner const int ncs = src0->ne[0]; // d_conv - 1 + n_t
const int n_t = src1->ne[1]; // n_tokens const int nr = src0->ne[1]; // d_inner
const int n_kv = src0->ne[2]; // max number of sequences in the batch const int n_t = dst->ne[1]; // tokens per sequence
const int n_s = dst->ne[2]; // number of sequences in the batch
GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst)); GGML_ASSERT( dst->ne[0] == nr);
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// for use with the destination state offset between sequences
GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
// rows per thread // rows per thread
const int dr = (nr + nth - 1)/nth; const int dr = (nr + nth - 1)/nth;
@ -15849,76 +15832,29 @@ static void ggml_compute_forward_ssm_conv_f32(
const int ir1 = MIN(ir0 + dr, nr); const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0; const int ir = ir1 - ir0;
if (n_kv > 1) { for (int i3 = 0; i3 < n_s; ++i3) {
// multiple sequences means it's hard to know when it's the first time a state is read, for (int i2 = 0; i2 < n_t; ++i2) {
// so copy them all over to the destination, just to be sure. // {d_conv - 1 + n_t, d_inner, n_seqs}
for (int i3 = 0; i3 < n_kv; ++i3) { // sliding window
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
// can't use memcpy because of d_conv vs d_conv - 1 float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
// TODO: transpose the output for smaller strides for big batches?
// d_inner
for (int i1 = 0; i1 < ir; ++i1) { for (int i1 = 0; i1 < ir; ++i1) {
for (int i0 = 0; i0 < nc - 1; ++i0) { // rowwise dot product
// copy s0 to last (d_conv - 1) columns of s // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; float sumf = 0.0f;
// d_conv
for (int i0 = 0; i0 < nc; ++i0) {
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
} }
x[i1] = sumf;
} }
} }
} }
for (int i2 = 0; i2 < n_t; ++i2) {
int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
float * s0; // {d_conv - 1, d_inner, n_kv}
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
int ne0s0;
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
// avoid needing to copy the state for the first token
if (i2 == 0) {
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
ne0s0 = src0->ne[0];
} else {
// the source is the last (d_conv - 1) columns of the destination
s0 = s + 1;
ne0s0 = nc;
}
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
// shift state left
for (int i0 = 0; i0 < nc - 1; ++i0) {
s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
}
// insert x on the last column
s[(nc - 1) + i1*nc] = x0[i1];
}
// handle copies when there are multiple output states
for (int i3 = 1; i3 < n_kv; ++i3) {
int32_t seq = sq[i3];
if (0 <= seq && seq < n_kv) {
float * s1 = s + (seq - sq[0])*nc*nr;
memcpy(s1, s, nc*ir*sizeof(float));
} else {
// stop at negative or too big seq_ids
break;
}
}
// it seems a little faster when this is separate from the state shift
for (int i1 = 0; i1 < ir; ++i1) {
// rowwise dot product
float sumf = 0.0f;
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
sumf += s[i] * c[i];
}
x[i1] = sumf;
}
}
} }
static void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_conv(
@ -15947,15 +15883,14 @@ static void ggml_compute_forward_ssm_scan_f32(
const struct ggml_tensor * src3 = dst->src[3]; // A const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C const struct ggml_tensor * src5 = dst->src[5]; // C
const struct ggml_tensor * src6 = dst->src[6]; // sq
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
const int64_t nc = src0->ne[0]; // d_state const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = src1->ne[1]; // number of tokens in the batch const int64_t n_t = src1->ne[1]; // number of tokens per sequence
const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch const int64_t n_s = src0->ne[2]; // number of sequences in the batch
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
@ -15964,12 +15899,12 @@ static void ggml_compute_forward_ssm_scan_f32(
GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C, and when copying the states // required for the dot product between s and C
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// required for per-sequence offsets for states // required for per-sequence offsets for states
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[2]) // required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float)); GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
// rows per thread // rows per thread
const int dr = (nr + nth - 1)/nth; const int dr = (nr + nth - 1)/nth;
@ -15979,64 +15914,36 @@ static void ggml_compute_forward_ssm_scan_f32(
const int ir1 = MIN(ir0 + dr, nr); const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0; const int ir = ir1 - ir0;
if (n_kv > 1) { for (int i3 = 0; i3 < n_s; ++i3) {
// it's hard to know if the source states have already been copied for (int i2 = 0; i2 < n_t; ++i2) {
// when there are multiple, so copy them already. const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
for (int i3 = 0; i3 < n_kv; ++i3) { const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
memcpy(s, s0, nc*ir*sizeof(float)); const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
} float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
for (int i2 = 0; i2 < n_t; ++i2) { // use the output as the source for the next token-wise iterations
int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens} if (i2 > 0) { s0 = s; }
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
float * s0;
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); // d_inner
for (int i1 = 0; i1 < ir; ++i1) {
// avoid needing to copy the state for the first token // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
if (i2 == 0) { float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv} float x_dt = x[i1] * dt_soft_plus;
} else { float sumf = 0.0f;
// otherwise the source is the same as the destination // d_state
s0 = s; for (int i0 = 0; i0 < nc; ++i0) {
} int i = i0 + i1*nc;
// state = prev_state * dA + dB * x
// d_inner float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
for (int i1 = 0; i1 < ir; ++i1) { // y = rowwise_dotprod(state, C)
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 sumf += state * C[i0];
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; s[i] = state;
float x_dt = x[i1] * dt_soft_plus; }
float sumf = 0.0f; y[i1] = sumf;
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// state = prev_state * dA + dB * x
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[i0];
s[i] = state;
}
y[i1] = sumf;
}
// handle copies when there are multiple output states
for (int i3 = 1; i3 < n_kv; ++i3) {
int32_t seq = sq[i3];
if (0 <= seq && seq < n_kv) {
float * s1 = s + (seq - sq[0])*nc*nr;
memcpy(s1, s, nc*ir*sizeof(float));
} else {
// stop at negative or too big seq_ids
break;
} }
} }
} }

View File

@ -0,0 +1,24 @@
#version 450
#include "types.comp"
#include "generic_binary_head.comp"
void main() {
const uint idx = gl_GlobalInvocationID.x;
if (idx >= p.ne) {
return;
}
const uint offset = p.param3;
const uint src1_i = idx - offset;
const uint oz = src1_i / p.nb02;
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
const uint ox = src1_i % p.nb01;
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
} else {
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]));
}
}

View File

@ -30,6 +30,10 @@ void main() {
#ifndef OPTIMIZATION_ERROR_WORKAROUND #ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]); data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]);
#else #else
data_d[p.d_offset + dst_idx] = is_src0 ? data_a[src0_idx] : data_b[src1_idx]; if (is_src0) {
data_d[p.d_offset + dst_idx] = data_a[src0_idx];
} else {
data_d[p.d_offset + dst_idx] = data_b[src1_idx];
}
#endif #endif
} }

View File

@ -39,8 +39,7 @@ void main() {
vec2 v = dequantize(ib, iqs, a_offset / QUANT_K); vec2 v = dequantize(ib, iqs, a_offset / QUANT_K);
// matrix multiplication // matrix multiplication
tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[b_offset + iybs + iqs]) + tmp[tid] = fma(FLOAT_TYPE(v.x), FLOAT_TYPE(data_b[b_offset + iybs + iqs]), fma(FLOAT_TYPE(v.y), FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]), tmp[tid]));
FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
} }
// sum up partial sums and write back result // sum up partial sums and write back result

View File

@ -53,7 +53,7 @@ void main() {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
tmp[tid] += xi * FLOAT_TYPE(data_b[iy]); tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
} }
// sum up partial sums and write back result // sum up partial sums and write back result

View File

@ -52,7 +52,7 @@ void main() {
// y is not transposed but permuted // y is not transposed but permuted
const uint iy = channel*nrows_y + row_y; const uint iy = channel*nrows_y + row_y;
tmp[tid] += xi * FLOAT_TYPE(data_b[iy]); tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
} }
// dst is not transposed and not permuted // dst is not transposed and not permuted

View File

@ -39,24 +39,25 @@ void main() {
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
sum1 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3) sum1 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3); fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3), sum1))))))));
sum2 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF) sum2 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF); fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF), sum2))))))));
} }
tmp[16 * ix + tid] += dall * sum1 - dmin * sum2; const uint tmp_idx = 16 * ix + tid;
tmp[tmp_idx] = fma(dall, sum1, fma(-dmin, sum2, tmp[tmp_idx]));
} }
// sum up partial sums and write back result // sum up partial sums and write back result

View File

@ -40,16 +40,17 @@ void main() {
FLOAT_TYPE sum = FLOAT_TYPE(0.0); FLOAT_TYPE sum = FLOAT_TYPE(0.0);
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
sum += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)) sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)); fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
} }
tmp[16 * ix + tid] += d * sum; const uint tmp_idx = 16 * ix + tid;
tmp[tmp_idx] = fma(d, sum, tmp[tmp_idx]);
} }
// sum up partial sums and write back result // sum up partial sums and write back result

View File

@ -67,17 +67,17 @@ void main() {
const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4); const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4);
const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4); const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4);
const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1 + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3); const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx]), q4_0, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), q4_1, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3)));
const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_5 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7); const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_4, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), q4_5, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), q4_6, FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7)));
const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx]) * q4_8 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_9 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * q4_10 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11); const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx]), q4_8, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), q4_9, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), q4_10, FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11)));
const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_12 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_13 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * q4_14 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15); const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_12, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), q4_13, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), q4_14, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15)));
const FLOAT_TYPE smin = FLOAT_TYPE( const FLOAT_TYPE smin =
FLOAT_TYPE(data_b[b_offset + y1_idx ]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx ]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7 fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
+ FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7 fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), sc7,
+ FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * sc7 fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), sc7,
+ FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7 fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 3]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 35]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 3]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7)))))))))))))));
); const uint tmp_idx = 16 * ix + tid;
tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin); tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
#else #else
const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
@ -88,16 +88,19 @@ void main() {
const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx ]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1); const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), q4_0, FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1);
const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3); const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx ]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5); const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), q4_4, FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5);
const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7); const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
const FLOAT_TYPE smin = FLOAT_TYPE( const FLOAT_TYPE smin =
FLOAT_TYPE(data_b[b_offset + y1_idx]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7 fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
+ FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7 + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7)))))));
);
tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin); tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) +
sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
const uint tmp_idx = 16 * ix + tid;
tmp[tmp_idx] = fma(dall, (fma(sx, FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f), fma(sy, FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f),
fma(sz, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)), fma(sw, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))))))), fma(-dmin, smin, tmp[tmp_idx]));
#endif #endif
} }

View File

@ -66,35 +66,33 @@ void main() {
const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4); const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4);
const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4); const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4);
const FLOAT_TYPE sx = FLOAT_TYPE( const FLOAT_TYPE sx =
FLOAT_TYPE(data_b[b_offset + y1_idx ]) * (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)) fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)),
+ FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)) fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)),
+ FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) * (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)) fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 16]), (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)),
+ FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)) FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)))));
); const FLOAT_TYPE sy =
const FLOAT_TYPE sy = FLOAT_TYPE( fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)),
FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)) fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)),
+ FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)) fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 48]), (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)),
+ FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) * (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)) FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)))));
+ FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)) const FLOAT_TYPE sz =
); fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)),
const FLOAT_TYPE sz = FLOAT_TYPE( fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)),
FLOAT_TYPE(data_b[b_offset + y2_idx ]) * (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)) fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 16]), (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)),
+ FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)) FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)))));
+ FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) * (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)) const FLOAT_TYPE sw =
+ FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)) fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)),
); fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)),
const FLOAT_TYPE sw = FLOAT_TYPE( fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 48]), (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)),
FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)) FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)))));
+ FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)) const FLOAT_TYPE smin =
+ FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) * (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)) fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]), sc2,
+ FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)) fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]), sc3,
); fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]), sc6,
const FLOAT_TYPE smin = FLOAT_TYPE( (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7)));
(FLOAT_TYPE(data_b[b_offset + y1_idx]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17])) * sc2 + (FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49])) * sc3 const uint tmp_idx = 16 * ix + tid;
+ (FLOAT_TYPE(data_b[b_offset + y2_idx]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17])) * sc6 + (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7 tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
);
tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
} }
// sum up partial sums and write back result // sum up partial sums and write back result

View File

@ -44,22 +44,22 @@ void main() {
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
#if K_QUANTS_PER_ITERATION == 1 #if K_QUANTS_PER_ITERATION == 1
FLOAT_TYPE sum = FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32) const uint tmp_idx = 16 * ix + tid;
+ FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32) tmp[tmp_idx] = fma(FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32),
+ FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32) fma(FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32),
+ FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32) fma(FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32),
+ FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32) fma(FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32),
+ FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32) fma(FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32),
+ FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32) fma(FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32),
+ FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32); fma(FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32),
tmp[16 * ix + tid] += sum; fma(FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32), tmp[tmp_idx]))))))));
#else #else
FLOAT_TYPE sum = FLOAT_TYPE(0.0); FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 4; ++l) { [[unroll]] for (int l = 0; l < 4; ++l) {
sum += FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32) sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32) fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32),
+ FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32); fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), sum))));
} }
tmp[16 * ix + tid] += sum; tmp[16 * ix + tid] += sum;
#endif #endif

View File

@ -326,10 +326,10 @@ void main() {
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
} }
const float d = loadd.x * sc; const float d = loadd.x * sc;
const float m = loadd.y * mbyte; const float m = -loadd.y * mbyte;
buf_a[buf_idx ] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) - m); buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m));
buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m); buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
#elif defined(DATA_A_Q5_K) #elif defined(DATA_A_Q5_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
@ -357,10 +357,10 @@ void main() {
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
} }
const float d = loadd.x * sc; const float d = loadd.x * sc;
const float m = loadd.y * mbyte; const float m = -loadd.y * mbyte;
buf_a[buf_idx ] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0)) - m); buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m));
buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m); buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
#elif defined(DATA_A_Q6_K) #elif defined(DATA_A_Q6_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
@ -463,7 +463,8 @@ void main() {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) {
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]); const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]);
} }
} }
} }

View File

@ -0,0 +1,24 @@
#version 450
#include "types.comp"
#include "generic_unary_head.comp"
uint src0_idx_mod(uint idx) {
const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00;
}
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx_mod(idx)]);
}

View File

@ -368,6 +368,10 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
})); }));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] { tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
})); }));
@ -380,6 +384,10 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
})); }));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] { tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
})); }));

View File

@ -130,6 +130,7 @@ class Keys:
INNER_SIZE = "{arch}.ssm.inner_size" INNER_SIZE = "{arch}.ssm.inner_size"
STATE_SIZE = "{arch}.ssm.state_size" STATE_SIZE = "{arch}.ssm.state_size"
TIME_STEP_RANK = "{arch}.ssm.time_step_rank" TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
class Tokenizer: class Tokenizer:
MODEL = "tokenizer.ggml.model" MODEL = "tokenizer.ggml.model"
@ -219,6 +220,8 @@ class MODEL_ARCH(IntEnum):
T5 = auto() T5 = auto()
T5ENCODER = auto() T5ENCODER = auto()
JAIS = auto() JAIS = auto()
NEMOTRON = auto()
EXAONE = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
@ -347,6 +350,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.T5: "t5", MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder", MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_ARCH.JAIS: "jais", MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone",
} }
TENSOR_NAMES: dict[MODEL_TENSOR, str] = { TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -1065,6 +1070,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
], ],
MODEL_ARCH.NEMOTRON: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.EXAONE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO # TODO
} }
@ -1105,6 +1141,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_ARCH.CHATGLM: [ MODEL_ARCH.CHATGLM: [
MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ROPE_FREQS,
], ],
MODEL_ARCH.NEMOTRON: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
} }
# #
@ -1339,6 +1379,7 @@ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
# tokenization # tokenization
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL

View File

@ -730,6 +730,9 @@ class GGUFWriter:
def add_ssm_time_step_rank(self, value: int) -> None: def add_ssm_time_step_rank(self, value: int) -> None:
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
def add_tokenizer_model(self, model: str) -> None: def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model) self.add_string(Keys.Tokenizer.MODEL, model)

View File

@ -10,10 +10,10 @@ class TensorNameMap:
# Token embeddings # Token embeddings
MODEL_TENSOR.TOKEN_EMBD: ( MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox "gpt_neox.embed_in", # gptneox
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon "transformer.word_embeddings", # falcon
"word_embeddings", # bloom "word_embeddings", # bloom
"model.embed_tokens", # llama-hf "model.embed_tokens", # llama-hf nemotron
"tok_embeddings", # llama-pth "tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert "embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon "language_model.embedding.word_embeddings", # persimmon
@ -52,7 +52,7 @@ class TensorNameMap:
# Output # Output
MODEL_TENSOR.OUTPUT: ( MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox "embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone
"output", # llama-pth bloom internlm2 "output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon "word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2 "lm_head.linear", # phi2
@ -62,7 +62,7 @@ class TensorNameMap:
# Output norm # Output norm
MODEL_TENSOR.OUTPUT_NORM: ( MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox "gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon jais "transformer.ln_f", # gpt2 gpt-j falcon jais exaone
"model.norm", # llama-hf baichuan internlm2 "model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth "norm", # llama-pth
"transformer.norm_f", # mpt dbrx "transformer.norm_f", # mpt dbrx
@ -75,6 +75,7 @@ class TensorNameMap:
"transformer.rms_norm", # Grok "transformer.rms_norm", # Grok
"encoder.final_layernorm", # chatglm "encoder.final_layernorm", # chatglm
"transformer.norm", # openelm "transformer.norm", # openelm
"model.norm", # nemotron
), ),
# Rope frequencies # Rope frequencies
@ -88,12 +89,12 @@ class TensorNameMap:
# Attention norm # Attention norm
MODEL_TENSOR.ATTN_NORM: ( MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox "gpt_neox.layers.{bid}.input_layernorm", # gptneox
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone
"transformer.blocks.{bid}.norm_1", # mpt "transformer.blocks.{bid}.norm_1", # mpt
"transformer.h.{bid}.input_layernorm", # falcon7b "transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom "h.{bid}.input_layernorm", # bloom
"transformer.h.{bid}.ln_mlp", # falcon40b "transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf "model.layers.{bid}.input_layernorm", # llama-hf nemotron
"layers.{bid}.attention_norm", # llama-pth "layers.{bid}.attention_norm", # llama-pth
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon "language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi "model.layers.{bid}.ln1", # yi
@ -135,18 +136,19 @@ class TensorNameMap:
# Attention query # Attention query
MODEL_TENSOR.ATTN_Q: ( MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron
"layers.{bid}.attention.wq", # llama-pth "layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert "encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j "transformer.h.{bid}.attn.q_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.q_proj", # plamo "model.layers.layers.{bid}.self_attn.q_proj", # plamo
"model.layers.{bid}.attention.wq", # internlm2 "model.layers.{bid}.attention.wq", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
"transformer.h.{bid}.attn.attention.q_proj", # exaone
), ),
# Attention key # Attention key
MODEL_TENSOR.ATTN_K: ( MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron
"layers.{bid}.attention.wk", # llama-pth "layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert "encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j "transformer.h.{bid}.attn.k_proj", # gpt-j
@ -154,18 +156,20 @@ class TensorNameMap:
"model.layers.layers.{bid}.self_attn.k_proj", # plamo "model.layers.layers.{bid}.self_attn.k_proj", # plamo
"model.layers.{bid}.attention.wk", # internlm2 "model.layers.{bid}.attention.wk", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
"transformer.h.{bid}.attn.attention.k_proj", # exaone
), ),
# Attention value # Attention value
MODEL_TENSOR.ATTN_V: ( MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron
"layers.{bid}.attention.wv", # llama-pth "layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert "encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j "transformer.h.{bid}.attn.v_proj", # gpt-j
"transformer.h.{bid}.attn.v", # refact "transformer.h.{bid}.attn.v", # refact
"model.layers.layers.{bid}.self_attn.v_proj", # plamo "model.layers.layers.{bid}.self_attn.v_proj", # plamo
"model.layers.{bid}.attention.wv", # internlm2 "model.layers.{bid}.attention.wv", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
"transformer.h.{bid}.attn.attention.v_proj", # exaone
), ),
# Attention output # Attention output
@ -175,7 +179,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.attn.out_proj", # mpt "transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon "transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom "h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron
"layers.{bid}.attention.wo", # llama-pth "layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert "encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j "transformer.h.{bid}.attn.out_proj", # gpt-j
@ -190,6 +194,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
"encoder.layers.{bid}.self_attention.dense", # chatglm "encoder.layers.{bid}.self_attention.dense", # chatglm
"transformer.layers.{bid}.attn.out_proj", # openelm "transformer.layers.{bid}.attn.out_proj", # openelm
"transformer.h.{bid}.attn.attention.out_proj", # exaone
), ),
# Attention output norm # Attention output norm
@ -215,10 +220,10 @@ class TensorNameMap:
# Feed-forward norm # Feed-forward norm
MODEL_TENSOR.FFN_NORM: ( MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais "transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
"h.{bid}.post_attention_layernorm", # bloom "h.{bid}.post_attention_layernorm", # bloom
"transformer.blocks.{bid}.norm_2", # mpt "transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron
"layers.{bid}.ffn_norm", # llama-pth "layers.{bid}.ffn_norm", # llama-pth
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
"model.layers.{bid}.ln2", # yi "model.layers.{bid}.ln2", # yi
@ -258,7 +263,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.ffn.up_proj", # mpt "transformer.blocks.{bid}.ffn.up_proj", # mpt
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"h.{bid}.mlp.dense_h_to_4h", # bloom "h.{bid}.mlp.dense_h_to_4h", # bloom
"model.layers.{bid}.mlp.up_proj", # llama-hf refact "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron
"layers.{bid}.feed_forward.w3", # llama-pth "layers.{bid}.feed_forward.w3", # llama-pth
"encoder.layer.{bid}.intermediate.dense", # bert "encoder.layer.{bid}.intermediate.dense", # bert
"transformer.h.{bid}.mlp.fc_in", # gpt-j "transformer.h.{bid}.mlp.fc_in", # gpt-j
@ -277,6 +282,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w3", # arctic "model.layers.{bid}.residual_mlp.w3", # arctic
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"transformer.h.{bid}.mlp.c_fc_1", # exaone
), ),
MODEL_TENSOR.FFN_UP_EXP: ( MODEL_TENSOR.FFN_UP_EXP: (
@ -308,6 +314,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
"transformer.h.{bid}.mlp.linear_1", # refact "transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic "model.layers.{bid}.residual_mlp.w1", # arctic
"transformer.h.{bid}.mlp.c_fc_0", # exaone
), ),
MODEL_TENSOR.FFN_GATE_EXP: ( MODEL_TENSOR.FFN_GATE_EXP: (
@ -329,7 +336,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.ffn.down_proj", # mpt "transformer.blocks.{bid}.ffn.down_proj", # mpt
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"h.{bid}.mlp.dense_4h_to_h", # bloom "h.{bid}.mlp.dense_4h_to_h", # bloom
"model.layers.{bid}.mlp.down_proj", # llama-hf "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron
"layers.{bid}.feed_forward.w2", # llama-pth "layers.{bid}.feed_forward.w2", # llama-pth
"encoder.layer.{bid}.output.dense", # bert "encoder.layer.{bid}.output.dense", # bert
"transformer.h.{bid}.mlp.fc_out", # gpt-j "transformer.h.{bid}.mlp.fc_out", # gpt-j
@ -347,6 +354,7 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w2", # arctic "model.layers.{bid}.residual_mlp.w2", # arctic
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone
), ),
MODEL_TENSOR.FFN_DOWN_EXP: ( MODEL_TENSOR.FFN_DOWN_EXP: (

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "gguf" name = "gguf"
version = "0.9.1" version = "0.10.0"
description = "Read and write ML models in GGUF for GGML" description = "Read and write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"] authors = ["GGML <ggml@ggml.ai>"]
packages = [ packages = [

View File

@ -93,6 +93,9 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
}; };
enum llama_rope_type { enum llama_rope_type {
@ -510,6 +513,9 @@ extern "C" {
// to the decoder to start generating output sequence. For other models, it returns -1. // to the decoder to start generating output sequence. For other models, it returns -1.
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model); LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
// Returns 0 on success // Returns 0 on success
LLAMA_API uint32_t llama_model_quantize( LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp, const char * fname_inp,
@ -914,11 +920,8 @@ extern "C" {
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding
// Returns -1 if unknown, 1 for true or 0 for false. LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model); LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
// Returns -1 if unknown, 1 for true or 0 for false.
LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
// Codellama infill tokens // Codellama infill tokens
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix

View File

@ -321,6 +321,21 @@ private:
// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused // TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused
template<typename T, typename Container = std::vector<T>, typename Compare = std::less<typename Container::value_type>>
class llama_priority_queue : public std::priority_queue<T, Container, Compare> {
public:
using std::priority_queue<T, Container, Compare>::priority_queue;
T pop_move() {
T item = std::move(this->c.front());
std::pop_heap(this->c.begin(), this->c.end(), this->comp);
this->c.pop_back();
return item;
}
void pop() = delete;
};
struct llm_bigram_bpe { struct llm_bigram_bpe {
struct comparator { struct comparator {
bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const { bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const {
@ -329,7 +344,7 @@ struct llm_bigram_bpe {
}; };
using queue_storage = std::vector<llm_bigram_bpe>; using queue_storage = std::vector<llm_bigram_bpe>;
using queue = std::priority_queue<llm_bigram_bpe, queue_storage, comparator>; using queue = llama_priority_queue<llm_bigram_bpe, queue_storage, comparator>;
llm_symbol::index left; llm_symbol::index left;
llm_symbol::index right; llm_symbol::index right;
std::string text; std::string text;
@ -388,6 +403,7 @@ struct llm_tokenizer_bpe {
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R: case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
case LLAMA_VOCAB_PRE_TYPE_SMOLLM: case LLAMA_VOCAB_PRE_TYPE_SMOLLM:
case LLAMA_VOCAB_PRE_TYPE_CODESHELL: case LLAMA_VOCAB_PRE_TYPE_CODESHELL:
case LLAMA_VOCAB_PRE_TYPE_EXAONE:
regex_exprs = { regex_exprs = {
"\\p{N}", "\\p{N}",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
@ -410,6 +426,8 @@ struct llm_tokenizer_bpe {
}; };
break; break;
case LLAMA_VOCAB_PRE_TYPE_PORO: case LLAMA_VOCAB_PRE_TYPE_PORO:
case LLAMA_VOCAB_PRE_TYPE_BLOOM:
case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH:
regex_exprs = { regex_exprs = {
" ?[^(\\s|.,!?…。,、।۔،)]+", " ?[^(\\s|.,!?…。,、।۔،)]+",
}; };
@ -517,8 +535,7 @@ struct llm_tokenizer_bpe {
// build token(s) // build token(s)
while (!work_queue.empty()) { while (!work_queue.empty()) {
auto bigram = work_queue.top(); auto bigram = work_queue.pop_move();
work_queue.pop();
auto & left_symbol = symbols[bigram.left]; auto & left_symbol = symbols[bigram.left];
auto & right_symbol = symbols[bigram.right]; auto & right_symbol = symbols[bigram.right];
@ -1466,11 +1483,11 @@ llama_token llama_token_pad_impl(const struct llama_vocab & vocab) {
return vocab.special_pad_id; return vocab.special_pad_id;
} }
int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab) { bool llama_add_bos_token_impl(const struct llama_vocab & vocab) {
return vocab.tokenizer_add_bos; return vocab.tokenizer_add_bos;
} }
int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab) { bool llama_add_eos_token_impl(const struct llama_vocab & vocab) {
return vocab.tokenizer_add_eos; return vocab.tokenizer_add_eos;
} }

View File

@ -95,8 +95,8 @@ llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
llama_token llama_token_nl_impl (const struct llama_vocab & vocab); llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
llama_token llama_token_pad_impl(const struct llama_vocab & vocab); llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab); bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab); bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab); llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
llama_token llama_token_middle_impl(const struct llama_vocab & vocab); llama_token llama_token_middle_impl(const struct llama_vocab & vocab);

File diff suppressed because it is too large Load Diff

View File

@ -2145,6 +2145,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
// sycl backend will limit task global_range < MAX_INT
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
// however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.)
// these cases are verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend)
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_conv_transpose_1d()); test_cases.emplace_back(new test_conv_transpose_1d());
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));
@ -2287,6 +2294,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
// sycl backend will limit task global_range < MAX_INT
// test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)
// however this case needs to alloc more memory which may fail in some devices (Intel Arc770, etc.)
// this case is verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend)
// test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 512, 262144, 9216, {1, 1}, {1, 1}));
for (ggml_type type_a : base_types) { for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {4, 8}) { for (int n_mats : {4, 8}) {

View File

@ -503,7 +503,7 @@ static void test_special_chars() {
"aaaaabcccc", "aaaaabcccc",
"aaaabccc", "aaaabccc",
"aaaabccccc", "aaaabccccc",
"🔵🟠✅❌abc❌✅🟠🔵" "🔵🟠✅❌abc❌✅🟠🔵",
"🔵🟠abc🟠🔵" "🔵🟠abc🟠🔵"
} }
); );

View File

@ -0,0 +1,139 @@
#!/bin/bash
set -e
# Array of models to iterate over
declare -a params=(
"Gemma2ForCausalLM 64"
"LlamaForCausalLM 64"
"Phi3ForCausalLM 64"
)
MODELS_REPO=lora-tests
MODELS_REPO_URL=https://huggingface.co/ggml-org/$MODELS_REPO
# Clone the Hugging Face repository if the directory does not exist
if [ ! -d "$MODELS_REPO" ]; then
echo "Cloning the Hugging Face repository..."
git clone $MODELS_REPO_URL
else
echo "Repository already exists. Skipping clone."
fi
# Array to store results to print
results=()
trim_leading_whitespace() {
local input_string="$1"
echo "${input_string#"${input_string%%[![:space:]]*}"}"
}
extract_starting_substring() {
local reference_string="$1"
local target_string="$2"
local target_length=${#target_string}
echo "${reference_string:0:$target_length}"
}
get_first_word() {
local input_string="$1"
read -r first_word _ <<< "$input_string"
echo "$first_word"
}
# Load the expected strings
EXPECTED_BASE_FULL=$(cat $MODELS_REPO/data/pale_blue_dot.txt)
EXPECTED_LORA_FULL=$(cat $MODELS_REPO/data/bohemian_rhapsody.txt)
EXPECTED_BASE_FIRST_WORD=$(get_first_word "$EXPECTED_BASE_FULL")
EXPECTED_LORA_FIRST_WORD=$(get_first_word "$EXPECTED_LORA_FULL")
run_conversion_and_inference_lora() {
local model_name=$1
local hidden_size=$2
echo -e "\n\n-------- RUNNING TEST FOR MODEL $model_name --------\n\n"
# Convert safetensors to gguf
echo "Running convert_hf_to_gguf.py for $model_name with hidden_size $hidden_size..."
python convert_hf_to_gguf.py $MODELS_REPO/$model_name/hidden_size=$hidden_size/base \
--outfile $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
--outtype f32
echo -e "\n\n---------------------------\n\n"
echo "Running convert_lora_to_gguf.py for $model_name with hidden_size $hidden_size..."
python3 convert_lora_to_gguf.py $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora \
--base $MODELS_REPO/$model_name/hidden_size=$hidden_size/base \
--outtype f32
echo -e "\n\n---------------------------\n\n"
echo "Running llama-export-lora with lora for $model_name with hidden_size $hidden_size..."
./llama-export-lora \
-m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
-o $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \
--lora $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora/Lora-F32-LoRA.gguf
# Run inference
echo -e "\n\n---------------------------\n\n"
echo "Running llama-cli without lora for $model_name with hidden_size $hidden_size..."
OUTPUT_BASE=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
-p "$EXPECTED_BASE_FIRST_WORD" -n 50 --seed 42 --temp 0)
echo -e "\n\n---------------------------\n\n"
echo "Running llama-cli with hot lora for $model_name with hidden_size $hidden_size..."
OUTPUT_LORA_HOT=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
--lora $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora/Lora-F32-LoRA.gguf \
-p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0)
echo -e "\n\n---------------------------\n\n"
echo "Running llama-cli with merged lora for $model_name with hidden_size $hidden_size..."
OUTPUT_LORA_MERGED=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \
-p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0)
# Remove any initial white space
OUTPUT_BASE=$(trim_leading_whitespace "$OUTPUT_BASE")
OUTPUT_LORA_HOT=$(trim_leading_whitespace "$OUTPUT_LORA_HOT")
OUTPUT_LORA_MERGED=$(trim_leading_whitespace "$OUTPUT_LORA_MERGED")
# Extract the corresponding substring from full string
EXPECTED_BASE=$(extract_starting_substring "$EXPECTED_BASE_FULL" "$OUTPUT_BASE")
EXPECTED_LORA=$(extract_starting_substring "$EXPECTED_LORA_FULL" "$OUTPUT_LORA_HOT")
# Assert output equals the expected output
if [[ "$OUTPUT_BASE" != "$EXPECTED_BASE" ]]; then
echo "Error: $model_name OUTPUT_BASE does not start with the expected string."
echo -e "Out=$OUTPUT_BASE\n\nExp=$EXPECTED_BASE"
exit 1
fi
if [[ "$OUTPUT_LORA_HOT" != "$EXPECTED_LORA" ]]; then
echo "Error: $model_name OUTPUT_LORA_HOT does not start with the expected string."
echo -e "Out=$OUTPUT_LORA_HOT\n\nExp=$EXPECTED_LORA"
exit 1
fi
if [[ "$OUTPUT_LORA_MERGED" != "$EXPECTED_LORA" ]]; then
echo "Error: $model_name OUTPUT_LORA_MERGED does not start with the expected string."
echo -e "Out=$OUTPUT_LORA_MERGED\n\nExp=$EXPECTED_LORA"
exit 1
fi
# Store the results
results+=("
\n\033[1mResults for $model_name with hidden_size $hidden_size:\033[0m
\n\033[32m • Base:\n$OUTPUT_BASE
\n\033[34m • Lora hot:\n$OUTPUT_LORA_HOT
\n\033[36m • Lora merged:\n$OUTPUT_LORA_MERGED
\n \033[0m
")
echo "All tests passed for $model_name with hidden_size $hidden_size!"
}
# Run test for each model
for param in "${params[@]}"; do
run_conversion_and_inference_lora $param
done
# Print results
echo -e "\n\n---------------------------\n\n"
echo -e "\n\033[1mSummary of All Results:\033[0m"
for result in "${results[@]}"; do
echo -e "$result"
done