move ggml_amx_init from ggml.c to ggml-amx/mmq.cpp

ggml-ci
This commit is contained in:
mingfeima 2024-08-13 22:06:53 -07:00
parent 0b4de32e61
commit 37ccb9d324
2 changed files with 17 additions and 23 deletions

View File

@ -2363,8 +2363,18 @@ bool ggml_amx_init() {
}
bool ggml_compute_forward_mul_mat_use_amx(struct ggml_tensor * dst) {
// load tile config
ggml_tile_config_init();
static thread_local bool is_first_time = true;
if (is_first_time) {
#pragma omp single
{
ggml_amx_init();
}
// load tile config
ggml_tile_config_init();
}
is_first_time = false;
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
@ -2464,7 +2474,7 @@ void ggml_mul_mat_amx(struct ggml_tensor * dst, int nth, int ith, void * wdata,
return;
}
#pragma omp master
#pragma omp single
{
GGML_DISPATCH_QTYPES(TYPE, [&] {
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
@ -2479,20 +2489,13 @@ void ggml_mul_mat_amx(struct ggml_tensor * dst, int nth, int ith, void * wdata,
src0->extra = aligned_alloc(64, N * row_size_B);
convert_B_packed_format<type, blck_size>((void *)src0->extra, (const type *)src0->data, N, K);
}
});
}
#pragma omp barrier
const float * A_data = static_cast<const float *>(src1->data);
parallel_for(nth, ith, M, [&](int begin, int end) {
GGML_DISPATCH_QTYPES(TYPE, [&] {
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
for (int m = begin; m < end; ++m) {
const float * A_data = static_cast<const float *>(src1->data);
for (int m = 0; m < M; ++m) {
from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K);
}
});
});
#pragma omp barrier
}
GGML_ASSERT(src0->extra != nullptr);
if (M == 1) {

View File

@ -411,11 +411,6 @@ static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
float ggml_table_f32_f16[1 << 16];
#if GGML_USE_AMX
// global flag for amx init
static bool ggml_amx_initialized = false;
#endif
GGML_CALL const char * ggml_status_to_string(enum ggml_status status) {
switch (status) {
case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
@ -3530,10 +3525,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
}
#if GGML_USE_AMX
ggml_amx_initialized = ggml_amx_init();
#endif
is_first_call = false;
}
@ -12334,7 +12325,7 @@ static void ggml_compute_forward_mul_mat(
// compute by src0 rows
#if GGML_USE_AMX
if (ggml_compute_forward_mul_mat_use_amx(dst) && ggml_amx_initialized) {
if (ggml_compute_forward_mul_mat_use_amx(dst)) {
ggml_mul_mat_amx(dst, nth, ith, params->wdata, params->wsize);
return;
}