diff --git a/ggml-threading.c b/ggml-threading.c index fb02b4046..2a5cfa096 100644 --- a/ggml-threading.c +++ b/ggml-threading.c @@ -260,22 +260,17 @@ void ggml_threading_suspend(struct ggml_threading_context *ctx) { return; } - struct ggml_compute_state_shared *shared = &ctx->shared; - - ggml_spin_lock(&shared->spin); + PRINT_DEBUG("[main] wait_now will be set, expect %d workers wait\n", + n_worker_threads); ctx->shared.wait_now = true; - ggml_spin_unlock(&shared->spin); const int n_worker_threads = ctx->n_threads - 1; - while (ctx->shared.n_waiting != n_worker_threads) { ggml_spin_pause(); } - ggml_spin_lock(&shared->spin); - ctx->suspending = true; - ggml_spin_unlock(&shared->spin); PRINT_DEBUG("[main] saw %d workers waiting\n", n_worker_threads); + ctx->suspending = true; } // Wakeup all workers. @@ -291,7 +286,6 @@ void ggml_threading_resume(struct ggml_threading_context *ctx) { } struct ggml_compute_state_shared *shared = &ctx->shared; - ggml_spin_lock(&shared->spin); int64_t perf_cycles_0 = 0; int64_t perf_time_0 = 0; @@ -307,8 +301,6 @@ void ggml_threading_resume(struct ggml_threading_context *ctx) { shared->wait_now = false; while (shared->n_waiting != 0) { - ggml_spin_unlock(&shared->spin); - if (loop_counter > 0) { ggml_spin_pause(); if (loop_counter > 3) { @@ -326,8 +318,6 @@ void ggml_threading_resume(struct ggml_threading_context *ctx) { GGML_ASSERT(pthread_cond_broadcast(&shared->cond) == 0); GGML_ASSERT(pthread_mutex_unlock(&shared->mutex) == 0); last_signal_time = ggml_time_us(); - - ggml_spin_lock(&shared->spin); } ctx->suspending = false; @@ -335,9 +325,7 @@ void ggml_threading_resume(struct ggml_threading_context *ctx) { if (shared->ctx->features & GGML_THREADING_FEATURE_PERF) { ggml_perf_collect(&shared->ctx->wakeup_perf, perf_cycles_0, perf_time_0); - } - - ggml_spin_unlock(&shared->spin); + }; } bool ggml_threading_is_suspending(struct ggml_threading_context *ctx) { @@ -385,8 +373,6 @@ static void ggml_threading_setup_workers(struct ggml_threading_context *ctx, } } else if (current->wait) { if (shared->n_waiting < n_worker_threads) { - PRINT_DEBUG("[main] wait_now will be set, expect %d workers wait\n", - n_worker_threads); ggml_spin_unlock(&ctx->shared.spin); ggml_threading_suspend(ctx); ggml_spin_lock(&ctx->shared.spin); diff --git a/tests/test-ggml-threading.c b/tests/test-ggml-threading.c index f941f4dc3..cb2cca163 100644 --- a/tests/test-ggml-threading.c +++ b/tests/test-ggml-threading.c @@ -214,7 +214,7 @@ lifecycle_runner(const struct ggml_compute_params *params, } // Test thread lifecycle: start -> suspend -> resume -> stop -static int test_lifecycle(void) { +static int test_lifecycle(bool wait_on_done) { struct ggml_tensor node; memset(&node, 0, sizeof(struct ggml_tensor)); @@ -243,14 +243,15 @@ static int test_lifecycle(void) { int threads_arr_len = sizeof(threads_arr) / sizeof(threads_arr[0]); int n_threads = 1; + enum ggml_threading_features features = + wait_on_done ? GGML_THREADING_FEATURE_NONE + : GGML_THREADING_FEATURE_WAIT_ON_DONE; for (int i = 0; i < threads_arr_len; i++) { n_threads = threads_arr[i]; int start_time = (int)ggml_time_ms(); - ctx = ggml_threading_start( - n_threads, NULL, lifecycle_runner, - /*features*/ GGML_THREADING_FEATURE_WAIT_ON_DONE | - GGML_THREADING_FEATURE_PERF, - /*stages_time*/ NULL); + ctx = ggml_threading_start(n_threads, NULL, lifecycle_runner, + features | GGML_THREADING_FEATURE_PERF, + /*stages_time*/ NULL); int elapsed = (int)ggml_time_ms() - start_time; if (elapsed > 5 * n_threads) { printf("[test-ggml-threading] %s: it took %d ms to start %d worker " @@ -547,13 +548,17 @@ int main(void) { } // lifecycle. - { - printf("[test-ggml-threading] test lifecycle ...\n"); + for (int i = 0; i < 2; i++) { + bool wait_on_done = (i == 1); + printf("[test-ggml-threading] test lifecycle (want_on_done = %d) ...\n", + wait_on_done); ++n_tests; - if (test_lifecycle() == 0) { + if (test_lifecycle(wait_on_done) == 0) { ++n_passed; - printf("[test-ggml-threading] test lifecycle: ok\n\n"); + printf("[test-ggml-threading] test lifecycle (want_on_done = %d): " + "ok\n\n", + wait_on_done); } }