mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
44c117f41e
* fix track_max_mem in forward_batch_wo_cache_flash_attn_train * remove unnecessary Adam(W) optimizer tensors. reduces optimizer memory overhead from 7*modelsize to 2*modelsize. additionally allows to optimize models with more than 2^31 parameters by replacing int with int64_t. bumps training checkpoint file version, but old checkpoints can still be read. new version with less tensors is saved. * add gradient clipping to AdamW * Fix reset of unused g->nodes and g->grads to NULL * implement gradient checkpointing for training reduces memory overhead from O(n_layer) to O(sqrt(n_layer)) as explained in readme of https://github.com/cybertronai/gradient-checkpointing * remove unused compute buffer 3 * add and use function ggml_build_backward_expand to avoid stack overflows with large maximum number of nodes GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep); * change AdamW decay parameter to work like the torch AdamW decay parameter It is now relative to Adam learning rate `alpha*sched`. Before that it was relative to `sched` only. `alpha` being the maximum learning rate and `sched` being a scaling parameter in [0..1] * change default AdamW weight decay parameter used in training to 0.1 as used in nanoGPT * change default AdamW weight decay parameter defined in ggml to 0.0, making Adam default instead of AdamW btw: the default weight decay parameter for torch.optim.AdamW is 0.01 * bug fixes for cross entropy loss ggml_cross_entropy_loss: sums where not correctly added in workload of each thread ggml_cross_entropy_loss_back: simplify backward process, reducing numerical issues guard usage of exp f16 lookup in cross entropy by #define GGML_CROSS_ENTROPY_EXP_FP16 cross entropy loss is only used once during training, but it is quite sensitive to numerical errors introduced by exp-f16-lookup. so exp-f16-lookup for cross entropy loss is disabled by default, trading better gradients for very slightly worse runtime performance. * fix test-grad0 for cross_entropy_loss the second argument to cross_entropy_loss must sum up to 1 for each row * fix test-grad0 for soft_max dont use only sum as aggregation, because sum of softmax is always 1 -> finite differences should not work instead use sum(log(soft_max()*(1-eps)+eps)); use eps to avoid log(0) * improve finite differences of test-grad0 by using double instead of float * change cross_entropy_loss to output average over all rows this helps keeping the loss and gradients in a sane range * improve gradient checkpointing sqrt(n_layers) is only the best checkpoint step when mem size of checkpoints and mem size of layers are equal. since layers require more memory than the single-tensor-checkpoint we use, the optimal values are compute different: ``` given: n, u, v objective: minimize(a*u+b*v) where a*b=n, a>0, b>0 b=n/a minimize(a*u+v*n/a) diff(a*u+v*n/a, a) = u - (v*n/a)/a diff(a*u+v*n/a, a) == 0 u - (v*n/a)/a == 0 u == v*n/(a*a) u*a*a = v*n a*a = v*n/u a = sqrt(n*v/u) ``` this change results in more checkpoints, requiring less layers to store between checkpoints, overall improving memory usage. * disable gradient checkpointing debug output * llama : fix rope usage in train-text-from-scratch after ChatGLM change * add more training parameters: --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. --adam-min-alpha N Adam minimum learning rate alpha, usually 0.1 * alpha * replace memcpy with reshape operation so that the graph is not cut at the input this makes it possible to store other values into the input tensor and then simply recompute the graph without rebuilding it * remove unused function argument from get_example_targets_batch * measure and print total training time * add optimization callback to ggml_opt_resume_g this callback is called before each iteration with custom data and pointer to learning schedule parameter (only used in Adam(W)). can be used for dynamic learning schedule and setting input data for batches before each iteration * use optimization callback in training allows dynamic learning schedule and different batch data for each iteration without relying on low n_iter and high n_examples parameters reduces runtime by avoiding restart of optimization function and improves training convergence by providing a different batch for each iteration * add minimum number of tensor dimensions to apply weight decay (default 2) this allows to not apply weight decay to bias parameters * rename training parameter cos-decay-alpha to cos-decay-min and clarify that adam-min-alpha also applies to warmup * fix increase of model.train_samples and model.train_tokens now that each optimizer iteration gets its own batch we need to multiply by number of opt iterations * change sampling parameters for prediction after training to defaults of common.h and clarify what is context for prediction and what are generated tokens * tighten abs error bounds for cross_entropy_loss in test-grad0 * add conditional compilation of using F16 exp in flash attention uncomment `// #define GGML_FLASH_ATTN_EXP_FP16` to enable usage of f16 exp in flash attention * tighten abs error bounds for flash_attn in test-grad0 * tighten abs error bounds for sqrt in test-grad0 * remove out-commented vectorized code of opt_adam the vectorized code might be bit faster for low number of parameters, but it had a big memory usage overhead * ggml : update ggml_rms_norm_back with configurable eps * llama training : fix ggml_rms_norm_back calls to pass configurable eps * remove trailing whitespace * add train function using automatic gradient checkpointing backward pass and allocator * in train function replace add_inplace by regular add because using add_inplace seems to result in different gradients * don't use allocate hash_map on context because the context has no_alloc=True when using memory allocator resulting in NULL data pointers * correctly clone reshape and permute operations by also cloning tensor->nb values * fix variable name and add missing type cast * terminate recursive tensor cloning when reaching tensor without src tensors * correctly clone view tensors by setting data pointers without this the checkpointing would only work when being used together with memory allocator * fix variable names * swap arguments to commutative ops to be the same as in `forward_batch_wo_cache_flash_attn` * add input tensors as checkpoints so that recursive tensor cloning of gradient checkpointing terminates on input tensors * fix variable name and add missing boolean negation * make sure some tensors are not reallocated by inserting new temporary nodes depending on them: output and parameter gradient tensors need to be available at the end of the graph execution parameter gradient tensors also need to be available before the graph execution because they are set to zero before each optimizer iteration checkpoint tensors are allocated all together to reduce memory allocator fragmentation afterwards, in addition to the temporary nodes, we also need to reset the temporary leafs * fix ASSERT to work with zero layers * add training options whether to use allocator and/or unified training function * integrate unified training function which may use memory allocator the unified training function also supports arguments whether to use flash attention and/or gradient checkpointing * format name of cloned tensors with " (clone)" suffix * set names for tensors in unified train function for easier debugging * allocate graph on context using ggml_new_graph * remove handwritten training functions * remove unused training parameters "use_scratch" and "use_unified" * remove trailing whitespace * remove unused train params: mem_compute1_gb & mem_compute2_gb mem_compute_gb is used for compute when automatic memory allocator is not enabled, otherwise it can be very small to only hold the tensor definitions mem_compute0_gb is used for automatic memory allocator (as long as measurement of max required size is not implemented) * remove unused forward_batch function * add debug asserts in ggml_allocr_alloc to some common pitfalls when using this function directly * only use ggml_allocr_alloc when tensor has NULL data and is no view * fix test when to create temporary backward graph temporary backward graph is only necessary when using checkpointing * fix memory "leak" in optimizers each iteration a new cplan with new memory for work data was allocated. now cplan creation only happens at the start of optimization, with each iteration reusing the cplan and its work data. * reverse order of for loop in ggml_build_backward_expand to save memory when using gradient checkpointing and allocator with this loop order gradient checkpointing with allocator on 16 layer model saves 13% memory; 2 layer memory it saves 2% memory. the computation results are the same * add missing lctx argument to get_example_targets_batch * implement llama model file saving using gguf checkpoint loading and saving disabled, to be replaced by loading and saving via gguf * implement loading/saving of checkpointing files using GGUF * bug fixes * add checkpoint file version for future compatibility * update readme with gguf filenames * save & load opt->just_initialized value * add first draft for checkpoint conversion script * add gguf arch and ftype * save opt parameter counter as uint64 * add gguf key and tensor names for optimizer and training * add layer_norm_rms_eps to checkpoint convert script * use same GGUF_GET_KEY macro as in llama.cpp * use norm_rms_eps, and rope parameters and command line options to set them * fix memory corruption bug in gguf ctx->kv and ctx->infos was reallocated using not-aligned realloc, but freed with aligned free. to fix this a GGML_ALIGNED_REALLOC was added, but there is no posix_memalign_realloc function. so on non-windows and non-mingw32 platforms we fall back to aligned malloc, followed by copying and freeing the old data. * add gguf example cmake file * bug fixes in tokenize_file * bug fixes in load_llama_model_gguf * bug fix: init model when no checkpoint was loaded * bug fix in read_tensor_by_name * bug fix in load_opt_context_gguf * avoid printing lots of spaced on the unusual case that loss gets nan * set name of tensors with empty name from what was read from gguf * remove trailing whitespace * print data checksums before saving and after loading to verify correctness * bug fixes for convert-train-checkpoint-to-gguf * temporarily add code to write old checkpoint files used to verify that old checkpoint files are correctly converted to gguf * bug fixes for convert-train-checkpoint-to-gguf.py loading checkpoints with opt_version=0 * remove code used to verify correctness of checkpoint file conversion * remove trailing whitespace * remove prediction related code use main for prediction, it is better optimized * update train-text-from-scratch README.md * fix non-windows GGML_ALIGNED_REALLOC * add missing blank line at end of file * remove GGML_ALIGNED_REALLOC and use normal malloc/realloc/free for gguf ctx->kv & ctx->infos * train : fix compile warnings --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2298 lines
98 KiB
C++
2298 lines
98 KiB
C++
#include "ggml.h"
|
|
#include "ggml-alloc.h"
|
|
#include "common.h"
|
|
#include "llama.h"
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
#include <cassert>
|
|
#include <climits>
|
|
#include <cstring>
|
|
#include <cstdarg>
|
|
#include <ctime>
|
|
#include <random>
|
|
#include <stdexcept>
|
|
#include <algorithm>
|
|
#include <string>
|
|
|
|
#if defined(_MSC_VER)
|
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
#endif
|
|
|
|
struct random_normal_distribution {
|
|
std::mt19937 gen;
|
|
std::normal_distribution<float> rd;
|
|
float min;
|
|
float max;
|
|
};
|
|
|
|
struct random_uniform_distribution {
|
|
std::mt19937 gen;
|
|
std::uniform_real_distribution<float> rd;
|
|
};
|
|
|
|
void init_random_normal_distribution(struct random_normal_distribution * rnd, int seed, float mean, float std, float min, float max) {
|
|
rnd->gen = std::mt19937(seed);
|
|
rnd->rd = std::normal_distribution<float>{mean, std};
|
|
rnd->min = min;
|
|
rnd->max = max;
|
|
}
|
|
|
|
void init_random_uniform_distribution(struct random_uniform_distribution * rnd, int seed, float min, float max) {
|
|
rnd->gen = std::mt19937(seed);
|
|
rnd->rd = std::uniform_real_distribution<float>{min, max};
|
|
}
|
|
|
|
int clamp(const int v, const int min, const int max) {
|
|
return ((v < min) ? (min) : (v > max) ? (max) : v);
|
|
}
|
|
|
|
float fclamp(const float v, const float min, const float max) {
|
|
return ((v < min) ? (min) : (v > max) ? (max) : v);
|
|
}
|
|
|
|
float frand() {
|
|
return (float)rand()/(float)RAND_MAX;
|
|
}
|
|
|
|
float frand_normal(struct random_normal_distribution * rnd) {
|
|
return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max);
|
|
}
|
|
|
|
float frand_uniform(struct random_uniform_distribution * rnd) {
|
|
return rnd->rd(rnd->gen);
|
|
}
|
|
|
|
struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
|
|
float scale = 1.0f; // xavier
|
|
switch (tensor->n_dims) {
|
|
case 1:
|
|
scale /= sqrtf(tensor->ne[0]);
|
|
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
|
|
*dst = scale * frand_normal(rnd);
|
|
}
|
|
break;
|
|
case 2:
|
|
scale /= sqrtf(tensor->ne[0]+tensor->ne[1]);
|
|
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
|
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
|
|
*dst = scale * frand_normal(rnd);
|
|
}
|
|
}
|
|
break;
|
|
case 3:
|
|
scale /= sqrtf(tensor->ne[0]+tensor->ne[1]);
|
|
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
|
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
|
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
|
|
*dst = scale * frand_normal(rnd);
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
case 4:
|
|
scale /= sqrtf(tensor->ne[0]+tensor->ne[1]);
|
|
for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
|
|
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
|
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
|
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
|
|
*dst = scale * frand_normal(rnd);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
default:
|
|
assert(false);
|
|
};
|
|
return tensor;
|
|
}
|
|
|
|
struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) {
|
|
switch (tensor->n_dims) {
|
|
case 1:
|
|
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
|
|
*dst = frand_uniform(rnd);
|
|
}
|
|
break;
|
|
case 2:
|
|
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
|
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
|
|
*dst = frand_uniform(rnd);
|
|
}
|
|
}
|
|
break;
|
|
case 3:
|
|
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
|
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
|
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
|
|
*dst = frand_uniform(rnd);
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
case 4:
|
|
for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
|
|
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
|
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
|
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
|
|
*dst = frand_uniform(rnd);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
default:
|
|
assert(false);
|
|
};
|
|
return tensor;
|
|
}
|
|
|
|
struct my_llama_hparams {
|
|
uint32_t n_vocab = 32000;
|
|
uint32_t n_ctx = 512;
|
|
uint32_t n_embd = 4096;
|
|
uint32_t n_head = 32;
|
|
uint32_t n_layer = 32;
|
|
uint32_t n_rot = 64;
|
|
uint32_t n_ff = 11008;
|
|
|
|
// float f_norm_eps = 1e-5; // falcon
|
|
float f_norm_rms_eps = 1e-5; // llama
|
|
|
|
float rope_freq_base = 10000.0f;
|
|
float rope_freq_scale = 1.0f;
|
|
|
|
bool operator!=(const my_llama_hparams& other) const {
|
|
return memcmp(this, &other, sizeof(my_llama_hparams));
|
|
}
|
|
};
|
|
|
|
struct my_llama_layer {
|
|
// normalization
|
|
struct ggml_tensor * attention_norm;
|
|
|
|
// attention
|
|
struct ggml_tensor * wq;
|
|
struct ggml_tensor * wk;
|
|
struct ggml_tensor * wv;
|
|
struct ggml_tensor * wo;
|
|
|
|
// normalization
|
|
struct ggml_tensor * ffn_norm;
|
|
|
|
// ff
|
|
struct ggml_tensor * w1;
|
|
struct ggml_tensor * w2;
|
|
struct ggml_tensor * w3;
|
|
};
|
|
|
|
struct my_llama_model {
|
|
struct ggml_context * ctx = NULL;
|
|
|
|
my_llama_hparams hparams;
|
|
|
|
struct ggml_tensor * tok_embeddings;
|
|
|
|
struct ggml_tensor * norm;
|
|
struct ggml_tensor * output;
|
|
|
|
std::vector<my_llama_layer> layers;
|
|
|
|
uint32_t train_its = 0;
|
|
uint32_t train_samples = 0;
|
|
uint32_t train_tokens = 0;
|
|
};
|
|
|
|
// gguf constants
|
|
const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type";
|
|
const char * LLM_KV_OPTIMIZER_TYPE_ADAM = "adam";
|
|
const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs";
|
|
const char * LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version";
|
|
const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count";
|
|
const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count";
|
|
const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count";
|
|
const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized";
|
|
const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss";
|
|
const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss";
|
|
const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count";
|
|
const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count";
|
|
const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss";
|
|
const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step";
|
|
const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j";
|
|
const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k";
|
|
const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end";
|
|
const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count";
|
|
|
|
const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments";
|
|
const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments";
|
|
const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values";
|
|
|
|
const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters";
|
|
const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters";
|
|
const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients";
|
|
const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients";
|
|
const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction";
|
|
const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values";
|
|
const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha";
|
|
const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys";
|
|
const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
|
|
const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
|
|
|
|
const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
|
|
const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
|
|
const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
|
|
const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
|
|
|
|
// gguf constants (sync with gguf.py)
|
|
|
|
const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture";
|
|
const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type";
|
|
|
|
const char * LLM_KV_CONTEXT_LENGTH = "%s.context_length";
|
|
const char * LLM_KV_EMBEDDING_LENGTH = "%s.embedding_length";
|
|
const char * LLM_KV_BLOCK_COUNT = "%s.block_count";
|
|
const char * LLM_KV_FEED_FORWARD_LENGTH = "%s.feed_forward_length";
|
|
const char * LLM_KV_ATTENTION_HEAD_COUNT = "%s.attention.head_count";
|
|
const char * LLM_KV_ATTENTION_LAYERNORM_RMS_EPS = "%s.attention.layer_norm_rms_epsilon";
|
|
const char * LLM_KV_ROPE_DIMENSION_COUNT = "%s.rope.dimension_count";
|
|
const char * LLM_KV_ROPE_FREQ_BASE = "%s.rope.freq_base"; // TODO load in llama.cpp
|
|
const char * LLM_KV_ROPE_SCALE_LINEAR = "%s.rope.scale_linear";
|
|
|
|
const char * LLM_KV_TOKENIZER_MODEL = "tokenizer.ggml.model";
|
|
const char * LLM_KV_TOKENIZER_LIST = "tokenizer.ggml.tokens";
|
|
const char * LLM_KV_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type";
|
|
const char * LLM_KV_TOKENIZER_SCORES = "tokenizer.ggml.scores";
|
|
const char * LLM_KV_TOKENIZER_MERGES = "tokenizer.ggml.merges";
|
|
const char * LLM_KV_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id";
|
|
const char * LLM_KV_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id";
|
|
const char * LLM_KV_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id";
|
|
const char * LLM_KV_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id";
|
|
const char * LLM_KV_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id";
|
|
|
|
const char * LLM_TENSOR_TOKEN_EMBD = "token_embd";
|
|
const char * LLM_TENSOR_OUTPUT_NORM = "output_norm";
|
|
const char * LLM_TENSOR_OUTPUT = "output";
|
|
const char * LLM_TENSOR_ATTN_NORM = "blk.%d.attn_norm";
|
|
const char * LLM_TENSOR_ATTN_Q = "blk.%d.attn_q";
|
|
const char * LLM_TENSOR_ATTN_K = "blk.%d.attn_k";
|
|
const char * LLM_TENSOR_ATTN_V = "blk.%d.attn_v";
|
|
const char * LLM_TENSOR_ATTN_OUT = "blk.%d.attn_output";
|
|
const char * LLM_TENSOR_FFN_NORM = "blk.%d.ffn_norm";
|
|
const char * LLM_TENSOR_FFN_GATE = "blk.%d.ffn_gate";
|
|
const char * LLM_TENSOR_FFN_DOWN = "blk.%d.ffn_down";
|
|
const char * LLM_TENSOR_FFN_UP = "blk.%d.ffn_up";
|
|
|
|
void print_params(struct my_llama_hparams * params) {
|
|
printf("%s: n_vocab: %d\n", __func__, params->n_vocab);
|
|
printf("%s: n_ctx: %d\n", __func__, params->n_ctx);
|
|
printf("%s: n_embd: %d\n", __func__, params->n_embd);
|
|
printf("%s: n_head: %d\n", __func__, params->n_head);
|
|
printf("%s: n_ff: %d\n", __func__, params->n_ff);
|
|
printf("%s: n_layer: %d\n", __func__, params->n_layer);
|
|
printf("%s: n_rot: %d\n", __func__, params->n_rot);
|
|
}
|
|
|
|
void init_model(struct my_llama_model * model) {
|
|
const auto & hparams = model->hparams;
|
|
|
|
const uint32_t n_embd = hparams.n_embd;
|
|
const uint32_t n_layer = hparams.n_layer;
|
|
const uint32_t n_vocab = hparams.n_vocab;
|
|
const uint32_t n_ff = hparams.n_ff;
|
|
|
|
struct ggml_context * ctx = model->ctx;
|
|
|
|
model->train_its = 0;
|
|
model->train_samples = 0;
|
|
model->train_tokens = 0;
|
|
|
|
std::vector<char> tn_buf;
|
|
tn_buf.resize(GGML_MAX_NAME);
|
|
auto tn = [&tn_buf](const char * key) -> const char * {
|
|
snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", key);
|
|
return tn_buf.data();
|
|
};
|
|
auto tni = [&tn_buf](const char * key, int bid) -> const char * {
|
|
snprintf(tn_buf.data(), tn_buf.size(), key, bid);
|
|
std::string s = tn_buf.data();
|
|
snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", s.c_str());
|
|
return tn_buf.data();
|
|
};
|
|
|
|
model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
|
|
model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
|
model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
|
|
|
|
ggml_set_name(model->tok_embeddings, tn(LLM_TENSOR_TOKEN_EMBD));
|
|
ggml_set_name(model->norm, tn(LLM_TENSOR_OUTPUT_NORM));
|
|
ggml_set_name(model->output, tn(LLM_TENSOR_OUTPUT));
|
|
|
|
model->layers.resize(n_layer);
|
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
|
auto & layer = model->layers[i];
|
|
|
|
layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
|
|
|
layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
|
|
layer.wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
|
|
layer.wv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
|
|
layer.wo = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
|
|
|
|
layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
|
|
|
layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
|
|
layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
|
|
layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
|
|
|
|
ggml_set_name(layer.attention_norm, tni(LLM_TENSOR_ATTN_NORM, i));
|
|
|
|
ggml_set_name(layer.wq, tni(LLM_TENSOR_ATTN_Q, i));
|
|
ggml_set_name(layer.wk, tni(LLM_TENSOR_ATTN_K, i));
|
|
ggml_set_name(layer.wv, tni(LLM_TENSOR_ATTN_V, i));
|
|
ggml_set_name(layer.wo, tni(LLM_TENSOR_ATTN_OUT, i));
|
|
|
|
ggml_set_name(layer.ffn_norm, tni(LLM_TENSOR_FFN_NORM, i));
|
|
|
|
ggml_set_name(layer.w1, tni(LLM_TENSOR_FFN_GATE, i));
|
|
ggml_set_name(layer.w2, tni(LLM_TENSOR_FFN_DOWN, i));
|
|
ggml_set_name(layer.w3, tni(LLM_TENSOR_FFN_UP, i));
|
|
}
|
|
}
|
|
|
|
void set_param_model(struct my_llama_model * model) {
|
|
const auto& hparams = model->hparams;
|
|
|
|
const uint32_t n_layer = hparams.n_layer;
|
|
|
|
struct ggml_context* ctx = model->ctx;
|
|
|
|
ggml_set_param(ctx, model->tok_embeddings);
|
|
ggml_set_param(ctx, model->norm);
|
|
ggml_set_param(ctx, model->output);
|
|
|
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
|
auto & layer = model->layers[i];
|
|
|
|
ggml_set_param(ctx, layer.attention_norm);
|
|
ggml_set_param(ctx, layer.wq);
|
|
ggml_set_param(ctx, layer.wk);
|
|
ggml_set_param(ctx, layer.wv);
|
|
ggml_set_param(ctx, layer.wo);
|
|
ggml_set_param(ctx, layer.ffn_norm);
|
|
ggml_set_param(ctx, layer.w1);
|
|
ggml_set_param(ctx, layer.w2);
|
|
ggml_set_param(ctx, layer.w3);
|
|
}
|
|
}
|
|
|
|
void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) {
|
|
const auto & hparams = model->hparams;
|
|
|
|
const uint32_t n_layer = hparams.n_layer;
|
|
|
|
struct random_normal_distribution rnd;
|
|
init_random_normal_distribution(&rnd, seed, mean, std, min, max);
|
|
|
|
randomize_tensor_normal(model->tok_embeddings, &rnd);
|
|
randomize_tensor_normal(model->norm, &rnd);
|
|
randomize_tensor_normal(model->output, &rnd);
|
|
|
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
|
auto & layer = model->layers[i];
|
|
randomize_tensor_normal(layer.attention_norm, &rnd);
|
|
|
|
randomize_tensor_normal(layer.wq, &rnd);
|
|
randomize_tensor_normal(layer.wk, &rnd);
|
|
randomize_tensor_normal(layer.wv, &rnd);
|
|
randomize_tensor_normal(layer.wo, &rnd);
|
|
|
|
randomize_tensor_normal(layer.ffn_norm, &rnd);
|
|
|
|
randomize_tensor_normal(layer.w1, &rnd);
|
|
randomize_tensor_normal(layer.w2, &rnd);
|
|
randomize_tensor_normal(layer.w3, &rnd);
|
|
}
|
|
}
|
|
|
|
void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
|
|
GGML_ASSERT(tensor->n_dims == 1);
|
|
GGML_ASSERT(tensor->ne[0] == ne0);
|
|
}
|
|
|
|
void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
|
|
GGML_ASSERT(tensor->n_dims == 2);
|
|
GGML_ASSERT(tensor->ne[0] == ne0);
|
|
GGML_ASSERT(tensor->ne[1] == ne1);
|
|
}
|
|
|
|
void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
|
|
GGML_ASSERT(tensor->n_dims == 3);
|
|
GGML_ASSERT(tensor->ne[0] == ne0);
|
|
GGML_ASSERT(tensor->ne[1] == ne1);
|
|
GGML_ASSERT(tensor->ne[2] == ne2);
|
|
}
|
|
|
|
void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
|
|
GGML_ASSERT(tensor->n_dims == 4);
|
|
GGML_ASSERT(tensor->ne[0] == ne0);
|
|
GGML_ASSERT(tensor->ne[1] == ne1);
|
|
GGML_ASSERT(tensor->ne[2] == ne2);
|
|
GGML_ASSERT(tensor->ne[3] == ne3);
|
|
}
|
|
|
|
static size_t hash(void * p) {
|
|
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
|
}
|
|
|
|
static size_t hash_find(void * hash_table[], void * p) {
|
|
size_t h = hash(p);
|
|
|
|
// linear probing
|
|
size_t i = h;
|
|
while (hash_table[i] != NULL && hash_table[i] != p) {
|
|
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
|
if (i == h) {
|
|
// visited all hash table entries -> not found
|
|
return GGML_GRAPH_HASHTABLE_SIZE;
|
|
}
|
|
}
|
|
return i;
|
|
}
|
|
|
|
static bool hash_insert(void * hash_table[], void * p) {
|
|
//size_t h = hash(p);
|
|
size_t i = hash_find(hash_table, p);
|
|
|
|
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
|
|
|
if (hash_table[i] == p) {
|
|
return true;
|
|
}
|
|
|
|
// insert
|
|
GGML_ASSERT(hash_table[i] == NULL);
|
|
hash_table[i] = p;
|
|
return false;
|
|
}
|
|
|
|
static bool hash_contains(void * hash_table[], void * p) {
|
|
size_t i = hash_find(hash_table, p);
|
|
return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
|
|
}
|
|
|
|
struct hash_map {
|
|
void * keys[GGML_GRAPH_HASHTABLE_SIZE];
|
|
void * vals[GGML_GRAPH_HASHTABLE_SIZE];
|
|
};
|
|
//static const size_t HASH_MAP_SIZE = sizeof(struct hash_map);
|
|
|
|
struct hash_map * new_hash_map() {
|
|
struct hash_map * result = new struct hash_map;
|
|
for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
|
|
result->keys[i] = NULL;
|
|
result->vals[i] = NULL;
|
|
}
|
|
return result;
|
|
};
|
|
|
|
void free_hash_map(struct hash_map * map) {
|
|
delete map;
|
|
}
|
|
|
|
static bool ggml_is_view(struct ggml_tensor * t) {
|
|
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
|
|
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
|
|
}
|
|
|
|
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
|
|
switch (t->op) {
|
|
case GGML_OP_PERMUTE:
|
|
case GGML_OP_RESHAPE:
|
|
case GGML_OP_TRANSPOSE:
|
|
case GGML_OP_VIEW:
|
|
return t->src[0];
|
|
case GGML_OP_CPY:
|
|
return t->src[1];
|
|
default:
|
|
return NULL;
|
|
}
|
|
}
|
|
|
|
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
|
|
struct ggml_tensor * parent = t;
|
|
do {
|
|
parent = get_view_parent(parent);
|
|
} while (ggml_is_view(parent));
|
|
return parent;
|
|
}
|
|
|
|
struct ggml_tensor * ggml_recompute_graph_node(
|
|
struct ggml_context * ctx,
|
|
struct ggml_cgraph * graph,
|
|
struct hash_map * replacements,
|
|
struct ggml_tensor * node) {
|
|
|
|
if (node == NULL) {
|
|
return NULL;
|
|
}
|
|
|
|
if (node->is_param) {
|
|
return node;
|
|
}
|
|
|
|
if (!hash_contains(graph->visited_hash_table, node)) {
|
|
return node;
|
|
}
|
|
|
|
int count_children = 0;
|
|
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
|
if (node->src[k]) {
|
|
++count_children;
|
|
}
|
|
}
|
|
|
|
if (count_children == 0) {
|
|
return node;
|
|
}
|
|
|
|
size_t i = hash_find(replacements->keys, node);
|
|
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
|
if (replacements->keys[i] == node) {
|
|
return (struct ggml_tensor *) replacements->vals[i];
|
|
}
|
|
|
|
struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
|
|
|
|
// insert clone into replacements
|
|
GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
|
|
replacements->keys[i] = node;
|
|
replacements->vals[i] = clone;
|
|
|
|
clone->op = node->op;
|
|
clone->grad = node->grad;
|
|
clone->is_param = node->is_param;
|
|
clone->extra = node->extra;
|
|
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
|
|
clone->nb[k] = node->nb[k];
|
|
}
|
|
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
|
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
|
|
}
|
|
if (ggml_is_view(clone)) {
|
|
struct ggml_tensor * source = get_view_source(clone);
|
|
GGML_ASSERT(source != NULL);
|
|
clone->data = source->data;
|
|
}
|
|
|
|
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
|
|
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
|
|
memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
|
|
ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
|
|
|
|
return clone;
|
|
};
|
|
|
|
void ggml_build_backward_gradient_checkpointing(
|
|
struct ggml_context * ctx,
|
|
struct ggml_cgraph * gf,
|
|
struct ggml_cgraph * gb,
|
|
struct ggml_cgraph * gb_tmp,
|
|
struct ggml_tensor * * checkpoints,
|
|
int n_checkpoints) {
|
|
*gb_tmp = *gf;
|
|
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
|
|
|
|
if (n_checkpoints <= 0) {
|
|
*gb = *gb_tmp;
|
|
return;
|
|
}
|
|
|
|
struct hash_map * replacements = new_hash_map();
|
|
|
|
// insert checkpoints in replacements
|
|
for (int i = 0; i < n_checkpoints; ++i) {
|
|
size_t k = hash_find(replacements->keys, checkpoints[i]);
|
|
GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
|
GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
|
|
replacements->keys[k] = checkpoints[i];
|
|
replacements->vals[k] = checkpoints[i];
|
|
}
|
|
|
|
*gb = *gf;
|
|
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
|
|
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
|
|
// by recomputing them from checkpoints
|
|
for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
|
|
struct ggml_tensor * node = gb_tmp->nodes[i];
|
|
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
|
// insert new tensors recomputing src, reusing already made replacements,
|
|
// remember replacements: remember new tensors with mapping from corresponding gf nodes
|
|
// recurse for input tensors,
|
|
// unless (i.e. terminating when) input tensors are checkpoints
|
|
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
|
|
}
|
|
// insert rewritten backward node with replacements made into resulting backward graph gb
|
|
ggml_build_forward_expand(gb, node);
|
|
}
|
|
|
|
free_hash_map(replacements);
|
|
}
|
|
|
|
struct ggml_tensor * llama_build_train_graphs(
|
|
struct my_llama_model * model,
|
|
struct ggml_allocr * alloc,
|
|
struct ggml_context * ctx,
|
|
struct ggml_cgraph * gf,
|
|
struct ggml_cgraph * gb,
|
|
struct ggml_cgraph * gb_tmp,
|
|
struct ggml_tensor * * logits,
|
|
struct ggml_tensor * tokens_input,
|
|
struct ggml_tensor * targets,
|
|
const int n_tokens,
|
|
const int n_batch,
|
|
const bool enable_flash_attn,
|
|
const bool enable_checkpointing) {
|
|
|
|
ggml_set_scratch(ctx, { 0, 0, nullptr, });
|
|
const int n_past = 0;
|
|
const int N = n_tokens;
|
|
const auto & hparams = model->hparams;
|
|
const int n_ctx = hparams.n_ctx;
|
|
const int n_vocab = hparams.n_vocab;
|
|
const int n_embd = hparams.n_embd;
|
|
const int n_layer = hparams.n_layer;
|
|
const int n_head = hparams.n_head;
|
|
const int n_rot = hparams.n_rot;
|
|
const int n_ff = hparams.n_ff;
|
|
const float f_norm_rms_eps = hparams.f_norm_rms_eps;
|
|
const float rope_freq_base = hparams.rope_freq_base;
|
|
const float rope_freq_scale = hparams.rope_freq_scale;
|
|
|
|
auto set_name = [](struct ggml_tensor * t, const char * n) {
|
|
ggml_set_name(t, n);
|
|
if (t->grad) {
|
|
ggml_format_name(t->grad, "%s->grad", n);
|
|
}
|
|
};
|
|
|
|
// rope has so much parameters that we make a custom function for it
|
|
auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
|
|
(struct ggml_tensor * t) -> struct ggml_tensor * {
|
|
// not capturing these, to silcence warnings
|
|
const int n_past = 0;
|
|
const int rope_mode = 0;
|
|
|
|
return ggml_rope_custom(ctx,
|
|
t, n_past, n_rot, rope_mode, n_ctx,
|
|
rope_freq_base, rope_freq_scale);
|
|
};
|
|
|
|
set_name(tokens_input, "tokens_input");
|
|
set_name(targets, "targets");
|
|
|
|
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
|
|
struct ggml_tensor * t00 = ggml_reshape_1d(ctx, tokens_input, N*n_batch); set_name(t00, "t00"); assert_shape_1d(t00, N*n_batch);
|
|
struct ggml_tensor * t01 = ggml_get_rows(ctx, model->tok_embeddings, t00); set_name(t01, "t01"); assert_shape_2d(t01, n_embd, N*n_batch);
|
|
|
|
struct ggml_tensor * cur = t01;
|
|
|
|
std::vector<struct ggml_tensor *> checkpoints;
|
|
checkpoints.push_back(tokens_input);
|
|
checkpoints.push_back(targets);
|
|
checkpoints.push_back(t00);
|
|
checkpoints.push_back(t01);
|
|
|
|
struct ggml_tensor * kv_scale;
|
|
if (!enable_flash_attn) {
|
|
kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
|
|
}
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
struct my_llama_layer & layer = model->layers[il];
|
|
struct ggml_tensor * t02 = ggml_rms_norm (ctx, cur, f_norm_rms_eps); set_name(t02, "t02"); assert_shape_2d(t02, n_embd, N*n_batch);
|
|
struct ggml_tensor * t03 = ggml_repeat (ctx, layer.attention_norm, t02); set_name(t03, "t03"); assert_shape_2d(t03, n_embd, N*n_batch);
|
|
struct ggml_tensor * t04 = ggml_mul (ctx, t03, t02); set_name(t04, "t04"); assert_shape_2d(t04, n_embd, N*n_batch);
|
|
struct ggml_tensor * t05 = ggml_mul_mat (ctx, layer.wq, t04); set_name(t05, "t05"); assert_shape_2d(t05, n_embd, N*n_batch);
|
|
struct ggml_tensor * t06 = ggml_reshape_4d (ctx, t05, n_embd/n_head, n_head, N, n_batch); set_name(t06, "t06"); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
|
|
struct ggml_tensor * t07 = rope (t06); set_name(t07, "t07"); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
|
|
struct ggml_tensor * t08 = ggml_mul_mat (ctx, layer.wk, t04); set_name(t08, "t08"); assert_shape_2d(t08, n_embd, N*n_batch);
|
|
struct ggml_tensor * t09 = ggml_reshape_4d (ctx, t08, n_embd/n_head, n_head, N, n_batch); set_name(t09, "t09"); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
|
|
struct ggml_tensor * t10 = rope (t09); set_name(t10, "t10"); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
|
|
struct ggml_tensor * t11 = ggml_mul_mat (ctx, t04, layer.wv); set_name(t11, "t11"); assert_shape_2d(t11, N*n_batch, n_embd);
|
|
struct ggml_tensor * t12 = ggml_reshape_4d (ctx, t11, N, n_batch, n_embd/n_head, n_head); set_name(t12, "t12"); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
|
|
struct ggml_tensor * t13 = ggml_permute (ctx, t07, 0, 2, 1, 3); set_name(t13, "t13"); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch);
|
|
struct ggml_tensor * t14 = ggml_permute (ctx, t10, 0, 2, 1, 3); set_name(t14, "t14"); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch);
|
|
struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
|
|
struct ggml_tensor * t16;
|
|
if (enable_flash_attn) {
|
|
t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
|
|
} else {
|
|
struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch);
|
|
struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch);
|
|
struct ggml_tensor * t16_2 = ggml_diag_mask_inf_inplace(ctx, t16_1, n_past); set_name(t16_2, "t16_2"); assert_shape_4d(t16_2, N, N, n_head, n_batch);
|
|
struct ggml_tensor * t16_3 = ggml_soft_max_inplace (ctx, t16_2); set_name(t16_3, "t16_3"); assert_shape_4d(t16_3, N, N, n_head, n_batch);
|
|
t16 = ggml_mul_mat(ctx, t15, t16_3); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
|
|
}
|
|
struct ggml_tensor * t17 = ggml_permute (ctx, t16, 0, 2, 1, 3); set_name(t17, "t17"); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch);
|
|
struct ggml_tensor * t18 = ggml_cont (ctx, t17); set_name(t18, "t18"); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch);
|
|
struct ggml_tensor * t19 = ggml_reshape_2d (ctx, t18, n_embd, N*n_batch); set_name(t19, "t19"); assert_shape_2d(t19, n_embd, N*n_batch);
|
|
struct ggml_tensor * t20 = ggml_mul_mat (ctx, layer.wo, t19); set_name(t20, "t20"); assert_shape_2d(t20, n_embd, N*n_batch);
|
|
struct ggml_tensor * t21 = ggml_add (ctx, t20, cur); set_name(t21, "t21"); assert_shape_2d(t21, n_embd, N*n_batch);
|
|
struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, f_norm_rms_eps); set_name(t22, "t22"); assert_shape_2d(t22, n_embd, N*n_batch);
|
|
struct ggml_tensor * t23 = ggml_repeat (ctx, layer.ffn_norm, t22); set_name(t23, "t23"); assert_shape_2d(t23, n_embd, N*n_batch);
|
|
struct ggml_tensor * t24 = ggml_mul (ctx, t23, t22); set_name(t24, "t24"); assert_shape_2d(t24, n_embd, N*n_batch);
|
|
struct ggml_tensor * t25 = ggml_mul_mat (ctx, layer.w3, t24); set_name(t25, "t25"); assert_shape_2d(t25, n_ff, N*n_batch);
|
|
struct ggml_tensor * t26 = ggml_mul_mat (ctx, layer.w1, t24); set_name(t26, "t26"); assert_shape_2d(t26, n_ff, N*n_batch);
|
|
struct ggml_tensor * t27 = ggml_silu (ctx, t26); set_name(t27, "t27"); assert_shape_2d(t27, n_ff, N*n_batch);
|
|
struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); set_name(t28, "t28"); assert_shape_2d(t28, n_ff, N*n_batch);
|
|
struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.w2, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch);
|
|
struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); set_name(t30, "t30"); assert_shape_2d(t30, n_embd, N*n_batch);
|
|
cur = t30;
|
|
checkpoints.push_back(cur);
|
|
}
|
|
struct ggml_tensor * t31 = ggml_rms_norm (ctx, cur, f_norm_rms_eps); set_name(t31, "t31"); assert_shape_2d(t31, n_embd, N*n_batch);
|
|
struct ggml_tensor * t32 = ggml_repeat (ctx, model->norm, t31); set_name(t32, "t32"); assert_shape_2d(t32, n_embd, N*n_batch);
|
|
struct ggml_tensor * t33 = ggml_mul (ctx, t32, t31); set_name(t33, "t33"); assert_shape_2d(t33, n_embd, N*n_batch);
|
|
struct ggml_tensor * t34 = ggml_mul_mat (ctx, model->output, t33); set_name(t34, "t34"); assert_shape_2d(t34, n_vocab, N*n_batch);
|
|
struct ggml_tensor * t35 = ggml_reshape_3d (ctx, t34, n_vocab, N, n_batch); set_name(t35, "t35"); assert_shape_3d(t35, n_vocab, N, n_batch);
|
|
struct ggml_tensor * t36 = ggml_cross_entropy_loss(ctx, t35, targets); set_name(t36, "t36"); assert_shape_1d(t36, 1);
|
|
|
|
checkpoints.push_back(t31);
|
|
checkpoints.push_back(t32);
|
|
checkpoints.push_back(t33);
|
|
checkpoints.push_back(t34);
|
|
checkpoints.push_back(t35);
|
|
checkpoints.push_back(t36);
|
|
|
|
ggml_build_forward_expand(gf, t36);
|
|
|
|
if (enable_checkpointing) {
|
|
ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
|
|
} else {
|
|
*gb = *gf;
|
|
ggml_build_backward_expand(ctx, gf, gb, true);
|
|
}
|
|
|
|
if (alloc) {
|
|
// make sure some tensors are not reallocated by inserting new temporary nodes depending on them
|
|
int n_leafs_before = gb->n_leafs;
|
|
int n_nodes_before = gb->n_nodes;
|
|
struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
|
|
// output tensors
|
|
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
|
|
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
|
|
// input gradient
|
|
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
|
|
GGML_ASSERT(t36->grad->data == NULL && !ggml_is_view(t36->grad));
|
|
ggml_allocr_alloc(alloc, t36->grad);
|
|
// gradient tensors (will be set to zero by ggml_graph_reset)
|
|
// pinning these produces large unnecessary memory overhead, which will be resolved by PR 2632
|
|
for (int i = 0; i < gf->n_nodes; ++i) {
|
|
if (!gf->grads[i]) continue;
|
|
if (gf->grads[i]->data == NULL && !ggml_is_view(gf->grads[i])) {
|
|
ggml_allocr_alloc(alloc, gf->grads[i]);
|
|
}
|
|
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, gf->grads[i], one));
|
|
}
|
|
// allocating checkpoints in one block to reduce memory fragmentation
|
|
// note: they will be freed in reverse order
|
|
for (int i = 0; i < (int) checkpoints.size(); ++i) {
|
|
if (checkpoints[i]->data == NULL && !ggml_is_view(checkpoints[i])) {
|
|
ggml_allocr_alloc(alloc, checkpoints[i]);
|
|
}
|
|
}
|
|
|
|
//int n_leafs_after = gb->n_leafs;
|
|
//int n_nodes_after = gb->n_nodes;
|
|
|
|
ggml_allocr_alloc_graph(alloc, gb);
|
|
|
|
// remove the additional nodes and leafs
|
|
for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
|
|
gb->leafs[i] = NULL;
|
|
}
|
|
for (int i = n_nodes_before; i < gb->n_nodes; ++i) {
|
|
gb->nodes[i] = NULL;
|
|
}
|
|
gb->n_leafs = n_leafs_before;
|
|
gb->n_nodes = n_nodes_before;
|
|
}
|
|
|
|
*logits = t35;
|
|
return t36;
|
|
}
|
|
|
|
void set_f32_3d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int64_t i2, float value) {
|
|
float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
|
|
*ptr = value;
|
|
}
|
|
|
|
void set_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, float value) {
|
|
float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
|
|
*ptr = value;
|
|
}
|
|
|
|
void set_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int32_t value) {
|
|
int32_t * ptr = (int32_t *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
|
|
*ptr = value;
|
|
}
|
|
|
|
float get_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
|
|
float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
|
|
return *ptr;
|
|
}
|
|
|
|
int32_t get_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
|
|
int32_t * ptr = (int32_t *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
|
|
return *ptr;
|
|
}
|
|
|
|
void print_row(struct ggml_tensor * probs, int i) {
|
|
for (int k = 0; k < probs->ne[0]; ++k) {
|
|
float p = get_f32_2d(probs, k, i);
|
|
printf(" %.2f", p);
|
|
}
|
|
printf("\n");
|
|
}
|
|
|
|
void print_matrix(struct ggml_tensor * probs) {
|
|
assert(probs->n_dims == 2);
|
|
for (int i = 0; i < probs->ne[1]; ++i) {
|
|
for (int k = 0; k < probs->ne[0]; ++k) {
|
|
float p = get_f32_2d(probs, k, i);
|
|
printf(" %.2f", p);
|
|
}
|
|
printf("\n");
|
|
}
|
|
}
|
|
|
|
void get_example_targets(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
|
|
int n_tokens = tokens_input->ne[0];
|
|
int n_vocab = target_logits->ne[0];
|
|
|
|
size_t sample = train_samples[example_id % n_train_samples];
|
|
GGML_ASSERT(sample+n_tokens-1 < n_train_data);
|
|
|
|
ggml_set_f32(target_logits, -1.0f/n_vocab);
|
|
ggml_set_f32(target_probs, 0.0f);
|
|
ggml_set_i32_1d(tokens_input, 0, llama_token_bos(lctx));
|
|
for (int i=1; i<n_tokens+1; ++i) {
|
|
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
|
|
set_f32_2d(target_logits, token, i-1, +1.0f);
|
|
set_f32_2d(target_probs, token, i-1, +1.0f);
|
|
if (i<n_tokens) {
|
|
ggml_set_i32_1d(tokens_input, i, token);
|
|
}
|
|
}
|
|
}
|
|
|
|
void get_example_targets_batch(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
|
|
GGML_ASSERT(tokens_input->n_dims == 2);
|
|
GGML_ASSERT(target_logits->n_dims == 3);
|
|
GGML_ASSERT(target_probs->n_dims == 3);
|
|
int n_vocab = target_logits->ne[0];
|
|
int n_tokens = tokens_input->ne[0];
|
|
int n_batch = tokens_input->ne[1];
|
|
GGML_ASSERT(n_tokens == target_logits->ne[1]);
|
|
GGML_ASSERT(n_batch == target_logits->ne[2]);
|
|
GGML_ASSERT(n_vocab == target_probs->ne[0]);
|
|
GGML_ASSERT(n_tokens == target_probs->ne[1]);
|
|
GGML_ASSERT(n_batch == target_probs->ne[2]);
|
|
|
|
ggml_set_f32(target_logits, -1.0f/n_vocab);
|
|
ggml_set_f32(target_probs, 0.0f);
|
|
// printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
|
|
for (int k=0; k<n_batch; ++k) {
|
|
// printf("%s: batch %d\n", __func__, k);
|
|
size_t sample_idx = (example_id*n_batch + k) % n_train_samples;
|
|
size_t sample = train_samples[sample_idx];
|
|
// printf("%s: sample_idx=%zu sample=%zu\n", __func__, sample_idx, sample);
|
|
GGML_ASSERT(sample+n_tokens-1 < n_train_data);
|
|
|
|
set_i32_2d(tokens_input, 0, k, llama_token_bos(lctx));
|
|
for (int i=1; i<n_tokens+1; ++i) {
|
|
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
|
|
set_f32_3d(target_logits, token, i-1, k, +1.0f);
|
|
set_f32_3d(target_probs, token, i-1, k, +1.0f);
|
|
if (i<n_tokens) {
|
|
set_i32_2d(tokens_input, i, k, token);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
#ifdef __GNUC__
|
|
#ifdef __MINGW32__
|
|
__attribute__((format(gnu_printf, 1, 2)))
|
|
#else
|
|
__attribute__((format(printf, 1, 2)))
|
|
#endif
|
|
#endif
|
|
static std::string format(const char * fmt, ...) {
|
|
va_list ap, ap2;
|
|
va_start(ap, fmt);
|
|
va_copy(ap2, ap);
|
|
int size = vsnprintf(NULL, 0, fmt, ap);
|
|
GGML_ASSERT(size >= 0 && size < INT_MAX);
|
|
std::vector<char> buf(size + 1);
|
|
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
|
GGML_ASSERT(size2 == size);
|
|
va_end(ap2);
|
|
va_end(ap);
|
|
return std::string(buf.data(), size);
|
|
}
|
|
|
|
int tokenize_file(struct llama_context * lctx, const char * filename, std::vector<llama_token>& out) {
|
|
FILE * fp = std::fopen(filename, "rb");
|
|
if (fp == NULL) {
|
|
return 0;
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
GGML_ASSERT(_fseeki64(fp, (__int64) 0, SEEK_END) == 0);
|
|
#else
|
|
GGML_ASSERT(std::fseek(fp, (long) 0, SEEK_END) == 0);
|
|
#endif
|
|
|
|
size_t size = 0;
|
|
#ifdef _WIN32
|
|
__int64 ret = _ftelli64(fp);
|
|
size = ret;
|
|
#else
|
|
long ret = std::ftell(fp);
|
|
size = ret;
|
|
#endif
|
|
|
|
#ifdef _WIN32
|
|
GGML_ASSERT(_fseeki64(fp, (__int64) 0, SEEK_SET) == 0);
|
|
#else
|
|
GGML_ASSERT(std::fseek(fp, (long) 0, SEEK_SET) == 0);
|
|
#endif
|
|
|
|
std::vector<char> buf;
|
|
buf.resize(size+1);
|
|
out.resize(size+1);
|
|
|
|
if (std::fread(buf.data(), size, 1, fp) != 1) {
|
|
throw std::runtime_error(std::string("unexpectedly reached end of file"));
|
|
}
|
|
if (ferror(fp)) {
|
|
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
|
}
|
|
|
|
buf[size] = '\0';
|
|
|
|
int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false);
|
|
if (n_tokens < 0) {
|
|
out.resize(-n_tokens);
|
|
n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false);
|
|
}
|
|
GGML_ASSERT(n_tokens >= 0);
|
|
out.resize(n_tokens);
|
|
|
|
bool verify = false;
|
|
if (verify) {
|
|
const char * in = buf.data();
|
|
const char * end = buf.data() + buf.size();
|
|
for (int i = 0; i < (int) out.size(); ++i) {
|
|
std::string s = llama_token_to_piece(lctx, out[i]);
|
|
int len = s.length();
|
|
if (in >= end) {
|
|
printf("%s: unexpected end of original text.\n", __func__);
|
|
break;
|
|
}
|
|
const bool matches = (strncmp(in, s.c_str(), len) == 0);
|
|
if (matches) {
|
|
in += len;
|
|
} else {
|
|
printf("%s: mismatch: expected '%s', but got '%s'\n", __func__, std::string(in, len).c_str(), s.c_str());
|
|
}
|
|
}
|
|
}
|
|
|
|
return n_tokens;
|
|
}
|
|
|
|
void shuffle_ints(int * begin, int * end) {
|
|
if (end <= begin) return;
|
|
int max=begin[0];
|
|
for (int i=1; i<end-begin; ++i) {
|
|
if (begin[i] > max) {
|
|
max = begin[i];
|
|
}
|
|
}
|
|
std::vector<float> vals;
|
|
vals.resize(max+1);
|
|
for (int i=0; i<max+1; ++i) {
|
|
vals[i] = frand();
|
|
}
|
|
std::sort(begin, end, [&vals](int a, int b){
|
|
return vals.at(a) < vals.at(b);
|
|
});
|
|
}
|
|
|
|
#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
|
|
{ \
|
|
const std::string skey(key); \
|
|
const int kid = gguf_find_key(ctx, skey.c_str()); \
|
|
if (kid >= 0) { \
|
|
enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
|
|
if (ktype != (type)) { \
|
|
throw std::runtime_error(format("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype))); \
|
|
} \
|
|
(dst) = func(ctx, kid); \
|
|
} else if (req) { \
|
|
throw std::runtime_error(format("key not found in model: %s", skey.c_str())); \
|
|
} \
|
|
}
|
|
|
|
|
|
bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) {
|
|
GGML_ASSERT(a != NULL);
|
|
GGML_ASSERT(b != NULL);
|
|
GGML_ASSERT(a->type == b->type);
|
|
GGML_ASSERT(ggml_are_same_shape(a, b));
|
|
GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b));
|
|
|
|
return true;
|
|
}
|
|
|
|
void read_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) {
|
|
if (dst == NULL) {
|
|
return;
|
|
}
|
|
struct ggml_tensor * t = ggml_get_tensor(ctx, name);
|
|
GGML_ASSERT(are_same_layout(dst, t));
|
|
memcpy(dst->data, t->data, ggml_nbytes(t));
|
|
|
|
if (strlen(ggml_get_name(dst)) == 0) {
|
|
ggml_set_name(dst, name);
|
|
}
|
|
}
|
|
|
|
void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
|
|
// NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
|
|
|
|
uint32_t file_version;
|
|
GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION);
|
|
GGML_ASSERT(file_version == 0);
|
|
|
|
GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT);
|
|
GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT);
|
|
GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED);
|
|
|
|
uint64_t nx;
|
|
GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT);
|
|
opt->nx = (size_t) nx;
|
|
|
|
// don't call ggml_opt_init until optimizer type and optimizer specific parameters are know
|
|
|
|
std::string opt_type;
|
|
GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE);
|
|
if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) {
|
|
opt->params.type = GGML_OPT_ADAM;
|
|
|
|
GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS);
|
|
GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS);
|
|
GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT);
|
|
|
|
GGML_ASSERT(opt->ctx != NULL);
|
|
ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
|
|
|
|
read_tensor_by_name(opt->adam.m, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
|
|
read_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
|
|
read_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
|
|
} else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) {
|
|
opt->params.type = GGML_OPT_LBFGS;
|
|
|
|
GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT);
|
|
GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS);
|
|
GGUF_GET_KEY(fctx, opt->lbfgs.step, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP);
|
|
GGUF_GET_KEY(fctx, opt->lbfgs.j, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J);
|
|
GGUF_GET_KEY(fctx, opt->lbfgs.k, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K);
|
|
GGUF_GET_KEY(fctx, opt->lbfgs.end, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END);
|
|
GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT);
|
|
|
|
GGML_ASSERT(opt->ctx != NULL);
|
|
ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
|
|
|
|
read_tensor_by_name(opt->lbfgs.x, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
|
|
read_tensor_by_name(opt->lbfgs.xp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
|
|
read_tensor_by_name(opt->lbfgs.g, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
|
|
read_tensor_by_name(opt->lbfgs.gp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
|
|
read_tensor_by_name(opt->lbfgs.d, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
|
|
read_tensor_by_name(opt->lbfgs.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
|
|
read_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
|
|
read_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
|
|
read_tensor_by_name(opt->lbfgs.lms, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
|
|
read_tensor_by_name(opt->lbfgs.lmy, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
|
|
} else {
|
|
throw std::runtime_error("unknown optimizer type\n");
|
|
}
|
|
}
|
|
|
|
void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) {
|
|
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0);
|
|
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past);
|
|
gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx);
|
|
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter);
|
|
gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
|
|
|
|
switch (opt->params.type) {
|
|
case GGML_OPT_ADAM:
|
|
{
|
|
gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM);
|
|
gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best);
|
|
gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, opt->adam.fx_prev);
|
|
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement);
|
|
|
|
ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
|
|
ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
|
|
if (opt->adam.pf) {
|
|
ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
|
|
}
|
|
|
|
gguf_add_tensor(fctx, opt->adam.m);
|
|
gguf_add_tensor(fctx, opt->adam.v);
|
|
if (opt->adam.pf) {
|
|
gguf_add_tensor(fctx, opt->adam.pf);
|
|
}
|
|
} break;
|
|
case GGML_OPT_LBFGS:
|
|
{
|
|
gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS);
|
|
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m);
|
|
gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, opt->lbfgs.fx_best);
|
|
gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, opt->lbfgs.step);
|
|
gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, opt->lbfgs.j);
|
|
gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, opt->lbfgs.k);
|
|
gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, opt->lbfgs.end);
|
|
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement);
|
|
|
|
ggml_set_name(opt->lbfgs.x, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
|
|
ggml_set_name(opt->lbfgs.xp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
|
|
ggml_set_name(opt->lbfgs.g, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
|
|
ggml_set_name(opt->lbfgs.gp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
|
|
ggml_set_name(opt->lbfgs.d, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
|
|
if (opt->lbfgs.pf) {
|
|
ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
|
|
}
|
|
ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
|
|
ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
|
|
ggml_set_name(opt->lbfgs.lms, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
|
|
ggml_set_name(opt->lbfgs.lmy, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
|
|
|
|
gguf_add_tensor(fctx, opt->lbfgs.x);
|
|
gguf_add_tensor(fctx, opt->lbfgs.xp);
|
|
gguf_add_tensor(fctx, opt->lbfgs.g);
|
|
gguf_add_tensor(fctx, opt->lbfgs.gp);
|
|
gguf_add_tensor(fctx, opt->lbfgs.d);
|
|
if (opt->lbfgs.pf) {
|
|
gguf_add_tensor(fctx, opt->lbfgs.pf);
|
|
}
|
|
gguf_add_tensor(fctx, opt->lbfgs.lmal);
|
|
gguf_add_tensor(fctx, opt->lbfgs.lmys);
|
|
gguf_add_tensor(fctx, opt->lbfgs.lms);
|
|
gguf_add_tensor(fctx, opt->lbfgs.lmy);
|
|
} break;
|
|
}
|
|
}
|
|
|
|
void load_llama_model_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model) {
|
|
// NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
|
|
std::string arch;
|
|
|
|
std::vector<char> keybuf;
|
|
keybuf.resize(512);
|
|
auto kv = [&arch, &keybuf](const char * key) -> const char * {
|
|
snprintf(keybuf.data(), keybuf.size(), key, arch.c_str());
|
|
return keybuf.data();
|
|
};
|
|
|
|
std::vector<char> tn_buf;
|
|
tn_buf.resize(GGML_MAX_NAME);
|
|
auto tn = [&tn_buf](const char * key) -> const char * {
|
|
snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", key);
|
|
return tn_buf.data();
|
|
};
|
|
auto tni = [&tn_buf](const char * key, int bid) -> const char * {
|
|
snprintf(tn_buf.data(), tn_buf.size(), key, bid);
|
|
std::string s = tn_buf.data();
|
|
snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", s.c_str());
|
|
return tn_buf.data();
|
|
};
|
|
|
|
GGUF_GET_KEY(fctx, arch, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_GENERAL_ARCHITECTURE);
|
|
GGML_ASSERT(arch == "llama");
|
|
|
|
uint32_t ftype_u;
|
|
GGUF_GET_KEY(fctx, ftype_u, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_GENERAL_FILE_TYPE);
|
|
GGML_ASSERT((enum llama_ftype) ftype_u == LLAMA_FTYPE_ALL_F32);
|
|
|
|
// n_ctx was not saved in earlier checkpoint file versions, so we make it optional here
|
|
GGUF_GET_KEY(fctx, model->hparams.n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_CONTEXT_LENGTH));
|
|
|
|
GGUF_GET_KEY(fctx, model->hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
|
|
GGUF_GET_KEY(fctx, model->hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
|
|
GGUF_GET_KEY(fctx, model->hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
|
|
GGUF_GET_KEY(fctx, model->hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
|
|
|
|
model->hparams.n_rot = model->hparams.n_embd / model->hparams.n_head;
|
|
GGUF_GET_KEY(fctx, model->hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
|
|
|
|
float rope_freq_scale = 1.0f;
|
|
GGUF_GET_KEY(fctx, model->hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
|
|
GGUF_GET_KEY(fctx, model->hparams.rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
|
|
GGUF_GET_KEY(fctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
|
|
if (rope_freq_scale != 1.0f) {
|
|
model->hparams.rope_freq_scale = 1.0f / rope_freq_scale;
|
|
}
|
|
|
|
init_model(model);
|
|
|
|
read_tensor_by_name(model->tok_embeddings, f_ggml_ctx, tn(LLM_TENSOR_TOKEN_EMBD));
|
|
read_tensor_by_name(model->norm, f_ggml_ctx, tn(LLM_TENSOR_OUTPUT_NORM));
|
|
read_tensor_by_name(model->output, f_ggml_ctx, tn(LLM_TENSOR_OUTPUT));
|
|
|
|
for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
|
|
auto & layer = model->layers[i];
|
|
|
|
read_tensor_by_name(layer.attention_norm, f_ggml_ctx, tni(LLM_TENSOR_ATTN_NORM, i));
|
|
read_tensor_by_name(layer.wq, f_ggml_ctx, tni(LLM_TENSOR_ATTN_Q, i));
|
|
read_tensor_by_name(layer.wk, f_ggml_ctx, tni(LLM_TENSOR_ATTN_K, i));
|
|
read_tensor_by_name(layer.wv, f_ggml_ctx, tni(LLM_TENSOR_ATTN_V, i));
|
|
read_tensor_by_name(layer.wo, f_ggml_ctx, tni(LLM_TENSOR_ATTN_OUT, i));
|
|
read_tensor_by_name(layer.ffn_norm, f_ggml_ctx, tni(LLM_TENSOR_FFN_NORM, i));
|
|
read_tensor_by_name(layer.w1, f_ggml_ctx, tni(LLM_TENSOR_FFN_GATE, i));
|
|
read_tensor_by_name(layer.w2, f_ggml_ctx, tni(LLM_TENSOR_FFN_DOWN, i));
|
|
read_tensor_by_name(layer.w3, f_ggml_ctx, tni(LLM_TENSOR_FFN_UP, i));
|
|
}
|
|
}
|
|
|
|
void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model) {
|
|
const char * arch = "llama";
|
|
enum llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
|
|
|
|
std::vector<char> keybuf;
|
|
keybuf.resize(512);
|
|
auto kv = [arch, &keybuf](const char * key) -> const char * {
|
|
snprintf(keybuf.data(), keybuf.size(), key, arch);
|
|
return keybuf.data();
|
|
};
|
|
|
|
// set arch
|
|
gguf_set_val_str(fctx, LLM_KV_GENERAL_ARCHITECTURE, arch);
|
|
gguf_set_val_u32(fctx, LLM_KV_GENERAL_FILE_TYPE, ftype);
|
|
|
|
// set hparams
|
|
gguf_set_val_u32(fctx, kv(LLM_KV_CONTEXT_LENGTH), model->hparams.n_ctx );
|
|
gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH), model->hparams.n_embd );
|
|
gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH), model->hparams.n_ff );
|
|
gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT), model->hparams.n_head );
|
|
gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT), model->hparams.n_layer );
|
|
gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT), model->hparams.n_rot );
|
|
|
|
gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS), model->hparams.f_norm_rms_eps );
|
|
gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_FREQ_BASE), model->hparams.rope_freq_base ); // TODO load in llama.cpp
|
|
gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_SCALE_LINEAR), 1.0f / model->hparams.rope_freq_scale );
|
|
|
|
// set vocab by copying from vocab_model gguf file
|
|
{
|
|
struct gguf_init_params params = {
|
|
/*.no_alloc = */ false,
|
|
/*.ctx = */ NULL,
|
|
};
|
|
struct gguf_context * vctx = gguf_init_from_file(fn_vocab_model, params);
|
|
|
|
const int token_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_LIST));
|
|
if (token_idx == -1) {
|
|
throw std::runtime_error("cannot find tokenizer vocab in model file\n");
|
|
}
|
|
const uint32_t n_vocab = gguf_get_arr_n(vctx, token_idx);
|
|
|
|
const int score_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_SCORES));
|
|
if (score_idx == -1) {
|
|
throw std::runtime_error("cannot find tokenizer scores in model file\n");
|
|
}
|
|
|
|
const float * scores = (const float * ) gguf_get_arr_data(vctx, score_idx);
|
|
|
|
const int toktype_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE));
|
|
if (toktype_idx == -1) {
|
|
throw std::runtime_error("cannot find token type list in GGUF file\n");
|
|
}
|
|
|
|
const int * toktypes = (const int * ) gguf_get_arr_data(vctx, toktype_idx);
|
|
|
|
std::string tokenizer_name;
|
|
GGUF_GET_KEY(vctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL));
|
|
|
|
gguf_set_val_str(fctx, kv(LLM_KV_TOKENIZER_MODEL), tokenizer_name.c_str());
|
|
gguf_set_arr_data(fctx, kv(LLM_KV_TOKENIZER_SCORES), GGUF_TYPE_FLOAT32, scores, n_vocab);
|
|
gguf_set_arr_data(fctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE), GGUF_TYPE_INT32, toktypes, n_vocab);
|
|
|
|
int32_t special_bos_id = 1;
|
|
int32_t special_eos_id = 2;
|
|
int32_t special_unk_id = 0;
|
|
int32_t special_sep_id = -1;
|
|
int32_t special_pad_id = -1;
|
|
if (tokenizer_name == "llama") {
|
|
// default special tokens
|
|
special_bos_id = 1;
|
|
special_eos_id = 2;
|
|
special_unk_id = 0;
|
|
special_sep_id = -1;
|
|
special_pad_id = -1;
|
|
} else if (tokenizer_name == "gpt2") {
|
|
// read and copy bpe merges
|
|
const int merges_keyidx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_MERGES));
|
|
if (merges_keyidx == -1) {
|
|
throw std::runtime_error("cannot find tokenizer merges in model file\n");
|
|
}
|
|
|
|
const int n_merges = gguf_get_arr_n(vctx, merges_keyidx);
|
|
|
|
std::vector<const char*> merges;
|
|
merges.resize(n_merges);
|
|
for (int i = 0; i < n_merges; i++) {
|
|
merges[i] = gguf_get_arr_str(vctx, merges_keyidx, i);
|
|
}
|
|
gguf_set_arr_str(fctx, kv(LLM_KV_TOKENIZER_MERGES), merges.data(), n_merges);
|
|
|
|
// default special tokens
|
|
special_bos_id = 11;
|
|
special_eos_id = 11;
|
|
special_unk_id = -1;
|
|
special_sep_id = -1;
|
|
special_pad_id = -1;
|
|
} else {
|
|
fprintf(stderr, "%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
|
|
fprintf(stderr, "%s: using default tokenizer: 'llama'", __func__);
|
|
}
|
|
|
|
std::vector<const char*> tokens;
|
|
tokens.resize(n_vocab);
|
|
for (uint32_t i = 0; i < n_vocab; i++) {
|
|
tokens[i] = gguf_get_arr_str(vctx, token_idx, i);
|
|
}
|
|
gguf_set_arr_str(fctx, kv(LLM_KV_TOKENIZER_LIST), tokens.data(), n_vocab);
|
|
|
|
GGUF_GET_KEY(vctx, special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID));
|
|
GGUF_GET_KEY(vctx, special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID));
|
|
GGUF_GET_KEY(vctx, special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID));
|
|
GGUF_GET_KEY(vctx, special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID));
|
|
GGUF_GET_KEY(vctx, special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID));
|
|
|
|
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_BOS_ID), special_bos_id);
|
|
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_EOS_ID), special_eos_id);
|
|
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_UNK_ID), special_unk_id);
|
|
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_SEP_ID), special_sep_id);
|
|
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_PAD_ID), special_pad_id);
|
|
|
|
gguf_free(vctx);
|
|
}
|
|
|
|
// add tensors
|
|
gguf_add_tensor(fctx, model->tok_embeddings);
|
|
gguf_add_tensor(fctx, model->norm);
|
|
gguf_add_tensor(fctx, model->output);
|
|
for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
|
|
auto & layer = model->layers[i];
|
|
|
|
|
|
gguf_add_tensor(fctx, layer.attention_norm);
|
|
gguf_add_tensor(fctx, layer.wq);
|
|
gguf_add_tensor(fctx, layer.wk);
|
|
gguf_add_tensor(fctx, layer.wv);
|
|
gguf_add_tensor(fctx, layer.wo);
|
|
gguf_add_tensor(fctx, layer.ffn_norm);
|
|
gguf_add_tensor(fctx, layer.w1);
|
|
gguf_add_tensor(fctx, layer.w2);
|
|
gguf_add_tensor(fctx, layer.w3);
|
|
}
|
|
}
|
|
|
|
void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model) {
|
|
struct gguf_context * fctx = gguf_init_empty();
|
|
|
|
save_llama_model_gguf(fctx, fn_vocab_model, model);
|
|
|
|
// write file
|
|
const bool only_meta = false;
|
|
gguf_write_to_file(fctx, filename, only_meta);
|
|
gguf_free(fctx);
|
|
}
|
|
|
|
void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct ggml_opt_context * opt) {
|
|
load_llama_model_gguf(fctx, f_ggml_ctx, model);
|
|
|
|
uint32_t file_version;
|
|
GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
|
|
GGML_ASSERT(file_version == 0);
|
|
|
|
GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
|
|
GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
|
|
GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
|
|
|
|
load_opt_context_gguf(fctx, f_ggml_ctx, opt);
|
|
}
|
|
|
|
void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) {
|
|
save_llama_model_gguf(fctx, fn_vocab_model, model);
|
|
|
|
gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 0);
|
|
gguf_set_val_u32(fctx, LLM_KV_TRAINING_ITERATION_COUNT, model->train_its);
|
|
gguf_set_val_u32(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, model->train_samples);
|
|
gguf_set_val_u32(fctx, LLM_KV_TRAINING_TOKEN_COUNT, model->train_tokens);
|
|
|
|
save_opt_context_gguf(fctx, opt);
|
|
}
|
|
|
|
bool load_checkpoint_file(const char * filename, struct my_llama_model * model, struct ggml_opt_context * opt) {
|
|
struct ggml_context * f_ggml_ctx;
|
|
struct gguf_init_params params;
|
|
params.no_alloc = false;
|
|
params.ctx = &f_ggml_ctx;
|
|
struct gguf_context * fctx = gguf_init_from_file(filename, params);
|
|
if (fctx == NULL) {
|
|
return false;
|
|
}
|
|
|
|
load_checkpoint_gguf(fctx, f_ggml_ctx, model, opt);
|
|
|
|
return true;
|
|
}
|
|
|
|
void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) {
|
|
struct gguf_context * fctx = gguf_init_empty();
|
|
|
|
save_checkpoint_gguf(fctx, fn_vocab_model, model, opt);
|
|
|
|
// write file
|
|
const bool only_meta = false;
|
|
gguf_write_to_file(fctx, filename, only_meta);
|
|
gguf_free(fctx);
|
|
}
|
|
|
|
float cosine_decay(const int decay_steps, const float minimum, int step) {
|
|
if (step > decay_steps) {
|
|
step = decay_steps;
|
|
}
|
|
const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps));
|
|
const float decay = (1 - minimum)*cosine_decay + minimum;
|
|
return decay;
|
|
}
|
|
|
|
float cosine_decay_restart(int decay_steps, const float minimum, int step, float restart_step_mult, bool enable_restart) {
|
|
if (enable_restart) {
|
|
while (step > decay_steps) {
|
|
step -= decay_steps;
|
|
decay_steps = (int) restart_step_mult * decay_steps;
|
|
}
|
|
}
|
|
return cosine_decay(decay_steps, minimum, step);
|
|
}
|
|
|
|
struct train_params {
|
|
const char * fn_vocab_model;
|
|
const char * fn_train_data;
|
|
const char * fn_checkpoint_in;
|
|
const char * fn_checkpoint_out;
|
|
const char * fn_model_out;
|
|
|
|
uint32_t seed;
|
|
|
|
int n_ctx;
|
|
int n_embd;
|
|
int n_head;
|
|
int n_layer;
|
|
int n_ff;
|
|
|
|
int n_threads;
|
|
int n_batch;
|
|
int n_examples;
|
|
|
|
float f_norm_rms_eps;
|
|
float rope_freq_base;
|
|
float rope_freq_scale;
|
|
|
|
int print_info_interval;
|
|
|
|
bool samples_start_after_nl;
|
|
bool use_adam;
|
|
bool use_flash;
|
|
bool use_checkpointing;
|
|
bool use_alloc;
|
|
|
|
// only adam
|
|
int warmup;
|
|
int cos_decay_steps;
|
|
float cos_decay_restart;
|
|
float cos_decay_min;
|
|
bool enable_restart;
|
|
|
|
int opt_past;
|
|
float opt_delta;
|
|
int opt_max_no_improvement;
|
|
|
|
int lbfgs_n_iter;
|
|
int adam_n_iter;
|
|
float adam_alpha;
|
|
float adam_min_alpha;
|
|
float adam_decay;
|
|
int adam_decay_min_ndim;
|
|
float adam_beta1;
|
|
float adam_beta2;
|
|
float adam_gclip;
|
|
float adam_eps_f;
|
|
|
|
int mem_model_gb;
|
|
int mem_compute_gb;
|
|
int mem_compute0_gb;
|
|
};
|
|
|
|
struct train_params get_default_train_params() {
|
|
struct train_params params;
|
|
params.fn_vocab_model = "ggml-vic7b-uncensored-q4_0.bin";
|
|
params.fn_train_data = "shakespeare.txt";
|
|
params.fn_checkpoint_in = "checkpoint.bin";
|
|
params.fn_checkpoint_out = "checkpoint.bin";
|
|
params.fn_model_out = "ggml-checkpoint-f32.bin";
|
|
|
|
params.seed = -1;
|
|
|
|
params.n_ctx = 128;
|
|
params.n_embd = 256;
|
|
params.n_head = 8;
|
|
params.n_layer = 16;
|
|
params.n_ff = 768;
|
|
|
|
params.n_threads = 6;
|
|
params.n_batch = 8;
|
|
params.n_examples = 1;
|
|
|
|
params.f_norm_rms_eps = 1e-5;
|
|
params.rope_freq_base = 10000.0f;
|
|
params.rope_freq_scale = 1.0f;
|
|
|
|
params.print_info_interval = 1;
|
|
|
|
params.samples_start_after_nl = false;
|
|
params.use_adam = true;
|
|
params.use_flash = true;
|
|
params.use_checkpointing = true;
|
|
params.use_alloc = true;
|
|
|
|
params.opt_past = 0;
|
|
params.opt_delta = 1e-5f;
|
|
params.opt_max_no_improvement = 0;
|
|
|
|
// only adam
|
|
params.warmup = 100;
|
|
params.cos_decay_steps = 1000;
|
|
params.cos_decay_restart = 1.1f;
|
|
params.cos_decay_min = 0.1f;
|
|
params.enable_restart = false;
|
|
|
|
params.lbfgs_n_iter = 256;
|
|
params.adam_n_iter = 256;
|
|
params.adam_alpha = 1e-3f;
|
|
params.adam_min_alpha = 0;
|
|
params.adam_decay = 1e-1f;
|
|
params.adam_decay_min_ndim = 2;
|
|
params.adam_beta1 = 0.9f;
|
|
params.adam_beta2 = 0.999f;
|
|
params.adam_gclip = 1.0f;
|
|
params.adam_eps_f = 0.0f;
|
|
|
|
params.mem_model_gb = 2;
|
|
params.mem_compute_gb = 24;
|
|
params.mem_compute0_gb = 8;
|
|
return params;
|
|
}
|
|
|
|
void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) {
|
|
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
|
fprintf(stderr, "\n");
|
|
fprintf(stderr, "options:\n");
|
|
fprintf(stderr, " -h, --help show this help message and exit\n");
|
|
fprintf(stderr, " --vocab-model FNAME model path from which to load vocab (default '%s')\n", params->fn_vocab_model);
|
|
fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
|
|
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
|
|
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
|
|
fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out);
|
|
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
|
|
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
|
|
fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd);
|
|
fprintf(stderr, " --ff N Feedforward size used for new models. (default %d)\n", params->n_ff);
|
|
fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head);
|
|
fprintf(stderr, " --layer N Number of layers for new models (default %d)\n", params->n_layer);
|
|
fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
|
|
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
|
|
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
|
|
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
|
|
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
|
|
fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
|
|
fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval);
|
|
fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
|
|
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
|
|
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
|
|
fprintf(stderr, " --no-flash Don't use flash attention \n");
|
|
fprintf(stderr, " --use-flash Use flash attention (default)\n");
|
|
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
|
|
fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
|
|
fprintf(stderr, " --no-alloc Don't use allocator\n");
|
|
fprintf(stderr, " --use-alloc Use allocator (default)\n");
|
|
fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
|
|
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
|
|
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
|
|
fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
|
|
fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
|
|
fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
|
|
fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
|
|
fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
|
|
fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
|
|
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
|
|
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
|
|
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
|
|
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
|
|
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
|
|
fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
|
|
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
|
|
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
|
|
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
|
|
fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
|
|
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
|
|
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
|
|
fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb);
|
|
fprintf(stderr, "\n");
|
|
}
|
|
|
|
bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
|
bool invalid_param = false;
|
|
std::string arg;
|
|
struct train_params default_params = get_default_train_params();
|
|
const std::string arg_prefix = "--";
|
|
|
|
for (int i = 1; i < argc; i++) {
|
|
arg = argv[i];
|
|
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
|
|
std::replace(arg.begin(), arg.end(), '_', '-');
|
|
}
|
|
|
|
if (arg == "--vocab-model") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->fn_vocab_model = argv[i];
|
|
} else if (arg == "--train-data") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->fn_train_data = argv[i];
|
|
} else if (arg == "--checkpoint-in") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->fn_checkpoint_in = argv[i];
|
|
} else if (arg == "--checkpoint-out") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->fn_checkpoint_out = argv[i];
|
|
} else if (arg == "--model-out") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->fn_model_out = argv[i];
|
|
} else if (arg == "-s" || arg == "--seed") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->seed = std::stoi(argv[i]);
|
|
} else if (arg == "-c" || arg == "--ctx") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->n_ctx = std::stoi(argv[i]);
|
|
} else if (arg == "--embd") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->n_embd = std::stoi(argv[i]);
|
|
} else if (arg == "--ff") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->n_ff = std::stoi(argv[i]);
|
|
} else if (arg == "--head") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->n_head = std::stoi(argv[i]);
|
|
} else if (arg == "--layer") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->n_layer = std::stoi(argv[i]);
|
|
} else if (arg == "--norm-rms-eps") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->f_norm_rms_eps = std::stof(argv[i]);
|
|
} else if (arg == "--rope-freq-base") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->rope_freq_base = std::stof(argv[i]);
|
|
} else if (arg == "--rope-freq-scale") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->rope_freq_scale = std::stof(argv[i]);
|
|
} else if (arg == "-t" || arg == "--threads") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->n_threads = std::stoi(argv[i]);
|
|
} else if (arg == "-b" || arg == "--batch") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->n_batch = std::stoi(argv[i]);
|
|
} else if (arg == "-n" || arg == "--examples") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->n_examples = std::stoi(argv[i]);
|
|
} else if (arg == "--print-info-interval") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->print_info_interval = std::stoi(argv[i]);
|
|
} else if (arg == "--samples-after-nl") {
|
|
params->samples_start_after_nl = true;
|
|
} else if (arg == "--use-lbfgs") {
|
|
params->use_adam = false;
|
|
} else if (arg == "--use-adam") {
|
|
params->use_adam = true;
|
|
} else if (arg == "--no-flash") {
|
|
params->use_flash = false;
|
|
} else if (arg == "--use-flash") {
|
|
params->use_flash = true;
|
|
} else if (arg == "--no-checkpointing") {
|
|
params->use_checkpointing = false;
|
|
} else if (arg == "--use-checkpointing") {
|
|
params->use_checkpointing = true;
|
|
} else if (arg == "--no-alloc") {
|
|
params->use_alloc = false;
|
|
} else if (arg == "--use-alloc") {
|
|
params->use_alloc = true;
|
|
} else if (arg == "--warmup") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->warmup = std::stoi(argv[i]);
|
|
} else if (arg == "--cos-decay-steps") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->cos_decay_steps = std::stof(argv[i]);
|
|
} else if (arg == "--cos-decay-restart") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->cos_decay_restart = std::stof(argv[i]);
|
|
} else if (arg == "--cos-decay-min") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->cos_decay_min = std::stof(argv[i]);
|
|
} else if (arg == "--enable-restart") {
|
|
params->enable_restart = true;
|
|
} else if (arg == "--disable-restart") {
|
|
params->enable_restart = false;
|
|
} else if (arg == "--opt-past") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->opt_past = std::stoi(argv[i]);
|
|
} else if (arg == "--opt-delta") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->opt_delta = std::stof(argv[i]);
|
|
} else if (arg == "--opt-max-no-improvement") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->opt_max_no_improvement = std::stoi(argv[i]);
|
|
} else if (arg == "--adam-epsf") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->adam_eps_f = std::stof(argv[i]);
|
|
} else if (arg == "--adam-iter") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->adam_n_iter = std::stoi(argv[i]);
|
|
} else if (arg == "--adam-alpha") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->adam_alpha = std::stof(argv[i]);
|
|
} else if (arg == "--adam-min-alpha") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->adam_min_alpha = std::stof(argv[i]);
|
|
} else if (arg == "--adam-decay") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->adam_decay = std::stof(argv[i]);
|
|
} else if (arg == "--adam-decay-min-ndim") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->adam_decay_min_ndim = std::stoi(argv[i]);
|
|
} else if (arg == "--adam-beta1") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->adam_beta1 = std::stof(argv[i]);
|
|
} else if (arg == "--adam-beta2") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->adam_beta2 = std::stof(argv[i]);
|
|
} else if (arg == "--adam-gclip") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->adam_gclip = std::stof(argv[i]);
|
|
} else if (arg == "--lbfgs-iter") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->lbfgs_n_iter = std::stoi(argv[i]);
|
|
} else if (arg == "--mem-model") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->mem_model_gb = std::stoi(argv[i]);
|
|
} else if (arg == "--mem-compute") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->mem_compute_gb = std::stoi(argv[i]);
|
|
} else if (arg == "--mem-compute0") {
|
|
if (++i >= argc) {
|
|
invalid_param = true;
|
|
break;
|
|
}
|
|
params->mem_compute0_gb = std::stoi(argv[i]);
|
|
} else if (arg == "-h" || arg == "--help") {
|
|
train_print_usage(argc, argv, &default_params);
|
|
exit(0);
|
|
} else {
|
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
|
train_print_usage(argc, argv, &default_params);
|
|
exit(1);
|
|
}
|
|
}
|
|
if (invalid_param) {
|
|
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
|
|
train_print_usage(argc, argv, &default_params);
|
|
exit(1);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
struct opt_callback_data {
|
|
struct train_params * params;
|
|
struct ggml_opt_context * opt;
|
|
struct llama_context * lctx;
|
|
llama_token * tokens_data;
|
|
size_t tokens_size;
|
|
int * samples_data;
|
|
size_t samples_size;
|
|
int shuffle_countdown;
|
|
struct ggml_tensor * tokens_input;
|
|
struct ggml_tensor * target_logits;
|
|
struct ggml_tensor * target_probs;
|
|
};
|
|
|
|
void opt_callback(void * vdata, float * sched) {
|
|
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
|
|
struct train_params * params = data->params;
|
|
struct ggml_opt_context * opt = data->opt;
|
|
int n_batch = params->n_batch;
|
|
|
|
*sched = (opt->iter < params->warmup)
|
|
? (float) opt->iter / (float) params->warmup
|
|
: cosine_decay_restart(
|
|
params->cos_decay_steps,
|
|
params->cos_decay_min,
|
|
opt->iter - params->warmup,
|
|
params->cos_decay_restart,
|
|
params->enable_restart);
|
|
float min_sched = params->adam_min_alpha / params->adam_alpha;
|
|
*sched = min_sched + *sched * (1.0f - min_sched);
|
|
|
|
int impr_plot = std::isnan(opt->loss_after) ? 0 : -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
|
|
printf("%s: iter=%*d, sched=%f loss0=%f loss=%f | improvement: %*d>\n", __func__, 6, opt->iter, *sched, opt->loss_before, opt->loss_after, impr_plot, (int)0);
|
|
|
|
if (data->shuffle_countdown < n_batch) {
|
|
printf("%s: reshuffle samples\n", __func__);
|
|
shuffle_ints(data->samples_data, data->samples_data + data->samples_size);
|
|
for (int i = 0; i < (int) data->samples_size; ++i) {
|
|
GGML_ASSERT(data->samples_data[i]+params->n_ctx-1 < (int) data->tokens_size);
|
|
}
|
|
data->shuffle_countdown = data->samples_size;
|
|
}
|
|
|
|
get_example_targets_batch(
|
|
data->lctx,
|
|
data->samples_data,
|
|
data->samples_size,
|
|
data->tokens_data,
|
|
data->tokens_size,
|
|
opt->iter,
|
|
data->tokens_input,
|
|
data->target_logits,
|
|
data->target_probs);
|
|
|
|
data->shuffle_countdown -= n_batch;
|
|
}
|
|
|
|
int main(int argc, char ** argv) {
|
|
struct train_params params = get_default_train_params();
|
|
|
|
if (!train_params_parse(argc, argv, ¶ms)) {
|
|
return 1;
|
|
}
|
|
|
|
if (params.seed == LLAMA_DEFAULT_SEED) {
|
|
params.seed = time(NULL);
|
|
}
|
|
printf("%s: seed: %u\n", __func__, params.seed);
|
|
srand(params.seed);
|
|
|
|
struct llama_context_params llama_params = llama_context_default_params();
|
|
llama_params.vocab_only = true;
|
|
|
|
struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params);
|
|
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
|
|
|
|
printf("%s: tokenize training data\n", __func__);
|
|
std::vector<llama_token> train_tokens;
|
|
if (tokenize_file(lctx, params.fn_train_data, train_tokens) < 0) {
|
|
fprintf(stderr, "%s: failed to tokenize file '%s'\n", __func__, params.fn_train_data);
|
|
}
|
|
printf("%s: number of training tokens: %d\n", __func__, (int) train_tokens.size());
|
|
|
|
struct my_llama_model model;
|
|
model.hparams.n_vocab = llama_n_vocab(lctx);
|
|
model.hparams.n_ctx = params.n_ctx;
|
|
model.hparams.n_embd = params.n_embd;
|
|
model.hparams.n_head = params.n_head;
|
|
model.hparams.n_layer = params.n_layer;
|
|
model.hparams.n_ff = params.n_ff;
|
|
// llama.cpp requires n_rot to be exactly n_embd / n_head
|
|
model.hparams.n_rot = model.hparams.n_embd / model.hparams.n_head;
|
|
model.hparams.f_norm_rms_eps = params.f_norm_rms_eps;
|
|
model.hparams.rope_freq_base = params.rope_freq_base;
|
|
model.hparams.rope_freq_scale = params.rope_freq_scale;
|
|
|
|
print_params(&model.hparams);
|
|
|
|
std::vector<size_t> token_noccurs;
|
|
std::vector<bool> token_notavail;
|
|
token_noccurs.resize(model.hparams.n_vocab, 0);
|
|
token_notavail.resize(model.hparams.n_vocab, true);
|
|
for (int i = 0; i < (int) train_tokens.size(); ++i) {
|
|
++token_noccurs[train_tokens[i]];
|
|
token_notavail[train_tokens[i]] = false;
|
|
}
|
|
|
|
std::vector<float> token_freq;
|
|
token_freq.resize(model.hparams.n_vocab, 0);
|
|
int n_unique_tokens = 0;
|
|
for (int i = 0; i < (int) token_noccurs.size(); ++i) {
|
|
token_freq[i] = (float) token_noccurs[i] / (float) train_tokens.size();
|
|
n_unique_tokens += (token_noccurs[i] > 0) ? 1 : 0;
|
|
}
|
|
printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
|
|
|
|
struct ggml_init_params lcparams;
|
|
lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb);
|
|
lcparams.mem_buffer = NULL;
|
|
lcparams.no_alloc = false;
|
|
|
|
model.ctx = ggml_init(lcparams);
|
|
|
|
int n_tokens = model.hparams.n_ctx;
|
|
int n_vocab = model.hparams.n_vocab;
|
|
int n_batch = params.n_batch;
|
|
|
|
struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
|
|
memset(opt, 0, sizeof(struct ggml_opt_context));
|
|
|
|
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
|
|
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
|
|
opt_params_adam.print_forward_graph = false;
|
|
opt_params_adam.print_backward_graph = false;
|
|
opt_params_adam.n_threads = params.n_threads;
|
|
opt_params_adam.past = params.opt_past;
|
|
opt_params_adam.delta = params.opt_delta;
|
|
opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
|
|
opt_params_adam.adam.n_iter = params.adam_n_iter;
|
|
opt_params_adam.adam.sched = 1.0f;
|
|
opt_params_adam.adam.alpha = params.adam_alpha;
|
|
opt_params_adam.adam.decay = params.adam_decay;
|
|
opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim;
|
|
opt_params_adam.adam.beta1 = params.adam_beta1;
|
|
opt_params_adam.adam.beta2 = params.adam_beta2;
|
|
opt_params_adam.adam.gclip = params.adam_gclip;
|
|
opt_params_adam.adam.eps_f = params.adam_eps_f;
|
|
|
|
opt_params_lbfgs.print_forward_graph = false;
|
|
opt_params_lbfgs.print_backward_graph = false;
|
|
opt_params_lbfgs.n_threads = params.n_threads;
|
|
opt_params_adam.past = params.opt_past;
|
|
opt_params_adam.delta = params.opt_delta;
|
|
opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
|
|
opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
|
|
|
|
opt->ctx = model.ctx;
|
|
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
|
|
|
|
printf("%s: init model\n", __func__);
|
|
bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, opt);
|
|
if (!existed) {
|
|
init_model(&model);
|
|
}
|
|
set_param_model(&model);
|
|
|
|
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
|
|
|
|
opt->iter = model.train_its;
|
|
printf("%s: opt iter %d\n", __func__, opt->iter);
|
|
|
|
bool from_scratch = !existed;
|
|
if (from_scratch) {
|
|
randomize_model(&model, params.seed, 0.0f, 1.0f, -1.0f, +1.0f);
|
|
}
|
|
|
|
printf("used_mem model: %zu bytes\n", ggml_used_mem(model.ctx));
|
|
// ggml_print_tensor_objects(model.ctx);
|
|
|
|
// TODO: use std::vector<uint8_t> intead of "new"
|
|
size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb);
|
|
uint8_t * compute_addr = new uint8_t[compute_size];
|
|
|
|
size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb);
|
|
uint8_t * compute_buf_0 = new uint8_t[size_buf_0];
|
|
|
|
ggml_allocr * alloc = NULL;
|
|
if (params.use_alloc) {
|
|
static const size_t tensor_alignment = 32;
|
|
alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment);
|
|
}
|
|
|
|
GGML_ASSERT(n_tokens < (int) train_tokens.size());
|
|
std::vector<int> train_samples;
|
|
train_samples.push_back(0);
|
|
for (int i = 1; i < (int) train_tokens.size() - n_tokens; ++i) {
|
|
if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl(lctx))) {
|
|
train_samples.push_back(i);
|
|
}
|
|
}
|
|
shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size());
|
|
for (int i = 0; i < (int) train_samples.size(); ++i) {
|
|
GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size());
|
|
}
|
|
|
|
printf("%s: begin training\n", __func__);
|
|
|
|
struct opt_callback_data opt_cb_data;
|
|
opt_cb_data.params = ¶ms;
|
|
opt_cb_data.opt = opt;
|
|
opt_cb_data.lctx = lctx;
|
|
opt_cb_data.tokens_data = train_tokens.data();
|
|
opt_cb_data.tokens_size = train_tokens.size();
|
|
opt_cb_data.samples_data = train_samples.data();
|
|
opt_cb_data.samples_size = train_samples.size();
|
|
opt_cb_data.shuffle_countdown = train_samples.size();
|
|
opt_cb_data.tokens_input = NULL;
|
|
opt_cb_data.target_logits = NULL;
|
|
opt_cb_data.target_probs = NULL;
|
|
|
|
int64_t t0 = ggml_time_ms();
|
|
|
|
for (int ex = 0; ex < params.n_examples; ++ex) {
|
|
if (ex*n_batch >= (int) train_samples.size()) {
|
|
shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size());
|
|
for (int i = 0; i < (int) train_samples.size(); ++i) {
|
|
GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size());
|
|
}
|
|
}
|
|
|
|
struct ggml_init_params cparams = {
|
|
compute_size, // mem_size
|
|
compute_addr, // mem_buffer
|
|
false, // no_alloc
|
|
};
|
|
struct ggml_context * ctx0 = ggml_init(cparams);
|
|
|
|
ggml_set_no_alloc(ctx0, false);
|
|
|
|
// don't use alloc for input tensors, so we can safely fill them with data
|
|
//struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
|
//struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
|
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
|
struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
|
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
|
|
|
ggml_set_no_alloc(ctx0, (alloc != NULL));
|
|
|
|
if (alloc) {
|
|
ggml_allocr_reset(alloc);
|
|
}
|
|
|
|
opt_cb_data.tokens_input = tokens_input;
|
|
opt_cb_data.target_logits = target_logits;
|
|
opt_cb_data.target_probs = target_probs;
|
|
|
|
int n_past = 0;
|
|
|
|
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
|
struct ggml_cgraph * gb = ggml_new_graph(ctx0);
|
|
struct ggml_cgraph * gb_tmp = params.use_checkpointing
|
|
? ggml_new_graph(ctx0)
|
|
: NULL;
|
|
|
|
GGML_ASSERT(n_past == 0);
|
|
|
|
struct ggml_tensor * loss = NULL;
|
|
struct ggml_tensor * logits = NULL;
|
|
|
|
loss = llama_build_train_graphs(
|
|
&model, alloc, ctx0,
|
|
gf, gb, gb_tmp,
|
|
&logits, tokens_input, target_probs,
|
|
n_tokens, n_batch,
|
|
params.use_flash,
|
|
params.use_checkpointing
|
|
);
|
|
|
|
size_t used_mem_before_opt = ggml_used_mem(ctx0);
|
|
|
|
opt->params.adam.sched = (opt->iter < params.warmup)
|
|
? (float) opt->iter / (float) params.warmup
|
|
: cosine_decay_restart(
|
|
params.cos_decay_steps,
|
|
params.cos_decay_min,
|
|
opt->iter - params.warmup,
|
|
params.cos_decay_restart,
|
|
params.enable_restart);
|
|
|
|
float min_sched = params.adam_min_alpha / params.adam_alpha;
|
|
opt->params.adam.sched = min_sched + opt->params.adam.sched * (1.0f - min_sched);
|
|
|
|
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
|
|
|
|
ggml_opt_resume_g(ctx0, opt, loss, gf, gb, &opt_callback, (void *) &opt_cb_data);
|
|
|
|
size_t used_mem_after_opt = ggml_used_mem(ctx0);
|
|
|
|
int n_iter = params.use_adam ? params.adam_n_iter : params.lbfgs_n_iter;
|
|
model.train_its = opt->iter;
|
|
model.train_samples += n_batch * n_iter;
|
|
model.train_tokens += n_batch * n_tokens * n_iter;
|
|
|
|
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
|
|
printf("Example %d, opt iter %d\n", ex, opt->iter);
|
|
printf("error_before_opt: %.6f\n", opt->loss_before);
|
|
printf("error_after_opt: %.6f\n", opt->loss_after);
|
|
printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt);
|
|
printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
|
|
}
|
|
|
|
ggml_free(ctx0);
|
|
}
|
|
|
|
int64_t t1 = ggml_time_ms();
|
|
int64_t d = t1-t0;
|
|
double dd = (double) d * 1e-3;
|
|
printf("%s: total training time=%f seconds\n", __func__, dd);
|
|
|
|
if (params.n_examples > 0) {
|
|
save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt);
|
|
}
|
|
|
|
if (strlen(params.fn_model_out) > 0) {
|
|
save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model);
|
|
}
|
|
|
|
if (alloc) {
|
|
ggml_allocr_free(alloc);
|
|
}
|
|
|
|
delete[] compute_addr;
|
|
delete[] compute_buf_0;
|
|
ggml_free(model.ctx);
|
|
llama_free(lctx);
|
|
llama_free_model(lmodel);
|
|
return 0;
|
|
}
|