Merge branch 'master' of https://github.com/ggerganov/llama.cpp into ceb/fix-cuda-warning-flags

This commit is contained in:
Jared Van Bortel 2023-12-13 12:06:01 -05:00
commit c8554b80be
33 changed files with 2430 additions and 440 deletions

View File

@ -639,6 +639,11 @@ else()
message(STATUS "Unknown architecture") message(STATUS "Unknown architecture")
endif() endif()
if (MINGW)
# Target Windows 8 for PrefetchVirtualMemory
add_compile_definitions(_WIN32_WINNT=0x602)
endif()
# #
# POSIX conformance # POSIX conformance
# #

View File

@ -268,12 +268,15 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64))
#MK_CXXFLAGS += -mssse3 #MK_CXXFLAGS += -mssse3
endif endif
# The stack is only 16-byte aligned on Windows, so don't let gcc emit aligned moves.
# https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54412
# https://github.com/ggerganov/llama.cpp/issues/2922
ifneq '' '$(findstring mingw,$(shell $(CC) -dumpmachine))' ifneq '' '$(findstring mingw,$(shell $(CC) -dumpmachine))'
# The stack is only 16-byte aligned on Windows, so don't let gcc emit aligned moves.
# https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54412
# https://github.com/ggerganov/llama.cpp/issues/2922
MK_CFLAGS += -Xassembler -muse-unaligned-vector-move MK_CFLAGS += -Xassembler -muse-unaligned-vector-move
MK_CXXFLAGS += -Xassembler -muse-unaligned-vector-move MK_CXXFLAGS += -Xassembler -muse-unaligned-vector-move
# Target Windows 8 for PrefetchVirtualMemory
MK_CPPFLAGS += -D_WIN32_WINNT=0x602
endif endif
ifneq ($(filter aarch64%,$(UNAME_M)),) ifneq ($(filter aarch64%,$(UNAME_M)),)
@ -358,6 +361,11 @@ ifdef LLAMA_CUBLAS
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
OBJS += ggml-cuda.o OBJS += ggml-cuda.o
MK_NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math MK_NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
ifdef LLAMA_DEBUG
MK_NVCCFLAGS += -lineinfo
endif
ifdef LLAMA_CUDA_NVCC ifdef LLAMA_CUDA_NVCC
NVCC = $(LLAMA_CUDA_NVCC) NVCC = $(LLAMA_CUDA_NVCC)
else else
@ -696,16 +704,16 @@ tests/test-quantize-perf: tests/test-quantize-perf.cpp ggml.o $(OBJS)
tests/test-sampling: tests/test-sampling.cpp ggml.o llama.o $(OBJS) tests/test-sampling: tests/test-sampling.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-tokenizer-0-falcon: tests/test-tokenizer-0-falcon.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) tests/test-tokenizer-0-falcon: tests/test-tokenizer-0-falcon.cpp ggml.o llama.o $(COMMON_DEPS) console.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-tokenizer-0-llama: tests/test-tokenizer-0-llama.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) tests/test-tokenizer-0-llama: tests/test-tokenizer-0-llama.cpp ggml.o llama.o $(COMMON_DEPS) console.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-tokenizer-1-bpe: tests/test-tokenizer-1-bpe.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) tests/test-tokenizer-1-bpe: tests/test-tokenizer-1-bpe.cpp ggml.o llama.o $(COMMON_DEPS) console.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-tokenizer-1-llama: tests/test-tokenizer-1-llama.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) tests/test-tokenizer-1-llama: tests/test-tokenizer-1-llama.cpp ggml.o llama.o $(COMMON_DEPS) console.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-rope: tests/test-rope.cpp ggml.o $(OBJS) tests/test-rope: tests/test-rope.cpp ggml.o $(OBJS)

View File

@ -10,6 +10,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
### Hot topics ### Hot topics
- Added Mixtral support: https://github.com/ggerganov/llama.cpp/pull/4406
- **llama.h API change for handling KV cache offloading and data type: https://github.com/ggerganov/llama.cpp/pull/4309** - **llama.h API change for handling KV cache offloading and data type: https://github.com/ggerganov/llama.cpp/pull/4309**
- Using `llama.cpp` with AWS instances: https://github.com/ggerganov/llama.cpp/discussions/4225 - Using `llama.cpp` with AWS instances: https://github.com/ggerganov/llama.cpp/discussions/4225
- Looking for contributions to improve and maintain the `server` example: https://github.com/ggerganov/llama.cpp/issues/4216 - Looking for contributions to improve and maintain the `server` example: https://github.com/ggerganov/llama.cpp/issues/4216

View File

@ -656,6 +656,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
} else if (arg == "-h" || arg == "--help") { } else if (arg == "-h" || arg == "--help") {
return false; return false;
} else if (arg == "--version") {
fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
exit(0);
} else if (arg == "--random-prompt") { } else if (arg == "--random-prompt") {
params.random_prompt = true; params.random_prompt = true;
} else if (arg == "--in-prefix-bos") { } else if (arg == "--in-prefix-bos") {
@ -794,6 +798,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf("\n"); printf("\n");
printf("options:\n"); printf("options:\n");
printf(" -h, --help show this help message and exit\n"); printf(" -h, --help show this help message and exit\n");
printf(" --version show version and build info\n");
printf(" -i, --interactive run in interactive mode\n"); printf(" -i, --interactive run in interactive mode\n");
printf(" --interactive-first run in interactive mode and wait for input right away\n"); printf(" --interactive-first run in interactive mode and wait for input right away\n");
printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n");

View File

@ -61,13 +61,13 @@
// #define LOG_TARGET stderr // #define LOG_TARGET stderr
// #include "log.h" // #include "log.h"
// //
// The log target can also be redirected to a diffrent function // The log target can also be redirected to a different function
// like so: // like so:
// //
// #define LOG_TARGET log_handler_diffrent() // #define LOG_TARGET log_handler_different()
// #include "log.h" // #include "log.h"
// //
// FILE* log_handler_diffrent() // FILE* log_handler_different()
// { // {
// return stderr; // return stderr;
// } // }
@ -421,7 +421,7 @@ inline FILE *log_handler2_impl(bool change = false, LogTriState append = LogTriS
// Disables logs entirely at runtime. // Disables logs entirely at runtime.
// Makes LOG() and LOG_TEE() produce no output, // Makes LOG() and LOG_TEE() produce no output,
// untill enabled back. // until enabled back.
#define log_disable() log_disable_impl() #define log_disable() log_disable_impl()
// INTERNAL, DO NOT USE // INTERNAL, DO NOT USE

View File

@ -77,8 +77,18 @@ class Model:
self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_embedding_length(n_embd)
if (n_ff := self.hparams.get("intermediate_size")) is not None: if (n_ff := self.hparams.get("intermediate_size")) is not None:
self.gguf_writer.add_feed_forward_length(n_ff) self.gguf_writer.add_feed_forward_length(n_ff)
if (n_head := self.hparams.get("num_attention_head")) is not None: if (n_head := self.hparams.get("num_attention_heads")) is not None:
self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count(n_head)
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
self.gguf_writer.add_head_count_kv(n_head_kv)
if (n_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
self.gguf_writer.add_layer_norm_rms_eps(n_rms_eps)
if (n_experts := self.hparams.get("num_local_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)
self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True)) self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
def write_tensors(self): def write_tensors(self):
@ -170,6 +180,8 @@ class Model:
return StableLMModel return StableLMModel
if model_architecture == "QWenLMHeadModel": if model_architecture == "QWenLMHeadModel":
return QwenModel return QwenModel
if model_architecture == "MixtralForCausalLM":
return MixtralModel
return Model return Model
def _is_model_safetensors(self) -> bool: def _is_model_safetensors(self) -> bool:
@ -207,6 +219,8 @@ class Model:
return gguf.MODEL_ARCH.STABLELM return gguf.MODEL_ARCH.STABLELM
if arch == "QWenLMHeadModel": if arch == "QWenLMHeadModel":
return gguf.MODEL_ARCH.QWEN return gguf.MODEL_ARCH.QWEN
if arch == "MixtralForCausalLM":
return gguf.MODEL_ARCH.LLAMA
raise NotImplementedError(f'Architecture "{arch}" not supported!') raise NotImplementedError(f'Architecture "{arch}" not supported!')
@ -837,6 +851,11 @@ class StableLMModel(Model):
self.gguf_writer.add_layer_norm_eps(1e-5) self.gguf_writer.add_layer_norm_eps(1e-5)
class MixtralModel(Model):
def set_vocab(self):
self._set_vocab_sentencepiece()
class QwenModel(Model): class QwenModel(Model):
@staticmethod @staticmethod
def token_bytes_to_string(b): def token_bytes_to_string(b):

View File

@ -42,6 +42,7 @@ NDArray: TypeAlias = 'np.ndarray[Any, Any]'
ARCH = gguf.MODEL_ARCH.LLAMA ARCH = gguf.MODEL_ARCH.LLAMA
DEFAULT_CONCURRENCY = 8 DEFAULT_CONCURRENCY = 8
# #
# data types # data types
# #
@ -62,10 +63,10 @@ class UnquantizedDataType(DataType):
pass pass
DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0']) DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0']) DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = []) DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0']) DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
@dataclass(frozen=True) @dataclass(frozen=True)
@ -151,14 +152,16 @@ GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = {
@dataclass @dataclass
class Params: class Params:
n_vocab: int n_vocab: int
n_embd: int n_embd: int
n_layer: int n_layer: int
n_ctx: int n_ctx: int
n_ff: int n_ff: int
n_head: int n_head: int
n_head_kv: int n_head_kv: int
f_norm_eps: float n_experts: int | None = None
n_experts_used: int | None = None
f_norm_eps: float | None = None
rope_scaling_type: gguf.RopeScalingType | None = None rope_scaling_type: gguf.RopeScalingType | None = None
f_rope_freq_base: float | None = None f_rope_freq_base: float | None = None
@ -233,6 +236,13 @@ class Params:
raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n" raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n"
"Suggestion: provide 'config.json' of the model in the same directory containing model files.") "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
n_experts = None
n_experts_used = None
if "num_local_experts" in config:
n_experts = config["num_local_experts"]
n_experts_used = config["num_experts_per_tok"]
return Params( return Params(
n_vocab = config["vocab_size"], n_vocab = config["vocab_size"],
n_embd = config["hidden_size"], n_embd = config["hidden_size"],
@ -241,6 +251,8 @@ class Params:
n_ff = config["intermediate_size"], n_ff = config["intermediate_size"],
n_head = (n_head := config["num_attention_heads"]), n_head = (n_head := config["num_attention_heads"]),
n_head_kv = config.get("num_key_value_heads", n_head), n_head_kv = config.get("num_key_value_heads", n_head),
n_experts = n_experts,
n_experts_used = n_experts_used,
f_norm_eps = config["rms_norm_eps"], f_norm_eps = config["rms_norm_eps"],
f_rope_freq_base = config.get("rope_theta"), f_rope_freq_base = config.get("rope_theta"),
rope_scaling_type = rope_scaling_type, rope_scaling_type = rope_scaling_type,
@ -255,8 +267,15 @@ class Params:
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path)) config = json.load(open(config_path))
n_experts = None
n_experts_used = None
f_rope_freq_base = None
# hack to determine LLaMA v1 vs v2 vs CodeLlama # hack to determine LLaMA v1 vs v2 vs CodeLlama
if config.get("rope_theta") == 1000000: if config.get("moe"):
# Mixtral
n_ctx = 32768
elif config.get("rope_theta") == 1000000:
# CodeLlama # CodeLlama
n_ctx = 16384 n_ctx = 16384
elif config["norm_eps"] == 1e-05: elif config["norm_eps"] == 1e-05:
@ -266,16 +285,27 @@ class Params:
# LLaMA v1 # LLaMA v1
n_ctx = 2048 n_ctx = 2048
if "layers.0.feed_forward.w1.weight" in model:
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("moe"):
n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0]
n_experts = config["moe"]["num_experts"]
n_experts_used = config["moe"]["num_experts_per_tok"]
f_rope_freq_base = 1e6
return Params( return Params(
n_vocab = model["tok_embeddings.weight"].shape[0], n_vocab = model["tok_embeddings.weight"].shape[0],
n_embd = config["dim"], n_embd = config["dim"],
n_layer = config["n_layers"], n_layer = config["n_layers"],
n_ctx = n_ctx, n_ctx = n_ctx,
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0], n_ff = n_ff,
n_head = (n_head := config["n_heads"]), n_head = (n_head := config["n_heads"]),
n_head_kv = config.get("n_kv_heads", n_head), n_head_kv = config.get("n_kv_heads", n_head),
n_experts = n_experts,
n_experts_used = n_experts_used,
f_norm_eps = config["norm_eps"], f_norm_eps = config["norm_eps"],
f_rope_freq_base = config.get("rope_theta"), f_rope_freq_base = config.get("rope_theta", f_rope_freq_base),
) )
@staticmethod @staticmethod
@ -585,7 +615,7 @@ def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
if any("model.embed_tokens.weight" in mp.model for mp in models_plus): if any("model.embed_tokens.weight" in mp.model for mp in models_plus):
# Transformers models put different tensors in different files, but # Transformers models put different tensors in different files, but
# don't split indivdual tensors between files. # don't split individual tensors between files.
model: LazyModel = {} model: LazyModel = {}
for mp in models_plus: for mp in models_plus:
model.update(mp.model) model.update(mp.model)
@ -678,7 +708,7 @@ class LazyUnpickler(pickle.Unpickler):
return func(*args) return func(*args)
CLASSES: dict[tuple[str, str], Any] = { CLASSES: dict[tuple[str, str], Any] = {
# getattr used here as a workaround for mypy not being smart enough to detrmine # getattr used here as a workaround for mypy not being smart enough to determine
# the staticmethods have a __func__ attribute. # the staticmethods have a __func__ attribute.
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'), ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'),
@ -832,7 +862,17 @@ class OutputFile:
self.gguf.add_rope_dimension_count(params.n_embd // params.n_head) self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
self.gguf.add_head_count (params.n_head) self.gguf.add_head_count (params.n_head)
self.gguf.add_head_count_kv (params.n_head_kv) self.gguf.add_head_count_kv (params.n_head_kv)
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
if params.n_experts:
self.gguf.add_expert_count(params.n_experts)
if params.n_experts_used:
self.gguf.add_expert_used_count(params.n_experts_used)
if params.f_norm_eps:
self.gguf.add_layer_norm_rms_eps(params.f_norm_eps)
else:
raise ValueError('f_norm_eps is None')
if params.f_rope_freq_base is not None: if params.f_rope_freq_base is not None:
self.gguf.add_rope_freq_base(params.f_rope_freq_base) self.gguf.add_rope_freq_base(params.f_rope_freq_base)
@ -956,7 +996,7 @@ class OutputFile:
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType: def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) +".weight"].data_type wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32): if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
return GGMLFileType.AllF32 return GGMLFileType.AllF32

View File

@ -739,7 +739,7 @@ bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip
temp->ny = longer_side; temp->ny = longer_side;
temp->size = 3 * longer_side * longer_side; temp->size = 3 * longer_side * longer_side;
temp->data = new uint8_t[temp->size](); temp->data = new uint8_t[temp->size]();
uint8_t bc[3] = {122, 116, 104}; // bakground color in RGB from LLaVA uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA
// fill with background color // fill with background color
for (size_t i = 0; i < temp->size; i++) { for (size_t i = 0; i < temp->size; i++) {

View File

@ -51,7 +51,7 @@ def bytes_to_unicode():
The reversible bpe codes work on unicode strings. The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab. This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on. And avoids mapping to whitespace/control characters the bpe code barfs on.
""" """

View File

@ -1,6 +1,6 @@
# llama.cpp/examples/lookahead # llama.cpp/examples/lookahead
Demonstartion of lookahead decoding technique: Demonstration of lookahead decoding technique:
https://lmsys.org/blog/2023-11-21-lookahead-decoding/ https://lmsys.org/blog/2023-11-21-lookahead-decoding/

View File

@ -11227,7 +11227,7 @@ class binary_reader
} }
if (is_ndarray) // ndarray dimensional vector can only contain integers, and can not embed another array if (is_ndarray) // ndarray dimensional vector can only contain integers, and can not embed another array
{ {
return sax->parse_error(chars_read, get_token_string(), parse_error::create(113, chars_read, exception_message(input_format, "ndarray dimentional vector is not allowed", "size"), nullptr)); return sax->parse_error(chars_read, get_token_string(), parse_error::create(113, chars_read, exception_message(input_format, "ndarray dimensional vector is not allowed", "size"), nullptr));
} }
std::vector<size_t> dim; std::vector<size_t> dim;
if (JSON_HEDLEY_UNLIKELY(!get_ubjson_ndarray_size(dim))) if (JSON_HEDLEY_UNLIKELY(!get_ubjson_ndarray_size(dim)))

View File

@ -114,7 +114,7 @@ export async function* llama(prompt, params = {}, config = {}) {
return content; return content;
} }
// Call llama, return an event target that you can subcribe to // Call llama, return an event target that you can subscribe to
// //
// Example: // Example:
// //

View File

@ -223,7 +223,7 @@
repeat_last_n: 256, // 0 = disable penalty, -1 = context size repeat_last_n: 256, // 0 = disable penalty, -1 = context size
repeat_penalty: 1.18, // 1.0 = disabled repeat_penalty: 1.18, // 1.0 = disabled
top_k: 40, // <= 0 to use vocab size top_k: 40, // <= 0 to use vocab size
top_p: 0.5, // 1.0 = disabled top_p: 0.95, // 1.0 = disabled
min_p: 0.05, // 0 = disabled min_p: 0.05, // 0 = disabled
tfs_z: 1.0, // 1.0 = disabled tfs_z: 1.0, // 1.0 = disabled
typical_p: 1.0, // 1.0 = disabled typical_p: 1.0, // 1.0 = disabled
@ -238,7 +238,7 @@
cache_prompt: true cache_prompt: true
}) })
/* START: Support for storing prompt templates and parameters in borwser LocalStorage */ /* START: Support for storing prompt templates and parameters in browsers LocalStorage */
const local_storage_storageKey = "llamacpp_server_local_storage"; const local_storage_storageKey = "llamacpp_server_local_storage";
@ -282,7 +282,7 @@
let importedTemplates = local_storage_getDataAsObject('user_templates') let importedTemplates = local_storage_getDataAsObject('user_templates')
if (importedTemplates) { if (importedTemplates) {
// saved templates were successfuly imported. // saved templates were successfully imported.
console.log('Processing saved templates and updating default template') console.log('Processing saved templates and updating default template')
params.value = { ...params.value, image_data: [] }; params.value = { ...params.value, image_data: [] };
@ -303,7 +303,7 @@
} }
function userTemplateResetToDefault() { function userTemplateResetToDefault() {
console.log('Reseting themplate to default') console.log('Resetting template to default')
selectedUserTemplate.value.name = 'default'; selectedUserTemplate.value.name = 'default';
selectedUserTemplate.value.data = savedUserTemplates.value['default']; selectedUserTemplate.value.data = savedUserTemplates.value['default'];
} }
@ -762,7 +762,7 @@
<fieldset class="two"> <fieldset class="two">
${IntField({ label: "Predictions", max: 2048, min: -1, name: "n_predict", value: params.value.n_predict })} ${IntField({ label: "Predictions", max: 2048, min: -1, name: "n_predict", value: params.value.n_predict })}
${FloatField({ label: "Temperature", max: 1.5, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })} ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })} ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })} ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })} ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}

View File

@ -2382,6 +2382,7 @@ json oaicompat_completion_params_parse(
llama_params["__oaicompat"] = true; llama_params["__oaicompat"] = true;
// Map OpenAI parameters to llama.cpp parameters // Map OpenAI parameters to llama.cpp parameters
llama_params["model"] = json_value(body, "model", std::string("uknown"));
llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
llama_params["temperature"] = json_value(body, "temperature", 0.8); llama_params["temperature"] = json_value(body, "temperature", 0.8);

View File

@ -1,6 +1,6 @@
# llama.cpp/examples/speculative # llama.cpp/examples/speculative
Demonstartion of speculative decoding and tree-based speculative decoding techniques Demonstration of speculative decoding and tree-based speculative decoding techniques
More info: More info:

View File

@ -428,7 +428,7 @@ int main(int argc, char ** argv) {
++n_past_tgt; ++n_past_tgt;
} }
// the first token is always proposed by the traget model before the speculation loop so we erase it here // the first token is always proposed by the target model before the speculation loop so we erase it here
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) { if (!drafts[s].active) {
continue; continue;

View File

@ -43,7 +43,7 @@ GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph
// ggml-backend v2 API // ggml-backend v2 API
// //
// Seperate tensor and graph allocator objects // Separate tensor and graph allocator objects
// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators // This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
// The original API is kept as a wrapper around the new API // The original API is kept as a wrapper around the new API

View File

@ -1,13 +1,15 @@
#include <algorithm> #include <algorithm>
#include <assert.h>
#include <atomic>
#include <cinttypes>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <cinttypes>
#include <float.h> #include <float.h>
#include <limits> #include <limits>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <atomic> #include <vector>
#include <assert.h>
#if defined(GGML_USE_HIPBLAS) #if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
@ -1684,31 +1686,65 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
} }
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) { static __global__ void k_get_rows(
const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2; const void * src0, const int32_t * src1, dst_t * dst,
const int row = blockDim.y*blockIdx.y + threadIdx.y; int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
if (col >= ncols) { const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
if (i00 >= ne00) {
return; return;
} }
const int r = y[row]; const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
// copy x[r*ncols + col] to dst[row*ncols + col] dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const int xi = r*ncols + col; const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
const int di = row*ncols + col;
const int ib = xi/qk; // block index const int ib = i00/qk; // block index
const int iqs = (xi%qk)/qr; // quant index const int iqs = (i00%qk)/qr; // quant index
const int iybs = di - di%qk; // y block start index const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2; const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize // dequantize
dfloat2 v; dfloat2 v;
dequantize_kernel(x, ib, iqs, v); dequantize_kernel(src0_row, ib, iqs, v);
dst[iybs + iqs + 0] = v.x; dst_row[iybs + iqs + 0] = v.x;
dst[iybs + iqs + y_offset] = v.y; dst_row[iybs + iqs + y_offset] = v.y;
}
template<typename src0_t, typename dst_t>
static __global__ void k_get_rows_float(
const src0_t * src0, const int32_t * src1, dst_t * dst,
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
if (i00 >= ne00) {
return;
}
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
dst_row[i00] = src0_row[i00];
} }
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@ -5053,11 +5089,69 @@ static __global__ void im2col_f32_f16(
} }
template<int qk, int qr, dequantize_kernel_t dq> template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) { static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
GGML_TENSOR_BINARY_OP_LOCALS
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
const dim3 block_nums(block_num_x, nrows, 1); const dim3 block_nums(block_num_x, ne10, ne11*ne12);
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
// strides in elements
//const size_t s0 = nb0 / ggml_element_size(dst);
const size_t s1 = nb1 / ggml_element_size(dst);
const size_t s2 = nb2 / ggml_element_size(dst);
const size_t s3 = nb3 / ggml_element_size(dst);
const size_t s10 = nb10 / ggml_element_size(src1);
const size_t s11 = nb11 / ggml_element_size(src1);
const size_t s12 = nb12 / ggml_element_size(src1);
//const size_t s13 = nb13 / ggml_element_size(src1);
GGML_ASSERT(ne00 % 2 == 0);
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd,
ne00, /*ne01, ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/);
(void) dst;
}
template<typename src0_t>
static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
GGML_TENSOR_BINARY_OP_LOCALS
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
// strides in elements
//const size_t s0 = nb0 / ggml_element_size(dst);
const size_t s1 = nb1 / ggml_element_size(dst);
const size_t s2 = nb2 / ggml_element_size(dst);
const size_t s3 = nb3 / ggml_element_size(dst);
const size_t s10 = nb10 / ggml_element_size(src1);
const size_t s11 = nb11 / ggml_element_size(src1);
const size_t s12 = nb12 / ggml_element_size(src1);
//const size_t s13 = nb13 / ggml_element_size(src1);
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd,
ne00, /*ne01, ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/);
(void) dst;
} }
template<float (*bin_op)(const float, const float)> template<float (*bin_op)(const float, const float)>
@ -5069,7 +5163,6 @@ struct bin_bcast_cuda {
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS
int nr0 = ne10/ne0; int nr0 = ne10/ne0;
int nr1 = ne11/ne1; int nr1 = ne11/ne1;
int nr2 = ne12/ne2; int nr2 = ne12/ne2;
@ -5117,26 +5210,28 @@ struct bin_bcast_cuda {
int64_t ne12 = cne1[2]; int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3]; int64_t ne13 = cne1[3];
//size_t nb0 = cnb0[0]; size_t nb0 = cnb0[0];
size_t nb1 = cnb0[1]; size_t nb1 = cnb0[1];
size_t nb2 = cnb0[2]; size_t nb2 = cnb0[2];
size_t nb3 = cnb0[3]; size_t nb3 = cnb0[3];
//size_t nb10 = cnb1[0]; size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1]; size_t nb11 = cnb1[1];
size_t nb12 = cnb1[2]; size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3]; size_t nb13 = cnb1[3];
//size_t s0 = nb0 / sizeof(src1_t); size_t s0 = nb0 / sizeof(src1_t);
size_t s1 = nb1 / sizeof(src1_t); size_t s1 = nb1 / sizeof(src1_t);
size_t s2 = nb2 / sizeof(src1_t); size_t s2 = nb2 / sizeof(src1_t);
size_t s3 = nb3 / sizeof(src1_t); size_t s3 = nb3 / sizeof(src1_t);
//size_t s10 = nb10 / sizeof(src1_t); size_t s10 = nb10 / sizeof(src1_t);
size_t s11 = nb11 / sizeof(src1_t); size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t); size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t); size_t s13 = nb13 / sizeof(src1_t);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128; const int block_size = 128;
@ -6447,36 +6542,34 @@ static void ggml_cuda_op_get_rows(
GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
const int ncols = src0->ne[0]; GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
const int nrows = ggml_nelements(src1); GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
const int32_t * src1_i32 = (const int32_t *) src1_d; const int32_t * src1_i32 = (const int32_t *) src1_d;
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream); get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
break; break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream); get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break; break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream); get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream); get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break; break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream); get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break; break;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream); get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break; break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream); get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break; break;
default: default:
// TODO: k-quants // TODO: k-quants
@ -8234,36 +8327,69 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
} }
#endif #endif
static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) { static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#if 0 #if 0
//#ifdef CUDA_USE_TENSOR_CORES
// const bool use_tensor_cores = true;
//#else
// const bool use_tensor_cores = false;
//#endif
ggml_cuda_mul_mat_id_cublas(dst); ggml_cuda_mul_mat_id_cublas(dst);
// TODO: mmq/mmv support // TODO: mmq/mmv support
#else
const struct ggml_tensor * ids = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const int id = dst->op_params[0];
int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
int32_t a_id;
CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
const struct ggml_tensor * src0 = dst->src[a_id + 2];
ggml_cuda_mul_mat(src0, src1, dst);
#endif #endif
(void) _src0; GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
(void) _src1;
const struct ggml_tensor * ids = src0;
const int32_t id = ((int32_t *) dst->op_params)[0];
const int32_t n_as = ((int32_t *) dst->op_params)[1];
std::vector<char> ids_host(ggml_nbytes(ids));
if (ids->backend == GGML_BACKEND_GPU) {
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
} else {
memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
}
const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
ggml_tensor_extra_gpu src1_row_extra;
ggml_tensor_extra_gpu dst_row_extra;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
src1_row.ne[1] = 1;
dst_row.ne[1] = 1;
src1_row.nb[2] = src1_row.nb[1];
dst_row.nb[2] = dst_row.nb[1];
src1_row.nb[3] = src1_row.nb[1];
dst_row.nb[3] = dst_row.nb[1];
src1_row.extra = &src1_row_extra;
dst_row.extra = &dst_row_extra;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
//int32_t row_id;
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
//CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(row_id >= 0 && row_id < n_as);
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
src1_row.data = (char *) src1->data + i01*src1->nb[1];
dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
dst_row.data = (char *) dst->data + i01*dst->nb[1];
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
}
} }
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -9181,6 +9307,45 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
} }
return true; return true;
} break; } break;
case GGML_OP_GET_ROWS:
{
switch (op->src[0]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
return true;
default:
return false;
}
} break;
case GGML_OP_CPY:
{
ggml_type src0_type = op->src[0]->type;
ggml_type src1_type = op->src[1]->type;
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
return true;
}
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
return true;
}
return false;
} break;
case GGML_OP_NONE: case GGML_OP_NONE:
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
case GGML_OP_VIEW: case GGML_OP_VIEW:
@ -9188,7 +9353,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
case GGML_OP_TRANSPOSE: case GGML_OP_TRANSPOSE:
case GGML_OP_NORM: case GGML_OP_NORM:
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
case GGML_OP_GET_ROWS:
case GGML_OP_DUP: case GGML_OP_DUP:
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_MUL: case GGML_OP_MUL:
@ -9197,7 +9361,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_SQR: case GGML_OP_SQR:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_CPY:
case GGML_OP_CONT: case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
@ -9264,7 +9427,9 @@ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * use
UNUSED(params); UNUSED(params);
} }
extern "C" int ggml_backend_cuda_reg_devices() { extern "C" int ggml_backend_cuda_reg_devices();
int ggml_backend_cuda_reg_devices() {
int device_count = ggml_cuda_get_device_count(); int device_count = ggml_cuda_get_device_count();
//int device_count = 1; // DEBUG: some tools require delaying CUDA initialization //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
for (int i = 0; i < device_count; i++) { for (int i = 0; i < device_count; i++) {

View File

@ -102,6 +102,21 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32); GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@ -140,6 +155,7 @@ struct ggml_metal_context {
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0); //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1); //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DECL_KERNEL(cpy_f16_f16); GGML_METAL_DECL_KERNEL(cpy_f16_f16);
GGML_METAL_DECL_KERNEL(cpy_f16_f32);
GGML_METAL_DECL_KERNEL(concat); GGML_METAL_DECL_KERNEL(concat);
GGML_METAL_DECL_KERNEL(sqr); GGML_METAL_DECL_KERNEL(sqr);
GGML_METAL_DECL_KERNEL(sum_rows); GGML_METAL_DECL_KERNEL(sum_rows);
@ -177,6 +193,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data); ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
} else { } else {
char* buffer2 = malloc(len+1); char* buffer2 = malloc(len+1);
va_end(args);
va_start(args, format);
vsnprintf(buffer2, len+1, format, args); vsnprintf(buffer2, len+1, format, args);
buffer2[len] = 0; buffer2[len] = 0;
ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data); ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
@ -352,6 +370,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@ -392,6 +425,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0); //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1); //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
GGML_METAL_ADD_KERNEL(cpy_f16_f16); GGML_METAL_ADD_KERNEL(cpy_f16_f16);
GGML_METAL_ADD_KERNEL(cpy_f16_f32);
GGML_METAL_ADD_KERNEL(concat); GGML_METAL_ADD_KERNEL(concat);
GGML_METAL_ADD_KERNEL(sqr); GGML_METAL_ADD_KERNEL(sqr);
GGML_METAL_ADD_KERNEL(sum_rows); GGML_METAL_ADD_KERNEL(sum_rows);
@ -452,6 +486,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@ -492,6 +541,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0); //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1); //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DEL_KERNEL(cpy_f16_f16); GGML_METAL_DEL_KERNEL(cpy_f16_f16);
GGML_METAL_DEL_KERNEL(cpy_f16_f32);
GGML_METAL_DEL_KERNEL(concat); GGML_METAL_DEL_KERNEL(concat);
GGML_METAL_DEL_KERNEL(sqr); GGML_METAL_DEL_KERNEL(sqr);
GGML_METAL_DEL_KERNEL(sum_rows); GGML_METAL_DEL_KERNEL(sum_rows);
@ -803,8 +853,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_NONE: case GGML_OP_NONE:
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
case GGML_OP_VIEW: case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE: case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_GET_ROWS:
case GGML_OP_CONCAT: case GGML_OP_CONCAT:
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_MUL: case GGML_OP_MUL:
@ -819,14 +870,38 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_ROPE: case GGML_OP_ROPE:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
return true; return true;
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_CONT:
{
switch (op->src[0]->type) {
case GGML_TYPE_F32:
switch (op->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
return true;
default:
return false;
}
case GGML_TYPE_F16:
switch (op->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return true;
default:
return false;
}
default:
return false;
};
}
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_GET_ROWS:
{ {
return op->ne[0] % 4 == 0; return op->ne[0] % 4 == 0;
} }
@ -1001,34 +1076,37 @@ void ggml_metal_graph_compute(
case GGML_OP_MUL: case GGML_OP_MUL:
case GGML_OP_DIV: case GGML_OP_DIV:
{ {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
bool bcast_row = false; bool bcast_row = false;
int64_t nb = ne00; int64_t nb = ne00;
if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) { id<MTLComputePipelineState> pipeline = nil;
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(src0));
// src1 is a row // src1 is a row
GGML_ASSERT(ne11 == 1); GGML_ASSERT(ne11 == 1);
nb = ne00 / 4; nb = ne00 / 4;
switch (dst->op) { switch (dst->op) {
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break; case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break; case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break; case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
default: GGML_ASSERT(false); default: GGML_ASSERT(false);
} }
bcast_row = true; bcast_row = true;
} else { } else {
switch (dst->op) { switch (dst->op) {
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break; case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break; case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break; case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
default: GGML_ASSERT(false); default: GGML_ASSERT(false);
} }
} }
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@ -1063,7 +1141,7 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else { } else {
const int nth = MIN(1024, ne0); const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} }
@ -1193,7 +1271,11 @@ void ggml_metal_graph_compute(
const float scale = ((float *) dst->op_params)[0]; const float scale = ((float *) dst->op_params)[0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; if (id_src1) {
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
@ -1444,7 +1526,7 @@ void ggml_metal_graph_compute(
else if (src0t == GGML_TYPE_Q6_K) { else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else { } else {
int64_t ny = (ne11 + nrows - 1)/nrows; const int64_t ny = (ne11 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
} }
@ -1456,7 +1538,7 @@ void ggml_metal_graph_compute(
GGML_ASSERT(src0t == GGML_TYPE_I32); GGML_ASSERT(src0t == GGML_TYPE_I32);
const int n_as = ne00; const int n_as = ((int32_t *) dst->op_params)[1];
// TODO: make this more general // TODO: make this more general
GGML_ASSERT(n_as <= 8); GGML_ASSERT(n_as <= 8);
@ -1488,14 +1570,22 @@ void ggml_metal_graph_compute(
// find the break-even point where the matrix-matrix kernel becomes more efficient compared // find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel // to the matrix-vector kernel
int ne11_mm_min = 0; int ne11_mm_min = 1;
const int idx = ((int32_t *) dst->op_params)[0]; const int idx = ((int32_t *) dst->op_params)[0];
// batch size
GGML_ASSERT(ne01 == ne11);
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && // !!!
ne11 > ne11_mm_min) { // TODO: for now, always use mat-vec kernels until we figure out how to improve the
// indirect matrix multiplication
// !!!
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
switch (src2->type) { switch (src2->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break; case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
@ -1514,19 +1604,22 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4]; [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5]; [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6]; [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
[encoder setBytes:&idx length:sizeof(idx) atIndex:15]; [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
// TODO: how to make this an array? read Metal docs // TODO: how to make this an array? read Metal docs
for (int j = 0; j < n_as; ++j) { for (int j = 0; j < n_as; ++j) {
struct ggml_tensor * src_cur = dst->src[2 + j]; struct ggml_tensor * src_cur = dst->src[2 + j];
@ -1534,11 +1627,157 @@ void ggml_metal_graph_compute(
size_t offs_src_cur = 0; size_t offs_src_cur = 0;
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j]; [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
} }
[encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
// TODO: processing one row at a time (ne11 -> 1) is not efficient
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
int nth0 = 32;
int nth1 = 1;
int nrows = 1;
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
// use custom matrix x vector kernel
switch (src2t) {
case GGML_TYPE_F32:
{
GGML_ASSERT(src1t == GGML_TYPE_F32);
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
} break;
case GGML_TYPE_F16:
{
GGML_ASSERT(src1t == GGML_TYPE_F32);
nth0 = 32;
nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
} break;
case GGML_TYPE_Q4_0:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
} break;
case GGML_TYPE_Q4_1:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
} break;
case GGML_TYPE_Q5_0:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
} break;
case GGML_TYPE_Q5_1:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
} break;
case GGML_TYPE_Q8_0:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
} break;
case GGML_TYPE_Q2_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
} break;
case GGML_TYPE_Q3_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
} break;
case GGML_TYPE_Q4_K:
{
nth0 = 4; //1;
nth1 = 8; //32;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
} break;
case GGML_TYPE_Q5_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
} break;
case GGML_TYPE_Q6_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
} break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
GGML_ASSERT(false && "not implemented");
}
};
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
// TODO: how to make this an array? read Metal docs
for (int j = 0; j < n_as; ++j) {
struct ggml_tensor * src_cur = dst->src[2 + j];
size_t offs_src_cur = 0;
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
}
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src2t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src2t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#else
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#endif
}
else if (src2t == GGML_TYPE_Q5_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src2t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
const int64_t ny = (_ne1 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
} }
} break; } break;
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
@ -1559,16 +1798,19 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false && "not implemented"); default: GGML_ASSERT(false && "not implemented");
} }
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
const int64_t n = ggml_nelements(src1); [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
{ {
@ -1813,7 +2055,7 @@ void ggml_metal_graph_compute(
{ {
switch (dstt) { switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break; case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
default: GGML_ASSERT(false && "not implemented"); default: GGML_ASSERT(false && "not implemented");
}; };
} break; } break;

File diff suppressed because it is too large Load Diff

View File

@ -3114,7 +3114,7 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
size_t vl = __riscv_vsetvl_e8m1(qk/2); size_t vl = __riscv_vsetvl_e8m1(qk/2);
// These tempory registers are for masking and shift operations // These temporary registers are for masking and shift operations
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
@ -4757,7 +4757,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
vl = 16; vl = 16;
// retreive lane to multiply with scale // retrieve lane to multiply with scale
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);

180
ggml.c
View File

@ -1,4 +1,4 @@
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
#define _USE_MATH_DEFINES // For M_PI on MSVC #define _USE_MATH_DEFINES // For M_PI on MSVC
#include "ggml-impl.h" #include "ggml-impl.h"
@ -33,7 +33,7 @@
// we should just be careful :) // we should just be careful :)
#pragma warning(disable: 4244 4267) #pragma warning(disable: 4244 4267)
// disable POSIX deprecation warnigns // disable POSIX deprecation warnings
// these functions are never going away, anyway // these functions are never going away, anyway
#pragma warning(disable: 4996) #pragma warning(disable: 4996)
#endif #endif
@ -1760,7 +1760,7 @@ static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
// WARN: // WARN:
// Mis-confguration can lead to problem that's hard to reason about: // Mis-configuration can lead to problem that's hard to reason about:
// * At best it crash or talks nosense. // * At best it crash or talks nosense.
// * At worst it talks slightly difference but hard to perceive. // * At worst it talks slightly difference but hard to perceive.
// //
@ -4075,17 +4075,18 @@ struct ggml_tensor * ggml_mul_mat(
struct ggml_tensor * ggml_mul_mat_id( struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * as[], struct ggml_tensor * const as[],
int n_as,
struct ggml_tensor * ids, struct ggml_tensor * ids,
int id, int id,
struct ggml_tensor * b) { struct ggml_tensor * b) {
int64_t n_as = ids->ne[0];
GGML_ASSERT(ids->type == GGML_TYPE_I32); GGML_ASSERT(ids->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_vector(ids)); GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
GGML_ASSERT(ids->ne[1] == b->ne[1]);
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2); GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
GGML_ASSERT(id >= 0 && id < n_as); GGML_ASSERT(id >= 0 && id < ids->ne[0]);
bool is_node = false; bool is_node = false;
@ -4097,13 +4098,14 @@ struct ggml_tensor * ggml_mul_mat_id(
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
ggml_set_op_params_i32(result, 0, id); ggml_set_op_params_i32(result, 0, id);
ggml_set_op_params_i32(result, 1, n_as);
result->op = GGML_OP_MUL_MAT_ID; result->op = GGML_OP_MUL_MAT_ID;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = ids; result->src[0] = ids;
result->src[1] = b; result->src[1] = b;
for (int64_t i = 0; i < n_as; i++) { for (int i = 0; i < n_as; i++) {
struct ggml_tensor * a = as[i]; struct ggml_tensor * a = as[i];
GGML_ASSERT(ggml_are_same_shape(as[0], a)); GGML_ASSERT(ggml_are_same_shape(as[0], a));
GGML_ASSERT(ggml_can_mul_mat(a, b)); GGML_ASSERT(ggml_can_mul_mat(a, b));
@ -4731,7 +4733,9 @@ struct ggml_tensor * ggml_get_rows(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b) { struct ggml_tensor * b) {
GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[2] == b->ne[1]);
GGML_ASSERT(b->ne[3] == 1);
GGML_ASSERT(b->type == GGML_TYPE_I32);
bool is_node = false; bool is_node = false;
@ -4741,7 +4745,7 @@ struct ggml_tensor * ggml_get_rows(
// TODO: implement non F32 return // TODO: implement non F32 return
//struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
result->op = GGML_OP_GET_ROWS; result->op = GGML_OP_GET_ROWS;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -7520,7 +7524,7 @@ static void ggml_compute_forward_acc_f32(
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
// view src0 and dst with these strides and data offset inbytes during acc // view src0 and dst with these strides and data offset inbytes during acc
// nb0 is implicitely element_size because src0 and dst are contiguous // nb0 is implicitly element_size because src0 and dst are contiguous
size_t nb1 = ((int32_t *) dst->op_params)[0]; size_t nb1 = ((int32_t *) dst->op_params)[0];
size_t nb2 = ((int32_t *) dst->op_params)[1]; size_t nb2 = ((int32_t *) dst->op_params)[1];
size_t nb3 = ((int32_t *) dst->op_params)[2]; size_t nb3 = ((int32_t *) dst->op_params)[2];
@ -9504,8 +9508,11 @@ static bool ggml_compute_forward_mul_mat_use_blas(
const int64_t ne0 = dst->ne[0]; const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1]; const int64_t ne1 = dst->ne[1];
// NOTE: with GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float)
// all the experts for each batch element and the processing would become incredibly slow
// TODO: find the optimal values for these // TODO: find the optimal values for these
if (ggml_is_contiguous(src0) && if (dst->op != GGML_OP_MUL_MAT_ID &&
ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) && ggml_is_contiguous(src1) &&
//src0->type == GGML_TYPE_F32 && //src0->type == GGML_TYPE_F32 &&
src1->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 &&
@ -9519,11 +9526,16 @@ static bool ggml_compute_forward_mul_mat_use_blas(
} }
#endif #endif
// off1 = offset in i11 and i1
// cne1 = ne11 and ne1
// in a normal matrix multiplication, off1 = 0 and cne1 = ne1
// during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
static void ggml_compute_forward_mul_mat( static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst,
int64_t off1, int64_t cne1) {
int64_t t0 = ggml_perf_time_us(); int64_t t0 = ggml_perf_time_us();
UNUSED(t0); UNUSED(t0);
@ -9591,10 +9603,9 @@ static void ggml_compute_forward_mul_mat(
const int64_t i03 = i13/r3; const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2; const int64_t i02 = i12/r2;
const void * x = (char *) src0->data + i02*nb02 + i03*nb03; const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3);
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
if (type != GGML_TYPE_F32) { if (type != GGML_TYPE_F32) {
float * const wdata = params->wdata; float * const wdata = params->wdata;
@ -9611,10 +9622,10 @@ static void ggml_compute_forward_mul_mat(
} }
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne11, ne01, ne10, cne1, ne01, ne10,
1.0f, y, ne10, 1.0f, y, ne10,
x, ne00, x, ne00,
0.0f, d, ne01); 0.0f, d, ne01);
} }
} }
@ -9630,6 +9641,7 @@ static void ggml_compute_forward_mul_mat(
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type); const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
assert(params->wsize >= ne11*ne12*ne13*row_size); assert(params->wsize >= ne11*ne12*ne13*row_size);
assert(src1->type == GGML_TYPE_F32);
for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i12 = 0; i12 < ne12; ++i12) {
@ -9652,7 +9664,7 @@ static void ggml_compute_forward_mul_mat(
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type); const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
const int64_t nr0 = ne01; // src0 rows const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = ne11*ne12*ne13; // src1 rows const int64_t nr1 = cne1*ne12*ne13; // src1 rows
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
@ -9694,9 +9706,9 @@ static void ggml_compute_forward_mul_mat(
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
const int64_t i13 = (ir1/(ne12*ne11)); const int64_t i13 = (ir1/(ne12*cne1));
const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11; const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11); const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
// broadcast src0 into src1 // broadcast src0 into src1
const int64_t i03 = i13/r3; const int64_t i03 = i13/r3;
@ -9736,20 +9748,28 @@ static void ggml_compute_forward_mul_mat(
static void ggml_compute_forward_mul_mat_id( static void ggml_compute_forward_mul_mat_id(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
const struct ggml_tensor * ids = dst->src[0]; if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
const struct ggml_tensor * src1 = dst->src[1]; // during GGML_TASK_INIT the entire src1 is converted to vec_dot_type
ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
return;
}
const int id = ggml_get_op_params_i32(dst, 0); const struct ggml_tensor * ids = src0;
const int id = ggml_get_op_params_i32(dst, 0);
const int n_as = ggml_get_op_params_i32(dst, 1);
const int a_id = ((int32_t *)ids->data)[id]; for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]); GGML_ASSERT(row_id >= 0 && row_id < n_as);
const struct ggml_tensor * src0 = dst->src[a_id + 2]; const struct ggml_tensor * src0_row = dst->src[row_id + 2];
ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
ggml_compute_forward_mul_mat(params, src0, src1, dst); }
} }
// ggml_compute_forward_out_prod // ggml_compute_forward_out_prod
@ -10161,7 +10181,7 @@ static void ggml_compute_forward_set_f32(
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
// view src0 and dst with these strides and data offset inbytes during set // view src0 and dst with these strides and data offset inbytes during set
// nb0 is implicitely element_size because src0 and dst are contiguous // nb0 is implicitly element_size because src0 and dst are contiguous
size_t nb1 = ((int32_t *) dst->op_params)[0]; size_t nb1 = ((int32_t *) dst->op_params)[0];
size_t nb2 = ((int32_t *) dst->op_params)[1]; size_t nb2 = ((int32_t *) dst->op_params)[1];
size_t nb3 = ((int32_t *) dst->op_params)[2]; size_t nb3 = ((int32_t *) dst->op_params)[2];
@ -10325,21 +10345,30 @@ static void ggml_compute_forward_get_rows_q(
return; return;
} }
const int nc = src0->ne[0]; GGML_TENSOR_BINARY_OP_LOCALS
const int nr = ggml_nelements(src1);
const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
const enum ggml_type type = src0->type; const enum ggml_type type = src0->type;
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
assert( dst->ne[0] == nc); assert(ne0 == nc);
assert( dst->ne[1] == nr); assert(ne02 == ne11);
assert(src0->nb[0] == ggml_type_size(type)); assert(nb00 == ggml_type_size(type));
assert(ggml_nrows(dst) == nr);
for (int i = 0; i < nr; ++i) { // TODO: multi-thread
const int r = ((int32_t *) src1->data)[i]; for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
for (int64_t i10 = 0; i10 < ne10; ++i10) {
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
dequantize_row_q( dequantize_row_q(
(const void *) ((char *) src0->data + r*src0->nb[1]), (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
(float *) ((char *) dst->data + i*dst->nb[1]), nc); (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
}
}
} }
} }
@ -10354,19 +10383,26 @@ static void ggml_compute_forward_get_rows_f16(
return; return;
} }
const int nc = src0->ne[0]; GGML_TENSOR_BINARY_OP_LOCALS
const int nr = ggml_nelements(src1);
assert( dst->ne[0] == nc); const int64_t nc = ne00;
assert( dst->ne[1] == nr); const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
assert(src0->nb[0] == sizeof(ggml_fp16_t));
for (int i = 0; i < nr; ++i) { assert(ne0 == nc);
const int r = ((int32_t *) src1->data)[i]; assert(ne02 == ne11);
assert(nb00 == sizeof(ggml_fp16_t));
assert(ggml_nrows(dst) == nr);
for (int j = 0; j < nc; ++j) { // TODO: multi-thread
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; for (int64_t i12 = 0; i12 < ne12; ++i12) {
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); for (int64_t i11 = 0; i11 < ne11; ++i11) {
for (int64_t i10 = 0; i10 < ne10; ++i10) {
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
ggml_fp16_to_fp32_row(
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
}
} }
} }
} }
@ -10382,19 +10418,27 @@ static void ggml_compute_forward_get_rows_f32(
return; return;
} }
const int nc = src0->ne[0]; GGML_TENSOR_BINARY_OP_LOCALS
const int nr = ggml_nelements(src1);
assert( dst->ne[0] == nc); const int64_t nc = ne00;
assert( dst->ne[1] == nr); const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < nr; ++i) { assert(ne0 == nc);
const int r = ((int32_t *) src1->data)[i]; assert(ne02 == ne11);
assert(nb00 == sizeof(float));
assert(ggml_nrows(dst) == nr);
ggml_vec_cpy_f32(nc, // TODO: multi-thread
(float *) ((char *) dst->data + i*dst->nb[1]), for (int64_t i12 = 0; i12 < ne12; ++i12) {
(float *) ((char *) src0->data + r*src0->nb[1])); for (int64_t i11 = 0; i11 < ne11; ++i11) {
for (int64_t i10 = 0; i10 < ne10; ++i10) {
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
ggml_vec_cpy_f32(nc,
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
(float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
}
}
} }
} }
@ -14037,11 +14081,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break; } break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {
ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor); ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
} break; } break;
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
{ {
ggml_compute_forward_mul_mat_id(params, tensor); ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor);
} break; } break;
case GGML_OP_OUT_PROD: case GGML_OP_OUT_PROD:
{ {
@ -14475,7 +14519,7 @@ void ggml_build_backward_gradient_checkpointing(
// insert new tensors recomputing src, reusing already made replacements, // insert new tensors recomputing src, reusing already made replacements,
// remember replacements: remember new tensors with mapping from corresponding gf nodes // remember replacements: remember new tensors with mapping from corresponding gf nodes
// recurse for input tensors, // recurse for input tensors,
// unless (i.e. terminating when) input tensors are replacments (like checkpoints) // unless (i.e. terminating when) input tensors are replacements (like checkpoints)
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
} }
// insert rewritten backward node with replacements made into resulting backward graph gb // insert rewritten backward node with replacements made into resulting backward graph gb

8
ggml.h
View File

@ -215,9 +215,9 @@
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
#define GGML_MAX_DIMS 4 #define GGML_MAX_DIMS 4
#define GGML_MAX_PARAMS 1024 #define GGML_MAX_PARAMS 2048
#define GGML_MAX_CONTEXTS 64 #define GGML_MAX_CONTEXTS 64
#define GGML_MAX_SRC 6 #define GGML_MAX_SRC 10
#define GGML_MAX_NAME 64 #define GGML_MAX_NAME 64
#define GGML_MAX_OP_PARAMS 64 #define GGML_MAX_OP_PARAMS 64
#define GGML_DEFAULT_N_THREADS 4 #define GGML_DEFAULT_N_THREADS 4
@ -1051,7 +1051,8 @@ extern "C" {
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b) // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
GGML_API struct ggml_tensor * ggml_mul_mat_id( GGML_API struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * as[], struct ggml_tensor * const as[],
int n_as,
struct ggml_tensor * ids, struct ggml_tensor * ids,
int id, int id,
struct ggml_tensor * b); struct ggml_tensor * b);
@ -1263,6 +1264,7 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
// supports 3D: a->ne[2] == b->ne[1]
GGML_API struct ggml_tensor * ggml_get_rows( GGML_API struct ggml_tensor * ggml_get_rows(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,

View File

@ -61,7 +61,7 @@ If you want to publish the package manually for any reason, you need to have `tw
pip install build twine pip install build twine
``` ```
Then, folow these steps to release a new version: Then, follow these steps to release a new version:
1. Bump the version in `pyproject.toml`. 1. Bump the version in `pyproject.toml`.
2. Build the package: 2. Build the package:

View File

@ -38,6 +38,8 @@ class Keys:
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
EXPERT_COUNT = "{arch}.expert_count"
EXPERT_USED_COUNT = "{arch}.expert_used_count"
class Attention: class Attention:
HEAD_COUNT = "{arch}.attention.head_count" HEAD_COUNT = "{arch}.attention.head_count"
@ -111,10 +113,14 @@ class MODEL_TENSOR(IntEnum):
ATTN_NORM = auto() ATTN_NORM = auto()
ATTN_NORM_2 = auto() ATTN_NORM_2 = auto()
ATTN_ROT_EMBD = auto() ATTN_ROT_EMBD = auto()
FFN_GATE_INP = auto()
FFN_NORM = auto()
FFN_GATE = auto() FFN_GATE = auto()
FFN_DOWN = auto() FFN_DOWN = auto()
FFN_UP = auto() FFN_UP = auto()
FFN_NORM = auto() FFN_GATE_EXP = auto()
FFN_DOWN_EXP = auto()
FFN_UP_EXP = auto()
ATTN_Q_NORM = auto() ATTN_Q_NORM = auto()
ATTN_K_NORM = auto() ATTN_K_NORM = auto()
@ -154,10 +160,14 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
} }
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -172,10 +182,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD, MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
], ],
MODEL_ARCH.GPTNEOX: [ MODEL_ARCH.GPTNEOX: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,

View File

@ -339,6 +339,12 @@ class GGUFWriter:
def add_clamp_kqv(self, value: float) -> None: def add_clamp_kqv(self, value: float) -> None:
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value) self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
def add_expert_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
def add_expert_used_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
def add_layer_norm_eps(self, value: float) -> None: def add_layer_norm_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)

View File

@ -149,6 +149,11 @@ class TensorNameMap:
"model.layers.{bid}.ln2", # yi "model.layers.{bid}.ln2", # yi
), ),
MODEL_TENSOR.FFN_GATE_INP: (
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
),
# Feed-forward up # Feed-forward up
MODEL_TENSOR.FFN_UP: ( MODEL_TENSOR.FFN_UP: (
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
@ -164,11 +169,21 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.w1", # qwen "transformer.h.{bid}.mlp.w1", # qwen
), ),
MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral
),
# Feed-forward gate # Feed-forward gate
MODEL_TENSOR.FFN_GATE: ( MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact "model.layers.{bid}.mlp.gate_proj", # llama-hf refact
"layers.{bid}.feed_forward.w1", # llama-pth "layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen "transformer.h.{bid}.mlp.w2", # qwen
),
MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral
), ),
# Feed-forward down # Feed-forward down
@ -185,6 +200,11 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
), ),
MODEL_TENSOR.FFN_DOWN_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral
),
MODEL_TENSOR.ATTN_Q_NORM: ( MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm", "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
), ),
@ -213,11 +233,14 @@ class TensorNameMap:
for tensor, keys in self.block_mappings_cfg.items(): for tensor, keys in self.block_mappings_cfg.items():
if tensor not in MODEL_TENSORS[arch]: if tensor not in MODEL_TENSORS[arch]:
continue continue
tensor_name = TENSOR_NAMES[tensor].format(bid = bid) # TODO: make this configurable
self.mapping[tensor_name] = (tensor, tensor_name) n_experts = 8
for key in keys: for xid in range(n_experts):
key = key.format(bid = bid) tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
self.mapping[key] = (tensor, tensor_name) self.mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
key = key.format(bid = bid, xid = xid)
self.mapping[key] = (tensor, tensor_name)
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None: def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key) result = self.mapping.get(key)

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "gguf" name = "gguf"
version = "0.6.0" version = "0.7.0"
description = "Read and write ML models in GGUF for GGML" description = "Read and write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"] authors = ["GGML <ggml@ggml.ai>"]
packages = [ packages = [

212
llama.cpp
View File

@ -91,7 +91,8 @@
#define LLAMA_ATTRIBUTE_FORMAT(...) #define LLAMA_ATTRIBUTE_FORMAT(...)
#endif #endif
#define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_NODES 8192
#define LLAMA_MAX_EXPERTS 8
// //
// logging // logging
@ -231,6 +232,8 @@ enum llm_kv {
LLM_KV_FEED_FORWARD_LENGTH, LLM_KV_FEED_FORWARD_LENGTH,
LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_USE_PARALLEL_RESIDUAL,
LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_TENSOR_DATA_LAYOUT,
LLM_KV_EXPERT_COUNT,
LLM_KV_EXPERT_USED_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV, LLM_KV_ATTENTION_HEAD_COUNT_KV,
@ -281,6 +284,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@ -338,10 +343,14 @@ enum llm_tensor {
LLM_TENSOR_ATTN_NORM, LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_NORM_2, LLM_TENSOR_ATTN_NORM_2,
LLM_TENSOR_ATTN_ROT_EMBD, LLM_TENSOR_ATTN_ROT_EMBD,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP, LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_DOWN_EXP,
LLM_TENSOR_FFN_GATE_EXP,
LLM_TENSOR_FFN_UP_EXP,
LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_ATTN_K_NORM,
}; };
@ -360,10 +369,14 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
}, },
}, },
{ {
@ -585,6 +598,10 @@ struct LLM_TN {
std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const { std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const {
return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix; return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix;
} }
std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const {
return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid, xid) + "." + suffix;
}
}; };
// //
@ -1164,6 +1181,8 @@ struct llama_hparams {
uint32_t n_layer; uint32_t n_layer;
uint32_t n_rot; uint32_t n_rot;
uint32_t n_ff; uint32_t n_ff;
uint32_t n_expert = 0;
uint32_t n_expert_used = 0;
float f_norm_eps; float f_norm_eps;
float f_norm_rms_eps; float f_norm_rms_eps;
@ -1178,15 +1197,18 @@ struct llama_hparams {
float f_max_alibi_bias; float f_max_alibi_bias;
bool operator!=(const llama_hparams & other) const { bool operator!=(const llama_hparams & other) const {
if (this->vocab_only != other.vocab_only) return true; if (this->vocab_only != other.vocab_only) return true;
if (this->n_vocab != other.n_vocab) return true; if (this->n_vocab != other.n_vocab) return true;
if (this->n_ctx_train != other.n_ctx_train) return true; if (this->n_ctx_train != other.n_ctx_train) return true;
if (this->n_embd != other.n_embd) return true; if (this->n_embd != other.n_embd) return true;
if (this->n_head != other.n_head) return true; if (this->n_head != other.n_head) return true;
if (this->n_head_kv != other.n_head_kv) return true; if (this->n_head_kv != other.n_head_kv) return true;
if (this->n_layer != other.n_layer) return true; if (this->n_layer != other.n_layer) return true;
if (this->n_rot != other.n_rot) return true; if (this->n_rot != other.n_rot) return true;
if (this->n_ff != other.n_ff) return true; if (this->n_ff != other.n_ff) return true;
if (this->n_expert != other.n_expert) return true;
if (this->n_expert_used != other.n_expert_used) return true;
if (this->rope_finetuned != other.rope_finetuned) return true; if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
@ -1268,6 +1290,12 @@ struct llama_layer {
struct ggml_tensor * ffn_down; // w2 struct ggml_tensor * ffn_down; // w2
struct ggml_tensor * ffn_up; // w3 struct ggml_tensor * ffn_up; // w3
// ff MoE
struct ggml_tensor * ffn_gate_inp;
struct ggml_tensor * ffn_gate_exp[LLAMA_MAX_EXPERTS];
struct ggml_tensor * ffn_down_exp[LLAMA_MAX_EXPERTS];
struct ggml_tensor * ffn_up_exp [LLAMA_MAX_EXPERTS];
// ff bias // ff bias
struct ggml_tensor * ffn_down_b; // b2 struct ggml_tensor * ffn_down_b; // b2
struct ggml_tensor * ffn_up_b; // b3 struct ggml_tensor * ffn_up_b; // b3
@ -2440,6 +2468,16 @@ static void llm_load_hparams(
ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff); ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head); ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer);
ml.get_key (LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
ml.get_key (LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
if (hparams.n_expert > 0) {
GGML_ASSERT(hparams.n_expert_used > 0);
} else {
GGML_ASSERT(hparams.n_expert_used == 0);
}
// n_head_kv is optional, default to n_head // n_head_kv is optional, default to n_head
hparams.n_head_kv = hparams.n_head; hparams.n_head_kv = hparams.n_head;
@ -2758,7 +2796,7 @@ static void llm_load_vocab(
// The assumption is, since special tokens aren't meant to be exposed to end user, they are designed // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed
// to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer // to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer
// are special tokens. // are special tokens.
// From testing, this appears to corelate 1:1 with special tokens. // From testing, this appears to correlate 1:1 with special tokens.
// //
// Counting special tokens and verifying in only one direction // Counting special tokens and verifying in only one direction
@ -2871,6 +2909,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
@ -3025,9 +3065,26 @@ static void llm_load_tensors(
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); layer.ffn_gate_inp = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, backend, false);
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); if (layer.ffn_gate_inp == nullptr) {
GGML_ASSERT(hparams.n_expert == 0);
GGML_ASSERT(hparams.n_expert_used == 0);
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
} else {
GGML_ASSERT(hparams.n_expert > 0);
GGML_ASSERT(hparams.n_expert_used > 0);
// MoE branch
for (uint32_t x = 0; x < hparams.n_expert; ++x) {
layer.ffn_gate_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd, n_ff}, backend_split);
layer.ffn_down_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd}, backend_split);
layer.ffn_up_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), {n_embd, n_ff}, backend_split);
}
}
if (backend == GGML_BACKEND_GPU) { if (backend == GGML_BACKEND_GPU) {
vram_weights += vram_weights +=
@ -3037,8 +3094,18 @@ static void llm_load_tensors(
(layer.bk ? ggml_nbytes(layer.bk) : 0) + (layer.bk ? ggml_nbytes(layer.bk) : 0) +
(layer.bv ? ggml_nbytes(layer.bv) : 0) + (layer.bv ? ggml_nbytes(layer.bv) : 0) +
(layer.bo ? ggml_nbytes(layer.bo) : 0) + (layer.bo ? ggml_nbytes(layer.bo) : 0) +
ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_norm);
ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
if (layer.ffn_gate_inp == nullptr) {
vram_weights +=
ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
} else {
vram_weights += ggml_nbytes(layer.ffn_gate_inp);
for (uint32_t x = 0; x < hparams.n_expert; ++x) {
vram_weights +=
ggml_nbytes(layer.ffn_gate_exp[x]) + ggml_nbytes(layer.ffn_down_exp[x]) + ggml_nbytes(layer.ffn_up_exp[x]);
}
}
} }
} }
} break; } break;
@ -4019,6 +4086,8 @@ struct llm_build_context {
const int64_t n_head_kv; const int64_t n_head_kv;
const int64_t n_embd_head; const int64_t n_embd_head;
const int64_t n_embd_gqa; const int64_t n_embd_gqa;
const int64_t n_expert;
const int64_t n_expert_used;
const float freq_base; const float freq_base;
const float freq_scale; const float freq_scale;
@ -4060,6 +4129,8 @@ struct llm_build_context {
n_head_kv (hparams.n_head_kv), n_head_kv (hparams.n_head_kv),
n_embd_head (hparams.n_embd_head()), n_embd_head (hparams.n_embd_head()),
n_embd_gqa (hparams.n_embd_gqa()), n_embd_gqa (hparams.n_embd_gqa()),
n_expert (hparams.n_expert),
n_expert_used (hparams.n_expert_used),
freq_base (cparams.rope_freq_base), freq_base (cparams.rope_freq_base),
freq_scale (cparams.rope_freq_scale), freq_scale (cparams.rope_freq_scale),
ext_factor (cparams.yarn_ext_factor), ext_factor (cparams.yarn_ext_factor),
@ -4184,7 +4255,7 @@ struct llm_build_context {
cb(ffn_inp, "ffn_inp", il); cb(ffn_inp, "ffn_inp", il);
// feed-forward network // feed-forward network
{ if (model.layers[il].ffn_gate_inp == nullptr) {
cur = llm_build_norm(ctx0, ffn_inp, hparams, cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL, model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);
@ -4196,6 +4267,69 @@ struct llm_build_context {
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il); LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} else {
// MoE branch
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
cb(logits, "ffn_moe_logits", il);
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
cb(probs, "ffn_moe_probs", il);
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
ggml_tensor * weights = ggml_get_rows(ctx0,
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
cb(weights, "ffn_moe_weights", il);
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
cb(weights_sum, "ffn_moe_weights_sum", il);
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
cb(weights, "ffn_moe_weights_norm", il);
// compute expert outputs
ggml_tensor * moe_out = nullptr;
for (int i = 0; i < n_expert_used; ++i) {
ggml_tensor * cur_expert;
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur);
cb(cur_up, "ffn_moe_up", il);
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur);
cb(cur_gate, "ffn_moe_gate", il);
cur_gate = ggml_silu(ctx0, cur_gate);
cb(cur_gate, "ffn_moe_silu", il);
cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
cb(cur_expert, "ffn_moe_gate_par", il);
cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd]
cb(cur_expert, "ffn_moe_down", il);
cur_expert = ggml_mul(ctx0, cur_expert,
ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
cb(cur_expert, "ffn_moe_weighted", il);
if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx0, moe_out, cur_expert);
cb(moe_out, "ffn_moe_out", il);
}
}
cur = moe_out;
} }
cur = ggml_add(ctx0, cur, ffn_inp); cur = ggml_add(ctx0, cur, ffn_inp);
@ -5450,6 +5584,20 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
{ "ffn_relu", OFFLOAD_FUNC }, { "ffn_relu", OFFLOAD_FUNC },
{ "ffn_sqr(relu)", OFFLOAD_FUNC }, { "ffn_sqr(relu)", OFFLOAD_FUNC },
{ "ffn_moe_logits", OFFLOAD_FUNC },
{ "ffn_moe_probs", OFFLOAD_FUNC },
{ "ffn_moe_argsort", OFFLOAD_FUNC },
{ "ffn_moe_weights", OFFLOAD_FUNC },
{ "ffn_moe_weights_sum", OFFLOAD_FUNC },
{ "ffn_moe_weights_norm", OFFLOAD_FUNC },
{ "ffn_moe_weighted", OFFLOAD_FUNC },
{ "ffn_moe_up", OFFLOAD_FUNC },
{ "ffn_moe_gate", OFFLOAD_FUNC },
{ "ffn_moe_silu", OFFLOAD_FUNC },
{ "ffn_moe_gate_par", OFFLOAD_FUNC },
{ "ffn_moe_down", OFFLOAD_FUNC },
{ "ffn_moe_out", OFFLOAD_FUNC },
{ "l_out", OFFLOAD_FUNC }, { "l_out", OFFLOAD_FUNC },
{ "result_norm", OFFLOAD_FUNC_EMB }, { "result_norm", OFFLOAD_FUNC_EMB },
@ -5846,7 +5994,7 @@ static int llama_decode_internal(
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = hparams.n_vocab; const int64_t n_vocab = hparams.n_vocab;
// helpers for smoother batch API transistion // helpers for smoother batch API transition
// after deprecating the llama_eval calls, these will be removed // after deprecating the llama_eval calls, these will be removed
std::vector<llama_pos> pos; std::vector<llama_pos> pos;
@ -6625,12 +6773,12 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
// loop over the text // loop over the text
while (true) { while (true) {
// find the first occurence of a given special token in this fragment // find the first occurrence of a given special token in this fragment
// passing offset argument only limit the "search area" but match coordinates // passing offset argument only limit the "search area" but match coordinates
// are still relative to the source full raw_text // are still relative to the source full raw_text
auto match = raw_text->find(special_token, raw_text_base_offset); auto match = raw_text->find(special_token, raw_text_base_offset);
// no occurences found, stop processing this fragment for a given special token // no occurrences found, stop processing this fragment for a given special token
if (match == std::string::npos) break; if (match == std::string::npos) break;
// check if match is within bounds of offset <-> length // check if match is within bounds of offset <-> length
@ -7829,7 +7977,7 @@ struct llama_beam_search_data {
} }
// Min-heaps are used to efficiently collect the top-k elements (k=n_beams). // Min-heaps are used to efficiently collect the top-k elements (k=n_beams).
// The repetative patterns below reflect the 2 stages of heaps: // The repetitive patterns below reflect the 2 stages of heaps:
// * Gather elements until the vector is full, then call std::make_heap() on it. // * Gather elements until the vector is full, then call std::make_heap() on it.
// * If the heap is full and a new element is found that should be included, pop the // * If the heap is full and a new element is found that should be included, pop the
// least element to the back(), replace it with the new, then push it into the heap. // least element to the back(), replace it with the new, then push it into the heap.
@ -8067,11 +8215,9 @@ static void llama_convert_tensor_internal(
workers.clear(); workers.clear();
} }
static ggml_type get_k_quant_type( static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
quantize_state_internal & qs,
ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype
) {
const std::string name = ggml_get_name(tensor); const std::string name = ggml_get_name(tensor);
// TODO: avoid hardcoded tensor names - use the TN_* constants // TODO: avoid hardcoded tensor names - use the TN_* constants
const llm_arch arch = qs.model.arch; const llm_arch arch = qs.model.arch;
const auto tn = LLM_TN(arch); const auto tn = LLM_TN(arch);
@ -8105,7 +8251,18 @@ static ggml_type get_k_quant_type(
// nearly negligible increase in model size by quantizing this tensor with more bits: // nearly negligible increase in model size by quantizing this tensor with more bits:
if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K; if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
} }
if (qs.model.hparams.n_expert == 8) {
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
// TODO: explore better strategies
new_type = GGML_TYPE_Q8_0;
}
++qs.i_attention_wv; ++qs.i_attention_wv;
} else if (name.find("attn_k.weight") != std::string::npos) {
if (qs.model.hparams.n_expert == 8) {
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
// TODO: explore better strategies
new_type = GGML_TYPE_Q8_0;
}
} else if (name.find("ffn_down.weight") != std::string::npos) { } else if (name.find("ffn_down.weight") != std::string::npos) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
@ -8318,6 +8475,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
quantize &= params->quantize_output_tensor || name != "output.weight"; quantize &= params->quantize_output_tensor || name != "output.weight";
quantize &= !params->only_copy; quantize &= !params->only_copy;
// do not quantize expert gating tensors
quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
enum ggml_type new_type; enum ggml_type new_type;
void * new_data; void * new_data;
size_t new_size; size_t new_size;

View File

@ -216,7 +216,7 @@ extern "C" {
// Keep the booleans together to avoid misalignment during copy-by-value. // Keep the booleans together to avoid misalignment during copy-by-value.
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true) bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
bool logits_all; // the llama_eval() call computes all logits, not just the last one bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embedding; // embedding mode only bool embedding; // embedding mode only
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
}; };

View File

@ -20,8 +20,6 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
size_t size = ggml_nelements(tensor); size_t size = ggml_nelements(tensor);
std::vector<float> data(size); std::vector<float> data(size);
std::random_device rd;
#if 0 #if 0
std::default_random_engine generator(rd()); std::default_random_engine generator(rd());
std::uniform_real_distribution<float> distribution(min, max); std::uniform_real_distribution<float> distribution(min, max);
@ -31,6 +29,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
} }
#endif #endif
auto init_thread = [&](size_t start, size_t end) { auto init_thread = [&](size_t start, size_t end) {
std::random_device rd;
std::default_random_engine generator(rd()); std::default_random_engine generator(rd());
std::uniform_real_distribution<float> distribution(min, max); std::uniform_real_distribution<float> distribution(min, max);
@ -51,7 +50,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
t.join(); t.join();
} }
if (tensor->type == GGML_TYPE_F32) { if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float)); ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) { } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) {
GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0); GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
@ -71,23 +70,28 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
std::vector<uint8_t> buf(ggml_nbytes(t)); std::vector<uint8_t> buf(ggml_nbytes(t));
ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t)); ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t));
ggml_type_traits_t tt = ggml_internal_get_type_traits(t->type);
size_t bs = ggml_blck_size(t->type);
// access elements by index to avoid gaps in views // access elements by index to avoid gaps in views
for (int64_t i3 = 0; i3 < t->ne[3]; i3++) { for (int64_t i3 = 0; i3 < t->ne[3]; i3++) {
for (int64_t i2 = 0; i2 < t->ne[2]; i2++) { for (int64_t i2 = 0; i2 < t->ne[2]; i2++) {
for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {
for (int64_t i0 = 0; i0 < t->ne[0]; i0++) { for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) {
size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0*t->nb[0]; size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];
float v;
if (t->type == GGML_TYPE_F16) { if (t->type == GGML_TYPE_F16) {
v = (float) ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]); tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]));
} else if (t->type == GGML_TYPE_F32) { } else if (t->type == GGML_TYPE_F32) {
v = *(float *) &buf[i]; tv.push_back(*(float *) &buf[i]);
} else if (t->type == GGML_TYPE_I32) { } else if (t->type == GGML_TYPE_I32) {
v = *(int32_t *) &buf[i]; tv.push_back((float)*(int32_t *) &buf[i]);
} else if (ggml_is_quantized(t->type)) {
std::vector<float> vq(ggml_blck_size(t->type));
tt.to_float(&buf[i], vq.data(), ggml_blck_size(t->type));
tv.insert(tv.end(), vq.begin(), vq.end());
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }
tv.push_back(v);
} }
} }
} }
@ -233,6 +237,10 @@ static bool ggml_is_view_op(enum ggml_op op) {
struct test_case { struct test_case {
virtual ~test_case() {} virtual ~test_case() {}
virtual std::string op_desc(ggml_tensor * t) {
return ggml_op_desc(t);
}
virtual std::string vars() { virtual std::string vars() {
return ""; return "";
} }
@ -240,7 +248,7 @@ struct test_case {
virtual ggml_tensor * build_graph(ggml_context * ctx) = 0; virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
virtual double max_nmse_err() { virtual double max_nmse_err() {
return 1e-6; return 1e-7;
} }
virtual void initialize_tensors(ggml_context * ctx) { virtual void initialize_tensors(ggml_context * ctx) {
@ -270,13 +278,13 @@ struct test_case {
ggml_tensor * out = build_graph(ctx); ggml_tensor * out = build_graph(ctx);
if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) { if (op_name != nullptr && op_desc(out) != op_name) {
//printf(" %s: skipping\n", ggml_op_desc(out)); //printf(" %s: skipping\n", op_desc(out).c_str());
ggml_free(ctx); ggml_free(ctx);
return true; return true;
} }
printf(" %s(%s): ", ggml_op_desc(out), vars().c_str()); printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
fflush(stdout); fflush(stdout);
// check if backends support op // check if backends support op
@ -317,7 +325,7 @@ struct test_case {
for (size_t i = 0; i < f1.size(); i++) { for (size_t i = 0; i < f1.size(); i++) {
// check for nans // check for nans
if (std::isnan(f1[i]) || std::isnan(f2[i])) { if (std::isnan(f1[i]) || std::isnan(f2[i])) {
printf("NaN at index %zu ", i); printf("[%s] NaN at index %zu (%f %f) ", ggml_op_desc(t1), i, f1[i], f2[i]);
ud->ok = false; ud->ok = false;
return true; return true;
} }
@ -325,12 +333,12 @@ struct test_case {
if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) { if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) { if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
if (std::signbit(f1[i]) != std::signbit(f2[i])) { if (std::signbit(f1[i]) != std::signbit(f2[i])) {
printf("inf sign mismatch: %f %f ", f1[i], f2[i]); printf("[%s] inf sign mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
ud->ok = false; ud->ok = false;
return true; return true;
} }
} else { } else {
printf("inf mismatch: %f %f ", f1[i], f2[i]); printf("[%s] inf mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
ud->ok = false; ud->ok = false;
return true; return true;
} }
@ -339,10 +347,16 @@ struct test_case {
double err = nmse(f1.data(), f2.data(), f1.size()); double err = nmse(f1.data(), f2.data(), f1.size());
if (err > ud->max_err) { if (err > ud->max_err) {
printf("NMSE = %f ", err); printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
//for (int i = 0; i < f1.size(); i++) {
// printf("(%f, %f) ", f1[i], f2[i]);
//}
//printf("\n");
ud->ok = false; ud->ok = false;
} }
return true; return true;
GGML_UNUSED(index);
}; };
ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud); ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
@ -372,13 +386,13 @@ struct test_case {
ggml_tensor * out = build_graph(ctx); ggml_tensor * out = build_graph(ctx);
if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) { if (op_name != nullptr && op_desc(out) != op_name) {
//printf(" %s: skipping\n", ggml_op_desc(out)); //printf(" %s: skipping\n", op_desc(out).c_str());
ggml_free(ctx); ggml_free(ctx);
return true; return true;
} }
int len = printf(" %s(%s): ", ggml_op_desc(out), vars().c_str()); int len = printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
fflush(stdout); fflush(stdout);
// check if backends support op // check if backends support op
@ -430,8 +444,9 @@ struct test_case {
return size; return size;
}; };
for (int i = 0; i < gf->n_nodes; i++) { for (int i = 0; i < gf->n_nodes; i++) {
if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out) if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out) {
continue; continue;
}
mem += tensor_op_size(gf->nodes[i]); mem += tensor_op_size(gf->nodes[i]);
} }
@ -486,17 +501,22 @@ struct test_get_rows : public test_case {
const int n; // cols const int n; // cols
const int m; // rows const int m; // rows
const int r; // rows to get const int r; // rows to get
const int b; // batch size
const bool v; // view (non-contiguous src1)
std::string vars() override { std::string vars() override {
return VARS_TO_STR4(type, n, m, r); return VARS_TO_STR6(type, n, m, r, b, v);
} }
test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3) test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
: type(type), n(n), m(m), r(r) {} : type(type), n(n), m(m), r(r), b(b), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * in = ggml_new_tensor_2d(ctx, type, n, m); ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
ggml_tensor * rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, r); ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
if (v) {
rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
}
ggml_tensor * out = ggml_get_rows(ctx, in, rows); ggml_tensor * out = ggml_get_rows(ctx, in, rows);
return out; return out;
} }
@ -504,12 +524,13 @@ struct test_get_rows : public test_case {
void initialize_tensors(ggml_context * ctx) override { void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) { if (t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) { continue; }
// rows // rows
std::vector<int> data(r); std::vector<int> data(r*b);
for (int i = 0; i < r; i++) { for (int i = 0; i < r*b; i++) {
data[i] = rand() % m; data[i] = rand() % m;
} }
ggml_backend_tensor_set(t, data.data(), 0, r * sizeof(int)); ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
} else { } else {
init_tensor_uniform(t); init_tensor_uniform(t);
} }
@ -770,11 +791,10 @@ struct test_mul_mat_id : public test_case {
const int64_t m; const int64_t m;
const int64_t n; const int64_t n;
const int64_t k; const int64_t k;
const std::array<int64_t, 2> bs; // dims 3 and 4 const bool v; // view (non-contiguous ids)
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
std::string vars() override { std::string vars() override {
return VARS_TO_STR9(type_a, type_b, n_mats, id, m, n, k, bs, nr); return VARS_TO_STR8(type_a, type_b, n_mats, id, m, n, k, v);
} }
double max_nmse_err() override { double max_nmse_err() override {
@ -782,7 +802,7 @@ struct test_mul_mat_id : public test_case {
} }
size_t op_size(ggml_tensor * t) override { size_t op_size(ggml_tensor * t) override {
size_t a = ggml_nbytes(t->src[2]) * n * nr[0] * nr[1]; size_t a = ggml_nbytes(t->src[2]) * n;
size_t b = ggml_nbytes(t->src[1]) * m; size_t b = ggml_nbytes(t->src[1]) * m;
size_t c = ggml_nbytes(t); size_t c = ggml_nbytes(t);
return a + b + c; return a + b + c;
@ -792,35 +812,41 @@ struct test_mul_mat_id : public test_case {
test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int n_mats = 2, int id = 0, int n_mats = 2, int id = 0,
int64_t m = 32, int64_t n = 32, int64_t k = 32, int64_t m = 32, int64_t n = 32, int64_t k = 32, bool v = false)
std::array<int64_t, 2> bs = {10, 10},
std::array<int64_t, 2> nr = {2, 2})
: type_a(type_a), type_b(type_b), n_mats(n_mats), id(id), : type_a(type_a), type_b(type_b), n_mats(n_mats), id(id),
m(m), n(n), k(k), bs(bs), nr(nr) {} m(m), n(n), k(k), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n) // C^T = A * B^T: (k, m) * (k, n) => (m, n)
std::vector<ggml_tensor *> mats; std::vector<ggml_tensor *> mats;
for (int i = 0; i < n_mats; i++) { for (int i = 0; i < n_mats; i++) {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); ggml_tensor * a = ggml_new_tensor_2d(ctx, type_a, k, m);
mats.push_back(a); mats.push_back(a);
} }
ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_mats); ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); if (v) {
ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), ids, id, b); ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0);
}
ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n);
ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), n_mats, ids, v ? id/2 : id, b);
return out; return out;
} }
void initialize_tensors(ggml_context * ctx) override { void initialize_tensors(ggml_context * ctx) override {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) { if (t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) { continue; }
// ids // ids
std::vector<int> data(n_mats); for (int64_t r = 0; r < ggml_nrows(t); r++) {
for (int i = 0; i < n_mats; i++) { std::vector<int32_t> data(t->ne[0]);
data[i] = i; for (int i = 0; i < t->ne[0]; i++) {
data[i] = i % n_mats;
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
} }
std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()()));
ggml_backend_tensor_set(t, data.data(), 0, n_mats * sizeof(int));
} else { } else {
init_tensor_uniform(t); init_tensor_uniform(t);
} }
@ -1109,6 +1135,90 @@ struct test_sum_rows : public test_case {
} }
}; };
// Mixtral MOE
struct test_moe : public test_case {
const int n_experts;
const int n_experts_per_tok;
const int n_tokens;
const int n_embd;
const int n_ff;
std::string op_desc(ggml_tensor * t) override {
return "MOE";
GGML_UNUSED(t);
}
std::string vars() override {
return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff);
}
test_moe(int n_experts = 8, int n_experts_per_tok = 2, int n_tokens = 1, int n_embd = 4096, int n_ff = 14336)
: n_experts(n_experts), n_experts_per_tok(n_experts_per_tok), n_tokens(n_tokens), n_embd(n_embd), n_ff(n_ff) {
}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_experts);
std::vector<ggml_tensor *> ffn_up_exp(n_experts);
std::vector<ggml_tensor *> ffn_gate_exp(n_experts);
std::vector<ggml_tensor *> ffn_down_exp(n_experts);
for (int i = 0; i < n_experts; ++i) {
ffn_up_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
ffn_gate_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
ffn_down_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
}
ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur);
ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, 1.0f/sqrtf(n_embd));
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok);
ggml_tensor * weights = ggml_get_rows(ctx,
ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts);
weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights);
weights = ggml_div(ctx, weights, weights_sum);
// compute expert outputs
ggml_tensor * moe_out = nullptr;
for (int i = 0; i < n_experts_per_tok; ++i) {
ggml_tensor * cur_expert;
ggml_tensor * cur_up = ggml_mul_mat_id(ctx, ffn_up_exp.data(), n_experts, selected_experts, i, cur);
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx, ffn_gate_exp.data(), n_experts, selected_experts, i, cur);
cur_gate = ggml_silu(ctx, cur_gate);
cur_expert = ggml_mul(ctx, cur_up, cur_gate);
cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert);
cur_expert = ggml_mul(ctx, cur_expert,
ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx, moe_out, cur_expert);
}
}
cur = moe_out;
return cur;
}
};
enum test_mode { enum test_mode {
MODE_TEST, MODE_TEST,
MODE_PERF, MODE_PERF,
@ -1117,14 +1227,28 @@ enum test_mode {
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
std::vector<std::unique_ptr<test_case>> test_cases; std::vector<std::unique_ptr<test_case>> test_cases;
const ggml_type all_types[] = {
GGML_TYPE_F32, GGML_TYPE_F16,
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K
};
// unary ops // unary ops
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) { for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
test_cases.emplace_back(new test_unary((ggml_unary_op) op)); test_cases.emplace_back(new test_unary((ggml_unary_op) op));
} }
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
test_cases.emplace_back(new test_get_rows(type, 10, 5, 3)); for (ggml_type type : all_types) {
test_cases.emplace_back(new test_get_rows(type, 16, 5, 3)); for (int b : {1, 7}) {
for (bool v : {false, true}) {
test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, v));
}
}
} }
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
@ -1134,7 +1258,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2}));
test_cases.emplace_back(new test_dup()); test_cases.emplace_back(new test_dup());
test_cases.emplace_back(new test_cpy());
for (ggml_type type : all_types) {
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, type, {256, 10, 10, 1}));
}
test_cases.emplace_back(new test_cont()); test_cases.emplace_back(new test_cont());
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) { auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
@ -1144,6 +1272,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}; };
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1});
@ -1170,8 +1299,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1}); //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1}); //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
test_cases.emplace_back(new test_scale()); test_cases.emplace_back(new test_scale());
@ -1180,16 +1309,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps)); test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
} }
const ggml_type all_types[] = {
GGML_TYPE_F32, GGML_TYPE_F16,
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K
};
for (ggml_type type_a : all_types) { for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
// FIXME: CPU crashes on f16xf16 // FIXME: CPU crashes on f16xf16
@ -1213,9 +1332,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (ggml_type type_a : all_types) { for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {1, 2, 4}) { for (int n_mats : {2, 4, 8}) {
for (int id = 0; id < n_mats; id++) { for (int id = 0; id < n_mats; id++) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, {1, 1}, {1, 1})); for (bool v : {false, true}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v));
}
} }
} }
} }
@ -1247,10 +1368,18 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_concat()); test_cases.emplace_back(new test_concat());
for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) { for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
} }
test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, {10, 10, 10, 10}));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, {2, 1, 1, 1}));
#if !defined(__SANITIZE_THREAD__)
// FIXME: these tests use too much memory with thread sanitizer
test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 14336));
//test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336));
#endif
// run tests // run tests
if (mode == MODE_TEST) { if (mode == MODE_TEST) {
@ -1267,14 +1396,17 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
ggml_backend_free(backend_cpu); ggml_backend_free(backend_cpu);
return n_ok == test_cases.size(); return n_ok == test_cases.size();
} else if (mode == MODE_PERF) { }
if (mode == MODE_PERF) {
for (auto & test : test_cases) { for (auto & test : test_cases) {
test->eval_perf(backend, op_name); test->eval_perf(backend, op_name);
} }
return true; return true;
} else {
GGML_ASSERT(false);
} }
GGML_ASSERT(false);
return false;
} }
static void usage(char ** argv) { static void usage(char ** argv) {
@ -1347,11 +1479,12 @@ int main(int argc, char ** argv) {
} }
printf("%zu/%zu backends passed\n", n_ok, ggml_backend_reg_get_count()); printf("%zu/%zu backends passed\n", n_ok, ggml_backend_reg_get_count());
if (n_ok != ggml_backend_reg_get_count()) { if (n_ok != ggml_backend_reg_get_count()) {
printf("\033[1;31mFAIL\033[0m\n"); printf("\033[1;31mFAIL\033[0m\n");
return 1; return 1;
} else {
printf("\033[1;32mOK\033[0m\n");
return 0;
} }
printf("\033[1;32mOK\033[0m\n");
return 0;
} }

View File

@ -1,4 +1,4 @@
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
#include "ggml.h" #include "ggml.h"
#include <cmath> #include <cmath>

View File

@ -117,7 +117,7 @@ static void usage(char * argv[]) {
printf(" --size SIZE set test size, divisible by 32 (L1_SIZE:%d)\n", L1_SIZE); printf(" --size SIZE set test size, divisible by 32 (L1_SIZE:%d)\n", L1_SIZE);
printf(" -3 use size as L1, L2, L3 sizes (L1:%d L2:%d L3:%d)\n", L1_SIZE, L2_SIZE, L3_SIZE); printf(" -3 use size as L1, L2, L3 sizes (L1:%d L2:%d L3:%d)\n", L1_SIZE, L2_SIZE, L3_SIZE);
printf(" -4 use size as L1, L2, L3, MEM sizes (L1:%d L2:%d L3:%d MEM:%d)\n", L1_SIZE, L2_SIZE, L3_SIZE, MEM_SIZE); printf(" -4 use size as L1, L2, L3, MEM sizes (L1:%d L2:%d L3:%d MEM:%d)\n", L1_SIZE, L2_SIZE, L3_SIZE, MEM_SIZE);
printf(" --op OP set test opration as quantize_row_q_reference, quantize_row_q, dequantize_row_q,\n"); printf(" --op OP set test operation as quantize_row_q_reference, quantize_row_q, dequantize_row_q,\n");
printf(" quantize_row_q_dot, vec_dot_q (all)\n"); printf(" quantize_row_q_dot, vec_dot_q (all)\n");
printf(" --type TYPE set test type as"); printf(" --type TYPE set test type as");
for (int i = 0; i < GGML_TYPE_COUNT; i++) { for (int i = 0; i < GGML_TYPE_COUNT; i++) {
@ -202,7 +202,7 @@ int main(int argc, char * argv[]) {
} }
int alignment = std::stoi(argv[i]); int alignment = std::stoi(argv[i]);
if (alignment < 0 || alignment > MAX_ALIGNMENT) { if (alignment < 0 || alignment > MAX_ALIGNMENT) {
fprintf(stderr, "error: aligment-offset must be less than %d\n", MAX_ALIGNMENT); fprintf(stderr, "error: alignment-offset must be less than %d\n", MAX_ALIGNMENT);
invalid_param = true; invalid_param = true;
break; break;
} }