mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-15 07:19:53 +00:00
metal : use log and exp instead of log1pf and expf in SSM_SCAN
This commit is contained in:
parent
87b97d08f4
commit
03d0e6eabe
@ -866,13 +866,13 @@ kernel void kernel_ssm_scan_f32(
|
|||||||
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
|
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
|
||||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns}
|
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns}
|
||||||
|
|
||||||
const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0];
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||||
const float x_dt = x[0] * dt_soft_plus;
|
const float x_dt = x[0] * dt_soft_plus;
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
|
|
||||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||||
const int64_t i = i0 + i1*nc;
|
const int64_t i = i0 + i1*nc;
|
||||||
const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
||||||
sumf += state * C[i0];
|
sumf += state * C[i0];
|
||||||
s[i] = state;
|
s[i] = state;
|
||||||
}
|
}
|
||||||
@ -955,9 +955,9 @@ kernel void kernel_ssm_scan_f32_group(
|
|||||||
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
|
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
|
||||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns}
|
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns}
|
||||||
|
|
||||||
const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0];
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||||
const float x_dt = x[0] * dt_soft_plus;
|
const float x_dt = x[0] * dt_soft_plus;
|
||||||
const float dA = expf(dt_soft_plus * A[0]);
|
const float dA = exp(dt_soft_plus * A[0]);
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
|
|
||||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||||
|
Loading…
Reference in New Issue
Block a user