diff --git a/ggml/src/ggml-cuda/ssm_scan.cu b/ggml/src/ggml-cuda/ssm_scan.cu index 357d3eba5..25dfd1ead 100644 --- a/ggml/src/ggml-cuda/ssm_scan.cu +++ b/ggml/src/ggml-cuda/ssm_scan.cu @@ -63,26 +63,25 @@ __global__ void __launch_bounds__(splitD, 2) __syncthreads(); for (int i = 0; i < L; i++) { - float dt_soft_plus = dt_block[i * stride_dt + wid * warpSize + wtid]; + float dt_soft_plus = dt_block[i * stride_dt + tid]; if (dt_soft_plus <= 20.0f) { dt_soft_plus = log1pf(exp(dt_soft_plus)); } - float x_dt = x_block[i * stride_x + wid * warpSize + wtid] * dt_soft_plus; + float x_dt = x_block[i * stride_x + tid] * dt_soft_plus; float sumf = 0.0f; #pragma unroll for (int j = 0; j < N; j++) { - float state = (smem_s0[(wid * warpSize + wtid) * stride_ss0 + j] * - expf(dt_soft_plus * smem_A[(wid * warpSize + wtid) * stride_sA + j])) + + float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) + (B_block[i * stride_B + j] * x_dt); sumf += state * C_block[i * stride_C + j]; if (i == L - 1) { - s_block[(wid * warpSize + wtid) * stride_s + j] = state; + s_block[tid * stride_s + j] = state; } else { - smem_s0[(wid * warpSize + wtid) * stride_ss0 + j] = state; + smem_s0[tid * stride_ss0 + j] = state; } } __syncthreads(); - y_block[i * stride_y + wid * warpSize + wtid] = sumf; + y_block[i * stride_y + tid] = sumf; } }