metal : opts

This commit is contained in:
Georgi Gerganov 2024-02-07 23:12:22 +02:00
parent 92a0c17474
commit e68e32548f
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 62 additions and 50 deletions

View File

@ -1343,7 +1343,7 @@ static bool ggml_metal_graph_compute(
const int nsg = 8;
const int nsg0 = 1;
const int nsh0 = 8;
const int nsh0 = 16;
const int nsg1 = 1;
const int nsh1 = 64;

View File

@ -4785,7 +4785,7 @@ void kernel_mul_mm_impl(
}
#define NSG0 1
#define NSH0 8
#define NSH0 16
#define NSG1 1
#define NSH1 64
@ -4815,31 +4815,32 @@ void kernel_mul_mm2_impl(
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const uint nsg = ntg.y; // number of simdgroups
const int64_t im = tgpig[2];
const int64_t i11 = tgpig[1]*(8*NSG1);
const int64_t i01 = tgpig[0]*(8*NSG0*nsg) + sgitg*(8*NSG0);
const int im = tgpig[2];
const int i11 = tgpig[1]*(8*NSG1);
const int i01 = tgpig[0]*(8*NSG0*nsg) + sgitg*(8*NSG0);
const int64_t i12 = im%ne12;
const int64_t i13 = im/ne12;
const int i12 = im%ne12;
const int i13 = im/ne12;
const int64_t ne01 = ne0;
const int64_t ne11 = ne1;
const int ne01 = ne0;
const int ne11 = ne1;
const int64_t NW = N_SIMDWIDTH;
const int NW = N_SIMDWIDTH;
const int64_t SH0 = (8*NSG0)*(8*NSH0); // shread memory per threadgroup for src0 data in (half)
const int64_t SH04 = SH0/4; // shread memory per threadgroup for src0 data in (half4)
const int SH0 = (8*NSG0)*(8*NSH0); // shread memory per threadgroup for src0 data in (half)
const int SH04 = SH0/4; // shread memory per threadgroup for src0 data in (half4)
const int64_t SH1 = (8*NSG1)*(8*NSH1); // shread memory for src1 data in (float)
const int64_t SH14 = SH1/4; // shread memory for src1 data in (float4)
const int SH1 = (8*NSG1)*(8*NSH1); // shread memory for src1 data in (float)
const int SH14 = SH1/4; // shread memory for src1 data in (float4)
const int64_t T1 = 8*NSH1; // row of src1 in shared memory in (float)
const int64_t T14 = T1/4; // row of src1 in shared memory in (float4)
const int T1 = 8*NSH1; // row of src1 in shared memory in (float)
const int T14 = T1/4; // row of src1 in shared memory in (float4)
threadgroup half * shared = (threadgroup half *) shared_u8;
threadgroup half * s0 = (threadgroup half *)(shared + sgitg*SH0);
threadgroup half4 * s04 = (threadgroup half4 *)(shared + sgitg*SH0);
threadgroup half4x4 * s016 = (threadgroup half4x4 *)(shared + sgitg*SH0);
threadgroup float * s1 = (threadgroup float *)(shared + nsg*SH0);
threadgroup float4 * s14 = (threadgroup float4 *)(shared + nsg*SH0);
@ -4850,12 +4851,12 @@ void kernel_mul_mm2_impl(
simdgroup_float8x8 mr[NSG0][NSG1];
// zero out shared memory SH0 for src0
for (int64_t i = tiisg; i < SH04; i += NW) {
for (int i = tiisg; i < SH04; i += NW) {
s04[i] = 0.0h;
}
// zero out shared memory SH1 for src1
for (int64_t i = tiitg; i < SH14; i += nsg*NW) {
for (int i = tiitg; i < SH14; i += nsg*NW) {
s14[i] = 0.0f;
}
@ -4868,24 +4869,27 @@ void kernel_mul_mm2_impl(
}
}
for (int64_t i00 = 0; i00 < ne00; i00 += 8*NSH1) {
for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) {
// load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory
{
threadgroup_barrier(mem_flags::mem_threadgroup);
const int64_t nload = min(8ll*NSG1, ne11 - i11) * (8*NSH1);
const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1);
for (int64_t i = tiitg; i < nload; i += nsg*NW) {
const int64_t ic = i%(8*NSH1);
const int64_t ir = i/(8*NSH1);
const size_t offs0 = im*nb12;
// TODO: use float4
device const float * p1 = (device const float *)(src1 + im*nb12 + (i11 + ir)*nb11 + (i00 + ic)*nb10);
for (int i = 4*tiitg; i < nload; i += 4*nsg*NW) {
const int ic = i%(8*NSH1);
const int ir = i/(8*NSH1);
if (i00 + ic < ne00) {
s1[8*NSH1*ir + ic] = *p1;
device const float4 * p1 = (device const float4 *)(src1 + offs0 + (i11 + ir)*nb11 + (i00 + ic)*nb10);
if (i00 + ic + 4 <= ne00) {
s14[(8*NSH1*ir + ic)/4] = *p1;
} else {
s1[8*NSH1*ir + ic] = 0.0f;
for (int k = 0; i00 + ic + k < ne00; k++){
s1[8*NSH1*ir + ic + k] = (*p1)[k];
}
}
}
@ -4895,28 +4899,36 @@ void kernel_mul_mm2_impl(
for (int b0 = 0; b0 < NSH1/NSH0; ++b0) {
// load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory
{
const int64_t nload = min(8ll*NSG0, ne01 - i01) * (8*NSH0);
const int nload = MIN(8*NSG0, ne01 - i01) * (8*NSH0);
half4x4 tmp0;
for (int64_t i = 16*tiisg; i < nload; i += 16*NW) {
const int64_t ic = i%(8*NSH0);
const int64_t ir = i/(8*NSH0);
const size_t offs0 = (i13/r3)*(nb02*ne02) + (i12/r2)*nb02;
const int64_t icc = i00 + 8*b0*NSH0 + ic;
for (int i = 16*tiisg; i < nload; i += 16*NW) {
const int ic = i%(8*NSH0);
const int ir = i/(8*NSH0);
const int64_t ib = (icc/(16*nl));
const int64_t il = (icc%(16*nl))/16;
const int icc = i00 + 8*b0*NSH0 + ic;
device const block_q * p0 = (device const block_q *)(src0 + (i13/r3)*(nb02*ne02) + (i12/r2)*nb02 + (i01 + ir)*nb01) + ib;
const int ib = (icc/(16*nl));
const int il = (icc%(16*nl))/16;
device const block_q * p0 = (device const block_q *)(src0 + offs0 + (i01 + ir)*nb01) + ib;
dequantize_func(p0, il, tmp0);
if (icc + 16 <= ne00) {
s016[(8*NSH0*ir + ic)/16] = tmp0;
} else {
for (int k = 0; k < 4; k++){
if (icc + 4*k < ne00) {
if (icc + 4*k <= ne00) {
s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k];
} else {
s04[(8*NSH0*ir + ic)/4 + k] = 0.0h;
for (int p = 0; icc + 4*k + p < ne00; p++) {
s0[8*NSH0*ir + ic + 4*k + p] = tmp0[k][p];
}
}
}
}
}
@ -4958,12 +4970,12 @@ void kernel_mul_mm2_impl(
device float * pdst = dst + im*ne1*ne0;
for (int is = 0; is < NSG1; is++) {
const int64_t i1 = i11 + is*8;
const int64_t nstore = min(8ll*NSG1, ne1 - i1) * (8*NSG0);
const int i1 = i11 + is*8;
const int nstore = MIN(8*NSG1, ne1 - i1) * (8*NSG0);
for (int64_t i = tiisg; i < nstore; i += NW) {
const int64_t ic = i%(8*NSG0);
const int64_t ir = i/(8*NSG0);
for (int i = tiisg; i < nstore; i += NW) {
const int ic = i%(8*NSG0);
const int ir = i/(8*NSG0);
if (i1 + ir < ne1 && i01 + ic < ne0) {
pdst[(i1 + ir)*ne0 + (i01 + ic)] = r0[(8*is)*(8*NSG0) + 8*NSG0*ir + ic];