threading: add suspend/resume APIs, so it's possible to run a thread pool at session level

This commit is contained in:
mqy 2023-06-18 18:57:33 +08:00
parent 5abb8aefea
commit 5feefb32b3
3 changed files with 204 additions and 40 deletions

View File

@ -194,6 +194,8 @@ struct ggml_threading_context {
struct ggml_perf_stats wait_perf;
struct ggml_perf_stats wakeup_perf;
atomic_bool suspending;
int64_t *stages_time;
};
@ -252,6 +254,30 @@ static void ggml_threading_cond_wait(struct ggml_compute_state *state) {
}
}
// Suspend
void ggml_threading_suspend(struct ggml_threading_context *ctx) {
if (ctx->n_threads == 1) {
return;
}
struct ggml_compute_state_shared *shared = &ctx->shared;
ggml_spin_lock(&shared->spin);
ctx->shared.wait_now = true;
ggml_spin_unlock(&shared->spin);
const int n_worker_threads = ctx->n_threads - 1;
while (ctx->shared.n_waiting != n_worker_threads) {
ggml_spin_pause();
}
ggml_spin_lock(&shared->spin);
ctx->suspending = true;
ggml_spin_unlock(&shared->spin);
PRINT_DEBUG("[main] saw %d workers waiting\n", n_worker_threads);
}
// Wakeup all workers.
//
// Workers takes some time to wakeup, and has to lock spin after wakeup. Yield
@ -259,8 +285,14 @@ static void ggml_threading_cond_wait(struct ggml_compute_state *state) {
// experimental. See tests/test-ggml-threading.c for details.
//
// NOTE: must be protected by shared->spin
static void
ggml_threading_wakeup_workers(struct ggml_compute_state_shared *shared) {
void ggml_threading_resume(struct ggml_threading_context *ctx) {
if (ctx->n_threads == 1) {
return;
}
struct ggml_compute_state_shared *shared = &ctx->shared;
ggml_spin_lock(&shared->spin);
int64_t perf_cycles_0 = 0;
int64_t perf_time_0 = 0;
@ -269,12 +301,11 @@ ggml_threading_wakeup_workers(struct ggml_compute_state_shared *shared) {
perf_time_0 = ggml_time_us();
}
shared->wait_now = false;
int loop_counter = 0;
int notify_counter = 0;
int64_t last_signal_time = 0;
shared->wait_now = false;
while (shared->n_waiting != 0) {
ggml_spin_unlock(&shared->spin);
@ -294,22 +325,23 @@ ggml_threading_wakeup_workers(struct ggml_compute_state_shared *shared) {
GGML_ASSERT(pthread_mutex_lock(&shared->mutex) == 0);
GGML_ASSERT(pthread_cond_broadcast(&shared->cond) == 0);
GGML_ASSERT(pthread_mutex_unlock(&shared->mutex) == 0);
++notify_counter;
last_signal_time = ggml_time_us();
ggml_spin_lock(&shared->spin);
}
ctx->suspending = false;
if (shared->ctx->features & GGML_THREADING_FEATURE_PERF) {
ggml_perf_collect(&shared->ctx->wakeup_perf, perf_cycles_0,
perf_time_0);
}
// if (notify_counter > 1) {
// printf("%s: loop counter: %d, notify counter: %d\n", __func__,
// loop_counter, notify_counter);
// }
UNUSED(notify_counter);
ggml_spin_unlock(&shared->spin);
}
bool ggml_threading_is_suspending(struct ggml_threading_context *ctx) {
return ctx->suspending;
}
// Setup workers for a task stage.
@ -329,7 +361,9 @@ static void ggml_threading_setup_workers(struct ggml_threading_context *ctx,
if (current->parallel) {
if (shared->n_waiting > 0) {
ggml_threading_wakeup_workers(shared);
ggml_spin_unlock(&shared->spin);
ggml_threading_resume(ctx);
ggml_spin_lock(&shared->spin);
}
if ((ctx->features & GGML_THREADING_FEATURE_WAIT_ON_DONE) > 0) {
@ -351,17 +385,11 @@ static void ggml_threading_setup_workers(struct ggml_threading_context *ctx,
}
} else if (current->wait) {
if (shared->n_waiting < n_worker_threads) {
shared->wait_now = true;
PRINT_DEBUG("[main] wait_now was set, expect %d workers wait\n",
PRINT_DEBUG("[main] wait_now will be set, expect %d workers wait\n",
n_worker_threads);
ggml_spin_unlock(&shared->spin);
while (shared->n_waiting != n_worker_threads) {
ggml_spin_pause();
}
ggml_spin_lock(&shared->spin);
PRINT_DEBUG("[main] saw %d workers waiting\n", n_worker_threads);
ggml_spin_unlock(&ctx->shared.spin);
ggml_threading_suspend(ctx);
ggml_spin_lock(&ctx->shared.spin);
}
}
@ -376,7 +404,7 @@ ggml_thread_ret_t ggml_threading_graph_compute_thread(void *data) {
struct ggml_compute_state_shared *shared = state->shared;
GGML_ASSERT(shared);
//GGML_ASSERT(shared->task_runner);
// GGML_ASSERT(shared->task_runner);
shared->n_ready++;
@ -527,7 +555,7 @@ START:
GGML_ASSERT(profiles[0].id == 1);
memcpy(&node->task_profile, &profiles[0],
sizeof(struct ggml_task_profile));
sizeof(struct ggml_task_profile));
runner = ctx->shared.task_runner;
GGML_ASSERT(runner);
@ -572,6 +600,7 @@ ggml_threading_start(int n_threads, ggml_threading_thread_runner *thread_runner,
ctx->n_threads = n_threads;
ctx->features = features;
ctx->stages_time = stages_time;
ctx->suspending = false;
int n_workers = n_threads - 1;
if (n_workers > 0) {
@ -633,9 +662,7 @@ void ggml_threading_stop(struct ggml_threading_context *ctx) {
PRINT_DEBUG("[main] stopping thread pool ...\n");
ctx->shared.stop = true;
ggml_spin_lock(&ctx->shared.spin);
ggml_threading_wakeup_workers(&ctx->shared);
ggml_spin_unlock(&ctx->shared.spin);
ggml_threading_resume(ctx);
for (int j = 0; j < ctx->n_threads - 1; j++) {
GGML_ASSERT(pthread_join(ctx->workers[j].thrd, NULL) == 0);

View File

@ -25,11 +25,12 @@ enum ggml_threading_features {
typedef ggml_thread_ret_t(ggml_threading_thread_runner)(void *data);
// Init and start underlying workers if n_threads > 1.
// n_threads: number of threads (including caller) involving in computing tasks.
//
// thread: optional OS thread runner, default value:
// `ggml_threading_graph_compute_thread`.
//
// task_runner: default task runner, nullable wheen tensor.runner is not NULL.
// task_runner: default task runner, nullable when tensor.runner is not NULL.
// Overridden by tensor.runner.
// features: configure threading behaviour, optional.
// threading additional features. see `ggml_threading_feature`, default 0.
@ -41,9 +42,18 @@ ggml_threading_start(int n_threads, ggml_threading_thread_runner *thread,
enum ggml_threading_features features,
int64_t stages_time[3]);
// Suspend worker threads.
void ggml_threading_suspend(struct ggml_threading_context *ctx);
// Resume worker threads.
void ggml_threading_resume(struct ggml_threading_context *ctx);
// Stop workers (if exist), free memories (including the ctx).
void ggml_threading_stop(struct ggml_threading_context *ctx);
// Is all worker threads suspending?
bool ggml_threading_is_suspending(struct ggml_threading_context *ctx);
// The default implementation of `ggml_threading_thread_runner`
ggml_thread_ret_t ggml_threading_graph_compute_thread(void *data);

View File

@ -41,8 +41,9 @@ static const int n_repeat = 10;
// counter with array.
static int work_done_arr[MAX_N_THREADS];
static enum ggml_compute_error mock_task_runner(const struct ggml_compute_params *params,
struct ggml_tensor *node) {
static enum ggml_compute_error
mock_task_runner(const struct ggml_compute_params *params,
struct ggml_tensor *node) {
int64_t loops = node->task_profile.dev_flags[1] * 1000 * 1000;
if (node->task_profile.stages[params->type].parallel) {
loops /= params->nth;
@ -59,7 +60,7 @@ static enum ggml_compute_error mock_task_runner(const struct ggml_compute_params
return GGML_COMPUTE_OK;
}
int test_driver(int id, struct ggml_tensor *node, int n_threads) {
static int test_driver(int id, struct ggml_tensor *node, int n_threads) {
uint8_t loops = node->task_profile.dev_flags[1];
printf(
"\n[test-ggml-threading] #%02d, workload: %2d million(s), n_threads: "
@ -81,8 +82,8 @@ int test_driver(int id, struct ggml_tensor *node, int n_threads) {
node->task_profile.runner = mock_task_runner;
struct ggml_threading_context *ctx =
ggml_threading_start(n_threads, NULL, NULL, features, /*stages_time*/ NULL);
struct ggml_threading_context *ctx = ggml_threading_start(
n_threads, NULL, NULL, features, /*stages_time*/ NULL);
int t1 = (int)ggml_time_us();
@ -148,7 +149,7 @@ mock_task_runner_fallback(const struct ggml_compute_params *params,
// By design, fallback should happen when attempt computing tensor in GPU,
// thus it is not parallelled.
int test_fallback(struct ggml_tensor *node) {
static int test_fallback(struct ggml_tensor *node) {
struct ggml_threading_context *ctx = ggml_threading_start(
1, NULL, mock_task_runner_fallback,
/*features*/ GGML_THREADING_FEATURE_NONE, /*stages_time*/ NULL);
@ -182,7 +183,7 @@ customized_node_runner(const struct ggml_compute_params *params,
}
// Test when node->task_profile.runner is not NULL.
int test_customized_node_runner(struct ggml_tensor *node) {
static int test_customized_node_runner(struct ggml_tensor *node) {
struct ggml_threading_context *ctx = ggml_threading_start(
1, NULL, mock_task_runner,
/*features*/ GGML_THREADING_FEATURE_NONE, /*stages_time*/ NULL);
@ -204,6 +205,121 @@ int test_customized_node_runner(struct ggml_tensor *node) {
return 0;
}
static enum ggml_compute_error
lifecycle_runner(const struct ggml_compute_params *params,
struct ggml_tensor *node) {
UNUSED(params);
UNUSED(node);
return GGML_COMPUTE_OK;
}
// Test thread lifecycle: start -> suspend -> resume -> stop
static int test_lifecycle(void) {
struct ggml_tensor node;
memset(&node, 0, sizeof(struct ggml_tensor));
struct ggml_task_stage *stages = node.task_profile.stages;
stages[0].valid = true;
stages[1].valid = true;
stages[1].parallel = true;
node.op = GGML_OP_MUL_MAT;
struct ggml_tensor src0 = {
.type = GGML_TYPE_Q4_0,
};
struct ggml_tensor src1 = {
.type = GGML_TYPE_F32,
};
node.src0 = &src0;
node.src1 = &src1;
int t0 = (int)ggml_time_ms();
// Suppose creating threading when entering session.
// We have to try affable threads.
struct ggml_threading_context *ctx = NULL;
int threads_arr[] = {4, 2};
int threads_arr_len = sizeof(threads_arr) / sizeof(threads_arr[0]);
int n_threads = 1;
for (int i = 0; i < threads_arr_len; i++) {
n_threads = threads_arr[i];
int start_time = (int)ggml_time_ms();
ctx = ggml_threading_start(
n_threads, NULL, lifecycle_runner,
/*features*/ GGML_THREADING_FEATURE_WAIT_ON_DONE |
GGML_THREADING_FEATURE_PERF,
/*stages_time*/ NULL);
int elapsed = (int)ggml_time_ms() - start_time;
if (elapsed > 5 * n_threads) {
printf("[test-ggml-threading] %s: it took %d ms to start %d worker "
"thread(s), skip\n",
__func__, elapsed, n_threads - 1);
ggml_threading_stop(ctx);
} else {
break;
}
}
if (n_threads == 1) {
printf("[test-ggml-threading] %s: too slow to start at least 1 worker "
"thread(s), skip\n",
__func__);
return 0;
}
// Suppose exiting from previous compute graph ...
printf("[test-ggml-threading] %s: %d workers, suspending ...\n", __func__,
n_threads - 1);
ggml_threading_suspend(ctx);
// Suppose entering new compute graph ...
printf("[test-ggml-threading] test lifecycle: resuming ...\n");
ggml_threading_resume(ctx);
const int m = 2;
const int n = 50;
printf("[test-ggml-threading] %s: computing %d tensors (half wait)...\n",
__func__, m * n);
for (int i = 0; i < m; i++) {
stages[0].wait = (i == 0);
for (int j = 0; j < n; j++) {
ggml_threading_compute_tensor(ctx, &node, /*wdata*/ NULL,
/*wsize*/ 0);
}
}
printf("[test-ggml-threading] %s: compute done, resuming...\n", __func__);
ggml_threading_resume(ctx);
const int loops = 100;
printf("[test-ggml-threading] %s: try %d loops of suspend-resume ...\n",
__func__, loops);
for (int i = 0; i < loops; i++) {
ggml_threading_suspend(ctx);
if (!ggml_threading_is_suspending(ctx)) {
abort();
}
ggml_threading_resume(ctx);
if (ggml_threading_is_suspending(ctx)) {
abort();
}
}
printf("[test-ggml-threading] %s: stopping ...\n", __func__);
ggml_threading_stop(ctx);
int elapsed_ms = (int)ggml_time_ms() - t0;
printf("[test-ggml-threading] %s: elapsed %d ms\n", __func__, elapsed_ms);
return 0;
}
int main(void) {
ggml_time_init();
@ -268,21 +384,21 @@ int main(void) {
}
// skip this n_threads when too slow.
int t0 = (int)ggml_time_us();
int t0 = (int)ggml_time_ms();
struct ggml_threading_context *ctx =
ggml_threading_start(n_threads, ggml_threading_graph_compute_thread,
NULL, 0, /*stages_time*/ NULL);
int t1 = (int)ggml_time_us();
int t1 = (int)ggml_time_ms();
ggml_threading_stop(ctx);
int elapsed_us = t1 - t0;
if (elapsed_us > 500 * n_threads) {
int elapsed_ms = t1 - t0;
if (elapsed_ms > 5 * n_threads) {
printf("[test-ggml-threading] warning: it took took %7.3f "
"ms to start %2d worker thread(s). Too slow, skip.\n",
1.0 * elapsed_us / 1000, n_threads - 1);
1.0 * elapsed_ms, n_threads - 1);
threads_arr[i] = 0;
++n_slow;
} else {
@ -430,6 +546,17 @@ int main(void) {
}
}
// lifecycle.
{
printf("[test-ggml-threading] test lifecycle ...\n");
++n_tests;
if (test_lifecycle() == 0) {
++n_passed;
printf("[test-ggml-threading] test lifecycle: ok\n\n");
}
}
printf("[test-ggml-threading] %d/%d passed.\n", n_passed, n_tests);
return (n_passed == n_tests) ? 0 : 1;