feat: implemented sigmoid function (ggml/806)

* added sigmoid function

* implemented metal kernel for sigmoid

* implemented cuda kernel for sigmoid

* added sigmoid unary op and incremented count
This commit is contained in:
Justina Cho 2024-05-01 14:44:26 -07:00 committed by Georgi Gerganov
parent ef0d5e3ec9
commit f5ef34e428
7 changed files with 136 additions and 1 deletions

View File

@ -2204,6 +2204,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
ggml_cuda_op_relu(ctx, dst); ggml_cuda_op_relu(ctx, dst);
break; break;
case GGML_UNARY_OP_SIGMOID:
ggml_cuda_op_sigmoid(ctx, dst);
break;
case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSIGMOID:
ggml_cuda_op_hardsigmoid(ctx, dst); ggml_cuda_op_hardsigmoid(ctx, dst);
break; break;
@ -2716,6 +2719,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_GELU_QUICK:

View File

@ -48,6 +48,15 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
dst[i] = fmaxf(x[i], 0); dst[i] = fmaxf(x[i], 0);
} }
static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = 1.0f / (1.0f + expf(-x[i]));
}
static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) { static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -108,6 +117,11 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k); relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
} }
static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
sigmoid_f32<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE; const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k); hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@ -188,6 +202,18 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
} }
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data; const float * src0_d = (const float *)src0->data;

View File

@ -4,6 +4,7 @@
#define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_TANH_BLOCK_SIZE 256 #define CUDA_TANH_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256 #define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_SIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256 #define CUDA_HARDSIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSWISH_BLOCK_SIZE 256 #define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256 #define CUDA_SQR_BLOCK_SIZE 256
@ -18,6 +19,8 @@ void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -40,6 +40,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_CLAMP, GGML_METAL_KERNEL_TYPE_CLAMP,
GGML_METAL_KERNEL_TYPE_TANH, GGML_METAL_KERNEL_TYPE_TANH,
GGML_METAL_KERNEL_TYPE_RELU, GGML_METAL_KERNEL_TYPE_RELU,
GGML_METAL_KERNEL_TYPE_SIGMOID,
GGML_METAL_KERNEL_TYPE_GELU, GGML_METAL_KERNEL_TYPE_GELU,
GGML_METAL_KERNEL_TYPE_GELU_4, GGML_METAL_KERNEL_TYPE_GELU_4,
GGML_METAL_KERNEL_TYPE_GELU_QUICK, GGML_METAL_KERNEL_TYPE_GELU_QUICK,
@ -493,6 +494,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
@ -730,6 +732,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
switch (ggml_get_unary_op(op)) { switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
@ -1237,6 +1240,18 @@ static enum ggml_status ggml_metal_graph_compute(
const int64_t n = ggml_nelements(dst); const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_SIGMOID:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:

View File

@ -229,6 +229,13 @@ kernel void kernel_relu(
dst[tpig] = max(0.0f, src0[tpig]); dst[tpig] = max(0.0f, src0[tpig]);
} }
kernel void kernel_sigmoid(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
}
kernel void kernel_tanh( kernel void kernel_tanh(
device const float * src0, device const float * src0,
device float * dst, device float * dst,

73
ggml.c
View File

@ -1949,6 +1949,7 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) {
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); } inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
// TODO: optimize performance // TODO: optimize performance
inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
@ -2329,6 +2330,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"TANH", "TANH",
"ELU", "ELU",
"RELU", "RELU",
"SIGMOID",
"GELU", "GELU",
"GELU_QUICK", "GELU_QUICK",
"SILU", "SILU",
@ -2336,7 +2338,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"HARDSIGMOID", "HARDSIGMOID",
}; };
static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12"); static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@ -4561,6 +4563,20 @@ struct ggml_tensor * ggml_leaky_relu(
return result; return result;
} }
// ggml_sigmoid
struct ggml_tensor * ggml_sigmoid(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
}
struct ggml_tensor * ggml_sigmoid_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
}
// ggml_gelu // ggml_gelu
struct ggml_tensor * ggml_gelu( struct ggml_tensor * ggml_gelu(
@ -10852,6 +10868,52 @@ static void ggml_compute_forward_relu(
} }
} }
// ggml_compute_forward_sigmoid
static void ggml_compute_forward_sigmoid_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
assert(params->ith == 0);
assert(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return;
}
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
assert(dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_vec_sigmoid_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_compute_forward_sigmoid(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_sigmoid_f32(params, dst);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_gelu // ggml_compute_forward_gelu
static void ggml_compute_forward_gelu_f32( static void ggml_compute_forward_gelu_f32(
@ -16617,6 +16679,10 @@ static void ggml_compute_forward_unary(
{ {
ggml_compute_forward_relu(params, dst); ggml_compute_forward_relu(params, dst);
} break; } break;
case GGML_UNARY_OP_SIGMOID:
{
ggml_compute_forward_sigmoid(params, dst);
} break;
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:
{ {
ggml_compute_forward_gelu(params, dst); ggml_compute_forward_gelu(params, dst);
@ -18601,6 +18667,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
zero_table); zero_table);
} }
} break; } break;
case GGML_UNARY_OP_SIGMOID:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:
{ {
GGML_ASSERT(false); // TODO: not implemented GGML_ASSERT(false); // TODO: not implemented
@ -19130,6 +19200,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
{ {

9
ggml.h
View File

@ -519,6 +519,7 @@ extern "C" {
GGML_UNARY_OP_TANH, GGML_UNARY_OP_TANH,
GGML_UNARY_OP_ELU, GGML_UNARY_OP_ELU,
GGML_UNARY_OP_RELU, GGML_UNARY_OP_RELU,
GGML_UNARY_OP_SIGMOID,
GGML_UNARY_OP_GELU, GGML_UNARY_OP_GELU,
GGML_UNARY_OP_GELU_QUICK, GGML_UNARY_OP_GELU_QUICK,
GGML_UNARY_OP_SILU, GGML_UNARY_OP_SILU,
@ -1073,6 +1074,14 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_sigmoid(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_sigmoid_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_gelu( GGML_API struct ggml_tensor * ggml_gelu(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);