mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-31 22:04:35 +00:00
metal : opts
This commit is contained in:
parent
92a0c17474
commit
e68e32548f
@ -1343,7 +1343,7 @@ static bool ggml_metal_graph_compute(
|
|||||||
const int nsg = 8;
|
const int nsg = 8;
|
||||||
|
|
||||||
const int nsg0 = 1;
|
const int nsg0 = 1;
|
||||||
const int nsh0 = 8;
|
const int nsh0 = 16;
|
||||||
const int nsg1 = 1;
|
const int nsg1 = 1;
|
||||||
const int nsh1 = 64;
|
const int nsh1 = 64;
|
||||||
|
|
||||||
|
110
ggml-metal.metal
110
ggml-metal.metal
@ -4785,7 +4785,7 @@ void kernel_mul_mm_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define NSG0 1
|
#define NSG0 1
|
||||||
#define NSH0 8
|
#define NSH0 16
|
||||||
#define NSG1 1
|
#define NSG1 1
|
||||||
#define NSH1 64
|
#define NSH1 64
|
||||||
|
|
||||||
@ -4815,33 +4815,34 @@ void kernel_mul_mm2_impl(
|
|||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
const uint nsg = ntg.y; // number of simdgroups
|
const uint nsg = ntg.y; // number of simdgroups
|
||||||
|
|
||||||
const int64_t im = tgpig[2];
|
const int im = tgpig[2];
|
||||||
const int64_t i11 = tgpig[1]*(8*NSG1);
|
const int i11 = tgpig[1]*(8*NSG1);
|
||||||
const int64_t i01 = tgpig[0]*(8*NSG0*nsg) + sgitg*(8*NSG0);
|
const int i01 = tgpig[0]*(8*NSG0*nsg) + sgitg*(8*NSG0);
|
||||||
|
|
||||||
const int64_t i12 = im%ne12;
|
const int i12 = im%ne12;
|
||||||
const int64_t i13 = im/ne12;
|
const int i13 = im/ne12;
|
||||||
|
|
||||||
const int64_t ne01 = ne0;
|
const int ne01 = ne0;
|
||||||
const int64_t ne11 = ne1;
|
const int ne11 = ne1;
|
||||||
|
|
||||||
const int64_t NW = N_SIMDWIDTH;
|
const int NW = N_SIMDWIDTH;
|
||||||
|
|
||||||
const int64_t SH0 = (8*NSG0)*(8*NSH0); // shread memory per threadgroup for src0 data in (half)
|
const int SH0 = (8*NSG0)*(8*NSH0); // shread memory per threadgroup for src0 data in (half)
|
||||||
const int64_t SH04 = SH0/4; // shread memory per threadgroup for src0 data in (half4)
|
const int SH04 = SH0/4; // shread memory per threadgroup for src0 data in (half4)
|
||||||
|
|
||||||
const int64_t SH1 = (8*NSG1)*(8*NSH1); // shread memory for src1 data in (float)
|
const int SH1 = (8*NSG1)*(8*NSH1); // shread memory for src1 data in (float)
|
||||||
const int64_t SH14 = SH1/4; // shread memory for src1 data in (float4)
|
const int SH14 = SH1/4; // shread memory for src1 data in (float4)
|
||||||
|
|
||||||
const int64_t T1 = 8*NSH1; // row of src1 in shared memory in (float)
|
const int T1 = 8*NSH1; // row of src1 in shared memory in (float)
|
||||||
const int64_t T14 = T1/4; // row of src1 in shared memory in (float4)
|
const int T14 = T1/4; // row of src1 in shared memory in (float4)
|
||||||
|
|
||||||
threadgroup half * shared = (threadgroup half *) shared_u8;
|
threadgroup half * shared = (threadgroup half *) shared_u8;
|
||||||
|
|
||||||
threadgroup half * s0 = (threadgroup half *)(shared + sgitg*SH0);
|
threadgroup half * s0 = (threadgroup half *)(shared + sgitg*SH0);
|
||||||
threadgroup half4 * s04 = (threadgroup half4 *)(shared + sgitg*SH0);
|
threadgroup half4 * s04 = (threadgroup half4 *)(shared + sgitg*SH0);
|
||||||
threadgroup float * s1 = (threadgroup float *)(shared + nsg*SH0);
|
threadgroup half4x4 * s016 = (threadgroup half4x4 *)(shared + sgitg*SH0);
|
||||||
threadgroup float4 * s14 = (threadgroup float4 *)(shared + nsg*SH0);
|
threadgroup float * s1 = (threadgroup float *)(shared + nsg*SH0);
|
||||||
|
threadgroup float4 * s14 = (threadgroup float4 *)(shared + nsg*SH0);
|
||||||
|
|
||||||
threadgroup float * r0 = (threadgroup float *)(shared + 2*sgitg*(8*NSG0)*(8*NSG1));
|
threadgroup float * r0 = (threadgroup float *)(shared + 2*sgitg*(8*NSG0)*(8*NSG1));
|
||||||
|
|
||||||
@ -4850,12 +4851,12 @@ void kernel_mul_mm2_impl(
|
|||||||
simdgroup_float8x8 mr[NSG0][NSG1];
|
simdgroup_float8x8 mr[NSG0][NSG1];
|
||||||
|
|
||||||
// zero out shared memory SH0 for src0
|
// zero out shared memory SH0 for src0
|
||||||
for (int64_t i = tiisg; i < SH04; i += NW) {
|
for (int i = tiisg; i < SH04; i += NW) {
|
||||||
s04[i] = 0.0h;
|
s04[i] = 0.0h;
|
||||||
}
|
}
|
||||||
|
|
||||||
// zero out shared memory SH1 for src1
|
// zero out shared memory SH1 for src1
|
||||||
for (int64_t i = tiitg; i < SH14; i += nsg*NW) {
|
for (int i = tiitg; i < SH14; i += nsg*NW) {
|
||||||
s14[i] = 0.0f;
|
s14[i] = 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4868,24 +4869,27 @@ void kernel_mul_mm2_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64_t i00 = 0; i00 < ne00; i00 += 8*NSH1) {
|
for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) {
|
||||||
// load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory
|
// load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory
|
||||||
{
|
{
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
const int64_t nload = min(8ll*NSG1, ne11 - i11) * (8*NSH1);
|
const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1);
|
||||||
|
|
||||||
for (int64_t i = tiitg; i < nload; i += nsg*NW) {
|
const size_t offs0 = im*nb12;
|
||||||
const int64_t ic = i%(8*NSH1);
|
|
||||||
const int64_t ir = i/(8*NSH1);
|
|
||||||
|
|
||||||
// TODO: use float4
|
for (int i = 4*tiitg; i < nload; i += 4*nsg*NW) {
|
||||||
device const float * p1 = (device const float *)(src1 + im*nb12 + (i11 + ir)*nb11 + (i00 + ic)*nb10);
|
const int ic = i%(8*NSH1);
|
||||||
|
const int ir = i/(8*NSH1);
|
||||||
|
|
||||||
if (i00 + ic < ne00) {
|
device const float4 * p1 = (device const float4 *)(src1 + offs0 + (i11 + ir)*nb11 + (i00 + ic)*nb10);
|
||||||
s1[8*NSH1*ir + ic] = *p1;
|
|
||||||
|
if (i00 + ic + 4 <= ne00) {
|
||||||
|
s14[(8*NSH1*ir + ic)/4] = *p1;
|
||||||
} else {
|
} else {
|
||||||
s1[8*NSH1*ir + ic] = 0.0f;
|
for (int k = 0; i00 + ic + k < ne00; k++){
|
||||||
|
s1[8*NSH1*ir + ic + k] = (*p1)[k];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4895,28 +4899,36 @@ void kernel_mul_mm2_impl(
|
|||||||
for (int b0 = 0; b0 < NSH1/NSH0; ++b0) {
|
for (int b0 = 0; b0 < NSH1/NSH0; ++b0) {
|
||||||
// load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory
|
// load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory
|
||||||
{
|
{
|
||||||
const int64_t nload = min(8ll*NSG0, ne01 - i01) * (8*NSH0);
|
const int nload = MIN(8*NSG0, ne01 - i01) * (8*NSH0);
|
||||||
|
|
||||||
half4x4 tmp0;
|
half4x4 tmp0;
|
||||||
|
|
||||||
for (int64_t i = 16*tiisg; i < nload; i += 16*NW) {
|
const size_t offs0 = (i13/r3)*(nb02*ne02) + (i12/r2)*nb02;
|
||||||
const int64_t ic = i%(8*NSH0);
|
|
||||||
const int64_t ir = i/(8*NSH0);
|
|
||||||
|
|
||||||
const int64_t icc = i00 + 8*b0*NSH0 + ic;
|
for (int i = 16*tiisg; i < nload; i += 16*NW) {
|
||||||
|
const int ic = i%(8*NSH0);
|
||||||
|
const int ir = i/(8*NSH0);
|
||||||
|
|
||||||
const int64_t ib = (icc/(16*nl));
|
const int icc = i00 + 8*b0*NSH0 + ic;
|
||||||
const int64_t il = (icc%(16*nl))/16;
|
|
||||||
|
|
||||||
device const block_q * p0 = (device const block_q *)(src0 + (i13/r3)*(nb02*ne02) + (i12/r2)*nb02 + (i01 + ir)*nb01) + ib;
|
const int ib = (icc/(16*nl));
|
||||||
|
const int il = (icc%(16*nl))/16;
|
||||||
|
|
||||||
|
device const block_q * p0 = (device const block_q *)(src0 + offs0 + (i01 + ir)*nb01) + ib;
|
||||||
|
|
||||||
dequantize_func(p0, il, tmp0);
|
dequantize_func(p0, il, tmp0);
|
||||||
|
|
||||||
for (int k = 0; k < 4; k++){
|
if (icc + 16 <= ne00) {
|
||||||
if (icc + 4*k < ne00) {
|
s016[(8*NSH0*ir + ic)/16] = tmp0;
|
||||||
s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k];
|
} else {
|
||||||
} else {
|
for (int k = 0; k < 4; k++){
|
||||||
s04[(8*NSH0*ir + ic)/4 + k] = 0.0h;
|
if (icc + 4*k <= ne00) {
|
||||||
|
s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k];
|
||||||
|
} else {
|
||||||
|
for (int p = 0; icc + 4*k + p < ne00; p++) {
|
||||||
|
s0[8*NSH0*ir + ic + 4*k + p] = tmp0[k][p];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -4958,12 +4970,12 @@ void kernel_mul_mm2_impl(
|
|||||||
device float * pdst = dst + im*ne1*ne0;
|
device float * pdst = dst + im*ne1*ne0;
|
||||||
|
|
||||||
for (int is = 0; is < NSG1; is++) {
|
for (int is = 0; is < NSG1; is++) {
|
||||||
const int64_t i1 = i11 + is*8;
|
const int i1 = i11 + is*8;
|
||||||
const int64_t nstore = min(8ll*NSG1, ne1 - i1) * (8*NSG0);
|
const int nstore = MIN(8*NSG1, ne1 - i1) * (8*NSG0);
|
||||||
|
|
||||||
for (int64_t i = tiisg; i < nstore; i += NW) {
|
for (int i = tiisg; i < nstore; i += NW) {
|
||||||
const int64_t ic = i%(8*NSG0);
|
const int ic = i%(8*NSG0);
|
||||||
const int64_t ir = i/(8*NSG0);
|
const int ir = i/(8*NSG0);
|
||||||
|
|
||||||
if (i1 + ir < ne1 && i01 + ic < ne0) {
|
if (i1 + ir < ne1 && i01 + ic < ne0) {
|
||||||
pdst[(i1 + ir)*ne0 + (i01 + ic)] = r0[(8*is)*(8*NSG0) + 8*NSG0*ir + ic];
|
pdst[(i1 + ir)*ne0 + (i01 + ic)] = r0[(8*is)*(8*NSG0) + 8*NSG0*ir + ic];
|
||||||
|
Loading…
Reference in New Issue
Block a user