diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 895ba4794..2df845234 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -34,7 +34,8 @@ #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" - +#include "ggml-cuda/ssm_conv.cuh" +#include "ggml-cuda/ssm_scan.cuh" #include #include #include @@ -2342,6 +2343,11 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_FLASH_ATTN_EXT: ggml_cuda_flash_attn_ext(ctx, dst); break; + case GGML_OP_SSM_CONV: + ggml_cuda_op_ssm_conv(ctx, dst); + break; + case GGML_OP_SSM_SCAN: + ggml_cuda_op_ssm_scan(ctx, dst); case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; @@ -2967,6 +2973,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: + case GGML_OP_SSM_CONV: + case GGML_OP_SSM_SCAN: return true; case GGML_OP_FLASH_ATTN_EXT: #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 133e219f0..f2d643e4e 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -153,9 +153,9 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou } static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { - GGML_ASSERT(ncols % WARP_SIZE == 0); + GGML_ASSERT(ncols % WARP_SIZE == 0 || ncols < WARP_SIZE); if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_dims(min(ncols, WARP_SIZE), 1, 1); rms_norm_f32<<>>(x, dst, ncols, eps); } else { const dim3 block_dims(1024, 1, 1); diff --git a/ggml/src/ggml-cuda/ssm_conv.cu b/ggml/src/ggml-cuda/ssm_conv.cu new file mode 100644 index 000000000..abb0177f0 --- /dev/null +++ b/ggml/src/ggml-cuda/ssm_conv.cu @@ -0,0 +1,100 @@ +#include "ssm_conv.cuh" + +template +static __global__ void ssm_conv_f32( + const float * __restrict__ src0, const float * __restrict__ src1, + const int src0_nb0, const int src0_nb1, const int src0_nb2, + const int src1_nb1, + float * __restrict__ dst, + const int dst_nb0, const int dst_nb1, const int dst_nb2, + const int nc, const int ncs, const int nr, const int n_t, const int n_s) { + + const int tid = blockIdx.y; + const int i3 = blockIdx.x; + const int i2 = threadIdx.x; + + const int ith = tid; + const int nth = WARP_SIZE; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = min(ir0 + dr, nr); + const int ir = ir1 - ir0; + + // {d_conv - 1 + n_t, d_inner, n_seqs} + // sliding window + const float * s = (const float *) ((const char *) src0 + ir0*src0_nb1 + i2*src0_nb0 + i3*src0_nb2); // {d_conv, d_inner, n_s} + const float * c = (const float *) ((const char *) src1 + ir0*src1_nb1); // {d_conv, d_inner} + float * x = (float *) ((char *) dst + ir0*dst_nb0 + i2*dst_nb1 + i3*dst_nb2); // {d_inner, n_t, n_s} + + // TODO: transpose the output for smaller strides for big batches? + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // rowwise dot product + // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision + float sumf = 0.0f; + + // d_conv + #pragma unroll + for (int i0 = 0; i0 < nc; ++i0) { + sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; + } + x[i1] = sumf; + } +} + +static void ssm_conv_f32_cuda( + const float * src0, const float * src1, + const int src0_nb0, const int src0_nb1, const int src0_nb2, + const int src1_nb1, + float * dst, + const int dst_nb0, const int dst_nb1, const int dst_nb2, + const int nc, const int ncs, const int nr, const int n_t, const int n_s, + cudaStream_t stream) { + + const dim3 block_dims(n_t, 1, 1); + //const int nblocks = n_s; // TODO + const dim3 grid_dims(n_s, WARP_SIZE, 1); + + ssm_conv_f32<<>>( + src0, src1, + src0_nb0, src0_nb1, src0_nb2, + src1_nb1, + dst, + dst_nb0, dst_nb1, dst_nb2, + nc, ncs, nr, n_t, n_s); +} + +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; // conv_x + const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight + + const int nc = src1->ne[0]; // d_conv + const int ncs = src0->ne[0]; // d_conv - 1 + n_t + const int nr = src0->ne[1]; // d_inner + const int n_t = dst->ne[1]; // tokens per sequence + const int n_s = dst->ne[2]; // number of sequences in the batch + + GGML_ASSERT( dst->ne[0] == nr); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + ssm_conv_f32_cuda(src0_d, src1_d, + src0->nb[0], src0->nb[1], src0->nb[2], + src1->nb[1], + dst_d, + dst->nb[0], dst->nb[1], dst->nb[2], + nc, ncs, nr, n_t, n_s, + stream); +} diff --git a/ggml/src/ggml-cuda/ssm_conv.cuh b/ggml/src/ggml-cuda/ssm_conv.cuh new file mode 100644 index 000000000..8e6c1f00b --- /dev/null +++ b/ggml/src/ggml-cuda/ssm_conv.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ssm_scan.cu b/ggml/src/ggml-cuda/ssm_scan.cu new file mode 100644 index 000000000..cc8dac9e6 --- /dev/null +++ b/ggml/src/ggml-cuda/ssm_scan.cu @@ -0,0 +1,144 @@ +#include "ssm_scan.cuh" + +template +static __global__ void ssm_scan_f32( + const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, const float * __restrict__ src3, + const float * __restrict__ src4, const float * __restrict__ src5, + const int src0_nb1, const int src0_nb2, + const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, + const int src2_nb0, const int src2_nb1, const int src2_nb2, + const int src3_nb1, + const int src4_nb1, const int src4_nb2, + const int src5_nb1, const int src5_nb2, + float * __restrict__ dst, + const int nc, const int nr, const int n_t, const int n_s) { + +// const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + const int i3 = threadIdx.y; + + const int ith = tid; + const int nth = WARP_SIZE; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = min(ir0 + dr, nr); + const int ir = ir1 - ir0; + for (int i2 = 0; i2 < n_t; ++i2) { + const float * s0 = (const float *) ((const char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3 + ir0*src3_nb1); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4 + i2*src4_nb1 + i3*src4_nb2); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5 + i2*src5_nb1 + i3*src5_nb2); // {d_state, n_t, n_s} + float * y = (float *) ((char *) dst + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} + float * s = (float *) ((char *) dst + ir0*src0_nb1 + i3*src0_nb2 + src1_nb3); // {d_state, d_inner, n_s} + + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } + + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + float sumf = 0.0f; + // d_state + #pragma unroll + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // state = prev_state * dA + dB * x + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[i0]; + s[i] = state; + } + y[i1] = sumf; + } + } +} + +static void ssm_scan_f32_cuda( + const float * src0, const float * src1, const float * src2, const float * src3, + const float * src4, const float * src5, + const int src0_nb1, const int src0_nb2, + const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, + const int src2_nb0, const int src2_nb1, const int src2_nb2, + const int src3_nb1, + const int src4_nb1, const int src4_nb2, + const int src5_nb1, const int src5_nb2, + float * dst, + const int nc, const int nr, const int n_t, const int n_s, + cudaStream_t stream) { + + const dim3 block_dims(WARP_SIZE, n_s, 1); + const int nblocks = 1; // TODO + + ssm_scan_f32<<>>( + src0, src1, src2, src3, + src4, src5, + src0_nb1, src0_nb2, + src1_nb0, src1_nb1, src1_nb2, src1_nb3, + src2_nb0, src2_nb1, src2_nb2, + src3_nb1, + src4_nb1, src4_nb2, + src5_nb1, src5_nb2, + dst, + nc, nr, n_t, n_s); +} + +void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; // s + const struct ggml_tensor * src1 = dst->src[1]; // x + const struct ggml_tensor * src2 = dst->src[2]; // dt + const struct ggml_tensor * src3 = dst->src[3]; // A + const struct ggml_tensor * src4 = dst->src[4]; // B + const struct ggml_tensor * src5 = dst->src[5]; // C + + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = src1->ne[1]; // number of tokens per sequence + const int64_t n_s = src0->ne[2]; // number of sequences in the batch + + GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(float)); + GGML_ASSERT(src3->nb[0] == sizeof(float)); + GGML_ASSERT(src4->nb[0] == sizeof(float)); + GGML_ASSERT(src5->nb[0] == sizeof(float)); + // required for the dot product between s and C + GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + // required for per-sequence offsets for states + GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); + // required to get correct offset for state destination (i.e. src1->nb[3]) + GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); + + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + const float * src2_d = (const float *)src2->data; + const float * src3_d = (const float *)src3->data; + const float * src4_d = (const float *)src4->data; + const float * src5_d = (const float *)src5->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + ssm_scan_f32_cuda( + src0_d, src1_d, src2_d, src3_d, + src4_d, src5_d, + src0->nb[1], src0->nb[2], + src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], + src2->nb[0], src2->nb[1], src2->nb[2], + src3->nb[1], + src4->nb[1], src4->nb[2], + src5->nb[1], src5->nb[2], + dst_d, + nc, nr, n_t, n_s, + stream); +} diff --git a/ggml/src/ggml-cuda/ssm_scan.cuh b/ggml/src/ggml-cuda/ssm_scan.cuh new file mode 100644 index 000000000..ee078f5eb --- /dev/null +++ b/ggml/src/ggml-cuda/ssm_scan.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/llama.cpp b/src/llama.cpp index af8afd845..be0da6721 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9590,9 +9590,9 @@ static struct ggml_tensor * llm_build_mamba( // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers if (ssm_dt_b_c_rms) { - dt = ggml_rms_norm(ctx, dt, norm_rms_eps); - B = ggml_rms_norm(ctx, B, norm_rms_eps); - C = ggml_rms_norm(ctx, C, norm_rms_eps); + dt = ggml_rms_norm(ctx, ggml_cont(ctx, dt), norm_rms_eps); + B = ggml_rms_norm(ctx, ggml_cont(ctx, B), norm_rms_eps); + C = ggml_rms_norm(ctx, ggml_cont(ctx, C), norm_rms_eps); } // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}