code : normalize enum names (#5697)

* coda : normalize enum names

ggml-ci

* code : cont

* code : cont
This commit is contained in:
Georgi Gerganov 2024-02-25 12:09:09 +02:00 committed by GitHub
parent 69917dfa55
commit ab336a9d5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 502 additions and 502 deletions

View File

@ -295,9 +295,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
else { invalid_param = true; break; }
} else if (arg == "--rope-scale") {
if (++i >= argc) {
@ -630,11 +630,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
}
std::string arg_next = argv[i];
if (arg_next == "none") {
params.split_mode = LLAMA_SPLIT_NONE;
params.split_mode = LLAMA_SPLIT_MODE_NONE;
} else if (arg_next == "layer") {
params.split_mode = LLAMA_SPLIT_LAYER;
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
} else if (arg_next == "row") {
params.split_mode = LLAMA_SPLIT_ROW;
params.split_mode = LLAMA_SPLIT_MODE_ROW;
} else {
invalid_param = true;
break;
@ -837,15 +837,15 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
sep++;
if (strncmp(sep, "int:", 4) == 0) {
sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_INT;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
kvo.int_value = std::atol(sep);
} else if (strncmp(sep, "float:", 6) == 0) {
sep += 6;
kvo.tag = LLAMA_KV_OVERRIDE_FLOAT;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
kvo.float_value = std::atof(sep);
} else if (strncmp(sep, "bool:", 5) == 0) {
sep += 5;
kvo.tag = LLAMA_KV_OVERRIDE_BOOL;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
if (std::strcmp(sep, "true") == 0) {
kvo.bool_value = true;
} else if (std::strcmp(sep, "false") == 0) {

View File

@ -61,7 +61,7 @@ struct gpt_params {
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
int32_t n_beams = 0; // if non-zero then use beam search of given width.
@ -75,7 +75,7 @@ struct gpt_params {
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
// // sampling parameters

View File

@ -31,7 +31,7 @@ struct train_state * init_train_state() {
state->opt = new struct ggml_opt_context;
state->opt->ctx = NULL;
state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
state->opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
state->opt->loss_after = 0.0f;
@ -556,7 +556,7 @@ void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_g
std::string opt_type;
GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE);
if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) {
opt->params.type = GGML_OPT_ADAM;
opt->params.type = GGML_OPT_TYPE_ADAM;
GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS);
GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS);
@ -568,7 +568,7 @@ void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_g
copy_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
copy_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
} else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) {
opt->params.type = GGML_OPT_LBFGS;
opt->params.type = GGML_OPT_TYPE_LBFGS;
GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT);
GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS);
@ -603,7 +603,7 @@ void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context *
gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
switch (opt->params.type) {
case GGML_OPT_ADAM:
case GGML_OPT_TYPE_ADAM:
{
gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM);
gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best);
@ -622,7 +622,7 @@ void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context *
gguf_add_tensor(fctx, opt->adam.pf);
}
} break;
case GGML_OPT_LBFGS:
case GGML_OPT_TYPE_LBFGS:
{
gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS);
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m);

View File

@ -1547,7 +1547,7 @@ int main(int argc, char ** argv) {
float error_before_opt = ggml_get_f32_1d(e, 0);
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_TYPE_LBFGS);
opt_params_lbfgs.print_forward_graph = false;
opt_params_lbfgs.print_backward_graph = false;
opt_params_lbfgs.lbfgs.n_iter = 16;

View File

@ -1531,7 +1531,7 @@ int main(int argc, char ** argv) {
lora.hparams.n_rank_output = n_rank_output;
// set opt params from command line
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false;
opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;

View File

@ -157,9 +157,9 @@ static const char * output_format_str(output_formats format) {
static const char * split_mode_str(llama_split_mode mode) {
switch (mode) {
case LLAMA_SPLIT_NONE: return "none";
case LLAMA_SPLIT_LAYER: return "layer";
case LLAMA_SPLIT_ROW: return "row";
case LLAMA_SPLIT_MODE_NONE: return "none";
case LLAMA_SPLIT_MODE_LAYER: return "layer";
case LLAMA_SPLIT_MODE_ROW: return "row";
default: GGML_ASSERT(!"invalid split mode");
}
}
@ -193,7 +193,7 @@ static const cmd_params cmd_params_defaults = {
/* type_v */ {GGML_TYPE_F16},
/* n_threads */ {get_num_physical_cores()},
/* n_gpu_layers */ {99},
/* split_mode */ {LLAMA_SPLIT_LAYER},
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
/* main_gpu */ {0},
/* no_kv_offload */ {false},
/* mul_mat_q */ {true},
@ -358,11 +358,11 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
for (const auto & m : p) {
llama_split_mode mode;
if (m == "none") {
mode = LLAMA_SPLIT_NONE;
mode = LLAMA_SPLIT_MODE_NONE;
} else if (m == "layer") {
mode = LLAMA_SPLIT_LAYER;
mode = LLAMA_SPLIT_MODE_LAYER;
} else if (m == "row") {
mode = LLAMA_SPLIT_ROW;
mode = LLAMA_SPLIT_MODE_ROW;
} else {
invalid_param = true;
break;

View File

@ -152,7 +152,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
ggml_tensor * newline_tmp = clip_get_newline_tensor(ctx_clip);
model.newline = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, newline_tmp->ne[0]);
if (newline_tmp->backend != GGML_BACKEND_CPU) {
if (newline_tmp->backend != GGML_BACKEND_TYPE_CPU) {
if (newline_tmp->buffer == NULL) {
printf("newline_tmp tensor buffer is NULL\n");
}

View File

@ -2086,9 +2086,9 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
else { invalid_param = true; break; }
}
else if (arg == "--rope-freq-base")
@ -2212,15 +2212,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
std::string arg_next = argv[i];
if (arg_next == "none")
{
params.split_mode = LLAMA_SPLIT_NONE;
params.split_mode = LLAMA_SPLIT_MODE_NONE;
}
else if (arg_next == "layer")
{
params.split_mode = LLAMA_SPLIT_LAYER;
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
}
else if (arg_next == "row")
{
params.split_mode = LLAMA_SPLIT_ROW;
params.split_mode = LLAMA_SPLIT_MODE_ROW;
}
else {
invalid_param = true;
@ -2447,15 +2447,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
sep++;
if (strncmp(sep, "int:", 4) == 0) {
sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_INT;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
kvo.int_value = std::atol(sep);
} else if (strncmp(sep, "float:", 6) == 0) {
sep += 6;
kvo.tag = LLAMA_KV_OVERRIDE_FLOAT;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
kvo.float_value = std::atof(sep);
} else if (strncmp(sep, "bool:", 5) == 0) {
sep += 5;
kvo.tag = LLAMA_KV_OVERRIDE_BOOL;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
if (std::strcmp(sep, "true") == 0) {
kvo.bool_value = true;
} else if (std::strcmp(sep, "false") == 0) {

View File

@ -960,7 +960,7 @@ int main(int argc, char ** argv) {
struct ggml_opt_context * opt = train->opt;
// set opt params from command line
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false;
opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;

View File

@ -6369,11 +6369,11 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
swap(dst_row[col], dst_row[ixj]);
}
} else {
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
swap(dst_row[col], dst_row[ixj]);
}
}
@ -7927,10 +7927,10 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
const dim3 block_dims(ncols, 1, 1);
const dim3 block_nums(1, nrows, 1);
if (order == GGML_SORT_ASC) {
k_argsort_f32_i32<GGML_SORT_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
} else if (order == GGML_SORT_DESC) {
k_argsort_f32_i32<GGML_SORT_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
if (order == GGML_SORT_ORDER_ASC) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
} else {
GGML_ASSERT(false);
}
@ -8362,11 +8362,11 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
cudaMemcpyKind kind;
char * src_ptr;
if (src->backend == GGML_BACKEND_CPU) {
if (src->backend == GGML_BACKEND_TYPE_CPU) {
kind = cudaMemcpyHostToDevice;
src_ptr = (char *) src->data;
} else if (src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT) {
GGML_ASSERT(src->backend != GGML_BACKEND_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
} else if (src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
GGML_ASSERT(src->backend != GGML_BACKEND_TYPE_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
kind = cudaMemcpyDeviceToDevice;
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
int id;
@ -8771,7 +8771,7 @@ static void ggml_cuda_op_mul_mat_q(
// the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the kernel writes into
const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
const int64_t nrows_dst = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device ? ne0 : row_diff;
switch (src0->type) {
case GGML_TYPE_Q4_0:
@ -8920,7 +8920,7 @@ static void ggml_cuda_op_mul_mat_vec_q(
// the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the kernel writes into
const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
const int64_t nrows_dst = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device ? ne0 : row_diff;
switch (src0->type) {
case GGML_TYPE_Q4_0:
@ -9096,7 +9096,7 @@ static void ggml_cuda_op_mul_mat_cublas(
// the main device has a larger memory buffer to hold the results from all GPUs
// ldc == nrows of the matrix that cuBLAS writes into
int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
int ldc = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device ? ne0 : row_diff;
const int compute_capability = g_device_caps[id].cc;
@ -9444,7 +9444,7 @@ static void ggml_cuda_op_soft_max(
const bool use_src2 = src2 != nullptr;
if (use_src2) {
const bool src2_on_device = src2->backend == GGML_BACKEND_GPU;
const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
if (src2_on_device) {
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
@ -9502,16 +9502,16 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
const bool use_src1 = src1 != nullptr;
const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT( dst->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU;
const bool src0_on_device = src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_TYPE_GPU;
const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU;
// dd = data device
float * src0_ddf = nullptr;
@ -9555,7 +9555,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
}
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
CUDA_CHECK(cudaDeviceSynchronize());
}
}
@ -9636,8 +9636,8 @@ static void ggml_cuda_op_mul_mat(
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
@ -9653,20 +9653,20 @@ static void ggml_cuda_op_mul_mat(
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
const bool src0_on_device = src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
const bool src0_is_contiguous = ggml_is_contiguous(src0);
const bool src1_is_contiguous = ggml_is_contiguous(src1);
const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
GGML_ASSERT(!(split && ne02 > 1));
GGML_ASSERT(!(split && ne03 > 1));
GGML_ASSERT(!(split && ne02 < ne12));
std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
if (split) {
// TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_GPU_SPLIT check
// TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_TYPE_GPU_SPLIT check
// GGML_ASSERT(src0->buffer != nullptr && src0->buffer->buft == ...);
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
tensor_split = buft_ctx->tensor_split;
@ -9724,8 +9724,8 @@ static void ggml_cuda_op_mul_mat(
used_devices++;
const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
const bool src1_on_device = src1->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
ggml_cuda_set_device(id);
cudaStream_t stream = g_cudaStreams[id][0];
@ -9776,8 +9776,8 @@ static void ggml_cuda_op_mul_mat(
continue;
}
const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
const bool src1_on_device = src1->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
const int64_t row_diff = dev[id].row_high - dev[id].row_low;
ggml_cuda_set_device(id);
@ -9802,12 +9802,12 @@ static void ggml_cuda_op_mul_mat(
// the main device memory buffer can be on VRAM scratch, with space for all partial results
// in that case an offset on dst_ddf_i is needed
if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) {
if (dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device) {
dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split
}
// copy src0, src1 to device if necessary
if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
if (src1->backend == GGML_BACKEND_TYPE_GPU && src1_is_contiguous) {
if (id != g_main_device) {
if (convert_src1_to_q8_1) {
char * src1_ddq_i_source = dev[g_main_device].src1_ddq + src1_ddq_i_offset;
@ -9820,14 +9820,14 @@ static void ggml_cuda_op_mul_mat(
src1_ncols*ne10*sizeof(float), stream));
}
}
} else if (src1->backend == GGML_BACKEND_CPU || (src1_on_device && !src1_is_contiguous)) {
} else if (src1->backend == GGML_BACKEND_TYPE_CPU || (src1_on_device && !src1_is_contiguous)) {
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
} else {
GGML_ASSERT(false);
}
if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_CPU || !src1_is_contiguous)) {
if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_TYPE_CPU || !src1_is_contiguous)) {
quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
}
@ -9845,10 +9845,10 @@ static void ggml_cuda_op_mul_mat(
if (!dst_on_device) {
void * dst_off_device;
cudaMemcpyKind kind;
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
dst_off_device = dst->data;
kind = cudaMemcpyDeviceToHost;
} else if (dst->backend == GGML_BACKEND_GPU) {
} else if (dst->backend == GGML_BACKEND_TYPE_GPU) {
dst_off_device = dst_extra->data_device[g_main_device];
kind = cudaMemcpyDeviceToDevice;
} else {
@ -9913,7 +9913,7 @@ static void ggml_cuda_op_mul_mat(
}
}
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
ggml_cuda_set_device(g_main_device);
CUDA_CHECK(cudaDeviceSynchronize());
}
@ -10019,7 +10019,7 @@ GGML_CALL bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const stru
static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
GGML_ASSERT(src0->type == GGML_TYPE_F16);
@ -10050,7 +10050,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_is_permuted(src0));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
@ -10109,7 +10109,7 @@ static void ggml_cuda_mul_mat_batched_cublas(const ggml_tensor * src0, const ggm
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_TENSOR_BINARY_OP_LOCALS
@ -10255,11 +10255,11 @@ static void ggml_cuda_mul_mat_batched_cublas(const ggml_tensor * src0, const ggm
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool all_on_device =
(src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
(src1->backend == GGML_BACKEND_GPU) &&
( dst->backend == GGML_BACKEND_GPU);
(src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT) &&
(src1->backend == GGML_BACKEND_TYPE_GPU) &&
( dst->backend == GGML_BACKEND_TYPE_GPU);
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
int64_t min_compute_capability = INT_MAX;
@ -10409,7 +10409,7 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
GGML_ASSERT(!ggml_is_transposed(src00));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(src00->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
@ -10553,7 +10553,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
cudaStream_t stream = g_cudaStreams[g_main_device][0];
if (ids->backend == GGML_BACKEND_GPU) {
if (ids->backend == GGML_BACKEND_TYPE_GPU) {
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
@ -10570,20 +10570,20 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
src1_row.backend = GGML_BACKEND_GPU;
dst_row.backend = GGML_BACKEND_GPU;
src1_row.backend = GGML_BACKEND_TYPE_GPU;
dst_row.backend = GGML_BACKEND_TYPE_GPU;
src1_row.extra = &src1_row_extra;
dst_row.extra = &dst_row_extra;
char * src1_original = src1->backend == GGML_BACKEND_CPU ?
char * src1_original = src1->backend == GGML_BACKEND_TYPE_CPU ?
(char *) src1->data : (char *) src1_extra->data_device[g_main_device];
char * dst_original = dst->backend == GGML_BACKEND_CPU ?
char * dst_original = dst->backend == GGML_BACKEND_TYPE_CPU ?
(char *) dst->data : (char *) dst_extra->data_device[g_main_device];
if (src1->ne[1] == 1) {
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU);
GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU);
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
//int32_t row_id;
@ -10611,9 +10611,9 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ?
const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_TYPE_CPU ?
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
const cudaMemcpyKind dst_kind = dst->backend == GGML_BACKEND_CPU ?
const cudaMemcpyKind dst_kind = dst->backend == GGML_BACKEND_TYPE_CPU ?
cudaMemcpyDeviceToHost : cudaMemcpyDeviceToDevice;
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
@ -10668,7 +10668,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
}
}
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
CUDA_CHECK(cudaStreamSynchronize(stream));
}
}
@ -10685,8 +10685,8 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU);
GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU);
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
@ -10817,9 +10817,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
if (!g_cublas_loaded) return false;
ggml_cuda_func_t func;
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
const bool any_on_device = tensor->backend == GGML_BACKEND_TYPE_GPU
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_TYPE_GPU);
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_MUL_MAT_ID) {
return false;
@ -10966,14 +10966,14 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
return false;
}
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT) {
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
ggml_cuda_set_peer_access(tensor->src[1]->ne[1]);
}
if (params->ith != 0) {
return true;
}
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return true;
}
func(tensor->src[0], tensor->src[1], tensor);
@ -11072,7 +11072,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
extra->data_device[ctx->device] = tensor->data;
tensor->backend = GGML_BACKEND_GPU;
tensor->backend = GGML_BACKEND_TYPE_GPU;
tensor->extra = extra;
if (ggml_is_quantized(tensor->type)) {
@ -11087,7 +11087,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
}
GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
@ -11098,7 +11098,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t
}
GGML_CALL static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
@ -11333,7 +11333,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu
CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming));
}
}
tensor->backend = GGML_BACKEND_GPU_SPLIT;
tensor->backend = GGML_BACKEND_TYPE_GPU_SPLIT;
tensor->extra = extra;
}
@ -11605,7 +11605,7 @@ GGML_CALL static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend,
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[cuda_ctx->device][0]));
}
@ -11614,7 +11614,7 @@ GGML_CALL static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend,
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[cuda_ctx->device][0]));
}
@ -11644,7 +11644,7 @@ GGML_CALL static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, gg
ggml_cuda_set_main_device(cuda_ctx->device);
ggml_compute_params params = {};
params.type = GGML_TASK_COMPUTE;
params.type = GGML_TASK_TYPE_COMPUTE;
params.ith = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@ -11654,13 +11654,13 @@ GGML_CALL static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, gg
}
#ifndef NDEBUG
assert(node->backend == GGML_BACKEND_GPU || node->backend == GGML_BACKEND_GPU_SPLIT);
assert(node->backend == GGML_BACKEND_TYPE_GPU || node->backend == GGML_BACKEND_TYPE_GPU_SPLIT);
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
assert(node->extra != nullptr);
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) {
assert(node->src[j]->backend == GGML_BACKEND_GPU || node->src[j]->backend == GGML_BACKEND_GPU_SPLIT);
assert(node->src[j]->backend == GGML_BACKEND_TYPE_GPU || node->src[j]->backend == GGML_BACKEND_TYPE_GPU_SPLIT);
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
assert(node->src[j]->extra != nullptr);
}

View File

@ -2262,8 +2262,8 @@ static bool ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil;
switch (order) {
case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
default: GGML_ASSERT(false);
};

View File

@ -1354,7 +1354,7 @@ static void ggml_cl_pool_free(cl_mem mem, size_t size) {
}
void ggml_cl_free_data(const struct ggml_tensor* tensor) {
if (tensor->backend != GGML_BACKEND_GPU) {
if (tensor->backend != GGML_BACKEND_TYPE_GPU) {
return;
}
@ -1412,7 +1412,7 @@ static cl_int ggml_cl_h2d_tensor_2d(cl_command_queue queue, cl_mem dst, size_t o
}
static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
@ -1476,7 +1476,7 @@ void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src
}
static void ggml_cl_add_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
@ -1566,13 +1566,13 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
size_t y_size;
size_t d_size;
cl_mem d_X;
if (src0->backend == GGML_BACKEND_GPU) { // NOLINT
if (src0->backend == GGML_BACKEND_TYPE_GPU) { // NOLINT
d_X = (cl_mem) src0->extra;
} else {
d_X = ggml_cl_pool_malloc(sizeof(float) * x_ne, &x_size);
}
cl_mem d_Y = src1->backend == GGML_BACKEND_GPU ? (cl_mem) src1->extra : ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
cl_mem d_D = dst->backend == GGML_BACKEND_GPU ? (cl_mem) dst->extra : ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
cl_mem d_Y = src1->backend == GGML_BACKEND_TYPE_GPU ? (cl_mem) src1->extra : ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
cl_mem d_D = dst->backend == GGML_BACKEND_TYPE_GPU ? (cl_mem) dst->extra : ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
size_t x_offset = 0;
@ -1580,7 +1580,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
// TODO: copy src0 here when r3>1
for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
if (src0->backend == GGML_BACKEND_GPU) {
if (src0->backend == GGML_BACKEND_TYPE_GPU) {
x_offset = (i03 * ne02 + i02) * x_ne;
} else {
// copy src0 to device
@ -1589,7 +1589,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
// copy src1 to device
if (src1->backend == GGML_BACKEND_CPU) {
if (src1->backend == GGML_BACKEND_TYPE_CPU) {
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
}
@ -1612,7 +1612,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
}
// copy dst to host
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL));
}
@ -1621,13 +1621,13 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
}
}
if (src0->backend != GGML_BACKEND_GPU) {
if (src0->backend != GGML_BACKEND_TYPE_GPU) {
ggml_cl_pool_free(d_X, x_size);
}
if (src1->backend != GGML_BACKEND_GPU) {
if (src1->backend != GGML_BACKEND_TYPE_GPU) {
ggml_cl_pool_free(d_Y, y_size);
}
if (dst->backend != GGML_BACKEND_GPU) {
if (dst->backend != GGML_BACKEND_TYPE_GPU) {
ggml_cl_pool_free(d_D, d_size);
}
}
@ -1670,7 +1670,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
size_t y_size;
size_t d_size;
cl_mem d_X;
if (src0->backend == GGML_BACKEND_GPU) { // NOLINT
if (src0->backend == GGML_BACKEND_TYPE_GPU) { // NOLINT
d_X = (cl_mem) src0->extra;
} else {
d_X = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &x_size);
@ -1687,7 +1687,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
// TODO: copy src0 here when r3>1
for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
if (src0->backend == GGML_BACKEND_GPU) {
if (src0->backend == GGML_BACKEND_TYPE_GPU) {
x_offset = (i03 * ne02 + i02) * x_ne;
} else {
// copy src0 to device
@ -1741,7 +1741,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
}
// copy dst to host, then convert to float
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL));
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
ggml_fp16_to_fp32_row(tmp, d, d_ne);
@ -1753,7 +1753,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
}
}
if (src0->backend != GGML_BACKEND_GPU) {
if (src0->backend != GGML_BACKEND_TYPE_GPU) {
ggml_cl_pool_free(d_X, x_size);
}
ggml_cl_pool_free(d_Y, y_size);
@ -1798,7 +1798,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
cl_mem d_Q;
if (src0->backend == GGML_BACKEND_CPU) {
if (src0->backend == GGML_BACKEND_TYPE_CPU) {
d_Q = ggml_cl_pool_malloc(q_sz, &q_size);
}
@ -1817,10 +1817,10 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
// copy src0 to device if necessary
if (src0->backend == GGML_BACKEND_CPU) {
if (src0->backend == GGML_BACKEND_TYPE_CPU) {
events.emplace_back();
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
} else if (src0->backend == GGML_BACKEND_GPU) {
} else if (src0->backend == GGML_BACKEND_TYPE_GPU) {
d_Q = (cl_mem) src0->extra;
} else {
GGML_ASSERT(false);
@ -1829,7 +1829,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
if (!mul_mat_vec) {
// convert src0 to fp32 on device
const size_t global = x_ne / global_denom;
const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0;
const size_t offset = src0->backend == GGML_BACKEND_TYPE_GPU ? (i03 * ne02 + i02) * x_bps : 0;
CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q));
CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X));
CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, &offset, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
@ -1843,7 +1843,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
// compute
const size_t global = ne01 * local;
const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0;
const size_t offset = src0->backend == GGML_BACKEND_TYPE_GPU ? (i03 * ne02 + i02) * x_bps : 0;
const cl_int ncols = ne00;
events.emplace_back();
CL_CHECK(clSetKernelArg(*dmmv, 0, sizeof(cl_mem), &d_Q));
@ -1895,7 +1895,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
}
ggml_cl_pool_free(d_Y, y_size);
ggml_cl_pool_free(d_D, d_size);
if (src0->backend == GGML_BACKEND_CPU) {
if (src0->backend == GGML_BACKEND_TYPE_CPU) {
ggml_cl_pool_free(d_Q, q_size);
}
}
@ -1911,7 +1911,7 @@ bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 &&
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU)) {
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_TYPE_GPU)) {
return true;
}
@ -1993,7 +1993,7 @@ void ggml_cl_transform_tensor(void * data, ggml_tensor * tensor) {
CL_CHECK(clFinish(queue));
tensor->extra = dst;
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
}
// ggml-backend
@ -2045,7 +2045,7 @@ static void ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer,
ctx->sub_buffers.push_back(sub_buffer);
tensor->extra = sub_buffer;
}
tensor->backend = GGML_BACKEND_GPU;
tensor->backend = GGML_BACKEND_TYPE_GPU;
}
static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {

View File

@ -3338,7 +3338,7 @@ void print_ggml_tensor(const char*name, struct ggml_tensor *src){
size_t total_elements = ggml_nelements(src);
const bool src_on_device = src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT;
const bool src_on_device = src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
float *src_data =NULL;
if(src_on_device) {
ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra;
@ -8086,11 +8086,11 @@ static void k_argsort_f32_i32(const float * x, int * dst, const int ncols,
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
swap(dst_row[col], dst_row[ixj]);
}
} else {
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
swap(dst_row[col], dst_row[ixj]);
}
}
@ -10825,7 +10825,7 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
const sycl::range<3> block_dims(1, 1, ncols);
const sycl::range<3> block_nums(1, nrows, 1);
if (order == GGML_SORT_ASC) {
if (order == GGML_SORT_ORDER_ASC) {
/*
DPCT1049:44: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
@ -10834,9 +10834,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ASC>(x, dst, ncols, item_ct1);
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(x, dst, ncols, item_ct1);
});
} else if (order == GGML_SORT_DESC) {
} else if (order == GGML_SORT_ORDER_DESC) {
/*
DPCT1049:45: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
@ -10845,7 +10845,7 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_DESC>(x, dst, ncols, item_ct1);
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(x, dst, ncols, item_ct1);
});
} else {
GGML_ASSERT(false);
@ -11407,12 +11407,12 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
dpct::memcpy_direction kind;
char * src_ptr;
if (src->backend == GGML_BACKEND_CPU) {
if (src->backend == GGML_BACKEND_TYPE_CPU) {
kind = dpct::host_to_device;
src_ptr = (char *) src->data;
// GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_CPU src_ptr %p\n", src_ptr);
} else if (src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT) {
GGML_ASSERT(src->backend != GGML_BACKEND_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
// GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr);
} else if (src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
GGML_ASSERT(src->backend != GGML_BACKEND_TYPE_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
kind = dpct::device_to_device;
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
int id;
@ -11846,7 +11846,7 @@ inline void ggml_sycl_op_mul_mat_q(
// the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && device_id == g_main_device ? ne0 : row_diff;
const int64_t nrows_dst = dst->backend == GGML_BACKEND_TYPE_GPU && device_id == g_main_device ? ne0 : row_diff;
switch (src0->type) {
case GGML_TYPE_Q4_0:
@ -12119,7 +12119,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
// the main device has a larger memory buffer to hold the results from all GPUs
// ldc == nrows of the matrix that cuBLAS writes into
int ldc = dst->backend == GGML_BACKEND_GPU && device_id == g_main_device ? ne0 : row_diff;
int ldc = dst->backend == GGML_BACKEND_TYPE_GPU && device_id == g_main_device ? ne0 : row_diff;
#ifdef GGML_SYCL_F16
bool use_fp16 = true; // TODO(Yu) SYCL capability check
@ -12501,16 +12501,16 @@ static void ggml_sycl_op_flatten(const ggml_tensor *src0,
const bool use_src1 = src1 != nullptr;
const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT( dst->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU;
const bool src0_on_device = src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_TYPE_GPU;
const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU;
// dd = data device
float * src0_ddf = nullptr;
@ -12565,7 +12565,7 @@ static void ggml_sycl_op_flatten(const ggml_tensor *src0,
main_stream->memcpy(dst->data, dst_ddf, ggml_nbytes(dst))));
}
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
SYCL_CHECK(CHECK_TRY_ERROR(
dpct::get_current_device().queues_wait_and_throw()));
}
@ -12640,8 +12640,8 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
@ -12656,13 +12656,13 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
const bool src0_on_device = src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
const bool src0_is_contiguous = ggml_is_contiguous(src0);
const bool src1_is_contiguous = ggml_is_contiguous(src1);
int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
GGML_ASSERT(!(split && ne02 > 1));
GGML_ASSERT(!(split && ne03 > 1));
GGML_ASSERT(!(split && ne02 < ne12));
@ -12717,8 +12717,8 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
used_devices++;
const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device_index;
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device_index;
const bool src1_on_device = src1->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device_index;
const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device_index;
ggml_sycl_set_device(get_device_id_by_index(id));
const dpct::queue_ptr stream = g_syclStreams[id][0];
@ -12782,8 +12782,8 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
continue;
}
const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device_index;
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device_index;
const bool src1_on_device = src1->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device_index;
const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device_index;
const int64_t row_diff = row_high[id] - row_low[id];
ggml_sycl_set_device(get_device_id_by_index(id));
@ -12809,12 +12809,12 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
// the main device memory buffer can be on VRAM scratch, with space for all partial results
// in that case an offset on dst_ddf_i is needed
if (dst->backend == GGML_BACKEND_GPU && id == g_main_device_index) {
if (dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device_index) {
dst_dd_i += row_low[id]; // offset is 0 if no tensor split
}
// copy src0, src1 to device if necessary
if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
if (src1->backend == GGML_BACKEND_TYPE_GPU && src1_is_contiguous) {
if (id != g_main_device_index) {
if (convert_src1_to_q8_1) {
char * src1_ddq_i_source = src1_ddq[g_main_device_index] + src1_ddq_i_offset;
@ -12830,14 +12830,14 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
src1_ncols * ne10 * sizeof(float))));
}
}
} else if (src1->backend == GGML_BACKEND_CPU || (src1_on_device && !src1_is_contiguous)) {
} else if (src1->backend == GGML_BACKEND_TYPE_CPU || (src1_on_device && !src1_is_contiguous)) {
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
} else {
GGML_ASSERT(false);
}
if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_CPU || !src1_is_contiguous)) {
if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_TYPE_CPU || !src1_is_contiguous)) {
quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
/*
DPCT1010:92: SYCL uses exceptions to report errors and does
@ -12867,10 +12867,10 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
if (!dst_on_device) {
void * dst_off_device;
dpct::memcpy_direction kind;
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
dst_off_device = dst->data;
kind = dpct::device_to_host;
} else if (dst->backend == GGML_BACKEND_GPU) {
} else if (dst->backend == GGML_BACKEND_TYPE_GPU) {
dst_off_device = dst_extra->data_device[g_main_device_index];
kind = dpct::device_to_device;
} else {
@ -12954,7 +12954,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
}
}
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
SYCL_CHECK(ggml_sycl_set_device(g_main_device));
SYCL_CHECK(CHECK_TRY_ERROR(
dpct::get_current_device().queues_wait_and_throw()));
@ -13091,7 +13091,7 @@ static void ggml_sycl_mul_mat_vec_p021(const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst) try {
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
GGML_ASSERT(src0->type == GGML_TYPE_F16);
@ -13129,7 +13129,7 @@ static void ggml_sycl_mul_mat_vec_nc(const ggml_tensor *src0,
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_is_permuted(src0));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
@ -13196,7 +13196,7 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
@ -13372,11 +13372,11 @@ catch (sycl::exception const &exc) {
static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool all_on_device =
(src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
(src1->backend == GGML_BACKEND_GPU) &&
( dst->backend == GGML_BACKEND_GPU);
(src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT) &&
(src1->backend == GGML_BACKEND_TYPE_GPU) &&
( dst->backend == GGML_BACKEND_TYPE_GPU);
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
int64_t min_compute_capability = INT_MAX;
for (int64_t id = 0; id < g_device_count; ++id) {
@ -13505,7 +13505,7 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
GGML_ASSERT(!ggml_is_transposed(src00));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(src00->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_TENSOR_LOCALS(int64_t, ne0, src00, ne);
@ -13643,7 +13643,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
const dpct::queue_ptr stream = g_syclStreams[g_main_device_index][0];
if (ids->backend == GGML_BACKEND_GPU) {
if (ids->backend == GGML_BACKEND_TYPE_GPU) {
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device_index];
SYCL_CHECK(CHECK_TRY_ERROR(
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
@ -13661,20 +13661,20 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
src1_row.backend = GGML_BACKEND_GPU;
dst_row.backend = GGML_BACKEND_GPU;
src1_row.backend = GGML_BACKEND_TYPE_GPU;
dst_row.backend = GGML_BACKEND_TYPE_GPU;
src1_row.extra = &src1_row_extra;
dst_row.extra = &dst_row_extra;
char * src1_original = src1->backend == GGML_BACKEND_CPU ?
char * src1_original = src1->backend == GGML_BACKEND_TYPE_CPU ?
(char *) src1->data : (char *) src1_extra->data_device[g_main_device_index];
char * dst_original = dst->backend == GGML_BACKEND_CPU ?
char * dst_original = dst->backend == GGML_BACKEND_TYPE_CPU ?
(char *) dst->data : (char *) dst_extra->data_device[g_main_device_index];
if (src1->ne[1] == 1) {
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU);
GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU);
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
//int32_t row_id;
@ -13756,7 +13756,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
}
}
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
}
}
@ -13779,8 +13779,8 @@ static void ggml_sycl_cpy(const ggml_tensor *src0, const ggml_tensor *src1,
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU);
GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU);
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
@ -13887,17 +13887,17 @@ void ggml_sycl_transform_tensor(void *data, struct ggml_tensor *tensor) try {
memset(extra, 0, sizeof(*extra));
for (int64_t id = 0; id < g_device_count; ++id) {
if (backend == GGML_BACKEND_GPU && id != g_main_device_index) {
if (backend == GGML_BACKEND_TYPE_GPU && id != g_main_device_index) {
continue;
}
ggml_sycl_set_device(get_device_id_by_index(id));
const dpct::queue_ptr stream = g_syclStreams[id][0];
int64_t row_low, row_high;
if (backend == GGML_BACKEND_GPU) {
if (backend == GGML_BACKEND_TYPE_GPU) {
row_low = 0;
row_high = nrows;
} else if (backend == GGML_BACKEND_GPU_SPLIT) {
} else if (backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
const int64_t rounding = get_row_rounding(tensor->type);
row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
@ -13946,7 +13946,7 @@ void ggml_sycl_transform_tensor(void *data, struct ggml_tensor *tensor) try {
extra->data_device[id] = buf;
if (backend == GGML_BACKEND_GPU_SPLIT) {
if (backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
for (int64_t is = 0; is < MAX_STREAMS; ++is) {
SYCL_CHECK(CHECK_TRY_ERROR(extra->events[id][is] =
new sycl::event()));
@ -13963,7 +13963,7 @@ catch (sycl::exception const &exc) {
}
void ggml_sycl_free_data(struct ggml_tensor *tensor) try {
if (!tensor || !tensor->extra || (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) ) {
if (!tensor || !tensor->extra || (tensor->backend != GGML_BACKEND_TYPE_GPU && tensor->backend != GGML_BACKEND_TYPE_GPU_SPLIT) ) {
return;
}
@ -14016,15 +14016,15 @@ static void ggml_sycl_assign_buffers_impl(struct ggml_tensor *tensor,
return;
}
tensor->backend = GGML_BACKEND_GPU;
tensor->backend = GGML_BACKEND_TYPE_GPU;
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_TYPE_CPU) {
const ggml_op src0_op = tensor->src[0]->op;
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
ggml_sycl_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc);
}
}
if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) {
if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_TYPE_CPU) {
ggml_sycl_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
}
@ -14042,7 +14042,7 @@ static void ggml_sycl_assign_buffers_impl(struct ggml_tensor *tensor,
SYCL_CHECK(ggml_sycl_set_device(g_main_device));
const dpct::queue_ptr stream = g_syclStreams[g_main_device_index][0];
if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
if (inplace && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT)) {
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
char * src0_ddc = (char *) src0_extra->data_device[g_main_device_index];
size_t offset = 0;
@ -14111,7 +14111,7 @@ void ggml_sycl_assign_scratch_offset(struct ggml_tensor *tensor,
const bool inplace = tensor->view_src != nullptr;
if (inplace && (tensor->view_src->backend == GGML_BACKEND_GPU || tensor->view_src->backend == GGML_BACKEND_GPU_SPLIT)) {
if (inplace && (tensor->view_src->backend == GGML_BACKEND_TYPE_GPU || tensor->view_src->backend == GGML_BACKEND_TYPE_GPU_SPLIT)) {
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->view_src->extra;
char * src0_ddc = (char *) src0_extra->data_device[g_main_device_index];
size_t view_offset = 0;
@ -14132,7 +14132,7 @@ catch (sycl::exception const &exc) {
}
void ggml_sycl_copy_to_device(struct ggml_tensor *tensor) try {
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
GGML_ASSERT(ggml_is_contiguous(tensor));
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
@ -14219,9 +14219,9 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_
if (!g_sycl_loaded) return false;
ggml_sycl_func_t func;
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
const bool any_on_device = tensor->backend == GGML_BACKEND_TYPE_GPU
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_TYPE_GPU);
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_MUL_MAT_ID) {
return false;
@ -14359,14 +14359,14 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_
return false;
}
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT) {
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
ggml_sycl_set_peer_access(tensor->src[1]->ne[1]);
}
if (params->ith != 0) {
return true;
}
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return true;
}
func(tensor->src[0], tensor->src[1], tensor);
@ -14517,7 +14517,7 @@ static void ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
extra->data_device[ctx->device] = tensor->data;
tensor->backend = GGML_BACKEND_GPU;
tensor->backend = GGML_BACKEND_TYPE_GPU;
tensor->extra = extra;
if (ggml_is_quantized(tensor->type)) {
@ -14548,7 +14548,7 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
ggml_tensor *tensor,
const void *data, size_t offset,
size_t size) try {
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
@ -14573,7 +14573,7 @@ static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,
const ggml_tensor *tensor,
void *data, size_t offset,
size_t size) try {
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
@ -14809,7 +14809,7 @@ static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
SYCL_CHECK(CHECK_TRY_ERROR(g_syclStreams[sycl_ctx->device][0]->memcpy(
(char *)tensor->data + offset, data, size)));
@ -14827,7 +14827,7 @@ static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
SYCL_CHECK(CHECK_TRY_ERROR(g_syclStreams[sycl_ctx->device][0]->memcpy(
data, (const char *)tensor->data + offset, size)));
@ -14880,7 +14880,7 @@ static bool ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph
ggml_sycl_set_main_device(sycl_ctx->device);
ggml_compute_params params = {};
params.type = GGML_TASK_COMPUTE;
params.type = GGML_TASK_TYPE_COMPUTE;
params.ith = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@ -14888,13 +14888,13 @@ static bool ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph
if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE)
continue;
assert(node->backend == GGML_BACKEND_GPU);
assert(node->backend == GGML_BACKEND_TYPE_GPU);
assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
assert(node->extra != nullptr);
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) {
assert(node->src[j]->backend == GGML_BACKEND_GPU);
assert(node->src[j]->backend == GGML_BACKEND_TYPE_GPU);
assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
assert(node->src[j]->extra != nullptr);
}

View File

@ -2320,8 +2320,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
src1_uma = d_Qy != nullptr;
}
const bool load_x = src0->backend != GGML_BACKEND_GPU && !src0_uma;
const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma;
const bool load_x = src0->backend != GGML_BACKEND_TYPE_GPU && !src0_uma;
const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
const bool x_non_contig = !load_x && !ggml_vk_dim01_contiguous(src0);
const bool y_non_contig = !load_y && !ggml_vk_dim01_contiguous(src1);
@ -2453,7 +2453,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
// compute
ggml_vk_matmul(ctx, subctx, *pipeline, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, ne12*ne13, ne02, ne12, r2, r3, stride_batch_x, stride_batch_y, ne20*ne21); // NOLINT
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
// copy dst to host
float * d = (float *) ((char *) dst->data);
ggml_vk_buffer_read_async(ctx, subctx, d_D, 0, d, sizeof(float) * d_ne * ne12 * ne13);
@ -2506,8 +2506,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
src1_uma = d_Qy != nullptr;
}
const bool load_x = src0->backend != GGML_BACKEND_GPU && !src0_uma;
const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma;
const bool load_x = src0->backend != GGML_BACKEND_TYPE_GPU && !src0_uma;
const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
const bool x_non_contig = !load_x && !ggml_vk_dim01_contiguous(src0);
const bool y_non_contig = !load_y && !ggml_vk_dim01_contiguous(src1);
@ -2630,7 +2630,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, *dmmv, { { d_X, x_offset, x_sz }, { d_Y, y_buffer_offset, y_sz + y_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 3 * sizeof(int), &pc, { (uint32_t)ne01, 1, 1});
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
// copy dst to host
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
ggml_vk_sync_buffers(subctx);
@ -2647,7 +2647,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl;
#endif
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU);
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT
GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT
GGML_ASSERT(src0->type == GGML_TYPE_F16);
@ -2679,7 +2679,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
src1_uma = d_Qy != nullptr;
}
const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma;
const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
const uint64_t x_ne = ne00 * ne01 * ne02;
const uint64_t y_ne = ne10 * ne11 * ne12;
@ -2721,7 +2721,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_mul_mat_vec_p021_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
// copy dst to host
float * d = (float *) dst->data;
ggml_vk_sync_buffers(subctx);
@ -2738,7 +2738,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_is_permuted(src0));
GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU);
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
@ -2771,7 +2771,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
src1_uma = d_Qy != nullptr;
}
const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma;
const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
const uint64_t d_ne = ne01 * ne11 * ne12;
@ -2814,7 +2814,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_mul_mat_vec_nc_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
// copy dst to host
float * d = (float *) dst->data;
ggml_vk_sync_buffers(subctx);
@ -2832,7 +2832,7 @@ static bool ggml_vk_can_mul_mat(const ggml_tensor * src0, const ggml_tensor * sr
return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || ggml_is_quantized(src1->type)) &&
dst->type == GGML_TYPE_F32 &&
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU);
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_TYPE_GPU);
}
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
@ -2880,8 +2880,8 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx
// TODO: support for transposed / permuted tensors
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU);
GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU);
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
@ -3110,8 +3110,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
}
}
const bool transfer_src0 = src0->backend != GGML_BACKEND_GPU && !src0_uma;
const bool transfer_src1 = use_src1 && src1->backend != GGML_BACKEND_GPU && !src1_uma;
const bool transfer_src0 = src0->backend != GGML_BACKEND_TYPE_GPU && !src0_uma;
const bool transfer_src1 = use_src1 && src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type) * ne0, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment);
uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) : 0;
@ -3120,7 +3120,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
vk_buffer d_D = extra->buffer_gpu.lock();
// Workaround for tiny tensor inputs on ROPE
if (use_src1 && src1->backend == GGML_BACKEND_GPU && y_sz > d_D->size) {
if (use_src1 && src1->backend == GGML_BACKEND_TYPE_GPU && y_sz > d_D->size) {
y_sz = VK_WHOLE_SIZE;
}
@ -3209,9 +3209,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
}
if (dst->backend == GGML_BACKEND_CPU && op == GGML_OP_CPY) {
if (dst->backend == GGML_BACKEND_TYPE_CPU && op == GGML_OP_CPY) {
ggml_vk_d2h_tensor_2d(ctx, subctx, d_D, 0, dst);
} else if(dst->backend == GGML_BACKEND_CPU) {
} else if(dst->backend == GGML_BACKEND_TYPE_CPU) {
// copy dst to host
float * d = (float *) dst->data;
ggml_vk_buffer_read_async(ctx, subctx, d_D, 0, d, d_sz);
@ -3253,7 +3253,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
}
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
// copy dst to host
ggml_vk_buffer_read_async(ctx, subctx, d_D, d_buf_offset + d_offset, (char *) dst->data + i02*nb2 + i03*nb3, d_sz);
}
@ -3359,7 +3359,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
static void ggml_vk_nop(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
// If backend is CPU, data from src0 has to be copied off the device
if (dst->backend == GGML_BACKEND_CPU) {
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
vk_buffer d_D = extra_src0->buffer_gpu.lock();
ggml_vk_sync_buffers(subctx);
@ -3994,9 +3994,9 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_preallocate_buffers_graph(" << node << ")" << std::endl;
#endif
const bool any_on_device = node->backend == GGML_BACKEND_GPU
|| (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|| (node->src[1] != nullptr && (node->src[1]->backend == GGML_BACKEND_GPU));
const bool any_on_device = node->backend == GGML_BACKEND_TYPE_GPU
|| (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_TYPE_GPU || node->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
|| (node->src[1] != nullptr && (node->src[1]->backend == GGML_BACKEND_TYPE_GPU));
if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT)) {
return;
@ -4215,9 +4215,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
}
static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, bool last_node){
const bool any_on_device = node->backend == GGML_BACKEND_GPU
|| (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|| (node->src[1] != nullptr && node->src[1]->backend == GGML_BACKEND_GPU);
const bool any_on_device = node->backend == GGML_BACKEND_TYPE_GPU
|| (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_TYPE_GPU || node->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
|| (node->src[1] != nullptr && node->src[1]->backend == GGML_BACKEND_TYPE_GPU);
if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT) || (node->op == GGML_OP_MUL_MAT && !any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node))) {
return;
@ -4371,7 +4371,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
last_node = true;
#endif
if (node->backend == GGML_BACKEND_CPU || last_node) {
if (node->backend == GGML_BACKEND_TYPE_CPU || last_node) {
ggml_vk_ctx_end(ctx->compute_ctx);
ctx->compute_ctx->exit_tensor = node;
ctx->compute_ctx = nullptr;
@ -4379,9 +4379,9 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
}
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor){
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
const bool any_on_device = tensor->backend == GGML_BACKEND_TYPE_GPU
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_TYPE_GPU);
if (ctx->disable || (!any_on_device && tensor->op != GGML_OP_MUL_MAT)) {
return false;
@ -4442,7 +4442,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
if (params->ith != 0) {
return true;
}
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return true;
}
@ -4745,7 +4745,7 @@ GGML_CALL static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t b
extra->offset = (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
}
tensor->backend = GGML_BACKEND_GPU;
tensor->backend = GGML_BACKEND_TYPE_GPU;
tensor->extra = extra;
}
@ -4753,7 +4753,7 @@ GGML_CALL static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t bu
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")" << std::endl;
#endif
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
@ -4768,7 +4768,7 @@ GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t bu
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")" << std::endl;
#endif
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
@ -4999,7 +4999,7 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g
#endif
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_buffer_type(ctx->idx) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
@ -5020,7 +5020,7 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c
#endif
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_buffer_type(ctx->idx) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
@ -5097,7 +5097,7 @@ GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml
int last_node = cgraph->n_nodes - 1;
// If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
while (last_node > 0 && cgraph->nodes[last_node]->backend != GGML_BACKEND_GPU) {
while (last_node > 0 && cgraph->nodes[last_node]->backend != GGML_BACKEND_TYPE_GPU) {
last_node -= 1;
}
@ -5106,7 +5106,7 @@ GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml
}
ggml_compute_params params = {};
params.type = GGML_TASK_COMPUTE;
params.type = GGML_TASK_TYPE_COMPUTE;
params.ith = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@ -5410,7 +5410,7 @@ static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * d
static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tensor * tensor, const char * name) {
void * tensor_data = tensor->data;
if (tensor->backend == GGML_BACKEND_GPU) {
if (tensor->backend == GGML_BACKEND_TYPE_GPU) {
const size_t tensor_size = ggml_nbytes(tensor);
tensor_data = malloc(tensor_size);
@ -5436,14 +5436,14 @@ static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tenso
std::vector<const ggml_tensor *> done;
ggml_vk_print_graph_origin(tensor, done);
if (tensor->backend == GGML_BACKEND_GPU) {
if (tensor->backend == GGML_BACKEND_TYPE_GPU) {
free(tensor_data);
}
}
static void ggml_vk_check_tensor(const std::string& name, const ggml_tensor * tensor) {
return;
GGML_ASSERT(tensor->backend == GGML_BACKEND_CPU);
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_CPU);
if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) {
return;
}
@ -5481,7 +5481,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
if (params->ith != 0) {
return;
}
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
return;
}
@ -5518,10 +5518,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
src0_buffer = malloc(src0_size);
src0_clone->data = src0_buffer;
if (src0->backend == GGML_BACKEND_CPU) {
if (src0->backend == GGML_BACKEND_TYPE_CPU) {
memcpy(src0_clone->data, src0->data, src0_size);
memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
} else if (src0->backend == GGML_BACKEND_GPU) {
} else if (src0->backend == GGML_BACKEND_TYPE_GPU) {
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src0->extra;
uint64_t offset = extra->offset;
if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) {
@ -5561,10 +5561,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
src1_buffer = malloc(src1_size);
src1_clone->data = src1_buffer;
if (src1->backend == GGML_BACKEND_CPU) {
if (src1->backend == GGML_BACKEND_TYPE_CPU) {
memcpy(src1_clone->data, src1->data, src1_size);
memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
} else if (src1->backend == GGML_BACKEND_GPU) {
} else if (src1->backend == GGML_BACKEND_TYPE_GPU) {
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src1->extra;
uint64_t offset = extra->offset;
if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) {
@ -5723,7 +5723,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_
if (params->ith != 0) {
return;
}
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
return;
}
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
@ -5735,7 +5735,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_
void * tensor_data = tensor->data;
if (tensor->backend == GGML_BACKEND_GPU) {
if (tensor->backend == GGML_BACKEND_TYPE_GPU) {
size_t tensor_size = ggml_nbytes(tensor);
tensor_data = malloc(tensor_size);
@ -5868,7 +5868,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_
comp_result = nullptr;
comp_size = 0;
if (tensor->backend == GGML_BACKEND_GPU) {
if (tensor->backend == GGML_BACKEND_TYPE_GPU) {
free(tensor_data);
}
}

350
ggml.c

File diff suppressed because it is too large Load Diff

38
ggml.h
View File

@ -364,9 +364,9 @@ extern "C" {
};
enum ggml_backend_type {
GGML_BACKEND_CPU = 0,
GGML_BACKEND_GPU = 10,
GGML_BACKEND_GPU_SPLIT = 20,
GGML_BACKEND_TYPE_CPU = 0,
GGML_BACKEND_TYPE_GPU = 10,
GGML_BACKEND_TYPE_GPU_SPLIT = 20,
};
// model file types
@ -498,9 +498,9 @@ extern "C" {
};
enum ggml_object_type {
GGML_OBJECT_TENSOR,
GGML_OBJECT_GRAPH,
GGML_OBJECT_WORK_BUFFER
GGML_OBJECT_TYPE_TENSOR,
GGML_OBJECT_TYPE_GRAPH,
GGML_OBJECT_TYPE_WORK_BUFFER
};
enum ggml_log_level {
@ -642,9 +642,9 @@ extern "C" {
// NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
// This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
enum ggml_task_type {
GGML_TASK_INIT = 0,
GGML_TASK_COMPUTE,
GGML_TASK_FINALIZE,
GGML_TASK_TYPE_INIT = 0,
GGML_TASK_TYPE_COMPUTE,
GGML_TASK_TYPE_FINALIZE,
};
struct ggml_compute_params {
@ -1649,8 +1649,8 @@ extern "C" {
// sort rows
enum ggml_sort_order {
GGML_SORT_ASC,
GGML_SORT_DESC,
GGML_SORT_ORDER_ASC,
GGML_SORT_ORDER_DESC,
};
GGML_API struct ggml_tensor * ggml_argsort(
@ -1943,8 +1943,8 @@ extern "C" {
// optimization methods
enum ggml_opt_type {
GGML_OPT_ADAM,
GGML_OPT_LBFGS,
GGML_OPT_TYPE_ADAM,
GGML_OPT_TYPE_LBFGS,
};
// linesearch methods
@ -1958,12 +1958,12 @@ extern "C" {
// optimization return values
enum ggml_opt_result {
GGML_OPT_OK = 0,
GGML_OPT_DID_NOT_CONVERGE,
GGML_OPT_NO_CONTEXT,
GGML_OPT_INVALID_WOLFE,
GGML_OPT_FAIL,
GGML_OPT_CANCEL,
GGML_OPT_RESULT_OK = 0,
GGML_OPT_RESULT_DID_NOT_CONVERGE,
GGML_OPT_RESULT_NO_CONTEXT,
GGML_OPT_RESULT_INVALID_WOLFE,
GGML_OPT_RESULT_FAIL,
GGML_OPT_RESULT_CANCEL,
GGML_LINESEARCH_FAIL = -128,
GGML_LINESEARCH_MINIMUM_STEP,

View File

@ -850,9 +850,9 @@ struct LLM_TN {
//
static std::map<int32_t, const char *> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_NONE, "none" },
{ LLAMA_ROPE_SCALING_LINEAR, "linear" },
{ LLAMA_ROPE_SCALING_YARN, "yarn" },
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
{ LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
{ LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" },
};
static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
@ -862,7 +862,7 @@ static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
}
}
return LLAMA_ROPE_SCALING_UNSPECIFIED;
return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
}
static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
@ -1580,7 +1580,7 @@ struct llama_hparams {
bool causal_attn = true;
bool need_kq_pos = false;
uint32_t pooling_type = LLAMA_POOLING_NONE;
uint32_t pooling_type = LLAMA_POOLING_TYPE_NONE;
bool operator!=(const llama_hparams & other) const {
if (this->vocab_only != other.vocab_only) return true;
@ -2345,9 +2345,9 @@ namespace GGUFMeta {
static const char * override_type_to_str(const llama_model_kv_override_type ty) {
switch (ty) {
case LLAMA_KV_OVERRIDE_BOOL: return "bool";
case LLAMA_KV_OVERRIDE_INT: return "int";
case LLAMA_KV_OVERRIDE_FLOAT: return "float";
case LLAMA_KV_OVERRIDE_TYPE_BOOL: return "bool";
case LLAMA_KV_OVERRIDE_TYPE_INT: return "int";
case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float";
}
return "unknown";
}
@ -2358,13 +2358,13 @@ namespace GGUFMeta {
LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ",
__func__, override_type_to_str(override->tag), override->key);
switch (override->tag) {
case LLAMA_KV_OVERRIDE_BOOL: {
case LLAMA_KV_OVERRIDE_TYPE_BOOL: {
LLAMA_LOG_INFO("%s\n", override->bool_value ? "true" : "false");
} break;
case LLAMA_KV_OVERRIDE_INT: {
case LLAMA_KV_OVERRIDE_TYPE_INT: {
LLAMA_LOG_INFO("%" PRId64 "\n", override->int_value);
} break;
case LLAMA_KV_OVERRIDE_FLOAT: {
case LLAMA_KV_OVERRIDE_TYPE_FLOAT: {
LLAMA_LOG_INFO("%.6f\n", override->float_value);
} break;
default:
@ -2383,7 +2383,7 @@ namespace GGUFMeta {
template<typename OT>
static typename std::enable_if<std::is_same<OT, bool>::value, bool>::type
try_override(OT & target, const struct llama_model_kv_override *override) {
if (validate_override(LLAMA_KV_OVERRIDE_BOOL, override)) {
if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, override)) {
target = override->bool_value;
return true;
}
@ -2393,7 +2393,7 @@ namespace GGUFMeta {
template<typename OT>
static typename std::enable_if<!std::is_same<OT, bool>::value && std::is_integral<OT>::value, bool>::type
try_override(OT & target, const struct llama_model_kv_override *override) {
if (validate_override(LLAMA_KV_OVERRIDE_INT, override)) {
if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, override)) {
target = override->int_value;
return true;
}
@ -2403,7 +2403,7 @@ namespace GGUFMeta {
template<typename OT>
static typename std::enable_if<std::is_floating_point<OT>::value, bool>::type
try_override(T & target, const struct llama_model_kv_override *override) {
if (validate_override(LLAMA_KV_OVERRIDE_FLOAT, override)) {
if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, override)) {
target = override->float_value;
return true;
}
@ -2999,7 +2999,7 @@ static void llm_load_hparams(
std::string rope_scaling("linear");
ml.get_key(LLM_KV_ROPE_SCALING_TYPE, rope_scaling, false);
hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED);
GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED);
// rope_freq_scale (inverse of the kv) is optional
float ropescale = 0.0f;
@ -3643,7 +3643,7 @@ static bool llm_load_tensors(
model.buft_layer[i] = llama_default_buffer_type_cpu(true);
}
if (split_mode == LLAMA_SPLIT_LAYER) {
if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
// calculate the split points
int device_count = llama_get_device_count();
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
@ -3682,10 +3682,10 @@ static bool llm_load_tensors(
}
} else {
ggml_backend_buffer_type_t split_buft;
if (split_mode == LLAMA_SPLIT_ROW) {
if (split_mode == LLAMA_SPLIT_MODE_ROW) {
split_buft = llama_default_buffer_type_split(main_gpu, tensor_split);
} else {
// LLAMA_SPLIT_NONE or LLAMA_SPLIT_LAYER in backends where it is not supported
// LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported
split_buft = llama_default_buffer_type_offload(main_gpu);
}
// assign the repeating layers
@ -5070,7 +5070,7 @@ struct llm_build_context {
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
do_rope_shift (worst_case || kv_self.has_shift),
pooling_type (cparams.do_pooling ? hparams.pooling_type : (uint32_t)LLAMA_POOLING_NONE),
pooling_type (cparams.do_pooling ? hparams.pooling_type : (uint32_t)LLAMA_POOLING_TYPE_NONE),
cb (cb),
buf_compute_meta (lctx.buf_compute_meta) {
// all initializations should be done in init()
@ -6050,12 +6050,12 @@ struct llm_build_context {
cur = inpL;
// pooling layer
if (pooling_type == LLAMA_POOLING_MEAN) {
if (pooling_type == LLAMA_POOLING_TYPE_MEAN) {
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
} else if (pooling_type == LLAMA_POOLING_CLS) {
} else if (pooling_type == LLAMA_POOLING_TYPE_CLS) {
cur = ggml_get_rows(ctx0, cur, inp_cls);
} else {
GGML_ASSERT(pooling_type == LLAMA_POOLING_NONE && "Invalid pooling type");
GGML_ASSERT(pooling_type == LLAMA_POOLING_TYPE_NONE && "Invalid pooling type");
}
cb(cur, "result_embd", -1);
@ -7754,7 +7754,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_MEAN) {
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
@ -7782,7 +7782,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_CLS) {
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@ -11351,7 +11351,7 @@ static int llama_apply_lora_from_file_internal(
struct llama_model_params llama_model_default_params() {
struct llama_model_params result = {
/*.n_gpu_layers =*/ 0,
/*.split_mode =*/ LLAMA_SPLIT_LAYER,
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr,
/*.progress_callback =*/ nullptr,
@ -11377,7 +11377,7 @@ struct llama_context_params llama_context_default_params() {
/*.n_batch =*/ 512,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
/*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f,
/*.yarn_ext_factor =*/ -1.0f,
@ -11565,16 +11565,16 @@ struct llama_context * llama_new_context_with_model(
cparams.cb_eval_user_data = params.cb_eval_user_data;
auto rope_scaling_type = params.rope_scaling_type;
if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
rope_scaling_type = hparams.rope_scaling_type_train;
}
if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) {
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
}
if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_YARN ? 1.0f : 0.0f;
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
}
if (params.seed == LLAMA_DEFAULT_SEED) {
@ -11608,8 +11608,8 @@ struct llama_context * llama_new_context_with_model(
}
#elif defined(GGML_USE_CUBLAS)
if (model->n_gpu_layers > 0) {
// with split_mode LLAMA_SPLIT_NONE or LLAMA_SPLIT_ROW, only the main GPU backend is used
if (model->split_mode == LLAMA_SPLIT_NONE || model->split_mode == LLAMA_SPLIT_ROW) {
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu);
if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, model->main_gpu);
@ -11618,7 +11618,7 @@ struct llama_context * llama_new_context_with_model(
}
ctx->backends.push_back(backend);
} else {
// LLAMA_SPLIT_LAYER requires a backend for each GPU
// LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) {
ggml_backend_t backend = ggml_backend_cuda_init(device);
if (backend == nullptr) {

28
llama.h
View File

@ -109,23 +109,23 @@ extern "C" {
};
enum llama_rope_scaling_type {
LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
LLAMA_ROPE_SCALING_NONE = 0,
LLAMA_ROPE_SCALING_LINEAR = 1,
LLAMA_ROPE_SCALING_YARN = 2,
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = -1,
LLAMA_ROPE_SCALING_TYPE_NONE = 0,
LLAMA_ROPE_SCALING_TYPE_LINEAR = 1,
LLAMA_ROPE_SCALING_TYPE_YARN = 2,
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN,
};
enum llama_pooling_type {
LLAMA_POOLING_NONE = 0,
LLAMA_POOLING_MEAN = 1,
LLAMA_POOLING_CLS = 2,
LLAMA_POOLING_TYPE_NONE = 0,
LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2,
};
enum llama_split_mode {
LLAMA_SPLIT_NONE = 0, // single GPU
LLAMA_SPLIT_LAYER = 1, // split layers and KV across GPUs
LLAMA_SPLIT_ROW = 2, // split rows across GPUs
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
};
typedef struct llama_token_data {
@ -173,9 +173,9 @@ extern "C" {
} llama_batch;
enum llama_model_kv_override_type {
LLAMA_KV_OVERRIDE_INT,
LLAMA_KV_OVERRIDE_FLOAT,
LLAMA_KV_OVERRIDE_BOOL,
LLAMA_KV_OVERRIDE_TYPE_INT,
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
LLAMA_KV_OVERRIDE_TYPE_BOOL,
};
struct llama_model_kv_override {

View File

@ -1264,7 +1264,7 @@ struct test_argsort : public test_case {
test_argsort(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {16, 10, 10, 10},
ggml_sort_order order = GGML_SORT_ASC)
ggml_sort_order order = GGML_SORT_ORDER_ASC)
: type(type), ne(ne), order(order) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
@ -2116,7 +2116,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_concat(GGML_TYPE_F32));
test_cases.emplace_back(new test_concat(GGML_TYPE_I32));
for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
}

View File

@ -118,7 +118,7 @@ int main(void) {
const float fe = ggml_get_f32_1d(e, 0);
printf("%s: e = %.4f\n", __func__, fe);
struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_ADAM);
struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
ggml_opt(ctx, opt_params, e);