From b423a6df5ee4610cc4828978aee375991615fc6f Mon Sep 17 00:00:00 2001 From: pidack Date: Tue, 27 Aug 2024 16:51:21 +0800 Subject: [PATCH] fix ssm_scan numerical error & others update --- ggml/src/ggml-cuda/ssm_conv.cu | 6 +-- ggml/src/ggml-cuda/ssm_scan.cu | 68 +++++++++++++++++----------------- 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/ggml/src/ggml-cuda/ssm_conv.cu b/ggml/src/ggml-cuda/ssm_conv.cu index df89b4cf5..eefe4f45e 100644 --- a/ggml/src/ggml-cuda/ssm_conv.cu +++ b/ggml/src/ggml-cuda/ssm_conv.cu @@ -2,7 +2,7 @@ template static __global__ void ssm_conv_f32( - const float * src0, const float * src1, + const float * __restrict__ src0, const float * __restrict__ src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * dst, @@ -32,7 +32,6 @@ static __global__ void ssm_conv_f32( float * x = (float *) ((char *) dst + ir0*dst_nb0 + i2*dst_nb1 + i3*dst_nb2); // {d_inner, n_t, n_s} // TODO: transpose the output for smaller strides for big batches? // d_inner - #pragma unroll for (int i1 = 0; i1 < ir; ++i1) { // rowwise dot product // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision @@ -56,7 +55,7 @@ static void ssm_conv_f32_cuda( const dim3 block_dims(WARP_SIZE, n_s, 1); const int nblocks = n_t; - + printf("size is %d\n",nr); ssm_conv_f32<<>>( src0, src1, src0_nb0, src0_nb1, src0_nb2, @@ -97,4 +96,3 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { nc, ncs, nr, n_t, n_s, stream); } - diff --git a/ggml/src/ggml-cuda/ssm_scan.cu b/ggml/src/ggml-cuda/ssm_scan.cu index dd912856d..cf08f6e0f 100644 --- a/ggml/src/ggml-cuda/ssm_scan.cu +++ b/ggml/src/ggml-cuda/ssm_scan.cu @@ -2,8 +2,8 @@ template static __global__ void ssm_scan_f32( - const float * src0, const float * src1, const float * src2, const float * src3, - const float * src4, const float * src5, + const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, const float * __restrict__ src3, + const float * __restrict__ src4, const float * __restrict__ src5, const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, @@ -11,10 +11,10 @@ static __global__ void ssm_scan_f32( const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, float * dst, - const int nc, const int nr) { + const int nc, const int nr, const int n_t, const int n_s) { +// const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - const int i2 = blockIdx.x; const int i3 = threadIdx.y; const int ith = tid; @@ -27,37 +27,37 @@ static __global__ void ssm_scan_f32( const int ir0 = dr*ith; const int ir1 = min(ir0 + dr, nr); const int ir = ir1 - ir0; + for (int i2 = 0; i2 < n_t; ++i2) { + const float * s0 = (const float *) ((const char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3 + ir0*src3_nb1); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4 + i2*src4_nb1 + i3*src4_nb2); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5 + i2*src5_nb1 + i3*src5_nb2); // {d_state, n_t, n_s} + float * y = (float *) ((char *) dst + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} + float * s = (float *) ((char *) dst + ir0*src0_nb1 + i3*src0_nb2 + src1_nb3); // {d_state, d_inner, n_s} - const float * s0 = (const float *) ((const char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3 + ir0*src3_nb1); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4 + i2*src4_nb1 + i3*src4_nb2); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5 + i2*src5_nb1 + i3*src5_nb2); // {d_state, n_t, n_s} - float * y = (float *) ((char *) dst + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} - float * s = (float *) ((char *) dst + ir0*src0_nb1 + i3*src0_nb2 + src1_nb3); // {d_state, d_inner, n_s} + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } - // use the output as the source for the next token-wise iterations - if (i2 > 0) { s0 = s; } - - // d_inner - #pragma unroll - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - #pragma unroll - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + float sumf = 0.0f; + // d_state + #pragma unroll + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // state = prev_state * dA + dB * x + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[i0]; + s[i] = state; + } + y[i1] = sumf; } - y[i1] = sumf; } } @@ -75,7 +75,7 @@ static void ssm_scan_f32_cuda( cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, n_s, 1); - const int nblocks = n_t; + const int nblocks = 1; // TODO ssm_scan_f32<<>>( src0, src1, src2, src3, @@ -87,7 +87,7 @@ static void ssm_scan_f32_cuda( src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, - nc, nr); + nc, nr, n_t, n_s); } void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {