From 03d0e6eabe6172a56a7d470bfd844012f2c2b291 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 10:58:41 -0400 Subject: [PATCH] metal : use log and exp instead of log1pf and expf in SSM_SCAN --- ggml/src/ggml-metal.metal | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c75fa25c3..cee9980a7 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -866,13 +866,13 @@ kernel void kernel_ssm_scan_f32( device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} - const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { const int64_t i = i0 + i1*nc; - const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt); + const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); sumf += state * C[i0]; s[i] = state; } @@ -955,9 +955,9 @@ kernel void kernel_ssm_scan_f32_group( device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} - const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; - const float dA = expf(dt_soft_plus * A[0]); + const float dA = exp(dt_soft_plus * A[0]); float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) {