mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
metal : pad n_ctx by 32 (#6177)
* metal : require ne00 >= 128 for mat-mat kernels ggml-ci * llama : pad n_ctx by 32 ggml-ci
This commit is contained in:
parent
59c17f02de
commit
95d576b48e
@ -48,6 +48,8 @@ int main(int argc, char ** argv) {
|
|||||||
params.prompt = "Hello my name is";
|
params.prompt = "Hello my name is";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
process_escapes(params.prompt);
|
||||||
|
|
||||||
// init LLM
|
// init LLM
|
||||||
|
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
|
@ -13044,6 +13044,9 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
||||||
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
||||||
|
|
||||||
|
// this is necessary due to kv_self.n being padded later during inference
|
||||||
|
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32);
|
||||||
|
|
||||||
// with causal attention, the batch size is limited by the context size
|
// with causal attention, the batch size is limited by the context size
|
||||||
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
||||||
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
||||||
|
@ -2091,6 +2091,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 128, { 8, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 128, { 8, 1}, {4, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 64, { 8, 1}, {4, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 64, { 8, 1}, {4, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
|
||||||
|
|
||||||
for (ggml_type type_a : all_types) {
|
for (ggml_type type_a : all_types) {
|
||||||
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
|
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
|
||||||
for (int n_mats : {2, 4, 8}) {
|
for (int n_mats : {2, 4, 8}) {
|
||||||
|
Loading…
Reference in New Issue
Block a user