mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 03:31:46 +00:00
modify unnecessary calculations
This commit is contained in:
parent
828e4f72a7
commit
e52a22d8d8
@ -63,26 +63,25 @@ __global__ void __launch_bounds__(splitD, 2)
|
||||
__syncthreads();
|
||||
|
||||
for (int i = 0; i < L; i++) {
|
||||
float dt_soft_plus = dt_block[i * stride_dt + wid * warpSize + wtid];
|
||||
float dt_soft_plus = dt_block[i * stride_dt + tid];
|
||||
if (dt_soft_plus <= 20.0f) {
|
||||
dt_soft_plus = log1pf(exp(dt_soft_plus));
|
||||
}
|
||||
float x_dt = x_block[i * stride_x + wid * warpSize + wtid] * dt_soft_plus;
|
||||
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
|
||||
float sumf = 0.0f;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < N; j++) {
|
||||
float state = (smem_s0[(wid * warpSize + wtid) * stride_ss0 + j] *
|
||||
expf(dt_soft_plus * smem_A[(wid * warpSize + wtid) * stride_sA + j])) +
|
||||
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
|
||||
(B_block[i * stride_B + j] * x_dt);
|
||||
sumf += state * C_block[i * stride_C + j];
|
||||
if (i == L - 1) {
|
||||
s_block[(wid * warpSize + wtid) * stride_s + j] = state;
|
||||
s_block[tid * stride_s + j] = state;
|
||||
} else {
|
||||
smem_s0[(wid * warpSize + wtid) * stride_ss0 + j] = state;
|
||||
smem_s0[tid * stride_ss0 + j] = state;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
y_block[i * stride_y + wid * warpSize + wtid] = sumf;
|
||||
y_block[i * stride_y + tid] = sumf;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user