mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-14 06:49:54 +00:00
parent
c049b37d7b
commit
3d68f364f1
106
ggml-cuda.cu
106
ggml-cuda.cu
@ -4489,6 +4489,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
|||||||
*dsti = __float2half(*xi);
|
*dsti = __float2half(*xi);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
|
||||||
|
const half * xi = (const half *) cxi;
|
||||||
|
half * dsti = (half *) cdsti;
|
||||||
|
|
||||||
|
*dsti = *xi;
|
||||||
|
}
|
||||||
|
|
||||||
template <cpy_kernel_t cpy_1>
|
template <cpy_kernel_t cpy_1>
|
||||||
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||||
@ -4742,6 +4749,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
|
|||||||
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __global__ void im2col_f32_f16(
|
||||||
|
const float * x, half * dst,
|
||||||
|
int ofs0, int ofs1, int IW, int IH, int CHW,
|
||||||
|
int s0, int s1, int p0, int p1, int d0, int d1) {
|
||||||
|
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
|
||||||
|
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
|
||||||
|
|
||||||
|
const int offset_dst =
|
||||||
|
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
|
||||||
|
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
|
||||||
|
|
||||||
|
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||||
|
dst[offset_dst] = __float2half(0.0f);
|
||||||
|
} else {
|
||||||
|
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
|
||||||
|
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<int qk, int qr, dequantize_kernel_t dq>
|
template<int qk, int qr, dequantize_kernel_t dq>
|
||||||
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
|
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
|
||||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||||
@ -5642,6 +5668,16 @@ static void ggml_cpy_f32_f16_cuda(
|
|||||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f16_f16_cuda(
|
||||||
|
const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||||
|
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
||||||
|
|
||||||
|
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||||
|
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||||
|
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
||||||
|
}
|
||||||
|
|
||||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
||||||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
||||||
@ -5725,6 +5761,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
|
|||||||
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void im2col_f32_f16_cuda(const float * x, half * dst,
|
||||||
|
int OH, int IW, int IH, int OW, int IC,
|
||||||
|
int KH, int KW, int N, int ofs0, int ofs1,
|
||||||
|
int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
|
||||||
|
dim3 block_nums(IC, OH, OW);
|
||||||
|
dim3 block_dims(N, KH, KW);
|
||||||
|
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
||||||
|
}
|
||||||
|
|
||||||
// buffer pool for cuda
|
// buffer pool for cuda
|
||||||
#define MAX_CUDA_BUFFERS 256
|
#define MAX_CUDA_BUFFERS 256
|
||||||
|
|
||||||
@ -6522,8 +6567,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|||||||
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
|
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
|
||||||
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
||||||
}
|
}
|
||||||
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
|
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
|
||||||
|
|
||||||
size_t dst_as = 0;
|
size_t dst_as = 0;
|
||||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
|
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
|
||||||
|
|
||||||
@ -6698,6 +6742,45 @@ inline void ggml_cuda_op_alibi(
|
|||||||
(void) src1_dd;
|
(void) src1_dd;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void ggml_cuda_op_im2col(
|
||||||
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
|
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
|
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
||||||
|
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
||||||
|
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
|
||||||
|
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
|
||||||
|
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
|
||||||
|
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
|
||||||
|
|
||||||
|
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
|
||||||
|
|
||||||
|
const int64_t N = src1->ne[is_2D ? 3 : 2];
|
||||||
|
const int64_t IC = src1->ne[is_2D ? 2 : 1];
|
||||||
|
const int64_t IH = is_2D ? src1->ne[1] : 1;
|
||||||
|
const int64_t IW = src1->ne[0];
|
||||||
|
|
||||||
|
const int64_t KH = is_2D ? src0->ne[1] : 1;
|
||||||
|
const int64_t KW = src0->ne[0];
|
||||||
|
|
||||||
|
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
||||||
|
const int64_t OW = dst->ne[1];
|
||||||
|
|
||||||
|
const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
||||||
|
const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||||
|
|
||||||
|
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
|
||||||
|
OH, IW, IH, OW, IC, KH, KW, N,
|
||||||
|
ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
|
||||||
|
|
||||||
|
(void) src0;
|
||||||
|
(void) src0_dd;
|
||||||
|
}
|
||||||
|
|
||||||
inline void ggml_cuda_op_diag_mask_inf(
|
inline void ggml_cuda_op_diag_mask_inf(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||||
@ -7610,6 +7693,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
|||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||||
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
||||||
ne10, ne11, nb10, nb11, nb12, main_stream);
|
ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||||
|
ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
||||||
|
ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
||||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||||
@ -7641,6 +7727,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
|
|||||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
(void) src0;
|
(void) src0;
|
||||||
(void) src1;
|
(void) src1;
|
||||||
@ -7934,6 +8024,15 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (tensor->op == GGML_OP_MUL_MAT) {
|
||||||
|
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
|
||||||
|
#ifndef NDEBUG
|
||||||
|
fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %d, src1->ne[3] = %d - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]);
|
||||||
|
#endif
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch (tensor->op) {
|
switch (tensor->op) {
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
func = ggml_cuda_repeat;
|
func = ggml_cuda_repeat;
|
||||||
@ -8012,6 +8111,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|||||||
case GGML_OP_ALIBI:
|
case GGML_OP_ALIBI:
|
||||||
func = ggml_cuda_alibi;
|
func = ggml_cuda_alibi;
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_IM2COL:
|
||||||
|
func = ggml_cuda_im2col;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -39,12 +39,6 @@ extern "C" {
|
|||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#undef MIN
|
|
||||||
#undef MAX
|
|
||||||
|
|
||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
|
||||||
|
|
||||||
// 16-bit float
|
// 16-bit float
|
||||||
// on Arm, we use __fp16
|
// on Arm, we use __fp16
|
||||||
// on x86, we use uint16_t
|
// on x86, we use uint16_t
|
||||||
|
@ -26,7 +26,7 @@
|
|||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
|
|
||||||
// max memory buffers that can be mapped to the device
|
// max memory buffers that can be mapped to the device
|
||||||
#define GGML_METAL_MAX_BUFFERS 16
|
#define GGML_METAL_MAX_BUFFERS 64
|
||||||
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
|
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
|
||||||
|
|
||||||
struct ggml_tensor;
|
struct ggml_tensor;
|
||||||
|
94
ggml-metal.m
94
ggml-metal.m
@ -86,6 +86,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DECL_KERNEL(norm);
|
GGML_METAL_DECL_KERNEL(norm);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
||||||
@ -114,6 +115,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(rope_f32);
|
GGML_METAL_DECL_KERNEL(rope_f32);
|
||||||
GGML_METAL_DECL_KERNEL(rope_f16);
|
GGML_METAL_DECL_KERNEL(rope_f16);
|
||||||
GGML_METAL_DECL_KERNEL(alibi_f32);
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(im2col_f16);
|
||||||
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
||||||
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
||||||
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
||||||
@ -126,7 +128,7 @@ struct ggml_metal_context {
|
|||||||
// MSL code
|
// MSL code
|
||||||
// TODO: move the contents here when ready
|
// TODO: move the contents here when ready
|
||||||
// for now it is easier to work in a separate file
|
// for now it is easier to work in a separate file
|
||||||
static NSString * const msl_library_source = @"see metal.metal";
|
//static NSString * const msl_library_source = @"see metal.metal";
|
||||||
|
|
||||||
// Here to assist with NSBundle Path Hack
|
// Here to assist with NSBundle Path Hack
|
||||||
@interface GGMLMetalClass : NSObject
|
@interface GGMLMetalClass : NSObject
|
||||||
@ -142,7 +144,8 @@ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_dat
|
|||||||
ggml_metal_log_user_data = user_data;
|
ggml_metal_log_user_data = user_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
|
GGML_ATTRIBUTE_FORMAT(2, 3)
|
||||||
|
static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
||||||
if (ggml_metal_log_callback != NULL) {
|
if (ggml_metal_log_callback != NULL) {
|
||||||
va_list args;
|
va_list args;
|
||||||
va_start(args, format);
|
va_start(args, format);
|
||||||
@ -210,7 +213,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
} else {
|
} else {
|
||||||
GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
||||||
|
|
||||||
NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
NSString * sourcePath;
|
||||||
|
NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
||||||
|
if (ggmlMetalPathResources) {
|
||||||
|
sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
|
||||||
|
} else {
|
||||||
|
sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
||||||
|
}
|
||||||
if (sourcePath == nil) {
|
if (sourcePath == nil) {
|
||||||
GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
|
GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
|
||||||
sourcePath = @"ggml-metal.metal";
|
sourcePath = @"ggml-metal.metal";
|
||||||
@ -281,6 +290,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||||
GGML_METAL_ADD_KERNEL(norm);
|
GGML_METAL_ADD_KERNEL(norm);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
||||||
@ -311,6 +321,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(rope_f32);
|
GGML_METAL_ADD_KERNEL(rope_f32);
|
||||||
GGML_METAL_ADD_KERNEL(rope_f16);
|
GGML_METAL_ADD_KERNEL(rope_f16);
|
||||||
GGML_METAL_ADD_KERNEL(alibi_f32);
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(im2col_f16);
|
||||||
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
||||||
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
||||||
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
||||||
@ -329,7 +340,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
||||||
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
||||||
if ([ctx->device supportsFamily:i]) {
|
if ([ctx->device supportsFamily:i]) {
|
||||||
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
|
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -380,6 +391,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(rms_norm);
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DEL_KERNEL(norm);
|
GGML_METAL_DEL_KERNEL(norm);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
||||||
@ -410,6 +422,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(rope_f32);
|
GGML_METAL_DEL_KERNEL(rope_f32);
|
||||||
GGML_METAL_DEL_KERNEL(rope_f16);
|
GGML_METAL_DEL_KERNEL(rope_f16);
|
||||||
GGML_METAL_DEL_KERNEL(alibi_f32);
|
GGML_METAL_DEL_KERNEL(alibi_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(im2col_f16);
|
||||||
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
||||||
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
||||||
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
||||||
@ -467,6 +480,10 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|||||||
|
|
||||||
const int64_t tsize = ggml_nbytes(t);
|
const int64_t tsize = ggml_nbytes(t);
|
||||||
|
|
||||||
|
if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
|
||||||
|
ctx = t->buffer->backend->context;
|
||||||
|
}
|
||||||
|
|
||||||
// find the view that contains the tensor fully
|
// find the view that contains the tensor fully
|
||||||
for (int i = 0; i < ctx->n_buffers; ++i) {
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
||||||
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
||||||
@ -567,7 +584,7 @@ bool ggml_metal_add_buffer(
|
|||||||
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||||
|
|
||||||
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
||||||
GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
||||||
} else {
|
} else {
|
||||||
GGML_METAL_LOG_INFO("\n");
|
GGML_METAL_LOG_INFO("\n");
|
||||||
}
|
}
|
||||||
@ -1024,7 +1041,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||||
[encoder setThreadgroupMemoryLength:MAX(16, nth/32*sizeof(float)) atIndex:0];
|
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
@ -1133,6 +1150,7 @@ void ggml_metal_graph_compute(
|
|||||||
switch (src0t) {
|
switch (src0t) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
||||||
nrows = 4;
|
nrows = 4;
|
||||||
} break;
|
} break;
|
||||||
@ -1140,6 +1158,7 @@ void ggml_metal_graph_compute(
|
|||||||
{
|
{
|
||||||
nth0 = 32;
|
nth0 = 32;
|
||||||
nth1 = 1;
|
nth1 = 1;
|
||||||
|
if (src1t == GGML_TYPE_F32) {
|
||||||
if (ne11 * ne12 < 4) {
|
if (ne11 * ne12 < 4) {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
||||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||||
@ -1149,6 +1168,10 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
||||||
nrows = 4;
|
nrows = 4;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
|
||||||
|
nrows = 4;
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
{
|
{
|
||||||
@ -1336,7 +1359,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
||||||
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
||||||
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
|
||||||
|
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
@ -1355,7 +1378,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
||||||
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
||||||
[encoder setThreadgroupMemoryLength:MAX(16, nth*sizeof(float)) atIndex:0];
|
[encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
|
||||||
|
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
@ -1410,8 +1433,7 @@ void ggml_metal_graph_compute(
|
|||||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
|
||||||
|
|
||||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
@ -1459,6 +1481,58 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_IM2COL:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
|
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
||||||
|
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
||||||
|
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
||||||
|
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
||||||
|
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
||||||
|
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
||||||
|
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
||||||
|
|
||||||
|
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
||||||
|
const int32_t IC = src1->ne[is_2D ? 2 : 1];
|
||||||
|
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
||||||
|
const int32_t IW = src1->ne[0];
|
||||||
|
|
||||||
|
const int32_t KH = is_2D ? src0->ne[1] : 1;
|
||||||
|
const int32_t KW = src0->ne[0];
|
||||||
|
|
||||||
|
const int32_t OH = is_2D ? dst->ne[2] : 1;
|
||||||
|
const int32_t OW = dst->ne[1];
|
||||||
|
|
||||||
|
const int32_t CHW = IC * KH * KW;
|
||||||
|
|
||||||
|
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
||||||
|
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
|
||||||
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
|
||||||
|
default: GGML_ASSERT(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
||||||
|
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
||||||
|
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
||||||
|
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
||||||
|
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
||||||
|
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
||||||
|
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
||||||
|
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
||||||
|
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
||||||
|
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
||||||
|
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
||||||
|
} break;
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
|
106
ggml-metal.metal
106
ggml-metal.metal
@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define N_F16_F16 4
|
||||||
|
|
||||||
|
kernel void kernel_mul_mv_f16_f16(
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
const int64_t r0 = tgpig.x;
|
||||||
|
const int64_t rb = tgpig.y*N_F16_F16;
|
||||||
|
const int64_t im = tgpig.z;
|
||||||
|
|
||||||
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||||
|
|
||||||
|
if (ne00 < 128) {
|
||||||
|
for (int row = 0; row < N_F16_F16; ++row) {
|
||||||
|
int r1 = rb + row;
|
||||||
|
if (r1 >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = tiisg; i < ne00; i += 32) {
|
||||||
|
sumf += (half) x[i] * (half) y[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
float all_sum = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
device const half4 * x4 = (device const half4 *)x;
|
||||||
|
for (int row = 0; row < N_F16_F16; ++row) {
|
||||||
|
int r1 = rb + row;
|
||||||
|
if (r1 >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
device const half4 * y4 = (device const half4 *) y;
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||||
|
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
|
||||||
|
}
|
||||||
|
|
||||||
|
float all_sum = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mv_f16_f32_1row(
|
kernel void kernel_mul_mv_f16_f32_1row(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
@ -1229,6 +1302,39 @@ kernel void kernel_rope(
|
|||||||
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
||||||
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
||||||
|
|
||||||
|
kernel void kernel_im2col_f16(
|
||||||
|
device const float * x,
|
||||||
|
device half * dst,
|
||||||
|
constant int32_t & ofs0,
|
||||||
|
constant int32_t & ofs1,
|
||||||
|
constant int32_t & IW,
|
||||||
|
constant int32_t & IH,
|
||||||
|
constant int32_t & CHW,
|
||||||
|
constant int32_t & s0,
|
||||||
|
constant int32_t & s1,
|
||||||
|
constant int32_t & p0,
|
||||||
|
constant int32_t & p1,
|
||||||
|
constant int32_t & d0,
|
||||||
|
constant int32_t & d1,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tgpg[[threadgroups_per_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
|
||||||
|
const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
|
||||||
|
|
||||||
|
const int32_t offset_dst =
|
||||||
|
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
||||||
|
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
|
||||||
|
|
||||||
|
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||||
|
dst[offset_dst] = 0.0f;
|
||||||
|
} else {
|
||||||
|
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
|
||||||
|
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_cpy_f16_f16(
|
kernel void kernel_cpy_f16_f16(
|
||||||
device const half * src0,
|
device const half * src0,
|
||||||
device half * dst,
|
device half * dst,
|
||||||
|
239
ggml-quants.c
239
ggml-quants.c
@ -14,26 +14,6 @@
|
|||||||
//
|
//
|
||||||
#include <arm_neon.h>
|
#include <arm_neon.h>
|
||||||
|
|
||||||
#if !defined(__aarch64__)
|
|
||||||
inline static int32_t vaddvq_s16(int16x8_t v) {
|
|
||||||
return
|
|
||||||
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
|
|
||||||
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
|
|
||||||
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
|
|
||||||
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
|
|
||||||
int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
|
|
||||||
int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
|
|
||||||
return vcombine_s16(a0, b0);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline static int32_t vaddvq_s32(int32x4_t v) {
|
|
||||||
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
#ifdef __wasm_simd128__
|
#ifdef __wasm_simd128__
|
||||||
@ -47,13 +27,15 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
|
|||||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||||
#include <intrin.h>
|
#include <intrin.h>
|
||||||
#else
|
#else
|
||||||
#if !defined(__riscv) && !defined(__s390__)
|
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
|
||||||
|
#if !defined(__riscv)
|
||||||
#include <immintrin.h>
|
#include <immintrin.h>
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef __riscv_v_intrinsic
|
#ifdef __riscv_v_intrinsic
|
||||||
#include <riscv_vector.h>
|
#include <riscv_vector.h>
|
||||||
@ -61,6 +43,7 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
|
|||||||
|
|
||||||
#undef MIN
|
#undef MIN
|
||||||
#undef MAX
|
#undef MAX
|
||||||
|
|
||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
|
|
||||||
@ -283,9 +266,31 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
|
|||||||
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
|
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
|
||||||
|
|
||||||
#if defined(__ARM_NEON)
|
#if defined(__ARM_NEON)
|
||||||
|
|
||||||
#if !defined(__aarch64__)
|
#if !defined(__aarch64__)
|
||||||
|
|
||||||
|
// 64-bit compatibility
|
||||||
|
|
||||||
|
// vaddvq_s16
|
||||||
|
// vpaddq_s16
|
||||||
|
// vaddvq_s32
|
||||||
|
// vaddvq_f32
|
||||||
|
// vmaxvq_f32
|
||||||
|
// vcvtnq_s32_f32
|
||||||
|
|
||||||
|
inline static int32_t vaddvq_s16(int16x8_t v) {
|
||||||
|
return
|
||||||
|
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
|
||||||
|
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
|
||||||
|
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
|
||||||
|
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
|
||||||
|
int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
|
||||||
|
int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
|
||||||
|
return vcombine_s16(a0, b0);
|
||||||
|
}
|
||||||
|
|
||||||
inline static int32_t vaddvq_s32(int32x4_t v) {
|
inline static int32_t vaddvq_s32(int32x4_t v) {
|
||||||
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
|
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
|
||||||
}
|
}
|
||||||
@ -311,6 +316,96 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// vld1q_s16_x2
|
||||||
|
// vld1q_u8_x2
|
||||||
|
// vld1q_u8_x4
|
||||||
|
// vld1q_s8_x2
|
||||||
|
// vld1q_s8_x4
|
||||||
|
// TODO: double-check these work correctly
|
||||||
|
|
||||||
|
typedef struct ggml_int16x8x2_t {
|
||||||
|
int16x8_t val[2];
|
||||||
|
} ggml_int16x8x2_t;
|
||||||
|
|
||||||
|
inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
|
||||||
|
ggml_int16x8x2_t res;
|
||||||
|
|
||||||
|
res.val[0] = vld1q_s16(ptr + 0);
|
||||||
|
res.val[1] = vld1q_s16(ptr + 8);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef struct ggml_uint8x16x2_t {
|
||||||
|
uint8x16_t val[2];
|
||||||
|
} ggml_uint8x16x2_t;
|
||||||
|
|
||||||
|
inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
|
||||||
|
ggml_uint8x16x2_t res;
|
||||||
|
|
||||||
|
res.val[0] = vld1q_u8(ptr + 0);
|
||||||
|
res.val[1] = vld1q_u8(ptr + 16);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef struct ggml_uint8x16x4_t {
|
||||||
|
uint8x16_t val[4];
|
||||||
|
} ggml_uint8x16x4_t;
|
||||||
|
|
||||||
|
inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
|
||||||
|
ggml_uint8x16x4_t res;
|
||||||
|
|
||||||
|
res.val[0] = vld1q_u8(ptr + 0);
|
||||||
|
res.val[1] = vld1q_u8(ptr + 16);
|
||||||
|
res.val[2] = vld1q_u8(ptr + 32);
|
||||||
|
res.val[3] = vld1q_u8(ptr + 48);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef struct ggml_int8x16x2_t {
|
||||||
|
int8x16_t val[2];
|
||||||
|
} ggml_int8x16x2_t;
|
||||||
|
|
||||||
|
inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
|
||||||
|
ggml_int8x16x2_t res;
|
||||||
|
|
||||||
|
res.val[0] = vld1q_s8(ptr + 0);
|
||||||
|
res.val[1] = vld1q_s8(ptr + 16);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef struct ggml_int8x16x4_t {
|
||||||
|
int8x16_t val[4];
|
||||||
|
} ggml_int8x16x4_t;
|
||||||
|
|
||||||
|
inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
|
||||||
|
ggml_int8x16x4_t res;
|
||||||
|
|
||||||
|
res.val[0] = vld1q_s8(ptr + 0);
|
||||||
|
res.val[1] = vld1q_s8(ptr + 16);
|
||||||
|
res.val[2] = vld1q_s8(ptr + 32);
|
||||||
|
res.val[3] = vld1q_s8(ptr + 48);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#define ggml_int16x8x2_t int16x8x2_t
|
||||||
|
#define ggml_uint8x16x2_t uint8x16x2_t
|
||||||
|
#define ggml_uint8x16x4_t uint8x16x4_t
|
||||||
|
#define ggml_int8x16x2_t int8x16x2_t
|
||||||
|
#define ggml_int8x16x4_t int8x16x4_t
|
||||||
|
|
||||||
|
#define ggml_vld1q_s16_x2 vld1q_s16_x2
|
||||||
|
#define ggml_vld1q_u8_x2 vld1q_u8_x2
|
||||||
|
#define ggml_vld1q_u8_x4 vld1q_u8_x4
|
||||||
|
#define ggml_vld1q_s8_x2 vld1q_s8_x2
|
||||||
|
#define ggml_vld1q_s8_x4 vld1q_s8_x4
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -3557,7 +3652,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const int32x4_t vzero = vdupq_n_s32(0);
|
const int32x4_t vzero = vdupq_n_s32(0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int8x16x2_t q2bytes;
|
ggml_int8x16x2_t q2bytes;
|
||||||
uint8_t aux[16];
|
uint8_t aux[16];
|
||||||
|
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
@ -3576,8 +3671,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
vst1q_u8(aux, scales);
|
vst1q_u8(aux, scales);
|
||||||
|
|
||||||
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
|
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
|
||||||
const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
|
const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
|
||||||
const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
|
const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
|
||||||
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
|
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
|
||||||
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
|
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
|
||||||
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
|
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
|
||||||
@ -3605,7 +3700,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
|
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
|
||||||
q8bytes = vld1q_s8_x2(q8); q8 += 32;\
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
|
||||||
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
|
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
|
||||||
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
|
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
|
||||||
MULTIPLY_ACCUM_WITH_SCALE((index));
|
MULTIPLY_ACCUM_WITH_SCALE((index));
|
||||||
@ -3613,9 +3708,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
for (int j = 0; j < QK_K/128; ++j) {
|
for (int j = 0; j < QK_K/128; ++j) {
|
||||||
|
|
||||||
const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
|
const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
|
||||||
|
|
||||||
int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
||||||
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
|
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
|
||||||
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
|
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
|
||||||
MULTIPLY_ACCUM_WITH_SCALE(0);
|
MULTIPLY_ACCUM_WITH_SCALE(0);
|
||||||
@ -3949,7 +4044,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const int32x4_t vzero = vdupq_n_s32(0);
|
const int32x4_t vzero = vdupq_n_s32(0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int8x16x4_t q2bytes;
|
ggml_int8x16x4_t q2bytes;
|
||||||
|
|
||||||
uint32_t aux32[2];
|
uint32_t aux32[2];
|
||||||
const uint8_t * scales = (const uint8_t *)aux32;
|
const uint8_t * scales = (const uint8_t *)aux32;
|
||||||
@ -3974,7 +4069,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
const uint8x16_t q2bits = vld1q_u8(q2);
|
const uint8x16_t q2bits = vld1q_u8(q2);
|
||||||
|
|
||||||
const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
|
const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
|
||||||
|
|
||||||
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
|
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
|
||||||
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
|
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
|
||||||
@ -4238,7 +4333,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const uint8x16_t m3 = vshlq_n_u8(m0, 3);
|
const uint8x16_t m3 = vshlq_n_u8(m0, 3);
|
||||||
const int8_t m32 = 32;
|
const int8_t m32 = 32;
|
||||||
|
|
||||||
int8x16x4_t q3bytes;
|
ggml_int8x16x4_t q3bytes;
|
||||||
|
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
|
|
||||||
@ -4250,9 +4345,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const uint8_t * restrict qh = x[i].hmask;
|
const uint8_t * restrict qh = x[i].hmask;
|
||||||
const int8_t * restrict q8 = y[i].qs;
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
uint8x16x2_t qhbits = vld1q_u8_x2(qh);
|
ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
|
||||||
|
|
||||||
uint8x16x4_t q3h;
|
ggml_uint8x16x4_t q3h;
|
||||||
|
|
||||||
int32_t isum = 0;
|
int32_t isum = 0;
|
||||||
|
|
||||||
@ -4268,9 +4363,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
for (int j = 0; j < QK_K/128; ++j) {
|
for (int j = 0; j < QK_K/128; ++j) {
|
||||||
|
|
||||||
const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32;
|
const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
|
||||||
const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64;
|
const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
|
||||||
const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64;
|
const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
|
||||||
|
|
||||||
q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
|
q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
|
||||||
q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
|
q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
|
||||||
@ -4772,7 +4867,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
||||||
const uint8x16_t mh = vdupq_n_u8(4);
|
const uint8x16_t mh = vdupq_n_u8(4);
|
||||||
|
|
||||||
int8x16x4_t q3bytes;
|
ggml_int8x16x4_t q3bytes;
|
||||||
|
|
||||||
uint16_t aux16[2];
|
uint16_t aux16[2];
|
||||||
int8_t * scales = (int8_t *)aux16;
|
int8_t * scales = (int8_t *)aux16;
|
||||||
@ -4781,11 +4876,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
uint8x16x4_t q3h;
|
ggml_uint8x16x4_t q3h;
|
||||||
|
|
||||||
const uint8x8_t hbits = vld1_u8(x[i].hmask);
|
const uint8x8_t hbits = vld1_u8(x[i].hmask);
|
||||||
const uint8x16_t q3bits = vld1q_u8(x[i].qs);
|
const uint8x16_t q3bits = vld1q_u8(x[i].qs);
|
||||||
const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
|
const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs);
|
||||||
|
|
||||||
const uint16_t a = *(const uint16_t *)x[i].scales;
|
const uint16_t a = *(const uint16_t *)x[i].scales;
|
||||||
aux16[0] = a & 0x0f0f;
|
aux16[0] = a & 0x0f0f;
|
||||||
@ -5134,8 +5229,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const int32x4_t mzero = vdupq_n_s32(0);
|
const int32x4_t mzero = vdupq_n_s32(0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int8x16x2_t q4bytes;
|
ggml_int8x16x2_t q4bytes;
|
||||||
int8x16x2_t q8bytes;
|
ggml_int8x16x2_t q8bytes;
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
|
|
||||||
@ -5170,17 +5265,17 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
for (int j = 0; j < QK_K/64; ++j) {
|
for (int j = 0; j < QK_K/64; ++j) {
|
||||||
|
|
||||||
const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
|
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
|
||||||
|
|
||||||
#ifdef __ARM_FEATURE_DOTPROD
|
#ifdef __ARM_FEATURE_DOTPROD
|
||||||
q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
||||||
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
||||||
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
||||||
|
|
||||||
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
||||||
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
|
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
|
||||||
|
|
||||||
q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
||||||
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
||||||
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
||||||
|
|
||||||
@ -5188,7 +5283,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
||||||
#else
|
#else
|
||||||
q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
||||||
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
||||||
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
||||||
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
@ -5197,7 +5292,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
|
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
|
||||||
|
|
||||||
q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
||||||
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
||||||
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
||||||
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
@ -5512,8 +5607,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
|
|
||||||
int8x16x2_t q4bytes;
|
ggml_int8x16x2_t q4bytes;
|
||||||
int8x16x4_t q8bytes;
|
ggml_int8x16x4_t q8bytes;
|
||||||
|
|
||||||
float sum_mins = 0.f;
|
float sum_mins = 0.f;
|
||||||
|
|
||||||
@ -5534,10 +5629,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
const float d = y[i].d * (float)x[i].d[0];
|
const float d = y[i].d * (float)x[i].d[0];
|
||||||
|
|
||||||
const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
|
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
|
||||||
|
|
||||||
#ifdef __ARM_FEATURE_DOTPROD
|
#ifdef __ARM_FEATURE_DOTPROD
|
||||||
q8bytes = vld1q_s8_x4(q8);
|
q8bytes = ggml_vld1q_s8_x4(q8);
|
||||||
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
||||||
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
||||||
|
|
||||||
@ -5551,7 +5646,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
|
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
|
||||||
|
|
||||||
#else
|
#else
|
||||||
q8bytes = vld1q_s8_x4(q8);
|
q8bytes = ggml_vld1q_s8_x4(q8);
|
||||||
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
||||||
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
||||||
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
@ -5785,7 +5880,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const int32x4_t mzero = vdupq_n_s32(0);
|
const int32x4_t mzero = vdupq_n_s32(0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int8x16x4_t q5bytes;
|
ggml_int8x16x4_t q5bytes;
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
|
|
||||||
@ -5815,16 +5910,16 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const uint8_t * restrict qh = x[i].qh;
|
const uint8_t * restrict qh = x[i].qh;
|
||||||
const int8_t * restrict q8 = y[i].qs;
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
uint8x16x2_t qhbits = vld1q_u8_x2(qh);
|
ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
|
||||||
|
|
||||||
uint8x16x4_t q5h;
|
ggml_uint8x16x4_t q5h;
|
||||||
|
|
||||||
int32_t sumi = 0;
|
int32_t sumi = 0;
|
||||||
|
|
||||||
for (int j = 0; j < QK_K/64; ++j) {
|
for (int j = 0; j < QK_K/64; ++j) {
|
||||||
|
|
||||||
const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32;
|
const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
|
||||||
const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
|
const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
|
||||||
|
|
||||||
q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
|
q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
|
||||||
q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
|
q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
|
||||||
@ -6218,8 +6313,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const int32x4_t mzero = vdupq_n_s32(0);
|
const int32x4_t mzero = vdupq_n_s32(0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int8x16x4_t q5bytes;
|
ggml_int8x16x4_t q5bytes;
|
||||||
uint8x16x4_t q5h;
|
ggml_uint8x16x4_t q5h;
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
|
|
||||||
@ -6234,8 +6329,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
const uint8x8_t qhbits = vld1_u8(qh);
|
const uint8x8_t qhbits = vld1_u8(qh);
|
||||||
|
|
||||||
const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
|
const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5);
|
||||||
const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
|
const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
|
||||||
|
|
||||||
const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
|
const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
|
||||||
q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
|
q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
|
||||||
@ -6511,8 +6606,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
const uint8x16_t mone = vdupq_n_u8(3);
|
const uint8x16_t mone = vdupq_n_u8(3);
|
||||||
|
|
||||||
int8x16x4_t q6bytes;
|
ggml_int8x16x4_t q6bytes;
|
||||||
uint8x16x4_t q6h;
|
ggml_uint8x16x4_t q6h;
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
@ -6524,9 +6619,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
const int8_t * restrict scale = x[i].scales;
|
const int8_t * restrict scale = x[i].scales;
|
||||||
|
|
||||||
const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
|
const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
|
||||||
const int8x16_t scales = vld1q_s8(scale);
|
const int8x16_t scales = vld1q_s8(scale);
|
||||||
const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
|
const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
|
||||||
|
|
||||||
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
|
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
|
||||||
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
|
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
|
||||||
@ -6538,9 +6633,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
for (int j = 0; j < QK_K/128; ++j) {
|
for (int j = 0; j < QK_K/128; ++j) {
|
||||||
|
|
||||||
uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32;
|
ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
|
||||||
uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64;
|
ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
|
||||||
int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
|
ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
|
||||||
|
|
||||||
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
|
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
|
||||||
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
|
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
|
||||||
@ -6583,7 +6678,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
scale += 2;
|
scale += 2;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
q8bytes = vld1q_s8_x4(q8); q8 += 64;
|
q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
|
||||||
|
|
||||||
shifted = vshrq_n_u8(qhbits.val[0], 4);
|
shifted = vshrq_n_u8(qhbits.val[0], 4);
|
||||||
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||||
@ -6987,8 +7082,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
const uint8x16_t mone = vdupq_n_u8(3);
|
const uint8x16_t mone = vdupq_n_u8(3);
|
||||||
|
|
||||||
int8x16x4_t q6bytes;
|
ggml_int8x16x4_t q6bytes;
|
||||||
uint8x16x4_t q6h;
|
ggml_uint8x16x4_t q6h;
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
@ -7003,8 +7098,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
int32_t isum = 0;
|
int32_t isum = 0;
|
||||||
|
|
||||||
uint8x16_t qhbits = vld1q_u8(qh);
|
uint8x16_t qhbits = vld1q_u8(qh);
|
||||||
uint8x16x2_t q6bits = vld1q_u8_x2(q6);
|
ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6);
|
||||||
int8x16x4_t q8bytes = vld1q_s8_x4(q8);
|
ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
|
||||||
|
|
||||||
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
|
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
|
||||||
uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
|
uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
|
||||||
|
19
ggml.h
19
ggml.h
@ -403,13 +403,8 @@ extern "C" {
|
|||||||
GGML_OP_ROPE_BACK,
|
GGML_OP_ROPE_BACK,
|
||||||
GGML_OP_ALIBI,
|
GGML_OP_ALIBI,
|
||||||
GGML_OP_CLAMP,
|
GGML_OP_CLAMP,
|
||||||
GGML_OP_CONV_1D,
|
|
||||||
GGML_OP_CONV_1D_STAGE_0, // internal
|
|
||||||
GGML_OP_CONV_1D_STAGE_1, // internal
|
|
||||||
GGML_OP_CONV_TRANSPOSE_1D,
|
GGML_OP_CONV_TRANSPOSE_1D,
|
||||||
GGML_OP_CONV_2D,
|
GGML_OP_IM2COL,
|
||||||
GGML_OP_CONV_2D_STAGE_0, // internal
|
|
||||||
GGML_OP_CONV_2D_STAGE_1, // internal
|
|
||||||
GGML_OP_CONV_TRANSPOSE_2D,
|
GGML_OP_CONV_TRANSPOSE_2D,
|
||||||
GGML_OP_POOL_1D,
|
GGML_OP_POOL_1D,
|
||||||
GGML_OP_POOL_2D,
|
GGML_OP_POOL_2D,
|
||||||
@ -1403,6 +1398,18 @@ extern "C" {
|
|||||||
float min,
|
float min,
|
||||||
float max);
|
float max);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_im2col(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
int s0,
|
||||||
|
int s1,
|
||||||
|
int p0,
|
||||||
|
int p1,
|
||||||
|
int d0,
|
||||||
|
int d1,
|
||||||
|
bool is_2D);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_conv_1d(
|
GGML_API struct ggml_tensor * ggml_conv_1d(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
Loading…
Reference in New Issue
Block a user