metal : fix build errors and kernel sig after #2268 (#3898)

This commit is contained in:
Georgi Gerganov 2023-11-02 08:33:37 +02:00 committed by GitHub
parent 2fffa0d61f
commit 183b3fac6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 33 deletions

View File

@ -1441,12 +1441,13 @@ void ggml_metal_graph_compute(
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
[encoder setBytes:&ext_factor length:sizeof(float) atIndex:24];
[encoder setBytes:&attn_factor length:sizeof(float) atIndex:25];
[encoder setBytes:&beta_fast length:sizeof(float) atIndex:26];
[encoder setBytes:&beta_slow length:sizeof(float) atIndex:27];
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;

View File

@ -1070,20 +1070,20 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
float * cos_theta, float * sin_theta
thread float * cos_theta, thread float * sin_theta
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
*cos_theta = cos(theta) * mscale;
*sin_theta = sin(theta) * mscale;
}
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
@ -1123,8 +1123,13 @@ typedef void (rope_t)(
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant int & n_orig_ctx,
constant float & freq_base,
constant float & freq_scale,
constant float & ext_factor,
constant float & attn_factor,
constant float & beta_fast,
constant float & beta_slow,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]);
@ -1153,6 +1158,7 @@ kernel void kernel_rope(
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant int & n_orig_ctx,
constant float & freq_base,
constant float & freq_scale,
constant float & ext_factor,