From 3232db628c8faf595f022ba19203acc104efddb0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 9 Jul 2023 14:08:53 +0300 Subject: [PATCH] mpi : trying to move more MPI stuff into ggml-mpi (WIP) (#2099) --- examples/embd-input/embd-input-lib.cpp | 2 +- examples/embedding/embedding.cpp | 4 +- examples/main/main.cpp | 4 +- examples/perplexity/perplexity.cpp | 4 +- examples/quantize/quantize.cpp | 4 +- examples/server/server.cpp | 4 +- examples/simple/simple.cpp | 4 +- ggml-mpi.c | 70 +++++++++++++++++++++--- ggml-mpi.h | 28 ++++++++-- llama.cpp | 73 +++++++++++--------------- llama.h | 4 +- 11 files changed, 134 insertions(+), 67 deletions(-) diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 5fa4942be..26563821a 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -34,7 +34,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) { } fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 03e801c2a..5192d6df5 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -35,7 +35,7 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -93,5 +93,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + llama_backend_free(); + return 0; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index ef57a8982..07d8fc6ac 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -105,7 +105,7 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -671,7 +671,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); - llama_finalize_backend(); + llama_backend_free(); return 0; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 68f44ba80..7e120ff12 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -147,7 +147,7 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -172,7 +172,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); - llama_finalize_backend(); + llama_backend_free(); return 0; } diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 1eb0f75d6..797d2f0c5 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -180,7 +180,7 @@ int main(int argc, char ** argv) { usage(argv[0]); } - llama_init_backend(false); + llama_backend_init(false); // parse command line arguments const std::string fname_inp = argv[arg_idx]; @@ -257,5 +257,7 @@ int main(int argc, char ** argv) { printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0); } + llama_backend_free(); + return 0; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2cbfc0018..296c5d646 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1079,7 +1079,7 @@ int main(int argc, char **argv) params.model_alias = params.model; } - llama_init_backend(params.numa); + llama_backend_init(params.numa); LOG_INFO("build info", {{"build", BUILD_NUMBER}, {"commit", BUILD_COMMIT}}); @@ -1309,5 +1309,7 @@ int main(int argc, char **argv) return 1; } + llama_backend_free(); + return 0; } diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 57a0fb7c5..aa2c4352d 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -66,7 +66,7 @@ int main(int argc, char ** argv) // Init LLM : //--------------------------------- - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -173,7 +173,7 @@ int main(int argc, char ** argv) llama_free( ctx ); llama_free_model( model ); - llama_finalize_backend(); + llama_backend_free(); return 0; } diff --git a/ggml-mpi.c b/ggml-mpi.c index bf301d08b..b68e2c42b 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -2,9 +2,11 @@ #include "ggml.h" +#include + #include #include -#include + #define UNUSED GGML_UNUSED struct ggml_mpi_tensor_info { @@ -52,9 +54,8 @@ static void ggml_mpi_compute_forward_recv( struct ggml_tensor * ggml_mpi_send_tensor( struct ggml_context * ctx, - struct ggml_tensor *src, - int dst_rank) { - + struct ggml_tensor * src, + int dst_rank) { struct ggml_tensor * result = ggml_map_custom1_inplace_f32(ctx, src, ggml_mpi_compute_forward_send); // TODO how/when to free this struct? @@ -67,9 +68,9 @@ struct ggml_tensor * ggml_mpi_send_tensor( struct ggml_tensor * ggml_mpi_recv_tensor( struct ggml_context * ctx, - struct ggml_tensor *parent, - struct ggml_tensor *dst, - int src_rank) { + struct ggml_tensor * parent, + struct ggml_tensor * dst, + int src_rank) { struct ggml_tensor * result = ggml_map_custom2_inplace_f32(ctx, dst, parent, ggml_mpi_compute_forward_recv); // TODO how/when to free this struct? @@ -79,3 +80,58 @@ struct ggml_tensor * ggml_mpi_recv_tensor( return result; } + +struct ggml_mpi_context { + int mpi_rank; + int mpi_size; +}; + +void ggml_mpi_backend_init(void) { + MPI_Init(NULL, NULL); +} + +void ggml_mpi_backend_free(void) { + MPI_Finalize(); +} + +struct ggml_mpi_context * ggml_mpi_init(void) { + struct ggml_mpi_context * ctx = calloc(1, sizeof(struct ggml_mpi_context)); + + MPI_Comm_rank(MPI_COMM_WORLD, &ctx->mpi_rank); + MPI_Comm_size(MPI_COMM_WORLD, &ctx->mpi_size); + + return ctx; +} + +void ggml_mpi_free(struct ggml_mpi_context * ctx) { + free(ctx); +} + +int ggml_mpi_rank(struct ggml_mpi_context * ctx) { + return ctx->mpi_rank; +} + +struct ggml_tensor * ggml_mpi_eval_init( + struct ggml_mpi_context * ctx_mpi, + struct ggml_context * ctx, + int n_embd, + int * n_tokens, + int * n_past, + int * n_threads) { + struct ggml_tensor * res = NULL; + + // synchronize the worker node parameters with the root node + MPI_Barrier(MPI_COMM_WORLD); + + MPI_Bcast(n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Bcast(n_past, 1, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); + + if (ctx_mpi->mpi_rank > 0) { + res = ggml_mpi_recv_tensor(ctx, NULL, + ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, *n_tokens), ctx_mpi->mpi_rank - 1); + ggml_set_name(res, "mpi_recv"); + } + + return res; +} diff --git a/ggml-mpi.h b/ggml-mpi.h index ef5269dc5..157c6255d 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -9,13 +9,31 @@ extern "C" { struct ggml_tensor * ggml_mpi_send_tensor( struct ggml_context * ctx, - struct ggml_tensor *src, - int dst_rank); + struct ggml_tensor * src, + int dst_rank); struct ggml_tensor * ggml_mpi_recv_tensor( struct ggml_context * ctx, - struct ggml_tensor *parent, - struct ggml_tensor *dst, - int src_rank); + struct ggml_tensor * parent, + struct ggml_tensor * dst, + int src_rank); + +struct ggml_mpi_context; + +void ggml_mpi_backend_init(void); +void ggml_mpi_backend_free(void); + +struct ggml_mpi_context * ggml_mpi_init(void); +void ggml_mpi_free(struct ggml_mpi_context * ctx); + +int ggml_mpi_rank(struct ggml_mpi_context * ctx); + +struct ggml_tensor * ggml_mpi_eval_init( + struct ggml_mpi_context * ctx_mpi, + struct ggml_context * ctx, + int n_embd, + int * n_tokens, + int * n_past, + int * n_threads); #ifdef __cplusplus } diff --git a/llama.cpp b/llama.cpp index 42b2f6155..d84e827c3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -52,10 +52,6 @@ #include #include -#ifdef GGML_USE_MPI -#include -#endif - #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif @@ -337,8 +333,9 @@ struct llama_context { ggml_metal_context * ctx_metal = NULL; #endif - int mpi_rank; - int mpi_size; +#ifdef GGML_USE_MPI + ggml_mpi_context * ctx_mpi = NULL; +#endif int buf_last = 0; size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; @@ -859,7 +856,7 @@ bool llama_mlock_supported() { return llama_mlock::SUPPORTED; } -void llama_init_backend(bool numa) { +void llama_backend_init(bool numa) { ggml_time_init(); // needed to initialize f16 tables @@ -872,14 +869,15 @@ void llama_init_backend(bool numa) { if (numa) { ggml_numa_init(); } + #ifdef GGML_USE_MPI - MPI_Init(NULL, NULL); + ggml_mpi_backend_init(); #endif } -void llama_finalize_backend() { +void llama_backend_free() { #ifdef GGML_USE_MPI - MPI_Finalize(); + ggml_mpi_backend_free(); #endif } @@ -1282,9 +1280,9 @@ static bool llama_eval_internal( llama_context & lctx, const llama_token * tokens, const float * embd, - const int n_tokens, - const int n_past, - const int n_threads, + int n_tokens, + int n_past, + int n_threads, const char * cgraph_fname) { LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); @@ -1333,16 +1331,14 @@ static bool llama_eval_internal( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (lctx.mpi_rank > 0) { #ifdef GGML_USE_MPI - inpL = ggml_mpi_recv_tensor(ctx0, NULL, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N), - lctx.mpi_rank-1); - ggml_set_name(inpL, "mpi_recv"); -#else - GGML_ASSERT(false); + inpL = ggml_mpi_eval_init(lctx.ctx_mpi, ctx0, n_embd, &n_tokens, &n_past, &n_threads); + + if (inpL) { + // only rank 0 loads uses the input + } else #endif - } else if (tokens) { + if (tokens) { struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); ggml_set_name(embd, "embd"); memcpy(embd->data, tokens, N*ggml_element_size(embd)); @@ -1585,7 +1581,6 @@ static bool llama_eval_internal( // input for next layer inpL = cur; - } lctx.use_buf(ctx0, 0); @@ -1601,6 +1596,7 @@ static bool llama_eval_internal( GGML_ASSERT(false); #endif } + if (lctx.mpi_rank == 0) { if (lctx.mpi_size > 1) { #ifdef GGML_USE_MPI @@ -1688,7 +1684,11 @@ static bool llama_eval_internal( // update kv token count lctx.kv_self.n = n_past + N; - if (lctx.mpi_rank == 0) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(lctx.ctx_mpi) == 0) { +#else + { +#endif // extract logits { auto & logits_out = lctx.logits; @@ -2659,14 +2659,6 @@ struct llama_context * llama_new_context_with_model( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; -#ifdef GGML_USE_MPI - MPI_Comm_size(MPI_COMM_WORLD, &ctx->mpi_size); - MPI_Comm_rank(MPI_COMM_WORLD, &ctx->mpi_rank); -#else - ctx->mpi_size = 1; - ctx->mpi_rank = 0; -#endif - ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; // reserve memory for context buffers @@ -2739,15 +2731,17 @@ struct llama_context * llama_new_context_with_model( } #endif - if (ctx->mpi_rank > 0) { +#ifdef GGML_USE_MPI + ctx->ctx_mpi = ggml_mpi_init(); + + if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { // Enter a blocking eval loop with dummy input, letting rank=0 drive the process const std::vector tmp = { llama_token_bos(), }; - while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)); -#ifdef GGML_USE_MPI - MPI_Finalize(); -#endif + while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; + llama_backend_free(); exit(1); } +#endif return ctx; } @@ -3425,13 +3419,6 @@ int llama_eval( int n_tokens, int n_past, int n_threads) { -#ifdef GGML_USE_MPI - // Synchronize the worker node parameters with the root node - MPI_Barrier(MPI_COMM_WORLD); - MPI_Bcast(&n_past, 1, MPI_INT, 0, MPI_COMM_WORLD); - MPI_Bcast(&n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD); - MPI_Bcast(&n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); -#endif if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; diff --git a/llama.h b/llama.h index b90c52355..686463aa2 100644 --- a/llama.h +++ b/llama.h @@ -158,9 +158,9 @@ extern "C" { // Initialize the llama + ggml backend // If numa is true, use NUMA optimizations // Call once at the start of the program - LLAMA_API void llama_init_backend(bool numa); + LLAMA_API void llama_backend_init(bool numa); // Call once at the end of the program - currently only used for MPI - LLAMA_API void llama_finalize_backend(); + LLAMA_API void llama_backend_free(); LLAMA_API int64_t llama_time_us();