10x performance improve 4 cuda ssm conv & scan

This commit is contained in:
pidack 2024-08-26 17:33:23 +08:00
parent fae826fb56
commit 20d390bea4
2 changed files with 58 additions and 61 deletions

View File

@ -7,10 +7,12 @@ static __global__ void ssm_conv_f32(
const int src1_nb1, const int src1_nb1,
float * dst, float * dst,
const int dst_nb0, const int dst_nb1, const int dst_nb2, const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int ncs, const int nr, const int n_t, const int n_s) { const int nc, const int ncs, const int nr) {
// const int row = blockIdx.x*blockDim.y + threadIdx.y; // const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int i2 = blockIdx.x;
const int i3 = threadIdx.y;
const int ith = tid; const int ith = tid;
const int nth = WARP_SIZE; const int nth = WARP_SIZE;
@ -19,33 +21,28 @@ static __global__ void ssm_conv_f32(
const int dr = (nr + nth - 1)/nth; const int dr = (nr + nth - 1)/nth;
// row range for this thread // row range for this thread
const int ir0 = dr*ith; const int ir0 = dr * ith;
const int ir1 = min(ir0 + dr, nr); const int ir1 = min(ir0 + dr, nr);
const int ir = ir1 - ir0; const int ir = ir1 - ir0;
for (int i3 = 0; i3 < n_s; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) {
// {d_conv - 1 + n_t, d_inner, n_seqs} // {d_conv - 1 + n_t, d_inner, n_seqs}
// sliding window // sliding window
const float * s = (const float *) ((const char *) src0 + ir0*src0_nb1 + i2*src0_nb0 + i3*src0_nb2); // {d_conv, d_inner, n_s} const float * s = (const float *) ((const char *) src0 + ir0*src0_nb1 + i2*src0_nb0 + i3*src0_nb2); // {d_conv, d_inner, n_s}
const float * c = (const float *) ((const char *) src1 + ir0*src1_nb1); // {d_conv, d_inner} const float * c = (const float *) ((const char *) src1 + ir0*src1_nb1); // {d_conv, d_inner}
float * x = (float *) ((char *) dst + ir0*dst_nb0 + i2*dst_nb1 + i3*dst_nb2); // {d_inner, n_t, n_s} float * x = (float *) ((char *) dst + ir0*dst_nb0 + i2*dst_nb1 + i3*dst_nb2); // {d_inner, n_t, n_s}
// TODO: transpose the output for smaller strides for big batches? // TODO: transpose the output for smaller strides for big batches?
// d_inner // d_inner
#pragma unroll
for (int i1 = 0; i1 < ir; ++i1) { for (int i1 = 0; i1 < ir; ++i1) {
// rowwise dot product // rowwise dot product
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
float sumf = 0.0f; float sumf = 0.0f;
#pragma unroll
// d_conv
for (int i0 = 0; i0 < nc; ++i0) { for (int i0 = 0; i0 < nc; ++i0) {
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
} }
x[i1] = sumf; x[i1] = sumf;
} }
}
}
} }
static void ssm_conv_f32_cuda( static void ssm_conv_f32_cuda(
@ -57,8 +54,8 @@ static void ssm_conv_f32_cuda(
const int nc, const int ncs, const int nr, const int n_t, const int n_s, const int nc, const int ncs, const int nr, const int n_t, const int n_s,
cudaStream_t stream) { cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, n_s, 1);
const int nblocks = 1; // TODO const int nblocks = n_t;
ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>( ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>(
src0, src1, src0, src1,
@ -66,7 +63,7 @@ static void ssm_conv_f32_cuda(
src1_nb1, src1_nb1,
dst, dst,
dst_nb0, dst_nb1, dst_nb2, dst_nb0, dst_nb1, dst_nb2,
nc, ncs, nr, n_t, n_s); nc, ncs, nr);
} }
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -100,3 +97,4 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
nc, ncs, nr, n_t, n_s, nc, ncs, nr, n_t, n_s,
stream); stream);
} }

View File

@ -11,10 +11,11 @@ static __global__ void ssm_scan_f32(
const int src4_nb1, const int src4_nb2, const int src4_nb1, const int src4_nb2,
const int src5_nb1, const int src5_nb2, const int src5_nb1, const int src5_nb2,
float * dst, float * dst,
const int nc, const int nr, const int n_t, const int n_s) { const int nc, const int nr) {
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int i2 = blockIdx.x;
const int i3 = threadIdx.y;
const int ith = tid; const int ith = tid;
const int nth = WARP_SIZE; const int nth = WARP_SIZE;
@ -27,8 +28,6 @@ static __global__ void ssm_scan_f32(
const int ir1 = min(ir0 + dr, nr); const int ir1 = min(ir0 + dr, nr);
const int ir = ir1 - ir0; const int ir = ir1 - ir0;
for (int i3 = 0; i3 < n_s; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) {
const float * s0 = (const float *) ((const char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s} const float * s0 = (const float *) ((const char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s}
const float * x = (const float *) ((const char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} const float * x = (const float *) ((const char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
const float * dt = (const float *) ((const char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s} const float * dt = (const float *) ((const char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s}
@ -42,12 +41,14 @@ static __global__ void ssm_scan_f32(
if (i2 > 0) { s0 = s; } if (i2 > 0) { s0 = s; }
// d_inner // d_inner
#pragma unroll
for (int i1 = 0; i1 < ir; ++i1) { for (int i1 = 0; i1 < ir; ++i1) {
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
float x_dt = x[i1] * dt_soft_plus; float x_dt = x[i1] * dt_soft_plus;
float sumf = 0.0f; float sumf = 0.0f;
// d_state // d_state
#pragma unroll
for (int i0 = 0; i0 < nc; ++i0) { for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc; int i = i0 + i1*nc;
// state = prev_state * dA + dB * x // state = prev_state * dA + dB * x
@ -58,8 +59,6 @@ static __global__ void ssm_scan_f32(
} }
y[i1] = sumf; y[i1] = sumf;
} }
}
}
} }
static void ssm_scan_f32_cuda( static void ssm_scan_f32_cuda(
@ -75,8 +74,8 @@ static void ssm_scan_f32_cuda(
const int nc, const int nr, const int n_t, const int n_s, const int nc, const int nr, const int n_t, const int n_s,
cudaStream_t stream) { cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, n_s, 1);
const int nblocks = 1; // TODO const int nblocks = n_t;
ssm_scan_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>( ssm_scan_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>(
src0, src1, src2, src3, src0, src1, src2, src3,
@ -88,7 +87,7 @@ static void ssm_scan_f32_cuda(
src4_nb1, src4_nb2, src4_nb1, src4_nb2,
src5_nb1, src5_nb2, src5_nb1, src5_nb2,
dst, dst,
nc, nr, n_t, n_s); nc, nr);
} }
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {