metal : use log and exp instead of log1pf and expf in SSM_SCAN

This commit is contained in:
Francis Couture-Harpin 2024-10-02 10:58:41 -04:00
parent 87b97d08f4
commit 03d0e6eabe

View File

@ -866,13 +866,13 @@ kernel void kernel_ssm_scan_f32(
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns}
const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus; const float x_dt = x[0] * dt_soft_plus;
float sumf = 0.0f; float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) { for (int64_t i0 = 0; i0 < nc; ++i0) {
const int64_t i = i0 + i1*nc; const int64_t i = i0 + i1*nc;
const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt); const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
sumf += state * C[i0]; sumf += state * C[i0];
s[i] = state; s[i] = state;
} }
@ -955,9 +955,9 @@ kernel void kernel_ssm_scan_f32_group(
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns}
const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus; const float x_dt = x[0] * dt_soft_plus;
const float dA = expf(dt_soft_plus * A[0]); const float dA = exp(dt_soft_plus * A[0]);
float sumf = 0.0f; float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) { for (int64_t i0 = 0; i0 < nc; ++i0) {