metal : add kernel arg structs (wip)

This commit is contained in:
Georgi Gerganov 2024-11-09 15:28:55 +02:00
parent 6423c65aa8
commit 996e479780
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 116 additions and 125 deletions

View File

@ -418,6 +418,36 @@ typedef struct {
} block_iq4_xs; } block_iq4_xs;
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
#if defined(GGML_COMMON_DECL_METAL_KARGS)
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t n_past;
int32_t n_dims;
int32_t n_ctx_orig;
float freq_base;
float freq_scale;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
} ggml_metal_kargs_rope;
#endif
#endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL
#endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL

View File

@ -3,6 +3,10 @@
#import "ggml-impl.h" #import "ggml-impl.h"
#import "ggml-backend-impl.h" #import "ggml-backend-impl.h"
#define GGML_COMMON_DECL_C
#define GGML_COMMON_DECL_METAL_KARGS
#include "ggml-common.h"
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#import <Metal/Metal.h> #import <Metal/Metal.h>
@ -2702,6 +2706,34 @@ static void ggml_metal_encode_node(
}; };
} }
ggml_metal_kargs_rope args = {
.ne00 = ne00,
.ne01 = ne01,
.ne02 = ne02,
.ne03 = ne03,
.nb00 = nb00,
.nb01 = nb01,
.nb02 = nb02,
.nb03 = nb03,
.ne0 = ne0,
.ne1 = ne1,
.ne2 = ne2,
.ne3 = ne3,
.nb0 = nb0,
.nb1 = nb1,
.nb2 = nb2,
.nb3 = nb3,
.n_past = n_past,
.n_dims = n_dims,
.n_ctx_orig = n_ctx_orig,
.freq_base = freq_base,
.freq_scale = freq_scale,
.ext_factor = ext_factor,
.attn_factor = attn_factor,
.beta_fast = beta_fast,
.beta_slow = beta_slow,
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@ -2711,31 +2743,7 @@ static void ggml_metal_encode_node(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
} }
[encoder setBuffer:id_dst offset:offs_dst atIndex:3]; [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; [encoder setBytes:&args length:sizeof(args) atIndex:4];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
[encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;

View File

@ -1,4 +1,5 @@
#define GGML_COMMON_DECL_METAL #define GGML_COMMON_DECL_METAL
#define GGML_COMMON_DECL_METAL_KARGS
#define GGML_COMMON_IMPL_METAL #define GGML_COMMON_IMPL_METAL
#include "ggml-common.h" #include "ggml-common.h"
@ -2229,7 +2230,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn( static void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
thread float * cos_theta, thread float * sin_theta) { thread float * cos_theta, thread float * sin_theta) {
// Get n-d rotational scaling corrected for extrapolation // Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap; float theta_interp = freq_scale * theta_extrap;
@ -2261,65 +2262,41 @@ static void rope_yarn_corr_dims(
template<typename T> template<typename T>
kernel void kernel_rope_norm( kernel void kernel_rope_norm(
device const void * src0, device const char * src0,
device const int32_t * src1, device const char * src1,
device const float * src2, device const char * src2,
device float * dst, device char * dst,
constant int64_t & ne00, constant ggml_metal_kargs_rope & args,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int & n_past,
constant int & n_dims,
constant int & n_ctx_orig,
constant float & freq_base,
constant float & freq_scale,
constant float & ext_factor,
constant float & attn_factor,
constant float & beta_fast,
constant float & beta_slow,
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]], uint3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) { uint3 tgpig[[threadgroup_position_in_grid]]) {
const int64_t i3 = tgpig[2]; const int i3 = tgpig[2];
const int64_t i2 = tgpig[1]; const int i2 = tgpig[1];
const int64_t i1 = tgpig[0]; const int i1 = tgpig[0];
float corr_dims[2]; float corr_dims[2];
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = src1; device const int32_t * pos = (device const int32_t *) src1;
const float theta_base = (float) pos[i2]; const float theta_base = (float) pos[i2];
const float inv_ndims = -1.f/n_dims; const float inv_ndims = -1.f/args.n_dims;
float cos_theta; float cos_theta;
float sin_theta; float sin_theta;
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < n_dims) { if (i0 < args.n_dims) {
const int64_t ic = i0/2; const int ic = i0/2;
const float theta = theta_base * pow(freq_base, inv_ndims*i0); const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
const float x0 = src[0]; const float x0 = src[0];
const float x1 = src[1]; const float x1 = src[1];
@ -2327,8 +2304,8 @@ kernel void kernel_rope_norm(
dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta; dst_data[1] = x0*sin_theta + x1*cos_theta;
} else { } else {
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0]; dst_data[0] = src[0];
dst_data[1] = src[1]; dst_data[1] = src[1];
@ -2338,74 +2315,50 @@ kernel void kernel_rope_norm(
template<typename T> template<typename T>
kernel void kernel_rope_neox( kernel void kernel_rope_neox(
device const void * src0, device const char * src0,
device const int32_t * src1, device const char * src1,
device const float * src2, device const char * src2,
device float * dst, device char * dst,
constant int64_t & ne00, constant ggml_metal_kargs_rope & args,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int & n_past,
constant int & n_dims,
constant int & n_ctx_orig,
constant float & freq_base,
constant float & freq_scale,
constant float & ext_factor,
constant float & attn_factor,
constant float & beta_fast,
constant float & beta_slow,
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]], uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) { uint3 tgpig[[threadgroup_position_in_grid]]) {
const int64_t i3 = tgpig[2]; const int i3 = tgpig[2];
const int64_t i2 = tgpig[1]; const int i2 = tgpig[1];
const int64_t i1 = tgpig[0]; const int i1 = tgpig[0];
float corr_dims[2]; float corr_dims[2];
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = src1; device const int32_t * pos = (device const int32_t *) src1;
const float theta_base = (float) pos[i2]; const float theta_base = (float) pos[i2];
const float inv_ndims = -1.f/n_dims; const float inv_ndims = -1.f/args.n_dims;
float cos_theta; float cos_theta;
float sin_theta; float sin_theta;
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < n_dims) { if (i0 < args.n_dims) {
const int64_t ic = i0/2; const int ic = i0/2;
const float theta = theta_base * pow(freq_base, inv_ndims*i0); const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0]; const float x0 = src[0];
const float x1 = src[n_dims/2]; const float x1 = src[args.n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else { } else {
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0]; dst_data[0] = src[0];
dst_data[1] = src[1]; dst_data[1] = src[1];