mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
parent
73a12a6344
commit
b7f2aa9e51
@ -536,14 +536,27 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
|||||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
|
if (ne00 < 128) {
|
||||||
for (int i = tiisg; i < ne00; i += 32) {
|
for (int i = tiisg; i < ne00; i += 32) {
|
||||||
sumf += (float) x[i] * (float) y[i];
|
sumf += (float) x[i] * (float) y[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
float all_sum = simd_sum(sumf);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
device const half4 * x4 = (device const half4 *) x;
|
||||||
|
device const float4 * y4 = (device const float4 *) y;
|
||||||
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||||
|
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
|
||||||
|
}
|
||||||
|
float all_sum = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define N_F16_F32 4
|
#define N_F16_F32 4
|
||||||
@ -570,11 +583,12 @@ kernel void kernel_mul_mat_f16_f32(
|
|||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t rb = N_F16_F32*tgpig.y;
|
const int64_t rb = tgpig.y*N_F16_F32;
|
||||||
const int64_t im = tgpig.z;
|
const int64_t im = tgpig.z;
|
||||||
|
|
||||||
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||||
|
|
||||||
|
if (ne00 < 128) {
|
||||||
for (int row = 0; row < N_F16_F32; ++row) {
|
for (int row = 0; row < N_F16_F32; ++row) {
|
||||||
int r1 = rb + row;
|
int r1 = rb + row;
|
||||||
if (r1 >= ne11) {
|
if (r1 >= ne11) {
|
||||||
@ -593,6 +607,30 @@ kernel void kernel_mul_mat_f16_f32(
|
|||||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
device const half4 * x4 = (device const half4 *)x;
|
||||||
|
for (int row = 0; row < N_F16_F32; ++row) {
|
||||||
|
int r1 = rb + row;
|
||||||
|
if (r1 >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
device const float4 * y4 = (device const float4 *) y;
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||||
|
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
||||||
|
}
|
||||||
|
|
||||||
|
float all_sum = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_alibi_f32(
|
kernel void kernel_alibi_f32(
|
||||||
|
Loading…
Reference in New Issue
Block a user