Update CUDA ops and tests to match implementation from commit 8fb57ac0 (llama : use im2col and mul_mat to perform convolution for Mamba); GPU version breaks with assert because of unsupported MUL_MAT

This commit is contained in:
Jan Ploski 2024-06-03 14:46:50 +02:00
parent 12c913c52c
commit 061e520075
3 changed files with 56 additions and 92 deletions

View File

@ -2,13 +2,12 @@
template <int block_size>
static __global__ void ssm_conv_f32(
const float * src0, const float * src1, const float * src2,
const int src0_nb1, const int src0_nb2,
const int src1_nb0, const int src1_nb1, const int src1_nb2,
const int src2_nb1,
const float * src0, const float * src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1,
float * dst,
const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int nr, const int n_t, const int n_s) {
const int nc, const int ncs, const int nr, const int n_t, const int n_s) {
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
@ -24,118 +23,80 @@ static __global__ void ssm_conv_f32(
const int ir1 = min(ir0 + dr, nr);
const int ir = ir1 - ir0;
// TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)?
// This would avoid having to copy into an intermediate buffer, but the state would be bigger.
// float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith;
extern __shared__ float wdata_f32[]; // work buffer for all threads
float * s = (float *) wdata_f32 + nc*dr*ith;
for (int i3 = 0; i3 < n_s; ++i3) {
float * s0 = (float *) ((char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_conv, d_inner, n_s}
// copy the state into working memory
// can't use memcpy because (d_conv) != (d_conv - 1)
for (int i1 = 0; i1 < ir; ++i1) {
for (int i0 = 0; i0 < nc - 1; ++i0) {
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
}
}
for (int i2 = 0; i2 < n_t; ++i2) {
float * x = (float *) ((char *) dst + ir0* dst_nb0 + i2* dst_nb1 + i3* dst_nb2); // {d_inner, n_t, n_s}
float * x0 = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
float * c = (float *) ((char *) src2 + ir0*src2_nb1); // {d_conv, d_inner}
// shift state left
//memmove(s, s + 1, (nc*ir - 1) * sizeof(float));
for (int i4 = 0; i4 < nc*ir - 1; ++i4) {
s[i4] = s[i4+1];
}
// {d_conv - 1 + n_t, d_inner, n_seqs}
// sliding window
const float * s = (const float *) ((const char *) src0 + ir0*src0_nb1 + i2*src0_nb0 + i3*src0_nb2); // {d_conv, d_inner, n_s}
const float * c = (const float *) ((const char *) src1 + ir0*src1_nb1); // {d_conv, d_inner}
float * x = (float *) ((char *) dst + ir0*dst_nb0 + i2*dst_nb1 + i3*dst_nb2); // {d_inner, n_t, n_s}
// TODO: transpose the output for smaller strides for big batches?
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
// insert x on the last column
s[(nc - 1) + i1*nc] = x0[i1];
}
// it seems a little faster when this is separate from the state shift
for (int i1 = 0; i1 < ir; ++i1) {
// rowwise dot product
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
float sumf = 0.0f;
// d_conv
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
sumf += s[i] * c[i];
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
}
x[i1] = sumf;
}
}
// copy the state out of it
for (int i1 = 0; i1 < ir; ++i1) {
for (int i0 = 0; i0 < nc - 1; ++i0) {
s0[i0 + i1*(nc - 1)] = s[1 + i0 + i1*nc];
}
}
}
}
static void ssm_conv_f32_cuda(
const float * src0, const float * src1, const float * src2,
const int src0_nb1, const int src0_nb2,
const int src1_nb0, const int src1_nb1, const int src1_nb2,
const int src2_nb1,
const float * src0, const float * src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1,
float * dst,
const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int nr, const int n_t, const int n_s,
const int nc, const int ncs, const int nr, const int n_t, const int n_s,
cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const int nblocks = 1; // TODO
const int shmem_size = nc * (nr + WARP_SIZE - 1) * sizeof(float); // TODO
ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, shmem_size, stream>>>(
src0, src1, src2,
src0_nb1, src0_nb2,
src1_nb0, src1_nb1, src1_nb2,
src2_nb1,
ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>(
src0, src1,
src0_nb0, src0_nb1, src0_nb2,
src1_nb1,
dst,
dst_nb0, dst_nb1, dst_nb2,
nc, nr, n_t, n_s);
nc, ncs, nr, n_t, n_s);
}
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // conv_state
const struct ggml_tensor * src1 = dst->src[1]; // x
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
const int nc = src2->ne[0]; // d_conv
const int nc = src1->ne[0]; // d_conv
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
const int nr = src0->ne[1]; // d_inner
const int n_t = src1->ne[1]; // tokens per sequence
const int n_s = src0->ne[2]; // number of sequences in the batch
const int n_t = dst->ne[1]; // tokens per sequence
const int n_s = dst->ne[2]; // number of sequences in the batch
GGML_ASSERT(ggml_are_same_shape(src1, dst));
GGML_ASSERT( dst->ne[0] == nr);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
const float * src2_d = (const float *)src2->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
ssm_conv_f32_cuda(src0_d, src1_d, src2_d,
src0->nb[1], src0->nb[2],
src1->nb[0], src1->nb[1], src1->nb[2],
src2->nb[1],
ssm_conv_f32_cuda(src0_d, src1_d,
src0->nb[0], src0->nb[1], src0->nb[2],
src1->nb[1],
dst_d,
dst->nb[0], dst->nb[1], dst->nb[2],
nc, nr, n_t, n_s,
nc, ncs, nr, n_t, n_s,
stream);
}

View File

@ -5,13 +5,12 @@ static __global__ void ssm_scan_f32(
const float * src0, const float * src1, const float * src2, const float * src3,
const float * src4, const float * src5,
const int src0_nb1, const int src0_nb2,
const int src1_nb0, const int src1_nb1, const int src1_nb2,
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
const int src2_nb0, const int src2_nb1, const int src2_nb2,
const int src3_nb1,
const int src4_nb1, const int src4_nb2,
const int src5_nb1, const int src5_nb2,
float * dst,
const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int nr, const int n_t, const int n_s) {
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
@ -30,13 +29,17 @@ static __global__ void ssm_scan_f32(
for (int i3 = 0; i3 < n_s; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) {
float * y = (float *) ((char *) dst + ir0* dst_nb0 + i2* dst_nb1 + i3* dst_nb2); // {d_inner, n_t, n_s}
float * s = (float *) ((char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s}
float * x = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
float * dt = (float *) ((char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s}
float * A = (float *) ((char *) src3 + ir0*src3_nb1); // {d_state, d_inner}
float * B = (float *) ((char *) src4 + i2*src4_nb1 + i3*src4_nb2); // {d_state, n_t, n_s}
float * C = (float *) ((char *) src5 + i2*src5_nb1 + i3*src5_nb2); // {d_state, n_t, n_s}
const float * s0 = (const float *) ((const char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s}
const float * x = (const float *) ((const char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
const float * dt = (const float *) ((const char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s}
const float * A = (const float *) ((const char *) src3 + ir0*src3_nb1); // {d_state, d_inner}
const float * B = (const float *) ((const char *) src4 + i2*src4_nb1 + i3*src4_nb2); // {d_state, n_t, n_s}
const float * C = (const float *) ((const char *) src5 + i2*src5_nb1 + i3*src5_nb2); // {d_state, n_t, n_s}
float * y = (float *) ((char *) dst + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
float * s = (float *) ((char *) dst + ir0*src0_nb1 + i3*src0_nb2 + src1_nb3); // {d_state, d_inner, n_s}
// use the output as the source for the next token-wise iterations
if (i2 > 0) { s0 = s; }
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
@ -48,7 +51,7 @@ static __global__ void ssm_scan_f32(
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// state = prev_state * dA + dB * x
float state = (s[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[i0];
s[i] = state;
@ -63,13 +66,12 @@ static void ssm_scan_f32_cuda(
const float * src0, const float * src1, const float * src2, const float * src3,
const float * src4, const float * src5,
const int src0_nb1, const int src0_nb2,
const int src1_nb0, const int src1_nb1, const int src1_nb2,
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
const int src2_nb0, const int src2_nb1, const int src2_nb2,
const int src3_nb1,
const int src4_nb1, const int src4_nb2,
const int src5_nb1, const int src5_nb2,
float * dst,
const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int nr, const int n_t, const int n_s,
cudaStream_t stream) {
@ -80,13 +82,12 @@ static void ssm_scan_f32_cuda(
src0, src1, src2, src3,
src4, src5,
src0_nb1, src0_nb2,
src1_nb0, src1_nb1, src1_nb2,
src1_nb0, src1_nb1, src1_nb2, src1_nb3,
src2_nb0, src2_nb1, src2_nb2,
src3_nb1,
src4_nb1, src4_nb2,
src5_nb1, src5_nb2,
dst,
dst_nb0, dst_nb1, dst_nb2,
nc, nr, n_t, n_s);
}
@ -103,7 +104,7 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
GGML_ASSERT(ggml_nelements(src1) == ggml_nelements(dst));
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
@ -112,6 +113,10 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// required for per-sequence offsets for states
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
@ -129,13 +134,12 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
src0_d, src1_d, src2_d, src3_d,
src4_d, src5_d,
src0->nb[1], src0->nb[2],
src1->nb[0], src1->nb[1], src1->nb[2],
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
src2->nb[0], src2->nb[1], src2->nb[2],
src3->nb[1],
src4->nb[1], src4->nb[2],
src5->nb[1], src5->nb[2],
dst_d,
dst->nb[0], dst->nb[1], dst->nb[2],
nc, nr, n_t, n_s,
stream);
}

View File

@ -1662,10 +1662,9 @@ struct test_ssm_conv : public test_case {
: type(type), d_conv(d_conv), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * s = ggml_new_tensor_3d(ctx, type, d_conv - 1, d_inner, n_seqs);
ggml_tensor * x = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs);
ggml_tensor * sx = ggml_new_tensor_3d(ctx, type, d_conv - 1 + n_seq_tokens, d_inner, n_seqs);
ggml_tensor * c = ggml_new_tensor_2d(ctx, type, d_conv, d_inner);
ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c);
ggml_tensor * out = ggml_ssm_conv(ctx, sx, c);
return out;
}
};