llama : fix buffer checks for mamba and rwk (#10111)

* llama : fix buffer checks for mamba and rwk

* llama : fix missing worst case flag during reserve

* cuda : fix supports_op for norm

* disable sched SET_CAUSE
This commit is contained in:
Diego Devesa 2024-10-31 22:54:23 +01:00 committed by GitHub
parent ab3d71f97f
commit c02e5ab2a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 12 deletions

View File

@ -1508,7 +1508,7 @@ static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, co
return -1; return -1;
} }
#if 1 #if 0
#define GGML_SCHED_MAX_SPLITS_DEBUG 4096 #define GGML_SCHED_MAX_SPLITS_DEBUG 4096
static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only
#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__) #define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)

View File

@ -3107,18 +3107,20 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
} }
return false; return false;
} break; } break;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
break;
case GGML_OP_NONE: case GGML_OP_NONE:
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
case GGML_OP_VIEW: case GGML_OP_VIEW:
case GGML_OP_PERMUTE: case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE: case GGML_OP_TRANSPOSE:
case GGML_OP_NORM:
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_ADD1: case GGML_OP_ADD1:
case GGML_OP_SUB: case GGML_OP_SUB:
case GGML_OP_MUL: case GGML_OP_MUL:
case GGML_OP_DIV: case GGML_OP_DIV:
case GGML_OP_RMS_NORM:
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_SQR: case GGML_OP_SQR:
case GGML_OP_SQRT: case GGML_OP_SQRT:

View File

@ -7272,6 +7272,7 @@ struct ggml_tensor * ggml_ssm_conv(
const int64_t n_s = sx->ne[2]; const int64_t n_s = sx->ne[2];
// TODO: maybe support other strides than 1? // TODO: maybe support other strides than 1?
// FIXME: this is always true?
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
GGML_ASSERT(sx->ne[1] == d_inner); GGML_ASSERT(sx->ne[1] == d_inner);
GGML_ASSERT(n_t >= 0); GGML_ASSERT(n_t >= 0);

View File

@ -7127,7 +7127,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
} break; } break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {
ggml_tensor * b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, w->ne[0], 512); ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]);
op_tensor = ggml_mul_mat(ctx, w, b); op_tensor = ggml_mul_mat(ctx, w, b);
} break; } break;
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
@ -7167,18 +7167,38 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
} break; } break;
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:
{ {
// TODO: ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); // FIXME
op_tensor = ggml_ssm_conv(ctx, nullptr, w); ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789);
op_tensor = ggml_ssm_conv(ctx, conv_x, w);
} break; } break;
case GGML_OP_SSM_SCAN: case GGML_OP_SSM_SCAN:
{ {
// TODO: ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); // FIXME
op_tensor = ggml_ssm_scan(ctx, nullptr, nullptr, nullptr, w, nullptr, nullptr); const int64_t d_state = w->ne[0];
const int64_t d_inner = w->ne[1];
const int64_t n_seq_tokens = 512;
const int64_t n_seqs = 1;
ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs);
ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C);
} break; } break;
case GGML_OP_RWKV_WKV: case GGML_OP_RWKV_WKV:
{ {
// TODO: ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state); // FIXME
op_tensor = ggml_rwkv_wkv(ctx, nullptr, nullptr, nullptr, w, nullptr, nullptr); const int64_t S = 123;
const int64_t H = 123;
const int64_t n_tokens = 123;
const int64_t n_seqs = 123;
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, 1, H, n_tokens);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
ggml_tensor * r = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
ggml_tensor * tf = w;
ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
op_tensor = ggml_rwkv_wkv(ctx, k, v, r, tf, td, state);
} break; } break;
default: default:
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
@ -7453,7 +7473,7 @@ static bool llm_load_tensors(
// tensors with "bias" suffix are always used with GGML_OP_ADD // tensors with "bias" suffix are always used with GGML_OP_ADD
ggml_op op; ggml_op op;
bool bias = strcmp(tn.suffix, "bias") == 0; bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;
if (bias) { if (bias) {
op = GGML_OP_ADD; op = GGML_OP_ADD;
} else { } else {
@ -19681,7 +19701,7 @@ struct llama_context * llama_new_context_with_model(
int n_nodes_tg = ggml_graph_n_nodes(gf_tg); int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
// reserve again with pp graph to avoid ggml-alloc reallocations during inference // reserve again with pp graph to avoid ggml-alloc reallocations during inference
gf_pp = llama_build_graph(*ctx, ubatch_pp, false); gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
if (!ggml_backend_sched_reserve(ctx->sched, gf_pp)) { if (!ggml_backend_sched_reserve(ctx->sched, gf_pp)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
llama_free(ctx); llama_free(ctx);