mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
attempt to get test-backend-ops working
This commit is contained in:
parent
8a99f69895
commit
50579f27e9
@ -314,6 +314,12 @@ static void ggml_backend_registry_init(void) {
|
||||
extern ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
|
||||
ggml_backend_register("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL);
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_KOMPUTE
|
||||
extern ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data);
|
||||
extern ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(void);
|
||||
ggml_backend_register("Kompute", ggml_backend_reg_kompute_init, ggml_backend_kompute_buffer_type(), NULL);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
|
||||
|
@ -499,7 +499,7 @@ ggml_vk_memory * ggml_vk_find_tensor(struct ggml_kompute_context * ctx, struct g
|
||||
|
||||
const intptr_t ioffs = reinterpret_cast<intptr_t>(t->data) - reinterpret_cast<intptr_t>(buf_ctx->data);
|
||||
|
||||
GGML_ASSERT(ioffs >= 0 && ioffs + ggml_nbytes(t) <= (int64_t)t->buffer->size);
|
||||
GGML_ASSERT(ioffs >= 0 && ioffs + (int64_t)ggml_nbytes(t) <= (int64_t)t->buffer->size);
|
||||
|
||||
offset = (uint64_t)ioffs;
|
||||
return buf_ctx;
|
||||
@ -1344,6 +1344,82 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
|
||||
ggml_vk_cpy<2, 4>(spirv, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static bool ggml_kompute_supports_op(const struct ggml_tensor * op) {
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
return true;
|
||||
default:
|
||||
;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_ALIBI:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return true;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_GET_ROWS:
|
||||
return op->ne[3] == 1;
|
||||
default:
|
||||
;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
|
||||
const int n_seq = 8;
|
||||
|
||||
@ -1362,7 +1438,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||
auto& seq = *sequences[seq_idx];
|
||||
|
||||
const int node_start = (seq_idx + 0) * n_nodes_per_seq;
|
||||
const int node_end = (seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq;
|
||||
const int node_end = std::min((seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq, gf->n_nodes);
|
||||
|
||||
for (int i = node_start; i < node_end; ++i) {
|
||||
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
||||
@ -1381,6 +1457,11 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||
break;
|
||||
}
|
||||
|
||||
if (!ggml_kompute_supports_op(dst)) {
|
||||
fprintf(stderr, "%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
||||
GGML_ASSERT(!"unsupported op");
|
||||
}
|
||||
|
||||
const int32_t ne00 = src0 ? src0->ne[0] : 0;
|
||||
const int32_t ne01 = src0 ? src0->ne[1] : 0;
|
||||
const int32_t ne02 = src0 ? src0->ne[2] : 0;
|
||||
@ -1718,7 +1799,7 @@ static bool ggml_backend_kompute_buffer_type_supports_backend(ggml_backend_buffe
|
||||
return ggml_backend_is_kompute(backend);
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(void) {
|
||||
ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type() {
|
||||
static struct ggml_backend_buffer_type ggml_backend_buffer_type_kompute = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_kompute_buffer_type_get_name,
|
||||
@ -1761,8 +1842,7 @@ static bool ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct gg
|
||||
|
||||
static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
||||
GGML_UNUSED(backend);
|
||||
GGML_UNUSED(op);
|
||||
return true; // TODO: implement
|
||||
return ggml_kompute_supports_op(op);
|
||||
}
|
||||
|
||||
static struct ggml_backend_i kompute_backend_i = {
|
||||
@ -1800,3 +1880,12 @@ ggml_backend_t ggml_backend_kompute_init() {
|
||||
bool ggml_backend_is_kompute(ggml_backend_t backend) {
|
||||
return backend && backend->iface.get_name == ggml_backend_kompute_name;
|
||||
}
|
||||
|
||||
extern "C" ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data);
|
||||
|
||||
ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data) {
|
||||
GGML_UNUSED(params);
|
||||
GGML_UNUSED(user_data);
|
||||
ggml_vk_init_device(0, "gpu");
|
||||
return ggml_backend_kompute_init();
|
||||
}
|
||||
|
@ -63,6 +63,10 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||
// user-code should use only these functions
|
||||
//
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// forward declaration
|
||||
typedef struct ggml_backend * ggml_backend_t;
|
||||
|
||||
@ -71,3 +75,7 @@ GGML_API ggml_backend_t ggml_backend_kompute_init(void);
|
||||
GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
|
||||
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user