metal : first working version of the inference without prompt processing

Bonus: supports partial inference on the CPU
This commit is contained in:
Georgi Gerganov 2023-07-20 14:56:29 +03:00
parent 290cb700bf
commit cb82adadb8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 11 additions and 6 deletions

View File

@ -237,7 +237,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
}
}
fprintf(stderr, "%s: error: buffer is nil\n", __func__);
fprintf(stderr, "%s: error: buffer is nil for tensor '%s'\n", __func__, t->name);
return nil;
}
@ -877,15 +877,15 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder];
}
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_past = ((int32_t *)(src1->data))[0];
const int n_past = ((int32_t *)(dst->op_params))[0];
float freq_base;
float freq_scale;
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];

View File

@ -2845,10 +2845,15 @@ struct llama_context * llama_new_context_with_model(
struct ggml_backend_buffer * buf_compute = ctx->buf_compute_metal->backend_buffer;
struct ggml_backend_buffer * buf_kv = ctx->kv_self.buf->backend_buffer;
struct ggml_backend_buffer * buf_input = ctx->buf_input->backend_buffer;
struct ggml_backend_buffer * buf_output = ctx->buf_output->backend_buffer;
LLAMA_METAL_CHECK_BUF(ggml_backend_metal_map_buffer(ctx->model.backend_metal, "eval", buf_compute->backend_data, buf_compute->backend_size, 0));
LLAMA_METAL_CHECK_BUF(ggml_backend_metal_map_buffer(ctx->model.backend_metal, "kv", buf_kv->backend_data, buf_kv->backend_size, 0));
LLAMA_METAL_CHECK_BUF(ggml_backend_metal_map_buffer(ctx->model.backend_metal, "inp", buf_input->backend_data, buf_input->backend_size, 0));
LLAMA_METAL_CHECK_BUF(ggml_backend_metal_map_buffer(ctx->model.backend_metal, "inp", buf_output->backend_data, buf_output->backend_size, 0));
//LLAMA_METAL_CHECK_BUF(ggml_backend_metal_map_buffer(ctx->model.backend_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size, 0));
//LLAMA_METAL_CHECK_BUF(ggml_backend_metal_map_buffer(ctx->model.backend_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size, 0));
#undef LLAMA_METAL_CHECK_BUF