metal : alibi for arbitrary number of heads (#3426)

This commit is contained in:
Jiahao Li 2023-10-04 00:55:21 +08:00 committed by GitHub
parent 017efe899d
commit f56e1baec3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 7 deletions

View File

@ -1213,12 +1213,9 @@ void ggml_metal_graph_compute(
float max_bias; float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
if (__builtin_popcount(n_head) != 1) {
GGML_ASSERT(false && "only power-of-two n_head implemented");
}
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
[encoder setComputePipelineState:ctx->pipeline_alibi_f32]; [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1239,7 +1236,9 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&m0 length:sizeof( float) atIndex:18]; [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;

View File

@ -830,7 +830,9 @@ kernel void kernel_alibi_f32(
constant uint64_t & nb1, constant uint64_t & nb1,
constant uint64_t & nb2, constant uint64_t & nb2,
constant uint64_t & nb3, constant uint64_t & nb3,
constant float & m0, constant float & m0,
constant float & m1,
constant int & n_heads_log2_floor,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { uint3 ntg[[threads_per_threadgroup]]) {
@ -846,7 +848,12 @@ kernel void kernel_alibi_f32(
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
float m_k = pow(m0, i2 + 1); float m_k;
if (i2 < n_heads_log2_floor) {
m_k = pow(m0, i2 + 1);
} else {
m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
}
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1); dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);