mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
metal : add kernel arg structs (wip)
This commit is contained in:
parent
6423c65aa8
commit
996e479780
@ -418,6 +418,36 @@ typedef struct {
|
|||||||
} block_iq4_xs;
|
} block_iq4_xs;
|
||||||
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
|
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
|
||||||
|
|
||||||
|
#if defined(GGML_COMMON_DECL_METAL_KARGS)
|
||||||
|
typedef struct {
|
||||||
|
int32_t ne00;
|
||||||
|
int32_t ne01;
|
||||||
|
int32_t ne02;
|
||||||
|
int32_t ne03;
|
||||||
|
uint64_t nb00;
|
||||||
|
uint64_t nb01;
|
||||||
|
uint64_t nb02;
|
||||||
|
uint64_t nb03;
|
||||||
|
int32_t ne0;
|
||||||
|
int32_t ne1;
|
||||||
|
int32_t ne2;
|
||||||
|
int32_t ne3;
|
||||||
|
uint64_t nb0;
|
||||||
|
uint64_t nb1;
|
||||||
|
uint64_t nb2;
|
||||||
|
uint64_t nb3;
|
||||||
|
int32_t n_past;
|
||||||
|
int32_t n_dims;
|
||||||
|
int32_t n_ctx_orig;
|
||||||
|
float freq_base;
|
||||||
|
float freq_scale;
|
||||||
|
float ext_factor;
|
||||||
|
float attn_factor;
|
||||||
|
float beta_fast;
|
||||||
|
float beta_slow;
|
||||||
|
} ggml_metal_kargs_rope;
|
||||||
|
#endif
|
||||||
|
|
||||||
#endif // GGML_COMMON_DECL
|
#endif // GGML_COMMON_DECL
|
||||||
#endif // GGML_COMMON_DECL
|
#endif // GGML_COMMON_DECL
|
||||||
|
|
||||||
|
@ -3,6 +3,10 @@
|
|||||||
#import "ggml-impl.h"
|
#import "ggml-impl.h"
|
||||||
#import "ggml-backend-impl.h"
|
#import "ggml-backend-impl.h"
|
||||||
|
|
||||||
|
#define GGML_COMMON_DECL_C
|
||||||
|
#define GGML_COMMON_DECL_METAL_KARGS
|
||||||
|
#include "ggml-common.h"
|
||||||
|
|
||||||
#import <Foundation/Foundation.h>
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
#import <Metal/Metal.h>
|
#import <Metal/Metal.h>
|
||||||
@ -2702,40 +2706,44 @@ static void ggml_metal_encode_node(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_kargs_rope args = {
|
||||||
|
.ne00 = ne00,
|
||||||
|
.ne01 = ne01,
|
||||||
|
.ne02 = ne02,
|
||||||
|
.ne03 = ne03,
|
||||||
|
.nb00 = nb00,
|
||||||
|
.nb01 = nb01,
|
||||||
|
.nb02 = nb02,
|
||||||
|
.nb03 = nb03,
|
||||||
|
.ne0 = ne0,
|
||||||
|
.ne1 = ne1,
|
||||||
|
.ne2 = ne2,
|
||||||
|
.ne3 = ne3,
|
||||||
|
.nb0 = nb0,
|
||||||
|
.nb1 = nb1,
|
||||||
|
.nb2 = nb2,
|
||||||
|
.nb3 = nb3,
|
||||||
|
.n_past = n_past,
|
||||||
|
.n_dims = n_dims,
|
||||||
|
.n_ctx_orig = n_ctx_orig,
|
||||||
|
.freq_base = freq_base,
|
||||||
|
.freq_scale = freq_scale,
|
||||||
|
.ext_factor = ext_factor,
|
||||||
|
.attn_factor = attn_factor,
|
||||||
|
.beta_fast = beta_fast,
|
||||||
|
.beta_slow = beta_slow,
|
||||||
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
if (id_src2 != nil) {
|
if (id_src2 != nil) {
|
||||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||||
} else {
|
} else {
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
|
[encoder setBytes:&args length:sizeof(args) atIndex:4];
|
||||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
|
||||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
|
||||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
|
||||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
|
|
||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
|
|
||||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
|
|
||||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
|
|
||||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
|
|
||||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
|
|
||||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
|
|
||||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
|
|
||||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
|
|
||||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
|
|
||||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
|
|
||||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
|
||||||
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
|
||||||
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
|
||||||
[encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
|
|
||||||
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
|
||||||
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
|
||||||
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
|
||||||
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
|
||||||
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
|
||||||
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#define GGML_COMMON_DECL_METAL
|
#define GGML_COMMON_DECL_METAL
|
||||||
|
#define GGML_COMMON_DECL_METAL_KARGS
|
||||||
#define GGML_COMMON_IMPL_METAL
|
#define GGML_COMMON_IMPL_METAL
|
||||||
#include "ggml-common.h"
|
#include "ggml-common.h"
|
||||||
|
|
||||||
@ -2229,7 +2230,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|||||||
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||||
static void rope_yarn(
|
static void rope_yarn(
|
||||||
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
|
||||||
thread float * cos_theta, thread float * sin_theta) {
|
thread float * cos_theta, thread float * sin_theta) {
|
||||||
// Get n-d rotational scaling corrected for extrapolation
|
// Get n-d rotational scaling corrected for extrapolation
|
||||||
float theta_interp = freq_scale * theta_extrap;
|
float theta_interp = freq_scale * theta_extrap;
|
||||||
@ -2261,65 +2262,41 @@ static void rope_yarn_corr_dims(
|
|||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
kernel void kernel_rope_norm(
|
kernel void kernel_rope_norm(
|
||||||
device const void * src0,
|
device const char * src0,
|
||||||
device const int32_t * src1,
|
device const char * src1,
|
||||||
device const float * src2,
|
device const char * src2,
|
||||||
device float * dst,
|
device char * dst,
|
||||||
constant int64_t & ne00,
|
constant ggml_metal_kargs_rope & args,
|
||||||
constant int64_t & ne01,
|
|
||||||
constant int64_t & ne02,
|
|
||||||
constant int64_t & ne03,
|
|
||||||
constant uint64_t & nb00,
|
|
||||||
constant uint64_t & nb01,
|
|
||||||
constant uint64_t & nb02,
|
|
||||||
constant uint64_t & nb03,
|
|
||||||
constant int64_t & ne0,
|
|
||||||
constant int64_t & ne1,
|
|
||||||
constant int64_t & ne2,
|
|
||||||
constant int64_t & ne3,
|
|
||||||
constant uint64_t & nb0,
|
|
||||||
constant uint64_t & nb1,
|
|
||||||
constant uint64_t & nb2,
|
|
||||||
constant uint64_t & nb3,
|
|
||||||
constant int & n_past,
|
|
||||||
constant int & n_dims,
|
|
||||||
constant int & n_ctx_orig,
|
|
||||||
constant float & freq_base,
|
|
||||||
constant float & freq_scale,
|
|
||||||
constant float & ext_factor,
|
|
||||||
constant float & attn_factor,
|
|
||||||
constant float & beta_fast,
|
|
||||||
constant float & beta_slow,
|
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
uint3 tptg[[threads_per_threadgroup]],
|
uint3 tptg [[threads_per_threadgroup]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||||
const int64_t i3 = tgpig[2];
|
const int i3 = tgpig[2];
|
||||||
const int64_t i2 = tgpig[1];
|
const int i2 = tgpig[1];
|
||||||
const int64_t i1 = tgpig[0];
|
const int i1 = tgpig[0];
|
||||||
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||||
|
|
||||||
device const int32_t * pos = src1;
|
device const int32_t * pos = (device const int32_t *) src1;
|
||||||
|
|
||||||
const float theta_base = (float) pos[i2];
|
const float theta_base = (float) pos[i2];
|
||||||
const float inv_ndims = -1.f/n_dims;
|
const float inv_ndims = -1.f/args.n_dims;
|
||||||
|
|
||||||
float cos_theta;
|
float cos_theta;
|
||||||
float sin_theta;
|
float sin_theta;
|
||||||
|
|
||||||
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||||
if (i0 < n_dims) {
|
if (i0 < args.n_dims) {
|
||||||
const int64_t ic = i0/2;
|
const int ic = i0/2;
|
||||||
|
|
||||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||||
|
|
||||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||||
|
|
||||||
const float x0 = src[0];
|
const float x0 = src[0];
|
||||||
const float x1 = src[1];
|
const float x1 = src[1];
|
||||||
@ -2327,8 +2304,8 @@ kernel void kernel_rope_norm(
|
|||||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||||
} else {
|
} else {
|
||||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||||
|
|
||||||
dst_data[0] = src[0];
|
dst_data[0] = src[0];
|
||||||
dst_data[1] = src[1];
|
dst_data[1] = src[1];
|
||||||
@ -2338,74 +2315,50 @@ kernel void kernel_rope_norm(
|
|||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
kernel void kernel_rope_neox(
|
kernel void kernel_rope_neox(
|
||||||
device const void * src0,
|
device const char * src0,
|
||||||
device const int32_t * src1,
|
device const char * src1,
|
||||||
device const float * src2,
|
device const char * src2,
|
||||||
device float * dst,
|
device char * dst,
|
||||||
constant int64_t & ne00,
|
constant ggml_metal_kargs_rope & args,
|
||||||
constant int64_t & ne01,
|
|
||||||
constant int64_t & ne02,
|
|
||||||
constant int64_t & ne03,
|
|
||||||
constant uint64_t & nb00,
|
|
||||||
constant uint64_t & nb01,
|
|
||||||
constant uint64_t & nb02,
|
|
||||||
constant uint64_t & nb03,
|
|
||||||
constant int64_t & ne0,
|
|
||||||
constant int64_t & ne1,
|
|
||||||
constant int64_t & ne2,
|
|
||||||
constant int64_t & ne3,
|
|
||||||
constant uint64_t & nb0,
|
|
||||||
constant uint64_t & nb1,
|
|
||||||
constant uint64_t & nb2,
|
|
||||||
constant uint64_t & nb3,
|
|
||||||
constant int & n_past,
|
|
||||||
constant int & n_dims,
|
|
||||||
constant int & n_ctx_orig,
|
|
||||||
constant float & freq_base,
|
|
||||||
constant float & freq_scale,
|
|
||||||
constant float & ext_factor,
|
|
||||||
constant float & attn_factor,
|
|
||||||
constant float & beta_fast,
|
|
||||||
constant float & beta_slow,
|
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
uint3 tptg[[threads_per_threadgroup]],
|
uint3 tptg[[threads_per_threadgroup]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||||
const int64_t i3 = tgpig[2];
|
const int i3 = tgpig[2];
|
||||||
const int64_t i2 = tgpig[1];
|
const int i2 = tgpig[1];
|
||||||
const int64_t i1 = tgpig[0];
|
const int i1 = tgpig[0];
|
||||||
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||||
|
|
||||||
device const int32_t * pos = src1;
|
device const int32_t * pos = (device const int32_t *) src1;
|
||||||
|
|
||||||
const float theta_base = (float) pos[i2];
|
const float theta_base = (float) pos[i2];
|
||||||
const float inv_ndims = -1.f/n_dims;
|
const float inv_ndims = -1.f/args.n_dims;
|
||||||
|
|
||||||
float cos_theta;
|
float cos_theta;
|
||||||
float sin_theta;
|
float sin_theta;
|
||||||
|
|
||||||
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||||
if (i0 < n_dims) {
|
if (i0 < args.n_dims) {
|
||||||
const int64_t ic = i0/2;
|
const int ic = i0/2;
|
||||||
|
|
||||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||||
|
|
||||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||||
|
|
||||||
const float x0 = src[0];
|
const float x0 = src[0];
|
||||||
const float x1 = src[n_dims/2];
|
const float x1 = src[args.n_dims/2];
|
||||||
|
|
||||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||||
} else {
|
} else {
|
||||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||||
|
|
||||||
dst_data[0] = src[0];
|
dst_data[0] = src[0];
|
||||||
dst_data[1] = src[1];
|
dst_data[1] = src[1];
|
||||||
|
Loading…
Reference in New Issue
Block a user