mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
parent
82e4f64578
commit
ab558ac2b3
@ -1269,6 +1269,8 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
if (id_src1) {
|
if (id_src1) {
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
} else {
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||||
@ -1520,7 +1522,7 @@ void ggml_metal_graph_compute(
|
|||||||
else if (src0t == GGML_TYPE_Q6_K) {
|
else if (src0t == GGML_TYPE_Q6_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
int64_t ny = (ne11 + nrows - 1)/nrows;
|
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -348,7 +348,7 @@ kernel void kernel_soft_max(
|
|||||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||||
|
|
||||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
|
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
||||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
@ -386,6 +386,8 @@ kernel void kernel_soft_max(
|
|||||||
}
|
}
|
||||||
|
|
||||||
float sum = simd_sum(lsum);
|
float sum = simd_sum(lsum);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if (ntg > N_SIMDWIDTH) {
|
if (ntg > N_SIMDWIDTH) {
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
buf[tiisg] = 0.0f;
|
buf[tiisg] = 0.0f;
|
||||||
@ -429,7 +431,7 @@ kernel void kernel_soft_max_4(
|
|||||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||||
|
|
||||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
||||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
@ -468,6 +470,8 @@ kernel void kernel_soft_max_4(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
float sum = simd_sum(lsum);
|
float sum = simd_sum(lsum);
|
||||||
if (ntg > N_SIMDWIDTH) {
|
if (ntg > N_SIMDWIDTH) {
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
|
Loading…
Reference in New Issue
Block a user