modify unnecessary calculations

This commit is contained in:
lihan 2024-12-04 14:33:55 +08:00
parent 828e4f72a7
commit e52a22d8d8

View File

@ -63,26 +63,25 @@ __global__ void __launch_bounds__(splitD, 2)
__syncthreads(); __syncthreads();
for (int i = 0; i < L; i++) { for (int i = 0; i < L; i++) {
float dt_soft_plus = dt_block[i * stride_dt + wid * warpSize + wtid]; float dt_soft_plus = dt_block[i * stride_dt + tid];
if (dt_soft_plus <= 20.0f) { if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(exp(dt_soft_plus)); dt_soft_plus = log1pf(exp(dt_soft_plus));
} }
float x_dt = x_block[i * stride_x + wid * warpSize + wtid] * dt_soft_plus; float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
float sumf = 0.0f; float sumf = 0.0f;
#pragma unroll #pragma unroll
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
float state = (smem_s0[(wid * warpSize + wtid) * stride_ss0 + j] * float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
expf(dt_soft_plus * smem_A[(wid * warpSize + wtid) * stride_sA + j])) +
(B_block[i * stride_B + j] * x_dt); (B_block[i * stride_B + j] * x_dt);
sumf += state * C_block[i * stride_C + j]; sumf += state * C_block[i * stride_C + j];
if (i == L - 1) { if (i == L - 1) {
s_block[(wid * warpSize + wtid) * stride_s + j] = state; s_block[tid * stride_s + j] = state;
} else { } else {
smem_s0[(wid * warpSize + wtid) * stride_ss0 + j] = state; smem_s0[tid * stride_ss0 + j] = state;
} }
} }
__syncthreads(); __syncthreads();
y_block[i * stride_y + wid * warpSize + wtid] = sumf; y_block[i * stride_y + tid] = sumf;
} }
} }