memory access pattern

This commit is contained in:
pidack 2024-08-27 20:51:26 +08:00
parent e53b14f152
commit eec0e8ca81

View File

@ -2,17 +2,16 @@
template <int block_size>
static __global__ void ssm_conv_f32(
const float * __restrict__ src0, const float * __restrict__ src1,
const float * src0, const float * src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1,
float * dst,
const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int ncs, const int nr) {
const int nc, const int ncs, const int nr, const int n_t, const int n_s) {
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int i2 = blockIdx.x;
const int i3 = threadIdx.y;
const int tid = blockIdx.y;
const int i3 = blockIdx.x;
const int i2 = threadIdx.x;
const int ith = tid;
const int nth = WARP_SIZE;
@ -21,7 +20,7 @@ static __global__ void ssm_conv_f32(
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr * ith;
const int ir0 = dr*ith;
const int ir1 = min(ir0 + dr, nr);
const int ir = ir1 - ir0;
@ -30,12 +29,15 @@ static __global__ void ssm_conv_f32(
const float * s = (const float *) ((const char *) src0 + ir0*src0_nb1 + i2*src0_nb0 + i3*src0_nb2); // {d_conv, d_inner, n_s}
const float * c = (const float *) ((const char *) src1 + ir0*src1_nb1); // {d_conv, d_inner}
float * x = (float *) ((char *) dst + ir0*dst_nb0 + i2*dst_nb1 + i3*dst_nb2); // {d_inner, n_t, n_s}
// TODO: transpose the output for smaller strides for big batches?
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
// rowwise dot product
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
float sumf = 0.0f;
// d_conv
#pragma unroll
for (int i0 = 0; i0 < nc; ++i0) {
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
@ -53,15 +55,17 @@ static void ssm_conv_f32_cuda(
const int nc, const int ncs, const int nr, const int n_t, const int n_s,
cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, n_s, 1);
const int nblocks = n_t;
ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>(
const dim3 block_dims(n_t, 1, 1);
//const int nblocks = n_s; // TODO
const dim3 grid_dims(n_s, WARP_SIZE, 1);
ssm_conv_f32<WARP_SIZE><<<grid_dims, block_dims, 0, stream>>>(
src0, src1,
src0_nb0, src0_nb1, src0_nb2,
src1_nb1,
dst,
dst_nb0, dst_nb1, dst_nb2,
nc, ncs, nr);
nc, ncs, nr, n_t, n_s);
}
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -86,7 +90,6 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
ssm_conv_f32_cuda(src0_d, src1_d,
src0->nb[0], src0->nb[1], src0->nb[2],
src1->nb[1],