metal : fix build errors and kernel sig after #2268 (#3898)

This commit is contained in:
Georgi Gerganov 2023-11-02 08:33:37 +02:00 committed by GitHub
parent 2fffa0d61f
commit 183b3fac6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 33 deletions

View File

@ -1441,12 +1441,13 @@ void ggml_metal_graph_compute(
[encoder setBytes:&n_past length:sizeof( int) atIndex:19]; [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20]; [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
[encoder setBytes:&mode length:sizeof( int) atIndex:21]; [encoder setBytes:&mode length:sizeof( int) atIndex:21];
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22]; [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23]; [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
[encoder setBytes:&ext_factor length:sizeof(float) atIndex:24]; [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
[encoder setBytes:&attn_factor length:sizeof(float) atIndex:25]; [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
[encoder setBytes:&beta_fast length:sizeof(float) atIndex:26]; [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
[encoder setBytes:&beta_slow length:sizeof(float) atIndex:27]; [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;

View File

@ -1070,20 +1070,20 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn( static void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
float * cos_theta, float * sin_theta thread float * cos_theta, thread float * sin_theta
) { ) {
// Get n-d rotational scaling corrected for extrapolation // Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap; float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp; float theta = theta_interp;
if (ext_factor != 0.0f) { if (ext_factor != 0.0f) {
ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation // Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
} }
*cos_theta = cosf(theta) * mscale; *cos_theta = cos(theta) * mscale;
*sin_theta = sinf(theta) * mscale; *sin_theta = sin(theta) * mscale;
} }
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
@ -1123,8 +1123,13 @@ typedef void (rope_t)(
constant int & n_past, constant int & n_past,
constant int & n_dims, constant int & n_dims,
constant int & mode, constant int & mode,
constant int & n_orig_ctx,
constant float & freq_base, constant float & freq_base,
constant float & freq_scale, constant float & freq_scale,
constant float & ext_factor,
constant float & attn_factor,
constant float & beta_fast,
constant float & beta_slow,
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]], uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]); uint3 tgpig[[threadgroup_position_in_grid]]);
@ -1153,6 +1158,7 @@ kernel void kernel_rope(
constant int & n_past, constant int & n_past,
constant int & n_dims, constant int & n_dims,
constant int & mode, constant int & mode,
constant int & n_orig_ctx,
constant float & freq_base, constant float & freq_base,
constant float & freq_scale, constant float & freq_scale,
constant float & ext_factor, constant float & ext_factor,