mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-15 07:19:53 +00:00
metal : use log and exp instead of log1pf and expf in SSM_SCAN
This commit is contained in:
parent
87b97d08f4
commit
03d0e6eabe
@ -866,13 +866,13 @@ kernel void kernel_ssm_scan_f32(
|
||||
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
|
||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns}
|
||||
|
||||
const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0];
|
||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||
const float x_dt = x[0] * dt_soft_plus;
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||
const int64_t i = i0 + i1*nc;
|
||||
const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
||||
const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
||||
sumf += state * C[i0];
|
||||
s[i] = state;
|
||||
}
|
||||
@ -955,9 +955,9 @@ kernel void kernel_ssm_scan_f32_group(
|
||||
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
|
||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns}
|
||||
|
||||
const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0];
|
||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||
const float x_dt = x[0] * dt_soft_plus;
|
||||
const float dA = expf(dt_soft_plus * A[0]);
|
||||
const float dA = exp(dt_soft_plus * A[0]);
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||
|
Loading…
Reference in New Issue
Block a user