metal : opts

This commit is contained in:
Georgi Gerganov 2024-02-07 23:12:22 +02:00
parent 92a0c17474
commit e68e32548f
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 62 additions and 50 deletions

View File

@ -1343,7 +1343,7 @@ static bool ggml_metal_graph_compute(
const int nsg = 8; const int nsg = 8;
const int nsg0 = 1; const int nsg0 = 1;
const int nsh0 = 8; const int nsh0 = 16;
const int nsg1 = 1; const int nsg1 = 1;
const int nsh1 = 64; const int nsh1 = 64;

View File

@ -4785,7 +4785,7 @@ void kernel_mul_mm_impl(
} }
#define NSG0 1 #define NSG0 1
#define NSH0 8 #define NSH0 16
#define NSG1 1 #define NSG1 1
#define NSH1 64 #define NSH1 64
@ -4815,33 +4815,34 @@ void kernel_mul_mm2_impl(
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
const uint nsg = ntg.y; // number of simdgroups const uint nsg = ntg.y; // number of simdgroups
const int64_t im = tgpig[2]; const int im = tgpig[2];
const int64_t i11 = tgpig[1]*(8*NSG1); const int i11 = tgpig[1]*(8*NSG1);
const int64_t i01 = tgpig[0]*(8*NSG0*nsg) + sgitg*(8*NSG0); const int i01 = tgpig[0]*(8*NSG0*nsg) + sgitg*(8*NSG0);
const int64_t i12 = im%ne12; const int i12 = im%ne12;
const int64_t i13 = im/ne12; const int i13 = im/ne12;
const int64_t ne01 = ne0; const int ne01 = ne0;
const int64_t ne11 = ne1; const int ne11 = ne1;
const int64_t NW = N_SIMDWIDTH; const int NW = N_SIMDWIDTH;
const int64_t SH0 = (8*NSG0)*(8*NSH0); // shread memory per threadgroup for src0 data in (half) const int SH0 = (8*NSG0)*(8*NSH0); // shread memory per threadgroup for src0 data in (half)
const int64_t SH04 = SH0/4; // shread memory per threadgroup for src0 data in (half4) const int SH04 = SH0/4; // shread memory per threadgroup for src0 data in (half4)
const int64_t SH1 = (8*NSG1)*(8*NSH1); // shread memory for src1 data in (float) const int SH1 = (8*NSG1)*(8*NSH1); // shread memory for src1 data in (float)
const int64_t SH14 = SH1/4; // shread memory for src1 data in (float4) const int SH14 = SH1/4; // shread memory for src1 data in (float4)
const int64_t T1 = 8*NSH1; // row of src1 in shared memory in (float) const int T1 = 8*NSH1; // row of src1 in shared memory in (float)
const int64_t T14 = T1/4; // row of src1 in shared memory in (float4) const int T14 = T1/4; // row of src1 in shared memory in (float4)
threadgroup half * shared = (threadgroup half *) shared_u8; threadgroup half * shared = (threadgroup half *) shared_u8;
threadgroup half * s0 = (threadgroup half *)(shared + sgitg*SH0); threadgroup half * s0 = (threadgroup half *)(shared + sgitg*SH0);
threadgroup half4 * s04 = (threadgroup half4 *)(shared + sgitg*SH0); threadgroup half4 * s04 = (threadgroup half4 *)(shared + sgitg*SH0);
threadgroup float * s1 = (threadgroup float *)(shared + nsg*SH0); threadgroup half4x4 * s016 = (threadgroup half4x4 *)(shared + sgitg*SH0);
threadgroup float4 * s14 = (threadgroup float4 *)(shared + nsg*SH0); threadgroup float * s1 = (threadgroup float *)(shared + nsg*SH0);
threadgroup float4 * s14 = (threadgroup float4 *)(shared + nsg*SH0);
threadgroup float * r0 = (threadgroup float *)(shared + 2*sgitg*(8*NSG0)*(8*NSG1)); threadgroup float * r0 = (threadgroup float *)(shared + 2*sgitg*(8*NSG0)*(8*NSG1));
@ -4850,12 +4851,12 @@ void kernel_mul_mm2_impl(
simdgroup_float8x8 mr[NSG0][NSG1]; simdgroup_float8x8 mr[NSG0][NSG1];
// zero out shared memory SH0 for src0 // zero out shared memory SH0 for src0
for (int64_t i = tiisg; i < SH04; i += NW) { for (int i = tiisg; i < SH04; i += NW) {
s04[i] = 0.0h; s04[i] = 0.0h;
} }
// zero out shared memory SH1 for src1 // zero out shared memory SH1 for src1
for (int64_t i = tiitg; i < SH14; i += nsg*NW) { for (int i = tiitg; i < SH14; i += nsg*NW) {
s14[i] = 0.0f; s14[i] = 0.0f;
} }
@ -4868,24 +4869,27 @@ void kernel_mul_mm2_impl(
} }
} }
for (int64_t i00 = 0; i00 < ne00; i00 += 8*NSH1) { for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) {
// load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory // load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory
{ {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
const int64_t nload = min(8ll*NSG1, ne11 - i11) * (8*NSH1); const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1);
for (int64_t i = tiitg; i < nload; i += nsg*NW) { const size_t offs0 = im*nb12;
const int64_t ic = i%(8*NSH1);
const int64_t ir = i/(8*NSH1);
// TODO: use float4 for (int i = 4*tiitg; i < nload; i += 4*nsg*NW) {
device const float * p1 = (device const float *)(src1 + im*nb12 + (i11 + ir)*nb11 + (i00 + ic)*nb10); const int ic = i%(8*NSH1);
const int ir = i/(8*NSH1);
if (i00 + ic < ne00) { device const float4 * p1 = (device const float4 *)(src1 + offs0 + (i11 + ir)*nb11 + (i00 + ic)*nb10);
s1[8*NSH1*ir + ic] = *p1;
if (i00 + ic + 4 <= ne00) {
s14[(8*NSH1*ir + ic)/4] = *p1;
} else { } else {
s1[8*NSH1*ir + ic] = 0.0f; for (int k = 0; i00 + ic + k < ne00; k++){
s1[8*NSH1*ir + ic + k] = (*p1)[k];
}
} }
} }
@ -4895,28 +4899,36 @@ void kernel_mul_mm2_impl(
for (int b0 = 0; b0 < NSH1/NSH0; ++b0) { for (int b0 = 0; b0 < NSH1/NSH0; ++b0) {
// load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory // load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory
{ {
const int64_t nload = min(8ll*NSG0, ne01 - i01) * (8*NSH0); const int nload = MIN(8*NSG0, ne01 - i01) * (8*NSH0);
half4x4 tmp0; half4x4 tmp0;
for (int64_t i = 16*tiisg; i < nload; i += 16*NW) { const size_t offs0 = (i13/r3)*(nb02*ne02) + (i12/r2)*nb02;
const int64_t ic = i%(8*NSH0);
const int64_t ir = i/(8*NSH0);
const int64_t icc = i00 + 8*b0*NSH0 + ic; for (int i = 16*tiisg; i < nload; i += 16*NW) {
const int ic = i%(8*NSH0);
const int ir = i/(8*NSH0);
const int64_t ib = (icc/(16*nl)); const int icc = i00 + 8*b0*NSH0 + ic;
const int64_t il = (icc%(16*nl))/16;
device const block_q * p0 = (device const block_q *)(src0 + (i13/r3)*(nb02*ne02) + (i12/r2)*nb02 + (i01 + ir)*nb01) + ib; const int ib = (icc/(16*nl));
const int il = (icc%(16*nl))/16;
device const block_q * p0 = (device const block_q *)(src0 + offs0 + (i01 + ir)*nb01) + ib;
dequantize_func(p0, il, tmp0); dequantize_func(p0, il, tmp0);
for (int k = 0; k < 4; k++){ if (icc + 16 <= ne00) {
if (icc + 4*k < ne00) { s016[(8*NSH0*ir + ic)/16] = tmp0;
s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k]; } else {
} else { for (int k = 0; k < 4; k++){
s04[(8*NSH0*ir + ic)/4 + k] = 0.0h; if (icc + 4*k <= ne00) {
s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k];
} else {
for (int p = 0; icc + 4*k + p < ne00; p++) {
s0[8*NSH0*ir + ic + 4*k + p] = tmp0[k][p];
}
}
} }
} }
} }
@ -4958,12 +4970,12 @@ void kernel_mul_mm2_impl(
device float * pdst = dst + im*ne1*ne0; device float * pdst = dst + im*ne1*ne0;
for (int is = 0; is < NSG1; is++) { for (int is = 0; is < NSG1; is++) {
const int64_t i1 = i11 + is*8; const int i1 = i11 + is*8;
const int64_t nstore = min(8ll*NSG1, ne1 - i1) * (8*NSG0); const int nstore = MIN(8*NSG1, ne1 - i1) * (8*NSG0);
for (int64_t i = tiisg; i < nstore; i += NW) { for (int i = tiisg; i < nstore; i += NW) {
const int64_t ic = i%(8*NSG0); const int ic = i%(8*NSG0);
const int64_t ir = i/(8*NSG0); const int ir = i/(8*NSG0);
if (i1 + ir < ne1 && i01 + ic < ne0) { if (i1 + ir < ne1 && i01 + ic < ne0) {
pdst[(i1 + ir)*ne0 + (i01 + ic)] = r0[(8*is)*(8*NSG0) + 8*NSG0*ir + ic]; pdst[(i1 + ir)*ne0 + (i01 + ic)] = r0[(8*is)*(8*NSG0) + 8*NSG0*ir + ic];