metal : attempt to adapt SSM_SCAN for Mamba-2

This commit is contained in:
Francis Couture-Harpin 2024-10-02 10:36:22 -04:00
parent 7d6cb36895
commit 2c77d799f9
2 changed files with 200 additions and 49 deletions

View File

@ -95,6 +95,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_NORM,
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
@ -591,6 +592,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
@ -1629,47 +1631,74 @@ static void ggml_metal_encode_node(
struct ggml_tensor * src3 = node->src[3];
struct ggml_tensor * src4 = node->src[4];
struct ggml_tensor * src5 = node->src[5];
struct ggml_tensor * src6 = node->src[6];
struct ggml_tensor * src7 = node->src[7];
GGML_ASSERT(src3);
GGML_ASSERT(src4);
GGML_ASSERT(src5);
GGML_ASSERT(src6);
GGML_ASSERT(src7);
size_t offs_src3 = 0;
size_t offs_src4 = 0;
size_t offs_src5 = 0;
size_t offs_src6 = 0;
size_t offs_src7 = 0;
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
id<MTLBuffer> id_src7 = src7 ? ggml_metal_get_buffer(src7, &offs_src7) : nil;
const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
const int64_t ne30 = src3->ne[0];
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
const uint64_t nb30 = src3->nb[0];
const uint64_t nb31 = src3->nb[1];
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
const int64_t ne41 = src4->ne[1];
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
const uint64_t nb40 = src4->nb[0];
const uint64_t nb41 = src4->nb[1];
const uint64_t nb42 = src4->nb[2];
const uint64_t nb43 = src4->nb[3];
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
const uint64_t nb50 = src5->nb[0];
const uint64_t nb51 = src5->nb[1];
const uint64_t nb52 = src5->nb[2];
const uint64_t nb53 = src5->nb[3];
const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
const uint64_t nb60 = src6->nb[0];
const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70);
const uint64_t nb70 = src7->nb[0];
const int64_t d_state = ne00;
const int64_t d_inner = ne01;
const int64_t n_head = ne02;
const int64_t n_group = ne41;
const int64_t n_seq_tokens = ne11;
const int64_t n_seqs = ne02;
const int64_t n_seqs = ne13;
if (ne30 == 1) {
// Mamba-2
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
} else {
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1678,33 +1707,49 @@ static void ggml_metal_encode_node(
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
[encoder setBuffer:id_src7 offset:offs_src7 atIndex:7];
[encoder setBuffer:id_dst offset:offs_dst atIndex:8];
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:9];
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:10];
[encoder setBytes:&n_head length:sizeof(n_head) atIndex:11];
[encoder setBytes:&n_group length:sizeof(n_group) atIndex:12];
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13];
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:15];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:16];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:17];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:18];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:19];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:20];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:21];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:22];
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:23];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:24];
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:25];
[encoder setBytes:&nb23 length:sizeof(nb23) atIndex:26];
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:27];
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:28];
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:29];
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:30];
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:31];
[encoder setBytes:&nb43 length:sizeof(nb43) atIndex:32];
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:33];
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:34];
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:35];
[encoder setBytes:&nb53 length:sizeof(nb53) atIndex:36];
[encoder setBytes:&nb60 length:sizeof(nb60) atIndex:37];
[encoder setBytes:&nb70 length:sizeof(nb70) atIndex:38];
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
if (ne30 == 1) {
// Mamba-2
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
GGML_ASSERT(d_inner == 1);
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
}
} break;
case GGML_OP_MUL_MAT:
{

View File

@ -795,7 +795,7 @@ kernel void kernel_ssm_conv_f32(
x[0] = sumf;
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
// TODO: optimize
kernel void kernel_ssm_scan_f32(
device const void * src0,
@ -804,14 +804,19 @@ kernel void kernel_ssm_scan_f32(
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device const void * src7,
device float * dst,
constant int64_t & d_state,
constant int64_t & d_inner,
constant int64_t & n_head,
constant int64_t & n_group,
constant int64_t & n_seq_tokens,
constant int64_t & n_seqs,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
@ -824,47 +829,148 @@ kernel void kernel_ssm_scan_f32(
constant uint64_t & nb40,
constant uint64_t & nb41,
constant uint64_t & nb42,
constant uint64_t & nb43,
constant uint64_t & nb50,
constant uint64_t & nb51,
constant uint64_t & nb52,
constant uint64_t & nb53,
constant uint64_t & nb60,
constant uint64_t & nb70,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i3 = tgpig.y;
const int64_t i1 = 0;
const int64_t ir = tgpig.x; // current head
const int64_t i3 = tgpig.y; // current seq
const int64_t nc = d_state;
const int64_t nr = d_inner;
const int64_t nh = n_head;
const int64_t ng = n_group;
const int64_t n_t = n_seq_tokens;
const int64_t n_s = n_seqs;
const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float);
device const int32_t * ids = (device const int32_t *) src7;
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03);
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns}
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {d_state, nh}
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns}
if (i2 > 0) {
s0 = s;
}
// i1 == 0
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
float x_dt = x[0] * dt_soft_plus;
const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
int64_t i = i0;
float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
const int64_t i = i0 + i1*nc;
const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
}
y[0] = sumf;
y[0] = sumf + x[0] * D[0];
// recurse
s0 = s;
}
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// TODO: optimize (e.g. by parallelizing over d_state)
kernel void kernel_ssm_scan_f32_group(
device const void * src0,
device const void * src1,
device const void * src2,
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device const void * src7,
device float * dst,
constant int64_t & d_state,
constant int64_t & d_inner,
constant int64_t & n_head,
constant int64_t & n_group,
constant int64_t & n_seq_tokens,
constant int64_t & n_seqs,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant uint64_t & nb20,
constant uint64_t & nb21,
constant uint64_t & nb22,
constant uint64_t & nb30,
constant uint64_t & nb31,
constant uint64_t & nb40,
constant uint64_t & nb41,
constant uint64_t & nb42,
constant uint64_t & nb43,
constant uint64_t & nb50,
constant uint64_t & nb51,
constant uint64_t & nb52,
constant uint64_t & nb53,
constant uint64_t & nb60,
constant uint64_t & nb70,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i1 = tgpig.x;
const int64_t ir = tgpig.y; // current head
const int64_t i3 = tgpig.z; // current seq
const int64_t nc = d_state;
const int64_t nr = d_inner;
const int64_t nh = n_head;
const int64_t ng = n_group;
const int64_t n_t = n_seq_tokens;
const int64_t n_s = n_seqs;
const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float);
device const int32_t * ids = (device const int32_t *) src7;
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03);
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns}
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh}
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns}
const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = expf(dt_soft_plus * A[0]);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
const int64_t i = i0 + i1*nc;
const float state = (s0[i] * dA) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
}
y[0] = sumf + x[0] * D[0];
// recurse
s0 = s;
}
}