mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
metal : more optimizations (#2959)
* Very minor speedup via simd-group synchronization in f16 x f32 * Another very minor speedup on metal * Quite significant PP speedup on metal * Another attempt * Minor * Massive improvement for TG for fp16 * ~4-5% improvement for Q8_0 TG on metal --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
6a31a3bd98
commit
ca82cf7bac
20
ggml-metal.m
20
ggml-metal.m
@ -76,6 +76,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DECL_KERNEL(norm);
|
GGML_METAL_DECL_KERNEL(norm);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
||||||
@ -219,6 +220,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||||
GGML_METAL_ADD_KERNEL(norm);
|
GGML_METAL_ADD_KERNEL(norm);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
||||||
@ -284,6 +286,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(rms_norm);
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DEL_KERNEL(norm);
|
GGML_METAL_DEL_KERNEL(norm);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
||||||
@ -868,7 +871,11 @@ void ggml_metal_graph_compute(
|
|||||||
{
|
{
|
||||||
nth0 = 32;
|
nth0 = 32;
|
||||||
nth1 = 1;
|
nth1 = 1;
|
||||||
|
if (ne11 * ne12 < 4) {
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
||||||
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
{
|
{
|
||||||
@ -920,8 +927,8 @@ void ggml_metal_graph_compute(
|
|||||||
GGML_ASSERT(ne02 == 1);
|
GGML_ASSERT(ne02 == 1);
|
||||||
GGML_ASSERT(ne12 == 1);
|
GGML_ASSERT(ne12 == 1);
|
||||||
|
|
||||||
nth0 = 2;
|
nth0 = 4; //1;
|
||||||
nth1 = 32;
|
nth1 = 8; //32;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
@ -969,9 +976,12 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
||||||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
|
else if (src0t == GGML_TYPE_Q4_K) {
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
|
}
|
||||||
else if (src0t == GGML_TYPE_Q3_K) {
|
else if (src0t == GGML_TYPE_Q3_K) {
|
||||||
#ifdef GGML_QKK_64
|
#ifdef GGML_QKK_64
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
@ -985,8 +995,8 @@ void ggml_metal_graph_compute(
|
|||||||
else if (src0t == GGML_TYPE_Q6_K) {
|
else if (src0t == GGML_TYPE_Q6_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
int64_t ny = (ne11 + 3)/4;
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
218
ggml-metal.metal
218
ggml-metal.metal
@ -133,19 +133,24 @@ kernel void kernel_soft_max(
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
// broadcast
|
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
|
||||||
if (tpitg[0] == 0) {
|
// the loop, and when that is done, buf[0] has the correct (synchronized) value
|
||||||
buf[0] = buf[0];
|
//if (tpitg[0] == 0) {
|
||||||
}
|
// buf[0] = buf[0];
|
||||||
|
//}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
const float max = buf[0];
|
const float max = buf[0];
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
buf[tpitg[0]] = 0.0f;
|
buf[tpitg[0]] = 0.0f;
|
||||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||||
buf[tpitg[0]] += exp(psrc0[i00] - max);
|
const float exp_psrc0 = exp(psrc0[i00] - max);
|
||||||
|
buf[tpitg[0]] += exp_psrc0;
|
||||||
|
// Remember the result of exp here. exp is expensive, so we really do not
|
||||||
|
// whish to compute it twice.
|
||||||
|
pdst[i00] = exp_psrc0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce
|
// reduce
|
||||||
@ -157,17 +162,18 @@ kernel void kernel_soft_max(
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
// broadcast
|
// broadcast - not needed, see above
|
||||||
if (tpitg[0] == 0) {
|
//// broadcast
|
||||||
buf[0] = buf[0];
|
//if (tpitg[0] == 0) {
|
||||||
}
|
// buf[0] = buf[0];
|
||||||
|
//}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
const float sum = buf[0];
|
const float sum = buf[0];
|
||||||
|
|
||||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||||
pdst[i00] = exp(psrc0[i00] - max) / sum;
|
pdst[i00] /= sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,25 +220,27 @@ kernel void kernel_norm(
|
|||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
// broadcast
|
//// broadcast
|
||||||
if (tpitg == 0) {
|
//if (tpitg == 0) {
|
||||||
sum[0] /= ne00;
|
// sum[0] /= ne00;
|
||||||
}
|
//}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
const float mean = sum[0];
|
const float mean = sum[0];
|
||||||
|
|
||||||
// recenter
|
// recenter and VARIANCE
|
||||||
device float * y = dst + tgpig*ne00;
|
device float * y = dst + tgpig*ne00;
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
||||||
y[i00] = x[i00] - mean;
|
|
||||||
}
|
|
||||||
|
|
||||||
// VARIANCE
|
|
||||||
// parallel sum
|
|
||||||
sum[tpitg] = 0.0f;
|
sum[tpitg] = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
|
y[i00] = x[i00] - mean;
|
||||||
sum[tpitg] += y[i00] * y[i00];
|
sum[tpitg] += y[i00] * y[i00];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//// VARIANCE
|
||||||
|
//// parallel sum
|
||||||
|
//sum[tpitg] = 0.0f;
|
||||||
|
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
|
// sum[tpitg] += y[i00] * y[i00];
|
||||||
|
//}
|
||||||
// reduce
|
// reduce
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for (uint i = ntg/2; i > 0; i /= 2) {
|
for (uint i = ntg/2; i > 0; i /= 2) {
|
||||||
@ -241,11 +249,11 @@ kernel void kernel_norm(
|
|||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
// broadcast
|
//// broadcast
|
||||||
if (tpitg == 0) {
|
//if (tpitg == 0) {
|
||||||
sum[0] /= ne00;
|
// sum[0] /= ne00;
|
||||||
}
|
//}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
const float variance = sum[0];
|
const float variance = sum[0];
|
||||||
|
|
||||||
const float scale = 1.0f/sqrt(variance + eps);
|
const float scale = 1.0f/sqrt(variance + eps);
|
||||||
@ -435,6 +443,8 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|||||||
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define NB_Q8_0 8
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q8_0_f32(
|
kernel void kernel_mul_mat_q8_0_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
@ -463,30 +473,30 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|||||||
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
||||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||||
|
|
||||||
float yl[16];
|
float yl[NB_Q8_0];
|
||||||
float sumf[nr]={0.f};
|
float sumf[nr]={0.f};
|
||||||
|
|
||||||
const int ix = tiisg/2;
|
const int ix = tiisg/4;
|
||||||
const int il = tiisg%2;
|
const int il = tiisg%4;
|
||||||
|
|
||||||
device const float * yb = y + ix * QK8_0 + 16*il;
|
device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
|
||||||
|
|
||||||
// each thread in a SIMD group deals with half a block.
|
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
|
||||||
for (int ib = ix; ib < nb; ib += nw/2) {
|
for (int ib = ix; ib < nb; ib += nw/4) {
|
||||||
for (int i = 0; i < 16; ++i) {
|
for (int i = 0; i < NB_Q8_0; ++i) {
|
||||||
yl[i] = yb[i];
|
yl[i] = yb[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < nr; row++) {
|
for (int row = 0; row < nr; row++) {
|
||||||
device const int8_t * qs = x[ib+row*nb].qs + 16*il;
|
device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
|
||||||
float sumq = 0.f;
|
float sumq = 0.f;
|
||||||
for (int iq = 0; iq < 16; ++iq) {
|
for (int iq = 0; iq < NB_Q8_0; ++iq) {
|
||||||
sumq += qs[iq] * yl[iq];
|
sumq += qs[iq] * yl[iq];
|
||||||
}
|
}
|
||||||
sumf[row] += sumq*x[ib+row*nb].d;
|
sumf[row] += sumq*x[ib+row*nb].d;
|
||||||
}
|
}
|
||||||
|
|
||||||
yb += QK8_0 * 16;
|
yb += NB_Q8_0 * nw;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < nr; ++row) {
|
for (int row = 0; row < nr; ++row) {
|
||||||
@ -497,6 +507,60 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_mul_mat_f16_f32_1row(
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
const int64_t r0 = tgpig.x;
|
||||||
|
const int64_t r1 = tgpig.y;
|
||||||
|
const int64_t im = tgpig.z;
|
||||||
|
|
||||||
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||||
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
if (ne00 < 128) {
|
||||||
|
for (int i = tiisg; i < ne00; i += 32) {
|
||||||
|
sumf += (float) x[i] * (float) y[i];
|
||||||
|
}
|
||||||
|
float all_sum = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
device const half4 * x4 = (device const half4 *) x;
|
||||||
|
device const float4 * y4 = (device const float4 *) y;
|
||||||
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||||
|
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
|
||||||
|
}
|
||||||
|
float all_sum = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i];
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#define N_F16_F32 4
|
||||||
|
|
||||||
kernel void kernel_mul_mat_f16_f32(
|
kernel void kernel_mul_mat_f16_f32(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
@ -515,55 +579,58 @@ kernel void kernel_mul_mat_f16_f32(
|
|||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
threadgroup float * sum [[threadgroup(0)]],
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpig[[thread_position_in_grid]],
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
||||||
uint3 tptg[[threads_per_threadgroup]]) {
|
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int64_t rb = N_F16_F32*tgpig.y;
|
||||||
const int64_t im = tgpig.z;
|
const int64_t im = tgpig.z;
|
||||||
|
|
||||||
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||||
|
|
||||||
|
if (ne00 < 128) {
|
||||||
|
for (int row = 0; row < N_F16_F32; ++row) {
|
||||||
|
int r1 = rb + row;
|
||||||
|
if (r1 >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
|
||||||
uint ith = tpitg.x;
|
float sumf = 0;
|
||||||
uint nth = tptg.x;
|
for (int i = tiisg; i < ne00; i += 32) {
|
||||||
|
sumf += (float) x[i] * (float) y[i];
|
||||||
sum[ith] = 0.0f;
|
|
||||||
|
|
||||||
for (int i = ith; i < ne00; i += nth) {
|
|
||||||
sum[ith] += (float) x[i] * (float) y[i];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// accumulate the sum from all threads in the threadgroup
|
float all_sum = simd_sum(sumf);
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
if (tiisg == 0) {
|
||||||
if (ith%4 == 0) {
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
device const half4 * x4 = (device const half4 *)x;
|
||||||
|
for (int row = 0; row < N_F16_F32; ++row) {
|
||||||
|
int r1 = rb + row;
|
||||||
|
if (r1 >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
device const float4 * y4 = (device const float4 *) y;
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||||
|
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
||||||
|
}
|
||||||
|
|
||||||
|
float all_sum = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i];
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
if (ith%16 == 0) {
|
|
||||||
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
if (ith == 0) {
|
|
||||||
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
|
||||||
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Original implementation. Left behind commented out for now
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
//for (uint i = tptg.x/2; i > 0; i /= 2) {
|
|
||||||
// if (tpitg.x < i) {
|
|
||||||
// sum[tpitg.x] += sum[tpitg.x + i];
|
|
||||||
// }
|
|
||||||
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//if (tpitg.x == 0) {
|
|
||||||
// dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
|
||||||
//}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_alibi_f32(
|
kernel void kernel_alibi_f32(
|
||||||
@ -1262,7 +1329,8 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|||||||
const int r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int r1 = tgpig.y;
|
const int r1 = tgpig.y;
|
||||||
const int r2 = tgpig.z;
|
const int r2 = tgpig.z;
|
||||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||||
|
const int first_row = r0 * N_DST;
|
||||||
const int ib_row = first_row * nb;
|
const int ib_row = first_row * nb;
|
||||||
const uint offset0 = r2/gqa*(nb*ne0);
|
const uint offset0 = r2/gqa*(nb*ne0);
|
||||||
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
||||||
|
Loading…
Reference in New Issue
Block a user