llama : add llama_init_backend() API (close #1527)

This commit is contained in:
Georgi Gerganov 2023-05-20 11:06:11 +03:00
parent d2c59b8ba4
commit ec2e10c444
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
7 changed files with 48 additions and 29 deletions

View File

@ -1,6 +1,7 @@
#include <locale.h>
#include "ggml.h" #include "ggml.h"
#include "build-info.h" #include "build-info.h"
#include <locale.h>
#include <assert.h> #include <assert.h>
#include <math.h> #include <math.h>
#include <cstring> #include <cstring>

View File

@ -31,6 +31,8 @@ int main(int argc, char ** argv) {
params.prompt = gpt_random_prompt(rng); params.prompt = gpt_random_prompt(rng);
} }
llama_init_backend();
llama_context * ctx; llama_context * ctx;
// load the model // load the model

View File

@ -96,8 +96,7 @@ int main(int argc, char ** argv) {
params.prompt = gpt_random_prompt(rng); params.prompt = gpt_random_prompt(rng);
} }
// params.prompt = R"(// this function checks if the number n is prime llama_init_backend();
//bool is_prime(int n) {)";
llama_context * ctx; llama_context * ctx;
g_ctx = &ctx; g_ctx = &ctx;

View File

@ -143,6 +143,8 @@ int main(int argc, char ** argv) {
params.prompt = gpt_random_prompt(rng); params.prompt = gpt_random_prompt(rng);
} }
llama_init_backend();
llama_context * ctx; llama_context * ctx;
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any

View File

@ -1,7 +1,7 @@
#include "ggml.h"
#include "llama.h"
#include "build-info.h" #include "build-info.h"
#include "llama.h"
#include <cstdio> #include <cstdio>
#include <map> #include <map>
#include <string> #include <string>
@ -42,8 +42,6 @@ bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::st
// ./quantize models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads] // ./quantize models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads]
// //
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_time_init();
if (argc < 3) { if (argc < 3) {
fprintf(stderr, "usage: %s model-f32.bin [model-quant.bin] type [nthreads]\n", argv[0]); fprintf(stderr, "usage: %s model-f32.bin [model-quant.bin] type [nthreads]\n", argv[0]);
for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) { for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
@ -52,12 +50,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
// needed to initialize f16 tables llama_init_backend();
{
struct ggml_init_params params = { 0, NULL, false };
struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx);
}
// parse command line arguments // parse command line arguments
const std::string fname_inp = argv[1]; const std::string fname_inp = argv[1];
@ -116,25 +109,25 @@ int main(int argc, char ** argv) {
} }
fprintf(stderr, "\n"); fprintf(stderr, "\n");
const int64_t t_main_start_us = ggml_time_us(); const int64_t t_main_start_us = llama_time_us();
int64_t t_quantize_us = 0; int64_t t_quantize_us = 0;
// load the model // load the model
{ {
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = llama_time_us();
if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype, nthread)) { if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype, nthread)) {
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
return 1; return 1;
} }
t_quantize_us = ggml_time_us() - t_start_us; t_quantize_us = llama_time_us() - t_start_us;
} }
// report timing // report timing
{ {
const int64_t t_main_end_us = ggml_time_us(); const int64_t t_main_end_us = llama_time_us();
printf("\n"); printf("\n");
printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0); printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0);

View File

@ -839,6 +839,21 @@ bool llama_mlock_supported() {
return llama_mlock::SUPPORTED; return llama_mlock::SUPPORTED;
} }
void llama_init_backend() {
ggml_time_init();
// needed to initialize f16 tables
{
struct ggml_init_params params = { 0, NULL, false };
struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx);
}
}
int64_t llama_time_us() {
return ggml_time_us();
}
// //
// model loading // model loading
// //

31
llama.h
View File

@ -40,9 +40,9 @@ extern "C" {
typedef int llama_token; typedef int llama_token;
typedef struct llama_token_data { typedef struct llama_token_data {
llama_token id; // token id llama_token id; // token id
float logit; // log-odds of the token float logit; // log-odds of the token
float p; // probability of the token float p; // probability of the token
} llama_token_data; } llama_token_data;
typedef struct llama_token_data_array { typedef struct llama_token_data_array {
@ -73,16 +73,16 @@ extern "C" {
// model file types // model file types
enum llama_ftype { enum llama_ftype {
LLAMA_FTYPE_ALL_F32 = 0, LLAMA_FTYPE_ALL_F32 = 0,
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
// LLAMA_FTYPE_MOSTLY_Q4_3 (6) support has been removed // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
}; };
LLAMA_API struct llama_context_params llama_context_default_params(); LLAMA_API struct llama_context_params llama_context_default_params();
@ -90,6 +90,13 @@ extern "C" {
LLAMA_API bool llama_mmap_supported(); LLAMA_API bool llama_mmap_supported();
LLAMA_API bool llama_mlock_supported(); LLAMA_API bool llama_mlock_supported();
// TODO: not great API - very likely to change
// Initialize the llama + ggml backend
// Call once at the start of the program
LLAMA_API void llama_init_backend();
LLAMA_API int64_t llama_time_us();
// Various functions for loading a ggml llama model. // Various functions for loading a ggml llama model.
// Allocate (almost) all memory needed for the model. // Allocate (almost) all memory needed for the model.
// Return NULL on failure // Return NULL on failure