diff --git a/ggml.c b/ggml.c index 58ac97026..7a3a5fa94 100644 --- a/ggml.c +++ b/ggml.c @@ -7103,40 +7103,35 @@ struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, struct ggml_tensor * s, struct ggml_tensor * x, - struct ggml_tensor * c, - struct ggml_tensor * sq) { + struct ggml_tensor * c) { GGML_ASSERT(ggml_is_3d(s)); - GGML_ASSERT(ggml_is_matrix(x)); + GGML_ASSERT(ggml_is_3d(x)); GGML_ASSERT(ggml_is_matrix(c)); - GGML_ASSERT(ggml_is_vector(sq)); - GGML_ASSERT(sq->type == GGML_TYPE_I32); - const int64_t d_conv = c->ne[0]; - const int64_t d_inner = c->ne[1]; - const int64_t n_tokens = x->ne[1]; - const int64_t n_rs = s->ne[2]; + const int64_t d_conv = c->ne[0]; + const int64_t d_inner = c->ne[1]; + const int64_t n_t = x->ne[1]; // tokens per sequence + const int64_t n_s = s->ne[2]; - GGML_ASSERT( s->ne[0] == d_conv - 1); - GGML_ASSERT( s->ne[1] == d_inner); - GGML_ASSERT( x->ne[0] == d_inner); - GGML_ASSERT(sq->ne[0] == n_tokens); + GGML_ASSERT(s->ne[0] == d_conv - 1); + GGML_ASSERT(s->ne[1] == d_inner); + GGML_ASSERT(x->ne[0] == d_inner); + GGML_ASSERT(x->ne[2] == n_s); bool is_node = false; - if (s->grad || x->grad || c->grad || sq->grad) { + if (s->grad || x->grad || c->grad) { GGML_ASSERT(false); // TODO: implement is_node = true; } - // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_rs} - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_rs)); + struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s); result->op = GGML_OP_SSM_CONV; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = s; result->src[1] = x; result->src[2] = c; - result->src[3] = sq; return result; } @@ -7150,40 +7145,43 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C, - struct ggml_tensor * sq) { + struct ggml_tensor * C) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); - GGML_ASSERT(sq->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_matrix(A)); + GGML_ASSERT(ggml_is_3d(B)); + GGML_ASSERT(ggml_is_3d(s)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(ggml_are_same_shape(B, C)); { - const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_tokens = x->ne[1]; + const int64_t d_state = s->ne[0]; + const int64_t d_inner = s->ne[1]; + const int64_t n_seq_tokens = x->ne[1]; + const int64_t n_seqs = x->ne[2]; + GGML_ASSERT(s->ne[2] == n_seqs); GGML_ASSERT(x->ne[0] == d_inner); GGML_ASSERT(A->ne[0] == d_state); GGML_ASSERT(A->ne[1] == d_inner); GGML_ASSERT(B->ne[0] == d_state); - GGML_ASSERT(B->ne[1] == n_tokens); - GGML_ASSERT(C->ne[0] == d_state); - GGML_ASSERT(C->ne[1] == n_tokens); + GGML_ASSERT(B->ne[1] == n_seq_tokens); + GGML_ASSERT(B->ne[2] == n_seqs); } bool is_node = false; - if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) { + if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) { GGML_ASSERT(false); // TODO: implement is_node = true; } - // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_rs} - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); + // y + struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, x->ne[0], x->ne[1], x->ne[2]); result->op = GGML_OP_SSM_SCAN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7193,7 +7191,6 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; - result->src[6] = sq; return result; } @@ -16249,24 +16246,20 @@ static void ggml_compute_forward_ssm_conv_f32( const struct ggml_tensor * src0 = dst->src[0]; // conv_state const struct ggml_tensor * src1 = dst->src[1]; // x const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight - const struct ggml_tensor * src3 = dst->src[3]; // state_seq const int ith = params->ith; const int nth = params->nth; - const int nc = src2->ne[0]; // d_conv - const int nr = src0->ne[1]; // d_inner - const int n_t = src1->ne[1]; // n_tokens - const int n_rs = src0->ne[2]; // max number of sequences in the batch + const int nc = src2->ne[0]; // d_conv + const int nr = src0->ne[1]; // d_inner + const int n_t = src1->ne[1]; // tokens per sequence + const int n_s = src0->ne[2]; // number of sequences in the batch - GGML_ASSERT((nr*n_t) + (nc*nr*n_rs) == ggml_nelements(dst)); + GGML_ASSERT(ggml_are_same_shape(src1, dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); - GGML_ASSERT(src3->nb[0] == sizeof(int32_t)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // for use with the destination state offset between sequences - GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float)); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -16276,64 +16269,53 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - const int32_t * sq = src3->data; // {n_tokens} + // TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)? + // This would avoid having to copy into an intermediate buffer, but the state would be bigger. + float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith; - if (n_rs > 1) { - // multiple sequences means it's hard to know when it's the first time a state is read, - // so copy them all over to the destination, just to be sure. - for (int i3 = 0; i3 < n_rs; ++i3) { - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); - float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); - // can't use memcpy because of d_conv vs d_conv - 1 - for (int i1 = 0; i1 < ir; ++i1) { - for (int i0 = 0; i0 < nc - 1; ++i0) { - // copy s0 to last (d_conv - 1) columns of s - s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; - } - } - } - } + for (int i3 = 0; i3 < n_s; ++i3) { + float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} - for (int i2 = 0; i2 < n_t; ++i2) { - int32_t sq_i = sq[i2]; - float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq_i*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_rs} - float * s0; // {d_conv - 1, d_inner, n_rs} - float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} - int ne0s0; - - GGML_ASSERT(0 <= sq_i && sq_i < n_rs); - - // avoid needing to copy the state for the first token - if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_conv - 1, d_inner, n_rs} - ne0s0 = src0->ne[0]; - } else { - // the source is the last (d_conv - 1) columns of the destination - s0 = s + 1; - ne0s0 = nc; - } - - // d_inner + // copy the state into working memory + // can't use memcpy because (d_conv) != (d_conv - 1) for (int i1 = 0; i1 < ir; ++i1) { - // shift state left for (int i0 = 0; i0 < nc - 1; ++i0) { - s[i0 + i1*nc] = s0[i0 + i1*ne0s0]; + s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; } - // insert x on the last column - s[(nc - 1) + i1*nc] = x0[i1]; } - // it seems a little faster when this is separate from the state shift - for (int i1 = 0; i1 < ir; ++i1) { - // rowwise dot product - float sumf = 0.0f; - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - sumf += s[i] * c[i]; + for (int i2 = 0; i2 < n_t; ++i2) { + float * x = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s} + float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} + + // shift state left + memmove(s, s + 1, (nc*ir - 1) * sizeof(float)); + + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // insert x on the last column + s[(nc - 1) + i1*nc] = x0[i1]; + } + + // it seems a little faster when this is separate from the state shift + for (int i1 = 0; i1 < ir; ++i1) { + // rowwise dot product + // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision + float sumf = 0.0f; + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + sumf += s[i] * c[i]; + } + x[i1] = sumf; + } + } + + // copy the state out of it + for (int i1 = 0; i1 < ir; ++i1) { + for (int i0 = 0; i0 < nc - 1; ++i0) { + s0[i0 + i1*(nc - 1)] = s[1 + i0 + i1*nc]; } - x[i1] = sumf; } } } @@ -16368,30 +16350,24 @@ static void ggml_compute_forward_ssm_scan_f32( const struct ggml_tensor * src3 = dst->src[3]; // A const struct ggml_tensor * src4 = dst->src[4]; // B const struct ggml_tensor * src5 = dst->src[5]; // C - const struct ggml_tensor * src6 = dst->src[6]; // sq const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens in the batch - const int64_t n_rs = src0->ne[2]; // max number of sequences in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = src1->ne[1]; // number of tokens per sequence + const int64_t n_s = src0->ne[2]; // number of sequences in the batch - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + GGML_ASSERT(ggml_nelements(src1) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); - // required for the dot product between s and C, and when copying the states + // required for the dot product between s and C GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[2]) - GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float)); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -16401,55 +16377,33 @@ static void ggml_compute_forward_ssm_scan_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - const int32_t * sq = src6->data; // {n_tokens} + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + float * y = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s} + float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + float * B = (float *) ((char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + float * C = (float *) ((char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - if (n_rs > 1) { - // it's hard to know if the source states have already been copied - // when there are multiple, so copy them already. - for (int i3 = 0; i3 < n_rs; ++i3) { - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); - memcpy(s, s0, nc*ir*sizeof(float)); - } - } - - for (int i2 = 0; i2 < n_t; ++i2) { - int32_t sq_i = sq[i2]; - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_rs} - float * s0; - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} - float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} - - GGML_ASSERT(0 <= sq_i && sq_i < n_rs); - - // avoid needing to copy the state for the first token - if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_state, d_inner, n_rs} - } else { - // otherwise the source is the same as the destination - s0 = s; - } - - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // state = prev_state * dA + dB * x + float state = (s[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[i0]; + s[i] = state; + } + y[i1] = sumf; } - y[i1] = sumf; } } } @@ -19614,7 +19568,13 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_SSM_CONV: + { + const int64_t d_conv = node->src[2]->ne[0]; + const int64_t d_inner = node->src[0]->ne[1]; + cur += sizeof(float)*d_conv*(d_inner + n_tasks - 1); + } break; case GGML_OP_CROSS_ENTROPY_LOSS: { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); diff --git a/ggml.h b/ggml.h index 4e6bcb30f..bdf05a311 100644 --- a/ggml.h +++ b/ggml.h @@ -1793,8 +1793,7 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * s, struct ggml_tensor * x, - struct ggml_tensor * c, - struct ggml_tensor * sq); + struct ggml_tensor * c); GGML_API struct ggml_tensor * ggml_ssm_scan( struct ggml_context * ctx, @@ -1803,8 +1802,7 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C, - struct ggml_tensor * sq); + struct ggml_tensor * C); // partition into non-overlapping windows with padding if needed // example: diff --git a/llama.cpp b/llama.cpp index 27374c185..ca64b7e29 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2114,6 +2114,24 @@ struct llama_layer { struct ggml_tensor * rope_short = nullptr; }; +// very similar to llama_batch, +// but has more metadata about sequences +struct llama_ubatch { + bool equal_seqs; + + uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) + uint32_t n_seq_tokens; // tokens per sequence + uint32_t n_seqs; + + llama_token * token; // [n_tokens] + float * embd; // [n_embd, n_tokens] + llama_pos * pos; // [n_tokens] + // FIXME: make all uses of this use n_seqs + int32_t * n_seq_id; // [n_seqs] + llama_seq_id ** seq_id; // [n_seqs] + int8_t * output; // [n_tokens] +}; + struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; @@ -2223,17 +2241,15 @@ struct llama_rs_cell { } }; - struct llama_rs_seq_meta { // cell id of the latest state of this seq_id int32_t tail = -1; // number of cells for which this seq_id is the first // (useful to know if cells in this sequence should be pruned) int32_t n_cells = 0; - // changing the tail cell of a sequence can only be done at batch boundary, - // this guards against changing the cell when it shouldn't be; - // should be cleared when done finding a slot - bool in_ubatch = false; + // the last pos of this sequence if it is in the current ubatch, + // only set and used when finding a slot. + llama_pos ubatch_end_pos = -1; }; // ring-buffered tree of cached recurrent state data @@ -2261,6 +2277,10 @@ struct llama_rs_cache { // find tail cells faster std::vector seq_tails; // map seq_ids to cell ids + // freeable cell ids, computed when finding a slot + // useful to find the smallest range to defrag + std::vector freeable; + // per layer // NOTE: the naming of r and s is arbitrary std::vector r_l; // rolling/shift states @@ -2399,8 +2419,8 @@ struct llama_rs_cache { if (seq_node->next_cell != next) { // TODO: relax the error when multiple cells have the same pos if (debug) { - LLAMA_LOG_ERROR("%s: invalid next cell for cells[%u] (%d instead of %d)\n", - __func__, cell_id, seq_node->next_cell, next); + LLAMA_LOG_ERROR("%s: invalid next cell for seq_id %d in cells[%u] (%d instead of %d)\n", + __func__, seq_id, cell_id, seq_node->next_cell, next); } seq_node->next_cell = next; was_valid = false; @@ -2414,15 +2434,6 @@ struct llama_rs_cache { } seq.n_cells = n_cells; } - // in_batch should only be true when in the process of finding a slot - if (seq.in_ubatch != false) { - if (debug) { - LLAMA_LOG_ERROR("%s: in_ubatch was true while it should have been false for seq_id %d\n", - __func__, seq_id); - } - seq.in_ubatch = false; - was_valid = false; - } } // tail_rc for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { @@ -2475,6 +2486,88 @@ struct llama_rs_cache { return was_valid; } + // each seq_id should have access to at least this many cells + // (to use when pruning (to avoid over-pruning)) + uint32_t min_cells_per_seq(const llama_ubatch & batch) const { + uint32_t seqs = n_seqs; + for (uint32_t i = 0; i < batch.n_seqs; ++i) { + llama_seq_id seq_id = batch.seq_id[i][0]; + const llama_rs_seq_meta & new_seq = seq_tails[seq_id]; + if (new_seq.tail < 0 || new_seq.n_cells == 0) { + seqs += 1; + } + } + return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1); + } + + void freeable_for_batch(const llama_ubatch & batch, llama_pos checkpoint_interval) { + GGML_ASSERT(batch.equal_seqs); + int32_t min_cells = min_cells_per_seq(batch); + + // TODO: minimize work required to find freeable cells + // currently, this finds freeable cells by excluding non-freeable cells, + // because some conditions are more easily expressed this way. + + freeable.assign(size, 1); + + for (llama_rs_seq_meta & seq : seq_tails) { + seq.ubatch_end_pos = -1; + } + + for (uint32_t i = 0; i < batch.n_seqs; ++i) { + int32_t n_seq_id = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_id; j++) { + llama_seq_id seq_id = batch.seq_id[i][j]; + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_tails.size()); + llama_rs_seq_meta & seq = seq_tails[seq_id]; + seq.ubatch_end_pos = batch.pos[i * batch.n_seq_tokens + batch.n_seq_tokens - 1]; + } + } + + for (llama_rs_seq_meta & seq : seq_tails) { + if (seq.tail >= 0 && freeable[seq.tail] != 0) { + llama_pos end_pos = seq.ubatch_end_pos; + // When is a tail cell not freeable? + if (end_pos < 0) { + // when any of its tails are not in the batch + freeable[seq.tail] = 0; + } else if (min_cells > 1) { + // TODO: fallback to this less often + llama_rs_cell & tail = cells[seq.tail]; + GGML_ASSERT(tail.pos < end_pos); + if (tail.prev < 0 || tail.pos + checkpoint_interval <= end_pos) { + // make a checkpoint before prompt processing + // TODO: should it always be done after instead? + freeable[seq.tail] = 0; + } else { + llama_rs_cell & prev = cells[tail.prev]; + if (prev.pos + checkpoint_interval <= end_pos) { + // make a checkpoint during text generation + freeable[seq.tail] = 0; + } + } + } + } + } + + for (uint32_t i = 0; i < size; ++i) { + llama_rs_cell & cell = cells[i]; + if (!cell.is_empty() && cell.tail_rc == 0) { + // TODO: reduce indirection here + llama_rs_seq_node & seq_node = cell.seq_nodes[0]; + llama_rs_seq_meta & seq = seq_tails[seq_node.seq_id]; + bool keep_tail = freeable[seq.tail] == 0; + // kept tails use an additional cell, so make them allow freeing a checkpoint + int32_t really_min_cells = keep_tail ? min_cells - 1 : min_cells; + // A checkpoint is kept if there's enough alloted space for this sequence + // or if it's the state right before the tail + if (seq.n_cells <= really_min_cells || (really_min_cells > 1 && seq_node.next_cell == seq.tail)) { + freeable[i] = 0; + } + } + } + } + // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. // Why an iterator? Because it allows using std::vector::erase. std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { @@ -2496,22 +2589,30 @@ struct llama_rs_cache { prev_node->next_cell = node.next_cell; if (node.is_tail()) { // move the tail back to the previous cell + prev_cell.tail_rc += 1; if (prev_cell.seq_nodes.size() > 1) { if (rs_cell.tail_rc == rs_cell.seq_nodes.size()) { - if (prev_cell.tail_rc == 0) { + if (prev_cell.tail_rc == 1) { n_shared_tail_cells += 1; } - // o oo oo - // |/ -> o/ - // | | - // e.g. when removing the leaf with a single tail - if (rs_cell.tail_rc == 1 && prev_cell.tail_rc != prev_cell.seq_nodes.size()) { - n_seqs -= 1; + if (rs_cell.tail_rc == 1) { + if (prev_cell.tail_rc != prev_cell.seq_nodes.size()) { + // o oo oo + // |/ -> o/ + // | | + // e.g. when removing the leaf of a split tree + n_seqs -= 1; + } else { + // o + // o -> oo + // | | + // e.g. when merging back with a previous tail + n_shared_tail_cells -= 1; + } } } } - prev_cell.tail_rc += 1; } } if ((uint32_t) node.seq_id < seq_tails.size()) { @@ -2534,6 +2635,7 @@ struct llama_rs_cache { // will fully become a tail cell if (rs_cell.tail_rc > 0) { n_seqs += 1; + n_shared_tail_cells -= 1; } } if (node_iter == rs_cell.seq_nodes.begin()) { @@ -2583,14 +2685,107 @@ struct llama_rs_cache { return false; } - bool insert_seq_tail_to_cell_id(uint32_t i_cell, const llama_seq_id & id) { + bool swap_cells(uint32_t i_src, uint32_t i_dst) { + if (i_src < size && i_dst < size && i_src != i_dst) { + llama_rs_cell & src = cells[i_src]; + llama_rs_cell & dst = cells[i_dst]; + + for (llama_rs_seq_node & seq_node : src.seq_nodes) { + if (seq_node.next_cell >= 0) { + llama_rs_cell & next = cells[seq_node.next_cell]; + next.prev = i_dst; + if ((uint32_t) seq_node.next_cell == i_dst) { + seq_node.next_cell = i_src; + } + } else { + // this is a tail + seq_tails[seq_node.seq_id].tail = i_dst; + } + } + for (llama_rs_seq_node & seq_node : dst.seq_nodes) { + if (seq_node.next_cell >= 0) { + llama_rs_cell & next = cells[seq_node.next_cell]; + next.prev = i_src; + if ((uint32_t) seq_node.next_cell == i_src) { + seq_node.next_cell = i_dst; + } + } else { + // this is a tail + seq_tails[seq_node.seq_id].tail = i_src; + } + } + + if (src.prev == dst.prev) { + // avoid swapping them twice + if (src.prev >= 0) { + llama_rs_cell & prev = cells[src.prev]; + for (llama_rs_seq_node & seq_node : prev.seq_nodes) { + if ((uint32_t) seq_node.next_cell == i_src) { + seq_node.next_cell = i_dst; + } else if ((uint32_t) seq_node.next_cell == i_dst) { + seq_node.next_cell = i_src; + } + } + } + } else { + if (src.prev >= 0) { + llama_rs_cell & prev = cells[src.prev]; + for (llama_rs_seq_node & seq_node : prev.seq_nodes) { + if ((uint32_t) seq_node.next_cell == i_src) { + seq_node.next_cell = i_dst; + } + } + } + if (dst.prev >= 0) { + llama_rs_cell & prev = cells[dst.prev]; + for (llama_rs_seq_node & seq_node : prev.seq_nodes) { + if ((uint32_t) seq_node.next_cell == i_dst) { + seq_node.next_cell = i_src; + } + } + } + } + + std::swap(src.pos, dst.pos); + std::swap(src.src, dst.src); + std::swap(src.prev, dst.prev); + std::swap(src.tail_rc, dst.tail_rc); + std::swap(src.seq_nodes, dst.seq_nodes); + + return true; + } + return false; + } + + bool insert_seq_tail_to_cell_id(uint32_t i_cell, llama_seq_id id, llama_pos end_pos = -1) { if (i_cell < size && (size_t) id < seq_tails.size()) { llama_rs_cell & rs_cell = cells[i_cell]; auto & seq = seq_tails[id]; int32_t prev = rs_cell.prev; + if (end_pos >= 0) { + if (end_pos <= rs_cell.pos) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", + __func__, end_pos, rs_cell.pos, id); + } + rs_cell.pos = end_pos; + } else { + // if no pos was specified, then the target cell should already have a valid one. + GGML_ASSERT(!rs_cell.is_empty()); + } if ((uint32_t) seq.tail == i_cell) { // the cell is already the tail of this seq_id - return false; + if (rs_cell.tail_rc != rs_cell.seq_nodes.size()) { + GGML_ASSERT(end_pos >= 0); // make sure this is the first re-added seq_id + // remove non-tail seq_ids (branch off them) + for (size_t i = rs_cell.seq_nodes.size(); i-- > 0;) { + if (!rs_cell.seq_nodes[i].is_tail()) { + remove_seq_node_from_cell(rs_cell, rs_cell.seq_nodes.begin() + i); + } + } + } + return true; } if (rs_cell.is_empty()) { prev = seq.tail; @@ -2603,9 +2798,7 @@ struct llama_rs_cache { auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id); GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken - if (rs_cell.pos < 0) { - GGML_ASSERT(rs_cell.is_empty()); - rs_cell.pos = prev_cell.pos + 1; + if (rs_cell.is_empty()) { rs_cell.src = prev_cell.src; } prev_node->next_cell = i_cell; @@ -2650,8 +2843,7 @@ struct llama_rs_cache { if (seq.tail < 0) { // from empty to unique n_seqs += 1; - // pos was not yet set - rs_cell.pos = 0; + // make sure it's cleared rs_cell.src = -1; } used += 1; @@ -2671,16 +2863,6 @@ struct llama_rs_cache { return false; } - // each seq_id should have access to at least this many cells - // (to use when pruning (to avoid over-pruning)) - size_t min_cells_per_seq(const llama_rs_seq_meta & new_seq) const { - uint32_t seqs = n_seqs; - if (new_seq.tail < 0 || new_seq.n_cells == 0) { - seqs += 1; - } - return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1); - } - size_t total_size() const { size_t size = 0; for (struct ggml_tensor * r : r_l) { @@ -2883,22 +3065,6 @@ struct llama_model { } }; -// very similar to llama_batch, -// but has more metadata about sequences -struct llama_ubatch { - bool equal_seqs; - - int32_t n_tokens; - int32_t n_seqs; - - llama_token * token; - float * embd; - llama_pos * pos; - int32_t * n_seq_id; - llama_seq_id ** seq_id; - int8_t * output; -}; - struct llama_sbatch_seq { int32_t n_seq_id; llama_seq_id * seq_id; @@ -2954,6 +3120,7 @@ struct llama_sbatch { true, 0, 0, + 0, !has_embd ? ubatch_token.data() : nullptr, has_embd ? ubatch_embd.data() : nullptr, ubatch_pos.data(), @@ -2967,16 +3134,14 @@ struct llama_sbatch { void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) { GGML_ASSERT(batch != nullptr); GGML_ASSERT(length <= seq.length); - if (ubatch.equal_seqs) { - // is the new sequence of a different size than expected? - if (ubatch.n_seqs > 0 && length != (size_t) ubatch.n_tokens / ubatch.n_seqs) { - ubatch.equal_seqs = false; - } - } + // Can only add sequences of equal lengths to a batch, + // otherwise it isn't clear to which sequence a token belongs + GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs); + GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); // NOTE: loops are separated for cache-friendliness if (batch->token) { for (size_t i = 0; i < length; ++i) { - ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; } } else { ubatch.token = nullptr; @@ -3004,22 +3169,32 @@ struct llama_sbatch { ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1); } } - if (batch->n_seq_id) { - for (size_t i = 0; i < length; ++i) { - ubatch.n_seq_id[ubatch.n_tokens + i] = batch->n_seq_id[ids[seq.offset + i]]; + if (seq.n_seq_id > 0) { + ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; + if (seq.seq_id) { + ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; + } else { + GGML_ASSERT(seq.n_seq_id == 1); + ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id; } } else { - for (size_t i = 0; i < length; ++i) { - ubatch.n_seq_id[ubatch.n_tokens + i] = 1; + if (batch->n_seq_id) { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_seqs + i] = batch->n_seq_id[ids[seq.offset + i]]; + } + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_seqs + i] = 1; + } } - } - if (batch->seq_id) { - for (size_t i = 0; i < length; ++i) { - ubatch.seq_id[ubatch.n_tokens + i] = batch->seq_id[ids[seq.offset + i]]; - } - } else { - for (size_t i = 0; i < length; ++i) { - ubatch.seq_id[ubatch.n_tokens + i] = &seq.all_seq_id; + if (batch->seq_id) { + for (size_t i = 0; i < length; ++i) { + ubatch.seq_id[ubatch.n_seqs + i] = batch->seq_id[ids[seq.offset + i]]; + } + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id; + } } } if (batch->logits) { @@ -3043,11 +3218,15 @@ struct llama_sbatch { if (is_last) { out_ids.push_back(id); } } } + if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { + ubatch.n_seq_tokens = seq.n_seq_id > 0 ? length : 1; + } ubatch.n_tokens += length; - ubatch.n_seqs += seq.n_seq_id != 0; // don't count seq_ids for legacy splits + ubatch.n_seqs += seq.n_seq_id > 0 ? 1 : length; // virtual sequences for legacy splits seq.offset += length; seq.length -= length; n_tokens -= length; + GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); } // legacy split, unknown number of sequences of unequal lengths @@ -3283,7 +3462,6 @@ struct llama_context { struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_copy; // I32 [n_rs] struct ggml_tensor * inp_s_mask; // F32 [1, n_rs] - struct ggml_tensor * inp_s_seq; // I32 [n_batch] // control vectors struct llama_control_vector cvec; @@ -3426,6 +3604,7 @@ static bool llama_cache_init( cache.rs.cells.resize(rs_size); cache.rs.seq_tails.clear(); cache.rs.seq_tails.resize(rs_size); + cache.rs.freeable.reserve(rs_size); #ifdef GGML_USE_CLBLAST offload = false; @@ -3524,11 +3703,13 @@ static bool llama_cache_find_slot( const uint32_t kv_size = cache.kv.size; const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; + const uint32_t n_seqs = batch.n_seqs; + const uint32_t n_seq_tokens = batch.n_seq_tokens; // only check first, to allow failing gracefully if (rs_size > 0) { // everything should fit if all seq_ids are smaller than the max - for (uint32_t i = 0; i < n_tokens; ++i) { + for (uint32_t i = 0; i < n_seqs; ++i) { int32_t n_seq_id = batch.n_seq_id[i]; for (int32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id = batch.seq_id[i][j]; @@ -3541,6 +3722,23 @@ static bool llama_cache_find_slot( } } } + // TODO: configurable checkpoint interval + cache.rs.freeable_for_batch(batch, 8); + { + uint32_t freeable_rs_cell_count = 0; + for (uint32_t is_freeable : cache.rs.freeable) { + freeable_rs_cell_count += (uint32_t) (is_freeable != 0); + if (freeable_rs_cell_count >= n_seqs) { + // there's enough, no need to count them all + break; + } + } + if (n_seqs > freeable_rs_cell_count) { + // This should not happen + LLAMA_LOG_ERROR("%s: n_seqs=%d > freeable_rs_cell_count=%d\n", __func__, n_seqs, freeable_rs_cell_count); + return false; + } + } } if (kv_size > 0) { @@ -3591,172 +3789,146 @@ static bool llama_cache_find_slot( if (rs_size > 0) { // For recurrent state architectures (like Mamba), // each cache cell can store the state for a whole sequence. - // TODO: find a way to always make the rs slot contiguous + // A slot should be always be contiguous. - llama_seq_id min_seq = cache.rs.size - 1; - llama_seq_id max_seq = 0; - uint32_t min_cell = cache.rs.size - 1; - uint32_t max_cell = 0; + uint32_t min_head = 0; + uint32_t min_n = cache.rs.size; + uint32_t min_free = 0; - for (uint32_t i = 0; i < n_tokens; ++i) { - int32_t target_cell = -1; // ensure all the sequences of a token get the same cell - int32_t n_seq_ids = batch.n_seq_id[i]; - for (int32_t j = 0; j < n_seq_ids; ++j) { - llama_seq_id seq_id = batch.seq_id[i][j]; - bool need_new_cell = false; - // Everything should fit assuming the biggest seq_id < rs_size - GGML_ASSERT((uint32_t) seq_id < rs_size); - llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; - if (seq_id > max_seq) { max_seq = seq_id; } - if (seq_id < min_seq) { min_seq = seq_id; } + // compact the freeable cell list + // e.g. 0,1,0,0,1,1,0,1,0,1 -> 1,4,5,7,9 + // while also finding the smallest cell range for the slot + { + uint32_t next_free = 0; + for (size_t i = 0; i < cache.rs.freeable.size(); ++i) { + if (cache.rs.freeable[i]) { + cache.rs.freeable[next_free] = i; + next_free += 1; - if (!seq.in_ubatch && target_cell >= 0) { - // never saw this seq_id before, - // but there's already a cell reserved for this token, use it - cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id); - } else if (seq.tail < 0) { - // this seq_id has no tail (and is empty) - need_new_cell = true; - } else { - llama_rs_cell & tail = cache.rs.cells[seq.tail]; - if (seq.in_ubatch) { - // this seq_id was already seen before in the batch - // assuming the tail cell already "has" this seq_id - tail.pos += 1; - target_cell = seq.tail; - } else { - // first time this sequence is seen, - // there's no reserved cell yet; - // if it's not the first sequence of the token, how could it even get here? - GGML_ASSERT(j == 0); - - bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids; - if (has_same_seqs) { - // the tail cell of a seq_id is assumed to already be part of the seq_id, - // hence the skip of the first seq_id - for (int32_t k = 1; k < n_seq_ids; ++k) { - if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) { - has_same_seqs = false; - } + if (next_free >= n_seqs) { + uint32_t head = cache.rs.freeable[next_free - n_seqs]; + // i is the last seen freeable cell id + uint32_t n = i - head + 1; + // keep the first smallest big enough slot + if (n < min_n) { + min_free = next_free - n_seqs; + min_head = head; + min_n = n; + if (n == n_seqs) { + // it's the smallest it can be + break; } } - - // TODO: make the checkpoint interval configurable - if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) { - // a checkpoint should be saved - need_new_cell = true; - } else { - // re-use last tail - tail.pos += 1; - target_cell = seq.tail; - } } } - - // reserve a cell for this seq_id - if (need_new_cell && target_cell < 0) { - const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq); - - uint32_t cell_id = cache.rs.size; - bool looped_once = false; - - while (true) { - if (cache.rs.head >= cache.rs.size) { - cache.rs.head = 0; - // avoid infinite loop - // NOTE: this should not fail; if it does, it's a bug. - GGML_ASSERT(!looped_once && "recurrent state cache seems full, but should not."); - looped_once = true; - } - cell_id = cache.rs.head; - llama_rs_cell & candidate = cache.rs.cells[cell_id]; - if (candidate.is_empty()) { break; } - if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { - // the candidate is the old tail - if (candidate.seq_nodes.size() > 1) { - // prune out the other seq_ids, because they diverge - // TODO(maybe): hande this in insert_seq_tail_to_cell_id - // (hopefully doesn't happen too often) - for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) { - if (node_iter->seq_id == seq_id) { - node_iter = std::next(node_iter); - } else { - node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter); - } - } - } - // re-use the tail cell to avoid not finding anything - candidate.pos += 1; - break; - } - if (candidate.tail_rc > 0) { - // skip tails of other sequences - cache.rs.head += 1; - continue; - } - if (candidate.seq_nodes.size() > 1) { - // shared prompts are not usually backtracked, so they can be pruned - cache.rs.clear_cell(candidate); - break; - } - - // prune too-long sequences - llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id; - if (seq_id_to_prune == seq_id) { - // TODO: selectively skip some cells to keep older states - cache.rs.clear_cell(candidate); - break; - } - GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size()); - auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune]; - if (seq_to_prune.n_cells > min_cells_per_seq) { - cache.rs.clear_cell(candidate); - break; - } - cache.rs.head += 1; - } - if (cell_id < cache.rs.size) { - cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id); - target_cell = cell_id; - } - } - - if (seq.tail >= 0) { - if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; } - if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; } - seq.in_ubatch = true; - } - - // Assuming the tokens are in-order - if (batch.pos[i] != cache.rs.cells[seq.tail].pos) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id); - } } - cache.rs.head = target_cell + 1; } - for (llama_seq_id i = min_seq; i <= max_seq; ++i) { - // make sure it's cleared for next time - cache.rs.seq_tails[i].in_ubatch = false; + // sanity check + GGML_ASSERT(min_head + min_n <= cache.rs.size); + + // keep only the necessary range + cache.rs.freeable.resize(min_free + n_seqs); + cache.rs.freeable.erase(cache.rs.freeable.begin(), cache.rs.freeable.begin() + min_free); + GGML_ASSERT(cache.rs.freeable.size() == n_seqs); + GGML_ASSERT(min_n >= n_seqs); + cache.rs.freeable.resize(min_n); + + // expand the free list + // e.g. 2,4,5,8 -> 1,0,1,1,0,0,1 + for (uint32_t i = n_seqs; i-- > 0;) { + uint32_t dst = cache.rs.freeable[i] - min_head; + if (dst != i) { + cache.rs.freeable[i] = 0; + } + GGML_ASSERT(dst >= i); + cache.rs.freeable[dst] = 1; + } + + // coalesce the free cells together + // e.g. 1,0,1,1,0,0,1 -> 1,1,1,1,0,0,0 + // or 1,0,1,1,1,1 -> 1,1,1,1,1,0 + { + uint32_t top_free = min_n - 1; + for (uint32_t i = min_n; i-- > 1;) { + uint32_t is_free = cache.rs.freeable[i]; + if (!is_free) { + GGML_ASSERT(top_free > i); + cache.rs.swap_cells(min_head + i, min_head + top_free); + std::swap(cache.rs.freeable[i], cache.rs.freeable[top_free]); + // the previous one has to be free, + // otherwise it would already have been swapped. + top_free -= 1; + } + // stop early if all freeable cells have already been put at the beginning + if (top_free < n_seqs) { break; } + } + } + + // order the re-used cells identically to their batch order + // (and clear the non-reused cells) + { + for (uint32_t i = 0; i < n_seqs; ++i) { + // ignore the already-swapped cells + if (cache.rs.freeable[i]) { + llama_rs_cell & cell = cache.rs.cells[min_head + i]; + if (!cell.is_empty()) { + if (cell.tail_rc == 0) { + cache.rs.clear_cell(cell); + } else { + // TODO: does this always work correctly + // even if there are more than one seq_node in this cell? + + // Which seq_id of the batch is it? + llama_seq_id seq_id = cell.seq_nodes[0].seq_id; + int32_t nth_seq_id = -1; + for (int32_t s = 0; (uint32_t) s < n_seqs; ++s) { + if (seq_id == batch.seq_id[s][0]) { + nth_seq_id = s; + break; + } + } + GGML_ASSERT(nth_seq_id != -1); + + cache.rs.swap_cells(min_head + i, min_head + nth_seq_id); + cache.rs.freeable[i] = 0; + std::swap(cache.rs.freeable[i], cache.rs.freeable[nth_seq_id]); + i -= 1; // check this cell again, now that it was swapped + } + } + } + } + } + + // reserve + { + for (uint32_t i = 0; i < n_seqs; ++i) { + uint32_t i_cell = min_head + i; + int32_t n_seq_id = batch.n_seq_id[i]; + llama_pos end_pos = batch.pos[(i * n_seq_tokens) + n_seq_tokens - 1]; + // set the pos with the first seq_id + cache.rs.insert_seq_tail_to_cell_id(i_cell, batch.seq_id[i][0], end_pos); + // insert the rest of the seq_ids by re-using the cell's pos + for (int j = 1; j < n_seq_id; ++j) { + cache.rs.insert_seq_tail_to_cell_id(i_cell, batch.seq_id[i][j]); + } + } } // allow getting the range of used cells, from head to head + n - cache.rs.head = min_cell; - cache.rs.n = max_cell - min_cell + 1; - - // sanity check - GGML_ASSERT(min_seq <= max_seq && min_cell <= max_cell); + cache.rs.head = min_head; + cache.rs.n = min_n; } if (kv_size > 0) { - for (uint32_t i = 0; i < n_tokens; i++) { - cache.kv.cells[cache.kv.head + i].pos = batch.pos[i]; + for (uint32_t s = 0; s < n_seqs; s++) { + for (uint32_t i = 0; i < n_seq_tokens; ++i) { + uint32_t k = s*n_seq_tokens + i; + cache.kv.cells[cache.kv.head + k].pos = batch.pos[k]; - for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { - cache.kv.cells[cache.kv.head + i].seq_id.insert(batch.seq_id[i][j]); + for (int32_t j = 0; j < batch.n_seq_id[s]; j++) { + cache.kv.cells[cache.kv.head + k].seq_id.insert(batch.seq_id[s][j]); + } } } @@ -8492,16 +8664,15 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_ubatch & batch, const llama_rs_cache & rs, struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, struct ggml_tensor * state_mask, - struct ggml_tensor * state_seq, struct ggml_tensor * w_dt_norm, struct ggml_tensor * w_b_norm, struct ggml_tensor * w_c_norm, - int32_t n_tokens, int32_t rs_head, int32_t n_rs, const llm_build_cb & cb, @@ -8510,14 +8681,23 @@ static struct ggml_tensor * llm_build_mamba( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_seqs = batch.n_seqs; - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx, rs.r_l[il], hparams.n_embd_r(il), rs.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx, rs.s_l[il], hparams.n_embd_s(il), rs.size); + const int64_t n_seq_tokens = batch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs); + + struct ggml_tensor * conv_states_all = rs.r_l[il]; + struct ggml_tensor * ssm_states_all = rs.s_l[il]; + + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx, conv_states_all, hparams.n_embd_r(il), rs.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx, ssm_states_all, hparams.n_embd_s(il), rs.size); // copy states { - // TODO: use some sort of read-only head and n to pass smaller tensors to ggml_get_rows - // NOTE: assuming the copy destinations are ALL contained in the current batch + // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs // this shrinks the tensors's ne[1] to n_rs conv_states = ggml_get_rows(ctx, conv_states, state_copy); ssm_states = ggml_get_rows(ctx, ssm_states, state_copy); @@ -8532,17 +8712,24 @@ static struct ggml_tensor * llm_build_mamba( conv_states = ggml_reshape_3d(ctx, conv_states, d_conv - 1, d_inner, n_rs); ssm_states = ggml_reshape_3d(ctx, ssm_states, d_state, d_inner, n_rs); - // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} + struct ggml_tensor * conv = ggml_view_3d(ctx, conv_states, d_conv - 1, d_inner, n_seqs, conv_states->nb[1], conv_states->nb[2], 0); + struct ggml_tensor * ssm = ggml_view_3d(ctx, ssm_states, d_state, d_inner, n_seqs, ssm_states->nb[1], ssm_states->nb[2], 0); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} struct ggml_tensor * xz = ggml_mul_mat(ctx, model.layers[il].ssm_in, cur); // split the above in two - // => {d_inner, n_tokens} - struct ggml_tensor * x = ggml_view_2d(ctx, xz, d_inner, xz->ne[1], xz->nb[1], 0); - struct ggml_tensor * z = ggml_view_2d(ctx, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); + // => {d_inner, n_seq_tokens, n_seqs} + struct ggml_tensor * x = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); + struct ggml_tensor * z = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz)); // conv { - // Custom operator which is needed only to ease simultaneous sequence processing. - // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, + // Custom operator, which is needed because self-overlapping views aren't yet well supported by ggml. + // And also because this uses much less memory for large batches (4 times less when d_conv is 4). + // The equivalent is to concatenate the columns of conv_states and x, // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, // then element-wise multiply that with the conv1d weigth, // then sum the elements of each row, @@ -8551,17 +8738,17 @@ static struct ggml_tensor * llm_build_mamba( // and then you're left with the resulting x tensor. // The new conv_states is the last (d_conv - 1) columns // of the last 3rd dimensional "layer" of the self-overlapping view. - // For simultaneous sequences, it's more complicated. - struct ggml_tensor * x_conv = ggml_ssm_conv(ctx, conv_states, x, model.layers[il].ssm_conv1d, state_seq); + // For simultaneous sequences, all sequences need to have the same length. + x = ggml_ssm_conv(ctx, conv, x, model.layers[il].ssm_conv1d); + + // ensure conv is updated before copying into the recurrent state cache + ggml_build_forward_expand(graph, x); - // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache ggml_build_forward_expand(graph, - ggml_cpy(ctx, - ggml_view_2d(ctx, x_conv, d_conv - 1, d_inner*n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), - ggml_view_1d(ctx, rs.r_l[il], (d_conv - 1)*(d_inner)*(n_rs), rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); - - // extract x from x_conv - x = ggml_view_2d(ctx, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); + ggml_cpy(ctx, conv_states, + ggml_view_1d(ctx, conv_states_all, + (d_conv - 1)*(d_inner)*(n_rs), + rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); // bias x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); @@ -8571,45 +8758,47 @@ static struct ggml_tensor * llm_build_mamba( // ssm { - // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} + // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} struct ggml_tensor * x_db = ggml_mul_mat(ctx, model.layers[il].ssm_x, x); // split - struct ggml_tensor * dt = ggml_view_2d(ctx, x_db, dt_rank, n_tokens, x_db->nb[1], 0); - struct ggml_tensor * B = ggml_view_2d(ctx, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_2d(ctx, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); + struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); + struct ggml_tensor * B = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); if (w_dt_norm) { dt = llm_build_norm(ctx, dt, hparams, w_dt_norm, NULL, LLM_NORM_RMS, cb, il); } if (w_b_norm) { B = llm_build_norm(ctx, B, hparams, w_b_norm, NULL, LLM_NORM_RMS, cb, il); } if (w_c_norm) { C = llm_build_norm(ctx, C, hparams, w_b_norm, NULL, LLM_NORM_RMS, cb, il); } - // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} + // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} dt = ggml_mul_mat(ctx, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. - // => {d_inner, n_tokens} and {d_state, d_inner, n_rs} combined, - // because only a single tensor can be returned. - struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); + // => {d_inner, n_seq_tokens, n_seqs} + struct ggml_tensor * y = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); - // store last states (the second part of y_ssm_states) + // The ssm scan also changes the state, ensure it's done before copying to the recurrent state cache + ggml_build_forward_expand(graph, y); + + // store last states ggml_build_forward_expand(graph, - ggml_cpy(ctx, - ggml_view_1d(ctx, y_ssm_states, d_state*d_inner*n_rs, d_inner*n_tokens*ggml_element_size(y_ssm_states)), - ggml_view_1d(ctx, rs.s_l[il], d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states)))); + ggml_cpy(ctx, ssm_states, + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - struct ggml_tensor * y = ggml_view_2d(ctx, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); + // TODO: skip computing output earlier for unused tokens - // TODO: skip computing output for unused tokens - - // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} + // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, z)); - // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_mul_mat(ctx, model.layers[il].ssm_out, y); } + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs); + return cur; } @@ -8642,6 +8831,8 @@ struct llm_build_context { const float norm_eps; const float norm_rms_eps; + const int32_t n_seqs; + const int32_t n_seq_tokens; const int32_t n_tokens; const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) const int32_t n_rs; @@ -8692,6 +8883,8 @@ struct llm_build_context { beta_slow (cparams.yarn_beta_slow), norm_eps (hparams.f_norm_eps), norm_rms_eps (hparams.f_norm_rms_eps), + n_seqs (batch.n_seqs), + n_seq_tokens (batch.n_seq_tokens), n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), n_rs (worst_case ? rs_self.size : rs_self.n), @@ -8726,7 +8919,6 @@ struct llm_build_context { lctx.inp_cls = nullptr; lctx.inp_s_copy = nullptr; lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; } void free() { @@ -8898,13 +9090,6 @@ struct llm_build_context { return lctx.inp_s_mask; } - struct ggml_tensor * build_inp_s_seq() { - lctx.inp_s_seq = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - cb(lctx.inp_s_seq, "inp_s_seq", -1); - ggml_set_input(lctx.inp_s_seq); - return lctx.inp_s_seq; - } - struct ggml_cgraph * build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -12017,7 +12202,6 @@ struct llm_build_context { struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_mask = build_inp_s_mask(); - struct ggml_tensor * state_seq = build_inp_s_seq(); for (int il = 0; il < n_layer; ++il) { // norm @@ -12026,9 +12210,9 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - cur = llm_build_mamba(ctx0, model, hparams, rs_self, gf, cur, - state_copy, state_mask, state_seq, NULL, NULL, NULL, - n_tokens, rs_head, n_rs, cb, il); + cur = llm_build_mamba(ctx0, model, hparams, batch, rs_self, gf, cur, + state_copy, state_mask, NULL, NULL, NULL, + rs_head, n_rs, cb, il); if (il == n_layer - 1) { // skip computing output for unused tokens @@ -12074,7 +12258,6 @@ struct llm_build_context { struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_mask = build_inp_s_mask(); - struct ggml_tensor * state_seq = build_inp_s_seq(); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); @@ -12089,10 +12272,9 @@ struct llm_build_context { if (n_head_kv == 0) { // Mamba - cur = llm_build_mamba(ctx0, model, hparams, rs_self, gf, cur, - state_copy, state_mask, state_seq, + cur = llm_build_mamba(ctx0, model, hparams, batch, rs_self, gf, cur, state_copy, state_mask, model.layers[il].ssm_dt_norm, model.layers[il].ssm_b_norm, model.layers[il].ssm_c_norm, - n_tokens, rs_head, n_rs, cb, il); + rs_head, n_rs, cb, il); } else { // Attention @@ -12152,6 +12334,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, false, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); } @@ -13234,8 +13417,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { if (lctx.inp_KQ_mask) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (cparams.causal_attn) { - const int64_t n_kv = kv_self.n; - const int64_t n_tokens = batch.n_tokens; + const int64_t n_kv = kv_self.n; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); @@ -13245,22 +13430,25 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { // of the correct sequence for each token of the batch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j][0]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; - for (int i = 0; i < n_kv; ++i) { - float f; - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - f = -INFINITY; - } else { - if (hparams.use_alibi) { - f = -fabs(kv_self.cells[i].pos - pos); + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos pos = batch.pos[s*n_seq_tokens + j]; + + for (int i = 0; i < n_kv; ++i) { + float f; + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + f = -INFINITY; } else { - f = 0.0f; + if (hparams.use_alibi) { + f = -fabs(kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } } + data[h*(n_kv*n_seq_tokens*n_seqs) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; } - data[h*(n_kv*n_tokens) + j*n_kv + i] = f; } } @@ -13271,8 +13459,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } } else { + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; // when using kv cache, the mask needs to match the kv cache size - const int64_t n_tokens = batch.n_tokens; const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); @@ -13280,27 +13470,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { float * data = (float *) lctx.inp_KQ_mask->data; for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_seq_id seq_id = batch.seq_id[j][0]; + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = batch.seq_id[s1][0]; - for (int i = 0; i < n_tokens; ++i) { - float f = -INFINITY; - for (int s = 0; s < batch.n_seq_id[i]; ++s) { - if (batch.seq_id[i][s] == seq_id) { - if (hparams.use_alibi) { - f = -fabs(batch.pos[i] - batch.pos[j]); - } else { - f = 0.0f; + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < batch.n_seq_id[s0]; ++s) { + if (batch.seq_id[s0][s] == seq_id) { + if (hparams.use_alibi) { + f = -fabs(batch.pos[ti] - batch.pos[tj]); + } else { + f = 0.0f; + } + break; + } } - break; + + data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; } } - data[h*(n_tokens*n_tokens) + j*n_stride + i] = f; - } - - for (int i = n_tokens; i < n_stride; ++i) { - data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY; + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; + } } } } @@ -13308,7 +13506,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(lctx.inp_mean); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); @@ -13317,12 +13517,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean)); std::vector sum(n_tokens, 0); - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when batch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); - sum[seq_id] += 1; + sum[seq_id] += batch.n_seq_tokens; } std::vector div(n_tokens, 0.0f); @@ -13333,14 +13535,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - data[seq_id*n_tokens + i] = div[seq_id]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + + for (int i = 0; i < n_seq_tokens; ++i) { + data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; + } } } if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(lctx.inp_cls); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); @@ -13348,14 +13555,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { uint32_t * data = (uint32_t *) lctx.inp_cls->data; memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - const llama_pos pos = batch.pos[i]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + // TODO: adapt limits to n_seqs when batch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); - if (pos == 0) { - data[seq_id] = i; + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = batch.pos[s*n_seq_tokens + i]; + + if (pos == 0) { + data[seq_id] = s*n_seq_tokens + i; + } } } } @@ -13372,7 +13583,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { uint32_t cell_id = i + rs_self.head; llama_rs_cell & rs_cell = lctx.cache.rs.cells[cell_id]; - data[i] = (float) rs_cell.src >= 0; + data[i] = (float) (rs_cell.src >= 0); // only clear once if (rs_cell.src < 0) { @@ -13404,29 +13615,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } } - - // For Mamba (and other recurrent architectures), - // update the correct state(s)/sequence(s) for each token of the batch. - // Each row contains relative cell ids of the sequences for the associated token. - // Like with the KQ_mask, if a token in the batch has multiple sequences, - // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). - if (lctx.inp_s_seq) { - const int64_t n_tokens = batch.n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer)); - int32_t * data = (int32_t *) lctx.inp_s_seq->data; - - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size()); - const auto & seq = rs_self.seq_tails[seq_id]; - // ensure the relative cell id will be positive but not too big - GGML_ASSERT((uint32_t) seq.tail >= rs_self.head); - GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n); - - data[i] = seq.tail - rs_self.head; - } - } } } @@ -13598,7 +13786,7 @@ static int llama_decode_internal( } else { GGML_ASSERT(u_batch.output); for (uint32_t i = 0; i < n_tokens; i++) { - n_outputs_new += u_batch.output[i] != 0; + n_outputs_new += (int32_t) (u_batch.output[i] != 0); } } @@ -14077,10 +14265,10 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { if (need_reserve) { // TODO: extract to a function // build worst-case graph - int n_seqs = 1; // TODO: worst-case number of sequences - int n_tokens = (int)std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); // initialize scheduler with the worst-case graph @@ -18026,10 +18214,10 @@ struct llama_context * llama_new_context_with_model( } // build worst-case graph - int n_seqs = 1; // TODO: worst-case number of sequences - int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch); + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true); // initialize scheduler with the worst-case graph @@ -19347,6 +19535,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, if (cell_count) { llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; batch.n_seqs = 1; for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; @@ -19354,9 +19543,9 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(pos); batch.pos[i] = pos; - batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = dest_seq_id; } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; if (!llama_cache_find_slot(cache, batch)) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return 0; @@ -19680,9 +19869,54 @@ void llama_synchronize(struct llama_context * ctx) { ctx->t_compute_start_us = 0; } +// make the outputs have the same order they had in the user-provided batch +static void llama_reorder_outputs(struct llama_context * ctx) { + std::vector & out_ids = ctx->sbatch.out_ids; + if (!out_ids.empty()) { + std::vector logits_tmp; + std::vector embd_tmp; + uint32_t n_vocab = ctx->model.hparams.n_vocab; + uint32_t n_embd = ctx->model.hparams.n_embd; + int32_t n_outputs = ctx->n_outputs; + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + // insertion sort (from https://en.wikipedia.org/wiki/Insertion_sort, but using memmove) + for (int32_t i = 1; i < n_outputs; ++i) { + int32_t j = i; + size_t out_id_tmp = out_ids[i]; + while (j > 0 && out_ids[j - 1] > out_id_tmp) { j -= 1; } + if (i - j == 0) { continue; } + memmove(out_ids.data() + j + 1, out_ids.data() + j, (i - j)*sizeof(out_ids[0])); + out_ids[j] = out_id_tmp; + if (ctx->logits_size > 0) { + // only allocate once something needs to be moved + if (logits_tmp.empty()) { logits_tmp.resize(n_vocab); } + memcpy(logits_tmp.data(), ctx->logits + i*n_vocab, n_vocab*sizeof(float)); + memmove(ctx->logits + (j + 1)*n_vocab, ctx->logits + j*n_vocab, (i - j)*n_vocab*sizeof(float)); + memcpy(ctx->logits + j*n_vocab, logits_tmp.data(), n_vocab*sizeof(float)); + } + if (ctx->embd_size > 0) { + // only allocate once something needs to be moved + if (embd_tmp.empty()) { embd_tmp.resize(n_embd); } + memcpy(embd_tmp.data(), ctx->embd + i*n_embd, n_embd*sizeof(float)); + memmove(ctx->embd + (j + 1)*n_embd, ctx->embd + j*n_embd, (i - j)*n_embd*sizeof(float)); + memcpy(ctx->embd + j*n_embd, embd_tmp.data(), n_embd*sizeof(float)); + } + } + std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1); + for (int32_t i = 0; i < n_outputs; ++i) { + ctx->output_ids[out_ids[i]] = i; + } + out_ids.clear(); + } +} + float * llama_get_logits(struct llama_context * ctx) { llama_synchronize(ctx); + // reorder logits for backward compatibility + // TODO: maybe deprecate this + llama_reorder_outputs(ctx); + return ctx->logits; } @@ -19727,6 +19961,10 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { float * llama_get_embeddings(struct llama_context * ctx) { llama_synchronize(ctx); + // reorder embeddings for backward compatibility + // TODO: maybe deprecate this + llama_reorder_outputs(ctx); + return ctx->embd; }