Merge branch 'master' into xsn/fix_lora

This commit is contained in:
ngxson 2024-07-15 11:06:47 +02:00
commit 703573f608
51 changed files with 2212 additions and 146823 deletions

View File

@ -18,6 +18,7 @@
vulkan-headers, vulkan-headers,
vulkan-loader, vulkan-loader,
curl, curl,
shaderc,
useBlas ? builtins.all (x: !x) [ useBlas ? builtins.all (x: !x) [
useCuda useCuda
useMetalKit useMetalKit
@ -146,6 +147,7 @@ let
vulkanBuildInputs = [ vulkanBuildInputs = [
vulkan-headers vulkan-headers
vulkan-loader vulkan-loader
shaderc
]; ];
in in

View File

@ -8,7 +8,7 @@ arg1="$1"
shift shift
if [[ "$arg1" == '--convert' || "$arg1" == '-c' ]]; then if [[ "$arg1" == '--convert' || "$arg1" == '-c' ]]; then
python3 ./convert-hf-to-gguf.py "$@" python3 ./convert_hf_to_gguf.py "$@"
elif [[ "$arg1" == '--quantize' || "$arg1" == '-q' ]]; then elif [[ "$arg1" == '--quantize' || "$arg1" == '-q' ]]; then
./llama-quantize "$@" ./llama-quantize "$@"
elif [[ "$arg1" == '--run' || "$arg1" == '-r' ]]; then elif [[ "$arg1" == '--run' || "$arg1" == '-r' ]]; then

View File

@ -355,8 +355,10 @@ jobs:
- name: Dependencies - name: Dependencies
id: depends id: depends
run: | run: |
sudo apt-get update wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add -
sudo apt-get install build-essential libvulkan-dev sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list
sudo apt-get update -y
sudo apt-get install -y build-essential vulkan-sdk
- name: Build - name: Build
id: cmake_build id: cmake_build

5
.gitignore vendored
View File

@ -61,6 +61,11 @@ llama-batched-swift
out/ out/
tmp/ tmp/
# Deprecated
/main
/server
# CI # CI
!.github/workflows/*.yml !.github/workflows/*.yml

View File

@ -132,7 +132,16 @@ set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location o
set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
get_directory_property(LLAMA_TRANSIENT_DEFINES COMPILE_DEFINITIONS)
# At the moment some compile definitions are placed within the ggml/src
# directory but not exported on the `ggml` target. This could be improved by
# determining _precisely_ which defines are necessary for the llama-config
# package.
#
get_directory_property(GGML_DIR_DEFINES DIRECTORY ggml/src COMPILE_DEFINITIONS)
get_target_property(GGML_TARGET_DEFINES ggml COMPILE_DEFINITIONS)
set(GGML_TRANSIENT_DEFINES ${GGML_TARGET_DEFINES} ${GGML_DIR_DEFINES})
get_target_property(GGML_LINK_LIBRARIES ggml LINK_LIBRARIES)
set_target_properties(llama PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/llama.h) set_target_properties(llama PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/llama.h)
install(TARGETS llama LIBRARY PUBLIC_HEADER) install(TARGETS llama LIBRARY PUBLIC_HEADER)

View File

@ -197,6 +197,10 @@ ifdef GGML_RPC
BUILD_TARGETS += rpc-server BUILD_TARGETS += rpc-server
endif endif
ifdef GGML_VULKAN
BUILD_TARGETS += vulkan-shaders-gen
endif
default: $(BUILD_TARGETS) $(LEGACY_TARGETS_BUILD) default: $(BUILD_TARGETS) $(LEGACY_TARGETS_BUILD)
test: $(TEST_TARGETS) test: $(TEST_TARGETS)
@ -547,11 +551,17 @@ ifdef GGML_OPENBLAS64
endif # GGML_OPENBLAS64 endif # GGML_OPENBLAS64
ifdef GGML_BLIS ifdef GGML_BLIS
MK_CPPFLAGS += -DGGML_USE_BLAS -I/usr/local/include/blis -I/usr/include/blis MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_BLIS -I/usr/local/include/blis -I/usr/include/blis
MK_LDFLAGS += -lblis -L/usr/local/lib MK_LDFLAGS += -lblis -L/usr/local/lib
OBJ_GGML += ggml/src/ggml-blas.o OBJ_GGML += ggml/src/ggml-blas.o
endif # GGML_BLIS endif # GGML_BLIS
ifdef GGML_NVPL
MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_NVPL -DNVPL_ILP64 -I/usr/local/include/nvpl_blas -I/usr/include/nvpl_blas
MK_LDFLAGS += -L/usr/local/lib -lnvpl_blas_core -lnvpl_blas_ilp64_gomp
OBJ_GGML += ggml/src/ggml-blas.o
endif # GGML_NVPL
ifndef GGML_NO_LLAMAFILE ifndef GGML_NO_LLAMAFILE
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
OBJ_GGML += ggml/src/llamafile/sgemm.o OBJ_GGML += ggml/src/llamafile/sgemm.o
@ -704,8 +714,8 @@ endif # GGML_CUDA
ifdef GGML_VULKAN ifdef GGML_VULKAN
MK_CPPFLAGS += -DGGML_USE_VULKAN MK_CPPFLAGS += -DGGML_USE_VULKAN
MK_LDFLAGS += -lvulkan MK_LDFLAGS += $(shell pkg-config --libs vulkan)
OBJ_GGML += ggml/src/ggml-vulkan.o OBJ_GGML += ggml/src/ggml-vulkan.o ggml/src/ggml-vulkan-shaders.o
ifdef GGML_VULKAN_CHECK_RESULTS ifdef GGML_VULKAN_CHECK_RESULTS
MK_CPPFLAGS += -DGGML_VULKAN_CHECK_RESULTS MK_CPPFLAGS += -DGGML_VULKAN_CHECK_RESULTS
@ -727,10 +737,28 @@ ifdef GGML_VULKAN_RUN_TESTS
MK_CPPFLAGS += -DGGML_VULKAN_RUN_TESTS MK_CPPFLAGS += -DGGML_VULKAN_RUN_TESTS
endif endif
ggml/src/ggml-vulkan.o: \ GLSLC_CMD = glslc
ggml/src/ggml-vulkan.cpp \ _ggml_vk_genshaders_cmd = $(shell pwd)/vulkan-shaders-gen
ggml/include/ggml-vulkan.h _ggml_vk_header = ggml/src/ggml-vulkan-shaders.hpp
$(CXX) $(CXXFLAGS) -c $< -o $@ _ggml_vk_source = ggml/src/ggml-vulkan-shaders.cpp
_ggml_vk_input_dir = ggml/src/vulkan-shaders
_ggml_vk_shader_deps = $(echo $(_ggml_vk_input_dir)/*.comp)
ggml/src/ggml-vulkan.o: ggml/src/ggml-vulkan.cpp ggml/include/ggml-vulkan.h $(_ggml_vk_header) $(_ggml_vk_source)
$(CXX) $(CXXFLAGS) $(shell pkg-config --cflags vulkan) -c $< -o $@
$(_ggml_vk_header): $(_ggml_vk_source)
$(_ggml_vk_source): $(_ggml_vk_shader_deps) vulkan-shaders-gen
$(_ggml_vk_genshaders_cmd) \
--glslc $(GLSLC_CMD) \
--input-dir $(_ggml_vk_input_dir) \
--target-hpp $(_ggml_vk_header) \
--target-cpp $(_ggml_vk_source)
vulkan-shaders-gen: ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
$(CXX) $(CXXFLAGS) -o $@ $(LDFLAGS) ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
endif # GGML_VULKAN endif # GGML_VULKAN
ifdef GGML_HIPBLAS ifdef GGML_HIPBLAS
@ -1110,6 +1138,7 @@ clean:
rm -vrf ggml/src/ggml-cuda/template-instances/*.o rm -vrf ggml/src/ggml-cuda/template-instances/*.o
rm -rvf $(BUILD_TARGETS) rm -rvf $(BUILD_TARGETS)
rm -rvf $(TEST_TARGETS) rm -rvf $(TEST_TARGETS)
rm -f vulkan-shaders-gen ggml/src/ggml-vulkan-shaders.hpp ggml/src/ggml-vulkan-shaders.cpp
rm -rvf $(LEGACY_TARGETS_CLEAN) rm -rvf $(LEGACY_TARGETS_CLEAN)
find examples pocs -type f -name "*.o" -delete find examples pocs -type f -name "*.o" -delete

View File

@ -8,6 +8,13 @@ set(GGML_CUDA @GGML_CUDA@)
set(GGML_METAL @GGML_METAL@) set(GGML_METAL @GGML_METAL@)
set(GGML_HIPBLAS @GGML_HIPBLAS@) set(GGML_HIPBLAS @GGML_HIPBLAS@)
set(GGML_ACCELERATE @GGML_ACCELERATE@) set(GGML_ACCELERATE @GGML_ACCELERATE@)
set(GGML_VULKAN @GGML_VULKAN@)
set(GGML_VULKAN_CHECK_RESULTS @GGML_VULKAN_CHECK_RESULTS@)
set(GGML_VULKAN_DEBUG @GGML_VULKAN_DEBUG@)
set(GGML_VULKAN_MEMORY_DEBUG @GGML_VULKAN_MEMORY_DEBUG@)
set(GGML_VULKAN_VALIDATE @GGML_VULKAN_VALIDATE@)
set(GGML_SYCL @GGML_SYCL@)
set(GGML_OPENMP @GGML_OPENMP@)
@PACKAGE_INIT@ @PACKAGE_INIT@
@ -37,18 +44,36 @@ if (GGML_METAL)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
endif() endif()
if (GGML_VULKAN)
find_package(Vulkan REQUIRED)
endif()
if (GGML_HIPBLAS) if (GGML_HIPBLAS)
find_package(hip REQUIRED) find_package(hip REQUIRED)
find_package(hipblas REQUIRED) find_package(hipblas REQUIRED)
find_package(rocblas REQUIRED) find_package(rocblas REQUIRED)
endif() endif()
if (GGML_SYCL)
find_package(IntelSYCL REQUIRED)
find_package(MKL REQUIRED)
endif()
if (GGML_OPENMP)
find_package(OpenMP REQUIRED)
endif()
find_library(ggml_LIBRARY ggml
REQUIRED
HINTS ${LLAMA_LIB_DIR})
find_library(llama_LIBRARY llama find_library(llama_LIBRARY llama
REQUIRED REQUIRED
HINTS ${LLAMA_LIB_DIR}) HINTS ${LLAMA_LIB_DIR})
set(_llama_link_deps "Threads::Threads" "@LLAMA_EXTRA_LIBS@") set(_llama_link_deps "${ggml_LIBRARY}" "@GGML_LINK_LIBRARIES@")
set(_llama_transient_defines "@LLAMA_TRANSIENT_DEFINES@") set(_llama_transient_defines "@GGML_TRANSIENT_DEFINES@")
add_library(llama UNKNOWN IMPORTED) add_library(llama UNKNOWN IMPORTED)

View File

@ -373,6 +373,29 @@ class Model:
except KeyError: except KeyError:
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
def does_token_look_special(self, token: str | bytes) -> bool:
if isinstance(token, (bytes, bytearray)):
token_text = token.decode(encoding="utf-8")
elif isinstance(token, memoryview):
token_text = token.tobytes().decode(encoding="utf-8")
else:
token_text = token
# Some models mark some added tokens which ought to be control tokens as not special.
# (e.g. command-r, command-r-plus, deepseek-coder, gemma{,-2})
seems_special = token_text in (
"<pad>", # deepseek-coder
"<mask>", "<2mass>", "[@BOS@]", # gemma{,-2}
)
seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>"))
seems_special = seems_special or (token_text.startswith("<") and token_text.endswith(">")) # deepseek-coder
# TODO: should these be marked as UNUSED instead? (maybe not)
seems_special = seems_special or (token_text.startswith("<unused") and token_text.endswith(">")) # gemma{,-2}
return seems_special
# used for GPT-2 BPE and WordPiece vocabs # used for GPT-2 BPE and WordPiece vocabs
def get_vocab_base(self) -> tuple[list[str], list[int], str]: def get_vocab_base(self) -> tuple[list[str], list[int], str]:
tokens: list[str] = [] tokens: list[str] = []
@ -391,16 +414,18 @@ class Model:
for i in range(vocab_size): for i in range(vocab_size):
if i not in reverse_vocab: if i not in reverse_vocab:
tokens.append(f"[PAD{i}]") tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED) toktypes.append(gguf.TokenType.UNUSED)
elif reverse_vocab[i] in added_vocab: else:
tokens.append(reverse_vocab[i]) token: str = reverse_vocab[i]
if tokenizer.added_tokens_decoder[i].special: if token in added_vocab:
if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL) toktypes.append(gguf.TokenType.CONTROL)
else: else:
token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces
toktypes.append(gguf.TokenType.USER_DEFINED) toktypes.append(gguf.TokenType.USER_DEFINED)
else: else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL) toktypes.append(gguf.TokenType.NORMAL)
tokens.append(token)
return tokens, toktypes, tokpre return tokens, toktypes, tokpre
@ -559,7 +584,7 @@ class Model:
for i in range(vocab_size): for i in range(vocab_size):
if i not in reverse_vocab: if i not in reverse_vocab:
tokens.append(f"[PAD{i}]") tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED) toktypes.append(gguf.TokenType.UNUSED)
elif reverse_vocab[i] in added_vocab: elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i]) tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.CONTROL) toktypes.append(gguf.TokenType.CONTROL)
@ -609,7 +634,7 @@ class Model:
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
for token_id in range(tokenizer.vocab_size()): for token_id in range(tokenizer.vocab_size()):
piece = tokenizer.IdToPiece(token_id) piece = tokenizer.IdToPiece(token_id)
@ -644,6 +669,25 @@ class Model:
scores[token_id] = -1000.0 scores[token_id] = -1000.0
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)
added_tokens_decoder = tokenizer_config_json.get("added_tokens_decoder", {})
for token_id, token_data in added_tokens_decoder.items():
token_id = int(token_id)
token: str = token_data["content"]
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
assert tokens[token_id] == token.encode("utf-8")
if token_data.get("special") or self.does_token_look_special(token):
toktypes[token_id] = SentencePieceTokenTypes.CONTROL
else:
token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
scores[token_id] = -1000.0
tokens[token_id] = token.encode("utf-8")
if vocab_size > len(tokens): if vocab_size > len(tokens):
pad_count = vocab_size - len(tokens) pad_count = vocab_size - len(tokens)
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
@ -1203,11 +1247,10 @@ class RefactModel(Model):
# TODO: how to determine special FIM tokens automatically? # TODO: how to determine special FIM tokens automatically?
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False, special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot']) special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
special_vocab._set_special_token("prefix", 1) special_vocab._set_special_token("prefix", 1)
special_vocab._set_special_token("suffix", 3) special_vocab._set_special_token("suffix", 3)
special_vocab._set_special_token("middle", 2) special_vocab._set_special_token("middle", 2)
special_vocab._set_special_token("fsep", 4) # is this correct?
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self): def set_gguf_parameters(self):
@ -1267,7 +1310,7 @@ class StableLMModel(Model):
if (self.dir_model / "tokenizer.json").is_file(): if (self.dir_model / "tokenizer.json").is_file():
self._set_vocab_gpt2() self._set_vocab_gpt2()
else: else:
# StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab # StableLM 2 1.6B used to have a vocab in a similar format to Qwen's vocab
self._set_vocab_qwen() self._set_vocab_qwen()
def set_gguf_parameters(self): def set_gguf_parameters(self):
@ -1579,7 +1622,6 @@ class DbrxModel(Model):
self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"]) self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])
self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"]) self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"]) self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"]) self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"])
@ -1873,7 +1915,7 @@ class Phi3MiniModel(Model):
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
for token_id in range(tokenizer.vocab_size()): for token_id in range(tokenizer.vocab_size()):
@ -1918,7 +1960,7 @@ class Phi3MiniModel(Model):
for token_id, foken_data in added_tokens_decoder.items(): for token_id, foken_data in added_tokens_decoder.items():
token_id = int(token_id) token_id = int(token_id)
token = foken_data["content"].encode("utf-8") token = foken_data["content"].encode("utf-8")
if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
assert tokens[token_id] == token assert tokens[token_id] == token
tokens[token_id] = token tokens[token_id] = token
scores[token_id] = -1000.0 scores[token_id] = -1000.0
@ -1934,7 +1976,7 @@ class Phi3MiniModel(Model):
for foken_data in added_tokens: for foken_data in added_tokens:
token_id = int(foken_data["id"]) token_id = int(foken_data["id"])
token = foken_data["content"].encode("utf-8") token = foken_data["content"].encode("utf-8")
if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
assert tokens[token_id] == token assert tokens[token_id] == token
tokens[token_id] = token tokens[token_id] = token
scores[token_id] = -1000.0 scores[token_id] = -1000.0
@ -2146,7 +2188,7 @@ class InternLM2Model(Model):
toktype = SentencePieceTokenTypes.BYTE toktype = SentencePieceTokenTypes.BYTE
# take care of ununsed raw token # take care of ununsed raw token
if piece.startswith('[UNUSED'): if piece.startswith('[UNUSED'):
toktype = SentencePieceTokenTypes.UNKNOWN toktype = SentencePieceTokenTypes.UNUSED
tokens.append(text) tokens.append(text)
scores.append(score) scores.append(score)
@ -2176,7 +2218,7 @@ class InternLM2Model(Model):
if token == chat_eos_token: if token == chat_eos_token:
chat_eos_token_id = token_id chat_eos_token_id = token_id
token = token.encode("utf-8") token = token.encode("utf-8")
if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
assert(tokens[token_id] == token) assert(tokens[token_id] == token)
tokens[token_id] = token tokens[token_id] = token
scores[token_id] = -1000.0 scores[token_id] = -1000.0
@ -2195,7 +2237,7 @@ class InternLM2Model(Model):
if token == chat_eos_token: if token == chat_eos_token:
chat_eos_token_id = token_id chat_eos_token_id = token_id
token = token.encode("utf-8") token = token.encode("utf-8")
if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
assert(tokens[token_id] == token) assert(tokens[token_id] == token)
tokens[token_id] = token tokens[token_id] = token
scores[token_id] = -1000.0 scores[token_id] = -1000.0
@ -2424,19 +2466,7 @@ class Gemma2Model(Model):
model_arch = gguf.MODEL_ARCH.GEMMA2 model_arch = gguf.MODEL_ARCH.GEMMA2
def set_vocab(self): def set_vocab(self):
tokens, scores, toktypes = self._create_vocab_sentencepiece() self._set_vocab_sentencepiece()
# hack: This is required so that we can properly use start/end-of-turn for chat template
for i in range(108):
# including <unusedX>, <start_of_turn>, <end_of_turn>
toktypes[i] = SentencePieceTokenTypes.CONTROL
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_space_prefix(False) self.gguf_writer.add_add_space_prefix(False)
@ -2463,11 +2493,6 @@ class Gemma2Model(Model):
) )
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
# sanity check
attn_scalar = self.hparams["query_pre_attn_scalar"]
if attn_scalar != hparams["hidden_size"] / hparams["num_attention_heads"]:
raise ValueError("query_pre_attn_scalar must be equal to n_embd / n_head")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused del bid # unused
@ -2760,7 +2785,7 @@ class ArcticModel(Model):
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
for token_id in range(tokenizer.vocab_size()): for token_id in range(tokenizer.vocab_size()):
@ -3015,7 +3040,7 @@ class T5Model(Model):
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
for token_id in range(tokenizer.vocab_size()): for token_id in range(tokenizer.vocab_size()):
piece = tokenizer.IdToPiece(token_id) piece = tokenizer.IdToPiece(token_id)
@ -3233,15 +3258,14 @@ class ChatGLMModel(Model):
if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size(): if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size():
score = tokenizer.tokenizer.sp_model.get_score(token_id) score = tokenizer.tokenizer.sp_model.get_score(token_id)
if len(piece) == 0:
text = f"[PAD{token_id}]".encode("utf-8")
if token_id >= tokenizer.tokenizer.sp_model.vocab_size(): if token_id >= tokenizer.tokenizer.sp_model.vocab_size():
if piece in special_tokens: if piece in special_tokens:
# show special tokens in prompt toktype = SentencePieceTokenTypes.CONTROL
toktype = SentencePieceTokenTypes.USER_DEFINED elif len(piece) == 0:
text = f"[PAD{token_id}]".encode("utf-8")
toktype = SentencePieceTokenTypes.UNUSED
else: else:
toktype = SentencePieceTokenTypes.UNKNOWN toktype = SentencePieceTokenTypes.USER_DEFINED
tokens.append(text) tokens.append(text)
scores.append(score) scores.append(score)
toktypes.append(toktype) toktypes.append(toktype)
@ -3330,7 +3354,7 @@ class ChatGLMModel(Model):
for i in range(vocab_size): for i in range(vocab_size):
if i not in reverse_vocab: if i not in reverse_vocab:
tokens.append(f"[PAD{i}]") tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED) toktypes.append(gguf.TokenType.UNUSED)
elif reverse_vocab[i] in added_vocab: elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i]) tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special: if tokenizer.added_tokens_decoder[i].special:

View File

@ -242,6 +242,45 @@ The following compilation options are also available to tweak performance (yes,
### Vulkan ### Vulkan
**Windows**
#### w64devkit
Download and extract [w64devkit](https://github.com/skeeto/w64devkit/releases).
Download and install the [Vulkan SDK](https://vulkan.lunarg.com/sdk/home#windows). When selecting components, only the Vulkan SDK Core is required.
Launch `w64devkit.exe` and run the following commands to copy Vulkan dependencies:
```sh
SDK_VERSION=1.3.283.0
cp /VulkanSDK/$SDK_VERSION/Bin/glslc.exe $W64DEVKIT_HOME/bin/
cp /VulkanSDK/$SDK_VERSION/Lib/vulkan-1.lib $W64DEVKIT_HOME/x86_64-w64-mingw32/lib/
cp -r /VulkanSDK/$SDK_VERSION/Include/* $W64DEVKIT_HOME/x86_64-w64-mingw32/include/
cat > $W64DEVKIT_HOME/x86_64-w64-mingw32/lib/pkgconfig/vulkan.pc <<EOF
Name: Vulkan-Loader
Description: Vulkan Loader
Version: $SDK_VERSION
Libs: -lvulkan-1
EOF
```
Switch into the `llama.cpp` directory and run `make GGML_VULKAN=1`.
#### MSYS2
Install [MSYS2](https://www.msys2.org/) and then run the following commands in a UCRT terminal to install dependencies.
```sh
pacman -S git \
mingw-w64-ucrt-x86_64-gcc \
mingw-w64-ucrt-x86_64-cmake \
mingw-w64-ucrt-x86_64-vulkan-devel \
mingw-w64-ucrt-x86_64-shaderc
```
Switch into `llama.cpp` directory and build using CMake.
```sh
cmake -B build -DGGML_VULKAN=ON
cmake --build build --config Release
```
**With docker**: **With docker**:
You don't need to install Vulkan SDK. It will be installed inside the container. You don't need to install Vulkan SDK. It will be installed inside the container.

View File

@ -99,7 +99,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
char src1_str[128] = {0}; char src1_str[128] = {0};
if (src1) { if (src1) {
sprintf(src1_str, "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str());
} }
printf("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, printf("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,

View File

@ -347,7 +347,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
char hex_result[17]; char hex_result[17];
for (int offset = 0; offset < 8; offset++) { for (int offset = 0; offset < 8; offset++) {
unsigned int shift_bits_by = (8 * (8 - offset - 1)); unsigned int shift_bits_by = (8 * (8 - offset - 1));
sprintf( ( hex_result + (2*offset)), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff);
} }
if (hash_params.manifest_is_usable) { if (hash_params.manifest_is_usable) {
@ -384,7 +384,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
char hex_result[41] = {0}; char hex_result[41] = {0};
for (int offset = 0; offset < 20; offset++) { for (int offset = 0; offset < 20; offset++) {
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
} }
if (hash_params.manifest_is_usable) { if (hash_params.manifest_is_usable) {
@ -421,7 +421,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0}; char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0};
for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) { for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) {
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
} }
if (hash_params.manifest_is_usable) { if (hash_params.manifest_is_usable) {
@ -460,7 +460,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
char hex_result[17]; char hex_result[17];
for (int offset = 0; offset < 8; offset++) { for (int offset = 0; offset < 8; offset++) {
unsigned int shift_bits_by = (8 * (8 - offset - 1)); unsigned int shift_bits_by = (8 * (8 - offset - 1));
sprintf( ( hex_result + (2*offset)), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff);
} }
if (hash_params.manifest_is_usable) { if (hash_params.manifest_is_usable) {
@ -490,7 +490,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
char hex_result[41]; char hex_result[41];
for (int offset = 0; offset < 20; offset++) { for (int offset = 0; offset < 20; offset++) {
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
} }
if (hash_params.manifest_is_usable) { if (hash_params.manifest_is_usable) {
@ -520,7 +520,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0}; char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0};
for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) { for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) {
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
} }
if (hash_params.manifest_is_usable) { if (hash_params.manifest_is_usable) {
@ -552,7 +552,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
generate_uuidv5(result, uuid); generate_uuidv5(result, uuid);
char string_buffer[37] = {0}; char string_buffer[37] = {0};
sprintf(string_buffer, "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", snprintf(string_buffer, sizeof(string_buffer), "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
uuid[0], uuid[1], uuid[2], uuid[3], uuid[0], uuid[1], uuid[2], uuid[3],
uuid[4], uuid[5], uuid[6], uuid[7], uuid[4], uuid[5], uuid[6], uuid[7],
uuid[8], uuid[9], uuid[10], uuid[11], uuid[8], uuid[9], uuid[10], uuid[11],

View File

@ -289,8 +289,13 @@ int main(int argc, char ** argv) {
// Should not run without any tokens // Should not run without any tokens
if (embd_inp.empty()) { if (embd_inp.empty()) {
if (add_bos) {
embd_inp.push_back(llama_token_bos(model)); embd_inp.push_back(llama_token_bos(model));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
} else {
LOG_TEE("error: input is empty\n");
return -1;
}
} }
// Tokenize negative prompt // Tokenize negative prompt

View File

@ -6,7 +6,7 @@ import re
from copy import copy from copy import copy
from enum import Enum from enum import Enum
from inspect import getdoc, isclass from inspect import getdoc, isclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
from docstring_parser import parse from docstring_parser import parse
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
@ -53,35 +53,38 @@ class PydanticDataType(Enum):
def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str: def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str:
if isclass(pydantic_type) and issubclass(pydantic_type, str): origin_type = get_origin(pydantic_type)
origin_type = pydantic_type if origin_type is None else origin_type
if isclass(origin_type) and issubclass(origin_type, str):
return PydanticDataType.STRING.value return PydanticDataType.STRING.value
elif isclass(pydantic_type) and issubclass(pydantic_type, bool): elif isclass(origin_type) and issubclass(origin_type, bool):
return PydanticDataType.BOOLEAN.value return PydanticDataType.BOOLEAN.value
elif isclass(pydantic_type) and issubclass(pydantic_type, int): elif isclass(origin_type) and issubclass(origin_type, int):
return PydanticDataType.INTEGER.value return PydanticDataType.INTEGER.value
elif isclass(pydantic_type) and issubclass(pydantic_type, float): elif isclass(origin_type) and issubclass(origin_type, float):
return PydanticDataType.FLOAT.value return PydanticDataType.FLOAT.value
elif isclass(pydantic_type) and issubclass(pydantic_type, Enum): elif isclass(origin_type) and issubclass(origin_type, Enum):
return PydanticDataType.ENUM.value return PydanticDataType.ENUM.value
elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel): elif isclass(origin_type) and issubclass(origin_type, BaseModel):
return format_model_and_field_name(pydantic_type.__name__) return format_model_and_field_name(origin_type.__name__)
elif get_origin(pydantic_type) is list: elif origin_type is list:
element_type = get_args(pydantic_type)[0] element_type = get_args(pydantic_type)[0]
return f"{map_pydantic_type_to_gbnf(element_type)}-list" return f"{map_pydantic_type_to_gbnf(element_type)}-list"
elif get_origin(pydantic_type) is set: elif origin_type is set:
element_type = get_args(pydantic_type)[0] element_type = get_args(pydantic_type)[0]
return f"{map_pydantic_type_to_gbnf(element_type)}-set" return f"{map_pydantic_type_to_gbnf(element_type)}-set"
elif get_origin(pydantic_type) is Union: elif origin_type is Union:
union_types = get_args(pydantic_type) union_types = get_args(pydantic_type)
union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types] union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
return f"union-{'-or-'.join(union_rules)}" return f"union-{'-or-'.join(union_rules)}"
elif get_origin(pydantic_type) is Optional: elif origin_type is Optional:
element_type = get_args(pydantic_type)[0] element_type = get_args(pydantic_type)[0]
return f"optional-{map_pydantic_type_to_gbnf(element_type)}" return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
elif isclass(pydantic_type): elif isclass(origin_type):
return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}" return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(origin_type.__name__)}"
elif get_origin(pydantic_type) is dict: elif origin_type is dict:
key_type, value_type = get_args(pydantic_type) key_type, value_type = get_args(pydantic_type)
return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}" return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
else: else:
@ -118,7 +121,7 @@ def get_members_structure(cls, rule_name):
# Modify this comprehension # Modify this comprehension
members = [ members = [
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}' f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}'
for name, param_type in cls.__annotations__.items() for name, param_type in get_type_hints(cls).items()
if name != "self" if name != "self"
] ]
@ -297,17 +300,20 @@ def generate_gbnf_rule_for_type(
field_name = format_model_and_field_name(field_name) field_name = format_model_and_field_name(field_name)
gbnf_type = map_pydantic_type_to_gbnf(field_type) gbnf_type = map_pydantic_type_to_gbnf(field_type)
if isclass(field_type) and issubclass(field_type, BaseModel): origin_type = get_origin(field_type)
origin_type = field_type if origin_type is None else origin_type
if isclass(origin_type) and issubclass(origin_type, BaseModel):
nested_model_name = format_model_and_field_name(field_type.__name__) nested_model_name = format_model_and_field_name(field_type.__name__)
nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules) nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules)
rules.extend(nested_model_rules) rules.extend(nested_model_rules)
gbnf_type, rules = nested_model_name, rules gbnf_type, rules = nested_model_name, rules
elif isclass(field_type) and issubclass(field_type, Enum): elif isclass(origin_type) and issubclass(origin_type, Enum):
enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes
enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}" enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
rules.append(enum_rule) rules.append(enum_rule)
gbnf_type, rules = model_name + "-" + field_name, rules gbnf_type, rules = model_name + "-" + field_name, rules
elif get_origin(field_type) == list: # Array elif origin_type is list: # Array
element_type = get_args(field_type)[0] element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type( element_rule_name, additional_rules = generate_gbnf_rule_for_type(
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
@ -317,7 +323,7 @@ def generate_gbnf_rule_for_type(
rules.append(array_rule) rules.append(array_rule)
gbnf_type, rules = model_name + "-" + field_name, rules gbnf_type, rules = model_name + "-" + field_name, rules
elif get_origin(field_type) == set or field_type == set: # Array elif origin_type is set: # Array
element_type = get_args(field_type)[0] element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type( element_rule_name, additional_rules = generate_gbnf_rule_for_type(
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
@ -371,7 +377,7 @@ def generate_gbnf_rule_for_type(
gbnf_type = f"{model_name}-{field_name}-optional" gbnf_type = f"{model_name}-{field_name}-optional"
else: else:
gbnf_type = f"{model_name}-{field_name}-union" gbnf_type = f"{model_name}-{field_name}-union"
elif isclass(field_type) and issubclass(field_type, str): elif isclass(origin_type) and issubclass(origin_type, str):
if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None: if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None:
triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False) triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False)
markdown_string = field_info.json_schema_extra.get("markdown_code_block", False) markdown_string = field_info.json_schema_extra.get("markdown_code_block", False)
@ -387,8 +393,8 @@ def generate_gbnf_rule_for_type(
gbnf_type = PydanticDataType.STRING.value gbnf_type = PydanticDataType.STRING.value
elif ( elif (
isclass(field_type) isclass(origin_type)
and issubclass(field_type, float) and issubclass(origin_type, float)
and field_info and field_info
and hasattr(field_info, "json_schema_extra") and hasattr(field_info, "json_schema_extra")
and field_info.json_schema_extra is not None and field_info.json_schema_extra is not None
@ -413,8 +419,8 @@ def generate_gbnf_rule_for_type(
) )
elif ( elif (
isclass(field_type) isclass(origin_type)
and issubclass(field_type, int) and issubclass(origin_type, int)
and field_info and field_info
and hasattr(field_info, "json_schema_extra") and hasattr(field_info, "json_schema_extra")
and field_info.json_schema_extra is not None and field_info.json_schema_extra is not None
@ -462,7 +468,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
if not issubclass(model, BaseModel): if not issubclass(model, BaseModel):
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__ # For non-Pydantic classes, generate model_fields from __annotations__ or __init__
if hasattr(model, "__annotations__") and model.__annotations__: if hasattr(model, "__annotations__") and model.__annotations__:
model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} # pyright: ignore[reportGeneralTypeIssues] model_fields = {name: (typ, ...) for name, typ in get_type_hints(model).items()}
else: else:
init_signature = inspect.signature(model.__init__) init_signature = inspect.signature(model.__init__)
parameters = init_signature.parameters parameters = init_signature.parameters
@ -470,7 +476,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
name != "self"} name != "self"}
else: else:
# For Pydantic models, use model_fields and check for ellipsis (required fields) # For Pydantic models, use model_fields and check for ellipsis (required fields)
model_fields = model.__annotations__ model_fields = get_type_hints(model)
model_rule_parts = [] model_rule_parts = []
nested_rules = [] nested_rules = []
@ -706,7 +712,7 @@ def generate_markdown_documentation(
else: else:
documentation += f" Fields:\n" # noqa: F541 documentation += f" Fields:\n" # noqa: F541
if isclass(model) and issubclass(model, BaseModel): if isclass(model) and issubclass(model, BaseModel):
for name, field_type in model.__annotations__.items(): for name, field_type in get_type_hints(model).items():
# if name == "markdown_code_block": # if name == "markdown_code_block":
# continue # continue
if get_origin(field_type) == list: if get_origin(field_type) == list:
@ -754,14 +760,17 @@ def generate_field_markdown(
field_info = model.model_fields.get(field_name) field_info = model.model_fields.get(field_name)
field_description = field_info.description if field_info and field_info.description else "" field_description = field_info.description if field_info and field_info.description else ""
if get_origin(field_type) == list: origin_type = get_origin(field_type)
origin_type = field_type if origin_type is None else origin_type
if origin_type == list:
element_type = get_args(field_type)[0] element_type = get_args(field_type)[0]
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})" field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
if field_description != "": if field_description != "":
field_text += ":\n" field_text += ":\n"
else: else:
field_text += "\n" field_text += "\n"
elif get_origin(field_type) == Union: elif origin_type == Union:
element_types = get_args(field_type) element_types = get_args(field_type)
types = [] types = []
for element_type in element_types: for element_type in element_types:
@ -792,9 +801,9 @@ def generate_field_markdown(
example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example
field_text += f"{indent} Example: {example_text}\n" field_text += f"{indent} Example: {example_text}\n"
if isclass(field_type) and issubclass(field_type, BaseModel): if isclass(origin_type) and issubclass(origin_type, BaseModel):
field_text += f"{indent} Details:\n" field_text += f"{indent} Details:\n"
for name, type_ in field_type.__annotations__.items(): for name, type_ in get_type_hints(field_type).items():
field_text += generate_field_markdown(name, type_, field_type, depth + 2) field_text += generate_field_markdown(name, type_, field_type, depth + 2)
return field_text return field_text
@ -855,7 +864,7 @@ def generate_text_documentation(
if isclass(model) and issubclass(model, BaseModel): if isclass(model) and issubclass(model, BaseModel):
documentation_fields = "" documentation_fields = ""
for name, field_type in model.__annotations__.items(): for name, field_type in get_type_hints(model).items():
# if name == "markdown_code_block": # if name == "markdown_code_block":
# continue # continue
if get_origin(field_type) == list: if get_origin(field_type) == list:
@ -948,7 +957,7 @@ def generate_field_text(
if isclass(field_type) and issubclass(field_type, BaseModel): if isclass(field_type) and issubclass(field_type, BaseModel):
field_text += f"{indent} Details:\n" field_text += f"{indent} Details:\n"
for name, type_ in field_type.__annotations__.items(): for name, type_ in get_type_hints(field_type).items():
field_text += generate_field_text(name, type_, field_type, depth + 2) field_text += generate_field_text(name, type_, field_type, depth + 2)
return field_text return field_text

View File

@ -20,6 +20,8 @@ def create_completion(prompt, grammar):
response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data) response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data)
data = response.json() data = response.json()
assert data.get("error") is None, data
print(data["content"]) print(data["content"])
return data["content"] return data["content"]

View File

@ -154,7 +154,7 @@ static void test_roundtrip_on_chunk(
} }
if (use_reference) { if (use_reference) {
qfns.from_float_reference(input_scratch, quantized_scratch, chunk_size); qfns.from_float_ref(input_scratch, quantized_scratch, chunk_size);
} else { } else {
qfns.from_float(input_scratch, quantized_scratch, chunk_size); qfns.from_float(input_scratch, quantized_scratch, chunk_size);
} }

View File

@ -737,6 +737,8 @@ struct server_context {
slot.ga_n = ga_n; slot.ga_n = ga_n;
slot.ga_w = ga_w; slot.ga_w = ga_w;
slot.sparams = params.sparams;
slot.reset(); slot.reset();
slots.push_back(slot); slots.push_back(slot);
@ -2003,6 +2005,11 @@ struct server_context {
int32_t n_batch = llama_n_batch(ctx); int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx);
// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
// next, batch any pending prompts without exceeding n_batch // next, batch any pending prompts without exceeding n_batch
if (params.cont_batching || batch.n_tokens == 0) { if (params.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) { for (auto & slot : slots) {
@ -2173,6 +2180,14 @@ struct server_context {
} }
} }
// check that we are in the right batch_type, if not defer the slot
bool slot_type = slot.embedding ? 1 : 0;
if (batch_type == -1) {
batch_type = slot_type;
} else if (batch_type != slot_type) {
continue;
}
// keep only the common part // keep only the common part
int p0 = (int) system_tokens.size() + slot.n_past; int p0 = (int) system_tokens.size() + slot.n_past;
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
@ -2274,6 +2289,9 @@ struct server_context {
{"n_tokens", batch.n_tokens}, {"n_tokens", batch.n_tokens},
}); });
// make sure we're in the right embedding mode
llama_set_embeddings(ctx, batch_type == 1);
// process the created batch of tokens // process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
@ -2988,6 +3006,11 @@ int main(int argc, char ** argv) {
}; };
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = json::parse(req.body); json data = json::parse(req.body);
@ -3083,6 +3106,11 @@ int main(int argc, char ** argv) {
}; };
const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
@ -3155,6 +3183,11 @@ int main(int argc, char ** argv) {
}; };
const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = json::parse(req.body); json data = json::parse(req.body);
@ -3241,13 +3274,8 @@ int main(int argc, char ** argv) {
return res.set_content(data.dump(), "application/json; charset=utf-8"); return res.set_content(data.dump(), "application/json; charset=utf-8");
}; };
const auto handle_embeddings = [&params, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) {
res.status = 501;
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
return;
}
const json body = json::parse(req.body); const json body = json::parse(req.body);
bool is_openai = false; bool is_openai = false;

View File

@ -122,8 +122,26 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
for (size_t i = 0; i < messages.size(); ++i) { for (size_t i = 0; i < messages.size(); ++i) {
const auto & curr_msg = messages[i]; const auto & curr_msg = messages[i];
std::string role = json_value(curr_msg, "role", std::string("")); std::string role = json_value(curr_msg, "role", std::string(""));
std::string content = json_value(curr_msg, "content", std::string(""));
std::string content;
if (curr_msg.contains("content")) {
if (curr_msg["content"].is_string()) {
content = curr_msg["content"].get<std::string>();
} else if (curr_msg["content"].is_array()) {
for (const auto & part : curr_msg["content"]) {
if (part.contains("text")) {
content += "\n" + part["text"].get<std::string>();
}
}
} else {
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
}
} else {
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
}
chat.push_back({role, content}); chat.push_back({role, content});
} }

View File

@ -29,6 +29,7 @@ static void print_usage_information(const char * argv0, FILE * stream) {
fprintf(stream, " -p PROMPT, --prompt PROMPT read prompt from the argument.\n"); fprintf(stream, " -p PROMPT, --prompt PROMPT read prompt from the argument.\n");
fprintf(stream, " --stdin read prompt from standard input.\n"); fprintf(stream, " --stdin read prompt from standard input.\n");
fprintf(stream, " --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n"); fprintf(stream, " --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n");
fprintf(stream, " --no-parse-special do not parse control tokens.\n");
fprintf(stream, " --log-disable disable logs. Makes stderr quiet when loading the model.\n"); fprintf(stream, " --log-disable disable logs. Makes stderr quiet when loading the model.\n");
fprintf(stream, " --show-count print the total number of tokens.\n"); fprintf(stream, " --show-count print the total number of tokens.\n");
} }
@ -195,6 +196,7 @@ int main(int raw_argc, char ** raw_argv) {
// variables where to put any arguments we see. // variables where to put any arguments we see.
bool printing_ids = false; bool printing_ids = false;
bool no_bos = false; bool no_bos = false;
bool no_parse_special = false;
bool disable_logging = false; bool disable_logging = false;
bool show_token_count = false; bool show_token_count = false;
const char * model_path = NULL; const char * model_path = NULL;
@ -229,6 +231,9 @@ int main(int raw_argc, char ** raw_argv) {
else if (arg == "--no-bos") { else if (arg == "--no-bos") {
no_bos = true; no_bos = true;
} }
else if (arg == "--no-parse-special") {
no_parse_special = true;
}
else if (arg == "-p" || arg == "--prompt") { else if (arg == "-p" || arg == "--prompt") {
if (prompt_set) { if (prompt_set) {
fprintf(stderr, "Error: -p or --prompt specified multiple times.\n"); fprintf(stderr, "Error: -p or --prompt specified multiple times.\n");
@ -359,9 +364,10 @@ int main(int raw_argc, char ** raw_argv) {
const bool model_wants_add_bos = llama_should_add_bos_token(model); const bool model_wants_add_bos = llama_should_add_bos_token(model);
const bool add_bos = model_wants_add_bos && !no_bos; const bool add_bos = model_wants_add_bos && !no_bos;
const bool parse_special = !no_parse_special;
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
tokens = ::llama_tokenize(model, prompt, add_bos, true); tokens = ::llama_tokenize(model, prompt, add_bos, parse_special);
if (printing_ids) { if (printing_ids) {
printf("["); printf("[");

View File

@ -20,11 +20,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1720031269, "lastModified": 1720768451,
"narHash": "sha256-rwz8NJZV+387rnWpTYcXaRNvzUSnnF9aHONoJIYmiUQ=", "narHash": "sha256-EYekUHJE2gxeo2pM/zM9Wlqw1Uw2XTJXOSAO79ksc4Y=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "9f4128e00b0ae8ec65918efeba59db998750ead6", "rev": "7e7c39ea35c5cdd002cd4588b03a3fb9ece6fad9",
"type": "github" "type": "github"
}, },
"original": { "original": {

2
ggml/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
src/ggml-vulkan-shaders.hpp
src/ggml-vulkan-shaders.cpp

View File

@ -1,220 +0,0 @@
#!/usr/bin/env python
import logging
import argparse
import asyncio
import os
from tempfile import gettempdir
logger = logging.getLogger("ggml-vk-generate-shaders")
GLSLC = "glslc"
type_names = [
"f32",
"f16",
"q4_0",
"q4_1",
"q5_0",
"q5_1",
"q8_0",
"q2_k",
"q3_k",
"q4_k",
"q5_k",
"q6_k",
]
ASYNCIO_CONCURRENCY = 64
input_dir = "vulkan-shaders"
output_dir = gettempdir()
lock = asyncio.Lock()
shader_fnames = []
async def string_to_spv(name, in_fname, defines, fp16=True):
name = f"{name}{'_fp32' if not fp16 else ''}"
out_fname = os.path.join(output_dir, f"{name}.spv")
in_path = os.path.join(input_dir, in_fname)
cmd = [GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname]
cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
stdout, stderr = await proc.communicate()
stdout = stdout.decode()
error = stderr.decode()
if proc.returncode:
cmd = " ".join(cmd)
logger.error(f"cannot compile {name}\n\n{cmd}\n\n{error}")
return
async with lock:
shader_fnames.append((name, out_fname))
def matmul_shaders(tasks, fp16, matmul_id):
if fp16:
load_vec = "8"
aligned_b_type_f32 = "mat2x4"
aligned_b_type_f16 = "f16mat2x4"
else:
load_vec = "4"
aligned_b_type_f32 = "vec4"
aligned_b_type_f16 = "f16vec4"
base_dict = {"FLOAT_TYPE": "float" if not fp16 else "float16_t"}
shader_name = "matmul"
if matmul_id:
base_dict["MUL_MAT_ID"] = "1"
shader_name = "matmul_id"
if fp16:
base_dict["FLOAT16"] = "1"
# Shaders with f16 B_TYPE
tasks.append(string_to_spv(f"{shader_name}_f32_f16", "mul_mm.comp", base_dict | {"DATA_A_F32": "1", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv(f"{shader_name}_f32_f16_aligned", "mul_mm.comp", base_dict | {"DATA_A_F32": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "B_TYPE": aligned_b_type_f16, "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv(f"{shader_name}_f16", "mul_mm.comp", base_dict | {"DATA_A_F16": "1", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv(f"{shader_name}_f16_aligned", "mul_mm.comp", base_dict | {"DATA_A_F16": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "B_TYPE": aligned_b_type_f16, "D_TYPE": "float"}, fp16))
for tname in type_names:
data_a_key = f"DATA_A_{tname.upper()}"
load_vec_a = load_vec if tname in ("f32", "f16") else "2"
tasks.append(string_to_spv(f"{shader_name}_{tname}_f32", "mul_mm.comp", base_dict | {data_a_key: "1", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv(f"{shader_name}_{tname}_f32_aligned", "mul_mm.comp", base_dict | {data_a_key: "2", "LOAD_VEC_A": load_vec_a, "LOAD_VEC_B": load_vec, "B_TYPE": aligned_b_type_f32, "D_TYPE": "float"}, fp16))
async def main():
logger.info("ggml_vulkan: Generating and compiling shaders to SPIR-V")
tasks = []
base_dict = {"FLOAT_TYPE": "float"}
for fp16 in (False, True):
# MUL_MAT
matmul_shaders(tasks, fp16, False)
# MUL_MAT_ID
matmul_shaders(tasks, fp16, True)
for tname in type_names:
# mul mat vec
data_a_key = f"DATA_A_{tname.upper()}"
shader = f"mul_mat_vec_{tname}.comp" if tname.endswith("_k") else "mul_mat_vec.comp"
tasks.append(string_to_spv(f"mul_mat_vec_{tname}_f32_f32", shader, base_dict | {data_a_key: "1", "B_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv(f"mul_mat_vec_{tname}_f16_f32", shader, base_dict | {data_a_key: "1", "B_TYPE": "float16_t", "D_TYPE": "float"}))
tasks.append(string_to_spv(f"mul_mat_vec_id_{tname}_f32", shader, base_dict | {"MUL_MAT_ID": "1", data_a_key: "1", "B_TYPE": "float", "D_TYPE": "float"}))
# Dequant shaders
if tname != "f16":
tasks.append(string_to_spv(f"dequant_{tname}", f"dequant_{tname}.comp", base_dict | {data_a_key: "1", "D_TYPE": "float16_t"}))
# get_rows
if not tname.endswith("_k"):
shader = "get_rows.comp" if tname in ("f32", "f16") else "get_rows_quant.comp"
if tname == "f16":
tasks.append(string_to_spv(f"get_rows_{tname}", shader, {data_a_key: "1", "B_TYPE": "int", "D_TYPE": "float16_t", "OPTIMIZATION_ERROR_WORKAROUND": "1"}))
else:
tasks.append(string_to_spv(f"get_rows_{tname}", shader, {data_a_key: "1", "B_TYPE": "int", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv(f"get_rows_{tname}_f32", shader, {data_a_key: "1", "B_TYPE": "int", "D_TYPE": "float"}))
tasks.append(string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
# Norms
tasks.append(string_to_spv("norm_f32", "norm.comp", base_dict | {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rms_norm_f32", "rms_norm.comp", base_dict | {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("cpy_f32_f32", "copy.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("cpy_f32_f16", "copy.comp", {"A_TYPE": "float", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv("cpy_f16_f16", "copy.comp", {"A_TYPE": "float16_t", "D_TYPE": "float16_t", "OPTIMIZATION_ERROR_WORKAROUND": "1"}))
tasks.append(string_to_spv("add_f32", "add.comp", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
tasks.append(string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}))
tasks.append(string_to_spv("mul_f32", "mul.comp", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
tasks.append(string_to_spv("div_f32", "div.comp", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
tasks.append(string_to_spv("scale_f32", "scale.comp", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
tasks.append(string_to_spv("sqr_f32", "square.comp", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
tasks.append(string_to_spv("clamp_f32", "clamp.comp", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
tasks.append(string_to_spv("gelu_f32", "gelu.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("silu_f32", "silu.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("relu_f32", "relu.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("soft_max_f32", "soft_max.comp", base_dict | {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("soft_max_f32_f16", "soft_max.comp", base_dict | {"A_TYPE": "float", "B_TYPE": "float16_t", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_norm_f32", "rope_norm.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_norm_f16", "rope_norm.comp", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv("rope_neox_f32", "rope_neox.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_neox_f16", "rope_neox.comp", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv("argsort_f32", "argsort.comp", {"A_TYPE": "float"}))
tasks.append(string_to_spv("sum_rows_f32", "sum_rows.comp", base_dict | {"A_TYPE": "float", "D_TYPE": "float"}))
# Helper to decorate tasks with semaphore acquisition.
async def withSemaphore(sem, task):
async with sem:
return await task
# Run tasks concurrently guarded by a concurrency limit.
sem = asyncio.Semaphore(ASYNCIO_CONCURRENCY)
await asyncio.gather(*(withSemaphore(sem, task) for task in tasks))
with open("ggml-vulkan-shaders.hpp", "w") as f:
f.write("#include <cstdint>\n\n")
for name, path in sorted(shader_fnames):
with open(path, "rb") as spv:
counter = 0
newline_counter = 0
f.write(f"unsigned char {name}_data[] = {{\n")
for val in spv.read():
f.write(f"0x{val:02x},")
newline_counter += 1
counter += 1
if newline_counter >= 12:
newline_counter = 0
f.write("\n")
f.write("\n};\n")
f.write(f"const uint64_t {name}_len = {counter};\n\n")
os.remove(path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GGML Vulkan Shader Generator")
parser.add_argument("--glslc", help="Path to glslc")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
if args.glslc:
GLSLC = args.glslc
asyncio.run(main())

View File

@ -714,7 +714,7 @@ extern "C" {
GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor); GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor);
GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
GGML_API GGML_CALL int ggml_blck_size(enum ggml_type type); GGML_API GGML_CALL int64_t ggml_blck_size(enum ggml_type type);
GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
@ -2410,10 +2410,10 @@ extern "C" {
#endif #endif
typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
typedef void (*ggml_from_float_to_mat_t)
(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs);
typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
const void * GGML_RESTRICT y, size_t by, int nrc); const void * GGML_RESTRICT y, size_t by, int nrc);
typedef void (*ggml_from_float_to_mat_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr,
int64_t k, int64_t bx);
typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
const void * GGML_RESTRICT y, int nr, int nc); const void * GGML_RESTRICT y, int nr, int nc);
typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
@ -2421,18 +2421,18 @@ extern "C" {
typedef struct { typedef struct {
const char * type_name; const char * type_name;
int blck_size; int64_t blck_size;
int64_t blck_size_interleave; // interleave elements in blocks
size_t type_size; size_t type_size;
bool is_quantized; bool is_quantized;
ggml_to_float_t to_float; ggml_to_float_t to_float;
ggml_from_float_t from_float; ggml_from_float_t from_float;
ggml_from_float_t from_float_reference; ggml_from_float_t from_float_ref;
ggml_from_float_to_mat_t from_float_to_mat;
ggml_vec_dot_t vec_dot; ggml_vec_dot_t vec_dot;
enum ggml_type vec_dot_type; enum ggml_type vec_dot_type;
int64_t nrows; // number of rows to process simultaneously; int64_t nrows; // number of rows to process simultaneously
int64_t ncols; // number of columns to process simultaneously; int64_t ncols; // number of columns to process simultaneously
int64_t interleave_blcksize; // interleave elements in blocks of interleave_blcksize;
ggml_from_float_to_mat_t from_float_to_mat;
ggml_gemv_t gemv; ggml_gemv_t gemv;
ggml_gemm_t gemm; ggml_gemm_t gemm;
} ggml_type_traits_t; } ggml_type_traits_t;

View File

@ -527,14 +527,11 @@ if (GGML_RPC)
endif() endif()
if (GGML_VULKAN) if (GGML_VULKAN)
find_package(Vulkan) find_package(Vulkan COMPONENTS glslc REQUIRED)
if (Vulkan_FOUND) if (Vulkan_FOUND)
message(STATUS "Vulkan found") message(STATUS "Vulkan found")
set(GGML_HEADERS_VULKAN ../include/ggml-vulkan.h)
set(GGML_SOURCES_VULKAN ggml-vulkan.cpp)
list(APPEND GGML_CDEF_PUBLIC GGML_USE_VULKAN) list(APPEND GGML_CDEF_PUBLIC GGML_USE_VULKAN)
# Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build
@ -563,7 +560,37 @@ if (GGML_VULKAN)
add_compile_definitions(GGML_VULKAN_RUN_TESTS) add_compile_definitions(GGML_VULKAN_RUN_TESTS)
endif() endif()
add_subdirectory(vulkan-shaders)
set (_ggml_vk_genshaders_cmd vulkan-shaders-gen)
set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
add_custom_command(
OUTPUT ${_ggml_vk_header}
${_ggml_vk_source}
COMMAND ${_ggml_vk_genshaders_cmd}
--glslc ${Vulkan_GLSLC_EXECUTABLE}
--input-dir ${_ggml_vk_input_dir}
--output-dir ${_ggml_vk_output_dir}
--target-hpp ${_ggml_vk_header}
--target-cpp ${_ggml_vk_source}
--no-clean
DEPENDS ${_ggml_vk_shader_deps}
COMMENT "Generate vulkan shaders"
)
set(GGML_HEADERS_VULKAN ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml-vulkan.h ${_ggml_vk_header})
set(GGML_SOURCES_VULKAN ggml-vulkan.cpp ${_ggml_vk_source})
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} Vulkan::Vulkan) set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} Vulkan::Vulkan)
set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${CMAKE_CURRENT_BINARY_DIR})
else() else()
message(WARNING "Vulkan not found") message(WARNING "Vulkan not found")
endif() endif()

View File

@ -20,19 +20,19 @@
// Functions to create the interleaved data layout formats // Functions to create the interleaved data layout formats
// interleave 4 block_q4_0s in blocks of interleave_blcksize // interleave 4 block_q4_0s in blocks of blck_size_interleave
// returns an interleaved block_q4_0x4 // returns an interleaved block_q4_0x4
// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks // in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks
// first, then interleave quants from 4 block_q4_0s in blocks of interleave_blcksize // first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave
// //
// - in : an array of block_q4_0 pointers // - in : an array of block_q4_0 pointers
// - interleave_blcksize : the block_q4_0 quants bytes are interleaved in blocks of // - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of
// interleave_blcksize bytes // blck_size_interleave bytes
// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes // - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes
// from bias offset form to pure sign form (this saves subtract // from bias offset form to pure sign form (this saves subtract
// operations durin unpacking) // operations durin unpacking)
// //
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_blcksize, unsigned int xor_mask) { static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
block_q4_0x4 out; block_q4_0x4 out;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
@ -40,9 +40,9 @@ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_b
} }
for (int i = 0; i < QK4_0 * 2; i++) { for (int i = 0; i < QK4_0 * 2; i++) {
int src_offset = (i / (4 * interleave_blcksize)) * interleave_blcksize; int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave;
int src_id = (i % (4 * interleave_blcksize)) / interleave_blcksize; int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave;
src_offset += (i % interleave_blcksize); src_offset += (i % blck_size_interleave);
out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
} }
@ -50,11 +50,11 @@ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_b
return out; return out;
} }
// interleave 8 block_q4_0s in blocks of interleave_blcksize // interleave 8 block_q4_0s in blocks of blck_size_interleave
// returns an interleaved block_q4_0x8 // returns an interleaved block_q4_0x8
// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks // in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
// first, then interleave quants from 8 block_q4_0s in blocks of interleave_blcksize // first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int interleave_blcksize, unsigned int xor_mask) { static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
block_q4_0x8 out; block_q4_0x8 out;
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
@ -62,9 +62,9 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int interleave_b
} }
for (int i = 0; i < QK4_0 * 4; i++) { for (int i = 0; i < QK4_0 * 4; i++) {
int src_offset = (i / (8 * interleave_blcksize)) * interleave_blcksize; int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave;
int src_id = (i % (8 * interleave_blcksize)) / interleave_blcksize; int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave;
src_offset += (i % interleave_blcksize); src_offset += (i % blck_size_interleave);
out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
} }
@ -135,7 +135,7 @@ void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k)
} }
#else #else
// scalar // scalar
const int interleave_blcksize = 4; const int blck_size_interleave = 4;
float srcv[4][QK8_0]; float srcv[4][QK8_0];
float id[4]; float id[4];
@ -155,12 +155,12 @@ void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k)
} }
for (int j = 0; j < QK8_0 * 4; j++) { for (int j = 0; j < QK8_0 * 4; j++) {
int src_offset = (j / (4 * interleave_blcksize)) * interleave_blcksize; int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
int src_id = (j % (4 * interleave_blcksize)) / interleave_blcksize; int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
src_offset += (j % interleave_blcksize); src_offset += (j % blck_size_interleave);
float x0 = srcv[src_id][src_offset] * id[src_id]; float x0 = srcv[src_id][src_offset] * id[src_id];
y[i].qs[j] = roundf(x0);; y[i].qs[j] = roundf(x0);
} }
} }
#endif #endif
@ -253,7 +253,7 @@ void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k)
} }
#else #else
// scalar // scalar
const int interleave_blcksize = 8; const int blck_size_interleave = 8;
float srcv[4][QK8_0]; float srcv[4][QK8_0];
float id[4]; float id[4];
@ -273,26 +273,30 @@ void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k)
} }
for (int j = 0; j < QK8_0 * 4; j++) { for (int j = 0; j < QK8_0 * 4; j++) {
int src_offset = (j / (4 * interleave_blcksize)) * interleave_blcksize; int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
int src_id = (j % (4 * interleave_blcksize)) / interleave_blcksize; int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
src_offset += (j % interleave_blcksize); src_offset += (j % blck_size_interleave);
float x0 = srcv[src_id][src_offset] * id[src_id]; float x0 = srcv[src_id][src_offset] * id[src_id];
y[i].qs[j] = roundf(x0);; y[i].qs[j] = roundf(x0);
} }
} }
#endif #endif
} }
void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t interleave_blcksize) { void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
assert(nrow == 4); assert(nrow == 4);
UNUSED(nrow); UNUSED(nrow);
if (interleave_blcksize == 4) quantize_q8_0_4x4(x, vy, n_per_row); if (blck_size_interleave == 4) {
else if (interleave_blcksize == 8) quantize_q8_0_4x8(x, vy, n_per_row); quantize_q8_0_4x4(x, vy, n_per_row);
else assert(false); } else if (blck_size_interleave == 8) {
quantize_q8_0_4x8(x, vy, n_per_row);
} else {
assert(false);
}
} }
static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int interleave_blcksize) { static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) {
assert(n_per_row % QK4_0 == 0); assert(n_per_row % QK4_0 == 0);
const int nb = n_per_row / QK4_0; const int nb = n_per_row / QK4_0;
@ -311,15 +315,15 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds
for (int64_t x = 0; x < nb; x++) { for (int64_t x = 0; x < nb; x++) {
for (int i = 0; i < nrows_interleaved; i++ ) { for (int i = 0; i < nrows_interleaved; i++ ) {
quantize_row_q4_0_reference(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0); quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0);
} }
if (nrows_interleaved == 8) { if (nrows_interleaved == 8) {
*(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, interleave_blcksize, 0x88); *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88);
out_ptr = (block_q4_0x8 *) out_ptr + 1; out_ptr = (block_q4_0x8 *) out_ptr + 1;
} }
else if (nrows_interleaved == 4) { else if (nrows_interleaved == 4) {
*(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, interleave_blcksize, 0x88); *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88);
out_ptr = (block_q4_0x4 *) out_ptr + 1; out_ptr = (block_q4_0x4 *) out_ptr + 1;
} }
} }

View File

@ -16,7 +16,7 @@ extern "C" {
void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t interleave_blcksize); void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
@ -24,14 +24,14 @@ size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT d
size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
// GEMV // GEMV
void ggml_gemv_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_8x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
// GEMM // GEMM
void ggml_gemm_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -394,7 +394,7 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
// backend registry // backend registry
#define GGML_REG_MAX_BACKENDS 16 #define GGML_REG_MAX_BACKENDS 64
struct ggml_backend_reg { struct ggml_backend_reg {
char name[128]; char name[128];

View File

@ -8,11 +8,12 @@
# include <Accelerate/Accelerate.h> # include <Accelerate/Accelerate.h>
#elif defined(GGML_BLAS_USE_MKL) #elif defined(GGML_BLAS_USE_MKL)
# include <mkl.h> # include <mkl.h>
#elif defined(GGML_BLAS_USE_BLIS)
# include <blis.h>
#elif defined(GGML_BLAS_USE_NVPL)
# include <nvpl_blas.h>
#else #else
# include <cblas.h> # include <cblas.h>
# ifdef BLIS_ENABLE_CBLAS
# include <blis.h>
# endif
#endif #endif
struct ggml_backend_blas_context { struct ggml_backend_blas_context {
@ -140,10 +141,14 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
openblas_set_num_threads(ctx->n_threads); openblas_set_num_threads(ctx->n_threads);
#endif #endif
#if defined(BLIS_ENABLE_CBLAS) #if defined(GGML_BLAS_USE_BLIS)
bli_thread_set_num_threads(ctx->n_threads); bli_thread_set_num_threads(ctx->n_threads);
#endif #endif
#if defined(GGML_BLAS_USE_NVPL)
nvpl_blas_set_num_threads(ctx->n_threads);
#endif
for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) { for (int64_t i12 = 0; i12 < ne12; i12++) {
const int64_t i03 = i13/r3; const int64_t i03 = i13/r3;

View File

@ -104,7 +104,7 @@
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaStream_t hipStream_t #define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess #define cudaSuccess hipSuccess
#define __trap abort #define __trap() do { abort(); __builtin_unreachable(); } while(0)
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED

View File

@ -70,6 +70,10 @@ struct mma_int_A_I16K8 {
} }
#endif // defined(INT8_MMA_AVAILABLE) #endif // defined(INT8_MMA_AVAILABLE)
} }
__device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
}
}; };
struct mma_int_B_J8K4 { struct mma_int_B_J8K4 {

File diff suppressed because it is too large Load Diff

View File

@ -37,47 +37,92 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
reinterpret_cast<half&>(y[ib].ds.y) = sum; reinterpret_cast<half&>(y[ib].ds.y) = sum;
} }
template <bool need_sum> template <mmq_q8_1_ds_layout ds_layout>
static __global__ void quantize_mmq_q8_1( static __global__ void quantize_mmq_q8_1(
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
if (ix0 >= kx0_padded) { if (ix0 >= kx0_padded) {
return; return;
} }
const float4 * x4 = (const float4 *) x;
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y; const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
block_q8_1_mmq * y = (block_q8_1_mmq *) vy; block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f; // Load 4 floats per thread and calculate max. abs. value between them:
float amax = fabsf(xi); const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
float amax = fabsf(xi.x);
amax = fmaxf(amax, fabsf(xi.y));
amax = fmaxf(amax, fabsf(xi.z));
amax = fmaxf(amax, fabsf(xi.w));
amax = warp_reduce_max(amax); // Exchange max. abs. value between vals_per_scale/4 threads.
#pragma unroll
float sum; for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
if (need_sum) { amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
sum = warp_reduce_sum(xi);
} }
const float d = amax / 127; float sum;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
sum = xi.x + xi.y + xi.z + xi.w;
y[ib].qs[iqs] = q; // Exchange calculate sum across vals_per_sum/4 threads.
#pragma unroll
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
}
}
if (iqs % QK8_1 != 0) { const float d_inv = 127.0f / amax;
char4 q;
q.x = roundf(xi.x*d_inv);
q.y = roundf(xi.y*d_inv);
q.z = roundf(xi.z*d_inv);
q.w = roundf(xi.w*d_inv);
// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
char4 * yqs4 = (char4 *) y[ib].qs;
yqs4[iqs/4] = q;
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
if (iqs % 16 != 0 || iqs >= 96) {
return; return;
} }
if (need_sum) { y[ib].d2s6[2 + iqs/16] = sum;
y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
if (iqs % 64 != 0) {
return;
}
const float d = 1.0f / d_inv;
y[ib].d2s6[iqs/64] = d;
return;
}
if (iqs % 32 != 0) {
return;
}
const float d = 1.0f / d_inv;
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
y[ib].ds4[iqs/32] = make_half2(d, sum);
} else { } else {
((float *) y[ib].ds)[iqs/QK8_1] = d; y[ib].d4[iqs/32] = d;
} }
} }
@ -101,12 +146,24 @@ void quantize_mmq_q8_1_cuda(
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0); GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(block_num_x, kx1, channels); const dim3 num_blocks(block_num_x, kx1, channels);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
if (mmq_need_sum(type_x)) { switch (mmq_get_q8_1_ds_layout(type_x)) {
quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); case MMQ_Q8_1_DS_LAYOUT_D4:
} else { quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break;
case MMQ_Q8_1_DS_LAYOUT_DS4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break;
case MMQ_Q8_1_DS_LAYOUT_D2S6:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break;
default:
GGML_ASSERT(false);
break;
} }
} }

View File

@ -6,6 +6,10 @@
#include <cstdint> #include <cstdint>
#define CUDA_QUANTIZE_BLOCK_SIZE 256 #define CUDA_QUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
typedef void (*quantize_cuda_t)( typedef void (*quantize_cuda_t)(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,

View File

@ -189,7 +189,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
} }
#define VDR_Q2_K_Q8_1_MMVQ 1 #define VDR_Q2_K_Q8_1_MMVQ 1
#define VDR_Q2_K_Q8_1_MMQ 2 #define VDR_Q2_K_Q8_1_MMQ 4
// contiguous v/x values // contiguous v/x values
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
@ -219,32 +219,56 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
return dm2f.x*sumf_d - dm2f.y*sumf_m; return dm2f.x*sumf_d - dm2f.y*sumf_m;
} }
// contiguous u/y values // contiguous v/x + u/y values
template <int ns8>
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) { const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
float sumf_d = 0.0f; float sumf = 0.0f;
float sumf_m = 0.0f; float sumf_d8 = 0.0f;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]); const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
int sumi_d = 0; int sumi_d0 = 0;
int sumi_m = 0;
const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
int sumi_d1 = 0;
const int vi0 = v[i0/(QI8_1/2)];
#pragma unroll #pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) { for (int i = i0; i < i0 + QI8_1/2; ++i) {
const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303; sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
sumi_d = ggml_cuda_dp4a(vi, u[i], sumi_d); // SIMD dot product }
sumi_m = ggml_cuda_dp4a(0x01010101, u[i], sumi_m); sumf_d8 += dm2f0.x * sumi_d0;
#pragma unroll
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
}
sumf_d8 += dm2f1.x * sumi_d1;
if (i0/QI8_1 < ns8) {
const float2 s8f = __half22float2(s8[i0/QI8_1]);
sumf -= dm2f0.y*s8f.x;
sumf -= dm2f1.y*s8f.y;
} else {
int sumi_m0 = 0;
#pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) {
sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
}
sumf_d8 -= dm2f0.y * sumi_m0;
int sumi_m1 = 0;
#pragma unroll
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
}
sumf_d8 -= dm2f1.y * sumi_m1;
}
} }
sumf_d += dm2f.x * sumi_d; return sumf + d8*sumf_d8;
sumf_m += dm2f.y * sumi_m;
}
return d8*(sumf_d - sumf_m);
} }
#define VDR_Q3_K_Q8_1_MMVQ 1 #define VDR_Q3_K_Q8_1_MMVQ 1
@ -283,7 +307,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
return d3 * sumf; return d3 * sumf;
} }
// contiguous u/y values // contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
const float & d3, const float & d8) { const float & d3, const float & d8) {
@ -296,8 +320,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
#pragma unroll #pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) { for (int i = i0; i < i0 + QI8_1/2; ++i) {
const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404); sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
sumi_sc = ggml_cuda_dp4a(vi, u[i], sumi_sc); // SIMD dot product
} }
sumi += sumi_sc * scales[i0 / (QI8_1/2)]; sumi += sumi_sc * scales[i0 / (QI8_1/2)];
@ -334,7 +357,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
return dm4f.x*sumf_d - dm4f.y*sumf_m; return dm4f.x*sumf_d - dm4f.y*sumf_m;
} }
// contiguous u/y values // contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@ -397,7 +420,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
return dm5f.x*sumf_d - dm5f.y*sumf_m; return dm5f.x*sumf_d - dm5f.y*sumf_m;
} }
// contiguous u/y values // contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@ -451,13 +474,16 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
return d*sumf; return d*sumf;
} }
// contiguous u/y values // contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
const float & d6, const float * __restrict__ d8) { const float & d6, const float * __restrict__ d8) {
float sumf_d = 0.0f; float sumf_d = 0.0f;
const int sc_packed = get_int_b4(sc, 0);
const int8_t * sc_reg = (const int8_t *) &sc_packed;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
@ -471,7 +497,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
} }
sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
} }
return d6 * sumf_d; return d6 * sumf_d;

View File

@ -193,16 +193,16 @@ enum ggml_metal_kernel_type {
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
GGML_METAL_KERNEL_TYPE_CONCAT, GGML_METAL_KERNEL_TYPE_CONCAT,
GGML_METAL_KERNEL_TYPE_SQR, GGML_METAL_KERNEL_TYPE_SQR,
GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@ -651,14 +651,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
@ -810,8 +810,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
switch (op->src[0]->type) { switch (op->src[0]->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
switch (op->type) { switch (op->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
@ -824,8 +824,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
} }
case GGML_TYPE_F16: case GGML_TYPE_F16:
switch (op->type) { switch (op->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_F16:
return true; return true;
default: default:
return false; return false;
@ -837,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
{ {
return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1; return op->ne[3] == 1;
} }
default: default:
return false; return false;
@ -2775,8 +2775,8 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
switch (dstt) { switch (dstt) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
@ -2789,8 +2789,8 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
switch (dstt) { switch (dstt) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
default: GGML_ASSERT(false && "not implemented"); default: GGML_ASSERT(false && "not implemented");
}; };
} break; } break;

View File

@ -1219,9 +1219,10 @@ kernel void kernel_mul_mv_q8_0_f32(
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
} }
#define N_F32_F32 4 #define N_MV_T_T 4
void kernel_mul_mv_f32_f32_impl( template<typename T0, typename T04, typename T1, typename T14>
void kernel_mul_mv_impl(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
@ -1243,9 +1244,8 @@ void kernel_mul_mv_f32_f32_impl(
uint r3, uint r3,
uint3 tgpig, uint3 tgpig,
uint tiisg) { uint tiisg) {
const int64_t r0 = tgpig.x; const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F32_F32; const int64_t rb = tgpig.y*N_MV_T_T;
const int64_t im = tgpig.z; const int64_t im = tgpig.z;
const uint i12 = im%ne12; const uint i12 = im%ne12;
@ -1253,20 +1253,20 @@ void kernel_mul_mv_f32_f32_impl(
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
device const float * x = (device const float *) (src0 + offset0); device const T0 * x = (device const T0 *) (src0 + offset0);
if (ne00 < 128) { if (ne00 < 128) {
for (int row = 0; row < N_F32_F32; ++row) { for (int row = 0; row < N_MV_T_T; ++row) {
int r1 = rb + row; int r1 = rb + row;
if (r1 >= ne11) { if (r1 >= ne11) {
break; break;
} }
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
float sumf = 0; float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) { for (int i = tiisg; i < ne00; i += 32) {
sumf += (float) x[i] * (float) y[i]; sumf += (T0) x[i] * (T1) y[i];
} }
float all_sum = simd_sum(sumf); float all_sum = simd_sum(sumf);
@ -1275,32 +1275,32 @@ void kernel_mul_mv_f32_f32_impl(
} }
} }
} else { } else {
device const float4 * x4 = (device const float4 *)x; device const T04 * x4 = (device const T04 *) x;
for (int row = 0; row < N_F32_F32; ++row) { for (int row = 0; row < N_MV_T_T; ++row) {
int r1 = rb + row; int r1 = rb + row;
if (r1 >= ne11) { if (r1 >= ne11) {
break; break;
} }
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
device const float4 * y4 = (device const float4 *) y; device const T14 * y4 = (device const T14 *) y;
float sumf = 0; float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) { for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
} }
float all_sum = simd_sum(sumf); float all_sum = simd_sum(sumf);
if (tiisg == 0) { if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
} }
} }
} }
} }
[[host_name("kernel_mul_mv_f32_f32")]] template<typename T0, typename T04, typename T1, typename T14>
kernel void kernel_mul_mv_f32_f32( kernel void kernel_mul_mv(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
@ -1322,90 +1322,38 @@ kernel void kernel_mul_mv_f32_f32(
constant uint & r3, constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) { uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); kernel_mul_mv_impl<T0, T04, T1, T14>(
src0,
src1,
dst,
ne00,
ne01,
ne02,
nb00,
nb01,
nb02,
ne10,
ne11,
ne12,
nb10,
nb11,
nb12,
ne0,
ne1,
r2,
r3,
tgpig,
tiisg);
} }
#define N_F16_F16 4 typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
kernel void kernel_mul_mv_f16_f16( template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
device const char * src0, template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
device const char * src1, template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x; template<typename T, typename T4>
const int64_t rb = tgpig.y*N_F16_F16; kernel void kernel_mul_mv_1row(
const int64_t im = tgpig.z;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
device const half * x = (device const half *) (src0 + offset0);
if (ne00 < 128) {
for (int row = 0; row < N_F16_F16; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
sumf += (half) x[i] * (half) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
device const half4 * x4 = (device const half4 *)x;
for (int row = 0; row < N_F16_F16; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
device const half4 * y4 = (device const half4 *) y;
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
}
void kernel_mul_mv_f16_f32_1row_impl(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
@ -1437,7 +1385,7 @@ void kernel_mul_mv_f16_f32_1row_impl(
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
device const half * x = (device const half *) (src0 + offset0); device const T * x = (device const T *) (src0 + offset0);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
float sumf = 0; float sumf = 0;
@ -1450,153 +1398,29 @@ void kernel_mul_mv_f16_f32_1row_impl(
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
} }
} else { } else {
device const half4 * x4 = (device const half4 *) x; device const T4 * x4 = (device const T4 *) x;
device const float4 * y4 = (device const float4 *) y;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
[[host_name("kernel_mul_mv_f16_f32_1row")]]
kernel void kernel_mul_mv_f16_f32_1row(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
}
#define N_F16_F32 4
void kernel_mul_mv_f16_f32_impl(
device const char * src0,
device const char * src1,
device float * dst,
int64_t ne00,
int64_t ne01,
int64_t ne02,
uint64_t nb00,
uint64_t nb01,
uint64_t nb02,
int64_t ne10,
int64_t ne11,
int64_t ne12,
uint64_t nb10,
uint64_t nb11,
uint64_t nb12,
int64_t ne0,
int64_t ne1,
uint r2,
uint r3,
uint3 tgpig,
uint tiisg) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F16_F32;
const int64_t im = tgpig.z;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
device const half * x = (device const half *) (src0 + offset0);
if (ne00 < 128) {
for (int row = 0; row < N_F16_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
sumf += (float) x[i] * (float) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
device const half4 * x4 = (device const half4 *)x;
for (int row = 0; row < N_F16_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
device const float4 * y4 = (device const float4 *) y; device const float4 * y4 = (device const float4 *) y;
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) { for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
} }
float all_sum = simd_sum(sumf); float all_sum = simd_sum(sumf);
if (tiisg == 0) { if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
} }
} }
}
} }
[[host_name("kernel_mul_mv_f16_f32")]] typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
kernel void kernel_mul_mv_f16_f32(
device const char * src0, template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
}
// Assumes row size (ne00) is a multiple of 4 // Assumes row size (ne00) is a multiple of 4
kernel void kernel_mul_mv_f16_f32_l4( template<typename T, typename T4>
kernel void kernel_mul_mv_l4(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
@ -1628,14 +1452,14 @@ kernel void kernel_mul_mv_f16_f32_l4(
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
device const half4 * x4 = (device const half4 *) (src0 + offset0); device const T4 * x4 = (device const T4 *) (src0 + offset0);
for (int r1 = 0; r1 < nrows; ++r1) { for (int r1 = 0; r1 < nrows; ++r1) {
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
float sumf = 0; float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) { for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
} }
float all_sum = simd_sum(sumf); float all_sum = simd_sum(sumf);
@ -1645,6 +1469,10 @@ kernel void kernel_mul_mv_f16_f32_l4(
} }
} }
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
static float rope_yarn_ramp(const float low, const float high, const int i0) { static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low); const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y)); return 1.0f - min(1.0f, max(0.0f, y));
@ -2765,9 +2593,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
kernel void kernel_cpy_f16_f16( template<typename T0, typename T1>
device const half * src0, kernel void kernel_cpy(
device half * dst, device const void * src0,
device void * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
constant int64_t & ne02, constant int64_t & ne02,
@ -2798,138 +2627,20 @@ kernel void kernel_cpy_f16_f16(
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0]; dst_data[i00] = (T1) src[0];
} }
} }
kernel void kernel_cpy_f16_f32( typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
device const half * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
const int64_t i3 = n / (ne2*ne1*ne0); template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
}
}
kernel void kernel_cpy_f32_f16(
device const float * src0,
device half * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
}
}
kernel void kernel_cpy_f32_f32(
device const float * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
}
}
kernel void kernel_cpy_f32_q8_0( kernel void kernel_cpy_f32_q8_0(
device const float * src0, device const float * src0,
@ -5730,9 +5441,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
} }
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)> template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows( kernel void kernel_get_rows_q(
device const void * src0, device const void * src0,
device const char * src1, device const void * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant uint64_t & nb01, constant uint64_t & nb01,
@ -5745,27 +5456,24 @@ kernel void kernel_get_rows(
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) { uint3 tptg [[threads_per_threadgroup]]) {
//const int64_t i = tgpig;
//const int64_t r = ((device int32_t *) src1)[i];
const int64_t i10 = tgpig.x; const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y; const int64_t i11 = tgpig.y;
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
const int64_t i02 = i11; const int64_t i02 = i11;
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
float4x4 temp; float4x4 temp;
dequantize_func( dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
} }
} }
kernel void kernel_get_rows_f32( template<typename T>
kernel void kernel_get_rows_f(
device const void * src0, device const void * src0,
device const char * src1, device const void * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant uint64_t & nb01, constant uint64_t & nb01,
@ -5781,47 +5489,19 @@ kernel void kernel_get_rows_f32(
const int64_t i10 = tgpig.x; const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y; const int64_t i11 = tgpig.y;
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
const int64_t i02 = i11; const int64_t i02 = i11;
for (int ind = tiitg; ind < ne00; ind += tptg.x) { for (int ind = tiitg; ind < ne00; ind += tptg.x) {
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
}
}
kernel void kernel_get_rows_f16(
device const void * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb1,
constant uint64_t & nb2,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
const int64_t i02 = i11;
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
} }
} }
kernel void kernel_get_rows_i32( kernel void kernel_get_rows_i32(
device const void * src0, device const void * src0,
device const char * src1, device const void * src1,
device int32_t * dst, device int32_t * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant uint64_t & nb01, constant uint64_t & nb01,
@ -5837,13 +5517,13 @@ kernel void kernel_get_rows_i32(
const int64_t i10 = tgpig.x; const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y; const int64_t i11 = tgpig.y;
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
const int64_t i02 = i11; const int64_t i02 = i11;
for (int ind = tiitg; ind < ne00; ind += tptg.x) { for (int ind = tiitg; ind < ne00; ind += tptg.x) {
((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
} }
} }
@ -5860,8 +5540,8 @@ kernel void kernel_get_rows_i32(
#define SG_MAT_ROW 8 #define SG_MAT_ROW 8
// each block_q contains 16*nl weights // each block_q contains 16*nl weights
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
void kernel_mul_mm_impl(device const uchar * src0, kernel void kernel_mul_mm(device const uchar * src0,
device const uchar * src1, device const uchar * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
@ -5881,7 +5561,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup half * sa = (threadgroup half *)(shared_memory); threadgroup T * sa = (threadgroup T *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
const uint r0 = tgpig.y; const uint r0 = tgpig.y;
@ -5896,7 +5576,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
simdgroup_half8x8 ma[4]; simdgroup_T8x8 ma[4];
simdgroup_float8x8 mb[2]; simdgroup_float8x8 mb[2];
simdgroup_float8x8 c_res[8]; simdgroup_float8x8 c_res[8];
for (int i = 0; i < 8; i++){ for (int i = 0; i < 8; i++){
@ -5919,7 +5599,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory // load data and store to threadgroup memory
half4x4 temp_a; T4x4 temp_a;
dequantize_func(x, il, temp_a); dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@ -5939,7 +5619,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products // load matrices from threadgroup memory and conduct outer products
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
#pragma unroll(4) #pragma unroll(4)
@ -6115,48 +5795,6 @@ void kernel_mul_mm_id_impl(
} }
} }
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm(device const uchar * src0,
device const uchar * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
src0,
src1,
dst,
ne00,
ne02,
nb01,
nb02,
ne12,
nb10,
nb11,
nb12,
ne0,
ne1,
r2,
r3,
shared_memory,
tgpig,
tiitg,
sgitg);
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm_id( kernel void kernel_mul_mm_id(
device const uchar * src0s, device const uchar * src0s,
@ -6237,69 +5875,60 @@ kernel void kernel_mul_mm_id(
// get rows // get rows
// //
typedef void (get_rows_t)( typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
device const void * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb1,
constant uint64_t & nb2,
uint3, uint, uint3);
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>; template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>; template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>; typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>; template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>; template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>; template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>; template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>; template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>; template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
// //
// matrix-matrix multiplication // matrix-matrix multiplication
// //
typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t; typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>; template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>; template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>; template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>; template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>; template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>; template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>; template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>; template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>; template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>; template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>; template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>; template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
// //
// indirect matrix-matrix multiplication // indirect matrix-matrix multiplication
@ -6436,7 +6065,7 @@ void mmv_fn(
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
} }
typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t; typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
template<mul_mv_impl_fn_t impl_fn> template<mul_mv_impl_fn_t impl_fn>
kernel void kernel_mul_mv_id( kernel void kernel_mul_mv_id(
@ -6514,10 +6143,10 @@ kernel void kernel_mul_mv_id(
sgitg); sgitg);
} }
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t; typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>; template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>; template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>; template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;

View File

@ -658,7 +658,7 @@ static inline __m128i packNibbles( __m256i bytes ) {
#endif //__loongarch_asx #endif //__loongarch_asx
// reference implementation for deterministic creation of model files // reference implementation for deterministic creation of model files
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
static const int qk = QK4_0; static const int qk = QK4_0;
assert(k % qk == 0); assert(k % qk == 0);
@ -696,11 +696,11 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict
} }
void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q4_0_reference(x, y, k); quantize_row_q4_0_ref(x, y, k);
} }
void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) { void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) {
const int qk = QK4_1; const int qk = QK4_1;
assert(k % qk == 0); assert(k % qk == 0);
@ -738,10 +738,10 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict
} }
void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) { void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q4_1_reference(x, y, k); quantize_row_q4_1_ref(x, y, k);
} }
void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) { void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, int64_t k) {
static const int qk = QK5_0; static const int qk = QK5_0;
assert(k % qk == 0); assert(k % qk == 0);
@ -786,10 +786,10 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict
} }
void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) { void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q5_0_reference(x, y, k); quantize_row_q5_0_ref(x, y, k);
} }
void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) { void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, int64_t k) {
const int qk = QK5_1; const int qk = QK5_1;
assert(k % qk == 0); assert(k % qk == 0);
@ -834,11 +834,11 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict
} }
void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q5_1_reference(x, y, k); quantize_row_q5_1_ref(x, y, k);
} }
// reference implementation for deterministic creation of model files // reference implementation for deterministic creation of model files
void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) { void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
assert(k % QK8_0 == 0); assert(k % QK8_0 == 0);
const int nb = k / QK8_0; const int nb = k / QK8_0;
@ -1144,12 +1144,12 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
#else #else
GGML_UNUSED(nb); GGML_UNUSED(nb);
// scalar // scalar
quantize_row_q8_0_reference(x, y, k); quantize_row_q8_0_ref(x, y, k);
#endif #endif
} }
// reference implementation for deterministic creation of model files // reference implementation for deterministic creation of model files
void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) { void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) {
assert(QK8_1 == 32); assert(QK8_1 == 32);
assert(k % QK8_1 == 0); assert(k % QK8_1 == 0);
const int nb = k / QK8_1; const int nb = k / QK8_1;
@ -1508,7 +1508,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
#else #else
GGML_UNUSED(nb); GGML_UNUSED(nb);
// scalar // scalar
quantize_row_q8_1_reference(x, y, k); quantize_row_q8_1_ref(x, y, k);
#endif #endif
} }
@ -1899,7 +1899,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
//========================- 2-bit (de)-quantization //========================- 2-bit (de)-quantization
void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) { void quantize_row_q2_K_ref(const float * restrict x, block_q2_K * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int nb = k / QK_K; const int nb = k / QK_K;
@ -2002,7 +2002,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int6
} }
void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) {
quantize_row_q2_K_reference(x, vy, k); quantize_row_q2_K_ref(x, vy, k);
} }
static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights, static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
@ -2226,7 +2226,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
if (!quant_weights) { if (!quant_weights) {
quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row);
} }
else { else {
char * qrow = (char *)dst; char * qrow = (char *)dst;
@ -2241,7 +2241,7 @@ size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nr
//========================= 3-bit (de)-quantization //========================= 3-bit (de)-quantization
void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) { void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int nb = k / QK_K; const int nb = k / QK_K;
@ -2368,7 +2368,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int6
} }
void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) {
quantize_row_q3_K_reference(x, vy, k); quantize_row_q3_K_ref(x, vy, k);
} }
static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) { static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) {
@ -2458,7 +2458,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
if (!quant_weights) { if (!quant_weights) {
quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row);
} }
else { else {
char * qrow = (char *)dst; char * qrow = (char *)dst;
@ -2473,7 +2473,7 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nr
// ====================== 4-bit (de)-quantization // ====================== 4-bit (de)-quantization
void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) { void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int nb = k / QK_K; const int nb = k / QK_K;
@ -2572,7 +2572,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int6
void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
block_q4_K * restrict y = vy; block_q4_K * restrict y = vy;
quantize_row_q4_K_reference(x, y, k); quantize_row_q4_K_ref(x, y, k);
} }
static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) { static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) {
@ -2651,7 +2651,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
if (!quant_weights) { if (!quant_weights) {
quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row);
} }
else { else {
char * qrow = (char *)dst; char * qrow = (char *)dst;
@ -2666,7 +2666,7 @@ size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nr
// ====================== 5-bit (de)-quantization // ====================== 5-bit (de)-quantization
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) { void quantize_row_q5_K_ref(const float * restrict x, block_q5_K * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int64_t nb = k / QK_K; const int64_t nb = k / QK_K;
@ -2783,7 +2783,7 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int6
void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
block_q5_K * restrict y = vy; block_q5_K * restrict y = vy;
quantize_row_q5_K_reference(x, y, k); quantize_row_q5_K_ref(x, y, k);
} }
static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) { static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) {
@ -2882,7 +2882,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
if (!quant_weights) { if (!quant_weights) {
quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row);
} }
else { else {
char * qrow = (char *)dst; char * qrow = (char *)dst;
@ -2897,7 +2897,7 @@ size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nr
// ====================== 6-bit (de)-quantization // ====================== 6-bit (de)-quantization
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) { void quantize_row_q6_K_ref(const float * restrict x, block_q6_K * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int64_t nb = k / QK_K; const int64_t nb = k / QK_K;
@ -3001,7 +3001,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int6
void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
block_q6_K * restrict y = vy; block_q6_K * restrict y = vy;
quantize_row_q6_K_reference(x, y, k); quantize_row_q6_K_ref(x, y, k);
} }
static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) { static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) {
@ -3091,7 +3091,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri
size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
if (!quant_weights) { if (!quant_weights) {
quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row);
} }
else { else {
char * qrow = (char *)dst; char * qrow = (char *)dst;
@ -3108,7 +3108,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
static_assert(QK4_0 == 32, "QK4_0 must be 32"); static_assert(QK4_0 == 32, "QK4_0 must be 32");
if (!quant_weights) { if (!quant_weights) {
quantize_row_q4_0_reference(x, y, n_per_row); quantize_row_q4_0_ref(x, y, n_per_row);
return; return;
} }
@ -3134,7 +3134,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) { if (!quant_weights) {
quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
} }
size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
@ -3151,7 +3151,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
static_assert(QK4_1 == 32, "QK4_1 must be 32"); static_assert(QK4_1 == 32, "QK4_1 must be 32");
if (!quant_weights) { if (!quant_weights) {
quantize_row_q4_1_reference(x, y, n_per_row); quantize_row_q4_1_ref(x, y, n_per_row);
return; return;
} }
@ -3179,7 +3179,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) { if (!quant_weights) {
quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
} }
size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
@ -3196,7 +3196,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
static_assert(QK5_0 == 32, "QK5_0 must be 32"); static_assert(QK5_0 == 32, "QK5_0 must be 32");
if (!quant_weights) { if (!quant_weights) {
quantize_row_q5_0_reference(x, y, n_per_row); quantize_row_q5_0_ref(x, y, n_per_row);
return; return;
} }
@ -3233,7 +3233,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) { if (!quant_weights) {
quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
} }
size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
@ -3250,7 +3250,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
static_assert(QK5_1 == 32, "QK5_1 must be 32"); static_assert(QK5_1 == 32, "QK5_1 must be 32");
if (!quant_weights) { if (!quant_weights) {
quantize_row_q5_1_reference(x, y, n_per_row); quantize_row_q5_1_ref(x, y, n_per_row);
return; return;
} }
@ -3286,7 +3286,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) { if (!quant_weights) {
quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
} }
size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
@ -3302,7 +3302,7 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nr
size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
(void)quant_weights; // not used (void)quant_weights; // not used
const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * row_size; return nrow * row_size;
} }
@ -3590,7 +3590,7 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y,
//===================================== Q8_K ============================================== //===================================== Q8_K ==============================================
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) { void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int64_t nb = k / QK_K; const int64_t nb = k / QK_K;
@ -3641,7 +3641,7 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6
} }
void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q8_K_reference(x, y, k); quantize_row_q8_K_ref(x, y, k);
} }
//===================================== Dot ptoducts ================================= //===================================== Dot ptoducts =================================
@ -13530,10 +13530,10 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t
void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
block_iq3_xxs * restrict y = vy; block_iq3_xxs * restrict y = vy;
quantize_row_iq3_xxs_reference(x, y, k); quantize_row_iq3_xxs_ref(x, y, k);
} }
void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { void quantize_row_iq3_xxs_ref(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
quantize_row_iq3_xxs_impl(256, x, y, k, NULL); quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
} }
@ -13746,10 +13746,10 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t n
void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
block_iq3_s * restrict y = vy; block_iq3_s * restrict y = vy;
quantize_row_iq3_s_reference(x, y, k); quantize_row_iq3_s_ref(x, y, k);
} }
void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) { void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
quantize_iq3_s(x, y, 1, k, NULL); quantize_iq3_s(x, y, 1, k, NULL);
} }
@ -14487,7 +14487,7 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k
} }
} }
void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) {
assert(k % QK4_NL == 0); assert(k % QK4_NL == 0);
quantize_row_iq4_nl(x, y, k); quantize_row_iq4_nl(x, y, k);
} }
@ -14515,10 +14515,10 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t
void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
block_iq4_xs * restrict y = vy; block_iq4_xs * restrict y = vy;
quantize_row_iq4_xs_reference(x, y, k); quantize_row_iq4_xs_ref(x, y, k);
} }
void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { void quantize_row_iq4_xs_ref(const float * restrict x, block_iq4_xs * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
quantize_iq4_xs(x, y, 1, k, NULL); quantize_iq4_xs(x, y, 1, k, NULL);
} }
@ -14705,7 +14705,7 @@ size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t n
return nrow * nblock * sizeof(block_iq2_s); return nrow * nblock * sizeof(block_iq2_s);
} }
void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) { void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
quantize_iq2_s(x, y, 1, k, NULL); quantize_iq2_s(x, y, 1, k, NULL);
} }
@ -14713,7 +14713,7 @@ void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restri
void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
block_iq2_s * restrict y = vy; block_iq2_s * restrict y = vy;
quantize_row_iq2_s_reference(x, y, k); quantize_row_iq2_s_ref(x, y, k);
} }
static bool validate_float(float f, size_t i) { static bool validate_float(float f, size_t i) {

View File

@ -12,25 +12,25 @@ extern "C" {
#endif #endif
// Quantization // Quantization
void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k);
void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k);
void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);

View File

@ -3768,37 +3768,13 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)))); stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
const ggml_tensor_extra_gpu *src0_extra =
(const ggml_tensor_extra_gpu *)src0->extra;
const ggml_tensor_extra_gpu *src1_extra =
(const ggml_tensor_extra_gpu *)src1->extra;
const ggml_tensor_extra_gpu *dst_extra =
(const ggml_tensor_extra_gpu *)dst->extra;
ggml_tensor_extra_gpu src0_row_extra;
ggml_tensor_extra_gpu src1_row_extra;
ggml_tensor_extra_gpu dst_row_extra;
ggml_tensor src0_row = *src0; ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1; ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst; ggml_tensor dst_row = *dst;
src1_row.backend = GGML_BACKEND_TYPE_GPU; char *src0_original = (char *)src0->data;
dst_row.backend = GGML_BACKEND_TYPE_GPU; char *src1_original = (char *)src1->data;
char *dst_original = (char *)dst->data;
src0_row.extra = &src0_row_extra;
src1_row.extra = &src1_row_extra;
dst_row.extra = &dst_row_extra;
char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU
? (char *)src0->data
: (char *)src0_extra->data_device[ctx.device];
char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU
? (char *)src1->data
: (char *)src1_extra->data_device[ctx.device];
char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU
? (char *)dst->data
: (char *)dst_extra->data_device[ctx.device];
src0_row.ne[2] = 1; src0_row.ne[2] = 1;
src0_row.ne[3] = 1; src0_row.ne[3] = 1;
@ -3827,12 +3803,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
const int64_t i1 = id; const int64_t i1 = id;
const int64_t i2 = i12; const int64_t i2 = i12;
src0_row_extra.data_device[ctx.device] = src0_row.data = src0_original + i02*nb02;
src0_original + i02*nb02; src1_row.data = src1_original + + i11*nb11 + i12*nb12;
src1_row_extra.data_device[ctx.device] = dst_row.data = dst_original + i1*nb1 + i2*nb2;
src1_original + + i11*nb11 + i12*nb12;
dst_row_extra.data_device[ctx.device] =
dst_original + i1*nb1 + i2*nb2;
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
} }
@ -3841,8 +3814,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
src1_row_extra.data_device[ctx.device] = src1_contiguous.get(); src1_row.data = src1_contiguous.get();
dst_row_extra.data_device[ctx.device] = dst_contiguous.get(); dst_row.data = dst_contiguous.get();
for (int64_t i02 = 0; i02 < n_as; i02++) { for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = 0; int64_t num_src1_rows = 0;
@ -3898,7 +3871,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
}); });
} }
src0_row_extra.data_device[ctx.device] = src0_original + i02*nb02; src0_row.data = src0_original + i02*nb02;
GGML_ASSERT(nb11 == sizeof(float)*ne10); GGML_ASSERT(nb11 == sizeof(float)*ne10);
GGML_ASSERT(nb1 == sizeof(float)*ne0); GGML_ASSERT(nb1 == sizeof(float)*ne0);
@ -5221,6 +5194,10 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
return false; return false;
} }
} }
ggml_type src0_type = op->src[0]->type;
if (src0_type == GGML_TYPE_BF16) {
return false;
}
return true; return true;
} break; } break;
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:

File diff suppressed because it is too large Load Diff

View File

@ -6561,7 +6561,7 @@ static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tenso
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
vk_buffer buffer_gpu = extra->buffer_gpu.lock(); vk_buffer buffer_gpu = extra->buffer_gpu.lock();
ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size); ggml_vk_buffer_read(buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size);
} }
std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl; std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl;
@ -6645,7 +6645,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
for (int i3 = 0; i3 < src0->ne[3]; i3++) { for (int i3 = 0; i3 < src0->ne[3]; i3++) {
for (int i2 = 0; i2 < src0->ne[2]; i2++) { for (int i2 = 0; i2 < src0->ne[2]; i2++) {
const int idx = i3*src0->ne[2] + i2; const int idx = i3*src0->ne[2] + i2;
ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]); ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
} }
} }
@ -6658,7 +6658,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
if (offset + src0_size >= buffer_gpu->size) { if (offset + src0_size >= buffer_gpu->size) {
src0_size = buffer_gpu->size - offset; src0_size = buffer_gpu->size - offset;
} }
ggml_vk_buffer_read(ctx, buffer_gpu, offset, src0_clone->data, src0_size); ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size);
memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
} }
} else { } else {
@ -6687,7 +6687,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
for (int i3 = 0; i3 < src1->ne[3]; i3++) { for (int i3 = 0; i3 < src1->ne[3]; i3++) {
for (int i2 = 0; i2 < src1->ne[2]; i2++) { for (int i2 = 0; i2 < src1->ne[2]; i2++) {
const int idx = i3*src1->ne[2] + i2; const int idx = i3*src1->ne[2] + i2;
ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]); ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
} }
} }
@ -6700,7 +6700,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
if (offset + src1_size >= buffer_gpu->size) { if (offset + src1_size >= buffer_gpu->size) {
src1_size = buffer_gpu->size - offset; src1_size = buffer_gpu->size - offset;
} }
ggml_vk_buffer_read(ctx, buffer_gpu, offset, src1_clone->data, src1_size); ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size);
memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
} }
} else { } else {
@ -6745,7 +6745,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
for (int i3 = 0; i3 < src2->ne[3]; i3++) { for (int i3 = 0; i3 < src2->ne[3]; i3++) {
for (int i2 = 0; i2 < src2->ne[2]; i2++) { for (int i2 = 0; i2 < src2->ne[2]; i2++) {
const int idx = i3*src2->ne[2] + i2; const int idx = i3*src2->ne[2] + i2;
ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]); ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
} }
} }
@ -6758,7 +6758,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
if (offset + src2_size >= buffer_gpu->size) { if (offset + src2_size >= buffer_gpu->size) {
src2_size = buffer_gpu->size - offset; src2_size = buffer_gpu->size - offset;
} }
ggml_vk_buffer_read(ctx, buffer_gpu, offset, src2_clone->data, src2_size); ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size);
memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
} }
} else { } else {
@ -6922,7 +6922,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor *
tensor_size = buffer_gpu->size - (extra->offset + tensor->view_offs); tensor_size = buffer_gpu->size - (extra->offset + tensor->view_offs);
} }
ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size); ggml_vk_buffer_read(buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size);
} }
float first_error_result = -1.0f; float first_error_result = -1.0f;

View File

@ -592,7 +592,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = false, .is_quantized = false,
.to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row,
.from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row, .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row,
.from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row, .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
.vec_dot_type = GGML_TYPE_F16, .vec_dot_type = GGML_TYPE_F16,
.nrows = 1, .nrows = 1,
@ -604,7 +604,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q4_0, .to_float = (ggml_to_float_t) dequantize_row_q4_0,
.from_float = quantize_row_q4_0, .from_float = quantize_row_q4_0,
.from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref,
.vec_dot = ggml_vec_dot_q4_0_q8_0, .vec_dot = ggml_vec_dot_q4_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
#if defined (__ARM_FEATURE_MATMUL_INT8) #if defined (__ARM_FEATURE_MATMUL_INT8)
@ -620,7 +620,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q4_1, .to_float = (ggml_to_float_t) dequantize_row_q4_1,
.from_float = quantize_row_q4_1, .from_float = quantize_row_q4_1,
.from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref,
.vec_dot = ggml_vec_dot_q4_1_q8_1, .vec_dot = ggml_vec_dot_q4_1_q8_1,
.vec_dot_type = GGML_TYPE_Q8_1, .vec_dot_type = GGML_TYPE_Q8_1,
#if defined (__ARM_FEATURE_MATMUL_INT8) #if defined (__ARM_FEATURE_MATMUL_INT8)
@ -636,7 +636,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = false, .is_quantized = false,
.to_float = NULL, .to_float = NULL,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = NULL, .vec_dot = NULL,
.vec_dot_type = GGML_TYPE_COUNT, .vec_dot_type = GGML_TYPE_COUNT,
.nrows = 1, .nrows = 1,
@ -648,7 +648,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = false, .is_quantized = false,
.to_float = NULL, .to_float = NULL,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = NULL, .vec_dot = NULL,
.vec_dot_type = GGML_TYPE_COUNT, .vec_dot_type = GGML_TYPE_COUNT,
.nrows = 1, .nrows = 1,
@ -660,7 +660,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q5_0, .to_float = (ggml_to_float_t) dequantize_row_q5_0,
.from_float = quantize_row_q5_0, .from_float = quantize_row_q5_0,
.from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref,
.vec_dot = ggml_vec_dot_q5_0_q8_0, .vec_dot = ggml_vec_dot_q5_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1, .nrows = 1,
@ -672,7 +672,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q5_1, .to_float = (ggml_to_float_t) dequantize_row_q5_1,
.from_float = quantize_row_q5_1, .from_float = quantize_row_q5_1,
.from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref,
.vec_dot = ggml_vec_dot_q5_1_q8_1, .vec_dot = ggml_vec_dot_q5_1_q8_1,
.vec_dot_type = GGML_TYPE_Q8_1, .vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1, .nrows = 1,
@ -684,7 +684,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q8_0, .to_float = (ggml_to_float_t) dequantize_row_q8_0,
.from_float = quantize_row_q8_0, .from_float = quantize_row_q8_0,
.from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref,
.from_float_to_mat = quantize_mat_q8_0,
.vec_dot = ggml_vec_dot_q8_0_q8_0, .vec_dot = ggml_vec_dot_q8_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
#if defined (__ARM_FEATURE_MATMUL_INT8) #if defined (__ARM_FEATURE_MATMUL_INT8)
@ -692,7 +693,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
#else #else
.nrows = 1, .nrows = 1,
#endif #endif
.from_float_to_mat = quantize_mat_q8_0,
}, },
[GGML_TYPE_Q8_1] = { [GGML_TYPE_Q8_1] = {
.type_name = "q8_1", .type_name = "q8_1",
@ -700,7 +700,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_q8_1), .type_size = sizeof(block_q8_1),
.is_quantized = true, .is_quantized = true,
.from_float = quantize_row_q8_1, .from_float = quantize_row_q8_1,
.from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
.vec_dot_type = GGML_TYPE_Q8_1, .vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1, .nrows = 1,
}, },
@ -711,7 +711,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q2_K, .to_float = (ggml_to_float_t) dequantize_row_q2_K,
.from_float = quantize_row_q2_K, .from_float = quantize_row_q2_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q2_K_ref,
.vec_dot = ggml_vec_dot_q2_K_q8_K, .vec_dot = ggml_vec_dot_q2_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -723,7 +723,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q3_K, .to_float = (ggml_to_float_t) dequantize_row_q3_K,
.from_float = quantize_row_q3_K, .from_float = quantize_row_q3_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref,
.vec_dot = ggml_vec_dot_q3_K_q8_K, .vec_dot = ggml_vec_dot_q3_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -735,7 +735,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q4_K, .to_float = (ggml_to_float_t) dequantize_row_q4_K,
.from_float = quantize_row_q4_K, .from_float = quantize_row_q4_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref,
.vec_dot = ggml_vec_dot_q4_K_q8_K, .vec_dot = ggml_vec_dot_q4_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -747,7 +747,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q5_K, .to_float = (ggml_to_float_t) dequantize_row_q5_K,
.from_float = quantize_row_q5_K, .from_float = quantize_row_q5_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref,
.vec_dot = ggml_vec_dot_q5_K_q8_K, .vec_dot = ggml_vec_dot_q5_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -759,7 +759,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q6_K, .to_float = (ggml_to_float_t) dequantize_row_q6_K,
.from_float = quantize_row_q6_K, .from_float = quantize_row_q6_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q6_K_ref,
.vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot = ggml_vec_dot_q6_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -771,7 +771,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq2_xxs, .to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K, .vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -783,7 +783,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq2_xs, .to_float = (ggml_to_float_t) dequantize_row_iq2_xs,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = ggml_vec_dot_iq2_xs_q8_K, .vec_dot = ggml_vec_dot_iq2_xs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -795,7 +795,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq3_xxs, .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs,
.from_float = quantize_row_iq3_xxs, .from_float = quantize_row_iq3_xxs,
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference, .from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref,
.vec_dot = ggml_vec_dot_iq3_xxs_q8_K, .vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -807,7 +807,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq3_s, .to_float = (ggml_to_float_t) dequantize_row_iq3_s,
.from_float = quantize_row_iq3_s, .from_float = quantize_row_iq3_s,
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_s_reference, .from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_ref,
.vec_dot = ggml_vec_dot_iq3_s_q8_K, .vec_dot = ggml_vec_dot_iq3_s_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -819,7 +819,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq2_s, .to_float = (ggml_to_float_t) dequantize_row_iq2_s,
.from_float = quantize_row_iq2_s, .from_float = quantize_row_iq2_s,
.from_float_reference = (ggml_from_float_t)quantize_row_iq2_s_reference, .from_float_ref = (ggml_from_float_t)quantize_row_iq2_s_ref,
.vec_dot = ggml_vec_dot_iq2_s_q8_K, .vec_dot = ggml_vec_dot_iq2_s_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -831,7 +831,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq1_s, .to_float = (ggml_to_float_t) dequantize_row_iq1_s,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = ggml_vec_dot_iq1_s_q8_K, .vec_dot = ggml_vec_dot_iq1_s_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -843,7 +843,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq1_m, .to_float = (ggml_to_float_t) dequantize_row_iq1_m,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = ggml_vec_dot_iq1_m_q8_K, .vec_dot = ggml_vec_dot_iq1_m_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -855,7 +855,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq4_nl, .to_float = (ggml_to_float_t) dequantize_row_iq4_nl,
.from_float = quantize_row_iq4_nl, .from_float = quantize_row_iq4_nl,
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_nl_reference, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref,
.vec_dot = ggml_vec_dot_iq4_nl_q8_0, .vec_dot = ggml_vec_dot_iq4_nl_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1, .nrows = 1,
@ -867,7 +867,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq4_xs, .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
.from_float = quantize_row_iq4_xs, .from_float = quantize_row_iq4_xs,
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_ref,
.vec_dot = ggml_vec_dot_iq4_xs_q8_K, .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -886,7 +886,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = false, .is_quantized = false,
.to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row, .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
.from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row, .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
.from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row, .from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
.vec_dot_type = GGML_TYPE_BF16, .vec_dot_type = GGML_TYPE_BF16,
.nrows = 1, .nrows = 1,
@ -894,48 +894,48 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q4_0_4_4] = { [GGML_TYPE_Q4_0_4_4] = {
.type_name = "q4_0_4x4", .type_name = "q4_0_4x4",
.blck_size = QK4_0, .blck_size = QK4_0,
.blck_size_interleave = 4,
.type_size = sizeof(block_q4_0), .type_size = sizeof(block_q4_0),
.is_quantized = true, .is_quantized = true,
.to_float = NULL, .to_float = NULL,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = NULL, .vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1, .nrows = 1,
.ncols = 4, .ncols = 4,
.interleave_blcksize = 4,
.gemv = ggml_gemv_q4_0_4x4_q8_0, .gemv = ggml_gemv_q4_0_4x4_q8_0,
.gemm = ggml_gemm_q4_0_4x4_q8_0, .gemm = ggml_gemm_q4_0_4x4_q8_0,
}, },
[GGML_TYPE_Q4_0_4_8] = { [GGML_TYPE_Q4_0_4_8] = {
.type_name = "q4_0_4x8", .type_name = "q4_0_4x8",
.blck_size = QK4_0, .blck_size = QK4_0,
.blck_size_interleave = 8,
.type_size = sizeof(block_q4_0), .type_size = sizeof(block_q4_0),
.is_quantized = true, .is_quantized = true,
.to_float = NULL, .to_float = NULL,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = NULL, .vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1, .nrows = 1,
.ncols = 4, .ncols = 4,
.interleave_blcksize = 8,
.gemv = ggml_gemv_q4_0_4x8_q8_0, .gemv = ggml_gemv_q4_0_4x8_q8_0,
.gemm = ggml_gemm_q4_0_4x8_q8_0, .gemm = ggml_gemm_q4_0_4x8_q8_0,
}, },
[GGML_TYPE_Q4_0_8_8] = { [GGML_TYPE_Q4_0_8_8] = {
.type_name = "q4_0_8x8", .type_name = "q4_0_8x8",
.blck_size = QK4_0, .blck_size = QK4_0,
.blck_size_interleave = 8,
.type_size = sizeof(block_q4_0), .type_size = sizeof(block_q4_0),
.is_quantized = true, .is_quantized = true,
.to_float = NULL, .to_float = NULL,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = NULL, .vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1, .nrows = 1,
.ncols = 8, .ncols = 8,
.interleave_blcksize = 8,
.gemv = ggml_gemv_q4_0_8x8_q8_0, .gemv = ggml_gemv_q4_0_8x8_q8_0,
.gemm = ggml_gemm_q4_0_8x8_q8_0, .gemm = ggml_gemm_q4_0_8x8_q8_0,
} }
@ -3115,7 +3115,7 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN); return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN);
} }
GGML_CALL int ggml_blck_size(enum ggml_type type) { GGML_CALL int64_t ggml_blck_size(enum ggml_type type) {
return type_traits[type].blck_size; return type_traits[type].blck_size;
} }
@ -12193,12 +12193,11 @@ static void ggml_compute_forward_mul_mat(
const enum ggml_type type = src0->type; const enum ggml_type type = src0->type;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat;
int64_t const vec_dot_num_rows = type_traits[type].nrows; int64_t const vec_dot_num_rows = type_traits[type].nrows;
int64_t const matmul_num_cols = type_traits[type].ncols; int64_t const matmul_num_cols = type_traits[type].ncols;
int64_t const interleave_blcksize = type_traits[type].interleave_blcksize; int64_t const blck_size_interleave = type_traits[type].blck_size_interleave;
ggml_from_float_to_mat_t const from_float_to_mat
= type_traits[vec_dot_type].from_float_to_mat;
ggml_gemv_t const gemv = type_traits[type].gemv; ggml_gemv_t const gemv = type_traits[type].gemv;
ggml_gemm_t const gemm = type_traits[type].gemm; ggml_gemm_t const gemm = type_traits[type].gemm;
@ -12264,12 +12263,12 @@ UseGgmlGemm1:;
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
4, ne10, interleave_blcksize); 4, ne10, blck_size_interleave);
} }
i11_processed = ne11 - ne11 % 4; i11_processed = ne11 - ne11 % 4;
} }
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
ne10); ne10);
} }
@ -12415,7 +12414,7 @@ static void ggml_compute_forward_mul_mat_id(
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
int64_t const matmul_num_cols = type_traits[type].ncols; int64_t const matmul_num_cols = type_traits[type].ncols;
ggml_gemv_t const gemv = type_traits[type].gemv; ggml_gemv_t const gemv = type_traits[type].gemv;
@ -12458,7 +12457,7 @@ static void ggml_compute_forward_mul_mat_id(
for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = ith; i11 < ne11; i11 += nth) { for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
ne10); ne10);
} }
@ -21063,8 +21062,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
(int64_t) info->ne[3]; (int64_t) info->ne[3];
if (ne % ggml_blck_size(info->type) != 0) { if (ne % ggml_blck_size(info->type) != 0) {
fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%d)\n", fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
__func__, info->name.data, (int)info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type)); __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
fclose(file); fclose(file);
gguf_free(ctx); gguf_free(ctx);
return NULL; return NULL;

View File

@ -0,0 +1,5 @@
set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View File

@ -0,0 +1,524 @@
#include <iostream>
#include <fstream>
#include <sstream>
#include <string>
#include <stdexcept>
#include <array>
#include <vector>
#include <map>
#include <thread>
#include <mutex>
#include <future>
#include <queue>
#include <condition_variable>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <sys/stat.h>
#include <sys/types.h>
#ifdef _WIN32
#include <windows.h>
#include <direct.h> // For _mkdir on Windows
#else
#include <unistd.h>
#include <sys/wait.h>
#include <fcntl.h>
#endif
#define ASYNCIO_CONCURRENCY 64
std::mutex lock;
std::vector<std::pair<std::string, std::string>> shader_fnames;
std::string GLSLC = "glslc";
std::string input_dir = "vulkan-shaders";
std::string output_dir = "/tmp";
std::string target_hpp = "ggml-vulkan-shaders.hpp";
std::string target_cpp = "ggml-vulkan-shaders.cpp";
bool no_clean = false;
const std::vector<std::string> type_names = {
"f32",
"f16",
"q4_0",
"q4_1",
"q5_0",
"q5_1",
"q8_0",
"q2_k",
"q3_k",
"q4_k",
"q5_k",
"q6_k"
};
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32
HANDLE stdout_read, stdout_write;
HANDLE stderr_read, stderr_write;
SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };
if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||
!SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {
throw std::runtime_error("Failed to create stdout pipe");
}
if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||
!SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {
throw std::runtime_error("Failed to create stderr pipe");
}
PROCESS_INFORMATION pi;
STARTUPINFOA si = { sizeof(STARTUPINFOA) };
si.dwFlags = STARTF_USESTDHANDLES;
si.hStdOutput = stdout_write;
si.hStdError = stderr_write;
std::vector<char> cmd(command.begin(), command.end());
cmd.push_back('\0');
if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
throw std::runtime_error("Failed to create process");
}
CloseHandle(stdout_write);
CloseHandle(stderr_write);
std::array<char, 128> buffer;
DWORD bytes_read;
while (ReadFile(stdout_read, buffer.data(), buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
stdout_str.append(buffer.data(), bytes_read);
}
while (ReadFile(stderr_read, buffer.data(), buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
stderr_str.append(buffer.data(), bytes_read);
}
CloseHandle(stdout_read);
CloseHandle(stderr_read);
WaitForSingleObject(pi.hProcess, INFINITE);
CloseHandle(pi.hProcess);
CloseHandle(pi.hThread);
#else
int stdout_pipe[2];
int stderr_pipe[2];
if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
throw std::runtime_error("Failed to create pipes");
}
pid_t pid = fork();
if (pid < 0) {
throw std::runtime_error("Failed to fork process");
}
if (pid == 0) {
close(stdout_pipe[0]);
close(stderr_pipe[0]);
dup2(stdout_pipe[1], STDOUT_FILENO);
dup2(stderr_pipe[1], STDERR_FILENO);
close(stdout_pipe[1]);
close(stderr_pipe[1]);
execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr);
_exit(EXIT_FAILURE);
} else {
close(stdout_pipe[1]);
close(stderr_pipe[1]);
std::array<char, 128> buffer;
ssize_t bytes_read;
while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {
stdout_str.append(buffer.data(), bytes_read);
}
while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {
stderr_str.append(buffer.data(), bytes_read);
}
close(stdout_pipe[0]);
close(stderr_pipe[0]);
waitpid(pid, nullptr, 0);
}
#endif
}
bool directory_exists(const std::string& path) {
struct stat info;
if (stat(path.c_str(), &info) != 0) {
return false; // Path doesn't exist or can't be accessed
}
return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory
}
bool create_directory(const std::string& path) {
#ifdef _WIN32
return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists
#else
return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions
#endif
}
std::string to_uppercase(const std::string& input) {
std::string result = input;
for (char& c : result) {
c = std::toupper(c);
}
return result;
}
bool string_ends_with(const std::string& str, const std::string& suffix) {
if (suffix.size() > str.size()) {
return false;
}
return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
}
#ifdef _WIN32
static const char path_separator = '\\';
#else
static const char path_separator = '/';
#endif
std::string join_paths(const std::string& path1, const std::string& path2) {
return path1 + path_separator + path2;
}
std::string basename(const std::string &path) {
return path.substr(path.find_last_of("/\\") + 1);
}
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
std::string name = _name + (fp16 ? "" : "_fp32");
std::string out_fname = join_paths(output_dir, name + ".spv");
std::string in_path = join_paths(input_dir, in_fname);
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname};
for (const auto& define : defines) {
cmd.push_back("-D" + define.first + "=" + define.second);
}
std::string command;
for (const auto& part : cmd) {
command += part + " ";
}
std::string stdout_str, stderr_str;
try {
// std::cout << "Executing command: ";
// for (const auto& part : cmd) {
// std::cout << part << " ";
// }
// std::cout << std::endl;
execute_command(command, stdout_str, stderr_str);
if (!stderr_str.empty()) {
std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl;
return;
}
std::lock_guard<std::mutex> guard(lock);
shader_fnames.push_back(std::make_pair(name, out_fname));
} catch (const std::exception& e) {
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
}
}
std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
std::map<std::string, std::string> result = a;
result.insert(b.begin(), b.end());
return result;
}
void matmul_shaders(std::vector<std::future<void>>& tasks, bool fp16, bool matmul_id) {
std::string load_vec = fp16 ? "8" : "4";
std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4";
std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4";
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}};
std::string shader_name = "matmul";
if (matmul_id) {
base_dict["MUL_MAT_ID"] = "1";
shader_name = "matmul_id";
}
if (fp16) {
base_dict["FLOAT16"] = "1";
}
// Shaders with f16 B_TYPE
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv(shader_name + "_f32_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
}));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv(shader_name + "_f32_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
}));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv(shader_name + "_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
}));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv(shader_name + "_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
}));
for (const auto& tname : type_names) {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
}));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
}));
}
}
void process_shaders(std::vector<std::future<void>>& tasks) {
std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
for (const auto& fp16 : {false, true}) {
matmul_shaders(tasks, fp16, false);
matmul_shaders(tasks, fp16, true);
}
for (const auto& tname : type_names) {
// mul mat vec
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
}));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
}));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
}));
// Dequant shaders
if (tname != "f16") {
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
}));
}
if (!string_ends_with(tname, "_k")) {
shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
if (tname == "f16") {
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
}));
} else {
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}});
}));
}
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("get_rows_" + tname + "_f32", shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
}));
}
}
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
}));
// Norms
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
}));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
}));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
}));
}
void write_output_files() {
FILE* hdr = fopen(target_hpp.c_str(), "w");
FILE* src = fopen(target_cpp.c_str(), "w");
fprintf(hdr, "#include <cstdint>\n\n");
fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
for (const auto& pair : shader_fnames) {
const std::string& name = pair.first;
const std::string& path = pair.second;
FILE* spv = fopen(path.c_str(), "rb");
if (!spv) {
std::cerr << "Error opening SPIR-V file: " << path << "\n";
continue;
}
fseek(spv, 0, SEEK_END);
size_t size = ftell(spv);
fseek(spv, 0, SEEK_SET);
std::vector<unsigned char> data(size);
size_t read_size = fread(data.data(), 1, size, spv);
fclose(spv);
if (read_size != size) {
std::cerr << "Error reading SPIR-V file: " << path << "\n";
continue;
}
fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size);
fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size);
fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size);
for (size_t i = 0; i < size; ++i) {
fprintf(src, "0x%02x,", data[i]);
if ((i + 1) % 12 == 0) fprintf(src, "\n");
}
fprintf(src, "\n};\n\n");
if (!no_clean) {
std::remove(path.c_str());
}
}
fclose(hdr);
fclose(src);
}
int main(int argc, char** argv) {
std::map<std::string, std::string> args;
for (int i = 1; i < argc; i += 2) {
if (i + 1 < argc) {
args[argv[i]] = argv[i + 1];
}
}
if (args.find("--glslc") != args.end()) {
GLSLC = args["--glslc"]; // Path to glslc
}
if (args.find("--input-dir") != args.end()) {
input_dir = args["--input-dir"]; // Directory containing shader sources
}
if (args.find("--output-dir") != args.end()) {
output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
}
if (args.find("--target-hpp") != args.end()) {
target_hpp = args["--target-hpp"]; // Path to generated header file
}
if (args.find("--target-cpp") != args.end()) {
target_cpp = args["--target-cpp"]; // Path to generated cpp file
}
if (args.find("--no-clean") != args.end()) {
no_clean = true; // Keep temporary SPIR-V files in output-dir after build
}
if (!directory_exists(input_dir)) {
std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl;
return EXIT_FAILURE;
}
if (!directory_exists(output_dir)) {
if (!create_directory(output_dir)) {
std::cerr << "Error creating output directory: " << output_dir << "\n";
return EXIT_FAILURE;
}
}
std::vector<std::future<void>> tasks;
process_shaders(tasks);
for (auto& task : tasks) {
task.get();
}
write_output_files();
return EXIT_SUCCESS;
}

View File

@ -27,8 +27,9 @@ UUID_NAMESPACE_LLAMA_CPP = uuid.UUID('ef001206-dadc-5f6d-a15f-3359e577d4e5')
# For more information about what field.parts and field.data represent, # For more information about what field.parts and field.data represent,
# please see the comments in the modify_gguf.py example. # please see the comments in the modify_gguf.py example.
def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar) -> None: def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar: bool, no_layer: bool) -> None:
sha1 = hashlib.sha1() sha1 = hashlib.sha1()
sha256 = hashlib.sha256()
uuidv5_sha1 = hashlib.sha1() uuidv5_sha1 = hashlib.sha1()
uuidv5_sha1.update(UUID_NAMESPACE_LLAMA_CPP.bytes) uuidv5_sha1.update(UUID_NAMESPACE_LLAMA_CPP.bytes)
@ -50,7 +51,7 @@ def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar) -> None:
bar = tqdm(desc="Hashing", total=total_weights, unit="weights", unit_scale=True, disable=disable_progress_bar) bar = tqdm(desc="Hashing", total=total_weights, unit="weights", unit_scale=True, disable=disable_progress_bar)
# Hashing Process # Hashing Process
for n, tensor in enumerate(reader.tensors, 1): for tensor in reader.tensors:
# We don't need these # We don't need these
if tensor.name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): if tensor.name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
@ -62,29 +63,39 @@ def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar) -> None:
sum_weights_in_tensor *= dim sum_weights_in_tensor *= dim
bar.update(sum_weights_in_tensor) bar.update(sum_weights_in_tensor)
if not no_layer:
sha1_layer = hashlib.sha1() sha1_layer = hashlib.sha1()
sha1_layer.update(tensor.data.data) sha1_layer.update(tensor.data.data)
sha1.update(tensor.data.data)
uuidv5_sha1.update(tensor.data.data)
print("sha1 {0} {1}:{2}".format(sha1_layer.hexdigest(), filename, tensor.name)) # noqa: NP100 print("sha1 {0} {1}:{2}".format(sha1_layer.hexdigest(), filename, tensor.name)) # noqa: NP100
sha256_layer = hashlib.sha256()
sha256_layer.update(tensor.data.data)
print("sha256 {0} {1}:{2}".format(sha256_layer.hexdigest(), filename, tensor.name)) # noqa: NP100
sha1.update(tensor.data.data)
sha256.update(tensor.data.data)
uuidv5_sha1.update(tensor.data.data)
# Flush Hash Progress Bar # Flush Hash Progress Bar
bar.close() bar.close()
# Display Hash Output # Display Hash Output
print("sha1 {0} {1}".format(sha1.hexdigest(), filename)) # noqa: NP100 print("sha1 {0} {1}".format(sha1.hexdigest(), filename)) # noqa: NP100
print("UUIDv5 {0} {1}".format(uuid.UUID(bytes=uuidv5_sha1.digest()[:16], version=5), filename)) # noqa: NP100 print("sha256 {0} {1}".format(sha256.hexdigest(), filename)) # noqa: NP100
print("uuid {0} {1}".format(uuid.UUID(bytes=uuidv5_sha1.digest()[:16], version=5), filename)) # noqa: NP100
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser(description="Dump GGUF file metadata") parser = argparse.ArgumentParser(description="Dump GGUF file metadata")
parser.add_argument("model", type=str, help="GGUF format model filename") parser.add_argument("model", type=str, help="GGUF format model filename")
parser.add_argument("--no-layer", action="store_true", help="exclude per layer hash")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity") parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
parser.add_argument("--progressbar", action="store_true", help="enable progressbar") parser.add_argument("--progressbar", action="store_true", help="enable progressbar")
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
reader = GGUFReader(args.model, 'r') reader = GGUFReader(args.model, 'r')
gguf_hash(reader, args.model, not args.progressbar) gguf_hash(reader, args.model, not args.progressbar, args.no_layer)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,2 +1,3 @@
docstring_parser~=0.15 docstring_parser~=0.15
pydantic~=2.6.3 pydantic~=2.6.3
requests

View File

@ -5477,6 +5477,7 @@ static void llm_load_vocab(
} else if ( } else if (
tokenizer_pre == "command-r") { tokenizer_pre == "command-r") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
vocab.tokenizer_clean_spaces = false;
} else if ( } else if (
tokenizer_pre == "qwen2") { tokenizer_pre == "qwen2") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
@ -5710,7 +5711,7 @@ static void llm_load_vocab(
// build special tokens cache // build special tokens cache
{ {
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) { for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) { if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
vocab.cache_special_tokens.push_back(id); vocab.cache_special_tokens.push_back(id);
} }
} }
@ -5941,13 +5942,6 @@ static bool llm_load_tensors(
auto & hparams = model.hparams; auto & hparams = model.hparams;
#ifdef GGML_USE_SYCL
// disable MoE with SYCL until mul_mat_id is updated
if (hparams.n_expert > 0) {
n_gpu_layers = 0;
}
#endif
model.split_mode = split_mode; model.split_mode = split_mode;
model.main_gpu = main_gpu; model.main_gpu = main_gpu;
model.n_gpu_layers = n_gpu_layers; model.n_gpu_layers = n_gpu_layers;
@ -8247,7 +8241,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il); cb(kq, "kq", il);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32); ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
@ -11800,7 +11794,12 @@ struct llm_build_context {
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
switch (model.type) {
case e_model::MODEL_9B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break;
case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
default: GGML_ASSERT(false);
};
cb(Qcur, "Qcur_scaled", il); cb(Qcur, "Qcur_scaled", il);
Kcur = ggml_rope_ext( Kcur = ggml_rope_ext(
@ -15532,17 +15531,6 @@ struct llm_tokenizer_bpe {
"[0-9][0-9][0-9]", "[0-9][0-9][0-9]",
}; };
break; break;
case LLAMA_VOCAB_PRE_TYPE_MPT:
// TODO: MPT pre-tokenization regexes are unknown
// the following are close, but not exact. run the following:
// ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
regex_exprs = {
"\\s?\\p{L}+",
"\\s?\\p{P}+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
};
break;
case LLAMA_VOCAB_PRE_TYPE_STARCODER: case LLAMA_VOCAB_PRE_TYPE_STARCODER:
case LLAMA_VOCAB_PRE_TYPE_REFACT: case LLAMA_VOCAB_PRE_TYPE_REFACT:
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R: case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
@ -15552,6 +15540,7 @@ struct llm_tokenizer_bpe {
}; };
break; break;
case LLAMA_VOCAB_PRE_TYPE_GPT2: case LLAMA_VOCAB_PRE_TYPE_GPT2:
case LLAMA_VOCAB_PRE_TYPE_MPT:
case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_OLMO:
case LLAMA_VOCAB_PRE_TYPE_JAIS: case LLAMA_VOCAB_PRE_TYPE_JAIS:
regex_exprs = { regex_exprs = {
@ -15578,8 +15567,8 @@ struct llm_tokenizer_bpe {
break; break;
case LLAMA_VOCAB_PRE_TYPE_VIKING: case LLAMA_VOCAB_PRE_TYPE_VIKING:
regex_exprs = { regex_exprs = {
"\\p{N}",
" ?[^(\\s|.,!?…。,、।۔،)]+", " ?[^(\\s|.,!?…。,、।۔،)]+",
"\\p{N}",
}; };
break; break;
default: default:
@ -16299,12 +16288,20 @@ struct fragment_buffer_variant {
// #define PRETOKENIZERDEBUG // #define PRETOKENIZERDEBUG
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) { static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer, bool parse_special) {
// for each special token // for each special token
for (const llama_vocab::id special_id : vocab.cache_special_tokens) { for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
const auto & data = vocab.id_to_token[special_id]; const auto & data = vocab.id_to_token[special_id];
const auto & special_token = data.text; const auto & special_token = data.text;
if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
// Ignore control and unknown tokens when parse_special == false
continue;
// User-defined tokens are still pre-tokenized before everything else
// ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
// This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
}
// for each text fragment // for each text fragment
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin(); std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
while (it != buffer.end()) { while (it != buffer.end()) {
@ -16417,7 +16414,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
if (!raw_text.empty()) { if (!raw_text.empty()) {
fragment_buffer.emplace_front(raw_text, 0, raw_text.length()); fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
if (parse_special) tokenizer_st_partition(vocab, fragment_buffer); tokenizer_st_partition(vocab, fragment_buffer, parse_special);
} }
switch (vocab.type) { switch (vocab.type) {
@ -21188,7 +21185,7 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
size--; size--;
} }
if (length < (int32_t)size) { if (length < (int32_t)size) {
return (int32_t) -size; return -(int32_t) size;
} }
memcpy(buf, token, size); memcpy(buf, token, size);
return (int32_t) size; return (int32_t) size;

View File

@ -14,7 +14,7 @@
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdouble-promotion" #pragma GCC diagnostic ignored "-Wdouble-promotion"
// ggml.c::quantize_row_q4_0_reference // ggml.c::quantize_row_q4_0_ref
inline static uint8_t round_orig(float v0) { return ((int8_t) (round(v0))) + 8; } inline static uint8_t round_orig(float v0) { return ((int8_t) (round(v0))) + 8; }
// ggml.c::ggml_silu_f32 // ggml.c::ggml_silu_f32
@ -24,7 +24,7 @@ inline static float silu_orig(float x) {
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
// ggml.c::quantize_row_q4_0_reference // ggml.c::quantize_row_q4_0_ref
inline static uint8_t round_float(float v0) { return (int8_t)roundf(v0) + 8; } inline static uint8_t round_float(float v0) { return (int8_t)roundf(v0) + 8; }
// ggml.c::ggml_silu_f32 // ggml.c::ggml_silu_f32

View File

@ -60,7 +60,7 @@ static float reference_quantization_error(ggml_type_traits_t & qfns, size_t test
qfns.from_float(test_data, tmp_q.data(), test_size); qfns.from_float(test_data, tmp_q.data(), test_size);
qfns.to_float(tmp_q.data(), tmp_out.data(), test_size); qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
qfns.from_float_reference(test_data, tmp_q.data(), test_size); qfns.from_float_ref(test_data, tmp_q.data(), test_size);
qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size); qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size); return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);

View File

@ -285,7 +285,7 @@ int main(int argc, char * argv[]) {
for (size_t size : params.test_sizes) { for (size_t size : params.test_sizes) {
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024)); printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
auto quantize_fn = [&](void) -> float { auto quantize_fn = [&](void) -> float {
qfns.from_float_reference(test_data1, test_q1, size); qfns.from_float_ref(test_data1, test_q1, size);
return test_q1[0]; return test_q1[0];
}; };
size_t quantized_size = ggml_row_size(type, size); size_t quantized_size = ggml_row_size(type, size);

View File

@ -195,7 +195,7 @@ int main(int argc, char **argv) {
const bool add_special = false; const bool add_special = false;
for (const auto & test_kv : k_tests) { for (const auto & test_kv : k_tests) {
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, true); const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, false);
printf("\n"); printf("\n");
printf("src: '%s'\n", test_kv.first.c_str()); printf("src: '%s'\n", test_kv.first.c_str());
@ -253,7 +253,7 @@ int main(int argc, char **argv) {
{ {
const auto t_start = ggml_time_us(); const auto t_start = ggml_time_us();
res = llama_tokenize(ctx, text, add_special, true); res = llama_tokenize(ctx, text, add_special, false);
const auto t_end = ggml_time_us(); const auto t_end = ggml_time_us();

View File

@ -20,7 +20,7 @@ from typing import Any, Iterator, cast
from typing_extensions import Buffer from typing_extensions import Buffer
import cffi import cffi
from transformers import AutoTokenizer from transformers import AutoTokenizer, PreTrainedTokenizer
logger = logging.getLogger("test-tokenizer-random") logger = logging.getLogger("test-tokenizer-random")
@ -129,7 +129,7 @@ class Tokenizer:
class TokenizerGroundtruth (Tokenizer): class TokenizerGroundtruth (Tokenizer):
def __init__(self, dir_tokenizer: str): def __init__(self, dir_tokenizer: str):
self.model = AutoTokenizer.from_pretrained(dir_tokenizer) self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
# guess BOS and EOS # guess BOS and EOS
ids = self.encode("a") ids = self.encode("a")
assert 1 <= len(ids) <= 3 assert 1 <= len(ids) <= 3
@ -143,7 +143,7 @@ class TokenizerGroundtruth (Tokenizer):
self.vocab = list(sorted(self.vocab)) self.vocab = list(sorted(self.vocab))
# tokens and lists # tokens and lists
self.special_tokens = list(self.model.all_special_tokens) self.special_tokens = list(self.model.all_special_tokens)
self.added_tokens = list(self.model.added_tokens_encoder) self.added_tokens = self.model.batch_decode(self.model.added_tokens_encoder.values(), skip_special_tokens=False)
self.bos_token = self.model.bos_token self.bos_token = self.model.bos_token
self.eos_token = self.model.eos_token self.eos_token = self.model.eos_token
@ -232,6 +232,7 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
'a\na', # bert fail 'a\na', # bert fail
'"`', # falcon '"`', # falcon
' \u2e4e', # falcon ' \u2e4e', # falcon
'\n\x0b ', # falcon
'a\xa0\xa0\x00b', # jina-v2-es 'a\xa0\xa0\x00b', # jina-v2-es
'one <mask>', # jina-v2-es <mask> lstrip=true 'one <mask>', # jina-v2-es <mask> lstrip=true
'a </s> b', # rstrip phi-3 'a </s> b', # rstrip phi-3