From ab336a9d5e5352ecdcdf4c12d2d54cf4ef82ce31 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 25 Feb 2024 12:09:09 +0200 Subject: [PATCH] code : normalize enum names (#5697) * coda : normalize enum names ggml-ci * code : cont * code : cont --- common/common.cpp | 18 +- common/common.h | 4 +- common/train.cpp | 10 +- examples/baby-llama/baby-llama.cpp | 2 +- examples/finetune/finetune.cpp | 2 +- examples/llama-bench/llama-bench.cpp | 14 +- examples/llava/llava.cpp | 2 +- examples/server/server.cpp | 18 +- .../train-text-from-scratch.cpp | 2 +- ggml-cuda.cu | 138 +++---- ggml-metal.m | 4 +- ggml-opencl.cpp | 50 +-- ggml-sycl.cpp | 152 ++++---- ggml-vulkan.cpp | 102 ++--- ggml.c | 350 +++++++++--------- ggml.h | 38 +- llama.cpp | 64 ++-- llama.h | 28 +- tests/test-backend-ops.cpp | 4 +- tests/test-opt.cpp | 2 +- 20 files changed, 502 insertions(+), 502 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 10ef11829..ec596f5a0 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -295,9 +295,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } std::string value(argv[i]); - /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; } - else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; } - else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; } + /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } + else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } + else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } else { invalid_param = true; break; } } else if (arg == "--rope-scale") { if (++i >= argc) { @@ -630,11 +630,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } std::string arg_next = argv[i]; if (arg_next == "none") { - params.split_mode = LLAMA_SPLIT_NONE; + params.split_mode = LLAMA_SPLIT_MODE_NONE; } else if (arg_next == "layer") { - params.split_mode = LLAMA_SPLIT_LAYER; + params.split_mode = LLAMA_SPLIT_MODE_LAYER; } else if (arg_next == "row") { - params.split_mode = LLAMA_SPLIT_ROW; + params.split_mode = LLAMA_SPLIT_MODE_ROW; } else { invalid_param = true; break; @@ -837,15 +837,15 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sep++; if (strncmp(sep, "int:", 4) == 0) { sep += 4; - kvo.tag = LLAMA_KV_OVERRIDE_INT; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; kvo.int_value = std::atol(sep); } else if (strncmp(sep, "float:", 6) == 0) { sep += 6; - kvo.tag = LLAMA_KV_OVERRIDE_FLOAT; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; kvo.float_value = std::atof(sep); } else if (strncmp(sep, "bool:", 5) == 0) { sep += 5; - kvo.tag = LLAMA_KV_OVERRIDE_BOOL; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; if (std::strcmp(sep, "true") == 0) { kvo.bool_value = true; } else if (std::strcmp(sep, "false") == 0) { diff --git a/common/common.h b/common/common.h index 935771d44..3e21579b0 100644 --- a/common/common.h +++ b/common/common.h @@ -61,7 +61,7 @@ struct gpt_params { float p_split = 0.1f; // speculative decoding split probability int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) - llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs + llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs int32_t n_beams = 0; // if non-zero then use beam search of given width. @@ -75,7 +75,7 @@ struct gpt_params { float yarn_beta_fast = 32.0f; // YaRN low correction dim float yarn_beta_slow = 1.0f; // YaRN high correction dim int32_t yarn_orig_ctx = 0; // YaRN original context length - int32_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; + int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; // // sampling parameters diff --git a/common/train.cpp b/common/train.cpp index e4c3d5df6..0dbfd24df 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -31,7 +31,7 @@ struct train_state * init_train_state() { state->opt = new struct ggml_opt_context; state->opt->ctx = NULL; - state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM); + state->opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM); state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES; state->opt->loss_after = 0.0f; @@ -556,7 +556,7 @@ void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_g std::string opt_type; GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE); if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) { - opt->params.type = GGML_OPT_ADAM; + opt->params.type = GGML_OPT_TYPE_ADAM; GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS); GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS); @@ -568,7 +568,7 @@ void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_g copy_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS); copy_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES); } else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) { - opt->params.type = GGML_OPT_LBFGS; + opt->params.type = GGML_OPT_TYPE_LBFGS; GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT); GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS); @@ -603,7 +603,7 @@ void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized); switch (opt->params.type) { - case GGML_OPT_ADAM: + case GGML_OPT_TYPE_ADAM: { gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM); gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best); @@ -622,7 +622,7 @@ void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * gguf_add_tensor(fctx, opt->adam.pf); } } break; - case GGML_OPT_LBFGS: + case GGML_OPT_TYPE_LBFGS: { gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS); gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m); diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 65bb238a0..bf0125e75 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -1547,7 +1547,7 @@ int main(int argc, char ** argv) { float error_before_opt = ggml_get_f32_1d(e, 0); - struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS); + struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_TYPE_LBFGS); opt_params_lbfgs.print_forward_graph = false; opt_params_lbfgs.print_backward_graph = false; opt_params_lbfgs.lbfgs.n_iter = 16; diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 98bf5a07a..3da5317b3 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1531,7 +1531,7 @@ int main(int argc, char ** argv) { lora.hparams.n_rank_output = n_rank_output; // set opt params from command line - opt->params = ggml_opt_default_params(GGML_OPT_ADAM); + opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM); opt->params.print_forward_graph = false; opt->params.print_backward_graph = false; opt->params.graph_size = LLAMA_TRAIN_MAX_NODES; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 11410f8ae..8fec3d43d 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -157,9 +157,9 @@ static const char * output_format_str(output_formats format) { static const char * split_mode_str(llama_split_mode mode) { switch (mode) { - case LLAMA_SPLIT_NONE: return "none"; - case LLAMA_SPLIT_LAYER: return "layer"; - case LLAMA_SPLIT_ROW: return "row"; + case LLAMA_SPLIT_MODE_NONE: return "none"; + case LLAMA_SPLIT_MODE_LAYER: return "layer"; + case LLAMA_SPLIT_MODE_ROW: return "row"; default: GGML_ASSERT(!"invalid split mode"); } } @@ -193,7 +193,7 @@ static const cmd_params cmd_params_defaults = { /* type_v */ {GGML_TYPE_F16}, /* n_threads */ {get_num_physical_cores()}, /* n_gpu_layers */ {99}, - /* split_mode */ {LLAMA_SPLIT_LAYER}, + /* split_mode */ {LLAMA_SPLIT_MODE_LAYER}, /* main_gpu */ {0}, /* no_kv_offload */ {false}, /* mul_mat_q */ {true}, @@ -358,11 +358,11 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { for (const auto & m : p) { llama_split_mode mode; if (m == "none") { - mode = LLAMA_SPLIT_NONE; + mode = LLAMA_SPLIT_MODE_NONE; } else if (m == "layer") { - mode = LLAMA_SPLIT_LAYER; + mode = LLAMA_SPLIT_MODE_LAYER; } else if (m == "row") { - mode = LLAMA_SPLIT_ROW; + mode = LLAMA_SPLIT_MODE_ROW; } else { invalid_param = true; break; diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 1a1cf7c78..980128166 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -152,7 +152,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector ggml_tensor * newline_tmp = clip_get_newline_tensor(ctx_clip); model.newline = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, newline_tmp->ne[0]); - if (newline_tmp->backend != GGML_BACKEND_CPU) { + if (newline_tmp->backend != GGML_BACKEND_TYPE_CPU) { if (newline_tmp->buffer == NULL) { printf("newline_tmp tensor buffer is NULL\n"); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 19a8c1067..780862ef6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2086,9 +2086,9 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } std::string value(argv[i]); - /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; } - else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; } - else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; } + /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } + else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } + else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } else { invalid_param = true; break; } } else if (arg == "--rope-freq-base") @@ -2212,15 +2212,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, std::string arg_next = argv[i]; if (arg_next == "none") { - params.split_mode = LLAMA_SPLIT_NONE; + params.split_mode = LLAMA_SPLIT_MODE_NONE; } else if (arg_next == "layer") { - params.split_mode = LLAMA_SPLIT_LAYER; + params.split_mode = LLAMA_SPLIT_MODE_LAYER; } else if (arg_next == "row") { - params.split_mode = LLAMA_SPLIT_ROW; + params.split_mode = LLAMA_SPLIT_MODE_ROW; } else { invalid_param = true; @@ -2447,15 +2447,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, sep++; if (strncmp(sep, "int:", 4) == 0) { sep += 4; - kvo.tag = LLAMA_KV_OVERRIDE_INT; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; kvo.int_value = std::atol(sep); } else if (strncmp(sep, "float:", 6) == 0) { sep += 6; - kvo.tag = LLAMA_KV_OVERRIDE_FLOAT; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; kvo.float_value = std::atof(sep); } else if (strncmp(sep, "bool:", 5) == 0) { sep += 5; - kvo.tag = LLAMA_KV_OVERRIDE_BOOL; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; if (std::strcmp(sep, "true") == 0) { kvo.bool_value = true; } else if (std::strcmp(sep, "false") == 0) { diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index e78ab185d..7eafe8515 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -960,7 +960,7 @@ int main(int argc, char ** argv) { struct ggml_opt_context * opt = train->opt; // set opt params from command line - opt->params = ggml_opt_default_params(GGML_OPT_ADAM); + opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM); opt->params.print_forward_graph = false; opt->params.print_backward_graph = false; opt->params.graph_size = LLAMA_TRAIN_MAX_NODES; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 21c612cb7..fb6d4f7d2 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6369,11 +6369,11 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n int ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { swap(dst_row[col], dst_row[ixj]); } } else { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { swap(dst_row[col], dst_row[ixj]); } } @@ -7927,10 +7927,10 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co const dim3 block_dims(ncols, 1, 1); const dim3 block_nums(1, nrows, 1); - if (order == GGML_SORT_ASC) { - k_argsort_f32_i32<<>>(x, dst, ncols); - } else if (order == GGML_SORT_DESC) { - k_argsort_f32_i32<<>>(x, dst, ncols); + if (order == GGML_SORT_ORDER_ASC) { + k_argsort_f32_i32<<>>(x, dst, ncols); + } else if (order == GGML_SORT_ORDER_DESC) { + k_argsort_f32_i32<<>>(x, dst, ncols); } else { GGML_ASSERT(false); } @@ -8362,11 +8362,11 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( cudaMemcpyKind kind; char * src_ptr; - if (src->backend == GGML_BACKEND_CPU) { + if (src->backend == GGML_BACKEND_TYPE_CPU) { kind = cudaMemcpyHostToDevice; src_ptr = (char *) src->data; - } else if (src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT) { - GGML_ASSERT(src->backend != GGML_BACKEND_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1])); + } else if (src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT) { + GGML_ASSERT(src->backend != GGML_BACKEND_TYPE_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1])); kind = cudaMemcpyDeviceToDevice; ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra; int id; @@ -8771,7 +8771,7 @@ static void ggml_cuda_op_mul_mat_q( // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into - const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; + const int64_t nrows_dst = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device ? ne0 : row_diff; switch (src0->type) { case GGML_TYPE_Q4_0: @@ -8920,7 +8920,7 @@ static void ggml_cuda_op_mul_mat_vec_q( // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into - const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; + const int64_t nrows_dst = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device ? ne0 : row_diff; switch (src0->type) { case GGML_TYPE_Q4_0: @@ -9096,7 +9096,7 @@ static void ggml_cuda_op_mul_mat_cublas( // the main device has a larger memory buffer to hold the results from all GPUs // ldc == nrows of the matrix that cuBLAS writes into - int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; + int ldc = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device ? ne0 : row_diff; const int compute_capability = g_device_caps[id].cc; @@ -9444,7 +9444,7 @@ static void ggml_cuda_op_soft_max( const bool use_src2 = src2 != nullptr; if (use_src2) { - const bool src2_on_device = src2->backend == GGML_BACKEND_GPU; + const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU; if (src2_on_device) { ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra; @@ -9502,16 +9502,16 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s const bool use_src1 = src1 != nullptr; const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1; - GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); - GGML_ASSERT( dst->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT); + GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT); ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; - const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU; - const bool dst_on_device = dst->backend == GGML_BACKEND_GPU; + const bool src0_on_device = src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT; + const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_TYPE_GPU; + const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU; // dd = data device float * src0_ddf = nullptr; @@ -9555,7 +9555,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream)); } - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { CUDA_CHECK(cudaDeviceSynchronize()); } } @@ -9636,8 +9636,8 @@ static void ggml_cuda_op_mul_mat( const int nb2 = dst->nb[2]; const int nb3 = dst->nb[3]; - GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT); - GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT); + GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT); GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1)); GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0); @@ -9653,20 +9653,20 @@ static void ggml_cuda_op_mul_mat( ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; + const bool src0_on_device = src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT; const bool src0_is_contiguous = ggml_is_contiguous(src0); const bool src1_is_contiguous = ggml_is_contiguous(src1); const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING); - const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; + const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT; GGML_ASSERT(!(split && ne02 > 1)); GGML_ASSERT(!(split && ne03 > 1)); GGML_ASSERT(!(split && ne02 < ne12)); std::array tensor_split; if (split) { - // TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_GPU_SPLIT check + // TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_TYPE_GPU_SPLIT check // GGML_ASSERT(src0->buffer != nullptr && src0->buffer->buft == ...); ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; tensor_split = buft_ctx->tensor_split; @@ -9724,8 +9724,8 @@ static void ggml_cuda_op_mul_mat( used_devices++; - const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; - const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; + const bool src1_on_device = src1->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device; + const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device; ggml_cuda_set_device(id); cudaStream_t stream = g_cudaStreams[id][0]; @@ -9776,8 +9776,8 @@ static void ggml_cuda_op_mul_mat( continue; } - const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; - const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; + const bool src1_on_device = src1->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device; + const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device; const int64_t row_diff = dev[id].row_high - dev[id].row_low; ggml_cuda_set_device(id); @@ -9802,12 +9802,12 @@ static void ggml_cuda_op_mul_mat( // the main device memory buffer can be on VRAM scratch, with space for all partial results // in that case an offset on dst_ddf_i is needed - if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) { + if (dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device) { dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split } // copy src0, src1 to device if necessary - if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) { + if (src1->backend == GGML_BACKEND_TYPE_GPU && src1_is_contiguous) { if (id != g_main_device) { if (convert_src1_to_q8_1) { char * src1_ddq_i_source = dev[g_main_device].src1_ddq + src1_ddq_i_offset; @@ -9820,14 +9820,14 @@ static void ggml_cuda_op_mul_mat( src1_ncols*ne10*sizeof(float), stream)); } } - } else if (src1->backend == GGML_BACKEND_CPU || (src1_on_device && !src1_is_contiguous)) { + } else if (src1->backend == GGML_BACKEND_TYPE_CPU || (src1_on_device && !src1_is_contiguous)) { CUDA_CHECK(ggml_cuda_cpy_tensor_2d( src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); } else { GGML_ASSERT(false); } - if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_CPU || !src1_is_contiguous)) { + if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_TYPE_CPU || !src1_is_contiguous)) { quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); } @@ -9845,10 +9845,10 @@ static void ggml_cuda_op_mul_mat( if (!dst_on_device) { void * dst_off_device; cudaMemcpyKind kind; - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { dst_off_device = dst->data; kind = cudaMemcpyDeviceToHost; - } else if (dst->backend == GGML_BACKEND_GPU) { + } else if (dst->backend == GGML_BACKEND_TYPE_GPU) { dst_off_device = dst_extra->data_device[g_main_device]; kind = cudaMemcpyDeviceToDevice; } else { @@ -9913,7 +9913,7 @@ static void ggml_cuda_op_mul_mat( } } - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { ggml_cuda_set_device(g_main_device); CUDA_CHECK(cudaDeviceSynchronize()); } @@ -10019,7 +10019,7 @@ GGML_CALL bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const stru static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); - GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT); GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation GGML_ASSERT(src0->type == GGML_TYPE_F16); @@ -10050,7 +10050,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); - GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT); GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -10109,7 +10109,7 @@ static void ggml_cuda_mul_mat_batched_cublas(const ggml_tensor * src0, const ggm GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT); GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_TENSOR_BINARY_OP_LOCALS @@ -10255,11 +10255,11 @@ static void ggml_cuda_mul_mat_batched_cublas(const ggml_tensor * src0, const ggm static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool all_on_device = - (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && - (src1->backend == GGML_BACKEND_GPU) && - ( dst->backend == GGML_BACKEND_GPU); + (src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT) && + (src1->backend == GGML_BACKEND_TYPE_GPU) && + ( dst->backend == GGML_BACKEND_TYPE_GPU); - const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; + const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT; int64_t min_compute_capability = INT_MAX; @@ -10409,7 +10409,7 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) { GGML_ASSERT(!ggml_is_transposed(src00)); GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(src00->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT); GGML_ASSERT(src1->type == GGML_TYPE_F32); const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00); @@ -10553,7 +10553,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s cudaStream_t stream = g_cudaStreams[g_main_device][0]; - if (ids->backend == GGML_BACKEND_GPU) { + if (ids->backend == GGML_BACKEND_TYPE_GPU) { const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device]; CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); @@ -10570,20 +10570,20 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s ggml_tensor src1_row = *src1; ggml_tensor dst_row = *dst; - src1_row.backend = GGML_BACKEND_GPU; - dst_row.backend = GGML_BACKEND_GPU; + src1_row.backend = GGML_BACKEND_TYPE_GPU; + dst_row.backend = GGML_BACKEND_TYPE_GPU; src1_row.extra = &src1_row_extra; dst_row.extra = &dst_row_extra; - char * src1_original = src1->backend == GGML_BACKEND_CPU ? + char * src1_original = src1->backend == GGML_BACKEND_TYPE_CPU ? (char *) src1->data : (char *) src1_extra->data_device[g_main_device]; - char * dst_original = dst->backend == GGML_BACKEND_CPU ? + char * dst_original = dst->backend == GGML_BACKEND_TYPE_CPU ? (char *) dst->data : (char *) dst_extra->data_device[g_main_device]; if (src1->ne[1] == 1) { - GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); - GGML_ASSERT(dst->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU); + GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU); for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { //int32_t row_id; @@ -10611,9 +10611,9 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s src1_row_extra.data_device[g_main_device] = src1_contiguous.get(); dst_row_extra.data_device[g_main_device] = dst_contiguous.get(); - const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ? + const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_TYPE_CPU ? cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice; - const cudaMemcpyKind dst_kind = dst->backend == GGML_BACKEND_CPU ? + const cudaMemcpyKind dst_kind = dst->backend == GGML_BACKEND_TYPE_CPU ? cudaMemcpyDeviceToHost : cudaMemcpyDeviceToDevice; for (int32_t row_id = 0; row_id < n_as; ++row_id) { @@ -10668,7 +10668,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s } } - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { CUDA_CHECK(cudaStreamSynchronize(stream)); } } @@ -10685,8 +10685,8 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); - GGML_ASSERT(src0->backend == GGML_BACKEND_GPU); - GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU); + GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU); GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); @@ -10817,9 +10817,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st if (!g_cublas_loaded) return false; ggml_cuda_func_t func; - const bool any_on_device = tensor->backend == GGML_BACKEND_GPU - || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) - || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU); + const bool any_on_device = tensor->backend == GGML_BACKEND_TYPE_GPU + || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT)) + || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_TYPE_GPU); if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_MUL_MAT_ID) { return false; @@ -10966,14 +10966,14 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st return false; } - if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT) { + if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT) { ggml_cuda_set_peer_access(tensor->src[1]->ne[1]); } if (params->ith != 0) { return true; } - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return true; } func(tensor->src[0], tensor->src[1], tensor); @@ -11072,7 +11072,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t extra->data_device[ctx->device] = tensor->data; - tensor->backend = GGML_BACKEND_GPU; + tensor->backend = GGML_BACKEND_TYPE_GPU; tensor->extra = extra; if (ggml_is_quantized(tensor->type)) { @@ -11087,7 +11087,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t } GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; @@ -11098,7 +11098,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t } GGML_CALL static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; @@ -11333,7 +11333,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming)); } } - tensor->backend = GGML_BACKEND_GPU_SPLIT; + tensor->backend = GGML_BACKEND_TYPE_GPU_SPLIT; tensor->extra = extra; } @@ -11605,7 +11605,7 @@ GGML_CALL static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[cuda_ctx->device][0])); } @@ -11614,7 +11614,7 @@ GGML_CALL static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[cuda_ctx->device][0])); } @@ -11644,7 +11644,7 @@ GGML_CALL static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, gg ggml_cuda_set_main_device(cuda_ctx->device); ggml_compute_params params = {}; - params.type = GGML_TASK_COMPUTE; + params.type = GGML_TASK_TYPE_COMPUTE; params.ith = 0; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -11654,13 +11654,13 @@ GGML_CALL static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, gg } #ifndef NDEBUG - assert(node->backend == GGML_BACKEND_GPU || node->backend == GGML_BACKEND_GPU_SPLIT); + assert(node->backend == GGML_BACKEND_TYPE_GPU || node->backend == GGML_BACKEND_TYPE_GPU_SPLIT); assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); assert(node->extra != nullptr); for (int j = 0; j < GGML_MAX_SRC; j++) { if (node->src[j] != nullptr) { - assert(node->src[j]->backend == GGML_BACKEND_GPU || node->src[j]->backend == GGML_BACKEND_GPU_SPLIT); + assert(node->src[j]->backend == GGML_BACKEND_TYPE_GPU || node->src[j]->backend == GGML_BACKEND_TYPE_GPU_SPLIT); assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); assert(node->src[j]->extra != nullptr); } diff --git a/ggml-metal.m b/ggml-metal.m index ee584cfa7..3d6b01263 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2262,8 +2262,8 @@ static bool ggml_metal_graph_compute( id pipeline = nil; switch (order) { - case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; - case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; + case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; + case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; default: GGML_ASSERT(false); }; diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 797bee667..df619a884 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -1354,7 +1354,7 @@ static void ggml_cl_pool_free(cl_mem mem, size_t size) { } void ggml_cl_free_data(const struct ggml_tensor* tensor) { - if (tensor->backend != GGML_BACKEND_GPU) { + if (tensor->backend != GGML_BACKEND_TYPE_GPU) { return; } @@ -1412,7 +1412,7 @@ static cl_int ggml_cl_h2d_tensor_2d(cl_command_queue queue, cl_mem dst, size_t o } static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU); const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; @@ -1476,7 +1476,7 @@ void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src } static void ggml_cl_add_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU); const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; @@ -1566,13 +1566,13 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr size_t y_size; size_t d_size; cl_mem d_X; - if (src0->backend == GGML_BACKEND_GPU) { // NOLINT + if (src0->backend == GGML_BACKEND_TYPE_GPU) { // NOLINT d_X = (cl_mem) src0->extra; } else { d_X = ggml_cl_pool_malloc(sizeof(float) * x_ne, &x_size); } - cl_mem d_Y = src1->backend == GGML_BACKEND_GPU ? (cl_mem) src1->extra : ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size); - cl_mem d_D = dst->backend == GGML_BACKEND_GPU ? (cl_mem) dst->extra : ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size); + cl_mem d_Y = src1->backend == GGML_BACKEND_TYPE_GPU ? (cl_mem) src1->extra : ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size); + cl_mem d_D = dst->backend == GGML_BACKEND_TYPE_GPU ? (cl_mem) dst->extra : ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size); size_t x_offset = 0; @@ -1580,7 +1580,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr // TODO: copy src0 here when r3>1 for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - if (src0->backend == GGML_BACKEND_GPU) { + if (src0->backend == GGML_BACKEND_TYPE_GPU) { x_offset = (i03 * ne02 + i02) * x_ne; } else { // copy src0 to device @@ -1589,7 +1589,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) { // copy src1 to device - if (src1->backend == GGML_BACKEND_CPU) { + if (src1->backend == GGML_BACKEND_TYPE_CPU) { CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL)); } @@ -1612,7 +1612,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr } // copy dst to host - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL)); } @@ -1621,13 +1621,13 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr } } - if (src0->backend != GGML_BACKEND_GPU) { + if (src0->backend != GGML_BACKEND_TYPE_GPU) { ggml_cl_pool_free(d_X, x_size); } - if (src1->backend != GGML_BACKEND_GPU) { + if (src1->backend != GGML_BACKEND_TYPE_GPU) { ggml_cl_pool_free(d_Y, y_size); } - if (dst->backend != GGML_BACKEND_GPU) { + if (dst->backend != GGML_BACKEND_TYPE_GPU) { ggml_cl_pool_free(d_D, d_size); } } @@ -1670,7 +1670,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr size_t y_size; size_t d_size; cl_mem d_X; - if (src0->backend == GGML_BACKEND_GPU) { // NOLINT + if (src0->backend == GGML_BACKEND_TYPE_GPU) { // NOLINT d_X = (cl_mem) src0->extra; } else { d_X = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &x_size); @@ -1687,7 +1687,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr // TODO: copy src0 here when r3>1 for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - if (src0->backend == GGML_BACKEND_GPU) { + if (src0->backend == GGML_BACKEND_TYPE_GPU) { x_offset = (i03 * ne02 + i02) * x_ne; } else { // copy src0 to device @@ -1741,7 +1741,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr } // copy dst to host, then convert to float - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL)); float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); ggml_fp16_to_fp32_row(tmp, d, d_ne); @@ -1753,7 +1753,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr } } - if (src0->backend != GGML_BACKEND_GPU) { + if (src0->backend != GGML_BACKEND_TYPE_GPU) { ggml_cl_pool_free(d_X, x_size); } ggml_cl_pool_free(d_Y, y_size); @@ -1798,7 +1798,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size); cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size); cl_mem d_Q; - if (src0->backend == GGML_BACKEND_CPU) { + if (src0->backend == GGML_BACKEND_TYPE_CPU) { d_Q = ggml_cl_pool_malloc(q_sz, &q_size); } @@ -1817,10 +1817,10 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) { for (int64_t i02 = 0; i02 < ne02; i02++) { // copy src0 to device if necessary - if (src0->backend == GGML_BACKEND_CPU) { + if (src0->backend == GGML_BACKEND_TYPE_CPU) { events.emplace_back(); CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++)); - } else if (src0->backend == GGML_BACKEND_GPU) { + } else if (src0->backend == GGML_BACKEND_TYPE_GPU) { d_Q = (cl_mem) src0->extra; } else { GGML_ASSERT(false); @@ -1829,7 +1829,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * if (!mul_mat_vec) { // convert src0 to fp32 on device const size_t global = x_ne / global_denom; - const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0; + const size_t offset = src0->backend == GGML_BACKEND_TYPE_GPU ? (i03 * ne02 + i02) * x_bps : 0; CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q)); CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X)); CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, &offset, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL)); @@ -1843,7 +1843,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * // compute const size_t global = ne01 * local; - const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0; + const size_t offset = src0->backend == GGML_BACKEND_TYPE_GPU ? (i03 * ne02 + i02) * x_bps : 0; const cl_int ncols = ne00; events.emplace_back(); CL_CHECK(clSetKernelArg(*dmmv, 0, sizeof(cl_mem), &d_Q)); @@ -1895,7 +1895,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * } ggml_cl_pool_free(d_Y, y_size); ggml_cl_pool_free(d_D, d_size); - if (src0->backend == GGML_BACKEND_CPU) { + if (src0->backend == GGML_BACKEND_TYPE_CPU) { ggml_cl_pool_free(d_Q, q_size); } } @@ -1911,7 +1911,7 @@ bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && - ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU)) { + ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_TYPE_GPU)) { return true; } @@ -1993,7 +1993,7 @@ void ggml_cl_transform_tensor(void * data, ggml_tensor * tensor) { CL_CHECK(clFinish(queue)); tensor->extra = dst; - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); } // ggml-backend @@ -2045,7 +2045,7 @@ static void ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ctx->sub_buffers.push_back(sub_buffer); tensor->extra = sub_buffer; } - tensor->backend = GGML_BACKEND_GPU; + tensor->backend = GGML_BACKEND_TYPE_GPU; } static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index b897828f9..c6c3c6e6f 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -3338,7 +3338,7 @@ void print_ggml_tensor(const char*name, struct ggml_tensor *src){ size_t total_elements = ggml_nelements(src); - const bool src_on_device = src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT; + const bool src_on_device = src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT; float *src_data =NULL; if(src_on_device) { ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra; @@ -8086,11 +8086,11 @@ static void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { swap(dst_row[col], dst_row[ixj]); } } else { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { swap(dst_row[col], dst_row[ixj]); } } @@ -10825,7 +10825,7 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, const sycl::range<3> block_dims(1, 1, ncols); const sycl::range<3> block_nums(1, nrows, 1); - if (order == GGML_SORT_ASC) { + if (order == GGML_SORT_ORDER_ASC) { /* DPCT1049:44: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query @@ -10834,9 +10834,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - k_argsort_f32_i32(x, dst, ncols, item_ct1); + k_argsort_f32_i32(x, dst, ncols, item_ct1); }); - } else if (order == GGML_SORT_DESC) { + } else if (order == GGML_SORT_ORDER_DESC) { /* DPCT1049:45: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query @@ -10845,7 +10845,7 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - k_argsort_f32_i32(x, dst, ncols, item_ct1); + k_argsort_f32_i32(x, dst, ncols, item_ct1); }); } else { GGML_ASSERT(false); @@ -11407,12 +11407,12 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst, dpct::memcpy_direction kind; char * src_ptr; - if (src->backend == GGML_BACKEND_CPU) { + if (src->backend == GGML_BACKEND_TYPE_CPU) { kind = dpct::host_to_device; src_ptr = (char *) src->data; - // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_CPU src_ptr %p\n", src_ptr); - } else if (src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT) { - GGML_ASSERT(src->backend != GGML_BACKEND_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1])); + // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr); + } else if (src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT) { + GGML_ASSERT(src->backend != GGML_BACKEND_TYPE_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1])); kind = dpct::device_to_device; ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra; int id; @@ -11846,7 +11846,7 @@ inline void ggml_sycl_op_mul_mat_q( // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into - const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && device_id == g_main_device ? ne0 : row_diff; + const int64_t nrows_dst = dst->backend == GGML_BACKEND_TYPE_GPU && device_id == g_main_device ? ne0 : row_diff; switch (src0->type) { case GGML_TYPE_Q4_0: @@ -12119,7 +12119,7 @@ inline void ggml_sycl_op_mul_mat_sycl( // the main device has a larger memory buffer to hold the results from all GPUs // ldc == nrows of the matrix that cuBLAS writes into - int ldc = dst->backend == GGML_BACKEND_GPU && device_id == g_main_device ? ne0 : row_diff; + int ldc = dst->backend == GGML_BACKEND_TYPE_GPU && device_id == g_main_device ? ne0 : row_diff; #ifdef GGML_SYCL_F16 bool use_fp16 = true; // TODO(Yu) SYCL capability check @@ -12501,16 +12501,16 @@ static void ggml_sycl_op_flatten(const ggml_tensor *src0, const bool use_src1 = src1 != nullptr; const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1; - GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); - GGML_ASSERT( dst->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT); + GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT); ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; - const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU; - const bool dst_on_device = dst->backend == GGML_BACKEND_GPU; + const bool src0_on_device = src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT; + const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_TYPE_GPU; + const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU; // dd = data device float * src0_ddf = nullptr; @@ -12565,7 +12565,7 @@ static void ggml_sycl_op_flatten(const ggml_tensor *src0, main_stream->memcpy(dst->data, dst_ddf, ggml_nbytes(dst)))); } - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { SYCL_CHECK(CHECK_TRY_ERROR( dpct::get_current_device().queues_wait_and_throw())); } @@ -12640,8 +12640,8 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, const int nb2 = dst->nb[2]; const int nb3 = dst->nb[3]; - GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT); - GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT); + GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT); GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0); @@ -12656,13 +12656,13 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; + const bool src0_on_device = src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT; const bool src0_is_contiguous = ggml_is_contiguous(src0); const bool src1_is_contiguous = ggml_is_contiguous(src1); int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING); - const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; + const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT; GGML_ASSERT(!(split && ne02 > 1)); GGML_ASSERT(!(split && ne03 > 1)); GGML_ASSERT(!(split && ne02 < ne12)); @@ -12717,8 +12717,8 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, used_devices++; - const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device_index; - const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device_index; + const bool src1_on_device = src1->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device_index; + const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device_index; ggml_sycl_set_device(get_device_id_by_index(id)); const dpct::queue_ptr stream = g_syclStreams[id][0]; @@ -12782,8 +12782,8 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, continue; } - const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device_index; - const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device_index; + const bool src1_on_device = src1->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device_index; + const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device_index; const int64_t row_diff = row_high[id] - row_low[id]; ggml_sycl_set_device(get_device_id_by_index(id)); @@ -12809,12 +12809,12 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, // the main device memory buffer can be on VRAM scratch, with space for all partial results // in that case an offset on dst_ddf_i is needed - if (dst->backend == GGML_BACKEND_GPU && id == g_main_device_index) { + if (dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device_index) { dst_dd_i += row_low[id]; // offset is 0 if no tensor split } // copy src0, src1 to device if necessary - if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) { + if (src1->backend == GGML_BACKEND_TYPE_GPU && src1_is_contiguous) { if (id != g_main_device_index) { if (convert_src1_to_q8_1) { char * src1_ddq_i_source = src1_ddq[g_main_device_index] + src1_ddq_i_offset; @@ -12830,14 +12830,14 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, src1_ncols * ne10 * sizeof(float)))); } } - } else if (src1->backend == GGML_BACKEND_CPU || (src1_on_device && !src1_is_contiguous)) { + } else if (src1->backend == GGML_BACKEND_TYPE_CPU || (src1_on_device && !src1_is_contiguous)) { SYCL_CHECK(ggml_sycl_cpy_tensor_2d( src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); } else { GGML_ASSERT(false); } - if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_CPU || !src1_is_contiguous)) { + if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_TYPE_CPU || !src1_is_contiguous)) { quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); /* DPCT1010:92: SYCL uses exceptions to report errors and does @@ -12867,10 +12867,10 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, if (!dst_on_device) { void * dst_off_device; dpct::memcpy_direction kind; - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { dst_off_device = dst->data; kind = dpct::device_to_host; - } else if (dst->backend == GGML_BACKEND_GPU) { + } else if (dst->backend == GGML_BACKEND_TYPE_GPU) { dst_off_device = dst_extra->data_device[g_main_device_index]; kind = dpct::device_to_device; } else { @@ -12954,7 +12954,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, } } - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { SYCL_CHECK(ggml_sycl_set_device(g_main_device)); SYCL_CHECK(CHECK_TRY_ERROR( dpct::get_current_device().queues_wait_and_throw())); @@ -13091,7 +13091,7 @@ static void ggml_sycl_mul_mat_vec_p021(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) try { GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); - GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT); GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation GGML_ASSERT(src0->type == GGML_TYPE_F16); @@ -13129,7 +13129,7 @@ static void ggml_sycl_mul_mat_vec_nc(const ggml_tensor *src0, GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); - GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT); GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -13196,7 +13196,7 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0, GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT); GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -13372,11 +13372,11 @@ catch (sycl::exception const &exc) { static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool all_on_device = - (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && - (src1->backend == GGML_BACKEND_GPU) && - ( dst->backend == GGML_BACKEND_GPU); + (src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT) && + (src1->backend == GGML_BACKEND_TYPE_GPU) && + ( dst->backend == GGML_BACKEND_TYPE_GPU); - const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; + const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT; int64_t min_compute_capability = INT_MAX; for (int64_t id = 0; id < g_device_count; ++id) { @@ -13505,7 +13505,7 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) { GGML_ASSERT(!ggml_is_transposed(src00)); GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(src00->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_TENSOR_LOCALS(int64_t, ne0, src00, ne); @@ -13643,7 +13643,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, const dpct::queue_ptr stream = g_syclStreams[g_main_device_index][0]; - if (ids->backend == GGML_BACKEND_GPU) { + if (ids->backend == GGML_BACKEND_TYPE_GPU) { const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device_index]; SYCL_CHECK(CHECK_TRY_ERROR( stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)))); @@ -13661,20 +13661,20 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, ggml_tensor src1_row = *src1; ggml_tensor dst_row = *dst; - src1_row.backend = GGML_BACKEND_GPU; - dst_row.backend = GGML_BACKEND_GPU; + src1_row.backend = GGML_BACKEND_TYPE_GPU; + dst_row.backend = GGML_BACKEND_TYPE_GPU; src1_row.extra = &src1_row_extra; dst_row.extra = &dst_row_extra; - char * src1_original = src1->backend == GGML_BACKEND_CPU ? + char * src1_original = src1->backend == GGML_BACKEND_TYPE_CPU ? (char *) src1->data : (char *) src1_extra->data_device[g_main_device_index]; - char * dst_original = dst->backend == GGML_BACKEND_CPU ? + char * dst_original = dst->backend == GGML_BACKEND_TYPE_CPU ? (char *) dst->data : (char *) dst_extra->data_device[g_main_device_index]; if (src1->ne[1] == 1) { - GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); - GGML_ASSERT(dst->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU); + GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU); for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { //int32_t row_id; @@ -13756,7 +13756,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, } } - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); } } @@ -13779,8 +13779,8 @@ static void ggml_sycl_cpy(const ggml_tensor *src0, const ggml_tensor *src1, const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); - GGML_ASSERT(src0->backend == GGML_BACKEND_GPU); - GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU); + GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU); GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); @@ -13887,17 +13887,17 @@ void ggml_sycl_transform_tensor(void *data, struct ggml_tensor *tensor) try { memset(extra, 0, sizeof(*extra)); for (int64_t id = 0; id < g_device_count; ++id) { - if (backend == GGML_BACKEND_GPU && id != g_main_device_index) { + if (backend == GGML_BACKEND_TYPE_GPU && id != g_main_device_index) { continue; } ggml_sycl_set_device(get_device_id_by_index(id)); const dpct::queue_ptr stream = g_syclStreams[id][0]; int64_t row_low, row_high; - if (backend == GGML_BACKEND_GPU) { + if (backend == GGML_BACKEND_TYPE_GPU) { row_low = 0; row_high = nrows; - } else if (backend == GGML_BACKEND_GPU_SPLIT) { + } else if (backend == GGML_BACKEND_TYPE_GPU_SPLIT) { const int64_t rounding = get_row_rounding(tensor->type); row_low = id == 0 ? 0 : nrows*g_tensor_split[id]; @@ -13946,7 +13946,7 @@ void ggml_sycl_transform_tensor(void *data, struct ggml_tensor *tensor) try { extra->data_device[id] = buf; - if (backend == GGML_BACKEND_GPU_SPLIT) { + if (backend == GGML_BACKEND_TYPE_GPU_SPLIT) { for (int64_t is = 0; is < MAX_STREAMS; ++is) { SYCL_CHECK(CHECK_TRY_ERROR(extra->events[id][is] = new sycl::event())); @@ -13963,7 +13963,7 @@ catch (sycl::exception const &exc) { } void ggml_sycl_free_data(struct ggml_tensor *tensor) try { - if (!tensor || !tensor->extra || (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) ) { + if (!tensor || !tensor->extra || (tensor->backend != GGML_BACKEND_TYPE_GPU && tensor->backend != GGML_BACKEND_TYPE_GPU_SPLIT) ) { return; } @@ -14016,15 +14016,15 @@ static void ggml_sycl_assign_buffers_impl(struct ggml_tensor *tensor, return; } - tensor->backend = GGML_BACKEND_GPU; + tensor->backend = GGML_BACKEND_TYPE_GPU; - if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) { + if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_TYPE_CPU) { const ggml_op src0_op = tensor->src[0]->op; if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) { ggml_sycl_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc); } } - if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) { + if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_TYPE_CPU) { ggml_sycl_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc); } @@ -14042,7 +14042,7 @@ static void ggml_sycl_assign_buffers_impl(struct ggml_tensor *tensor, SYCL_CHECK(ggml_sycl_set_device(g_main_device)); const dpct::queue_ptr stream = g_syclStreams[g_main_device_index][0]; - if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { + if (inplace && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT)) { ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; char * src0_ddc = (char *) src0_extra->data_device[g_main_device_index]; size_t offset = 0; @@ -14111,7 +14111,7 @@ void ggml_sycl_assign_scratch_offset(struct ggml_tensor *tensor, const bool inplace = tensor->view_src != nullptr; - if (inplace && (tensor->view_src->backend == GGML_BACKEND_GPU || tensor->view_src->backend == GGML_BACKEND_GPU_SPLIT)) { + if (inplace && (tensor->view_src->backend == GGML_BACKEND_TYPE_GPU || tensor->view_src->backend == GGML_BACKEND_TYPE_GPU_SPLIT)) { ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->view_src->extra; char * src0_ddc = (char *) src0_extra->data_device[g_main_device_index]; size_t view_offset = 0; @@ -14132,7 +14132,7 @@ catch (sycl::exception const &exc) { } void ggml_sycl_copy_to_device(struct ggml_tensor *tensor) try { - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); GGML_ASSERT(ggml_is_contiguous(tensor)); ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; @@ -14219,9 +14219,9 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_ if (!g_sycl_loaded) return false; ggml_sycl_func_t func; - const bool any_on_device = tensor->backend == GGML_BACKEND_GPU - || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) - || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU); + const bool any_on_device = tensor->backend == GGML_BACKEND_TYPE_GPU + || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT)) + || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_TYPE_GPU); if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_MUL_MAT_ID) { return false; @@ -14359,14 +14359,14 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_ return false; } - if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT) { + if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT) { ggml_sycl_set_peer_access(tensor->src[1]->ne[1]); } if (params->ith != 0) { return true; } - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return true; } func(tensor->src[0], tensor->src[1], tensor); @@ -14517,7 +14517,7 @@ static void ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, extra->data_device[ctx->device] = tensor->data; - tensor->backend = GGML_BACKEND_GPU; + tensor->backend = GGML_BACKEND_TYPE_GPU; tensor->extra = extra; if (ggml_is_quantized(tensor->type)) { @@ -14548,7 +14548,7 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data, size_t offset, size_t size) try { - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; @@ -14573,7 +14573,7 @@ static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor *tensor, void *data, size_t offset, size_t size) try { - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; @@ -14809,7 +14809,7 @@ static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend, ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type"); - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); SYCL_CHECK(CHECK_TRY_ERROR(g_syclStreams[sycl_ctx->device][0]->memcpy( (char *)tensor->data + offset, data, size))); @@ -14827,7 +14827,7 @@ static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend, ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type"); - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); SYCL_CHECK(CHECK_TRY_ERROR(g_syclStreams[sycl_ctx->device][0]->memcpy( data, (const char *)tensor->data + offset, size))); @@ -14880,7 +14880,7 @@ static bool ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph ggml_sycl_set_main_device(sycl_ctx->device); ggml_compute_params params = {}; - params.type = GGML_TASK_COMPUTE; + params.type = GGML_TASK_TYPE_COMPUTE; params.ith = 0; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -14888,13 +14888,13 @@ static bool ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE) continue; - assert(node->backend == GGML_BACKEND_GPU); + assert(node->backend == GGML_BACKEND_TYPE_GPU); assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device)); assert(node->extra != nullptr); for (int j = 0; j < GGML_MAX_SRC; j++) { if (node->src[j] != nullptr) { - assert(node->src[j]->backend == GGML_BACKEND_GPU); + assert(node->src[j]->backend == GGML_BACKEND_TYPE_GPU); assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device)); assert(node->src[j]->extra != nullptr); } diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 4e5eaff15..6caafb822 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -2320,8 +2320,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su src1_uma = d_Qy != nullptr; } - const bool load_x = src0->backend != GGML_BACKEND_GPU && !src0_uma; - const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma; + const bool load_x = src0->backend != GGML_BACKEND_TYPE_GPU && !src0_uma; + const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma; const bool x_non_contig = !load_x && !ggml_vk_dim01_contiguous(src0); const bool y_non_contig = !load_y && !ggml_vk_dim01_contiguous(src1); @@ -2453,7 +2453,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su // compute ggml_vk_matmul(ctx, subctx, *pipeline, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, ne12*ne13, ne02, ne12, r2, r3, stride_batch_x, stride_batch_y, ne20*ne21); // NOLINT - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { // copy dst to host float * d = (float *) ((char *) dst->data); ggml_vk_buffer_read_async(ctx, subctx, d_D, 0, d, sizeof(float) * d_ne * ne12 * ne13); @@ -2506,8 +2506,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context src1_uma = d_Qy != nullptr; } - const bool load_x = src0->backend != GGML_BACKEND_GPU && !src0_uma; - const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma; + const bool load_x = src0->backend != GGML_BACKEND_TYPE_GPU && !src0_uma; + const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma; const bool x_non_contig = !load_x && !ggml_vk_dim01_contiguous(src0); const bool y_non_contig = !load_y && !ggml_vk_dim01_contiguous(src1); @@ -2630,7 +2630,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, *dmmv, { { d_X, x_offset, x_sz }, { d_Y, y_buffer_offset, y_sz + y_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 3 * sizeof(int), &pc, { (uint32_t)ne01, 1, 1}); - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { // copy dst to host float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); ggml_vk_sync_buffers(subctx); @@ -2647,7 +2647,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl; #endif GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); - GGML_ASSERT(src0->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU); GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT GGML_ASSERT(src0->type == GGML_TYPE_F16); @@ -2679,7 +2679,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c src1_uma = d_Qy != nullptr; } - const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma; + const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma; const uint64_t x_ne = ne00 * ne01 * ne02; const uint64_t y_ne = ne10 * ne11 * ne12; @@ -2721,7 +2721,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_mul_mat_vec_p021_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { // copy dst to host float * d = (float *) dst->data; ggml_vk_sync_buffers(subctx); @@ -2738,7 +2738,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); - GGML_ASSERT(src0->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU); GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -2771,7 +2771,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con src1_uma = d_Qy != nullptr; } - const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma; + const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma; const uint64_t d_ne = ne01 * ne11 * ne12; @@ -2814,7 +2814,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_mul_mat_vec_nc_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { // copy dst to host float * d = (float *) dst->data; ggml_vk_sync_buffers(subctx); @@ -2832,7 +2832,7 @@ static bool ggml_vk_can_mul_mat(const ggml_tensor * src0, const ggml_tensor * sr return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || ggml_is_quantized(src1->type)) && dst->type == GGML_TYPE_F32 && - ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU); + ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_TYPE_GPU); } static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { @@ -2880,8 +2880,8 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx // TODO: support for transposed / permuted tensors GGML_ASSERT(nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(src0->backend == GGML_BACKEND_GPU); - GGML_ASSERT(dst->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU); + GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU); ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; @@ -3110,8 +3110,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c } } - const bool transfer_src0 = src0->backend != GGML_BACKEND_GPU && !src0_uma; - const bool transfer_src1 = use_src1 && src1->backend != GGML_BACKEND_GPU && !src1_uma; + const bool transfer_src0 = src0->backend != GGML_BACKEND_TYPE_GPU && !src0_uma; + const bool transfer_src1 = use_src1 && src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma; uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type) * ne0, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment); uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) : 0; @@ -3120,7 +3120,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c vk_buffer d_D = extra->buffer_gpu.lock(); // Workaround for tiny tensor inputs on ROPE - if (use_src1 && src1->backend == GGML_BACKEND_GPU && y_sz > d_D->size) { + if (use_src1 && src1->backend == GGML_BACKEND_TYPE_GPU && y_sz > d_D->size) { y_sz = VK_WHOLE_SIZE; } @@ -3209,9 +3209,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); } - if (dst->backend == GGML_BACKEND_CPU && op == GGML_OP_CPY) { + if (dst->backend == GGML_BACKEND_TYPE_CPU && op == GGML_OP_CPY) { ggml_vk_d2h_tensor_2d(ctx, subctx, d_D, 0, dst); - } else if(dst->backend == GGML_BACKEND_CPU) { + } else if(dst->backend == GGML_BACKEND_TYPE_CPU) { // copy dst to host float * d = (float *) dst->data; ggml_vk_buffer_read_async(ctx, subctx, d_D, 0, d, d_sz); @@ -3253,7 +3253,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements); } - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { // copy dst to host ggml_vk_buffer_read_async(ctx, subctx, d_D, d_buf_offset + d_offset, (char *) dst->data + i02*nb2 + i03*nb3, d_sz); } @@ -3359,7 +3359,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con static void ggml_vk_nop(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { // If backend is CPU, data from src0 has to be copied off the device - if (dst->backend == GGML_BACKEND_CPU) { + if (dst->backend == GGML_BACKEND_TYPE_CPU) { ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; vk_buffer d_D = extra_src0->buffer_gpu.lock(); ggml_vk_sync_buffers(subctx); @@ -3994,9 +3994,9 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm #ifdef GGML_VULKAN_DEBUG std::cerr << "ggml_vk_preallocate_buffers_graph(" << node << ")" << std::endl; #endif - const bool any_on_device = node->backend == GGML_BACKEND_GPU - || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) - || (node->src[1] != nullptr && (node->src[1]->backend == GGML_BACKEND_GPU)); + const bool any_on_device = node->backend == GGML_BACKEND_TYPE_GPU + || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_TYPE_GPU || node->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT)) + || (node->src[1] != nullptr && (node->src[1]->backend == GGML_BACKEND_TYPE_GPU)); if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT)) { return; @@ -4215,9 +4215,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, bool last_node){ - const bool any_on_device = node->backend == GGML_BACKEND_GPU - || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) - || (node->src[1] != nullptr && node->src[1]->backend == GGML_BACKEND_GPU); + const bool any_on_device = node->backend == GGML_BACKEND_TYPE_GPU + || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_TYPE_GPU || node->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT)) + || (node->src[1] != nullptr && node->src[1]->backend == GGML_BACKEND_TYPE_GPU); if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT) || (node->op == GGML_OP_MUL_MAT && !any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node))) { return; @@ -4371,7 +4371,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod last_node = true; #endif - if (node->backend == GGML_BACKEND_CPU || last_node) { + if (node->backend == GGML_BACKEND_TYPE_CPU || last_node) { ggml_vk_ctx_end(ctx->compute_ctx); ctx->compute_ctx->exit_tensor = node; ctx->compute_ctx = nullptr; @@ -4379,9 +4379,9 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod } static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor){ - const bool any_on_device = tensor->backend == GGML_BACKEND_GPU - || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) - || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU); + const bool any_on_device = tensor->backend == GGML_BACKEND_TYPE_GPU + || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT)) + || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_TYPE_GPU); if (ctx->disable || (!any_on_device && tensor->op != GGML_OP_MUL_MAT)) { return false; @@ -4442,7 +4442,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_ if (params->ith != 0) { return true; } - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return true; } @@ -4745,7 +4745,7 @@ GGML_CALL static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t b extra->offset = (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base; } - tensor->backend = GGML_BACKEND_GPU; + tensor->backend = GGML_BACKEND_TYPE_GPU; tensor->extra = extra; } @@ -4753,7 +4753,7 @@ GGML_CALL static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t bu #ifdef GGML_VULKAN_DEBUG std::cerr << "ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")" << std::endl; #endif - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; @@ -4768,7 +4768,7 @@ GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t bu #ifdef GGML_VULKAN_DEBUG std::cerr << "ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")" << std::endl; #endif - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; @@ -4999,7 +4999,7 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g #endif ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_buffer_type(ctx->idx) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; @@ -5020,7 +5020,7 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c #endif ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_buffer_type(ctx->idx) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; @@ -5097,7 +5097,7 @@ GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml int last_node = cgraph->n_nodes - 1; // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly - while (last_node > 0 && cgraph->nodes[last_node]->backend != GGML_BACKEND_GPU) { + while (last_node > 0 && cgraph->nodes[last_node]->backend != GGML_BACKEND_TYPE_GPU) { last_node -= 1; } @@ -5106,7 +5106,7 @@ GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml } ggml_compute_params params = {}; - params.type = GGML_TASK_COMPUTE; + params.type = GGML_TASK_TYPE_COMPUTE; params.ith = 0; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -5410,7 +5410,7 @@ static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * d static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tensor * tensor, const char * name) { void * tensor_data = tensor->data; - if (tensor->backend == GGML_BACKEND_GPU) { + if (tensor->backend == GGML_BACKEND_TYPE_GPU) { const size_t tensor_size = ggml_nbytes(tensor); tensor_data = malloc(tensor_size); @@ -5436,14 +5436,14 @@ static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tenso std::vector done; ggml_vk_print_graph_origin(tensor, done); - if (tensor->backend == GGML_BACKEND_GPU) { + if (tensor->backend == GGML_BACKEND_TYPE_GPU) { free(tensor_data); } } static void ggml_vk_check_tensor(const std::string& name, const ggml_tensor * tensor) { return; - GGML_ASSERT(tensor->backend == GGML_BACKEND_CPU); + GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_CPU); if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) { return; } @@ -5481,7 +5481,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ if (params->ith != 0) { return; } - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) { return; } @@ -5518,10 +5518,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ src0_buffer = malloc(src0_size); src0_clone->data = src0_buffer; - if (src0->backend == GGML_BACKEND_CPU) { + if (src0->backend == GGML_BACKEND_TYPE_CPU) { memcpy(src0_clone->data, src0->data, src0_size); memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (src0->backend == GGML_BACKEND_GPU) { + } else if (src0->backend == GGML_BACKEND_TYPE_GPU) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src0->extra; uint64_t offset = extra->offset; if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) { @@ -5561,10 +5561,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ src1_buffer = malloc(src1_size); src1_clone->data = src1_buffer; - if (src1->backend == GGML_BACKEND_CPU) { + if (src1->backend == GGML_BACKEND_TYPE_CPU) { memcpy(src1_clone->data, src1->data, src1_size); memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (src1->backend == GGML_BACKEND_GPU) { + } else if (src1->backend == GGML_BACKEND_TYPE_GPU) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src1->extra; uint64_t offset = extra->offset; if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) { @@ -5723,7 +5723,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_ if (params->ith != 0) { return; } - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) { return; } if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { @@ -5735,7 +5735,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_ void * tensor_data = tensor->data; - if (tensor->backend == GGML_BACKEND_GPU) { + if (tensor->backend == GGML_BACKEND_TYPE_GPU) { size_t tensor_size = ggml_nbytes(tensor); tensor_data = malloc(tensor_size); @@ -5868,7 +5868,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_ comp_result = nullptr; comp_size = 0; - if (tensor->backend == GGML_BACKEND_GPU) { + if (tensor->backend == GGML_BACKEND_TYPE_GPU) { free(tensor_data); } } diff --git a/ggml.c b/ggml.c index c09a3cad6..1d81553f4 100644 --- a/ggml.c +++ b/ggml.c @@ -2721,7 +2721,7 @@ static struct ggml_tensor * ggml_new_tensor_impl( } } - struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size); + struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size); // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here @@ -2729,7 +2729,7 @@ static struct ggml_tensor * ggml_new_tensor_impl( *result = (struct ggml_tensor) { /*.type =*/ type, - /*.backend =*/ GGML_BACKEND_CPU, + /*.backend =*/ GGML_BACKEND_TYPE_CPU, /*.buffer =*/ NULL, /*.ne =*/ { 1, 1, 1, 1 }, /*.nb =*/ { 0, 0, 0, 0 }, @@ -3302,7 +3302,7 @@ struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx) { char * const mem_buffer = ctx->mem_buffer; while (obj != NULL) { - if (obj->type == GGML_OBJECT_TENSOR) { + if (obj->type == GGML_OBJECT_TYPE_TENSOR) { return (struct ggml_tensor *)(mem_buffer + obj->offs); } @@ -3319,7 +3319,7 @@ struct ggml_tensor * ggml_get_next_tensor(const struct ggml_context * ctx, struc char * const mem_buffer = ctx->mem_buffer; while (obj != NULL) { - if (obj->type == GGML_OBJECT_TENSOR) { + if (obj->type == GGML_OBJECT_TYPE_TENSOR) { return (struct ggml_tensor *)(mem_buffer + obj->offs); } @@ -3335,7 +3335,7 @@ struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * nam char * const mem_buffer = ctx->mem_buffer; while (obj != NULL) { - if (obj->type == GGML_OBJECT_TENSOR) { + if (obj->type == GGML_OBJECT_TYPE_TENSOR) { struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs); if (strcmp(cur->name, name) == 0) { return cur; @@ -5879,7 +5879,7 @@ struct ggml_tensor * ggml_top_k( int k) { GGML_ASSERT(a->ne[0] >= k); - struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_DESC); + struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC); result = ggml_view_4d(ctx, result, k, result->ne[1], result->ne[2], result->ne[3], @@ -6673,7 +6673,7 @@ static void ggml_compute_forward_dup_same_cont( GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == dst->type); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -6705,7 +6705,7 @@ static void ggml_compute_forward_dup_f16( GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -6978,7 +6978,7 @@ static void ggml_compute_forward_dup_f32( GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -7231,7 +7231,7 @@ static void ggml_compute_forward_dup_bytes( GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); GGML_ASSERT(src0->type == dst->type); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -7411,7 +7411,7 @@ static void ggml_compute_forward_add_f32( GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -7419,7 +7419,7 @@ static void ggml_compute_forward_add_f32( const int nth = params->nth; #ifdef GGML_USE_CLBLAST - if (src1->backend == GGML_BACKEND_GPU) { + if (src1->backend == GGML_BACKEND_TYPE_GPU) { // TODO: OpenCL kernel support full broadcast GGML_ASSERT(ggml_can_repeat_rows(src1, src0)); if (ith == 0) { @@ -7501,7 +7501,7 @@ static void ggml_compute_forward_add_f16_f32( GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -7580,7 +7580,7 @@ static void ggml_compute_forward_add_f16_f16( GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -7636,7 +7636,7 @@ static void ggml_compute_forward_add_q_f32( GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -7774,7 +7774,7 @@ static void ggml_compute_forward_add1_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_is_scalar(src1)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -7828,7 +7828,7 @@ static void ggml_compute_forward_add1_f16_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_is_scalar(src1)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -7880,7 +7880,7 @@ static void ggml_compute_forward_add1_f16_f16( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_is_scalar(src1)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -7932,7 +7932,7 @@ static void ggml_compute_forward_add1_q_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_is_scalar(src1)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -8062,7 +8062,7 @@ static void ggml_compute_forward_acc_f32( size_t offset = ((int32_t *) dst->op_params)[3]; bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - if (!inplace && (params->type == GGML_TASK_INIT)) { + if (!inplace && (params->type == GGML_TASK_TYPE_INIT)) { if (params->ith != 0) { return; } @@ -8074,7 +8074,7 @@ static void ggml_compute_forward_acc_f32( ggml_nbytes(dst)); } - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -8176,7 +8176,7 @@ static void ggml_compute_forward_sub_f32( assert(params->ith == 0); assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -8257,14 +8257,14 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } const int ith = params->ith; const int nth = params->nth; #if defined(GGML_USE_CLBLAST) - if (src1->backend == GGML_BACKEND_GPU) { + if (src1->backend == GGML_BACKEND_TYPE_GPU) { // TODO: OpenCL kernel support full broadcast GGML_ASSERT(ggml_can_repeat_rows(src1, src0)); if (ith == 0) { @@ -8365,7 +8365,7 @@ static void ggml_compute_forward_div_f32( GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -8460,7 +8460,7 @@ static void ggml_compute_forward_sqr_f32( assert(params->ith == 0); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -8506,7 +8506,7 @@ static void ggml_compute_forward_sqrt_f32( assert(params->ith == 0); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -8552,7 +8552,7 @@ static void ggml_compute_forward_log_f32( GGML_ASSERT(params->ith == 0); GGML_ASSERT(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -8598,7 +8598,7 @@ static void ggml_compute_forward_sum_f32( assert(params->ith == 0); assert(ggml_is_scalar(dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -8633,7 +8633,7 @@ static void ggml_compute_forward_sum_f16( assert(params->ith == 0); assert(ggml_is_scalar(dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -8690,7 +8690,7 @@ static void ggml_compute_forward_sum_rows_f32( GGML_ASSERT(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -8745,7 +8745,7 @@ static void ggml_compute_forward_mean_f32( assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -8804,7 +8804,7 @@ static void ggml_compute_forward_argmax_f32( assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -8855,7 +8855,7 @@ static void ggml_compute_forward_repeat_f32( GGML_ASSERT(params->ith == 0); GGML_ASSERT(ggml_can_repeat(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -8900,7 +8900,7 @@ static void ggml_compute_forward_repeat_f16( GGML_ASSERT(params->ith == 0); GGML_ASSERT(ggml_can_repeat(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -8974,7 +8974,7 @@ static void ggml_compute_forward_repeat_back_f32( GGML_ASSERT(params->ith == 0); GGML_ASSERT(ggml_can_repeat(dst, src0)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9051,7 +9051,7 @@ static void ggml_compute_forward_concat_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9123,7 +9123,7 @@ static void ggml_compute_forward_abs_f32( assert(params->ith == 0); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9169,7 +9169,7 @@ static void ggml_compute_forward_sgn_f32( assert(params->ith == 0); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9215,7 +9215,7 @@ static void ggml_compute_forward_neg_f32( assert(params->ith == 0); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9261,7 +9261,7 @@ static void ggml_compute_forward_step_f32( assert(params->ith == 0); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9307,7 +9307,7 @@ static void ggml_compute_forward_tanh_f32( assert(params->ith == 0); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9353,7 +9353,7 @@ static void ggml_compute_forward_elu_f32( assert(params->ith == 0); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9399,7 +9399,7 @@ static void ggml_compute_forward_relu_f32( assert(params->ith == 0); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9446,7 +9446,7 @@ static void ggml_compute_forward_gelu_f32( GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9509,7 +9509,7 @@ static void ggml_compute_forward_gelu_quick_f32( GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9572,7 +9572,7 @@ static void ggml_compute_forward_silu_f32( GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9633,7 +9633,7 @@ static void ggml_compute_forward_leaky_relu_f32( assert(params->ith == 0); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9686,7 +9686,7 @@ static void ggml_compute_forward_silu_back_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, grad)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9748,7 +9748,7 @@ static void ggml_compute_forward_hardswish_f32( assert(params->ith == 0); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9791,7 +9791,7 @@ static void ggml_compute_forward_hardsigmoid_f32( assert(params->ith == 0); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9837,7 +9837,7 @@ static void ggml_compute_forward_norm_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9912,7 +9912,7 @@ static void ggml_compute_forward_rms_norm_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -9983,7 +9983,7 @@ static void ggml_compute_forward_rms_norm_back_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -10161,7 +10161,7 @@ static void ggml_compute_forward_group_norm_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -10328,7 +10328,7 @@ static void ggml_compute_forward_mul_mat( #if defined(GGML_USE_CLBLAST) if (ggml_cl_can_mul_mat(src0, src1, dst)) { - if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { + if (params->ith == 0 && params->type == GGML_TASK_TYPE_COMPUTE) { ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); } return; @@ -10341,7 +10341,7 @@ static void ggml_compute_forward_mul_mat( const size_t desired_wsize = ne13*ne12*ne_plane*sizeof(float); UNUSED(desired_wsize); - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { if (type != GGML_TYPE_F32) { assert(params->wsize >= desired_wsize); // parallelize by src0 rows @@ -10364,7 +10364,7 @@ static void ggml_compute_forward_mul_mat( return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -10402,7 +10402,7 @@ static void ggml_compute_forward_mul_mat( } #endif - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { return; } @@ -10426,7 +10426,7 @@ static void ggml_compute_forward_mul_mat( return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -10583,7 +10583,7 @@ static void ggml_compute_forward_mul_mat_id( #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)] - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { return; } @@ -10620,7 +10620,7 @@ static void ggml_compute_forward_mul_mat_id( return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -10768,7 +10768,7 @@ static void ggml_compute_forward_out_prod_f32( (ggml_is_contiguous(src1) || ggml_is_transposed(src1)); #endif - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst if (use_blas) { return; @@ -10781,7 +10781,7 @@ static void ggml_compute_forward_out_prod_f32( return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -10961,7 +10961,7 @@ static void ggml_compute_forward_out_prod_q_f32( // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { return; } @@ -10969,7 +10969,7 @@ static void ggml_compute_forward_out_prod_q_f32( return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -11087,7 +11087,7 @@ static void ggml_compute_forward_scale_f32( GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -11159,7 +11159,7 @@ static void ggml_compute_forward_set_f32( size_t offset = ((int32_t *) dst->op_params)[3]; bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - if (!inplace && (params->type == GGML_TASK_INIT)) { + if (!inplace && (params->type == GGML_TASK_TYPE_INIT)) { if (params->ith != 0) { return; } @@ -11171,7 +11171,7 @@ static void ggml_compute_forward_set_f32( ggml_nbytes(dst)); } - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -11319,7 +11319,7 @@ static void ggml_compute_forward_get_rows_q( assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -11359,7 +11359,7 @@ static void ggml_compute_forward_get_rows_f16( assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -11396,7 +11396,7 @@ static void ggml_compute_forward_get_rows_f32( assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -11499,14 +11499,14 @@ static void ggml_compute_forward_get_rows_back_f32_f16( // ggml_compute_forward_dup_same_cont(params, opt0, dst); - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { if (params->ith != 0) { return; } memset(dst->data, 0, ggml_nbytes(dst)); } - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -11538,14 +11538,14 @@ static void ggml_compute_forward_get_rows_back_f32( // ggml_compute_forward_dup_same_cont(params, opt0, dst); - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { if (params->ith != 0) { return; } memset(dst->data, 0, ggml_nbytes(dst)); } - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -11615,7 +11615,7 @@ static void ggml_compute_forward_diag_f32( GGML_ASSERT(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -11684,7 +11684,7 @@ static void ggml_compute_forward_diag_mask_f32( GGML_ASSERT(n_past >= 0); - if (!inplace && (params->type == GGML_TASK_INIT)) { + if (!inplace && (params->type == GGML_TASK_TYPE_INIT)) { if (ith != 0) { return; } @@ -11698,7 +11698,7 @@ static void ggml_compute_forward_diag_mask_f32( ggml_nbytes(dst)); } - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -11772,7 +11772,7 @@ static void ggml_compute_forward_soft_max_f32( assert(ggml_is_contiguous(dst)); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -11910,7 +11910,7 @@ static void ggml_compute_forward_soft_max_back_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src1, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -12004,7 +12004,7 @@ static void ggml_compute_forward_alibi_f32( assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -12063,7 +12063,7 @@ static void ggml_compute_forward_alibi_f16( assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -12170,7 +12170,7 @@ static void ggml_compute_forward_clamp_f32( assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -12310,7 +12310,7 @@ static void ggml_compute_forward_rope_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -12488,7 +12488,7 @@ static void ggml_compute_forward_rope_f16( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -12719,7 +12719,7 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32( GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { return; } @@ -12759,7 +12759,7 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32( return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -12818,7 +12818,7 @@ static void ggml_compute_forward_conv_transpose_1d_f32( GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { return; } @@ -12858,7 +12858,7 @@ static void ggml_compute_forward_conv_transpose_1d_f32( return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -12962,11 +12962,11 @@ static void ggml_compute_forward_im2col_f32( GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -13050,11 +13050,11 @@ static void ggml_compute_forward_im2col_f16( GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -13136,7 +13136,7 @@ static void ggml_compute_forward_conv_transpose_2d( GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { return; } @@ -13178,7 +13178,7 @@ static void ggml_compute_forward_conv_transpose_2d( return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -13230,7 +13230,7 @@ static void ggml_compute_forward_pool_1d_sk_p0( assert(src->type == GGML_TYPE_F32); assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -13299,7 +13299,7 @@ static void ggml_compute_forward_pool_2d( GGML_ASSERT(src->type == GGML_TYPE_F32); GGML_ASSERT(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -13372,7 +13372,7 @@ static void ggml_compute_forward_upscale_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -13432,7 +13432,7 @@ static void ggml_compute_forward_pad_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -13493,7 +13493,7 @@ static void ggml_compute_forward_argsort_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -13519,8 +13519,8 @@ static void ggml_compute_forward_argsort_f32( // C doesn't have a functional sort, so we do a bubble sort instead for (int64_t j = 0; j < ne0; j++) { for (int64_t k = j + 1; k < ne0; k++) { - if ((order == GGML_SORT_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || - (order == GGML_SORT_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { + if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || + (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { int32_t tmp = dst_data[j]; dst_data[j] = dst_data[k]; dst_data[k] = tmp; @@ -13603,11 +13603,11 @@ static void ggml_compute_forward_flash_attn_f32( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -13795,11 +13795,11 @@ static void ggml_compute_forward_flash_attn_f16( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -14054,11 +14054,11 @@ static void ggml_compute_forward_flash_ff_f16( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -14213,14 +14213,14 @@ static void ggml_compute_forward_flash_attn_back_f32( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { if (ith == 0) { memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); } return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -14536,7 +14536,7 @@ static void ggml_compute_forward_win_part_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -14602,7 +14602,7 @@ static void ggml_compute_forward_win_unpart_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -14730,7 +14730,7 @@ static void ggml_compute_forward_get_rel_pos_f16( const struct ggml_tensor * src0 = dst->src[0]; - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -14782,14 +14782,14 @@ static void ggml_compute_forward_add_rel_pos_f32( const struct ggml_tensor * src2 = dst->src[2]; const bool inplace = (bool) ((int32_t *) dst->op_params)[0]; - if (!inplace && params->type == GGML_TASK_INIT) { + if (!inplace && params->type == GGML_TASK_TYPE_INIT) { if (params->ith != 0) { return; } memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst)); return; } - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -14871,7 +14871,7 @@ static void ggml_compute_forward_map_unary_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -14920,7 +14920,7 @@ static void ggml_compute_forward_map_binary_f32( assert(params->ith == 0); assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -14969,7 +14969,7 @@ static void ggml_compute_forward_map_custom1_f32( assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -14988,7 +14988,7 @@ static void ggml_compute_forward_map_custom2_f32( assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -15008,7 +15008,7 @@ static void ggml_compute_forward_map_custom3_f32( assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -15023,7 +15023,7 @@ static void ggml_compute_forward_map_custom1( const struct ggml_tensor * a = dst->src[0]; - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -15041,7 +15041,7 @@ static void ggml_compute_forward_map_custom2( const struct ggml_tensor * a = dst->src[0]; const struct ggml_tensor * b = dst->src[1]; - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -15060,7 +15060,7 @@ static void ggml_compute_forward_map_custom3( const struct ggml_tensor * b = dst->src[1]; const struct ggml_tensor * c = dst->src[2]; - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -15094,14 +15094,14 @@ static void ggml_compute_forward_cross_entropy_loss_f32( GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc)); - if (params->type == GGML_TASK_INIT) { + if (params->type == GGML_TASK_TYPE_INIT) { if (ith == 0) { memset(sums, 0, sizeof(float) * (nth + nth * nc)); } return; } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_FINALIZE) { if (ith == 0) { float * dp = (float *) dst->data; ggml_vec_sum_f32(nth, dp, sums); @@ -15216,7 +15216,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( const int64_t ith = params->ith; const int64_t nth = params->nth; - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } @@ -15323,8 +15323,8 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm if (skip_cpu) { return; } - GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_CPU); - GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_CPU); + GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_TYPE_CPU); + GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_TYPE_CPU); #elif defined(GGML_USE_VULKAN) const bool skip_cpu = ggml_vk_compute_forward_cpu_assist(params, tensor); #ifdef GGML_VULKAN_CHECK_RESULTS @@ -15335,8 +15335,8 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm if (skip_cpu) { return; } - GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_CPU); - GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_CPU); + GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_TYPE_CPU); + GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_TYPE_CPU); #endif // GGML_USE_CUBLAS #ifdef GGML_USE_SYCL @@ -16882,7 +16882,7 @@ size_t ggml_graph_overhead(void) { struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads) { const size_t obj_size = ggml_graph_nbytes(size, grads); - struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, obj_size); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_GRAPH, obj_size); struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); struct ggml_tensor ** data_start = (struct ggml_tensor **) (cgraph + 1); @@ -17429,7 +17429,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { set_numa_thread_affinity(state->ith); int node_n = -1; - int task_phase = GGML_TASK_FINALIZE; + int task_phase = GGML_TASK_TYPE_FINALIZE; while (true) { if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { @@ -17441,7 +17441,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { // all other threads are finished and spinning // do finalize and init here so we don't have synchronize again struct ggml_compute_params params = { - /*.type =*/ GGML_TASK_FINALIZE, + /*.type =*/ GGML_TASK_TYPE_FINALIZE, /*.ith =*/ 0, /*.nth =*/ 0, /*.wsize =*/ cplan->work_size, @@ -17472,17 +17472,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (n_tasks == 1) { /* INIT */ if (GGML_OP_HAS_INIT[node->op]) { - params.type = GGML_TASK_INIT; + params.type = GGML_TASK_TYPE_INIT; ggml_compute_forward(¶ms, node); } // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, // they do something more efficient than spinning (?) - params.type = GGML_TASK_COMPUTE; + params.type = GGML_TASK_TYPE_COMPUTE; ggml_compute_forward(¶ms, node); if (GGML_OP_HAS_FINALIZE[node->op]) { - params.type = GGML_TASK_FINALIZE; + params.type = GGML_TASK_TYPE_FINALIZE; ggml_compute_forward(¶ms, node); } @@ -17496,7 +17496,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } } - task_phase = GGML_TASK_INIT; + task_phase = GGML_TASK_TYPE_INIT; atomic_store(&state->shared->n_active, n_threads); atomic_store(&state->shared->node_n, node_n); atomic_store(&state->shared->node_task, task_phase); @@ -17513,7 +17513,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { const int n_tasks = ggml_get_n_tasks(node, n_threads); struct ggml_compute_params params = { - /*.type =*/ GGML_TASK_INIT, + /*.type =*/ GGML_TASK_TYPE_INIT, /*.ith =*/ state->ith, /*.nth =*/ n_tasks, /*.wsize =*/ cplan->work_size, @@ -17527,7 +17527,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { - task_phase = GGML_TASK_COMPUTE; + task_phase = GGML_TASK_TYPE_COMPUTE; atomic_store(&state->shared->n_active, n_threads); atomic_store(&state->shared->node_task, task_phase); } @@ -17542,12 +17542,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } if (state->ith < n_tasks) { - params.type = GGML_TASK_COMPUTE; + params.type = GGML_TASK_TYPE_COMPUTE; ggml_compute_forward(¶ms, node); } if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { - task_phase = GGML_TASK_FINALIZE; + task_phase = GGML_TASK_TYPE_FINALIZE; atomic_store(&state->shared->n_active, n_threads); atomic_store(&state->shared->node_task, task_phase); } @@ -17783,7 +17783,7 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { /*.n_threads =*/ n_threads, /*.n_active =*/ n_threads, /*.node_n =*/ -1, - /*.node_task =*/ GGML_TASK_FINALIZE, + /*.node_task =*/ GGML_TASK_TYPE_FINALIZE, /*.abort_callback =*/ NULL, /*.abort_callback_data =*/ NULL, }; @@ -17851,7 +17851,7 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) { struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads); - struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; @@ -18659,7 +18659,7 @@ static enum ggml_opt_result ggml_opt_adam( float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); - struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; bool cancel = false; @@ -18671,7 +18671,7 @@ static enum ggml_opt_result ggml_opt_adam( if (callback) { callback(callback_data, accum_step, &sched, &cancel); if (cancel) { - return GGML_OPT_CANCEL; + return GGML_OPT_RESULT_CANCEL; } } // ggml_graph_reset (gf); @@ -18762,7 +18762,7 @@ static enum ggml_opt_result ggml_opt_adam( if (callback) { callback(callback_data, accum_step, &sched, &cancel); if (cancel) { - return GGML_OPT_CANCEL;; + return GGML_OPT_RESULT_CANCEL;; } } // ggml_graph_reset (gf); @@ -18779,7 +18779,7 @@ static enum ggml_opt_result ggml_opt_adam( if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) { GGML_PRINT_DEBUG("converged\n"); - return GGML_OPT_OK; + return GGML_OPT_RESULT_OK; } // delta-based convergence test @@ -18789,7 +18789,7 @@ static enum ggml_opt_result ggml_opt_adam( const float rate = (pf[(iter0 + t)%params.past] - fx)/fx; if (fabsf(rate) < params.delta) { - return GGML_OPT_OK; + return GGML_OPT_RESULT_OK; } } @@ -18805,7 +18805,7 @@ static enum ggml_opt_result ggml_opt_adam( ++n_no_improvement[0]; if (n_no_improvement[0] >= params.max_no_improvement) { - return GGML_OPT_OK; + return GGML_OPT_RESULT_OK; } } } @@ -18823,7 +18823,7 @@ static enum ggml_opt_result ggml_opt_adam( } } - return GGML_OPT_DID_NOT_CONVERGE; + return GGML_OPT_RESULT_DID_NOT_CONVERGE; } // @@ -18904,7 +18904,7 @@ static enum ggml_opt_result linesearch_backtracking( float sched = 0; callback(callback_data, accum_step, &sched, cancel); if (*cancel) { - return GGML_OPT_CANCEL; + return GGML_OPT_RESULT_CANCEL; } } // ggml_graph_reset (gf); @@ -18977,7 +18977,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) { - return GGML_OPT_INVALID_WOLFE; + return GGML_OPT_RESULT_INVALID_WOLFE; } } @@ -19006,7 +19006,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( } struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); - struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; float * x = opt->lbfgs.x->data; // current parameters @@ -19047,7 +19047,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( float sched = 0; callback(callback_data, accum_step, &sched, &cancel); if (cancel) { - return GGML_OPT_CANCEL; + return GGML_OPT_RESULT_CANCEL; } } // ggml_graph_reset (gf); @@ -19075,7 +19075,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( // already optimized if (gnorm/xnorm <= params.lbfgs.eps) { - return GGML_OPT_OK; + return GGML_OPT_RESULT_OK; } if (opt->just_initialized) { @@ -19120,7 +19120,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( // way to test and don't want to break something with so many changes lined up ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data); if (cancel) { - return GGML_OPT_CANCEL; + return GGML_OPT_RESULT_CANCEL; } if (ls < 0) { @@ -19143,7 +19143,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( } if (gnorm/xnorm <= params.lbfgs.eps) { // converged - return GGML_OPT_OK; + return GGML_OPT_RESULT_OK; } // delta-based convergence test @@ -19153,7 +19153,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( const float rate = (pf[k[0]%params.past] - fx)/fx; if (fabsf(rate) < params.delta) { - return GGML_OPT_OK; + return GGML_OPT_RESULT_OK; } } @@ -19169,14 +19169,14 @@ static enum ggml_opt_result ggml_opt_lbfgs( n_no_improvement[0]++; if (n_no_improvement[0] >= params.max_no_improvement) { - return GGML_OPT_OK; + return GGML_OPT_RESULT_OK; } } } if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) { // reached the maximum number of iterations - return GGML_OPT_DID_NOT_CONVERGE; + return GGML_OPT_RESULT_DID_NOT_CONVERGE; } // update vectors s and y: @@ -19232,17 +19232,17 @@ static enum ggml_opt_result ggml_opt_lbfgs( GGML_ASSERT(false && "lbfgs failed"); - return GGML_OPT_DID_NOT_CONVERGE; + return GGML_OPT_RESULT_DID_NOT_CONVERGE; } struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { struct ggml_opt_params result; switch (type) { - case GGML_OPT_ADAM: + case GGML_OPT_TYPE_ADAM: { result = (struct ggml_opt_params) { - .type = GGML_OPT_ADAM, + .type = GGML_OPT_TYPE_ADAM, .graph_size = GGML_DEFAULT_GRAPH_SIZE, .n_threads = 1, // FIXME: GGML_DEFAULT_N_THREADS ? .past = 0, @@ -19270,10 +19270,10 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { }, }; } break; - case GGML_OPT_LBFGS: + case GGML_OPT_TYPE_LBFGS: { result = (struct ggml_opt_params) { - .type = GGML_OPT_LBFGS, + .type = GGML_OPT_TYPE_LBFGS, .graph_size = GGML_DEFAULT_GRAPH_SIZE, .n_threads = 1, .past = 0, @@ -19318,12 +19318,12 @@ GGML_API void ggml_opt_init( opt->just_initialized = true; if (opt->ctx == NULL) { struct ggml_init_params ctx_opt_params; - if (opt->params.type == GGML_OPT_ADAM) { + if (opt->params.type == GGML_OPT_TYPE_ADAM) { ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3; if (opt->params.past > 0) { ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past; } - } else if (opt->params.type == GGML_OPT_LBFGS) { + } else if (opt->params.type == GGML_OPT_TYPE_LBFGS) { ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2); if (opt->params.past > 0) { ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past; @@ -19335,7 +19335,7 @@ GGML_API void ggml_opt_init( opt->ctx = ggml_init(ctx_opt_params); } switch (opt->params.type) { - case GGML_OPT_ADAM: + case GGML_OPT_TYPE_ADAM: { opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); @@ -19349,7 +19349,7 @@ GGML_API void ggml_opt_init( ggml_set_zero(opt->adam.pf); } } break; - case GGML_OPT_LBFGS: + case GGML_OPT_TYPE_LBFGS: { opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); @@ -19393,13 +19393,13 @@ enum ggml_opt_result ggml_opt( ctx = ggml_init(params_ctx); if (ctx == NULL) { - return GGML_OPT_NO_CONTEXT; + return GGML_OPT_RESULT_NO_CONTEXT; } free_ctx = true; } - enum ggml_opt_result result = GGML_OPT_OK; + enum ggml_opt_result result = GGML_OPT_RESULT_OK; struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context)); @@ -19438,14 +19438,14 @@ enum ggml_opt_result ggml_opt_resume_g( void * callback_data) { // build forward + backward compute graphs - enum ggml_opt_result result = GGML_OPT_OK; + enum ggml_opt_result result = GGML_OPT_RESULT_OK; switch (opt->params.type) { - case GGML_OPT_ADAM: + case GGML_OPT_TYPE_ADAM: { result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb, callback, callback_data); } break; - case GGML_OPT_LBFGS: + case GGML_OPT_TYPE_LBFGS: { result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb, callback, callback_data); } break; diff --git a/ggml.h b/ggml.h index a4166e1f7..75fd035a4 100644 --- a/ggml.h +++ b/ggml.h @@ -364,9 +364,9 @@ extern "C" { }; enum ggml_backend_type { - GGML_BACKEND_CPU = 0, - GGML_BACKEND_GPU = 10, - GGML_BACKEND_GPU_SPLIT = 20, + GGML_BACKEND_TYPE_CPU = 0, + GGML_BACKEND_TYPE_GPU = 10, + GGML_BACKEND_TYPE_GPU_SPLIT = 20, }; // model file types @@ -498,9 +498,9 @@ extern "C" { }; enum ggml_object_type { - GGML_OBJECT_TENSOR, - GGML_OBJECT_GRAPH, - GGML_OBJECT_WORK_BUFFER + GGML_OBJECT_TYPE_TENSOR, + GGML_OBJECT_TYPE_GRAPH, + GGML_OBJECT_TYPE_WORK_BUFFER }; enum ggml_log_level { @@ -642,9 +642,9 @@ extern "C" { // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled. // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995. enum ggml_task_type { - GGML_TASK_INIT = 0, - GGML_TASK_COMPUTE, - GGML_TASK_FINALIZE, + GGML_TASK_TYPE_INIT = 0, + GGML_TASK_TYPE_COMPUTE, + GGML_TASK_TYPE_FINALIZE, }; struct ggml_compute_params { @@ -1649,8 +1649,8 @@ extern "C" { // sort rows enum ggml_sort_order { - GGML_SORT_ASC, - GGML_SORT_DESC, + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, }; GGML_API struct ggml_tensor * ggml_argsort( @@ -1943,8 +1943,8 @@ extern "C" { // optimization methods enum ggml_opt_type { - GGML_OPT_ADAM, - GGML_OPT_LBFGS, + GGML_OPT_TYPE_ADAM, + GGML_OPT_TYPE_LBFGS, }; // linesearch methods @@ -1958,12 +1958,12 @@ extern "C" { // optimization return values enum ggml_opt_result { - GGML_OPT_OK = 0, - GGML_OPT_DID_NOT_CONVERGE, - GGML_OPT_NO_CONTEXT, - GGML_OPT_INVALID_WOLFE, - GGML_OPT_FAIL, - GGML_OPT_CANCEL, + GGML_OPT_RESULT_OK = 0, + GGML_OPT_RESULT_DID_NOT_CONVERGE, + GGML_OPT_RESULT_NO_CONTEXT, + GGML_OPT_RESULT_INVALID_WOLFE, + GGML_OPT_RESULT_FAIL, + GGML_OPT_RESULT_CANCEL, GGML_LINESEARCH_FAIL = -128, GGML_LINESEARCH_MINIMUM_STEP, diff --git a/llama.cpp b/llama.cpp index 1f6b6cff4..acd9be08a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -850,9 +850,9 @@ struct LLM_TN { // static std::map LLAMA_ROPE_SCALING_TYPES = { - { LLAMA_ROPE_SCALING_NONE, "none" }, - { LLAMA_ROPE_SCALING_LINEAR, "linear" }, - { LLAMA_ROPE_SCALING_YARN, "yarn" }, + { LLAMA_ROPE_SCALING_TYPE_NONE, "none" }, + { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" }, + { LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" }, }; static int32_t llama_rope_scaling_type_from_string(const std::string & name) { @@ -862,7 +862,7 @@ static int32_t llama_rope_scaling_type_from_string(const std::string & name) { } } - return LLAMA_ROPE_SCALING_UNSPECIFIED; + return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; } static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) { @@ -1580,7 +1580,7 @@ struct llama_hparams { bool causal_attn = true; bool need_kq_pos = false; - uint32_t pooling_type = LLAMA_POOLING_NONE; + uint32_t pooling_type = LLAMA_POOLING_TYPE_NONE; bool operator!=(const llama_hparams & other) const { if (this->vocab_only != other.vocab_only) return true; @@ -2345,9 +2345,9 @@ namespace GGUFMeta { static const char * override_type_to_str(const llama_model_kv_override_type ty) { switch (ty) { - case LLAMA_KV_OVERRIDE_BOOL: return "bool"; - case LLAMA_KV_OVERRIDE_INT: return "int"; - case LLAMA_KV_OVERRIDE_FLOAT: return "float"; + case LLAMA_KV_OVERRIDE_TYPE_BOOL: return "bool"; + case LLAMA_KV_OVERRIDE_TYPE_INT: return "int"; + case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float"; } return "unknown"; } @@ -2358,13 +2358,13 @@ namespace GGUFMeta { LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ", __func__, override_type_to_str(override->tag), override->key); switch (override->tag) { - case LLAMA_KV_OVERRIDE_BOOL: { + case LLAMA_KV_OVERRIDE_TYPE_BOOL: { LLAMA_LOG_INFO("%s\n", override->bool_value ? "true" : "false"); } break; - case LLAMA_KV_OVERRIDE_INT: { + case LLAMA_KV_OVERRIDE_TYPE_INT: { LLAMA_LOG_INFO("%" PRId64 "\n", override->int_value); } break; - case LLAMA_KV_OVERRIDE_FLOAT: { + case LLAMA_KV_OVERRIDE_TYPE_FLOAT: { LLAMA_LOG_INFO("%.6f\n", override->float_value); } break; default: @@ -2383,7 +2383,7 @@ namespace GGUFMeta { template static typename std::enable_if::value, bool>::type try_override(OT & target, const struct llama_model_kv_override *override) { - if (validate_override(LLAMA_KV_OVERRIDE_BOOL, override)) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, override)) { target = override->bool_value; return true; } @@ -2393,7 +2393,7 @@ namespace GGUFMeta { template static typename std::enable_if::value && std::is_integral::value, bool>::type try_override(OT & target, const struct llama_model_kv_override *override) { - if (validate_override(LLAMA_KV_OVERRIDE_INT, override)) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, override)) { target = override->int_value; return true; } @@ -2403,7 +2403,7 @@ namespace GGUFMeta { template static typename std::enable_if::value, bool>::type try_override(T & target, const struct llama_model_kv_override *override) { - if (validate_override(LLAMA_KV_OVERRIDE_FLOAT, override)) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, override)) { target = override->float_value; return true; } @@ -2999,7 +2999,7 @@ static void llm_load_hparams( std::string rope_scaling("linear"); ml.get_key(LLM_KV_ROPE_SCALING_TYPE, rope_scaling, false); hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling); - GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED); + GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED); // rope_freq_scale (inverse of the kv) is optional float ropescale = 0.0f; @@ -3643,7 +3643,7 @@ static bool llm_load_tensors( model.buft_layer[i] = llama_default_buffer_type_cpu(true); } - if (split_mode == LLAMA_SPLIT_LAYER) { + if (split_mode == LLAMA_SPLIT_MODE_LAYER) { // calculate the split points int device_count = llama_get_device_count(); bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; }); @@ -3682,10 +3682,10 @@ static bool llm_load_tensors( } } else { ggml_backend_buffer_type_t split_buft; - if (split_mode == LLAMA_SPLIT_ROW) { + if (split_mode == LLAMA_SPLIT_MODE_ROW) { split_buft = llama_default_buffer_type_split(main_gpu, tensor_split); } else { - // LLAMA_SPLIT_NONE or LLAMA_SPLIT_LAYER in backends where it is not supported + // LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported split_buft = llama_default_buffer_type_offload(main_gpu); } // assign the repeating layers @@ -5070,7 +5070,7 @@ struct llm_build_context { kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), do_rope_shift (worst_case || kv_self.has_shift), - pooling_type (cparams.do_pooling ? hparams.pooling_type : (uint32_t)LLAMA_POOLING_NONE), + pooling_type (cparams.do_pooling ? hparams.pooling_type : (uint32_t)LLAMA_POOLING_TYPE_NONE), cb (cb), buf_compute_meta (lctx.buf_compute_meta) { // all initializations should be done in init() @@ -6050,12 +6050,12 @@ struct llm_build_context { cur = inpL; // pooling layer - if (pooling_type == LLAMA_POOLING_MEAN) { + if (pooling_type == LLAMA_POOLING_TYPE_MEAN) { cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); - } else if (pooling_type == LLAMA_POOLING_CLS) { + } else if (pooling_type == LLAMA_POOLING_TYPE_CLS) { cur = ggml_get_rows(ctx0, cur, inp_cls); } else { - GGML_ASSERT(pooling_type == LLAMA_POOLING_NONE && "Invalid pooling type"); + GGML_ASSERT(pooling_type == LLAMA_POOLING_TYPE_NONE && "Invalid pooling type"); } cb(cur, "result_embd", -1); @@ -7754,7 +7754,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_MEAN) { + if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); @@ -7782,7 +7782,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_CLS) { + if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); @@ -11351,7 +11351,7 @@ static int llama_apply_lora_from_file_internal( struct llama_model_params llama_model_default_params() { struct llama_model_params result = { /*.n_gpu_layers =*/ 0, - /*.split_mode =*/ LLAMA_SPLIT_LAYER, + /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, /*.progress_callback =*/ nullptr, @@ -11377,7 +11377,7 @@ struct llama_context_params llama_context_default_params() { /*.n_batch =*/ 512, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, - /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED, + /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, @@ -11565,16 +11565,16 @@ struct llama_context * llama_new_context_with_model( cparams.cb_eval_user_data = params.cb_eval_user_data; auto rope_scaling_type = params.rope_scaling_type; - if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) { + if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { rope_scaling_type = hparams.rope_scaling_type_train; } - if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) { + if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) { cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none } if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set' - cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_YARN ? 1.0f : 0.0f; + cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; } if (params.seed == LLAMA_DEFAULT_SEED) { @@ -11608,8 +11608,8 @@ struct llama_context * llama_new_context_with_model( } #elif defined(GGML_USE_CUBLAS) if (model->n_gpu_layers > 0) { - // with split_mode LLAMA_SPLIT_NONE or LLAMA_SPLIT_ROW, only the main GPU backend is used - if (model->split_mode == LLAMA_SPLIT_NONE || model->split_mode == LLAMA_SPLIT_ROW) { + // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used + if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) { ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu); if (backend == nullptr) { LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, model->main_gpu); @@ -11618,7 +11618,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(backend); } else { - // LLAMA_SPLIT_LAYER requires a backend for each GPU + // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) { ggml_backend_t backend = ggml_backend_cuda_init(device); if (backend == nullptr) { diff --git a/llama.h b/llama.h index 889edf4d9..947284ea2 100644 --- a/llama.h +++ b/llama.h @@ -109,23 +109,23 @@ extern "C" { }; enum llama_rope_scaling_type { - LLAMA_ROPE_SCALING_UNSPECIFIED = -1, - LLAMA_ROPE_SCALING_NONE = 0, - LLAMA_ROPE_SCALING_LINEAR = 1, - LLAMA_ROPE_SCALING_YARN = 2, - LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN, + LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = -1, + LLAMA_ROPE_SCALING_TYPE_NONE = 0, + LLAMA_ROPE_SCALING_TYPE_LINEAR = 1, + LLAMA_ROPE_SCALING_TYPE_YARN = 2, + LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN, }; enum llama_pooling_type { - LLAMA_POOLING_NONE = 0, - LLAMA_POOLING_MEAN = 1, - LLAMA_POOLING_CLS = 2, + LLAMA_POOLING_TYPE_NONE = 0, + LLAMA_POOLING_TYPE_MEAN = 1, + LLAMA_POOLING_TYPE_CLS = 2, }; enum llama_split_mode { - LLAMA_SPLIT_NONE = 0, // single GPU - LLAMA_SPLIT_LAYER = 1, // split layers and KV across GPUs - LLAMA_SPLIT_ROW = 2, // split rows across GPUs + LLAMA_SPLIT_MODE_NONE = 0, // single GPU + LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs + LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; typedef struct llama_token_data { @@ -173,9 +173,9 @@ extern "C" { } llama_batch; enum llama_model_kv_override_type { - LLAMA_KV_OVERRIDE_INT, - LLAMA_KV_OVERRIDE_FLOAT, - LLAMA_KV_OVERRIDE_BOOL, + LLAMA_KV_OVERRIDE_TYPE_INT, + LLAMA_KV_OVERRIDE_TYPE_FLOAT, + LLAMA_KV_OVERRIDE_TYPE_BOOL, }; struct llama_model_kv_override { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f8574588b..24d12ef14 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1264,7 +1264,7 @@ struct test_argsort : public test_case { test_argsort(ggml_type type = GGML_TYPE_F32, std::array ne = {16, 10, 10, 10}, - ggml_sort_order order = GGML_SORT_ASC) + ggml_sort_order order = GGML_SORT_ORDER_ASC) : type(type), ne(ne), order(order) {} ggml_tensor * build_graph(ggml_context * ctx) override { @@ -2116,7 +2116,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_concat(GGML_TYPE_F32)); test_cases.emplace_back(new test_concat(GGML_TYPE_I32)); - for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) { + for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); } diff --git a/tests/test-opt.cpp b/tests/test-opt.cpp index 2c9997fca..546ca230b 100644 --- a/tests/test-opt.cpp +++ b/tests/test-opt.cpp @@ -118,7 +118,7 @@ int main(void) { const float fe = ggml_get_f32_1d(e, 0); printf("%s: e = %.4f\n", __func__, fe); - struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_ADAM); + struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM); ggml_opt(ctx, opt_params, e);