Merge remote-tracking branch 'origin/master' into tool-call

This commit is contained in:
ochafik 2024-12-14 15:04:45 +00:00
commit 055053c859
68 changed files with 3951 additions and 2713 deletions

View File

@ -552,35 +552,44 @@ jobs:
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
# TODO: tmp disabled. see for possible re-enable: macOS-latest-swift:
# https://github.com/ggerganov/llama.cpp/pull/10525 runs-on: macos-latest
# macOS-latest-swift:
# runs-on: macos-latest strategy:
# matrix:
# strategy: destination: ['generic/platform=macOS', 'generic/platform=iOS', 'generic/platform=tvOS']
# matrix:
# destination: ['generic/platform=macOS', 'generic/platform=iOS', 'generic/platform=tvOS'] steps:
# - name: Clone
# steps: id: checkout
# - name: Clone uses: actions/checkout@v4
# id: checkout
# uses: actions/checkout@v4 - name: Dependencies
# id: depends
# - name: Dependencies continue-on-error: true
# id: depends run: |
# continue-on-error: true brew update
# run: |
# brew update - name: Build llama.cpp with CMake
# id: cmake_build
# - name: xcodebuild for swift package run: |
# id: xcodebuild sysctl -a
# run: | mkdir build
# xcodebuild -scheme llama -destination "${{ matrix.destination }}" cd build
# cmake -G Xcode .. \
# - name: Build Swift Example -DGGML_METAL_USE_BF16=ON \
# id: make_build_swift_example -DGGML_METAL_EMBED_LIBRARY=ON \
# run: | -DLLAMA_BUILD_EXAMPLES=OFF \
# make swift -DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64"
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
sudo cmake --install . --config Release
- name: xcodebuild for swift package
id: xcodebuild
run: |
xcodebuild -scheme llama-Package -destination "${{ matrix.destination }}"
windows-msys2: windows-msys2:
runs-on: windows-latest runs-on: windows-latest
@ -1104,6 +1113,29 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Build
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake -G Xcode .. \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-DCMAKE_SYSTEM_NAME=iOS \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
sudo cmake --install . --config Release
- name: xcodebuild for swift package
id: xcodebuild
run: |
xcodebuild -scheme llama-Package -destination 'generic/platform=iOS'
- name: Build Xcode project - name: Build Xcode project
run: xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' build run: xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' build
@ -1131,23 +1163,6 @@ jobs:
./gradlew build --no-daemon ./gradlew build --no-daemon
# freeBSD-latest:
# runs-on: macos-12
# steps:
# - name: Clone
# uses: actions/checkout@v4
#
# - name: Build
# uses: cross-platform-actions/action@v0.19.0
# with:
# operating_system: freebsd
# version: '13.2'
# hypervisor: 'qemu'
# run: |
# sudo pkg update
# sudo pkg install -y gmake automake autoconf pkgconf llvm15 openblas
# gmake CC=/usr/local/bin/clang15 CXX=/usr/local/bin/clang++15 -j `sysctl -n hw.ncpu`
release: release:
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}

View File

@ -46,11 +46,9 @@ if (WIN32)
add_compile_definitions(_CRT_SECURE_NO_WARNINGS) add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
endif() endif()
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") if (MSVC)
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/source-charset:utf-8>") add_compile_options("$<$<COMPILE_LANGUAGE:C>:/utf-8>")
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/source-charset:utf-8>") add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/utf-8>")
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/execution-charset:utf-8>")
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/execution-charset:utf-8>")
endif() endif()
# #

View File

@ -31,6 +31,13 @@
{ "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } }, { "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } },
{ "name": "vulkan", "hidden": true, "cacheVariables": { "GGML_VULKAN": "ON" } }, { "name": "vulkan", "hidden": true, "cacheVariables": { "GGML_VULKAN": "ON" } },
{
"name": "x64-windows-llvm", "hidden": true,
"cacheVariables": {
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/x64-windows-llvm.cmake"
}
},
{ {
"name": "arm64-windows-msvc", "hidden": true, "name": "arm64-windows-msvc", "hidden": true,
"architecture": { "value": "arm64", "strategy": "external" }, "architecture": { "value": "arm64", "strategy": "external" },
@ -70,6 +77,11 @@
{ "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg" ] }, { "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg" ] },
{ "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg", "static" ] }, { "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg", "static" ] },
{ "name": "x64-windows-llvm-debug", "inherits": [ "base", "x64-windows-llvm", "debug" ] },
{ "name": "x64-windows-llvm-release", "inherits": [ "base", "x64-windows-llvm", "release" ] },
{ "name": "x64-windows-llvm-reldbg", "inherits": [ "base", "x64-windows-llvm", "reldbg" ] },
{ "name": "x64-windows-llvm+static-release", "inherits": [ "base", "x64-windows-llvm", "reldbg", "static" ] },
{ "name": "x64-windows-msvc-debug", "inherits": [ "base", "debug" ] }, { "name": "x64-windows-msvc-debug", "inherits": [ "base", "debug" ] },
{ "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] }, { "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] },
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] },

View File

@ -448,6 +448,10 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64))
MK_CFLAGS += -march=native -mtune=native MK_CFLAGS += -march=native -mtune=native
HOST_CXXFLAGS += -march=native -mtune=native HOST_CXXFLAGS += -march=native -mtune=native
# Usage AMX build test
#MK_CFLAGS += -march=graniterapids -mtune=graniterapids
#HOST_CXXFLAGS += -march=graniterapids -mtune=graniterapids
# Usage AVX-only # Usage AVX-only
#MK_CFLAGS += -mfma -mf16c -mavx #MK_CFLAGS += -mfma -mf16c -mavx
#MK_CXXFLAGS += -mfma -mf16c -mavx #MK_CXXFLAGS += -mfma -mf16c -mavx
@ -951,7 +955,6 @@ DIR_COMMON = common
OBJ_GGML = \ OBJ_GGML = \
$(DIR_GGML)/src/ggml.o \ $(DIR_GGML)/src/ggml.o \
$(DIR_GGML)/src/ggml-aarch64.o \
$(DIR_GGML)/src/ggml-alloc.o \ $(DIR_GGML)/src/ggml-alloc.o \
$(DIR_GGML)/src/ggml-backend.o \ $(DIR_GGML)/src/ggml-backend.o \
$(DIR_GGML)/src/ggml-backend-reg.o \ $(DIR_GGML)/src/ggml-backend-reg.o \
@ -959,9 +962,11 @@ OBJ_GGML = \
$(DIR_GGML)/src/ggml-quants.o \ $(DIR_GGML)/src/ggml-quants.o \
$(DIR_GGML)/src/ggml-threading.o \ $(DIR_GGML)/src/ggml-threading.o \
$(DIR_GGML)/src/ggml-cpu/ggml-cpu.o \ $(DIR_GGML)/src/ggml-cpu/ggml-cpu.o \
$(DIR_GGML)/src/ggml-cpu/ggml-cpu-cpp.o \ $(DIR_GGML)/src/ggml-cpu/ggml-cpu_cpp.o \
$(DIR_GGML)/src/ggml-cpu/ggml-cpu-aarch64.o \ $(DIR_GGML)/src/ggml-cpu/ggml-cpu-aarch64.o \
$(DIR_GGML)/src/ggml-cpu/ggml-cpu-hbm.o \
$(DIR_GGML)/src/ggml-cpu/ggml-cpu-quants.o \ $(DIR_GGML)/src/ggml-cpu/ggml-cpu-quants.o \
$(DIR_GGML)/src/ggml-cpu/ggml-cpu-traits.o \
$(OBJ_GGML_EXT) $(OBJ_GGML_EXT)
OBJ_LLAMA = \ OBJ_LLAMA = \
@ -1102,17 +1107,10 @@ DEP_FILES = $(OBJ_GGML:.o=.d) $(OBJ_LLAMA:.o=.d) $(OBJ_COMMON:.o=.d)
# Default target # Default target
all: $(BUILD_TARGETS) all: $(BUILD_TARGETS)
# force c++ build for source file that have same name as c file
# Note: need this exception because `ggml-cpu.c` and `ggml-cpu.cpp` both produce the same obj/dep files # Note: need this exception because `ggml-cpu.c` and `ggml-cpu.cpp` both produce the same obj/dep files
# g++ -M -I ./ggml/include/ -I ./ggml/src ggml/src/ggml-cpu/ggml-cpu.cpp | grep ggml $(DIR_GGML)/%_cpp.o: $(DIR_GGML)/%.cpp
$(DIR_GGML)/src/ggml-cpu/ggml-cpu-cpp.o: \ $(CXX) $(CXXFLAGS) -MMD -c $< -o $@
ggml/src/ggml-cpu/ggml-cpu.cpp \
ggml/include/ggml-backend.h \
ggml/include/ggml.h \
ggml/include/ggml-alloc.h \
ggml/src/ggml-backend-impl.h \
ggml/include/ggml-cpu.h \
ggml/src/ggml-impl.h
$(CXX) $(CXXFLAGS) -c $< -o $@
# Rules for building object files # Rules for building object files
$(DIR_GGML)/%.o: $(DIR_GGML)/%.c $(DIR_GGML)/%.o: $(DIR_GGML)/%.c

View File

@ -2,59 +2,6 @@
import PackageDescription import PackageDescription
var sources = [
"src/llama.cpp",
"src/llama-vocab.cpp",
"src/llama-grammar.cpp",
"src/llama-sampling.cpp",
"src/unicode.cpp",
"src/unicode-data.cpp",
"ggml/src/ggml.c",
"ggml/src/ggml-aarch64.c",
"ggml/src/ggml-alloc.c",
"ggml/src/ggml-backend.cpp",
"ggml/src/ggml-backend-reg.cpp",
"ggml/src/ggml-cpu/ggml-cpu.c",
"ggml/src/ggml-cpu/ggml-cpu.cpp",
"ggml/src/ggml-cpu/ggml-cpu-aarch64.c",
"ggml/src/ggml-cpu/ggml-cpu-quants.c",
"ggml/src/ggml-threading.cpp",
"ggml/src/ggml-quants.c",
]
var resources: [Resource] = []
var linkerSettings: [LinkerSetting] = []
var cSettings: [CSetting] = [
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
.unsafeFlags(["-fno-objc-arc"]),
.headerSearchPath("ggml/src"),
.headerSearchPath("ggml/src/ggml-cpu"),
// NOTE: NEW_LAPACK will required iOS version 16.4+
// We should consider add this in the future when we drop support for iOS 14
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
// .define("ACCELERATE_NEW_LAPACK"),
// .define("ACCELERATE_LAPACK_ILP64")
.define("GGML_USE_CPU"),
]
#if canImport(Darwin)
sources.append("ggml/src/ggml-common.h")
sources.append("ggml/src/ggml-metal/ggml-metal.m")
resources.append(.process("ggml/src/ggml-metal/ggml-metal.metal"))
linkerSettings.append(.linkedFramework("Accelerate"))
cSettings.append(
contentsOf: [
.define("GGML_USE_ACCELERATE"),
.define("GGML_USE_METAL"),
]
)
#endif
#if os(Linux)
cSettings.append(.define("_GNU_SOURCE"))
#endif
let package = Package( let package = Package(
name: "llama", name: "llama",
platforms: [ platforms: [
@ -67,26 +14,6 @@ let package = Package(
.library(name: "llama", targets: ["llama"]), .library(name: "llama", targets: ["llama"]),
], ],
targets: [ targets: [
.target( .systemLibrary(name: "llama", pkgConfig: "llama"),
name: "llama", ]
path: ".",
exclude: [
"build",
"cmake",
"examples",
"scripts",
"models",
"tests",
"CMakeLists.txt",
"Makefile",
"ggml/src/ggml-metal-embed.metal"
],
sources: sources,
resources: resources,
publicHeadersPath: "spm-headers",
cSettings: cSettings,
linkerSettings: linkerSettings
)
],
cxxLanguageStandard: .cxx17
) )

4
Sources/llama/llama.h Normal file
View File

@ -0,0 +1,4 @@
#pragma once
#include <llama.h>

View File

@ -0,0 +1,5 @@
module llama [system] {
header "llama.h"
link "llama"
export *
}

View File

@ -6,5 +6,5 @@ includedir=${prefix}/include
Name: llama Name: llama
Description: Port of Facebook's LLaMA model in C/C++ Description: Port of Facebook's LLaMA model in C/C++
Version: @PROJECT_VERSION@ Version: @PROJECT_VERSION@
Libs: -L${libdir} -lllama Libs: -L${libdir} -lggml -lggml-base -lllama
Cflags: -I${includedir} Cflags: -I${includedir}

View File

@ -0,0 +1,11 @@
set( CMAKE_SYSTEM_NAME Windows )
set( CMAKE_SYSTEM_PROCESSOR x86_64 )
set( CMAKE_C_COMPILER clang )
set( CMAKE_CXX_COMPILER clang++ )
set( arch_c_flags "-march=native" )
set( CMAKE_C_FLAGS_INIT "${arch_c_flags}" )
set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags}" )

View File

@ -222,7 +222,7 @@ struct common_params {
struct common_params_speculative speculative; struct common_params_speculative speculative;
std::string model = ""; // model path // NOLINT std::string model = ""; // model path // NOLINT
std::string model_alias = "unknown"; // model alias // NOLINT std::string model_alias = ""; // model alias // NOLINT
std::string model_url = ""; // model url to download // NOLINT std::string model_url = ""; // model url to download // NOLINT
std::string hf_token = ""; // HF token // NOLINT std::string hf_token = ""; // HF token // NOLINT
std::string hf_repo = ""; // HF repo // NOLINT std::string hf_repo = ""; // HF repo // NOLINT

View File

@ -62,6 +62,10 @@ struct common_speculative * common_speculative_init(
} }
void common_speculative_free(struct common_speculative * spec) { void common_speculative_free(struct common_speculative * spec) {
if (spec == nullptr) {
return;
}
common_sampler_free(spec->smpl); common_sampler_free(spec->smpl);
llama_batch_free(spec->batch); llama_batch_free(spec->batch);

View File

@ -661,6 +661,9 @@ class Model:
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b" res = "minerva-7b"
if chkhsh == "8b5a93ed704057481f240da0be7e7dca721d7f8f4755263b6807227a2cbeae65":
# ref: https://huggingface.co/sentence-transformers/stsb-roberta-base
res = "roberta-bpe"
if res is None: if res is None:
logger.warning("\n") logger.warning("\n")
@ -1989,6 +1992,14 @@ class Qwen2Model(Model):
except FileNotFoundError: except FileNotFoundError:
self._set_vocab_gpt2() self._set_vocab_gpt2()
def set_gguf_parameters(self):
super().set_gguf_parameters()
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "yarn":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
@Model.register("Qwen2MoeForCausalLM") @Model.register("Qwen2MoeForCausalLM")
class Qwen2MoeModel(Model): class Qwen2MoeModel(Model):
@ -2533,7 +2544,7 @@ class InternLM2Model(Model):
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
@Model.register("BertModel", "CamembertModel") @Model.register("BertModel", "CamembertModel", "RobertaModel")
class BertModel(Model): class BertModel(Model):
model_arch = gguf.MODEL_ARCH.BERT model_arch = gguf.MODEL_ARCH.BERT
@ -2574,7 +2585,8 @@ class BertModel(Model):
# we need this to validate the size of the token_type embeddings # we need this to validate the size of the token_type embeddings
# though currently we are passing all zeros to the token_type embeddings # though currently we are passing all zeros to the token_type embeddings
self.gguf_writer.add_token_type_count(2) # "Sequence A" or "Sequence B" # "Sequence A" or "Sequence B"
self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
# convert to phantom space vocab # convert to phantom space vocab
def phantom(tok): def phantom(tok):

View File

@ -103,6 +103,7 @@ models = [
{"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", }, {"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", },
{"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", }, {"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", },
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", }, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", },
{"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"},
] ]

View File

@ -55,7 +55,14 @@ cmake --build build --config Release
cmake --preset arm64-windows-llvm-release -D GGML_OPENMP=OFF cmake --preset arm64-windows-llvm-release -D GGML_OPENMP=OFF
cmake --build build-arm64-windows-llvm-release cmake --build build-arm64-windows-llvm-release
``` ```
Building for arm64 can also be done with the MSVC compiler with the build-arm64-windows-MSVC preset, or the standard CMake build instructions. However, note that the MSVC compiler does not support inline ARM assembly code, used e.g. for the accelerated Q4_0_4_8 CPU kernels. Building for arm64 can also be done with the MSVC compiler with the build-arm64-windows-MSVC preset, or the standard CMake build instructions. However, note that the MSVC compiler does not support inline ARM assembly code, used e.g. for the accelerated Q4_0_N_M CPU kernels.
For building with ninja generator and clang compiler as default:
-set path:set LIB=C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\x64;C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Tools\MSVC\14.41.34120\lib\x64\uwp;C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\x64
```bash
cmake --preset x64-windows-llvm-release
cmake --build build-x64-windows-llvm-release
```
## BLAS Build ## BLAS Build

View File

@ -210,20 +210,20 @@ actor LlamaContext {
llama_kv_cache_clear(context) llama_kv_cache_clear(context)
let t_pp_start = ggml_time_us() let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000;
if llama_decode(context, batch) != 0 { if llama_decode(context, batch) != 0 {
print("llama_decode() failed during prompt") print("llama_decode() failed during prompt")
} }
llama_synchronize(context) llama_synchronize(context)
let t_pp_end = ggml_time_us() let t_pp_end = DispatchTime.now().uptimeNanoseconds / 1000;
// bench text generation // bench text generation
llama_kv_cache_clear(context) llama_kv_cache_clear(context)
let t_tg_start = ggml_time_us() let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000;
for i in 0..<tg { for i in 0..<tg {
llama_batch_clear(&batch) llama_batch_clear(&batch)
@ -238,7 +238,7 @@ actor LlamaContext {
llama_synchronize(context) llama_synchronize(context)
} }
let t_tg_end = ggml_time_us() let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000;
llama_kv_cache_clear(context) llama_kv_cache_clear(context)

View File

@ -7,6 +7,7 @@
objects = { objects = {
/* Begin PBXBuildFile section */ /* Begin PBXBuildFile section */
1809696D2D05A39F00400EE8 /* llama in Frameworks */ = {isa = PBXBuildFile; productRef = 1809696C2D05A39F00400EE8 /* llama */; };
549479CB2AC9E16000E0F78B /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 549479CA2AC9E16000E0F78B /* Metal.framework */; }; 549479CB2AC9E16000E0F78B /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 549479CA2AC9E16000E0F78B /* Metal.framework */; };
79E1D9CD2B4CD16E005F8E46 /* InputButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = 79E1D9CC2B4CD16E005F8E46 /* InputButton.swift */; }; 79E1D9CD2B4CD16E005F8E46 /* InputButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = 79E1D9CC2B4CD16E005F8E46 /* InputButton.swift */; };
7FA3D2B32B2EA2F600543F92 /* DownloadButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7FA3D2B22B2EA2F600543F92 /* DownloadButton.swift */; }; 7FA3D2B32B2EA2F600543F92 /* DownloadButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7FA3D2B22B2EA2F600543F92 /* DownloadButton.swift */; };
@ -17,7 +18,6 @@
8A3F84242AC4C891005E2EE8 /* models in Resources */ = {isa = PBXBuildFile; fileRef = 8A3F84232AC4C891005E2EE8 /* models */; }; 8A3F84242AC4C891005E2EE8 /* models in Resources */ = {isa = PBXBuildFile; fileRef = 8A3F84232AC4C891005E2EE8 /* models */; };
8A907F332AC7138A006146EA /* LibLlama.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A907F322AC7134E006146EA /* LibLlama.swift */; }; 8A907F332AC7138A006146EA /* LibLlama.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A907F322AC7134E006146EA /* LibLlama.swift */; };
8A9F7C4D2AC332EE008AE1EA /* LlamaState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A9F7C4C2AC332EE008AE1EA /* LlamaState.swift */; }; 8A9F7C4D2AC332EE008AE1EA /* LlamaState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A9F7C4C2AC332EE008AE1EA /* LlamaState.swift */; };
DF810E132B4A5BA200301144 /* llama in Frameworks */ = {isa = PBXBuildFile; productRef = DF810E122B4A5BA200301144 /* llama */; };
F1FE20E22B465ECA00B45541 /* LoadCustomButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = F1FE20E12B465EC900B45541 /* LoadCustomButton.swift */; }; F1FE20E22B465ECA00B45541 /* LoadCustomButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = F1FE20E12B465EC900B45541 /* LoadCustomButton.swift */; };
/* End PBXBuildFile section */ /* End PBXBuildFile section */
@ -42,7 +42,7 @@
isa = PBXFrameworksBuildPhase; isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647; buildActionMask = 2147483647;
files = ( files = (
DF810E132B4A5BA200301144 /* llama in Frameworks */, 1809696D2D05A39F00400EE8 /* llama in Frameworks */,
549479CB2AC9E16000E0F78B /* Metal.framework in Frameworks */, 549479CB2AC9E16000E0F78B /* Metal.framework in Frameworks */,
8A39BE0A2AC7601100BFEB40 /* Accelerate.framework in Frameworks */, 8A39BE0A2AC7601100BFEB40 /* Accelerate.framework in Frameworks */,
); );
@ -151,7 +151,7 @@
); );
name = llama.swiftui; name = llama.swiftui;
packageProductDependencies = ( packageProductDependencies = (
DF810E122B4A5BA200301144 /* llama */, 1809696C2D05A39F00400EE8 /* llama */,
); );
productName = llama.swiftui; productName = llama.swiftui;
productReference = 8A1C83732AC328BD0096AF73 /* llama.swiftui.app */; productReference = 8A1C83732AC328BD0096AF73 /* llama.swiftui.app */;
@ -429,7 +429,7 @@
/* End XCConfigurationList section */ /* End XCConfigurationList section */
/* Begin XCSwiftPackageProductDependency section */ /* Begin XCSwiftPackageProductDependency section */
DF810E122B4A5BA200301144 /* llama */ = { 1809696C2D05A39F00400EE8 /* llama */ = {
isa = XCSwiftPackageProductDependency; isa = XCSwiftPackageProductDependency;
productName = llama; productName = llama;
}; };

View File

@ -54,8 +54,6 @@ As the models are currently fully loaded into memory, you will need adequate dis
Several quantization methods are supported. They differ in the resulting model disk size and inference speed. Several quantization methods are supported. They differ in the resulting model disk size and inference speed.
The quantization formats `Q4_0_4_4`, `Q4_0_4_8` and `Q4_0_8_8` are block interleaved variants of the `Q4_0` format, providing a data layout that is better suited for specific implementations of optimized mulmat kernels. Since these formats differ only in data layout, they have the same quantized size as the `Q4_0` format.
*(outdated)* *(outdated)*
| Model | Measure | F16 | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 | | Model | Measure | F16 | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 |

View File

@ -48,9 +48,6 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 5.33G, +0.0569 ppl @ Llama-3-8B", }, { "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 5.33G, +0.0569 ppl @ Llama-3-8B", },
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 6.14G, +0.0217 ppl @ Llama-3-8B", }, { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 6.14G, +0.0217 ppl @ Llama-3-8B", },
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 7.96G, +0.0026 ppl @ Llama-3-8B", }, { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 7.96G, +0.0026 ppl @ Llama-3-8B", },
{ "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, +0.0020 ppl @ Mistral-7B", }, { "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, +0.0020 ppl @ Mistral-7B", },
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", }, { "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", },
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", }, { "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },

View File

@ -34,14 +34,6 @@ endforeach()
add_executable(${TARGET} ${TARGET_SRCS}) add_executable(${TARGET} ${TARGET_SRCS})
install(TARGETS ${TARGET} RUNTIME) install(TARGETS ${TARGET} RUNTIME)
# clean up generated files in pre-build step
foreach(asset ${PUBLIC_ASSETS})
set(output "${CMAKE_CURRENT_BINARY_DIR}/${asset}.hpp")
add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND "${CMAKE_COMMAND}" -E remove -f "${output}"
)
endforeach()
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
if (LLAMA_SERVER_SSL) if (LLAMA_SERVER_SSL)

View File

@ -473,9 +473,11 @@ Notice that each `probs` is an array of length `n_probs`.
- `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.). - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.).
- `model`: The path to the model loaded with `-m` - `model`: The path to the model loaded with `-m`
- `prompt`: The provided `prompt` - `prompt`: The provided `prompt`
- `stopped_eos`: Indicating whether the completion has stopped because it encountered the EOS token - `stop_type`: Indicating whether the completion has stopped. Possible values are:
- `stopped_limit`: Indicating whether the completion stopped because `n_predict` tokens were generated before stop words or EOS was encountered - `none`: Generating (not stopped)
- `stopped_word`: Indicating whether the completion stopped due to encountering a stopping word from `stop` JSON array provided - `eos`: Stopped because it encountered the EOS token
- `limit`: Stopped because `n_predict` tokens were generated before stop words or EOS was encountered
- `word`: Stopped due to encountering a stopping word from `stop` JSON array provided
- `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word) - `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word)
- `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second` - `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second`
- `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`) - `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`)
@ -616,14 +618,83 @@ This endpoint is public (no API key check). By default, it is read-only. To make
```json ```json
{ {
"default_generation_settings": { ... }, "default_generation_settings": {
"id": 0,
"id_task": -1,
"n_ctx": 1024,
"speculative": false,
"is_processing": false,
"params": {
"n_predict": -1,
"seed": 4294967295,
"temperature": 0.800000011920929,
"dynatemp_range": 0.0,
"dynatemp_exponent": 1.0,
"top_k": 40,
"top_p": 0.949999988079071,
"min_p": 0.05000000074505806,
"xtc_probability": 0.0,
"xtc_threshold": 0.10000000149011612,
"typical_p": 1.0,
"repeat_last_n": 64,
"repeat_penalty": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"dry_multiplier": 0.0,
"dry_base": 1.75,
"dry_allowed_length": 2,
"dry_penalty_last_n": -1,
"dry_sequence_breakers": [
"\n",
":",
"\"",
"*"
],
"mirostat": 0,
"mirostat_tau": 5.0,
"mirostat_eta": 0.10000000149011612,
"penalize_nl": false,
"stop": [],
"max_tokens": -1,
"n_keep": 0,
"n_discard": 0,
"ignore_eos": false,
"stream": true,
"n_probs": 0,
"min_keep": 0,
"grammar": "",
"samplers": [
"dry",
"top_k",
"typ_p",
"top_p",
"min_p",
"xtc",
"temperature"
],
"speculative.n_max": 16,
"speculative.n_min": 5,
"speculative.p_min": 0.8999999761581421,
"timings_per_token": false
},
"prompt": "",
"next_token": {
"has_next_token": true,
"has_new_line": false,
"n_remain": -1,
"n_decoded": 0,
"stopping_word": ""
}
},
"total_slots": 1, "total_slots": 1,
"chat_template": "" "model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
"chat_template": "..."
} }
``` ```
- `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint. - `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint.
- `total_slots` - the total number of slots for process requests (defined by `--parallel` option) - `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
- `model_path` - the path to model file (same with `-m` argument)
- `chat_template` - the model's original Jinja2 prompt template - `chat_template` - the model's original Jinja2 prompt template
### POST `/props`: Change server global properties. ### POST `/props`: Change server global properties.
@ -817,54 +888,72 @@ Example:
```json ```json
[ [
{ {
"dynatemp_exponent": 1.0,
"dynatemp_range": 0.0,
"frequency_penalty": 0.0,
"grammar": "",
"id": 0, "id": 0,
"ignore_eos": false, "id_task": -1,
"n_ctx": 1024,
"speculative": false,
"is_processing": false, "is_processing": false,
"logit_bias": [], "params": {
"min_p": 0.05000000074505806, "n_predict": -1,
"mirostat": 0, "seed": 4294967295,
"mirostat_eta": 0.10000000149011612, "temperature": 0.800000011920929,
"mirostat_tau": 5.0, "dynatemp_range": 0.0,
"model": "llama-2-7b-32k-instruct.Q2_K.gguf", "dynatemp_exponent": 1.0,
"n_ctx": 2048,
"n_keep": 0,
"n_predict": 100000,
"n_probs": 0,
"next_token": {
"has_next_token": true,
"n_remain": -1,
"n_decoded": 0,
"stopped_eos": false,
"stopped_limit": false,
"stopped_word": false,
"stopping_word": ""
},
"penalize_nl": true,
"presence_penalty": 0.0,
"prompt": "Say hello to llama.cpp",
"repeat_last_n": 64,
"repeat_penalty": 1.100000023841858,
"samplers": [
"top_k",
"typical_p",
"top_p",
"min_p",
"temperature"
],
"seed": 42,
"stop": [
"\n"
],
"stream": false,
"task_id": 0,
"temperature": 0.0,
"top_k": 40, "top_k": 40,
"top_p": 0.949999988079071, "top_p": 0.949999988079071,
"typical_p": 1.0 "min_p": 0.05000000074505806,
"xtc_probability": 0.0,
"xtc_threshold": 0.10000000149011612,
"typical_p": 1.0,
"repeat_last_n": 64,
"repeat_penalty": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"dry_multiplier": 0.0,
"dry_base": 1.75,
"dry_allowed_length": 2,
"dry_penalty_last_n": -1,
"dry_sequence_breakers": [
"\n",
":",
"\"",
"*"
],
"mirostat": 0,
"mirostat_tau": 5.0,
"mirostat_eta": 0.10000000149011612,
"penalize_nl": false,
"stop": [],
"max_tokens": -1,
"n_keep": 0,
"n_discard": 0,
"ignore_eos": false,
"stream": true,
"n_probs": 0,
"min_keep": 0,
"grammar": "",
"samplers": [
"dry",
"top_k",
"typ_p",
"top_p",
"min_p",
"xtc",
"temperature"
],
"speculative.n_max": 16,
"speculative.n_min": 5,
"speculative.p_min": 0.8999999761581421,
"timings_per_token": false
},
"prompt": "",
"next_token": {
"has_next_token": true,
"has_new_line": false,
"n_remain": -1,
"n_decoded": 0,
"stopping_word": ""
}
} }
] ]
``` ```

File diff suppressed because it is too large Load Diff

View File

@ -44,11 +44,10 @@ To run with stdout/stderr display in real time (verbose output, but useful for d
DEBUG=1 ./tests.sh -s -v -x DEBUG=1 ./tests.sh -s -v -x
``` ```
Some tests (especially `@slow` ones) require model downloads. Since this can time out the tests, you can pre-download them in the cache ahead of time with: Hint: You can compile and run test in single command, useful for local developement:
```shell ```shell
pip install -r examples/server/tests/requirements.txt cmake --build build -j --target llama-server && ./examples/server/tests/tests.sh
python scripts/fetch_server_test_models.py
``` ```
To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html) To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html)

View File

@ -1,5 +1,9 @@
#!/bin/bash #!/bin/bash
# make sure we are in the right directory
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
cd $SCRIPT_DIR
set -eu set -eu
if [ $# -lt 1 ] if [ $# -lt 1 ]

View File

@ -22,7 +22,12 @@ def test_server_props():
server.start() server.start()
res = server.make_request("GET", "/props") res = server.make_request("GET", "/props")
assert res.status_code == 200 assert res.status_code == 200
assert ".gguf" in res.body["model_path"]
assert res.body["total_slots"] == server.n_slots assert res.body["total_slots"] == server.n_slots
default_val = res.body["default_generation_settings"]
assert server.n_ctx is not None and server.n_slots is not None
assert default_val["n_ctx"] == server.n_ctx / server.n_slots
assert default_val["params"]["seed"] == server.seed
def test_server_models(): def test_server_models():
@ -33,6 +38,31 @@ def test_server_models():
assert len(res.body["data"]) == 1 assert len(res.body["data"]) == 1
assert res.body["data"][0]["id"] == server.model_alias assert res.body["data"][0]["id"] == server.model_alias
def test_server_slots():
global server
# without slots endpoint enabled, this should return error
server.server_slots = False
server.start()
res = server.make_request("GET", "/slots")
assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED
assert "error" in res.body
server.stop()
# with slots endpoint enabled, this should return slots info
server.server_slots = True
server.n_slots = 2
server.start()
res = server.make_request("GET", "/slots")
assert res.status_code == 200
assert len(res.body) == server.n_slots
assert server.n_ctx is not None and server.n_slots is not None
assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
assert "params" in res.body[0]
assert res.body[0]["params"]["seed"] == server.seed
def test_load_split_model(): def test_load_split_model():
global server global server
server.model_hf_repo = "ggml-org/models" server.model_hf_repo = "ggml-org/models"

View File

@ -12,13 +12,13 @@ def create_server():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated", "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
[ [
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False), (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
] ]
) )
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated): def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
global server global server
server.start() server.start()
res = server.make_request("POST", "/chat/completions", data={ res = server.make_request("POST", "/chat/completions", data={
@ -30,29 +30,28 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
], ],
}) })
assert res.status_code == 200 assert res.status_code == 200
assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
assert res.body["model"] == model if model is not None else server.model_alias
assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["completion_tokens"] == n_predicted assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0] choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"] assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"]) assert match_regex(re_content, choice["message"]["content"])
if truncated: assert choice["finish_reason"] == finish_reason
assert choice["finish_reason"] == "length"
else:
assert choice["finish_reason"] == "stop"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated", "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
[ [
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False), ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False), ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
] ]
) )
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated): def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
global server global server
server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
server.start() server.start()
res = server.make_stream_request("POST", "/chat/completions", data={ res = server.make_stream_request("POST", "/chat/completions", data={
"model": model,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"messages": [ "messages": [
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
@ -61,18 +60,19 @@ def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, r
"stream": True, "stream": True,
}) })
content = "" content = ""
last_cmpl_id = None
for data in res: for data in res:
choice = data["choices"][0] choice = data["choices"][0]
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
if last_cmpl_id is None:
last_cmpl_id = data["id"]
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
if choice["finish_reason"] in ["stop", "length"]: if choice["finish_reason"] in ["stop", "length"]:
assert data["usage"]["prompt_tokens"] == n_prompt assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted assert data["usage"]["completion_tokens"] == n_predicted
assert "content" not in choice["delta"] assert "content" not in choice["delta"]
assert match_regex(re_content, content) assert match_regex(re_content, content)
# FIXME: not sure why this is incorrect in stream mode assert choice["finish_reason"] == finish_reason
# if truncated:
# assert choice["finish_reason"] == "length"
# else:
# assert choice["finish_reason"] == "stop"
else: else:
assert choice["finish_reason"] is None assert choice["finish_reason"] is None
content += choice["delta"]["content"] content += choice["delta"]["content"]
@ -93,7 +93,7 @@ def test_chat_completion_with_openai_library():
temperature=0.8, temperature=0.8,
) )
print(res) print(res)
assert res.choices[0].finish_reason == "stop" assert res.choices[0].finish_reason == "length"
assert res.choices[0].message.content is not None assert res.choices[0].message.content is not None
assert match_regex("(Suddenly)+", res.choices[0].message.content) assert match_regex("(Suddenly)+", res.choices[0].message.content)

View File

@ -42,15 +42,39 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
}) })
content = "" content = ""
for data in res: for data in res:
assert "stop" in data and type(data["stop"]) == bool
if data["stop"]: if data["stop"]:
assert data["timings"]["prompt_n"] == n_prompt assert data["timings"]["prompt_n"] == n_prompt
assert data["timings"]["predicted_n"] == n_predicted assert data["timings"]["predicted_n"] == n_predicted
assert data["truncated"] == truncated assert data["truncated"] == truncated
assert data["stop_type"] == "limit"
assert "generation_settings" in data
assert server.n_predict is not None
assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
assert data["generation_settings"]["seed"] == server.seed
assert match_regex(re_content, content) assert match_regex(re_content, content)
else: else:
content += data["content"] content += data["content"]
def test_completion_stream_vs_non_stream():
global server
server.start()
res_stream = server.make_stream_request("POST", "/completion", data={
"n_predict": 8,
"prompt": "I believe the meaning of life is",
"stream": True,
})
res_non_stream = server.make_request("POST", "/completion", data={
"n_predict": 8,
"prompt": "I believe the meaning of life is",
})
content_stream = ""
for data in res_stream:
content_stream += data["content"]
assert content_stream == res_non_stream.body["content"]
@pytest.mark.parametrize("n_slots", [1, 2]) @pytest.mark.parametrize("n_slots", [1, 2])
def test_consistent_result_same_seed(n_slots: int): def test_consistent_result_same_seed(n_slots: int):
global server global server
@ -221,3 +245,24 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
assert len(res.body["content"]) > 10 assert len(res.body["content"]) > 10
# FIXME: the result is not deterministic when using other slot than slot 0 # FIXME: the result is not deterministic when using other slot than slot 0
# assert match_regex(re_content, res.body["content"]) # assert match_regex(re_content, res.body["content"])
def test_n_probs():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"n_probs": 10,
"temperature": 0.0,
"n_predict": 5,
})
assert res.status_code == 200
assert "completion_probabilities" in res.body
assert len(res.body["completion_probabilities"]) == 5
for tok in res.body["completion_probabilities"]:
assert "probs" in tok
assert len(tok["probs"]) == 10
for prob in tok["probs"]:
assert "prob" in prob
assert "tok_str" in prob
assert 0.0 <= prob["prob"] <= 1.0

View File

@ -13,28 +13,28 @@ def test_infill_without_input_extra():
global server global server
server.start() server.start()
res = server.make_request("POST", "/infill", data={ res = server.make_request("POST", "/infill", data={
"prompt": "Complete this", "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", "prompt": " int n_threads = llama_",
"input_suffix": "}\n", "input_suffix": "}\n",
}) })
assert res.status_code == 200 assert res.status_code == 200
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"]) assert match_regex("(Ann|small|shiny)+", res.body["content"])
def test_infill_with_input_extra(): def test_infill_with_input_extra():
global server global server
server.start() server.start()
res = server.make_request("POST", "/infill", data={ res = server.make_request("POST", "/infill", data={
"prompt": "Complete this",
"input_extra": [{ "input_extra": [{
"filename": "llama.h", "filename": "llama.h",
"text": "LLAMA_API int32_t llama_n_threads();\n" "text": "LLAMA_API int32_t llama_n_threads();\n"
}], }],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"prompt": " int n_threads = llama_",
"input_suffix": "}\n", "input_suffix": "}\n",
}) })
assert res.status_code == 200 assert res.status_code == 200
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"]) assert match_regex("(Dad|excited|park)+", res.body["content"])
@pytest.mark.parametrize("input_extra", [ @pytest.mark.parametrize("input_extra", [
@ -48,10 +48,30 @@ def test_invalid_input_extra_req(input_extra):
global server global server
server.start() server.start()
res = server.make_request("POST", "/infill", data={ res = server.make_request("POST", "/infill", data={
"prompt": "Complete this",
"input_extra": [input_extra], "input_extra": [input_extra],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"prompt": " int n_threads = llama_",
"input_suffix": "}\n", "input_suffix": "}\n",
}) })
assert res.status_code == 400 assert res.status_code == 400
assert "error" in res.body assert "error" in res.body
@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
def test_with_qwen_model():
global server
server.model_file = None
server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF"
server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf"
server.start(timeout_seconds=600)
res = server.make_request("POST", "/infill", data={
"input_extra": [{
"filename": "llama.h",
"text": "LLAMA_API int32_t llama_n_threads();\n"
}],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"prompt": " int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 200
assert res.body["content"] == "n_threads();\n printf(\"Number of threads: %d\\n\", n_threads);\n return 0;\n"

View File

@ -64,6 +64,7 @@ class ServerProcess:
server_embeddings: bool | None = False server_embeddings: bool | None = False
server_reranking: bool | None = False server_reranking: bool | None = False
server_metrics: bool | None = False server_metrics: bool | None = False
server_slots: bool | None = False
draft: int | None = None draft: int | None = None
api_key: str | None = None api_key: str | None = None
response_format: str | None = None response_format: str | None = None
@ -93,7 +94,6 @@ class ServerProcess:
else: else:
server_path = "../../../build/bin/llama-server" server_path = "../../../build/bin/llama-server"
server_args = [ server_args = [
"--slots", # requires to get slot status via /slots endpoint
"--host", "--host",
self.server_host, self.server_host,
"--port", "--port",
@ -131,6 +131,8 @@ class ServerProcess:
server_args.append("--reranking") server_args.append("--reranking")
if self.server_metrics: if self.server_metrics:
server_args.append("--metrics") server_args.append("--metrics")
if self.server_slots:
server_args.append("--slots")
if self.model_alias: if self.model_alias:
server_args.extend(["--alias", self.model_alias]) server_args.extend(["--alias", self.model_alias])
if self.n_ctx: if self.n_ctx:
@ -187,7 +189,7 @@ class ServerProcess:
start_time = time.time() start_time = time.time()
while time.time() - start_time < timeout_seconds: while time.time() - start_time < timeout_seconds:
try: try:
response = self.make_request("GET", "/slots", headers={ response = self.make_request("GET", "/health", headers={
"Authorization": f"Bearer {self.api_key}" if self.api_key else None "Authorization": f"Bearer {self.api_key}" if self.api_key else None
}) })
if response.status_code == 200: if response.status_code == 200:
@ -230,7 +232,7 @@ class ServerProcess:
result.headers = dict(response.headers) result.headers = dict(response.headers)
result.status_code = response.status_code result.status_code = response.status_code
result.body = response.json() if parse_body else None result.body = response.json() if parse_body else None
print("Response from server", result.body) print("Response from server", json.dumps(result.body, indent=2))
return result return result
def make_stream_request( def make_stream_request(
@ -251,7 +253,7 @@ class ServerProcess:
break break
elif line.startswith('data: '): elif line.startswith('data: '):
data = json.loads(line[6:]) data = json.loads(line[6:])
print("Partial response from server", data) print("Partial response from server", json.dumps(data, indent=2))
yield data yield data
@ -375,3 +377,6 @@ def match_regex(regex: str, text: str) -> bool:
).search(text) ).search(text)
is not None is not None
) )
def is_slow_test_allowed():
return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"

View File

@ -22,6 +22,7 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
@ -42,17 +43,6 @@ using json = nlohmann::ordered_json;
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type {
ERROR_TYPE_INVALID_REQUEST,
ERROR_TYPE_AUTHENTICATION,
ERROR_TYPE_SERVER,
ERROR_TYPE_NOT_FOUND,
ERROR_TYPE_PERMISSION,
ERROR_TYPE_UNAVAILABLE, // custom error
ERROR_TYPE_NOT_SUPPORTED, // custom error
};
template <typename T> template <typename T>
static T json_value(const json & body, const std::string & key, const T & default_value) { static T json_value(const json & body, const std::string & key, const T & default_value) {
// Fallback null to default value // Fallback null to default value
@ -176,6 +166,9 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
} else { } else {
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
} }
if (result.empty()) {
throw std::runtime_error("\"prompt\" must not be empty");
}
return result; return result;
} }
@ -474,48 +467,11 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx,
return out; return out;
} }
struct completion_token_output {
llama_token tok;
std::string text_to_send;
struct token_prob {
llama_token tok;
float prob;
};
std::vector<token_prob> probs;
};
// convert a vector of completion_token_output to json
static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
json out = json::array();
for (const auto & prob : probs) {
json probs_for_token = json::array();
for (const auto & p : prob.probs) {
const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
probs_for_token.push_back(json {
{"tok_str", tok_str},
{"prob", p.prob},
});
}
const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
out.push_back(json {
{"content", tok_str},
{"probs", probs_for_token},
});
}
return out;
}
static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) { static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) {
const std::string str = const std::string str =
std::string(event) + ": " + std::string(event) + ": " +
data.dump(-1, ' ', false, json::error_handler_t::replace) + data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n"; // note: these newlines are important (not sure why though, if you know, add a comment to explain) "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
LOG_DBG("data stream, to_send: %s", str.c_str()); LOG_DBG("data stream, to_send: %s", str.c_str());
@ -535,11 +491,14 @@ static json oaicompat_completion_params_parse(
{ {
json llama_params; json llama_params;
llama_params["__oaicompat"] = true;
auto tools = json_value(body, "tools", json()); auto tools = json_value(body, "tools", json());
auto has_tools = tools.is_array() && !tools.empty(); auto has_tools = tools.is_array() && !tools.empty();
auto stream = json_value(body, "stream", json());
if (stream && has_tools) {
throw std::runtime_error("Cannot use tools with stream");
}
// Apply chat template to the list of messages // Apply chat template to the list of messages
llama_params["chat_template"] = tmpl.source(); llama_params["chat_template"] = tmpl.source();
@ -589,22 +548,24 @@ static json oaicompat_completion_params_parse(
if (use_jinja) { if (use_jinja) {
bool allow_content = tool_choice != "required"; bool allow_content = tool_choice != "required";
if (tool_choice != "none" && has_tools) { if (tool_choice != "none" && has_tools) {
llama_params["tools"] = tools;
llama_params["tool_call_style"] = tool_call_style;
auto parallel_tool_calls = body.contains("parallel_tool_calls") ? body.at("parallel_tool_calls") : json(); auto parallel_tool_calls = body.contains("parallel_tool_calls") ? body.at("parallel_tool_calls") : json();
llama_params["parse_tool_calls"] = true;
llama_params["parallel_tool_calls"] = parallel_tool_calls; llama_params["parallel_tool_calls"] = parallel_tool_calls;
auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools, llama_params["json_schema"]); auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools, llama_params["json_schema"]);
llama_params["prompt"] = handler.prompt; llama_params["prompt"] = handler.prompt;
for (const auto & stop : handler.additional_stop_words) { for (const auto & stop : handler.additional_stops) {
llama_params["stop"].push_back(stop); llama_params["stop"].push_back(stop);
} }
if (!handler.grammar_trigger_words.empty()) { if (!handler.grammar_triggers.empty()) {
auto triggers = json::array(); auto triggers = json::array();
for (const auto & word : handler.grammar_trigger_words) { for (const auto & word : handler.grammar_triggers) {
triggers.push_back(word); triggers.push_back(word);
} }
llama_params["grammar_trigger_words"] = triggers; llama_params["grammar_triggers"] = triggers;
} }
if (!handler.grammar.empty()) { if (!handler.grammar.empty()) {
if (llama_params.contains("grammar")) { if (llama_params.contains("grammar")) {
@ -656,192 +617,6 @@ static json oaicompat_completion_params_parse(
return llama_params; return llama_params;
} }
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, llama_tool_call_style tool_call_style, bool streaming = false, bool verbose = false) {
bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
std::string content = json_value(result, "content", std::string(""));
std::string finish_reason = "length";
if (stopped_word || stopped_eos) {
finish_reason = "stop";
}
auto chat_template = json_value(request, "chat_template", std::string());
llama_tool_calls parsed_tool_calls;
auto tools = json_value(request, "tools", json::array());
json tool_calls;
json message_content;
if (json_value(request, "parse_tool_calls", false)) {
parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content);
if (!parsed_tool_calls.tool_calls.empty()) {
finish_reason = "tool_calls";
message_content = parsed_tool_calls.content;
tool_calls = json::array();
for (const auto & tc : parsed_tool_calls.tool_calls) {
tool_calls.push_back({
{"type", "function"},
{"function", {
{"name", tc.name},
{"arguments", tc.arguments},
}},
{"id", tc.id.empty() ? json() : json(tc.id)},
});
}
} else {
message_content = parsed_tool_calls.content;
}
} else {
message_content = content;
}
json choices =
streaming ? json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}})
: json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"message", json{{"content", message_content},
{"tool_calls", tool_calls},
{"role", "assistant"}}}}});
std::time_t t = std::time(0);
json res = json {
{"choices", choices},
{"created", t},
{"model",
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
{"usage", json {
{"completion_tokens", num_tokens_predicted},
{"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
}},
{"id", completion_id}
};
// extra fields for debugging purposes
if (verbose) {
res["__verbose"] = result;
}
if (result.contains("completion_probabilities")) {
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
}
if (result.contains("timings")) {
res.push_back({"timings", json_value(result, "timings", json::object())});
}
return res;
}
// return value is vector as there is one case where we might need to generate two responses
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
return std::vector<json>({result});
}
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
bool stopped_word = json_value(result, "stopped_word", false);
bool stopped_eos = json_value(result, "stopped_eos", false);
bool stopped_limit = json_value(result, "stopped_limit", false);
std::string content = json_value(result, "content", std::string(""));
std::string finish_reason;
if (stopped_word || stopped_eos) {
finish_reason = "stop";
}
if (stopped_limit) {
finish_reason = "length";
}
std::time_t t = std::time(0);
json choices;
if (!finish_reason.empty()) {
choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
} else {
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
}}}})},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}};
json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"content", content}}}
}})},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}};
return std::vector<json>({initial_ret, second_ret});
}
} else {
// Some idiosyncrasy in task processing logic makes several trailing calls
// with empty content, we ignore these at the calee site.
if (content.empty()) {
return std::vector<json>({json::object()});
}
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json{
{"content", content},
}},
}});
}
}
json ret = json {
{"choices", choices},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
if (result.contains("timings")) {
ret.push_back({"timings", json_value(result, "timings", json::object())});
}
if (!finish_reason.empty()) {
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
ret.push_back({"usage", json {
{"completion_tokens", num_tokens_predicted},
{"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
}});
}
return std::vector<json>({ret});
}
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) { static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
json data = json::array(); json data = json::array();
int i = 0; int i = 0;
@ -934,42 +709,17 @@ static json format_detokenized_response(const std::string & content) {
}; };
} }
static json format_error_response(const std::string & message, const enum error_type type) { static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) {
std::string type_str; json data = json::array();
int code = 500; for (const auto & lb : logit_bias) {
switch (type) { data.push_back(json{
case ERROR_TYPE_INVALID_REQUEST: {"bias", lb.bias},
type_str = "invalid_request_error"; {"token", lb.token},
code = 400; });
break;
case ERROR_TYPE_AUTHENTICATION:
type_str = "authentication_error";
code = 401;
break;
case ERROR_TYPE_NOT_FOUND:
type_str = "not_found_error";
code = 404;
break;
case ERROR_TYPE_SERVER:
type_str = "server_error";
code = 500;
break;
case ERROR_TYPE_PERMISSION:
type_str = "permission_error";
code = 403;
break;
case ERROR_TYPE_NOT_SUPPORTED:
type_str = "not_supported_error";
code = 501;
break;
case ERROR_TYPE_UNAVAILABLE:
type_str = "unavailable_error";
code = 503;
break;
} }
return json { return data;
{"code", code}, }
{"message", message},
{"type", type_str}, static std::string safe_json_to_str(json data) {
}; return data.dump(-1, ' ', false, json::error_handler_t::replace);
} }

View File

@ -103,24 +103,14 @@ extern "C" {
// Internal types and functions exposed for tests and benchmarks // Internal types and functions exposed for tests and benchmarks
typedef void (*ggml_from_float_to_mat_t)
(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs);
typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
const void * GGML_RESTRICT y, size_t by, int nrc); const void * GGML_RESTRICT y, size_t by, int nrc);
typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
const void * GGML_RESTRICT y, int nr, int nc);
typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
const void * GGML_RESTRICT y, int nr, int nc);
struct ggml_type_traits_cpu { struct ggml_type_traits_cpu {
ggml_from_float_t from_float; ggml_from_float_t from_float;
ggml_from_float_to_mat_t from_float_to_mat;
ggml_vec_dot_t vec_dot; ggml_vec_dot_t vec_dot;
enum ggml_type vec_dot_type; enum ggml_type vec_dot_type;
int64_t nrows; // number of rows to process simultaneously int64_t nrows; // number of rows to process simultaneously
int64_t ncols; // number of columns to process simultaneously
ggml_gemv_t gemv;
ggml_gemm_t gemm;
}; };
GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type); GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);
@ -140,13 +130,6 @@ extern "C" {
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
#ifdef GGML_USE_CPU_HBM
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
#endif
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void);
GGML_BACKEND_API bool ggml_backend_cpu_buft_is_aarch64(ggml_backend_buffer_type_t buft);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -384,15 +384,15 @@ extern "C" {
GGML_TYPE_F64 = 28, GGML_TYPE_F64 = 28,
GGML_TYPE_IQ1_M = 29, GGML_TYPE_IQ1_M = 29,
GGML_TYPE_BF16 = 30, GGML_TYPE_BF16 = 30,
GGML_TYPE_Q4_0_4_4 = 31, // GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files
GGML_TYPE_Q4_0_4_8 = 32, // GGML_TYPE_Q4_0_4_8 = 32,
GGML_TYPE_Q4_0_8_8 = 33, // GGML_TYPE_Q4_0_8_8 = 33,
GGML_TYPE_TQ1_0 = 34, GGML_TYPE_TQ1_0 = 34,
GGML_TYPE_TQ2_0 = 35, GGML_TYPE_TQ2_0 = 35,
GGML_TYPE_IQ4_NL_4_4 = 36, // GGML_TYPE_IQ4_NL_4_4 = 36,
// GGML_TYPE_IQ4_NL_4_8 = 37, // GGML_TYPE_IQ4_NL_4_8 = 37,
// GGML_TYPE_IQ4_NL_8_8 = 38, // GGML_TYPE_IQ4_NL_8_8 = 38,
GGML_TYPE_COUNT, GGML_TYPE_COUNT = 39,
}; };
// precision // precision
@ -433,9 +433,6 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
}; };
// available tensor operations: // available tensor operations:
@ -2207,7 +2204,15 @@ extern "C" {
#ifdef __cplusplus #ifdef __cplusplus
// restrict not standard in C++ // restrict not standard in C++
# if defined(__GNUC__)
# define GGML_RESTRICT __restrict__
# elif defined(__clang__)
# define GGML_RESTRICT __restrict
# elif defined(_MSC_VER)
# define GGML_RESTRICT __restrict
# else
# define GGML_RESTRICT # define GGML_RESTRICT
# endif
#else #else
# define GGML_RESTRICT restrict # define GGML_RESTRICT restrict
#endif #endif

View File

@ -220,9 +220,7 @@ add_library(ggml-base
ggml-threading.cpp ggml-threading.cpp
ggml-threading.h ggml-threading.h
ggml-quants.c ggml-quants.c
ggml-quants.h ggml-quants.h)
ggml-aarch64.c
ggml-aarch64.h)
target_include_directories(ggml-base PRIVATE .) target_include_directories(ggml-base PRIVATE .)

View File

@ -1,129 +0,0 @@
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "ggml-aarch64.h"
#include "ggml-impl.h"
#include "ggml-quants.h"
#include <assert.h>
#define UNUSED GGML_UNUSED
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
block_q4_0x4 out;
for (int i = 0; i < 4; i++) {
out.d[i] = in[i].d;
}
const int end = QK4_0 * 2 / blck_size_interleave;
if (blck_size_interleave == 8) {
const uint64_t xor_mask = 0x8888888888888888ULL;
for (int i = 0; i < end; ++i) {
int src_id = i % 4;
int src_offset = (i / 4) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
uint64_t elems;
// Using memcpy to avoid unaligned memory accesses
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
elems ^= xor_mask;
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
}
} else if (blck_size_interleave == 4) {
const uint32_t xor_mask = 0x88888888;
for (int i = 0; i < end; ++i) {
int src_id = i % 4;
int src_offset = (i / 4) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
uint32_t elems;
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t));
elems ^= xor_mask;
memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t));
}
} else {
GGML_ASSERT(false);
}
return out;
}
// interleave 8 block_q4_0s in blocks of blck_size_interleave
// returns an interleaved block_q4_0x8
// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) {
block_q4_0x8 out;
for (int i = 0; i < 8; i++) {
out.d[i] = in[i].d;
}
const int end = QK4_0 * 4 / blck_size_interleave;
const uint64_t xor_mask = 0x8888888888888888ULL;
for (int i = 0; i < end; ++i) {
int src_id = i % 8;
int src_offset = (i / 8) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
uint64_t elems;
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
elems ^= xor_mask;
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
}
return out;
}
static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) {
assert(n_per_row % QK4_0 == 0);
const int nb = n_per_row / QK4_0;
void * out_ptr = NULL;
if (nrows_interleaved == 8) {
out_ptr = (block_q4_0x8 *) dst;
}
else if (nrows_interleaved == 4) {
out_ptr = (block_q4_0x4 *) dst;
}
assert(nrows_interleaved <= 8);
block_q4_0 dst_tmp[8];
for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {
for (int64_t x = 0; x < nb; x++) {
for (int i = 0; i < nrows_interleaved; i++ ) {
quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0);
}
if (nrows_interleaved == 8) {
*(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave);
out_ptr = (block_q4_0x8 *) out_ptr + 1;
}
else if (nrows_interleaved == 4) {
*(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave);
out_ptr = (block_q4_0x4 *) out_ptr + 1;
}
}
}
return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0));
}
size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
UNUSED(quant_weights);
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4);
}
size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
UNUSED(quant_weights);
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8);
}
size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
UNUSED(quant_weights);
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
}

View File

@ -1,19 +0,0 @@
#pragma once
#include "ggml.h"
// GGML internal header
#ifdef __cplusplus
extern "C" {
#endif
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
#ifdef __cplusplus
}
#endif

View File

@ -2089,7 +2089,7 @@ static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, con
static const ggml_backend_reg_i ggml_backend_cann_reg_interface = { static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
/* .get_name = */ ggml_backend_cann_reg_get_name, /* .get_name = */ ggml_backend_cann_reg_get_name,
/* .get_device_count = */ ggml_backend_cann_reg_get_device_count, /* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
/* .get_device_get = */ ggml_backend_cann_reg_get_device, /* .get_device = */ ggml_backend_cann_reg_get_device,
/* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address, /* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
}; };

View File

@ -6,7 +6,20 @@
typedef uint16_t ggml_half; typedef uint16_t ggml_half;
typedef uint32_t ggml_half2; typedef uint32_t ggml_half2;
#define GGML_COMMON_AGGR #define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S
#define GGML_COMMON_DECL
#elif defined(GGML_COMMON_DECL_CPP)
#include <cstdint>
typedef uint16_t ggml_half;
typedef uint32_t ggml_half2;
// std-c++ allow anonymous unions but some compiler warn on it
#define GGML_COMMON_AGGR_U data
// std-c++ do not allow it.
#define GGML_COMMON_AGGR_S data
#define GGML_COMMON_DECL #define GGML_COMMON_DECL
#elif defined(GGML_COMMON_DECL_METAL) #elif defined(GGML_COMMON_DECL_METAL)
@ -15,7 +28,8 @@ typedef uint32_t ggml_half2;
typedef half ggml_half; typedef half ggml_half;
typedef half2 ggml_half2; typedef half2 ggml_half2;
#define GGML_COMMON_AGGR #define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S
#define GGML_COMMON_DECL #define GGML_COMMON_DECL
#elif defined(GGML_COMMON_DECL_CUDA) #elif defined(GGML_COMMON_DECL_CUDA)
@ -29,7 +43,8 @@ typedef half2 ggml_half2;
typedef half ggml_half; typedef half ggml_half;
typedef half2 ggml_half2; typedef half2 ggml_half2;
#define GGML_COMMON_AGGR data #define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S data
#define GGML_COMMON_DECL #define GGML_COMMON_DECL
#elif defined(GGML_COMMON_DECL_HIP) #elif defined(GGML_COMMON_DECL_HIP)
@ -39,7 +54,8 @@ typedef half2 ggml_half2;
typedef half ggml_half; typedef half ggml_half;
typedef half2 ggml_half2; typedef half2 ggml_half2;
#define GGML_COMMON_AGGR data #define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S data
#define GGML_COMMON_DECL #define GGML_COMMON_DECL
#elif defined(GGML_COMMON_DECL_SYCL) #elif defined(GGML_COMMON_DECL_SYCL)
@ -49,7 +65,8 @@ typedef half2 ggml_half2;
typedef sycl::half ggml_half; typedef sycl::half ggml_half;
typedef sycl::half2 ggml_half2; typedef sycl::half2 ggml_half2;
#define GGML_COMMON_AGGR data #define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S data
#define GGML_COMMON_DECL #define GGML_COMMON_DECL
#endif #endif
@ -154,9 +171,9 @@ typedef struct {
struct { struct {
ggml_half d; // delta ggml_half d; // delta
ggml_half m; // min ggml_half m; // min
} GGML_COMMON_AGGR; } GGML_COMMON_AGGR_S;
ggml_half2 dm; ggml_half2 dm;
}; } GGML_COMMON_AGGR_U;
uint8_t qs[QK4_1 / 2]; // nibbles / quants uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1; } block_q4_1;
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding"); static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
@ -175,9 +192,9 @@ typedef struct {
struct { struct {
ggml_half d; // delta ggml_half d; // delta
ggml_half m; // min ggml_half m; // min
} GGML_COMMON_AGGR; } GGML_COMMON_AGGR_S;
ggml_half2 dm; ggml_half2 dm;
}; } GGML_COMMON_AGGR_U;
uint8_t qh[4]; // 5-th bit of quants uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_1 / 2]; // nibbles / quants uint8_t qs[QK5_1 / 2]; // nibbles / quants
} block_q5_1; } block_q5_1;
@ -196,37 +213,13 @@ typedef struct {
struct { struct {
ggml_half d; // delta ggml_half d; // delta
ggml_half s; // d * sum(qs[i]) ggml_half s; // d * sum(qs[i])
} GGML_COMMON_AGGR; } GGML_COMMON_AGGR_S;
ggml_half2 ds; ggml_half2 ds;
}; } GGML_COMMON_AGGR_U;
int8_t qs[QK8_1]; // quants int8_t qs[QK8_1]; // quants
} block_q8_1; } block_q8_1;
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding"); static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding");
typedef struct {
ggml_half d[4]; // deltas for 4 q4_0 blocks
uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks
} block_q4_0x4;
static_assert(sizeof(block_q4_0x4) == 4 * sizeof(ggml_half) + QK4_0 * 2, "wrong q4_0x4 block size/padding");
typedef struct {
ggml_half d[8]; // deltas for 8 q4_0 blocks
uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks
} block_q4_0x8;
static_assert(sizeof(block_q4_0x8) == 8 * sizeof(ggml_half) + QK4_0 * 4, "wrong q4_0x8 block size/padding");
typedef struct {
ggml_half d[4]; // deltas for 4 q8_0 blocks
int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks
} block_q8_0x4;
static_assert(sizeof(block_q8_0x4) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong q8_0x4 block size/padding");
typedef struct {
ggml_half d[8]; // deltas for 8 q8_0 blocks
int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks
} block_q8_0x8;
static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding");
// //
// Ternary quantization // Ternary quantization
// //
@ -261,9 +254,9 @@ typedef struct {
struct { struct {
ggml_half d; // super-block scale for quantized scales ggml_half d; // super-block scale for quantized scales
ggml_half dmin; // super-block scale for quantized mins ggml_half dmin; // super-block scale for quantized mins
} GGML_COMMON_AGGR; } GGML_COMMON_AGGR_S;
ggml_half2 dm; ggml_half2 dm;
}; } GGML_COMMON_AGGR_U;
} block_q2_K; } block_q2_K;
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
@ -288,9 +281,9 @@ typedef struct {
struct { struct {
ggml_half d; // super-block scale for quantized scales ggml_half d; // super-block scale for quantized scales
ggml_half dmin; // super-block scale for quantized mins ggml_half dmin; // super-block scale for quantized mins
} GGML_COMMON_AGGR; } GGML_COMMON_AGGR_S;
ggml_half2 dm; ggml_half2 dm;
}; } GGML_COMMON_AGGR_U;
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K; } block_q4_K;
@ -305,9 +298,9 @@ typedef struct {
struct { struct {
ggml_half d; // super-block scale for quantized scales ggml_half d; // super-block scale for quantized scales
ggml_half dmin; // super-block scale for quantized mins ggml_half dmin; // super-block scale for quantized mins
} GGML_COMMON_AGGR; } GGML_COMMON_AGGR_S;
ggml_half2 dm; ggml_half2 dm;
}; } GGML_COMMON_AGGR_U;
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qh[QK_K/8]; // quants, high bit uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits uint8_t qs[QK_K/2]; // quants, low 4 bits
@ -418,12 +411,6 @@ typedef struct {
} block_iq4_xs; } block_iq4_xs;
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
typedef struct {
ggml_half d[4]; // deltas for 4 iq4_nl blocks
uint8_t qs[QK4_NL * 2];// nibbles / quants for 4 iq4_nl blocks
} block_iq4_nlx4;
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
#endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL
#endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL
@ -437,6 +424,13 @@ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wro
#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = { #define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
#define GGML_TABLE_END() }; #define GGML_TABLE_END() };
#define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_CPP)
#include <cstdint>
#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
#define GGML_TABLE_END() };
#define GGML_COMMON_IMPL #define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_METAL) #elif defined(GGML_COMMON_IMPL_METAL)
#include <metal_stdlib> #include <metal_stdlib>

View File

@ -10,10 +10,14 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
list (APPEND GGML_CPU_SOURCES list (APPEND GGML_CPU_SOURCES
ggml-cpu/ggml-cpu.c ggml-cpu/ggml-cpu.c
ggml-cpu/ggml-cpu.cpp ggml-cpu/ggml-cpu.cpp
ggml-cpu/ggml-cpu-aarch64.c ggml-cpu/ggml-cpu-aarch64.cpp
ggml-cpu/ggml-cpu-aarch64.h ggml-cpu/ggml-cpu-aarch64.h
ggml-cpu/ggml-cpu-hbm.cpp
ggml-cpu/ggml-cpu-hbm.h
ggml-cpu/ggml-cpu-quants.c ggml-cpu/ggml-cpu-quants.c
ggml-cpu/ggml-cpu-quants.h ggml-cpu/ggml-cpu-quants.h
ggml-cpu/ggml-cpu-traits.cpp
ggml-cpu/ggml-cpu-traits.h
ggml-cpu/amx/amx.cpp ggml-cpu/amx/amx.cpp
ggml-cpu/amx/amx.h ggml-cpu/amx/amx.h
ggml-cpu/amx/mmq.cpp ggml-cpu/amx/mmq.cpp

View File

@ -5,6 +5,7 @@
#include "ggml-backend.h" #include "ggml-backend.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml-cpu-traits.h"
#if defined(__gnu_linux__) #if defined(__gnu_linux__)
#include <sys/syscall.h> #include <sys/syscall.h>
@ -17,6 +18,29 @@
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) #if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
// AMX type_trais
namespace ggml::cpu::amx {
class tensor_traits : public ggml::cpu::tensor_traits {
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
size = ggml_backend_amx_desired_wsize(op);
return true;
}
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
if (op->op == GGML_OP_MUL_MAT) {
ggml_backend_amx_mul_mat(params, op);
return true;
}
return false;
}
};
static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
static tensor_traits traits;
return &traits;
}
} // namespace ggml::cpu::amx
// AMX buffer interface // AMX buffer interface
static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
free(buffer->context); free(buffer->context);
@ -26,14 +50,23 @@ static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
return (void *) (buffer->context); return (void *) (buffer->context);
} }
static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { static void ggml_backend_amx_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits(buffer, tensor);
GGML_UNUSED(buffer);
}
static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
uint8_t value, size_t offset, size_t size) {
memset((char *) tensor->data + offset, value, size); memset((char *) tensor->data + offset, value, size);
GGML_UNUSED(buffer); GGML_UNUSED(buffer);
} }
static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
const void * data, size_t offset, size_t size) {
if (qtype_has_amx_kernels(tensor->type)) { if (qtype_has_amx_kernels(tensor->type)) {
GGML_LOG_DEBUG("%s: amx repack tensor %s of type %s\n", __func__, tensor->name, ggml_type_name(tensor->type));
ggml_backend_amx_convert_weight(tensor, data, offset, size); ggml_backend_amx_convert_weight(tensor, data, offset, size);
} else { } else {
memcpy((char *) tensor->data + offset, data, size); memcpy((char *) tensor->data + offset, data, size);
@ -42,6 +75,8 @@ static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, str
GGML_UNUSED(buffer); GGML_UNUSED(buffer);
} }
/*
// need to figure what we need to do with buffer->extra.
static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
GGML_ASSERT(!qtype_has_amx_kernels(tensor->type)); GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
memcpy(data, (const char *)tensor->data + offset, size); memcpy(data, (const char *)tensor->data + offset, size);
@ -62,6 +97,7 @@ static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
GGML_UNUSED(buffer); GGML_UNUSED(buffer);
} }
*/
static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
memset(buffer->context, value, buffer->size); memset(buffer->context, value, buffer->size);
@ -70,13 +106,13 @@ static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = { static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
/* .free_buffer = */ ggml_backend_amx_buffer_free_buffer, /* .free_buffer = */ ggml_backend_amx_buffer_free_buffer,
/* .get_base = */ ggml_backend_amx_buffer_get_base, /* .get_base = */ ggml_backend_amx_buffer_get_base,
/* .init_tensor = */ NULL, // no initialization required /* .init_tensor = */ ggml_backend_amx_buffer_init_tensor,
/* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor, /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_amx_buffer_set_tensor, /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_amx_buffer_get_tensor, /* .get_tensor = */ nullptr,
/* .cpy_tensor = */ ggml_backend_amx_buffer_cpy_tensor, /* .cpy_tensor = */ nullptr,
/* .clear = */ ggml_backend_amx_buffer_clear, /* .clear = */ ggml_backend_amx_buffer_clear,
/* .reset = */ NULL, /* .reset = */ nullptr,
}; };
static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) { static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
@ -101,14 +137,44 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ
GGML_UNUSED(buft); GGML_UNUSED(buft);
} }
static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) { namespace ggml::cpu::amx {
return ggml_backend_amx_get_alloc_size(tensor); class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
// handle only 2d gemm for now
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
};
GGML_UNUSED(buft); if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
is_contiguous_2d(op->src[1]) && // src1 must be contiguous
op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
(qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
// src1 must be host buffer
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
// src1 must be float32
if (op->src[1]->type == GGML_TYPE_F32) {
return true;
}
}
return false;
} }
static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) { ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
return false; if (op->op == GGML_OP_MUL_MAT && op->src[0]->buffer &&
op->src[0]->buffer->buft == ggml_backend_amx_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
}
return nullptr;
}
};
} // namespace ggml::cpu::amx
static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
return ggml_backend_amx_get_alloc_size(tensor);
GGML_UNUSED(buft); GGML_UNUSED(buft);
} }
@ -129,68 +195,26 @@ static bool ggml_amx_init() {
return true; return true;
#endif #endif
} }
ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() { ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = { static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
/* .iface = */ { /* .iface = */ {
/* .get_name = */ ggml_backend_amx_buffer_type_get_name, /* .get_name = */ ggml_backend_amx_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer, /* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment, /* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
/* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size, /* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size,
/* .is_host = */ ggml_backend_amx_buffer_type_is_host, /* .is_host = */ nullptr,
}, },
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
/* .context = */ NULL, /* .context = */ new ggml::cpu::amx::extra_buffer_type(),
}; };
if (!ggml_amx_init()) { if (!ggml_amx_init()) {
return NULL; return nullptr;
} }
return &ggml_backend_buffer_type_amx; return &ggml_backend_buffer_type_amx;
} }
bool ggml_backend_amx_buft_is_amx(ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;
}
bool ggml_backend_amx_device_supports_op(const struct ggml_tensor * op) {
// handle only 2d gemm for now
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
};
switch (op->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
return true;
case GGML_OP_MUL_MAT: {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];
const enum ggml_type type = src0->type;
const int64_t ne0 = op->ne[0];
// amx kernels enables for Q4_0, Q4_1, Q8_0, F16
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
bool can_use_amx =
is_contiguous_2d(src0) && // src0 must be contiguous
is_contiguous_2d(src1) && // src1 must be contiguous
src1->type == GGML_TYPE_F32 && // src1 must be float32
has_amx_kernels && // with amx kernel impls
ne0 % (TILE_N * 2) == 0; // out_features is 32x
return can_use_amx;
}
default:
return false;
}
}
#endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__) #endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__)

View File

@ -1,20 +1,8 @@
#include "ggml-backend.h" #include "ggml-backend.h"
#include "ggml-cpu-impl.h" #include "ggml-cpu-impl.h"
#ifdef __cplusplus // GGML internal header
extern "C" {
#endif
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) #if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void); ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);
bool ggml_backend_amx_buft_is_amx(ggml_backend_buffer_type_t buft);
bool ggml_backend_amx_device_supports_op(const struct ggml_tensor * op);
void ggml_backend_amx_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst);
#endif
#ifdef __cplusplus
}
#endif #endif

View File

@ -7,7 +7,7 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#if defined(_OPENMP) #if defined(GGML_USE_OPENMP)
#include <omp.h> #include <omp.h>
#endif #endif
@ -56,11 +56,11 @@ inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
} }
template <typename func_t> template <typename func_t>
inline void parallel_for(int nth, int n, const func_t& f) { inline void parallel_for(int n, const func_t& f) {
#if defined(_OPENMP) #if defined(GGML_USE_OPENMP)
#pragma omp parallel num_threads(nth) #pragma omp parallel
{ {
//int nth = omp_get_num_threads(); int nth = omp_get_num_threads();
int ith = omp_get_thread_num(); int ith = omp_get_thread_num();
int tbegin, tend; int tbegin, tend;
balance211(n, nth, ith, tbegin, tend); balance211(n, nth, ith, tbegin, tend);
@ -68,8 +68,6 @@ inline void parallel_for(int nth, int n, const func_t& f) {
} }
#else #else
f(0, n); f(0, n);
GGML_UNUSED(nth);
#endif #endif
} }
@ -91,10 +89,3 @@ inline bool qtype_has_amx_kernels(const enum ggml_type type) {
(type == GGML_TYPE_Q6_K) || (type == GGML_TYPE_Q6_K) ||
(type == GGML_TYPE_IQ4_XS); (type == GGML_TYPE_IQ4_XS);
} }
// ggml backend context
struct ggml_backend_amx_context {
int n_threads = GGML_DEFAULT_N_THREADS;
std::unique_ptr<char[]> work_data;
size_t work_size = 0;
};

View File

@ -18,10 +18,6 @@
#include <unistd.h> #include <unistd.h>
#endif #endif
#if defined(_OPENMP)
#include <omp.h>
#endif
#if (defined(_WIN32) || defined(_WIN64)) #if (defined(_WIN32) || defined(_WIN64))
#define RESTRICT __restrict #define RESTRICT __restrict
#else #else
@ -1382,13 +1378,13 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K>
#define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size #define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size
template<typename TB, int BLOCK_K> template<typename TB, int BLOCK_K>
void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K, int n_threads) { void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K) {
const int NB = N / TILE_N; const int NB = N / TILE_N;
const int KB = K / BLOCK_K; const int KB = K / BLOCK_K;
const int TILE_SIZE = get_tile_size<TB>(); const int TILE_SIZE = get_tile_size<TB>();
// parallel on NB should be enough // parallel on NB should be enough
parallel_for(n_threads, NB, [&](int begin, int end) { parallel_for(NB, [&](int begin, int end) {
for (int n = begin; n < end; ++n) { for (int n = begin; n < end; ++n) {
for (int k = 0; k < KB; ++k) { for (int k = 0; k < KB; ++k) {
int n0 = n * TILE_N; int n0 = n * TILE_N;
@ -2334,15 +2330,8 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d
const int K = tensor->ne[0]; // ne0: in_features const int K = tensor->ne[0]; // ne0: in_features
const int N = tensor->ne[1]; // ne1: out_features const int N = tensor->ne[1]; // ne1: out_features
#if defined(_OPENMP)
// the buffer ctx is not initialized when .set_tensor is called
int n_threads = omp_get_num_threads();
#else
int n_threads = 1;
#endif
GGML_DISPATCH_QTYPES(TYPE, [&] { GGML_DISPATCH_QTYPES(TYPE, [&] {
convert_B_packed_format<type, blck_size>((void *)((char *)tensor->data + offset), (const type *)data, N, K, n_threads); convert_B_packed_format<type, blck_size>((void *)((char *)tensor->data + offset), (const type *)data, N, K);
}); });
} }

View File

@ -1,16 +1,10 @@
#pragma once #pragma once
#include "common.h" #include "common.h"
#ifdef __cplusplus size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst);
extern "C" {
#endif
size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor); size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor);
void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void ggml_backend_amx_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_backend_amx_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
#ifdef __cplusplus
}
#endif

View File

@ -1,20 +1,57 @@
#define GGML_COMMON_IMPL_C #define GGML_COMMON_IMPL_CPP
#define GGML_COMMON_DECL_CPP
#include "ggml-common.h" #include "ggml-common.h"
#include "ggml-backend-impl.h"
#include "ggml-quants.h" #include "ggml-quants.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml-cpu/ggml-cpu-impl.h" #include "ggml-cpu-impl.h"
#include "ggml-cpu-traits.h"
#include <math.h> #include <cmath>
#include <string.h> #include <cstring>
#include <assert.h> #include <cassert>
#include <float.h> #include <cfloat>
#include <stdlib.h> // for qsort #include <cstdlib> // for qsort
#include <stdio.h> // for GGML_ASSERT #include <cstdio> // for GGML_ASSERT
#include "ggml-cpu-aarch64.h" #include "ggml-cpu-aarch64.h"
// TODO: move to include file?
template <int K> constexpr int QK_0() {
if constexpr (K == 4) {
return QK4_0;
}
if constexpr (K == 8) {
return QK8_0;
}
return -1;
}
template <int K, int N> struct block {
ggml_half d[N]; // deltas for N qK_0 blocks
int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
};
// control size
static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
using block_q4_0x4 = block<4, 4>;
using block_q4_0x8 = block<4, 8>;
using block_q8_0x4 = block<8, 4>;
using block_q8_0x8 = block<8, 8>;
struct block_iq4_nlx4 {
ggml_half d[4]; // deltas for 4 iq4_nl blocks
uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
};
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
#if defined(__GNUC__) #if defined(__GNUC__)
#pragma GCC diagnostic ignored "-Woverlength-strings" #pragma GCC diagnostic ignored "-Woverlength-strings"
#elif defined(_MSC_VER) #elif defined(_MSC_VER)
@ -185,12 +222,12 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) { static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK8_0 == 32); assert(QK8_0 == 32);
assert(k % QK8_0 == 0); assert(k % QK8_0 == 0);
const int nb = k / QK8_0; const int nb = k / QK8_0;
block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
float32x4_t srcv[4][8]; float32x4_t srcv[4][8];
@ -279,12 +316,12 @@ static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int6
#endif #endif
} }
static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) { static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK8_0 == 32); assert(QK8_0 == 32);
assert(k % QK8_0 == 0); assert(k % QK8_0 == 0);
const int nb = k / QK8_0; const int nb = k / QK8_0;
block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
float32x4_t srcv[4][8]; float32x4_t srcv[4][8];
@ -494,7 +531,7 @@ static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int6
#endif #endif
} }
void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
assert(nrow == 4); assert(nrow == 4);
UNUSED(nrow); UNUSED(nrow);
if (blck_size_interleave == 4) { if (blck_size_interleave == 4) {
@ -506,7 +543,7 @@ void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nro
} }
} }
void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 4; const int ncols_interleaved = 4;
@ -591,7 +628,7 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
} }
} }
void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 4; const int ncols_interleaved = 4;
@ -701,7 +738,7 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
} }
} }
void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 8; const int ncols_interleaved = 8;
@ -974,7 +1011,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
} }
} }
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 4; const int ncols_interleaved = 4;
@ -1070,7 +1107,7 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
} }
} }
void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 4; const int ncols_interleaved = 4;
@ -1586,7 +1623,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
} }
} }
void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 4; const int ncols_interleaved = 4;
@ -2040,7 +2077,7 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
} }
} }
void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 8; const int ncols_interleaved = 8;
@ -2560,31 +2597,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
// Shuffle pattern one - right side input // Shuffle pattern one - right side input
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
// Shuffle pattern two - right side input // Shuffle pattern two - right side input
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
// Scale values - Load the weight scale values of two block_q4_0x8 // Scale values - Load the weight scale values of two block_q4_0x8
const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
@ -2618,31 +2655,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// Shuffle pattern one - left side input // Shuffle pattern one - left side input
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
// Shuffle pattern two - left side input // Shuffle pattern two - left side input
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
// Resembles MMLAs into 2x2 matrices in ARM Version // Resembles MMLAs into 2x2 matrices in ARM Version
@ -2671,10 +2708,10 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// Straighten out to make 4 row vectors // Straighten out to make 4 row vectors
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78)); __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01); __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78)); __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11); __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68); const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
@ -2753,31 +2790,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
// Shuffle pattern one - right side input // Shuffle pattern one - right side input
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
// Shuffle pattern two - right side input // Shuffle pattern two - right side input
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
// Scale values - Load the weight scale values of two block_q4_0x8 // Scale values - Load the weight scale values of two block_q4_0x8
@ -2809,31 +2846,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// Shuffle pattern one - left side input // Shuffle pattern one - left side input
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
// Shuffle pattern two - left side input // Shuffle pattern two - left side input
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
// Resembles MMLAs into 2x2 matrices in ARM Version // Resembles MMLAs into 2x2 matrices in ARM Version
@ -2862,10 +2899,10 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// Straighten out to make 4 row vectors // Straighten out to make 4 row vectors
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78)); __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01); __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78)); __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11); __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68); const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
@ -3460,7 +3497,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
} }
} }
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { static void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
const int ncols_interleaved = 4; const int ncols_interleaved = 4;
@ -3571,7 +3608,6 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
} }
} }
// FIXME: this code is duplicated from ggml-aarch64.c
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) { static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
block_q4_0x4 out; block_q4_0x4 out;
@ -3641,20 +3677,20 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
return out; return out;
} }
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) { static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
GGML_ASSERT(interleave_block == 4 || interleave_block == 8); GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
constexpr int nrows_interleaved = 4;
block_q4_0x4 * dst = (block_q4_0x4 *)t->data; block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
const block_q4_0 * src = (const block_q4_0 *)data; const block_q4_0 * src = (const block_q4_0 *)data;
block_q4_0 dst_tmp[4]; block_q4_0 dst_tmp[4];
int nrow = t->ne[1]; // Number of rows int nrow = ggml_nrows(t);
int nrows_interleaved = 4;
int nblocks = t->ne[0] / QK4_0; int nblocks = t->ne[0] / QK4_0;
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
return -1; return -1;
} }
@ -3672,20 +3708,20 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
GGML_UNUSED(data_size); GGML_UNUSED(data_size);
} }
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, const void * restrict data, size_t data_size) { static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
GGML_ASSERT(interleave_block == 8); GGML_ASSERT(interleave_block == 8);
constexpr int nrows_interleaved = 8;
block_q4_0x8 * dst = (block_q4_0x8*)t->data; block_q4_0x8 * dst = (block_q4_0x8*)t->data;
const block_q4_0 * src = (const block_q4_0*) data; const block_q4_0 * src = (const block_q4_0*) data;
block_q4_0 dst_tmp[8]; block_q4_0 dst_tmp[8];
int nrow = t->ne[1]; // Number of rows int nrow = ggml_nrows(t);
int nrows_interleaved = 8;
int nblocks = t->ne[0] / QK4_0; int nblocks = t->ne[0] / QK4_0;
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
return -1; return -1;
} }
@ -3712,16 +3748,18 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
const int end = QK4_NL * 2 / blck_size_interleave; const int end = QK4_NL * 2 / blck_size_interleave;
if (blck_size_interleave == 8) { // TODO: this branch seems wrong
for (int i = 0; i < end; ++i) { //if (blck_size_interleave == 8) {
int src_id = i % 4; // for (int i = 0; i < end; ++i) {
int src_offset = (i / 4) * blck_size_interleave; // int src_id = i % 4;
int dst_offset = i * blck_size_interleave; // int src_offset = (i / 4) * blck_size_interleave;
// int dst_offset = i * blck_size_interleave;
// Using memcpy to avoid unaligned memory accesses // // Using memcpy to avoid unaligned memory accesses
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
} // }
} else if (blck_size_interleave == 4) { //} else
if (blck_size_interleave == 4) {
for (int i = 0; i < end; ++i) { for (int i = 0; i < end; ++i) {
int src_id = i % 4; int src_id = i % 4;
int src_offset = (i / 4) * blck_size_interleave; int src_offset = (i / 4) * blck_size_interleave;
@ -3736,20 +3774,21 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
return out; return out;
} }
static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) { static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
GGML_ASSERT(interleave_block == 4 || interleave_block == 8); //GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
GGML_ASSERT(interleave_block == 4);
block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data; block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
const block_iq4_nl * src = (const block_iq4_nl *)data; const block_iq4_nl * src = (const block_iq4_nl *)data;
block_iq4_nl dst_tmp[4]; block_iq4_nl dst_tmp[4];
int nrow = t->ne[1]; // Number of rows int nrow = ggml_nrows(t);
int nrows_interleaved = 4; int nrows_interleaved = 4;
int nblocks = t->ne[0] / QK4_0; int nblocks = t->ne[0] / QK4_0;
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
return -1; return -1;
} }
@ -3767,57 +3806,457 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_b
GGML_UNUSED(data_size); GGML_UNUSED(data_size);
} }
// Prepare for optimized kernels if applicable namespace ggml::cpu::aarch64 {
void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * restrict data, size_t data_size) { // repack
if (cur->type == repack_type) { template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
memcpy(cur->data, data, data_size); int repack(struct ggml_tensor *, const void *, size_t);
// TODO: generalise.
template <> int repack<block_q4_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);
}
template <> int repack<block_q4_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);
}
template <> int repack<block_q4_0, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
}
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
}
// TODO: needs to be revisited
//template <> int repack<block_iq4_nl, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
//}
// gemv
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
void gemv(int, float *, size_t, const void *, const void *, int, int);
template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <>
void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
// gemm
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
void gemm(int, float *, size_t, const void *, const void *, int, int);
template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <>
void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
class tensor_traits_base : public ggml::cpu::tensor_traits {
public:
virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
};
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
// not realy a GGML_TYPE_Q8_0 but same size.
switch (op->op) {
case GGML_OP_MUL_MAT:
size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
return true;
case GGML_OP_MUL_MAT_ID:
size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
return true;
default:
// GGML_ABORT("fatal error");
break;
}
return false;
}
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
switch (op->op) {
case GGML_OP_MUL_MAT:
forward_mul_mat(params, op);
return true;
case GGML_OP_MUL_MAT_ID:
forward_mul_mat_id(params, op);
return true;
default:
// GGML_ABORT("fatal error");
break;
}
return false;
}
void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
const ggml_tensor * src0 = op->src[0];
const ggml_tensor * src1 = op->src[1];
ggml_tensor * dst = op;
GGML_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
// GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
char * wdata = static_cast<char *>(params->wdata);
const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
assert(params->wsize >= nbw1 * ne11);
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
int64_t i11_processed = 0;
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
INTER_SIZE);
}
i11_processed = ne11 - ne11 % 4;
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
}
ggml_barrier(params->threadpool);
const void * src1_wdata = params->wdata;
const size_t src1_col_stride = ggml_row_size(GGML_TYPE_Q8_0, ne10);
int64_t src0_start = (ith * ne01) / nth;
int64_t src0_end = ((ith + 1) * ne01) / nth;
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
if (src0_start >= src0_end) {
return; return;
} }
if (cur->type == GGML_TYPE_Q4_0) { // If there are more than three rows in src1, use gemm; otherwise, use gemv.
switch (repack_type) { if (ne11 > 3) {
case GGML_TYPE_Q4_0_8_8: gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size); (const char *) src0->data + src0_start * nb01,
break; (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
case GGML_TYPE_Q4_0_4_8:
repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
break;
case GGML_TYPE_Q4_0_4_4:
repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
break;
default:
GGML_ABORT("Unsupported type");
} }
} else if (cur->type == GGML_TYPE_IQ4_NL) { for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
switch (repack_type) { gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
case GGML_TYPE_IQ4_NL_4_4: (const char *) src0->data + src0_start * nb01,
repack_iq4_nl_to_iq4_nl_4_bl(cur, 4, data, data_size); (const char *) src1_wdata + (src1_col_stride * iter), 1,
break; src0_end - src0_start);
default:
GGML_ABORT("Unsupported type");
}
} else {
GGML_ABORT("Unsupported type");
} }
} }
enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) { void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {
const ggml_tensor * src0 = op->src[0];
const ggml_tensor * src1 = op->src[1];
const ggml_tensor * ids = op->src[2];
ggml_tensor * dst = op;
GGML_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ne03 == 1);
GGML_ASSERT(ne13 == 1);
GGML_ASSERT(ne3 == 1);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
// row groups
const int n_ids = ids->ne[0]; // n_expert_used
const int n_as = ne02; // n_expert
const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
const size_t nbw2 = nbw1*ne11;
const size_t nbw3 = nbw2*ne12;
struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};
GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
n_as * ne12 * sizeof(mmid_row_mapping)));
auto wdata = (char *) params->wdata;
auto wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
// src1: float32 => block_q8_0
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
(void *) (wdata + i12 * nbw2 + i11 * nbw1),
ne10);
}
}
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
if (ith == 0) {
// initialize matrix_row_counts
memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
// group rows by src0 matrix
for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
for (int32_t id = 0; id < n_ids; ++id) {
const int32_t i02 =
*(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
GGML_ASSERT(i02 >= 0 && i02 < n_as);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
matrix_row_counts[i02] += 1;
}
}
}
ggml_barrier(params->threadpool);
// compute each matrix multiplication in sequence
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
const int64_t cne1 = matrix_row_counts[cur_a];
if (cne1 == 0) {
continue;
}
auto src0_cur = (const char *) src0->data + cur_a*nb02;
//const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1; // src1 rows
int64_t src0_cur_start = (ith * ne01) / nth;
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
src0_cur_start =
(src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
if (src0_cur_start >= src0_cur_end) return;
for (int ir1 = 0; ir1 < nr1; ir1++) {
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
const int id = row_mapping.i1; // selected expert index
const int64_t i11 = id % ne11;
const int64_t i12 = row_mapping.i2; // row index in src1
const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row
auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
ne01, src0_cur + src0_cur_start * nb01,
src1_col, 1, src0_cur_end - src0_cur_start);
}
}
#undef MMID_MATRIX_ROW
}
int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
(int) NB_COLS, (int) INTER_SIZE);
return ggml::cpu::aarch64::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
}
};
// instance for Q4
static const tensor_traits<block_q4_0, 4, 4> q4_0_4x4_q8_0;
static const tensor_traits<block_q4_0, 8, 4> q4_0_4x8_q8_0;
static const tensor_traits<block_q4_0, 8, 8> q4_0_8x8_q8_0;
// instance for IQ4
static const tensor_traits<block_iq4_nl, 4, 4> iq4_nl_4x4_q8_0;
} // namespace ggml::cpu::aarch64
static const ggml::cpu::tensor_traits * ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) {
if (cur->type == GGML_TYPE_Q4_0) { if (cur->type == GGML_TYPE_Q4_0) {
// TODO: enable for AVX2 - currently disabled due to bad gemv performance if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
if (/* ggml_cpu_has_avx2() || */ (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { if (cur->ne[1] % 8 == 0) {
return GGML_TYPE_Q4_0_8_8; return &ggml::cpu::aarch64::q4_0_8x8_q8_0;
}
} }
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
return GGML_TYPE_Q4_0_4_8; if (cur->ne[1] % 4 == 0) {
return &ggml::cpu::aarch64::q4_0_4x8_q8_0;
}
} }
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
return GGML_TYPE_Q4_0_4_4; if (cur->ne[1] % 4 == 0) {
return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
}
} }
} else if (cur->type == GGML_TYPE_IQ4_NL) { } else if (cur->type == GGML_TYPE_IQ4_NL) {
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
return GGML_TYPE_IQ4_NL_4_4; if (cur->ne[1] % 4 == 0) {
return &ggml::cpu::aarch64::iq4_nl_4x4_q8_0;
}
} }
} }
return cur->type; return nullptr;
}
static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_aarch64_get_optimal_repack_type(tensor));
GGML_UNUSED(buffer);
}
static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
const void * data, size_t offset, size_t size) {
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
auto tensor_traits = (ggml::cpu::aarch64::tensor_traits_base *) tensor->extra;
auto OK = tensor_traits->repack(tensor, data, size);
GGML_ASSERT(OK == 0);
GGML_UNUSED(buffer);
}
static const char * ggml_backend_cpu_aarch64_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "CPU_AARCH64";
GGML_UNUSED(buft);
}
static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
if (buffer == nullptr) {
return nullptr;
}
buffer->buft = buft;
buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor;
return buffer;
}
static size_t ggml_backend_cpu_aarch64_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return TENSOR_ALIGNMENT;
GGML_UNUSED(buft);
}
namespace ggml::cpu::aarch64 {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
if ( op->op == GGML_OP_MUL_MAT &&
op->src[0]->buffer &&
(ggml_n_dims(op->src[0]) == 2) &&
op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type() &&
ggml_aarch64_get_optimal_repack_type(op->src[0])
) {
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
if (op->src[1]->type == GGML_TYPE_F32) {
return true;
}
//if (op->src[1]->type == GGML_TYPE_Q8_0) {
// return true;
//}
// may be possible if Q8_0 packed...
} else if (op->op == GGML_OP_MUL_MAT_ID
&& op->src[0]->buffer
&& (ggml_n_dims(op->src[0]) == 3)
&& op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()
&& ggml_aarch64_get_optimal_repack_type(op->src[0])
) {
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
if (op->src[1]->type == GGML_TYPE_F32) {
return true;
}
//if (op->src[1]->type == GGML_TYPE_Q8_0) {
// return true;
//}
}
return false;
}
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
}
}
return nullptr;
}
};
} // namespace ggml::cpu::aarch64
ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void) {
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_aarch64 = {
/* .iface = */ {
/* .get_name = */ ggml_backend_cpu_aarch64_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_aarch64_buffer_type_get_alignment,
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
/* .is_host = */ nullptr,
},
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
/* .context = */ new ggml::cpu::aarch64::extra_buffer_type(),
};
return &ggml_backend_cpu_buffer_type_aarch64;
} }

View File

@ -1,32 +1,8 @@
#pragma once #pragma once
#include "ggml-cpu-traits.h"
#include "ggml.h" #include "ggml.h"
// GGML internal header // GGML internal header
#ifdef __cplusplus ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void);
extern "C" {
#endif
// Quantization
void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
// GEMV
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
// GEMM
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * data, size_t data_size);
enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur);
#ifdef __cplusplus
}
#endif

View File

@ -0,0 +1,55 @@
#ifdef GGML_USE_CPU_HBM
#include "ggml-backend.h"
#include "ggml-backend-impl.h"
#include "ggml-cpu.h"
#include "ggml-impl.h"
#include "ggml-cpu-hbm.h"
// buffer type HBM
#include <hbwmalloc.h>
static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "CPU_HBM";
GGML_UNUSED(buft);
}
static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
hbw_free(buffer->context);
}
static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
size_t size) {
void * ptr;
int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
if (result != 0) {
GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size);
return NULL;
}
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
buffer->buft = buft;
buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
return buffer;
}
ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {
/* .iface = */ {
/* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
/* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
},
/* .context = */ nullptr,
};
return &ggml_backend_cpu_buffer_type_hbm;
}
#endif

View File

@ -0,0 +1,8 @@
#pragma once
#include "ggml-backend.h"
#include "ggml.h"
// GGML CPU internal header
ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);

View File

@ -0,0 +1,36 @@
#include "ggml-cpu-traits.h"
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
namespace ggml::cpu {
tensor_traits::~tensor_traits() {}
extra_buffer_type::~extra_buffer_type() {}
} // namespace ggml::cpu
bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) {
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
if (extra && extra->context) {
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
auto tensor_traits = buf_extra->get_tensor_traits(op);
if (tensor_traits && tensor_traits->compute_forward(params, op)) {
return true;
}
}
}
return false;
}
bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size) {
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
if (extra && extra->context) {
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
auto tensor_traits = buf_extra->get_tensor_traits(op);
if (tensor_traits && tensor_traits->work_size(n_threads, op, *size)) {
return true;
}
}
}
return false;
}

View File

@ -0,0 +1,38 @@
#pragma once
#include "ggml-backend-impl.h"
#include "ggml-cpu-impl.h"
#include "ggml.h"
#ifdef __cplusplus
# include <vector>
extern "C" {
#endif
// return true if op part of extra "accelerator"
bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op);
bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size);
#ifdef __cplusplus
}
namespace ggml::cpu {
// register in tensor->extra
class tensor_traits {
public:
virtual ~tensor_traits();
virtual bool work_size(int n_threads, const struct ggml_tensor * op, size_t & size) = 0;
virtual bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) = 0;
};
class extra_buffer_type {
public:
virtual ~extra_buffer_type();
virtual bool supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) = 0;
virtual tensor_traits * get_tensor_traits(const struct ggml_tensor * op) = 0;
};
} // namespace ggml::cpu
// implemented in ggml-cpu.cpp.
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffers_type();
#endif

View File

@ -3,7 +3,7 @@
#include "ggml-backend-impl.h" #include "ggml-backend-impl.h"
#include "ggml-backend.h" #include "ggml-backend.h"
#include "ggml-cpu-aarch64.h" #include "ggml-cpu-traits.h"
#include "ggml-cpu-impl.h" #include "ggml-cpu-impl.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml-impl.h" #include "ggml-impl.h"
@ -224,10 +224,6 @@ typedef void * thread_ret_t;
typedef pthread_t ggml_thread_t; typedef pthread_t ggml_thread_t;
#ifdef GGML_USE_CPU_HBM
#include <hbwmalloc.h>
#endif
#if defined(__APPLE__) #if defined(__APPLE__)
#include <unistd.h> #include <unistd.h>
#include <mach/mach.h> #include <mach/mach.h>
@ -301,7 +297,6 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
}, },
[GGML_TYPE_Q8_0] = { [GGML_TYPE_Q8_0] = {
.from_float = quantize_row_q8_0, .from_float = quantize_row_q8_0,
.from_float_to_mat = quantize_mat_q8_0,
.vec_dot = ggml_vec_dot_q8_0_q8_0, .vec_dot = ggml_vec_dot_q8_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
#if defined (__ARM_FEATURE_MATMUL_INT8) #if defined (__ARM_FEATURE_MATMUL_INT8)
@ -409,33 +404,6 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_BF16, .vec_dot_type = GGML_TYPE_BF16,
.nrows = 1, .nrows = 1,
}, },
[GGML_TYPE_Q4_0_4_4] = {
.from_float = NULL,
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
.ncols = 4,
.gemv = ggml_gemv_q4_0_4x4_q8_0,
.gemm = ggml_gemm_q4_0_4x4_q8_0,
},
[GGML_TYPE_Q4_0_4_8] = {
.from_float = NULL,
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
.ncols = 4,
.gemv = ggml_gemv_q4_0_4x8_q8_0,
.gemm = ggml_gemm_q4_0_4x8_q8_0,
},
[GGML_TYPE_Q4_0_8_8] = {
.from_float = NULL,
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
.ncols = 8,
.gemv = ggml_gemv_q4_0_8x8_q8_0,
.gemm = ggml_gemm_q4_0_8x8_q8_0,
},
[GGML_TYPE_TQ1_0] = { [GGML_TYPE_TQ1_0] = {
.from_float = quantize_row_tq1_0, .from_float = quantize_row_tq1_0,
.vec_dot = ggml_vec_dot_tq1_0_q8_K, .vec_dot = ggml_vec_dot_tq1_0_q8_K,
@ -448,15 +416,6 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
}, },
[GGML_TYPE_IQ4_NL_4_4] = {
.from_float = NULL,
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
.ncols = 4,
.gemv = ggml_gemv_iq4_nl_4x4_q8_0,
.gemm = ggml_gemm_iq4_nl_4x4_q8_0,
},
}; };
const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
@ -4509,9 +4468,6 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
{ {
ggml_compute_forward_add_q_f32(params, dst); ggml_compute_forward_add_q_f32(params, dst);
} break; } break;
@ -4889,9 +4845,6 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
{ {
ggml_compute_forward_add1_q_f32(params, dst); ggml_compute_forward_add1_q_f32(params, dst);
} break; } break;
@ -5019,9 +4972,6 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
default: default:
{ {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -7437,27 +7387,9 @@ static void ggml_compute_forward_mul_mat(
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
enum ggml_type type = src0->type; enum ggml_type const vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
if (src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
type = (enum ggml_type)(intptr_t)src0->extra;
}
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
if (src0->buffer && ggml_backend_amx_buft_is_amx(src0->buffer->buft)) {
ggml_backend_amx_mul_mat(params, dst);
return;
}
#endif
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float; ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
ggml_from_float_to_mat_t const from_float_to_mat = type_traits_cpu[vec_dot_type].from_float_to_mat; int64_t const vec_dot_num_rows = type_traits_cpu[src0->type].nrows;
int64_t const vec_dot_num_rows = type_traits_cpu[type].nrows;
int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
int64_t const blck_size_interleave = ggml_get_type_traits(type)->blck_size_interleave;
ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
ggml_gemm_t const gemm = type_traits_cpu[type].gemm;
GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11); GGML_ASSERT(ne1 == ne11);
@ -7465,7 +7397,7 @@ static void ggml_compute_forward_mul_mat(
GGML_ASSERT(ne3 == ne13); GGML_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1 // we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type)); GGML_ASSERT(nb00 == ggml_type_size(src0->type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type)); GGML_ASSERT(nb10 == ggml_type_size(src1->type));
// dst cannot be transposed or permuted // dst cannot be transposed or permuted
@ -7477,6 +7409,7 @@ static void ggml_compute_forward_mul_mat(
// nb01 >= nb00 - src0 is not transposed // nb01 >= nb00 - src0 is not transposed
// compute by src0 rows // compute by src0 rows
// TODO: extract to "extra_op"
#if GGML_USE_LLAMAFILE #if GGML_USE_LLAMAFILE
// broadcast factors // broadcast factors
const int64_t r2 = ne12 / ne02; const int64_t r2 = ne12 / ne02;
@ -7487,15 +7420,15 @@ static void ggml_compute_forward_mul_mat(
if (src1_cont) { if (src1_cont) {
for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++) for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type), if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(type), nb01/ggml_type_size(src0->type),
(const char *)src1->data + i12*nb12 + i13*nb13, (const char *)src1->data + i12*nb12 + i13*nb13,
nb11/ggml_type_size(src1->type), nb11/ggml_type_size(src1->type),
(char *)dst->data + i12*nb2 + i13*nb3, (char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type), nb1/ggml_type_size(dst->type),
ith, nth, ith, nth,
type, src0->type,
src1->type, src1->type,
dst->type)) dst->type))
goto UseGgmlGemm1; goto UseGgmlGemm1;
@ -7516,16 +7449,7 @@ UseGgmlGemm1:;
for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i12 = 0; i12 < ne12; ++i12) {
int64_t i11_processed = 0; for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
4, ne10, blck_size_interleave);
}
i11_processed = ne11 - ne11 % 4;
}
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
ne10); ne10);
@ -7548,15 +7472,15 @@ UseGgmlGemm1:;
for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++) for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type), if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(type), nb01/ggml_type_size(src0->type),
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
row_size/ggml_type_size(vec_dot_type), row_size/ggml_type_size(vec_dot_type),
(char *)dst->data + i12*nb2 + i13*nb3, (char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type), nb1/ggml_type_size(dst->type),
ith, nth, ith, nth,
type, src0->type,
vec_dot_type, vec_dot_type,
dst->type)) dst->type))
goto UseGgmlGemm2; goto UseGgmlGemm2;
@ -7598,28 +7522,6 @@ UseGgmlGemm2:;
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
if ((ggml_n_dims(src0) == 2) && gemv) {
const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11;
int64_t src0_start = (ith * ne01) / nth;
int64_t src0_end = ((ith + 1) * ne01) / nth;
src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
if (src0_start >= src0_end) return;
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
if (gemm && (ne11 > 3)) {
gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
}
for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start);
}
return;
}
// The first chunk comes from our thread_id, the rest will get auto-assigned. // The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith; int current_chunk = ith;
@ -7642,7 +7544,7 @@ UseGgmlGemm2:;
num_rows_per_vec_dot = 1; num_rows_per_vec_dot = 1;
} }
ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
if (nth >= nchunk0 * nchunk1) { if (nth >= nchunk0 * nchunk1) {
break; break;
@ -7674,8 +7576,6 @@ static void ggml_compute_forward_mul_mat_id(
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float; ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
// we don't support permuted src0 or src1 // we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type)); GGML_ASSERT(nb00 == ggml_type_size(type));
@ -7761,34 +7661,6 @@ static void ggml_compute_forward_mul_mat_id(
const int64_t nr0 = ne01; // src0 rows const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1; // src1 rows const int64_t nr1 = cne1; // src1 rows
if (((ggml_n_dims(src0) - 1) == 2) && gemv) {
int64_t src0_cur_start = (ith * ne01) / nth;
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
if (src0_cur_start >= src0_cur_end) return;
for (int ir1 = 0; ir1 < nr1; ir1++) {
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
const int id = row_mapping.i1; // selected expert index
const int64_t i11 = id % ne11;
const int64_t i12 = row_mapping.i2; // row index in src1
const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12 * ne11) * row_size
: (i11 * nb11 + i12 * nb12));
gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
(const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
}
continue;
}
// distribute the thread work across the inner or outer loop based on which one is larger // distribute the thread work across the inner or outer loop based on which one is larger
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
@ -8096,9 +7968,6 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
{ {
ggml_compute_forward_out_prod_q_f32(params, dst); ggml_compute_forward_out_prod_q_f32(params, dst);
} break; } break;
@ -8361,9 +8230,6 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
default: default:
{ {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -8625,9 +8491,6 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
{ {
ggml_compute_forward_get_rows_q(params, dst); ggml_compute_forward_get_rows_q(params, dst);
} break; } break;
@ -9217,10 +9080,6 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q8_K: case GGML_TYPE_Q8_K:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
case GGML_TYPE_IQ4_NL_4_4:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
@ -12426,6 +12285,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
return; return;
} }
// extra_buffer op?
if (ggml_cpu_extra_compute_forward(params, tensor)) return;
switch (tensor->op) { switch (tensor->op) {
case GGML_OP_DUP: case GGML_OP_DUP:
{ {
@ -13373,6 +13235,8 @@ struct ggml_cplan ggml_graph_plan(
size_t cur = 0; size_t cur = 0;
if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
switch (node->op) { switch (node->op) {
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_DUP: case GGML_OP_DUP:
@ -13403,16 +13267,10 @@ struct ggml_cplan ggml_graph_plan(
} break; } break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
if (node->src[0]->buffer && ggml_backend_amx_buft_is_amx(node->src[0]->buffer->buft)) {
cur = ggml_backend_amx_desired_wsize(node);
}
#endif
const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type; const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
if (node->src[1]->type != vec_dot_type) { if (node->src[1]->type != vec_dot_type) {
size_t cur2 = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1])); cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
cur = MAX(cur, cur2);
} }
} break; } break;
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
@ -13449,7 +13307,6 @@ struct ggml_cplan ggml_graph_plan(
const int64_t ne00 = node->src[0]->ne[0]; // K const int64_t ne00 = node->src[0]->ne[0]; // K
const int64_t ne01 = node->src[0]->ne[1]; // Cout const int64_t ne01 = node->src[0]->ne[1]; // Cout
const int64_t ne02 = node->src[0]->ne[2]; // Cin const int64_t ne02 = node->src[0]->ne[2]; // Cin
const int64_t ne10 = node->src[1]->ne[0]; // L const int64_t ne10 = node->src[1]->ne[0]; // L
const int64_t ne11 = node->src[1]->ne[1]; // Cin const int64_t ne11 = node->src[1]->ne[1]; // Cin
@ -13514,6 +13371,7 @@ struct ggml_cplan ggml_graph_plan(
default: default:
break; break;
} }
}
work_size = MAX(work_size, cur); work_size = MAX(work_size, cur);
} }

View File

@ -2,12 +2,18 @@
#include "ggml-backend-impl.h" #include "ggml-backend-impl.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml-cpu-aarch64.h" #include "ggml-cpu-aarch64.h"
#include "ggml-cpu-traits.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#include "amx/amx.h" #include "amx/amx.h"
#include <cctype> #include <cctype>
#include <string> #include <string>
#include <vector> #include <vector>
#ifdef GGML_USE_CPU_HBM
#include "ggml-cpu-hbm.h"
#endif
#if defined(__APPLE__) #if defined(__APPLE__)
#include <sys/types.h> #include <sys/types.h>
#include <sys/sysctl.h> #include <sys/sysctl.h>
@ -23,115 +29,7 @@
// ggml-backend interface // ggml-backend interface
#ifdef GGML_USE_CPU_HBM std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type() {
// buffer type HBM
#include <hbwmalloc.h>
static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "CPU_HBM";
GGML_UNUSED(buft);
}
static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
hbw_free(buffer->context);
}
static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
void * ptr;
int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
if (result != 0) {
GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size);
return NULL;
}
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
buffer->buft = buft;
buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
return buffer;
}
ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {
/* .iface = */ {
/* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
},
/* .context = */ NULL,
};
return &ggml_backend_cpu_buffer_type_hbm;
}
#endif
// buffer type AARCH64
static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
tensor->extra = (void *)ggml_aarch64_get_optimal_repack_type(tensor); // NOLINT
GGML_UNUSED(buffer);
}
static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
enum ggml_type repack_type = (enum ggml_type)(intptr_t)tensor->extra;
ggml_aarch64_repack_tensor(tensor, repack_type, data, size);
GGML_UNUSED(buffer);
}
static const char * ggml_backend_cpu_aarch64_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "CPU_AARCH64";
GGML_UNUSED(buft);
}
static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
auto * buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
if (buffer == NULL) {
return NULL;
}
buffer->buft = buft;
buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor;
return buffer;
}
ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void) {
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_aarch64 = {
/* .iface = */ {
/* .get_name = */ ggml_backend_cpu_aarch64_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .is_host = */ NULL,
},
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
/* .context = */ NULL,
};
return &ggml_backend_cpu_buffer_type_aarch64;
}
bool ggml_backend_cpu_buft_is_aarch64(ggml_backend_buffer_type_t buft) {
return buft == ggml_backend_cpu_aarch64_buffer_type();
}
static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backend_dev_t device) {
static std::vector<ggml_backend_buffer_type_t> bufts = []() { static std::vector<ggml_backend_buffer_type_t> bufts = []() {
std::vector<ggml_backend_buffer_type_t> bufts; std::vector<ggml_backend_buffer_type_t> bufts;
@ -152,11 +50,22 @@ static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backen
return bufts; return bufts;
}(); }();
return bufts.data(); return bufts;
}
static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) {
return ggml_backend_cpu_get_extra_buffers_type().data();
GGML_UNUSED(device); GGML_UNUSED(device);
} }
static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) {
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
if (extra && extra == buft) return true;
}
return false;
}
// CPU backend - backend (stream) // CPU backend - backend (stream)
struct ggml_backend_cpu_context { struct ggml_backend_cpu_context {
@ -465,25 +374,19 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
return true; return true;
} }
if (src0 && src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) { // extra_buffer_op?
if (op->op != GGML_OP_MUL_MAT || src0->type == ggml_aarch64_get_optimal_repack_type(src0)) { for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
return false; if (extra) {
auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context;
if (buf_extra && buf_extra->supports_op(dev, op)) {
return true;
}
} }
} }
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) // the other case need host buffer.
if (src0 && src0->buffer && ggml_backend_amx_buft_is_amx(src0->buffer->buft)) { for (int i = 0; i < GGML_MAX_SRC; i++) {
return ggml_backend_amx_device_supports_op(op); if (op->src[i] && op->src[i]->buffer && !ggml_backend_buft_is_host(op->src[i]->buffer->buft)) {
}
for (int i = 1; i < GGML_MAX_SRC; i++) {
if (op->src[i] && op->src[i]->buffer && ggml_backend_amx_buft_is_amx(op->src[i]->buffer->buft)) {
return false;
}
}
#endif
for (int i = 1; i < GGML_MAX_SRC; i++) {
if (op->src[i] && op->src[i]->buffer && ggml_backend_cpu_buft_is_aarch64(op->src[i]->buffer->buft)) {
return false; return false;
} }
} }
@ -506,19 +409,10 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
default: default:
return true; return true;
} }
GGML_UNUSED(dev);
} }
static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
bool supported = ggml_backend_buft_is_host(buft) || ggml_backend_cpu_buft_is_aarch64(buft); return ggml_backend_buft_is_host(buft) || ggml_backend_cpu_is_extra_buffer_type(buft);
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
supported = supported || ggml_backend_amx_buft_is_amx(buft);
#endif
return supported;
GGML_UNUSED(dev); GGML_UNUSED(dev);
} }
@ -666,10 +560,12 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) { static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) {
if (strcmp(name, "ggml_backend_set_n_threads") == 0) { if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
return (void *)ggml_backend_cpu_set_n_threads; ggml_backend_set_n_threads_t fct = ggml_backend_cpu_set_n_threads;
return (void *)fct;
} }
if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) { if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) {
return (void *)ggml_backend_cpu_get_extra_bufts; ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_cpu_device_get_extra_buffers_type;
return (void *)fct;
} }
if (strcmp(name, "ggml_backend_get_features") == 0) { if (strcmp(name, "ggml_backend_get_features") == 0) {
return (void *)ggml_backend_cpu_get_features; return (void *)ggml_backend_cpu_get_features;

View File

@ -3210,7 +3210,7 @@ static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, con
static const ggml_backend_reg_i ggml_backend_cuda_reg_interface = { static const ggml_backend_reg_i ggml_backend_cuda_reg_interface = {
/* .get_name = */ ggml_backend_cuda_reg_get_name, /* .get_name = */ ggml_backend_cuda_reg_get_name,
/* .get_device_count = */ ggml_backend_cuda_reg_get_device_count, /* .get_device_count = */ ggml_backend_cuda_reg_get_device_count,
/* .get_device_get = */ ggml_backend_cuda_reg_get_device, /* .get_device = */ ggml_backend_cuda_reg_get_device,
/* .get_proc_address = */ ggml_backend_cuda_reg_get_proc_address, /* .get_proc_address = */ ggml_backend_cuda_reg_get_proc_address,
}; };

View File

@ -57,7 +57,7 @@ static __global__ void mul_mat_vec(
if (block_size > WARP_SIZE) { if (block_size > WARP_SIZE) {
buf_iw[tid/WARP_SIZE] = sumf; buf_iw[tid/WARP_SIZE] = sumf;
__syncthreads(); __syncthreads();
if (tid > WARP_SIZE) { if (tid >= WARP_SIZE) {
return; return;
} }
sumf = buf_iw[tid]; sumf = buf_iw[tid];

View File

@ -510,6 +510,35 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
#endif #endif
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"]; NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
if (path_lib == nil) {
// Try to find the resource in the directory where the current binary located.
NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
// Optionally, if this is a symlink, try to resolve it.
default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
// It is a relative path, adding the binary directory as directory prefix.
default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
}
if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
// Link to the resource could not be resolved.
default_metallib_path = nil;
} else {
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
}
}
} else {
// The resource couldn't be found in the binary's directory.
default_metallib_path = nil;
}
path_lib = default_metallib_path;
}
if (try_metallib && path_lib != nil) { if (try_metallib && path_lib != nil) {
// pre-compiled library found // pre-compiled library found
NSURL * libURL = [NSURL fileURLWithPath:path_lib]; NSURL * libURL = [NSURL fileURLWithPath:path_lib];

View File

@ -5220,15 +5220,6 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{ {
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
} break; } break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
{
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4);
} break;
case GGML_TYPE_Q4_0_8_8:
{
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8);
} break;
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:

View File

@ -4630,7 +4630,7 @@ static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, cons
static const ggml_backend_reg_i ggml_backend_sycl_reg_interface = { static const ggml_backend_reg_i ggml_backend_sycl_reg_interface = {
/* .get_name = */ ggml_backend_sycl_reg_get_name, /* .get_name = */ ggml_backend_sycl_reg_get_name,
/* .get_device_count = */ ggml_backend_sycl_reg_get_device_count, /* .get_device_count = */ ggml_backend_sycl_reg_get_device_count,
/* .get_device_get = */ ggml_backend_sycl_reg_get_device, /* .get_device = */ ggml_backend_sycl_reg_get_device,
/* .get_proc_address = */ ggml_backend_sycl_reg_get_proc_address, /* .get_proc_address = */ ggml_backend_sycl_reg_get_proc_address,
}; };

View File

@ -8,6 +8,20 @@ if (Vulkan_FOUND)
../../include/ggml-vulkan.h ../../include/ggml-vulkan.h
) )
# Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
# If it's not, there will be an error to stderr.
# If it's supported, set a define to indicate that we should compile those shaders
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error)
if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
else()
message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
endif()
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})

File diff suppressed because it is too large Load Diff

View File

@ -7,6 +7,12 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif #endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_shader_subgroup_basic : enable
#endif
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#endif #endif
@ -57,6 +63,7 @@ layout (push_constant) uniform parameter
#endif #endif
} p; } p;
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
layout (constant_id = 1) const uint BM = 64; layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64; layout (constant_id = 2) const uint BN = 64;
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
@ -65,13 +72,26 @@ layout (constant_id = 5) const uint WN = 32;
layout (constant_id = 6) const uint WMITER = 2; layout (constant_id = 6) const uint WMITER = 2;
layout (constant_id = 7) const uint TM = 4; layout (constant_id = 7) const uint TM = 4;
layout (constant_id = 8) const uint TN = 2; layout (constant_id = 8) const uint TN = 2;
layout (constant_id = 9) const uint WARP = 32; layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
layout (constant_id = 10) const uint WARP = 32;
shared FLOAT_TYPE buf_a[BM * (BK+1)]; #ifdef COOPMAT
shared FLOAT_TYPE buf_b[BN * (BK+1)]; #define SHMEM_STRIDE (BK + 8)
#else
#define SHMEM_STRIDE (BK + 1)
#endif
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
shared u16vec2 row_ids[3072]; shared u16vec2 row_ids[3072];
#endif // MUL_MAT_ID
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif #endif
void main() { void main() {
@ -98,17 +118,32 @@ void main() {
const uint ik = gl_WorkGroupID.x / blocks_m; const uint ik = gl_WorkGroupID.x / blocks_m;
const uint ic = gl_WorkGroupID.y; const uint ic = gl_WorkGroupID.y;
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER; const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER; const uint WSUBN = WN / WNITER;
#ifdef COOPMAT
const uint warp_i = gl_SubgroupID;
const uint tiw = gl_SubgroupInvocationID;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint storestride = WARP / TM;
const uint store_r = tiw % TM;
const uint store_c = tiw / TM;
#else
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP; const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM); const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM); const uint tiwc = tiw / (WSUBM / TM);
#endif
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
@ -156,21 +191,31 @@ void main() {
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
#endif #endif
float sums[WMITER * TM * WNITER * TN]; #ifdef COOPMAT
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
#else
ACC_TYPE sums[WMITER * TM * WNITER * TN];
FLOAT_TYPE cache_a[WMITER * TM]; FLOAT_TYPE cache_a[WMITER * TM];
FLOAT_TYPE cache_b[WNITER * TN]; FLOAT_TYPE cache_b[WNITER * TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = 0.0f; sums[i] = ACC_TYPE(0.0f);
} }
#endif
[[unroll]] for (uint block = start_k; block < end_k; block += BK) { for (uint block = start_k; block < end_k; block += BK) {
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
#if defined(DATA_A_F32) || defined(DATA_A_F16) #if defined(DATA_A_F32) || defined(DATA_A_F16)
#if LOAD_VEC_A == 8 #if LOAD_VEC_A == 8
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
@ -181,21 +226,21 @@ void main() {
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
#elif LOAD_VEC_A == 4 #elif LOAD_VEC_A == 4
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
#else #else
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
} else { } else {
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f); buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
} }
#endif #endif
#elif defined(DATA_A_Q4_0) #elif defined(DATA_A_Q4_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -208,7 +253,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q4_1) #elif defined(DATA_A_Q4_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -222,7 +267,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q5_0) #elif defined(DATA_A_Q5_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -237,7 +282,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q5_1) #elif defined(DATA_A_Q5_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -253,7 +298,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q8_0) #elif defined(DATA_A_Q8_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = (idx & 0xF) * 2; const uint iqs = (idx & 0xF) * 2;
@ -265,7 +310,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q2_K) #elif defined(DATA_A_Q2_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -284,7 +329,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q3_K) #elif defined(DATA_A_Q3_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -308,7 +353,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
#elif defined(DATA_A_Q4_K) #elif defined(DATA_A_Q4_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -320,15 +365,20 @@ void main() {
const vec2 loadd = vec2(data_a[ib].d); const vec2 loadd = vec2(data_a[ib].d);
uint8_t sc; const uint scidx0 = (is < 4) ? is : (is + 4);
uint8_t mbyte; const uint scidx1 = (is < 4) ? is : (is - 4);
if (is < 4) { const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
sc = uint8_t(data_a[ib].scales[is ] & 63); const uint scidxshift1 = (is < 4) ? 0 : 2;
mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); const uint mbidx0 = is + 4;
} else { const uint mbidx1 = (is < 4) ? is + 4 : is;
sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); const uint mbidxshift0 = (is < 4) ? 0 : 4;
} const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const float d = loadd.x * sc; const float d = loadd.x * sc;
const float m = -loadd.y * mbyte; const float m = -loadd.y * mbyte;
@ -336,7 +386,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
#elif defined(DATA_A_Q5_K) #elif defined(DATA_A_Q5_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -351,15 +401,20 @@ void main() {
const vec2 loadd = vec2(data_a[ib].d); const vec2 loadd = vec2(data_a[ib].d);
uint8_t sc; const uint scidx0 = (is < 4) ? is : (is + 4);
uint8_t mbyte; const uint scidx1 = (is < 4) ? is : (is - 4);
if (is < 4) { const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
sc = uint8_t(data_a[ib].scales[is ] & 63); const uint scidxshift1 = (is < 4) ? 0 : 2;
mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); const uint mbidx0 = is + 4;
} else { const uint mbidx1 = (is < 4) ? is + 4 : is;
sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); const uint mbidxshift0 = (is < 4) ? 0 : 4;
} const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const float d = loadd.x * sc; const float d = loadd.x * sc;
const float m = -loadd.y * mbyte; const float m = -loadd.y * mbyte;
@ -367,7 +422,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
#elif defined(DATA_A_Q6_K) #elif defined(DATA_A_Q6_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -386,7 +441,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
#elif defined(DATA_A_IQ4_NL) #elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -407,7 +462,7 @@ void main() {
#else #else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif #endif
const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
@ -423,24 +478,24 @@ void main() {
#else #else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif #endif
const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
#elif !MUL_MAT_ID #elif !MUL_MAT_ID
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
} else { } else {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
} }
#else #else
const uint row_i = ic * BN + loadc_b + l; const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1) { if (row_i < _ne1) {
const u16vec2 row_idx = row_ids[row_i]; const u16vec2 row_idx = row_ids[row_i];
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else { } else {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
} }
#endif #endif
} }
@ -450,16 +505,30 @@ void main() {
pos_a += BK / LOAD_VEC_A; pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B; pos_b += BK / LOAD_VEC_B;
for (uint i = 0; i < BK; i++) { #ifdef COOPMAT
[[unroll]] for (uint i = 0; i < BK; i += TK) {
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
// Load from shared into cache
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
}
}
}
#else
[[unroll]] for (uint i = 0; i < BK; i++) {
// Load from shared into cache // Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint j = 0; j < TM; j++) { [[unroll]] for (uint j = 0; j < TM; j++) {
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
} }
} }
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint j = 0; j < TN; j++) { [[unroll]] for (uint j = 0; j < TN; j++) {
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
} }
} }
@ -468,12 +537,13 @@ void main() {
[[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]); sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]);
} }
} }
} }
} }
} }
#endif
barrier(); barrier();
} }
@ -485,6 +555,54 @@ void main() {
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif #endif
#ifdef COOPMAT
#ifdef MUL_MAT_ID
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
#else
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
if (is_aligned && is_in_bounds) {
// Full coopMat is within bounds and stride_d is aligned with 16B
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
} else if (is_in_bounds) {
// Full coopMat is within bounds, but stride_d is not aligned
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
// Partial coopMat is within bounds
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
}
}
#endif // MUL_MAT_ID
#else
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
@ -496,7 +614,7 @@ void main() {
if (row_i >= _ne1) break; if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i]; const u16vec2 row_idx = row_ids[row_i];
#endif #endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) {
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
@ -504,9 +622,10 @@ void main() {
if (dr_warp + cr < p.M && dc_warp + cc < p.N) { if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
} }
#endif #endif // MUL_MAT_ID
} }
} }
} }
} }
#endif // COOPMAT
} }

View File

@ -16,6 +16,5 @@ void main() {
if (i >= p.KX) { if (i >= p.KX) {
return; return;
} }
data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.));
data_d[i] = D_TYPE(tanh(data_a[i]));
} }

View File

@ -0,0 +1,7 @@
#version 460
#extension GL_NV_cooperative_matrix2 : require
void main()
{
}

View File

@ -60,6 +60,7 @@ const std::vector<std::string> type_names = {
"iq4_nl" "iq4_nl"
}; };
namespace {
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32 #ifdef _WIN32
HANDLE stdout_read, stdout_write; HANDLE stdout_read, stdout_write;
@ -198,8 +199,8 @@ static uint32_t compile_count = 0;
static std::mutex compile_count_mutex; static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond; static std::condition_variable compile_count_cond;
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) { void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
std::string out_fname = join_paths(output_dir, name + ".spv"); std::string out_fname = join_paths(output_dir, name + ".spv");
std::string in_path = join_paths(input_dir, in_fname); std::string in_path = join_paths(input_dir, in_fname);
@ -258,7 +259,7 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
} }
static std::vector<std::future<void>> compiles; static std::vector<std::future<void>> compiles;
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) { void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
{ {
// wait until fewer than N compiles are in progress. // wait until fewer than N compiles are in progress.
// 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
@ -269,10 +270,10 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
} }
compile_count++; compile_count++;
} }
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat2, f16acc)); compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
} }
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) { void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
@ -291,14 +292,20 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
if (coopmat) {
base_dict["COOPMAT"] = "1";
}
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
// Shaders with f16 B_TYPE // Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string data_a_key = "DATA_A_" + to_uppercase(tname);
@ -307,12 +314,12 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
// For aligned matmul loads // For aligned matmul loads
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
if (tname != "f16" && tname != "f32") { if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
} }
} }
} }
@ -322,28 +329,27 @@ void process_shaders() {
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}}; std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
// matmul // matmul
for (const auto& fp16 : {false, true}) {
for (const auto& matmul_id : {false, true}) { for (const auto& matmul_id : {false, true}) {
for (const auto& coopmat2 : {false, true}) { // No coopmats
for (const auto& f16acc : {false, true}) { // fp32
#if !defined(VK_NV_cooperative_matrix2) matmul_shaders(false, matmul_id, false, false, false);
if (coopmat2) {
continue; // fp16, fp32acc and fp16acc
} matmul_shaders(true, matmul_id, false, false, false);
matmul_shaders(true, matmul_id, false, false, true);
// Coopmat, fp32acc and fp16acc
matmul_shaders(true, matmul_id, true, false, false);
matmul_shaders(true, matmul_id, true, false, true);
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
// Coopmat2, fp32acc and fp16acc
matmul_shaders(true, matmul_id, false, true, false);
matmul_shaders(true, matmul_id, false, true, true);
#endif #endif
if (coopmat2 && !fp16) {
continue;
}
if (!coopmat2 && f16acc) {
continue;
}
matmul_shaders(fp16, matmul_id, coopmat2, f16acc);
}
}
}
} }
#if defined(VK_NV_cooperative_matrix2) #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
// flash attention // flash attention
for (const auto& f16acc : {false, true}) { for (const auto& f16acc : {false, true}) {
std::string acctype = f16acc ? "float16_t" : "float"; std::string acctype = f16acc ? "float16_t" : "float";
@ -355,11 +361,11 @@ void process_shaders() {
if (tname == "f16") { if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, true, f16acc); merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
} else { } else {
std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, true, f16acc); merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
} }
} }
} }
@ -524,6 +530,7 @@ void write_output_files() {
fclose(hdr); fclose(hdr);
fclose(src); fclose(src);
} }
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
std::map<std::string, std::string> args; std::map<std::string, std::string> args;

View File

@ -8,7 +8,10 @@
// FIXME: required here for quantization functions // FIXME: required here for quantization functions
#include "ggml-quants.h" #include "ggml-quants.h"
#include "ggml-aarch64.h"
#ifdef GGML_USE_CPU_HBM
#include <hbwmalloc.h>
#endif
#if defined(_MSC_VER) || defined(__MINGW32__) #if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW #include <malloc.h> // using malloc.h with MSC/MINGW
@ -788,32 +791,23 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
.to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row, .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
.from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref, .from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref,
}, },
[GGML_TYPE_Q4_0_4_4] = { [31] = { // GGML_TYPE_Q4_0_4_4
.type_name = "q4_0_4x4", .type_name = "TYPE_Q4_0_4_4 REMOVED, use Q4_0 with runtime repacking",
.blck_size = QK4_0, .blck_size = 0,
.blck_size_interleave = 4, .type_size = 0,
.type_size = sizeof(block_q4_0), .is_quantized = false,
.is_quantized = true,
.to_float = NULL,
.from_float_ref = NULL,
}, },
[GGML_TYPE_Q4_0_4_8] = { [32] = { // GGML_TYPE_Q4_0_4_8
.type_name = "q4_0_4x8", .type_name = "TYPE_Q4_0_4_8 REMOVED, use Q4_0 with runtime repacking",
.blck_size = QK4_0, .blck_size = 0,
.blck_size_interleave = 8, .type_size = 0,
.type_size = sizeof(block_q4_0), .is_quantized = false,
.is_quantized = true,
.to_float = NULL,
.from_float_ref = NULL,
}, },
[GGML_TYPE_Q4_0_8_8] = { [33] = { // GGML_TYPE_Q4_0_8_8
.type_name = "q4_0_8x8", .type_name = "TYPE_Q4_0_8_8 REMOVED, use Q4_0 with runtime repacking",
.blck_size = QK4_0, .blck_size = 0,
.blck_size_interleave = 8, .type_size = 0,
.type_size = sizeof(block_q4_0), .is_quantized = false,
.is_quantized = true,
.to_float = NULL,
.from_float_ref = NULL,
}, },
[GGML_TYPE_TQ1_0] = { [GGML_TYPE_TQ1_0] = {
.type_name = "tq1_0", .type_name = "tq1_0",
@ -831,14 +825,23 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
.to_float = (ggml_to_float_t) dequantize_row_tq2_0, .to_float = (ggml_to_float_t) dequantize_row_tq2_0,
.from_float_ref = (ggml_from_float_t) quantize_row_tq2_0_ref, .from_float_ref = (ggml_from_float_t) quantize_row_tq2_0_ref,
}, },
[GGML_TYPE_IQ4_NL_4_4] = { [36] = { // GGML_TYPE_IQ4_NL_4_4
.type_name = "iq4_nl_4x4", .type_name = "TYPE_IQ4_NL_4_4 REMOVED, use IQ4_NL with runtime repacking",
.blck_size = QK4_NL, .blck_size = 0,
.blck_size_interleave = 4, .type_size = 0,
.type_size = sizeof(block_iq4_nl), .is_quantized = false,
.is_quantized = true, },
.to_float = NULL, [37] = { // GGML_TYPE_IQ4_NL_4_8
.from_float_ref = NULL, .type_name = "TYPE_IQ4_NL_4_8 REMOVED, use IQ4_NL with runtime repacking",
.blck_size = 0,
.type_size = 0,
.is_quantized = false,
},
[38] = { // GGML_TYPE_IQ4_NL_8_8
.type_name = "TYPE_IQ4_NL_8_8 REMOVED, use IQ4_NL with runtime repacking",
.blck_size = 0,
.type_size = 0,
.is_quantized = false,
}, },
}; };
@ -1270,9 +1273,6 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break;
case GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = GGML_TYPE_Q4_0_4_8; break;
case GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = GGML_TYPE_Q4_0_8_8; break;
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
} }
@ -6304,9 +6304,6 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
size_t elemsize = sizeof(ggml_fp16_t); size_t elemsize = sizeof(ggml_fp16_t);
@ -6838,7 +6835,16 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
(int64_t) info->ne[2] * (int64_t) info->ne[2] *
(int64_t) info->ne[3]; (int64_t) info->ne[3];
if (ggml_blck_size(info->type) == 0 || ne % ggml_blck_size(info->type) != 0) { if (ggml_blck_size(info->type) == 0 ) {
// this tensor type support have been removed:
fprintf(stderr, "%s: tensor '%s' of type %d: %s\n",
__func__, info->name.data, (int) info->type, ggml_type_name(info->type));
fclose(file);
gguf_free(ctx);
return NULL;
}
if (ne % ggml_blck_size(info->type) != 0) {
fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n", fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
__func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type)); __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
fclose(file); fclose(file);

View File

@ -761,6 +761,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_K,
@ -1432,9 +1433,6 @@ class GGMLQuantizationType(IntEnum):
F64 = 28 F64 = 28
IQ1_M = 29 IQ1_M = 29
BF16 = 30 BF16 = 30
Q4_0_4_4 = 31
Q4_0_4_8 = 32
Q4_0_8_8 = 33
TQ1_0 = 34 TQ1_0 = 34
TQ2_0 = 35 TQ2_0 = 35
@ -1478,9 +1476,9 @@ class LlamaFileType(IntEnum):
MOSTLY_IQ4_XS = 30 # except 1d tensors MOSTLY_IQ4_XS = 30 # except 1d tensors
MOSTLY_IQ1_M = 31 # except 1d tensors MOSTLY_IQ1_M = 31 # except 1d tensors
MOSTLY_BF16 = 32 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors
MOSTLY_Q4_0_4_4 = 33 # except 1d tensors # MOSTLY_Q4_0_4_4 = 33 # removed from gguf files, use Q4_0 and runtime repack
MOSTLY_Q4_0_4_8 = 34 # except 1d tensors # MOSTLY_Q4_0_4_8 = 34 # removed from gguf files, use Q4_0 and runtime repack
MOSTLY_Q4_0_8_8 = 35 # except 1d tensors # MOSTLY_Q4_0_8_8 = 35 # removed from gguf files, use Q4_0 and runtime repack
MOSTLY_TQ1_0 = 36 # except 1d tensors MOSTLY_TQ1_0 = 36 # except 1d tensors
MOSTLY_TQ2_0 = 37 # except 1d tensors MOSTLY_TQ2_0 = 37 # except 1d tensors
@ -1556,9 +1554,6 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.F64: (1, 8), GGMLQuantizationType.F64: (1, 8),
GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
GGMLQuantizationType.BF16: (1, 2), GGMLQuantizationType.BF16: (1, 2),
GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16),
GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16),
GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16),
GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13), GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13),
GGMLQuantizationType.TQ2_0: (256, 2 + 64), GGMLQuantizationType.TQ2_0: (256, 2 + 64),
} }

View File

@ -146,6 +146,7 @@ class TensorNameMap:
# Attention query # Attention query
MODEL_TENSOR.ATTN_Q: ( MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
"layers.{bid}.attention.wq", # llama-pth "layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert "encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j "transformer.h.{bid}.attn.q_proj", # gpt-j
@ -158,6 +159,7 @@ class TensorNameMap:
# Attention key # Attention key
MODEL_TENSOR.ATTN_K: ( MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
"layers.{bid}.attention.wk", # llama-pth "layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert "encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j "transformer.h.{bid}.attn.k_proj", # gpt-j

View File

@ -172,9 +172,9 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors //LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // removed from gguf files, use Q4_0 and runtime repack
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors //LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // removed from gguf files, use Q4_0 and runtime repack
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors //LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack
LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors

View File

@ -0,0 +1,112 @@
ied 4 ½ months
__ggml_vocab_test__
Führer
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
this is 🦙.cpp
__ggml_vocab_test__
w048 7tuijk dsdfhu
__ggml_vocab_test__
нещо на Български
__ggml_vocab_test__
កាន់តែពិសេសអាចខលចេញ
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
Hello
__ggml_vocab_test__
(
__ggml_vocab_test__
=
__ggml_vocab_test__
' era
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3
__ggml_vocab_test__
33
__ggml_vocab_test__
333
__ggml_vocab_test__
3333
__ggml_vocab_test__
33333
__ggml_vocab_test__
333333
__ggml_vocab_test__
3333333
__ggml_vocab_test__
33333333
__ggml_vocab_test__
333333333
__ggml_vocab_test__
Cửa Việt
__ggml_vocab_test__
discards
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天 ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
__ggml_vocab_test__

View File

@ -0,0 +1,46 @@
2550 204 18430 377
597 2768 298 8564
1437
1437 1437
1437 1437 1437
50117
50118
50140
50140 50118
50117 50118
31414 232
20920 232
31414 623
20920 623
20920 623 328
31414 6 232 328
20920 6 232 328
42 16 8103 18164 27 4 49317
605 40976 262 10109 18474 385 29 36807 6455
36765 25482 22063 23171 34251 18697 10809 26161 18697 3602 22063 27969 40966 25417 15264 26161 24269 36709 41171 35328
1376 17772 7471 1376 17772 19002 1376 17772 9085 1376 4333 13859 1376 17772 9357 1376 4333 9264 1376 17772 25448 1376 17772 18400 1376 17772 4333 1376 4333 10172 1376 17772 4333 1376 17772 7258 1376 17772 19002 1376 17772 5782 1376 17772 10172 1376 17772 3726 1376 17772 5782 1376 4333 10172 1376 17772 23171
6569 15113 7471 36 21113 43 17841 19002 17 8384 6569 14285 4958 12605 36 34654 2841 4203 354 10146 26511 1070 43 36174 5782 36 8338 21554 14 34 63 308 19233 43
31414
20920
1437 20920
1437 1437 20920
1437 1437 1437 20920
1437 1437 1437 20920 50118 1437 1437 1437 20920
36
50118 5457
108 3567
31414 6 1423 108 1250 328 1336 32 47 17841 10172 17487 47876 3602 48617 15264 46537 11423 27326 48494 8210 49233 1558 1570 27761 49429 43251 10809 17772
32376 12846
246
3103
25631
46152
3103 25631
46152 3103
46152 25631
46152 46152
46152 3103 25631
347 1376 2023 12410 102 16376 1376 2023 6382 90
9553 5954
50118 1437 50140 1437 50140 50118 1437 50117 1437 50117 50117 1437 50117 50118 1437 1437 50118 1437 1437 1437 50118 1437 1437 1437 1437 50118 1437 1437 1437 1437 1437 50118 6569 15113 7471 36 21113 43 17841 19002 17 8384 6569 14285 4958 12605 36 34654 2841 4203 354 10146 26511 1070 43 36174 5782 8103 18164 27 6569 18164 27 155 2357 30242 155 25631 30242 3103 30242 25631 30242 46152 30242 3103 25631 155 4 246 155 7586 246 155 734 246 25974 17772 7471 1376 17772 19002 1376 17772 9085 1376 4333 13859 1376 17772 9357 1376 4333 9264 1376 17772 25448 1376 17772 18400 1376 17772 4333 1376 4333 10172 1376 17772 4333 1376 17772 7258 1376 17772 19002 1376 17772 5782 18636 10172 17487 47876 3602 48617 15264 46537 11423 27326 48494 8210 49233 1558 1570 27761 49429 43251 10809 17772 36738 48332 47463 18697 10809 25482 22063 23171 34251 18697 10809 26161 18697 3602 22063 27969 40966 25417 15264 26161 24269 36709 41171 35328 128 49690 108 49972 49519 12905 48149 48149 43796 32376 12846 27282 28749 38 348 57 128 41042 37 18 89 6 128 4629 47 686 116 128 448 45 686 38 581 146 24 6 128 495 47 101 103 6845 116 166 108 30660 10 108 462 574

View File

@ -4578,9 +4578,6 @@ struct llama_model_loader {
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break;
case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break;
case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break;
case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break;
default: default:
{ {
LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
@ -5344,9 +5341,6 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4";
case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: return "Q4_0_4_8";
case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: return "Q4_0_8_8";
default: return "unknown, may not work"; default: return "unknown, may not work";
} }
@ -18367,10 +18361,6 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
new_type = GGML_TYPE_IQ3_S; new_type = GGML_TYPE_IQ3_S;
} }
else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 ||
new_type == GGML_TYPE_Q4_0_8_8) {
new_type = GGML_TYPE_Q4_0;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) { else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) {
new_type = GGML_TYPE_Q4_K; new_type = GGML_TYPE_Q4_K;
} }
@ -18693,9 +18683,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break;
case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break;
case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: default_type = GGML_TYPE_Q4_0_4_4; break;
case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: default_type = GGML_TYPE_Q4_0_4_8; break;
case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: default_type = GGML_TYPE_Q4_0_8_8; break;
default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
} }
@ -19034,14 +19021,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
f32_data = (float *) f32_conv_buf.data(); f32_data = (float *) f32_conv_buf.data();
} }
int chunk_size_multiplier = 1;
if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || new_type == GGML_TYPE_Q4_0_8_8) {
if ((new_type == GGML_TYPE_Q4_0_8_8) && (tensor->ne[1] % 8 != 0)) new_type = GGML_TYPE_Q4_0;
else if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0;
if (new_type == GGML_TYPE_Q4_0_8_8) chunk_size_multiplier = 8;
else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8) chunk_size_multiplier = 4;
}
LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
fflush(stdout); fflush(stdout);
@ -19054,8 +19033,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
const int64_t nrows = tensor->ne[1]; const int64_t nrows = tensor->ne[1];
static const int64_t min_chunk_size = 32 * 512; static const int64_t min_chunk_size = 32 * 512;
const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)) * const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row));
chunk_size_multiplier;
const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;