mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-15 15:29:53 +00:00
llama : fix Mamba-2 conv state saving
* ggml : make the ggml_mul fast broadcast path more consistently formatted
This commit is contained in:
parent
2bfe9de6d3
commit
aff96920f9
@ -10226,7 +10226,7 @@ static void ggml_compute_forward_mul_f32(
|
|||||||
if (scale == 0.0f) {
|
if (scale == 0.0f) {
|
||||||
// NOTE: this also sets NANs to zero, which is not compliant with IEEE754,
|
// NOTE: this also sets NANs to zero, which is not compliant with IEEE754,
|
||||||
// but it is useful when resetting the state of recurrent models.
|
// but it is useful when resetting the state of recurrent models.
|
||||||
memset((char *)dst->data + ir*nb1, 0, nb1);
|
memset((char *) dst->data + ir*nb1, 0, ne0 * sizeof(float));
|
||||||
} else {
|
} else {
|
||||||
if (dst->data != src0->data) {
|
if (dst->data != src0->data) {
|
||||||
// src0 is same shape as dst => same indices
|
// src0 is same shape as dst => same indices
|
||||||
|
@ -9335,7 +9335,7 @@ static struct ggml_tensor * llm_build_mamba2(
|
|||||||
ggml_cpy(ctx, last_conv,
|
ggml_cpy(ctx, last_conv,
|
||||||
ggml_view_1d(ctx, conv_states_all,
|
ggml_view_1d(ctx, conv_states_all,
|
||||||
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
|
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
|
||||||
kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
|
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
|
||||||
|
|
||||||
// 1D convolution
|
// 1D convolution
|
||||||
// The equivalent is to make a self-overlapping view of conv_x
|
// The equivalent is to make a self-overlapping view of conv_x
|
||||||
|
Loading…
Reference in New Issue
Block a user