mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
metal : enable shader debugging (cmake option) (#4705)
* ggml : disable fast-math for Metal (cmake build only) ggml-ci * metal : fix Metal API debug warnings * cmake : add -fno-inline for Metal build (#4545) * metal : fix API debug warnings * metal : fix compile warnings * metal : use uint64_t for strides * cmake : rename option to LLAMA_METAL_SHADER_DEBUG * metal : fix mat-vec Q8_0 kernel for BS > 1 * metal : normalize mat-vec kernel signatures * cmake : respect LLAMA_QKK_64 option * metal : fix mat-vec Q4_K kernel for QK_K == 64 ggml-ci
This commit is contained in:
parent
edd1ab7bc3
commit
58ba655af0
@ -95,6 +95,7 @@ option(LLAMA_HIP_UMA "llama: use HIP unified memory arch
|
|||||||
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
|
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
|
||||||
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
|
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
|
||||||
option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF)
|
option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF)
|
||||||
|
option(LLAMA_METAL_SHADER_DEBUG "llama: compile Metal with -fno-fast-math" OFF)
|
||||||
option(LLAMA_MPI "llama: use MPI" OFF)
|
option(LLAMA_MPI "llama: use MPI" OFF)
|
||||||
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
|
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
|
||||||
|
|
||||||
@ -154,9 +155,9 @@ if (APPLE AND LLAMA_ACCELERATE)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (LLAMA_METAL)
|
if (LLAMA_METAL)
|
||||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||||
|
|
||||||
message(STATUS "Metal framework found")
|
message(STATUS "Metal framework found")
|
||||||
set(GGML_HEADERS_METAL ggml-metal.h)
|
set(GGML_HEADERS_METAL ggml-metal.h)
|
||||||
@ -173,6 +174,33 @@ if (LLAMA_METAL)
|
|||||||
# copy ggml-metal.metal to bin directory
|
# copy ggml-metal.metal to bin directory
|
||||||
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
|
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
|
||||||
|
|
||||||
|
if (LLAMA_METAL_SHADER_DEBUG)
|
||||||
|
# custom command to do the following:
|
||||||
|
# xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
|
||||||
|
# xcrun -sdk macosx metallib ggml-metal.air -o ggml.metallib
|
||||||
|
#
|
||||||
|
# note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works
|
||||||
|
# disabling fast math is needed in order to pass tests/test-backend-ops
|
||||||
|
# note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
|
||||||
|
set(XC_FLAGS -fno-fast-math -fno-inline -g)
|
||||||
|
if (LLAMA_QKK_64)
|
||||||
|
set(XC_FLAGS ${XC_FLAGS} -DQK_K=64)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml.metallib
|
||||||
|
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air
|
||||||
|
COMMAND xcrun -sdk macosx metallib ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml.metallib
|
||||||
|
DEPENDS ggml-metal.metal
|
||||||
|
COMMENT "Compiling Metal kernels"
|
||||||
|
)
|
||||||
|
|
||||||
|
add_custom_target(
|
||||||
|
ggml-metal ALL
|
||||||
|
DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml.metallib
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS}
|
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS}
|
||||||
${FOUNDATION_LIBRARY}
|
${FOUNDATION_LIBRARY}
|
||||||
${METAL_FRAMEWORK}
|
${METAL_FRAMEWORK}
|
||||||
|
14
ci/run.sh
14
ci/run.sh
@ -30,6 +30,12 @@ sd=`dirname $0`
|
|||||||
cd $sd/../
|
cd $sd/../
|
||||||
SRC=`pwd`
|
SRC=`pwd`
|
||||||
|
|
||||||
|
CMAKE_EXTRA=""
|
||||||
|
|
||||||
|
if [ ! -z ${GG_BUILD_METAL} ]; then
|
||||||
|
CMAKE_EXTRA="${CMAKE_EXTRA} -DLLAMA_METAL_SHADER_DEBUG=ON"
|
||||||
|
fi
|
||||||
|
|
||||||
## helpers
|
## helpers
|
||||||
|
|
||||||
# download a file if it does not exist or if it is outdated
|
# download a file if it does not exist or if it is outdated
|
||||||
@ -81,8 +87,8 @@ function gg_run_ctest_debug {
|
|||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
(time cmake -DCMAKE_BUILD_TYPE=Debug .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
|
(time cmake -DCMAKE_BUILD_TYPE=Debug ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
|
||||||
(time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log
|
(time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log
|
||||||
|
|
||||||
(time ctest --output-on-failure -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
(time ctest --output-on-failure -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
||||||
|
|
||||||
@ -109,8 +115,8 @@ function gg_run_ctest_release {
|
|||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
(time cmake -DCMAKE_BUILD_TYPE=Release .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
|
(time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
|
||||||
(time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log
|
(time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log
|
||||||
|
|
||||||
if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||||
(time ctest --output-on-failure ) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
(time ctest --output-on-failure ) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
||||||
|
28
ggml-metal.m
28
ggml-metal.m
@ -257,13 +257,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
||||||
#endif
|
#endif
|
||||||
NSError * error = nil;
|
NSError * error = nil;
|
||||||
NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
|
NSString * libPath = [bundle pathForResource:@"ggml" ofType:@"metallib"];
|
||||||
if (libPath != nil) {
|
if (libPath != nil) {
|
||||||
|
// pre-compiled library found
|
||||||
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
||||||
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
|
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
|
||||||
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
||||||
} else {
|
} else {
|
||||||
GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
GGML_METAL_LOG_INFO("%s: ggml.metallib not found, loading from source\n", __func__);
|
||||||
|
|
||||||
NSString * sourcePath;
|
NSString * sourcePath;
|
||||||
NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
||||||
@ -291,6 +292,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
options = [MTLCompileOptions new];
|
options = [MTLCompileOptions new];
|
||||||
options.preprocessorMacros = @{ @"QK_K" : @(64) };
|
options.preprocessorMacros = @{ @"QK_K" : @(64) };
|
||||||
#endif
|
#endif
|
||||||
|
// try to disable fast-math
|
||||||
|
// NOTE: this seems to have no effect whatsoever
|
||||||
|
// instead, in order to disable fast-math, we have to build ggml.metallib from the command line
|
||||||
|
// using xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
|
||||||
|
// and go through the "pre-compiled library found" path above
|
||||||
|
//[options setFastMathEnabled:false];
|
||||||
|
|
||||||
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1230,7 +1238,7 @@ void ggml_metal_graph_compute(
|
|||||||
// not sure how to avoid this
|
// not sure how to avoid this
|
||||||
// TODO: make a simpler cpy_bytes kernel
|
// TODO: make a simpler cpy_bytes kernel
|
||||||
|
|
||||||
const int nth = MIN(1024, ne00);
|
const int nth = MIN((int) ctx->pipeline_cpy_f32_f32.maxTotalThreadsPerThreadgroup, ne00);
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
@ -1285,7 +1293,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
||||||
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
||||||
|
|
||||||
const int nth = MIN(1024, ne0);
|
const int nth = MIN((int) ctx->pipeline_add.maxTotalThreadsPerThreadgroup, ne00);
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
@ -1785,8 +1793,9 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
||||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
||||||
// TODO: how to make this an array? read Metal docs
|
// TODO: how to make this an array? read Metal docs
|
||||||
for (int j = 0; j < n_as; ++j) {
|
for (int j = 0; j < 8; ++j) {
|
||||||
struct ggml_tensor * src_cur = dst->src[2 + j];
|
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
||||||
|
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
||||||
|
|
||||||
size_t offs_src_cur = 0;
|
size_t offs_src_cur = 0;
|
||||||
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
||||||
@ -1909,8 +1918,9 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
||||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
||||||
// TODO: how to make this an array? read Metal docs
|
// TODO: how to make this an array? read Metal docs
|
||||||
for (int j = 0; j < n_as; ++j) {
|
for (int j = 0; j < 8; ++j) {
|
||||||
struct ggml_tensor * src_cur = dst->src[2 + j];
|
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
||||||
|
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
||||||
|
|
||||||
size_t offs_src_cur = 0;
|
size_t offs_src_cur = 0;
|
||||||
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
||||||
@ -2229,7 +2239,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
||||||
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
||||||
|
|
||||||
const int nth = MIN(1024, ne0);
|
const int nth = MIN((int) ctx->pipeline_upscale_f32.maxTotalThreadsPerThreadgroup, ne0);
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
475
ggml-metal.metal
475
ggml-metal.metal
@ -59,26 +59,26 @@ kernel void kernel_add(
|
|||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant int64_t & ne03,
|
constant int64_t & ne03,
|
||||||
constant int64_t & nb00,
|
constant uint64_t & nb00,
|
||||||
constant int64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant int64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & nb03,
|
constant uint64_t & nb03,
|
||||||
constant int64_t & ne10,
|
constant int64_t & ne10,
|
||||||
constant int64_t & ne11,
|
constant int64_t & ne11,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
constant int64_t & ne13,
|
constant int64_t & ne13,
|
||||||
constant int64_t & nb10,
|
constant uint64_t & nb10,
|
||||||
constant int64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant int64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & nb13,
|
constant uint64_t & nb13,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & ne2,
|
constant int64_t & ne2,
|
||||||
constant int64_t & ne3,
|
constant int64_t & ne3,
|
||||||
constant int64_t & nb0,
|
constant uint64_t & nb0,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant int64_t & nb2,
|
constant uint64_t & nb2,
|
||||||
constant int64_t & nb3,
|
constant uint64_t & nb3,
|
||||||
constant int64_t & offs,
|
constant int64_t & offs,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
@ -109,26 +109,26 @@ kernel void kernel_mul(
|
|||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant int64_t & ne03,
|
constant int64_t & ne03,
|
||||||
constant int64_t & nb00,
|
constant uint64_t & nb00,
|
||||||
constant int64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant int64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & nb03,
|
constant uint64_t & nb03,
|
||||||
constant int64_t & ne10,
|
constant int64_t & ne10,
|
||||||
constant int64_t & ne11,
|
constant int64_t & ne11,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
constant int64_t & ne13,
|
constant int64_t & ne13,
|
||||||
constant int64_t & nb10,
|
constant uint64_t & nb10,
|
||||||
constant int64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant int64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & nb13,
|
constant uint64_t & nb13,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & ne2,
|
constant int64_t & ne2,
|
||||||
constant int64_t & ne3,
|
constant int64_t & ne3,
|
||||||
constant int64_t & nb0,
|
constant uint64_t & nb0,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant int64_t & nb2,
|
constant uint64_t & nb2,
|
||||||
constant int64_t & nb3,
|
constant uint64_t & nb3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
@ -158,26 +158,26 @@ kernel void kernel_div(
|
|||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant int64_t & ne03,
|
constant int64_t & ne03,
|
||||||
constant int64_t & nb00,
|
constant uint64_t & nb00,
|
||||||
constant int64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant int64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & nb03,
|
constant uint64_t & nb03,
|
||||||
constant int64_t & ne10,
|
constant int64_t & ne10,
|
||||||
constant int64_t & ne11,
|
constant int64_t & ne11,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
constant int64_t & ne13,
|
constant int64_t & ne13,
|
||||||
constant int64_t & nb10,
|
constant uint64_t & nb10,
|
||||||
constant int64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant int64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & nb13,
|
constant uint64_t & nb13,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & ne2,
|
constant int64_t & ne2,
|
||||||
constant int64_t & ne3,
|
constant int64_t & ne3,
|
||||||
constant int64_t & nb0,
|
constant uint64_t & nb0,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant int64_t & nb2,
|
constant uint64_t & nb2,
|
||||||
constant int64_t & nb3,
|
constant uint64_t & nb3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
@ -205,7 +205,7 @@ kernel void kernel_add_row(
|
|||||||
device const float4 * src0,
|
device const float4 * src0,
|
||||||
device const float4 * src1,
|
device const float4 * src1,
|
||||||
device float4 * dst,
|
device float4 * dst,
|
||||||
constant int64_t & nb [[buffer(28)]],
|
constant uint64_t & nb [[buffer(28)]],
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
||||||
}
|
}
|
||||||
@ -214,7 +214,7 @@ kernel void kernel_mul_row(
|
|||||||
device const float4 * src0,
|
device const float4 * src0,
|
||||||
device const float4 * src1,
|
device const float4 * src1,
|
||||||
device float4 * dst,
|
device float4 * dst,
|
||||||
constant int64_t & nb [[buffer(28)]],
|
constant uint64_t & nb [[buffer(28)]],
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
||||||
}
|
}
|
||||||
@ -223,7 +223,7 @@ kernel void kernel_div_row(
|
|||||||
device const float4 * src0,
|
device const float4 * src0,
|
||||||
device const float4 * src1,
|
device const float4 * src1,
|
||||||
device float4 * dst,
|
device float4 * dst,
|
||||||
constant int64_t & nb [[buffer(28)]],
|
constant uint64_t & nb [[buffer(28)]],
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
||||||
}
|
}
|
||||||
@ -307,26 +307,26 @@ kernel void kernel_sum_rows(
|
|||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant int64_t & ne03,
|
constant int64_t & ne03,
|
||||||
constant int64_t & nb00,
|
constant uint64_t & nb00,
|
||||||
constant int64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant int64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & nb03,
|
constant uint64_t & nb03,
|
||||||
constant int64_t & ne10,
|
constant int64_t & ne10,
|
||||||
constant int64_t & ne11,
|
constant int64_t & ne11,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
constant int64_t & ne13,
|
constant int64_t & ne13,
|
||||||
constant int64_t & nb10,
|
constant uint64_t & nb10,
|
||||||
constant int64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant int64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & nb13,
|
constant uint64_t & nb13,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & ne2,
|
constant int64_t & ne2,
|
||||||
constant int64_t & ne3,
|
constant int64_t & ne3,
|
||||||
constant int64_t & nb0,
|
constant uint64_t & nb0,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant int64_t & nb2,
|
constant uint64_t & nb2,
|
||||||
constant int64_t & nb3,
|
constant uint64_t & nb3,
|
||||||
uint3 tpig[[thread_position_in_grid]]) {
|
uint3 tpig[[thread_position_in_grid]]) {
|
||||||
int64_t i3 = tpig.z;
|
int64_t i3 = tpig.z;
|
||||||
int64_t i2 = tpig.y;
|
int64_t i2 = tpig.y;
|
||||||
@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01[[buffer(4)]],
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02[[buffer(5)]],
|
constant int64_t & ne02,
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant uint64_t & nb00,
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant uint64_t & nb01,
|
||||||
constant int64_t & ne0 [[buffer(15)]],
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne1 [[buffer(16)]],
|
constant int64_t & ne10,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant int64_t & ne11,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01[[buffer(4)]],
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02[[buffer(5)]],
|
constant int64_t & ne02,
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant uint64_t & nb00,
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant uint64_t & nb01,
|
||||||
constant int64_t & ne0 [[buffer(15)]],
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne1 [[buffer(16)]],
|
constant int64_t & ne10,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant int64_t & ne11,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01[[buffer(4)]],
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02[[buffer(5)]],
|
constant int64_t & ne02,
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant uint64_t & nb00,
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant uint64_t & nb01,
|
||||||
constant int64_t & ne0 [[buffer(15)]],
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne1 [[buffer(16)]],
|
constant int64_t & ne10,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant int64_t & ne11,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01[[buffer(4)]],
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02[[buffer(5)]],
|
constant int64_t & ne02,
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant uint64_t & nb00,
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant uint64_t & nb01,
|
||||||
constant int64_t & ne0 [[buffer(15)]],
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne1 [[buffer(16)]],
|
constant int64_t & ne10,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant int64_t & ne11,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
@ -1071,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne10,
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant uint & r2,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
@ -1182,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant uint & r2,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
||||||
@ -1209,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant uint & r2,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
@ -1346,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant uint & r2,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
||||||
@ -1452,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant uint & r2,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
||||||
@ -1478,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant uint & r2,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
@ -1543,7 +1578,8 @@ kernel void kernel_alibi_f32(
|
|||||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||||
|
|
||||||
const int64_t k = i3*ne3 + i2;
|
const int64_t k = i3*ne3 + i2;
|
||||||
|
|
||||||
float m_k;
|
float m_k;
|
||||||
@ -2410,22 +2446,6 @@ typedef struct {
|
|||||||
} block_q6_K;
|
} block_q6_K;
|
||||||
// 210 bytes / block
|
// 210 bytes / block
|
||||||
|
|
||||||
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
||||||
uchar4 r;
|
|
||||||
if (j < 4) {
|
|
||||||
r[0] = q[j+0] & 63;
|
|
||||||
r[2] = q[j+1] & 63;
|
|
||||||
r[1] = q[j+4] & 63;
|
|
||||||
r[3] = q[j+5] & 63;
|
|
||||||
} else {
|
|
||||||
r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
|
||||||
r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
|
|
||||||
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
|
||||||
r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
|
|
||||||
}
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
|
|
||||||
//====================================== dot products =========================
|
//====================================== dot products =========================
|
||||||
|
|
||||||
void kernel_mul_mv_q2_K_f32_impl(
|
void kernel_mul_mv_q2_K_f32_impl(
|
||||||
@ -2584,14 +2604,21 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01[[buffer(4)]],
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02[[buffer(5)]],
|
constant int64_t & ne02,
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant uint64_t & nb00,
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant uint64_t & nb01,
|
||||||
constant int64_t & ne0 [[buffer(15)]],
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne1 [[buffer(16)]],
|
constant int64_t & ne10,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant int64_t & ne11,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
@ -2841,14 +2868,21 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01[[buffer(4)]],
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02[[buffer(5)]],
|
constant int64_t & ne02,
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant uint64_t & nb00,
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant uint64_t & nb01,
|
||||||
constant int64_t & ne0 [[buffer(15)]],
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne1 [[buffer(16)]],
|
constant int64_t & ne10,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant int64_t & ne11,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
@ -2984,8 +3018,8 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
const int ix = tiisg/4; // 0...7
|
const int ix = tiisg/4; // 0...7
|
||||||
const int it = tiisg%4; // 0...3
|
const int it = tiisg%4; // 0...3
|
||||||
@ -2994,7 +3028,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|||||||
const int r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int r1 = tgpig.y;
|
const int r1 = tgpig.y;
|
||||||
const int im = tgpig.z;
|
const int im = tgpig.z;
|
||||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
const int first_row = r0 * N_DST;
|
||||||
const int ib_row = first_row * nb;
|
const int ib_row = first_row * nb;
|
||||||
|
|
||||||
const uint i12 = im%ne12;
|
const uint i12 = im%ne12;
|
||||||
@ -3060,7 +3094,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3072,14 +3106,21 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01[[buffer(4)]],
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02[[buffer(5)]],
|
constant int64_t & ne02,
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant uint64_t & nb00,
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant uint64_t & nb01,
|
||||||
constant int64_t & ne0 [[buffer(15)]],
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne1 [[buffer(16)]],
|
constant int64_t & ne10,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant int64_t & ne11,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
@ -3271,14 +3312,21 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01[[buffer(4)]],
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02[[buffer(5)]],
|
constant int64_t & ne02,
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant uint64_t & nb00,
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant uint64_t & nb01,
|
||||||
constant int64_t & ne0 [[buffer(15)]],
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne1 [[buffer(16)]],
|
constant int64_t & ne10,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant int64_t & ne11,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
@ -3398,14 +3446,21 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01[[buffer(4)]],
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02[[buffer(5)]],
|
constant int64_t & ne02,
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant uint64_t & nb00,
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant uint64_t & nb01,
|
||||||
constant int64_t & ne0 [[buffer(15)]],
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne1 [[buffer(16)]],
|
constant int64_t & ne10,
|
||||||
constant uint & r2 [[buffer(17)]],
|
constant int64_t & ne11,
|
||||||
constant uint & r3 [[buffer(18)]],
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
@ -3523,7 +3578,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|||||||
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||||
const half d = xb->d;
|
const half d = xb->d;
|
||||||
|
|
||||||
for (int i=0;i<16;i++) {
|
for (int i = 0; i < 16; i++) {
|
||||||
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3792,12 +3847,12 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant int64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant int64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
constant int64_t & nb10,
|
constant uint64_t & nb10,
|
||||||
constant int64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant int64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
@ -3924,12 +3979,12 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant int64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant int64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
constant int64_t & nb10,
|
constant uint64_t & nb10,
|
||||||
constant int64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant int64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
@ -3965,19 +4020,19 @@ kernel void kernel_mul_mm_id(
|
|||||||
device const uchar * ids,
|
device const uchar * ids,
|
||||||
device const uchar * src1,
|
device const uchar * src1,
|
||||||
device uchar * dst,
|
device uchar * dst,
|
||||||
constant int64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant int64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant int64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
constant int64_t & ne13,
|
constant int64_t & ne13,
|
||||||
constant int64_t & nb10,
|
constant uint64_t & nb10,
|
||||||
constant int64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant int64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
@ -4070,12 +4125,12 @@ typedef void (mat_mm_t)(
|
|||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant int64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant int64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
constant int64_t & nb10,
|
constant uint64_t & nb10,
|
||||||
constant int64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant int64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
@ -4104,19 +4159,19 @@ typedef void (mat_mm_id_t)(
|
|||||||
device const uchar * ids,
|
device const uchar * ids,
|
||||||
device const uchar * src1,
|
device const uchar * src1,
|
||||||
device uchar * dst,
|
device uchar * dst,
|
||||||
constant int64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant int64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant int64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
constant int64_t & ne13,
|
constant int64_t & ne13,
|
||||||
constant int64_t & nb10,
|
constant uint64_t & nb10,
|
||||||
constant int64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant int64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
@ -4153,7 +4208,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device uchar * dst,
|
||||||
constant int64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -4169,7 +4224,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
@ -4222,7 +4277,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device uchar * dst,
|
||||||
constant int64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -4238,7 +4293,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
@ -4291,7 +4346,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device uchar * dst,
|
||||||
constant int64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -4307,7 +4362,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
@ -4354,7 +4409,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device uchar * dst,
|
||||||
constant int64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -4370,7 +4425,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
@ -4417,7 +4472,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device uchar * dst,
|
||||||
constant int64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -4433,7 +4488,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
@ -4480,7 +4535,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device uchar * dst,
|
||||||
constant int64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -4496,7 +4551,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
@ -4543,7 +4598,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device uchar * dst,
|
||||||
constant int64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -4559,7 +4614,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
@ -4606,7 +4661,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device uchar * dst,
|
||||||
constant int64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -4622,7 +4677,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
@ -4669,7 +4724,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device uchar * dst,
|
||||||
constant int64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -4685,7 +4740,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
@ -4732,7 +4787,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device uchar * dst,
|
||||||
constant int64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -4748,7 +4803,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
@ -4795,7 +4850,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device uchar * dst,
|
||||||
constant int64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -4811,7 +4866,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
@ -4858,7 +4913,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device uchar * dst,
|
||||||
constant int64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -4874,7 +4929,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
|
@ -15,19 +15,18 @@
|
|||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
||||||
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
|
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
|
||||||
size_t size = ggml_nelements(tensor);
|
size_t size = ggml_nelements(tensor);
|
||||||
std::vector<float> data(size);
|
std::vector<float> data(size);
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
std::default_random_engine generator(rd());
|
static std::default_random_engine generator(1234);
|
||||||
std::uniform_real_distribution<float> distribution(min, max);
|
std::uniform_real_distribution<float> distribution(min, max);
|
||||||
|
|
||||||
for (size_t i = 0; i < size; i++) {
|
for (size_t i = 0; i < size; i++) {
|
||||||
data[i] = distribution(generator);
|
data[i] = distribution(generator);
|
||||||
}
|
}
|
||||||
#endif
|
#else
|
||||||
auto init_thread = [&](size_t start, size_t end) {
|
auto init_thread = [&](size_t start, size_t end) {
|
||||||
std::random_device rd;
|
std::random_device rd;
|
||||||
std::default_random_engine generator(rd());
|
std::default_random_engine generator(rd());
|
||||||
@ -49,6 +48,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
|
|||||||
for (auto & t : threads) {
|
for (auto & t : threads) {
|
||||||
t.join();
|
t.join();
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
|
if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
|
||||||
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
|
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
|
||||||
@ -437,7 +437,7 @@ struct test_case {
|
|||||||
double err = nmse(f1.data(), f2.data(), f1.size());
|
double err = nmse(f1.data(), f2.data(), f1.size());
|
||||||
if (err > ud->max_err) {
|
if (err > ud->max_err) {
|
||||||
printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
|
printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
|
||||||
//for (int i = 0; i < f1.size(); i++) {
|
//for (int i = 0; i < (int) f1.size(); i++) {
|
||||||
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
||||||
//}
|
//}
|
||||||
//printf("\n");
|
//printf("\n");
|
||||||
|
Loading…
Reference in New Issue
Block a user