mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-31 22:04:35 +00:00
metal : add GGML_UNARY_OP_ELU
kernel (ggml/1018)
This commit is contained in:
parent
2a11b6b094
commit
4ac059838d
@ -126,6 +126,7 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
||||||
GGML_METAL_KERNEL_TYPE_SILU,
|
GGML_METAL_KERNEL_TYPE_SILU,
|
||||||
GGML_METAL_KERNEL_TYPE_SILU_4,
|
GGML_METAL_KERNEL_TYPE_SILU_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_ELU,
|
||||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
||||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
||||||
@ -649,6 +650,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
||||||
@ -968,6 +970,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
|
case GGML_UNARY_OP_ELU:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
@ -1589,6 +1592,18 @@ static void ggml_metal_encode_node(
|
|||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_UNARY_OP_ELU:
|
||||||
|
{
|
||||||
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ELU].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
|
const int64_t n = ggml_nelements(dst);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||||
|
@ -782,6 +782,14 @@ kernel void kernel_silu_4(
|
|||||||
dst[tpig] = x / (1.0f + exp(-x));
|
dst[tpig] = x / (1.0f + exp(-x));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_elu(
|
||||||
|
device const float * src0,
|
||||||
|
device float * dst,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
device const float & x = src0[tpig];
|
||||||
|
dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_sqr(
|
kernel void kernel_sqr(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
Loading…
Reference in New Issue
Block a user