mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
metal : parallel RoPE on Metal (#3024)
* Parallel RoPE on metal * PR suggestion --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
be6beeb8d7
commit
be8c9c245b
@ -1141,7 +1141,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
|
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
|
||||||
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
|
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
|
@ -682,25 +682,27 @@ kernel void kernel_rope(
|
|||||||
constant int & mode,
|
constant int & mode,
|
||||||
constant float & freq_base,
|
constant float & freq_base,
|
||||||
constant float & freq_scale,
|
constant float & freq_scale,
|
||||||
uint3 tpig[[thread_position_in_grid]]) {
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
const int64_t i3 = tpig[2];
|
uint3 tptg[[threads_per_threadgroup]],
|
||||||
const int64_t i2 = tpig[1];
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||||
const int64_t i1 = tpig[0];
|
const int64_t i3 = tgpig[2];
|
||||||
|
const int64_t i2 = tgpig[1];
|
||||||
|
const int64_t i1 = tgpig[0];
|
||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
const float theta_scale = pow(freq_base, -2.0f/n_dims);
|
|
||||||
|
|
||||||
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
||||||
|
|
||||||
float theta = freq_scale * (float)p;
|
const float theta_0 = freq_scale * (float)p;
|
||||||
|
const float inv_ndims = -1.f/n_dims;
|
||||||
|
|
||||||
if (!is_neox) {
|
if (!is_neox) {
|
||||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
||||||
|
|
||||||
|
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
||||||
const float cos_theta = cos(theta);
|
const float cos_theta = cos(theta);
|
||||||
const float sin_theta = sin(theta);
|
const float sin_theta = sin(theta);
|
||||||
|
|
||||||
theta *= theta_scale;
|
|
||||||
|
|
||||||
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
@ -712,12 +714,12 @@ kernel void kernel_rope(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
||||||
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
||||||
|
|
||||||
|
const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
|
||||||
const float cos_theta = cos(theta);
|
const float cos_theta = cos(theta);
|
||||||
const float sin_theta = sin(theta);
|
const float sin_theta = sin(theta);
|
||||||
|
|
||||||
theta *= theta_scale;
|
|
||||||
|
|
||||||
const int64_t i0 = ib*n_dims + ic/2;
|
const int64_t i0 = ib*n_dims + ic/2;
|
||||||
|
|
||||||
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
|
Loading…
Reference in New Issue
Block a user