ggml : sync (im2col, GPU conv, 32-bit arm compat) (#4060)

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-11-13 16:55:52 +02:00 committed by GitHub
parent c049b37d7b
commit 3d68f364f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 586 additions and 1075 deletions

View File

@ -4489,6 +4489,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
*dsti = __float2half(*xi); *dsti = __float2half(*xi);
} }
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
const half * xi = (const half *) cxi;
half * dsti = (half *) cdsti;
*dsti = *xi;
}
template <cpy_kernel_t cpy_1> template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@ -4742,6 +4749,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
} }
static __global__ void im2col_f32_f16(
const float * x, half * dst,
int ofs0, int ofs1, int IW, int IH, int CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
const int offset_dst =
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = __float2half(0.0f);
} else {
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
}
}
template<int qk, int qr, dequantize_kernel_t dq> template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) { static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
@ -5642,6 +5668,16 @@ static void ggml_cpy_f32_f16_cuda(
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
} }
static void ggml_cpy_f16_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k); scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@ -5725,6 +5761,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x); soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
} }
static void im2col_f32_f16_cuda(const float * x, half * dst,
int OH, int IW, int IH, int OW, int IC,
int KH, int KW, int N, int ofs0, int ofs1,
int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
dim3 block_nums(IC, OH, OW);
dim3 block_dims(N, KH, KW);
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
// buffer pool for cuda // buffer pool for cuda
#define MAX_CUDA_BUFFERS 256 #define MAX_CUDA_BUFFERS 256
@ -6522,8 +6567,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
} }
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
size_t dst_as = 0; size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
@ -6698,6 +6742,45 @@ inline void ggml_cuda_op_alibi(
(void) src1_dd; (void) src1_dd;
} }
inline void ggml_cuda_op_im2col(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
const int64_t N = src1->ne[is_2D ? 3 : 2];
const int64_t IC = src1->ne[is_2D ? 2 : 1];
const int64_t IH = is_2D ? src1->ne[1] : 1;
const int64_t IW = src1->ne[0];
const int64_t KH = is_2D ? src0->ne[1] : 1;
const int64_t KW = src0->ne[0];
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
OH, IW, IH, OW, IC, KH, KW, N,
ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
(void) src0;
(void) src0_dd;
}
inline void ggml_cuda_op_diag_mask_inf( inline void ggml_cuda_op_diag_mask_inf(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@ -7610,6 +7693,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream); ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream);
} else { } else {
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type)); ggml_type_name(src0->type), ggml_type_name(src1->type));
@ -7641,6 +7727,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
} }
void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
}
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
(void) src0; (void) src0;
(void) src1; (void) src1;
@ -7934,6 +8024,15 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
return false; return false;
} }
if (tensor->op == GGML_OP_MUL_MAT) {
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
#ifndef NDEBUG
fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %d, src1->ne[3] = %d - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]);
#endif
return false;
}
}
switch (tensor->op) { switch (tensor->op) {
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
func = ggml_cuda_repeat; func = ggml_cuda_repeat;
@ -8012,6 +8111,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_OP_ALIBI: case GGML_OP_ALIBI:
func = ggml_cuda_alibi; func = ggml_cuda_alibi;
break; break;
case GGML_OP_IM2COL:
func = ggml_cuda_im2col;
break;
default: default:
return false; return false;
} }

View File

@ -39,12 +39,6 @@ extern "C" {
#endif #endif
#endif #endif
#undef MIN
#undef MAX
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
// 16-bit float // 16-bit float
// on Arm, we use __fp16 // on Arm, we use __fp16
// on x86, we use uint16_t // on x86, we use uint16_t

View File

@ -26,7 +26,7 @@
#include <stdbool.h> #include <stdbool.h>
// max memory buffers that can be mapped to the device // max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 16 #define GGML_METAL_MAX_BUFFERS 64
#define GGML_METAL_MAX_COMMAND_BUFFERS 32 #define GGML_METAL_MAX_COMMAND_BUFFERS 32
struct ggml_tensor; struct ggml_tensor;

View File

@ -86,6 +86,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32); GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32); GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row); GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
@ -114,6 +115,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(rope_f32); GGML_METAL_DECL_KERNEL(rope_f32);
GGML_METAL_DECL_KERNEL(rope_f16); GGML_METAL_DECL_KERNEL(rope_f16);
GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(im2col_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32); GGML_METAL_DECL_KERNEL(cpy_f32_f32);
GGML_METAL_DECL_KERNEL(cpy_f16_f16); GGML_METAL_DECL_KERNEL(cpy_f16_f16);
@ -126,7 +128,7 @@ struct ggml_metal_context {
// MSL code // MSL code
// TODO: move the contents here when ready // TODO: move the contents here when ready
// for now it is easier to work in a separate file // for now it is easier to work in a separate file
static NSString * const msl_library_source = @"see metal.metal"; //static NSString * const msl_library_source = @"see metal.metal";
// Here to assist with NSBundle Path Hack // Here to assist with NSBundle Path Hack
@interface GGMLMetalClass : NSObject @interface GGMLMetalClass : NSObject
@ -142,6 +144,7 @@ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_dat
ggml_metal_log_user_data = user_data; ggml_metal_log_user_data = user_data;
} }
GGML_ATTRIBUTE_FORMAT(2, 3)
static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
if (ggml_metal_log_callback != NULL) { if (ggml_metal_log_callback != NULL) {
va_list args; va_list args;
@ -210,7 +213,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
} else { } else {
GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; NSString * sourcePath;
NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
if (ggmlMetalPathResources) {
sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
} else {
sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
}
if (sourcePath == nil) { if (sourcePath == nil) {
GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
sourcePath = @"ggml-metal.metal"; sourcePath = @"ggml-metal.metal";
@ -281,6 +290,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32); GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32); GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row); GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
@ -311,6 +321,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(rope_f32); GGML_METAL_ADD_KERNEL(rope_f32);
GGML_METAL_ADD_KERNEL(rope_f16); GGML_METAL_ADD_KERNEL(rope_f16);
GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(im2col_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32); GGML_METAL_ADD_KERNEL(cpy_f32_f32);
GGML_METAL_ADD_KERNEL(cpy_f16_f16); GGML_METAL_ADD_KERNEL(cpy_f16_f16);
@ -329,7 +340,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
if ([ctx->device supportsFamily:i]) { if ([ctx->device supportsFamily:i]) {
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i); GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
break; break;
} }
} }
@ -380,6 +391,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(norm); GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32); GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32); GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row); GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
@ -410,6 +422,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(rope_f32); GGML_METAL_DEL_KERNEL(rope_f32);
GGML_METAL_DEL_KERNEL(rope_f16); GGML_METAL_DEL_KERNEL(rope_f16);
GGML_METAL_DEL_KERNEL(alibi_f32); GGML_METAL_DEL_KERNEL(alibi_f32);
GGML_METAL_DEL_KERNEL(im2col_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f32); GGML_METAL_DEL_KERNEL(cpy_f32_f32);
GGML_METAL_DEL_KERNEL(cpy_f16_f16); GGML_METAL_DEL_KERNEL(cpy_f16_f16);
@ -467,6 +480,10 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
const int64_t tsize = ggml_nbytes(t); const int64_t tsize = ggml_nbytes(t);
if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
ctx = t->buffer->backend->context;
}
// find the view that contains the tensor fully // find the view that contains the tensor fully
for (int i = 0; i < ctx->n_buffers; ++i) { for (int i = 0; i < ctx->n_buffers; ++i) {
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
@ -567,7 +584,7 @@ bool ggml_metal_add_buffer(
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) { if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__); GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
} else { } else {
GGML_METAL_LOG_INFO("\n"); GGML_METAL_LOG_INFO("\n");
} }
@ -1024,7 +1041,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setThreadgroupMemoryLength:MAX(16, nth/32*sizeof(float)) atIndex:0]; [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
@ -1133,6 +1150,7 @@ void ggml_metal_graph_compute(
switch (src0t) { switch (src0t) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
GGML_ASSERT(src1t == GGML_TYPE_F32);
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32]; [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
nrows = 4; nrows = 4;
} break; } break;
@ -1140,6 +1158,7 @@ void ggml_metal_graph_compute(
{ {
nth0 = 32; nth0 = 32;
nth1 = 1; nth1 = 1;
if (src1t == GGML_TYPE_F32) {
if (ne11 * ne12 < 4) { if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
@ -1149,6 +1168,10 @@ void ggml_metal_graph_compute(
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
nrows = 4; nrows = 4;
} }
} else {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
nrows = 4;
}
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
{ {
@ -1336,7 +1359,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4]; [encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0]; [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
const int64_t nrows = ggml_nrows(src0); const int64_t nrows = ggml_nrows(src0);
@ -1355,7 +1378,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4]; [encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:MAX(16, nth*sizeof(float)) atIndex:0]; [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
const int64_t nrows = ggml_nrows(src0); const int64_t nrows = ggml_nrows(src0);
@ -1410,8 +1433,7 @@ void ggml_metal_graph_compute(
const int n_past = ((int32_t *) dst->op_params)[0]; const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
@ -1459,6 +1481,58 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
case GGML_OP_IM2COL:
{
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16);
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
const int32_t N = src1->ne[is_2D ? 3 : 2];
const int32_t IC = src1->ne[is_2D ? 2 : 1];
const int32_t IH = is_2D ? src1->ne[1] : 1;
const int32_t IW = src1->ne[0];
const int32_t KH = is_2D ? src0->ne[1] : 1;
const int32_t KW = src0->ne[0];
const int32_t OH = is_2D ? dst->ne[2] : 1;
const int32_t OW = dst->ne[1];
const int32_t CHW = IC * KH * KW;
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
switch (src0->type) {
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
default: GGML_ASSERT(false);
};
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
} break;
case GGML_OP_DUP: case GGML_OP_DUP:
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_CONT: case GGML_OP_CONT:

View File

@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
} }
} }
#define N_F16_F16 4
kernel void kernel_mul_mv_f16_f16(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F16_F16;
const int64_t im = tgpig.z;
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
if (ne00 < 128) {
for (int row = 0; row < N_F16_F16; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
sumf += (half) x[i] * (half) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
device const half4 * x4 = (device const half4 *)x;
for (int row = 0; row < N_F16_F16; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
device const half4 * y4 = (device const half4 *) y;
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
}
kernel void kernel_mul_mv_f16_f32_1row( kernel void kernel_mul_mv_f16_f32_1row(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
@ -1229,6 +1302,39 @@ kernel void kernel_rope(
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>; template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>; template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
kernel void kernel_im2col_f16(
device const float * x,
device half * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
constant int32_t & IH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
const int32_t offset_dst =
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
kernel void kernel_cpy_f16_f16( kernel void kernel_cpy_f16_f16(
device const half * src0, device const half * src0,
device half * dst, device half * dst,

View File

@ -14,26 +14,6 @@
// //
#include <arm_neon.h> #include <arm_neon.h>
#if !defined(__aarch64__)
inline static int32_t vaddvq_s16(int16x8_t v) {
return
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
}
inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
return vcombine_s16(a0, b0);
}
inline static int32_t vaddvq_s32(int32x4_t v) {
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
}
#endif
#else #else
#ifdef __wasm_simd128__ #ifdef __wasm_simd128__
@ -47,13 +27,15 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
#if defined(_MSC_VER) || defined(__MINGW32__) #if defined(_MSC_VER) || defined(__MINGW32__)
#include <intrin.h> #include <intrin.h>
#else #else
#if !defined(__riscv) && !defined(__s390__) #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
#if !defined(__riscv)
#include <immintrin.h> #include <immintrin.h>
#endif #endif
#endif #endif
#endif #endif
#endif #endif
#endif #endif
#endif
#ifdef __riscv_v_intrinsic #ifdef __riscv_v_intrinsic
#include <riscv_vector.h> #include <riscv_vector.h>
@ -61,6 +43,7 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
#undef MIN #undef MIN
#undef MAX #undef MAX
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
@ -283,9 +266,31 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
#if !defined(__aarch64__) #if !defined(__aarch64__)
// 64-bit compatibility
// vaddvq_s16
// vpaddq_s16
// vaddvq_s32
// vaddvq_f32
// vmaxvq_f32
// vcvtnq_s32_f32
inline static int32_t vaddvq_s16(int16x8_t v) {
return
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
}
inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
return vcombine_s16(a0, b0);
}
inline static int32_t vaddvq_s32(int32x4_t v) { inline static int32_t vaddvq_s32(int32x4_t v) {
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
} }
@ -311,6 +316,96 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
return res; return res;
} }
// vld1q_s16_x2
// vld1q_u8_x2
// vld1q_u8_x4
// vld1q_s8_x2
// vld1q_s8_x4
// TODO: double-check these work correctly
typedef struct ggml_int16x8x2_t {
int16x8_t val[2];
} ggml_int16x8x2_t;
inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
ggml_int16x8x2_t res;
res.val[0] = vld1q_s16(ptr + 0);
res.val[1] = vld1q_s16(ptr + 8);
return res;
}
typedef struct ggml_uint8x16x2_t {
uint8x16_t val[2];
} ggml_uint8x16x2_t;
inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
ggml_uint8x16x2_t res;
res.val[0] = vld1q_u8(ptr + 0);
res.val[1] = vld1q_u8(ptr + 16);
return res;
}
typedef struct ggml_uint8x16x4_t {
uint8x16_t val[4];
} ggml_uint8x16x4_t;
inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
ggml_uint8x16x4_t res;
res.val[0] = vld1q_u8(ptr + 0);
res.val[1] = vld1q_u8(ptr + 16);
res.val[2] = vld1q_u8(ptr + 32);
res.val[3] = vld1q_u8(ptr + 48);
return res;
}
typedef struct ggml_int8x16x2_t {
int8x16_t val[2];
} ggml_int8x16x2_t;
inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
ggml_int8x16x2_t res;
res.val[0] = vld1q_s8(ptr + 0);
res.val[1] = vld1q_s8(ptr + 16);
return res;
}
typedef struct ggml_int8x16x4_t {
int8x16_t val[4];
} ggml_int8x16x4_t;
inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
ggml_int8x16x4_t res;
res.val[0] = vld1q_s8(ptr + 0);
res.val[1] = vld1q_s8(ptr + 16);
res.val[2] = vld1q_s8(ptr + 32);
res.val[3] = vld1q_s8(ptr + 48);
return res;
}
#else
#define ggml_int16x8x2_t int16x8x2_t
#define ggml_uint8x16x2_t uint8x16x2_t
#define ggml_uint8x16x4_t uint8x16x4_t
#define ggml_int8x16x2_t int8x16x2_t
#define ggml_int8x16x4_t int8x16x4_t
#define ggml_vld1q_s16_x2 vld1q_s16_x2
#define ggml_vld1q_u8_x2 vld1q_u8_x2
#define ggml_vld1q_u8_x4 vld1q_u8_x4
#define ggml_vld1q_s8_x2 vld1q_s8_x2
#define ggml_vld1q_s8_x4 vld1q_s8_x4
#endif #endif
#endif #endif
@ -3557,7 +3652,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t vzero = vdupq_n_s32(0); const int32x4_t vzero = vdupq_n_s32(0);
#endif #endif
int8x16x2_t q2bytes; ggml_int8x16x2_t q2bytes;
uint8_t aux[16]; uint8_t aux[16];
float sum = 0; float sum = 0;
@ -3576,8 +3671,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
vst1q_u8(aux, scales); vst1q_u8(aux, scales);
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
@ -3605,7 +3700,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
#endif #endif
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
q8bytes = vld1q_s8_x2(q8); q8 += 32;\ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
MULTIPLY_ACCUM_WITH_SCALE((index)); MULTIPLY_ACCUM_WITH_SCALE((index));
@ -3613,9 +3708,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
for (int j = 0; j < QK_K/128; ++j) { for (int j = 0; j < QK_K/128; ++j) {
const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32; const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32; ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
MULTIPLY_ACCUM_WITH_SCALE(0); MULTIPLY_ACCUM_WITH_SCALE(0);
@ -3949,7 +4044,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t vzero = vdupq_n_s32(0); const int32x4_t vzero = vdupq_n_s32(0);
#endif #endif
int8x16x4_t q2bytes; ggml_int8x16x4_t q2bytes;
uint32_t aux32[2]; uint32_t aux32[2];
const uint8_t * scales = (const uint8_t *)aux32; const uint8_t * scales = (const uint8_t *)aux32;
@ -3974,7 +4069,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t q2bits = vld1q_u8(q2); const uint8x16_t q2bits = vld1q_u8(q2);
const int8x16x4_t q8bytes = vld1q_s8_x4(q8); const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3)); q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3)); q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
@ -4238,7 +4333,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t m3 = vshlq_n_u8(m0, 3); const uint8x16_t m3 = vshlq_n_u8(m0, 3);
const int8_t m32 = 32; const int8_t m32 = 32;
int8x16x4_t q3bytes; ggml_int8x16x4_t q3bytes;
float sum = 0; float sum = 0;
@ -4250,9 +4345,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const uint8_t * restrict qh = x[i].hmask; const uint8_t * restrict qh = x[i].hmask;
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
uint8x16x2_t qhbits = vld1q_u8_x2(qh); ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
uint8x16x4_t q3h; ggml_uint8x16x4_t q3h;
int32_t isum = 0; int32_t isum = 0;
@ -4268,9 +4363,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
for (int j = 0; j < QK_K/128; ++j) { for (int j = 0; j < QK_K/128; ++j) {
const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32; const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64; const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64; const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
@ -4772,7 +4867,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t m3b = vdupq_n_u8(0x3); const uint8x16_t m3b = vdupq_n_u8(0x3);
const uint8x16_t mh = vdupq_n_u8(4); const uint8x16_t mh = vdupq_n_u8(4);
int8x16x4_t q3bytes; ggml_int8x16x4_t q3bytes;
uint16_t aux16[2]; uint16_t aux16[2];
int8_t * scales = (int8_t *)aux16; int8_t * scales = (int8_t *)aux16;
@ -4781,11 +4876,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
uint8x16x4_t q3h; ggml_uint8x16x4_t q3h;
const uint8x8_t hbits = vld1_u8(x[i].hmask); const uint8x8_t hbits = vld1_u8(x[i].hmask);
const uint8x16_t q3bits = vld1q_u8(x[i].qs); const uint8x16_t q3bits = vld1q_u8(x[i].qs);
const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs); const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs);
const uint16_t a = *(const uint16_t *)x[i].scales; const uint16_t a = *(const uint16_t *)x[i].scales;
aux16[0] = a & 0x0f0f; aux16[0] = a & 0x0f0f;
@ -5134,8 +5229,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t mzero = vdupq_n_s32(0); const int32x4_t mzero = vdupq_n_s32(0);
#endif #endif
int8x16x2_t q4bytes; ggml_int8x16x2_t q4bytes;
int8x16x2_t q8bytes; ggml_int8x16x2_t q8bytes;
float sumf = 0; float sumf = 0;
@ -5170,17 +5265,17 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
for (int j = 0; j < QK_K/64; ++j) { for (int j = 0; j < QK_K/64; ++j) {
const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32; const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
#ifdef __ARM_FEATURE_DOTPROD #ifdef __ARM_FEATURE_DOTPROD
q8bytes = vld1q_s8_x2(q8); q8 += 32; q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
sumi1 += vaddvq_s32(p1) * scales[2*j+0]; sumi1 += vaddvq_s32(p1) * scales[2*j+0];
q8bytes = vld1q_s8_x2(q8); q8 += 32; q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
@ -5188,7 +5283,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
sumi2 += vaddvq_s32(p2) * scales[2*j+1]; sumi2 += vaddvq_s32(p2) * scales[2*j+1];
#else #else
q8bytes = vld1q_s8_x2(q8); q8 += 32; q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@ -5197,7 +5292,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0]; sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
q8bytes = vld1q_s8_x2(q8); q8 += 32; q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@ -5512,8 +5607,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
float sumf = 0; float sumf = 0;
int8x16x2_t q4bytes; ggml_int8x16x2_t q4bytes;
int8x16x4_t q8bytes; ggml_int8x16x4_t q8bytes;
float sum_mins = 0.f; float sum_mins = 0.f;
@ -5534,10 +5629,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const float d = y[i].d * (float)x[i].d[0]; const float d = y[i].d * (float)x[i].d[0];
const uint8x16x2_t q4bits = vld1q_u8_x2(q4); const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
#ifdef __ARM_FEATURE_DOTPROD #ifdef __ARM_FEATURE_DOTPROD
q8bytes = vld1q_s8_x4(q8); q8bytes = ggml_vld1q_s8_x4(q8);
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
@ -5551,7 +5646,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
#else #else
q8bytes = vld1q_s8_x4(q8); q8bytes = ggml_vld1q_s8_x4(q8);
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@ -5785,7 +5880,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t mzero = vdupq_n_s32(0); const int32x4_t mzero = vdupq_n_s32(0);
#endif #endif
int8x16x4_t q5bytes; ggml_int8x16x4_t q5bytes;
float sumf = 0; float sumf = 0;
@ -5815,16 +5910,16 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const uint8_t * restrict qh = x[i].qh; const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
uint8x16x2_t qhbits = vld1q_u8_x2(qh); ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
uint8x16x4_t q5h; ggml_uint8x16x4_t q5h;
int32_t sumi = 0; int32_t sumi = 0;
for (int j = 0; j < QK_K/64; ++j) { for (int j = 0; j < QK_K/64; ++j) {
const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32; const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
@ -6218,8 +6313,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t mzero = vdupq_n_s32(0); const int32x4_t mzero = vdupq_n_s32(0);
#endif #endif
int8x16x4_t q5bytes; ggml_int8x16x4_t q5bytes;
uint8x16x4_t q5h; ggml_uint8x16x4_t q5h;
float sumf = 0; float sumf = 0;
@ -6234,8 +6329,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x8_t qhbits = vld1_u8(qh); const uint8x8_t qhbits = vld1_u8(qh);
const uint8x16x2_t q5bits = vld1q_u8_x2(q5); const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5);
const int8x16x4_t q8bytes = vld1q_s8_x4(q8); const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
@ -6511,8 +6606,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t mone = vdupq_n_u8(3); const uint8x16_t mone = vdupq_n_u8(3);
int8x16x4_t q6bytes; ggml_int8x16x4_t q6bytes;
uint8x16x4_t q6h; ggml_uint8x16x4_t q6h;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
@ -6524,9 +6619,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
const int8_t * restrict scale = x[i].scales; const int8_t * restrict scale = x[i].scales;
const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
const int8x16_t scales = vld1q_s8(scale); const int8x16_t scales = vld1q_s8(scale);
const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
@ -6538,9 +6633,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
for (int j = 0; j < QK_K/128; ++j) { for (int j = 0; j < QK_K/128; ++j) {
uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32; ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64; ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
@ -6583,7 +6678,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
scale += 2; scale += 2;
#endif #endif
q8bytes = vld1q_s8_x4(q8); q8 += 64; q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
shifted = vshrq_n_u8(qhbits.val[0], 4); shifted = vshrq_n_u8(qhbits.val[0], 4);
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
@ -6987,8 +7082,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t mone = vdupq_n_u8(3); const uint8x16_t mone = vdupq_n_u8(3);
int8x16x4_t q6bytes; ggml_int8x16x4_t q6bytes;
uint8x16x4_t q6h; ggml_uint8x16x4_t q6h;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
@ -7003,8 +7098,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
int32_t isum = 0; int32_t isum = 0;
uint8x16_t qhbits = vld1q_u8(qh); uint8x16_t qhbits = vld1q_u8(qh);
uint8x16x2_t q6bits = vld1q_u8_x2(q6); ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6);
int8x16x4_t q8bytes = vld1q_s8_x4(q8); ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4); q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
uint8x16_t shifted = vshrq_n_u8(qhbits, 2); uint8x16_t shifted = vshrq_n_u8(qhbits, 2);

1045
ggml.c

File diff suppressed because it is too large Load Diff

19
ggml.h
View File

@ -403,13 +403,8 @@ extern "C" {
GGML_OP_ROPE_BACK, GGML_OP_ROPE_BACK,
GGML_OP_ALIBI, GGML_OP_ALIBI,
GGML_OP_CLAMP, GGML_OP_CLAMP,
GGML_OP_CONV_1D,
GGML_OP_CONV_1D_STAGE_0, // internal
GGML_OP_CONV_1D_STAGE_1, // internal
GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_CONV_2D, GGML_OP_IM2COL,
GGML_OP_CONV_2D_STAGE_0, // internal
GGML_OP_CONV_2D_STAGE_1, // internal
GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D, GGML_OP_POOL_1D,
GGML_OP_POOL_2D, GGML_OP_POOL_2D,
@ -1403,6 +1398,18 @@ extern "C" {
float min, float min,
float max); float max);
GGML_API struct ggml_tensor * ggml_im2col(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1,
bool is_2D);
GGML_API struct ggml_tensor * ggml_conv_1d( GGML_API struct ggml_tensor * ggml_conv_1d(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,