mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 20:04:35 +00:00
Merge branch 'master' into gg/flash-attn
This commit is contained in:
commit
c3cdfffa88
@ -846,7 +846,7 @@ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake
|
|||||||
${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake
|
${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake
|
||||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama)
|
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama)
|
||||||
|
|
||||||
set(GGML_PUBLIC_HEADERS "ggml.h"
|
set(GGML_PUBLIC_HEADERS "ggml.h" "ggml-alloc.h" "ggml-backend.h"
|
||||||
"${GGML_HEADERS_CUDA}" "${GGML_HEADERS_OPENCL}"
|
"${GGML_HEADERS_CUDA}" "${GGML_HEADERS_OPENCL}"
|
||||||
"${GGML_HEADERS_METAL}" "${GGML_HEADERS_MPI}" "${GGML_HEADERS_EXTRA}")
|
"${GGML_HEADERS_METAL}" "${GGML_HEADERS_MPI}" "${GGML_HEADERS_EXTRA}")
|
||||||
|
|
||||||
|
@ -189,6 +189,8 @@ class Model:
|
|||||||
return StableLMModel
|
return StableLMModel
|
||||||
if model_architecture == "QWenLMHeadModel":
|
if model_architecture == "QWenLMHeadModel":
|
||||||
return QwenModel
|
return QwenModel
|
||||||
|
if model_architecture == "Qwen2ForCausalLM":
|
||||||
|
return Model
|
||||||
if model_architecture == "MixtralForCausalLM":
|
if model_architecture == "MixtralForCausalLM":
|
||||||
return MixtralModel
|
return MixtralModel
|
||||||
if model_architecture == "GPT2LMHeadModel":
|
if model_architecture == "GPT2LMHeadModel":
|
||||||
@ -197,6 +199,8 @@ class Model:
|
|||||||
return Phi2Model
|
return Phi2Model
|
||||||
if model_architecture == "PlamoForCausalLM":
|
if model_architecture == "PlamoForCausalLM":
|
||||||
return PlamoModel
|
return PlamoModel
|
||||||
|
if model_architecture == "CodeShellForCausalLM":
|
||||||
|
return CodeShellModel
|
||||||
return Model
|
return Model
|
||||||
|
|
||||||
def _is_model_safetensors(self) -> bool:
|
def _is_model_safetensors(self) -> bool:
|
||||||
@ -234,6 +238,8 @@ class Model:
|
|||||||
return gguf.MODEL_ARCH.STABLELM
|
return gguf.MODEL_ARCH.STABLELM
|
||||||
if arch == "QWenLMHeadModel":
|
if arch == "QWenLMHeadModel":
|
||||||
return gguf.MODEL_ARCH.QWEN
|
return gguf.MODEL_ARCH.QWEN
|
||||||
|
if arch == "Qwen2ForCausalLM":
|
||||||
|
return gguf.MODEL_ARCH.QWEN2
|
||||||
if arch == "MixtralForCausalLM":
|
if arch == "MixtralForCausalLM":
|
||||||
return gguf.MODEL_ARCH.LLAMA
|
return gguf.MODEL_ARCH.LLAMA
|
||||||
if arch == "GPT2LMHeadModel":
|
if arch == "GPT2LMHeadModel":
|
||||||
@ -242,6 +248,8 @@ class Model:
|
|||||||
return gguf.MODEL_ARCH.PHI2
|
return gguf.MODEL_ARCH.PHI2
|
||||||
if arch == "PlamoForCausalLM":
|
if arch == "PlamoForCausalLM":
|
||||||
return gguf.MODEL_ARCH.PLAMO
|
return gguf.MODEL_ARCH.PLAMO
|
||||||
|
if arch == "CodeShellForCausalLM":
|
||||||
|
return gguf.MODEL_ARCH.CODESHELL
|
||||||
|
|
||||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||||
|
|
||||||
@ -1176,6 +1184,70 @@ class PlamoModel(Model):
|
|||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeShellModel(Model):
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
block_count = self.hparams["n_layer"]
|
||||||
|
|
||||||
|
self.gguf_writer.add_name("CodeShell")
|
||||||
|
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||||
|
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||||
|
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||||
|
self.gguf_writer.add_block_count(block_count)
|
||||||
|
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||||
|
self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
|
||||||
|
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||||
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
self.gguf_writer.add_rope_freq_base(10000.0)
|
||||||
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||||
|
self.gguf_writer.add_rope_scaling_factor(1.0)
|
||||||
|
|
||||||
|
def write_tensors(self):
|
||||||
|
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
||||||
|
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||||
|
tensors = dict(self.get_tensors())
|
||||||
|
has_lm_head = "lm_head.weight" in tensors.keys() or "output.weight" in tensors.keys()
|
||||||
|
for name, data_torch in tensors.items():
|
||||||
|
# we don't need these
|
||||||
|
if name.endswith((".attn.rotary_emb.inv_freq")):
|
||||||
|
continue
|
||||||
|
|
||||||
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
|
# convert any unsupported data types to float32
|
||||||
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
|
||||||
|
data = data_torch.squeeze().numpy()
|
||||||
|
|
||||||
|
# map tensor names
|
||||||
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
|
if new_name is None:
|
||||||
|
print(f"Can not map tensor {name!r}")
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
n_dims = len(data.shape)
|
||||||
|
data_dtype = data.dtype
|
||||||
|
|
||||||
|
# if f32 desired, convert any float16 to float32
|
||||||
|
if self.ftype == 0 and data_dtype == np.float16:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
||||||
|
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||||
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
if not has_lm_head and name == "transformer.wte.weight":
|
||||||
|
self.gguf_writer.add_tensor("output.weight", data)
|
||||||
|
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
|
||||||
|
@ -348,7 +348,7 @@ class Params:
|
|||||||
f_rope_freq_base = 1e6
|
f_rope_freq_base = 1e6
|
||||||
|
|
||||||
return Params(
|
return Params(
|
||||||
n_vocab=config.get("vocab_size", model["tok_embeddings.weight"].shape[0]),
|
n_vocab=model["tok_embeddings.weight"].shape[0],
|
||||||
n_embd=config["dim"],
|
n_embd=config["dim"],
|
||||||
n_layer=config["n_layers"],
|
n_layer=config["n_layers"],
|
||||||
n_ctx=n_ctx,
|
n_ctx=n_ctx,
|
||||||
|
32
examples/imatrix/README.md
Normal file
32
examples/imatrix/README.md
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# llama.cpp/examples/imatrix
|
||||||
|
|
||||||
|
Compute an importance matrix for a model and given text dataset. Can be used during quantization to enchance the quality of the quantum models.
|
||||||
|
More information is available here: https://github.com/ggerganov/llama.cpp/pull/4861
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```
|
||||||
|
./imatrix -m <some_fp_model> -f <some_training_data> [-o <output_file>] [--verbosity <verbosity_level>]
|
||||||
|
[-ofreq num_chunks] [-ow <0 or 1>] [other common params]
|
||||||
|
```
|
||||||
|
|
||||||
|
Here `-m` with a model name and `-f` with a file containing training data (such as e.g. `wiki.train.raw`) are mandatory.
|
||||||
|
The parameters in square brackets are optional and have the following meaning:
|
||||||
|
* `-o` (or `--output-file`) specifies the name of the file where the computed data will be stored. If missing `imatrix.dat` is used.
|
||||||
|
* `--verbosity` specifies the verbosity level. If set to `0`, no output other than the perplexity of the processed chunks will be generated. If set to `1`, each time the results are saved a message is written to `stderr`. If `>=2`, a message is output each time data is collected for any tensor. Default verbosity level is `1`.
|
||||||
|
* `-ofreq` (or `--output-frequency`) specifies how often the so far computed result is saved to disk. Default is 10 (i.e., every 10 chunks)
|
||||||
|
* `-ow` (or `--output-weight`) specifies if data will be collected for the `output.weight` tensor. My experience is that it is better to not utilize the importance matrix when quantizing `output.weight`, so this is set to `false` by default.
|
||||||
|
|
||||||
|
For faster computation, make sure to use GPU offloading via the `-ngl` argument
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLAMA_CUBLAS=1 make -j
|
||||||
|
|
||||||
|
# generate importance matrix (imatrix.dat)
|
||||||
|
./imatrix -m ggml-model-f16.gguf -f train-data.txt -ngl 99
|
||||||
|
|
||||||
|
# use the imatrix to perform a Q4_K_M quantization
|
||||||
|
./quantize --imatrix imatrix.dat ggml-model-f16.gguf ./ggml-model-q4_k_m.gguf q4_k_m
|
||||||
|
```
|
@ -80,7 +80,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
|
|||||||
// for simplicity, always copy src0 to host, because it is small
|
// for simplicity, always copy src0 to host, because it is small
|
||||||
// take into account that src0 is not contiguous!
|
// take into account that src0 is not contiguous!
|
||||||
GGML_ASSERT(src0->ne[1] == src1->ne[1]);
|
GGML_ASSERT(src0->ne[1] == src1->ne[1]);
|
||||||
GGML_ASSERT(n_as*ggml_nrows(src0));
|
GGML_ASSERT(n_as*ggml_nrows(src0)*sizeof(int) == GGML_PAD(ggml_nbytes(src0), n_as*sizeof(int)));
|
||||||
m_ids.resize(ggml_nbytes(src0)/sizeof(int));
|
m_ids.resize(ggml_nbytes(src0)/sizeof(int));
|
||||||
ggml_backend_tensor_get(src0, m_ids.data(), 0, ggml_nbytes(src0));
|
ggml_backend_tensor_get(src0, m_ids.data(), 0, ggml_nbytes(src0));
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#include <atomic>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@ -324,6 +325,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
double nll = 0.0;
|
double nll = 0.0;
|
||||||
double nll2 = 0.0;
|
double nll2 = 0.0;
|
||||||
|
|
||||||
|
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
||||||
|
|
||||||
|
std::vector<float> logits;
|
||||||
|
if (num_batches > 1) {
|
||||||
|
logits.reserve((size_t)n_ctx * n_vocab);
|
||||||
|
}
|
||||||
|
|
||||||
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
||||||
|
|
||||||
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||||
@ -332,10 +340,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
const int start = i * n_ctx;
|
const int start = i * n_ctx;
|
||||||
const int end = start + n_ctx;
|
const int end = start + n_ctx;
|
||||||
|
|
||||||
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
|
||||||
|
|
||||||
std::vector<float> logits;
|
|
||||||
|
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
@ -361,8 +365,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
// restore the original token in case it was set to BOS
|
// restore the original token in case it was set to BOS
|
||||||
tokens[batch_start] = token_org;
|
tokens[batch_start] = token_org;
|
||||||
|
|
||||||
const auto * batch_logits = llama_get_logits(ctx);
|
if (num_batches > 1) {
|
||||||
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
const auto * batch_logits = llama_get_logits(ctx);
|
||||||
|
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
@ -391,7 +397,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
// last 256 tokens. Then, we split the input up into context window size chunks to
|
// last 256 tokens. Then, we split the input up into context window size chunks to
|
||||||
// process the entire prompt.
|
// process the entire prompt.
|
||||||
const int first = n_ctx/2;
|
const int first = n_ctx/2;
|
||||||
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
||||||
|
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||||
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
|
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
|
||||||
count += n_ctx - first - 1;
|
count += n_ctx - first - 1;
|
||||||
|
|
||||||
@ -405,6 +412,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
||||||
}
|
}
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
|
logits.clear();
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
@ -422,26 +431,73 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
return {tokens, ppl, logit_history, prob_history};
|
return {tokens, ppl, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<float> evaluate_tokens(llama_context * ctx, std::vector<int> & tokens,
|
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
|
||||||
int n_past, int n_batch, int n_vocab) {
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||||
std::vector<float> result;
|
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||||
result.reserve(tokens.size() * n_vocab);
|
|
||||||
size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
|
llama_batch batch_view = {
|
||||||
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
|
n_tokens,
|
||||||
size_t n_tokens = tokens.size() - i_chunk * n_batch;
|
batch.token + i,
|
||||||
n_tokens = std::min(n_tokens, size_t(n_batch));
|
nullptr,
|
||||||
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
batch.pos + i,
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) {
|
batch.n_seq_id + i,
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
batch.seq_id + i,
|
||||||
return {};
|
batch.logits + i,
|
||||||
|
0, 0, 0, // unused
|
||||||
|
};
|
||||||
|
|
||||||
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
|
if (ret != 0) {
|
||||||
|
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto logits = llama_get_logits(ctx);
|
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
|
||||||
result.insert(result.end(), logits, logits + n_tokens * n_vocab);
|
|
||||||
|
|
||||||
n_past += n_tokens;
|
|
||||||
}
|
}
|
||||||
return result;
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
|
||||||
|
const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
|
||||||
|
constexpr int k_token_chunk = 4;
|
||||||
|
if (eval_results.size() != eval_pairs.size()) {
|
||||||
|
eval_results.resize(eval_pairs.size());
|
||||||
|
}
|
||||||
|
if (eval_pairs.empty()) return;
|
||||||
|
|
||||||
|
size_t max_threads = std::min((eval_pairs.size() + k_token_chunk - 1)/k_token_chunk, workers.size());
|
||||||
|
|
||||||
|
std::atomic<int> counter(0);
|
||||||
|
auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
|
||||||
|
float local_logprobs[k_token_chunk];
|
||||||
|
while (true) {
|
||||||
|
size_t first = counter.fetch_add(k_token_chunk, std::memory_order_relaxed);
|
||||||
|
if (first >= eval_results.size()) break;
|
||||||
|
size_t last = std::min(first + k_token_chunk, eval_results.size());
|
||||||
|
for (size_t i = first; i < last; ++i) {
|
||||||
|
auto logits = batch_logits + eval_pairs[i].first * n_vocab;
|
||||||
|
float max_logit = logits[0];
|
||||||
|
for (int j = 1; j < n_vocab; ++j) {
|
||||||
|
max_logit = std::max(max_logit, logits[j]);
|
||||||
|
}
|
||||||
|
float sum_p = 0.f;
|
||||||
|
for (int j = 0; j < n_vocab; ++j) {
|
||||||
|
sum_p += expf(logits[j] - max_logit);
|
||||||
|
}
|
||||||
|
local_logprobs[i - first] = logits[eval_pairs[i].second] - max_logit - std::log(sum_p);
|
||||||
|
}
|
||||||
|
std::memcpy(eval_results.data() + first, local_logprobs, (last - first)*sizeof(float));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (size_t it = 0; it < max_threads; ++it) {
|
||||||
|
workers[it] = std::thread(compute);
|
||||||
|
}
|
||||||
|
for (size_t it = 0; it < max_threads; ++it) {
|
||||||
|
workers[it].join();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
@ -533,7 +589,6 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
|
|
||||||
// determine the common prefix of the endings
|
// determine the common prefix of the endings
|
||||||
hs_cur.common_prefix = 0;
|
hs_cur.common_prefix = 0;
|
||||||
hs_cur.required_tokens = 0;
|
|
||||||
for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
|
for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
|
||||||
if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
|
if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
|
||||||
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
|
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
|
||||||
@ -566,40 +621,17 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
const int n_batch = params.n_batch;
|
const int n_batch = params.n_batch;
|
||||||
|
|
||||||
const int max_tasks_per_batch = params.n_parallel;
|
const int max_tasks_per_batch = 32;
|
||||||
const int max_seq = 4*max_tasks_per_batch;
|
const int max_seq = 4*max_tasks_per_batch;
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
||||||
|
|
||||||
std::vector<float> tok_logits(n_vocab);
|
std::vector<float> tok_logits(n_vocab);
|
||||||
std::vector<float> batch_logits(n_ctx*n_vocab);
|
std::vector<float> batch_logits(n_vocab*n_ctx);
|
||||||
|
|
||||||
auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
||||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
std::vector<float> eval_results;
|
||||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
std::vector<std::thread> workers(std::thread::hardware_concurrency());
|
||||||
|
|
||||||
llama_batch batch_view = {
|
|
||||||
n_tokens,
|
|
||||||
batch.token + i,
|
|
||||||
nullptr,
|
|
||||||
batch.pos + i,
|
|
||||||
batch.n_seq_id + i,
|
|
||||||
batch.seq_id + i,
|
|
||||||
batch.logits + i,
|
|
||||||
0, 0, 0, // unused
|
|
||||||
};
|
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
|
||||||
if (ret != 0) {
|
|
||||||
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
for (size_t i0 = 0; i0 < hs_task_count; i0++) {
|
for (size_t i0 = 0; i0 < hs_task_count; i0++) {
|
||||||
int n_cur = 0;
|
int n_cur = 0;
|
||||||
@ -649,11 +681,29 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
// decode all tasks [i0, i1)
|
// decode all tasks [i0, i1)
|
||||||
if (!decode_helper(ctx, batch, n_batch)) {
|
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
||||||
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
|
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Compute log-probs in parallel
|
||||||
|
// First we collect all tasks
|
||||||
|
eval_pairs.clear();
|
||||||
|
for (size_t i = i0; i < i1; ++i) {
|
||||||
|
auto & hs_cur = hs_data[i];
|
||||||
|
size_t li = hs_cur.common_prefix;
|
||||||
|
for (int s = 0; s < 4; ++s) {
|
||||||
|
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
|
||||||
|
eval_pairs.push_back(std::make_pair(hs_cur.i_batch + li++, hs_cur.seq_tokens[s][j + 1]));
|
||||||
|
}
|
||||||
|
++li;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Then we do the actual calculation
|
||||||
|
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
|
||||||
|
|
||||||
|
size_t ir = 0;
|
||||||
|
|
||||||
// compute the logprobs for each ending of the decoded tasks
|
// compute the logprobs for each ending of the decoded tasks
|
||||||
for (size_t i = i0; i < i1; ++i) {
|
for (size_t i = i0; i < i1; ++i) {
|
||||||
auto & hs_cur = hs_data[i];
|
auto & hs_cur = hs_data[i];
|
||||||
@ -662,26 +712,13 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
|
|
||||||
const auto first_probs = softmax(tok_logits);
|
const auto first_probs = softmax(tok_logits);
|
||||||
|
|
||||||
size_t li = hs_cur.common_prefix; // logits index in the batch
|
|
||||||
|
|
||||||
for (int s = 0; s < 4; ++s) {
|
for (int s = 0; s < 4; ++s) {
|
||||||
hs_cur.ending_logprob_count[s] = 1;
|
hs_cur.ending_logprob_count[s] = 1;
|
||||||
hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]);
|
hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]);
|
||||||
|
|
||||||
// Calculate the logprobs over the ending
|
|
||||||
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
|
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
|
||||||
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float));
|
hs_cur.ending_logprob[s] += eval_results[ir++];
|
||||||
|
|
||||||
const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]];
|
|
||||||
|
|
||||||
hs_cur.ending_logprob[s] += std::log(prob);
|
|
||||||
hs_cur.ending_logprob_count[s]++;
|
hs_cur.ending_logprob_count[s]++;
|
||||||
}
|
}
|
||||||
|
|
||||||
// account that we skip the last token in the ending
|
|
||||||
++li;
|
|
||||||
|
|
||||||
// Calculate the mean token logprob for acc_norm
|
|
||||||
hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s];
|
hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -720,6 +757,13 @@ struct winogrande_entry {
|
|||||||
std::string second;
|
std::string second;
|
||||||
std::array<std::string, 2> choices;
|
std::array<std::string, 2> choices;
|
||||||
int answer;
|
int answer;
|
||||||
|
|
||||||
|
size_t i_batch;
|
||||||
|
size_t common_prefix;
|
||||||
|
size_t required_tokens;
|
||||||
|
size_t n_base1; // number of tokens for context + choice 1
|
||||||
|
size_t n_base2; // number of tokens for context + choice 2
|
||||||
|
std::vector<llama_token> seq_tokens[2];
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string& prompt) {
|
static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string& prompt) {
|
||||||
@ -813,7 +857,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
}
|
}
|
||||||
float scale = 1/(1.f + (float)rng.max());
|
float scale = 1/(1.f + (float)rng.max());
|
||||||
std::vector<winogrande_entry> selected;
|
std::vector<winogrande_entry> selected;
|
||||||
selected.reserve(params.winogrande_tasks);
|
selected.resize(params.winogrande_tasks);
|
||||||
for (int i = 0; i < int(params.winogrande_tasks); ++i) {
|
for (int i = 0; i < int(params.winogrande_tasks); ++i) {
|
||||||
int j = int(scale*rng()*aux.size());
|
int j = int(scale*rng()*aux.size());
|
||||||
selected[i] = std::move(data[aux[j]]);
|
selected[i] = std::move(data[aux[j]]);
|
||||||
@ -823,115 +867,159 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
data = std::move(selected);
|
data = std::move(selected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "%s : tokenizing selected tasks\n", __func__);
|
||||||
|
|
||||||
// This is needed as usual for LLaMA models
|
// This is needed as usual for LLaMA models
|
||||||
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
||||||
|
|
||||||
|
for (auto & task : data) {
|
||||||
|
task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, add_bos);
|
||||||
|
task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, add_bos);
|
||||||
|
|
||||||
|
task.common_prefix = 0;
|
||||||
|
for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
|
||||||
|
if (task.seq_tokens[0][k] != task.seq_tokens[1][k]) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
task.common_prefix++;
|
||||||
|
}
|
||||||
|
|
||||||
|
task.required_tokens = task.common_prefix +
|
||||||
|
task.seq_tokens[0].size() - task.common_prefix +
|
||||||
|
task.seq_tokens[1].size() - task.common_prefix;
|
||||||
|
|
||||||
|
task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size();
|
||||||
|
task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size();
|
||||||
|
}
|
||||||
|
|
||||||
fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
|
fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
|
||||||
|
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
|
const int n_batch = params.n_batch;
|
||||||
|
|
||||||
|
const int max_tasks_per_batch = 128;
|
||||||
|
const int max_seq = 2*max_tasks_per_batch;
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
||||||
|
|
||||||
std::vector<float> tok_logits(n_vocab);
|
std::vector<float> tok_logits(n_vocab);
|
||||||
|
std::vector<float> batch_logits(n_vocab*n_ctx);
|
||||||
|
|
||||||
|
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
||||||
|
std::vector<float> eval_results;
|
||||||
|
std::vector<std::thread> workers(std::thread::hardware_concurrency());
|
||||||
|
|
||||||
int n_correct = 0;
|
int n_correct = 0;
|
||||||
int n_done = 0;
|
int n_done = 0;
|
||||||
|
|
||||||
for (size_t task_idx = 0; task_idx < data.size(); task_idx++) {
|
for (size_t i0 = 0; i0 < data.size(); i0++) {
|
||||||
const auto& task = data[task_idx];
|
int n_cur = 0;
|
||||||
|
|
||||||
auto base_context = ::llama_tokenize(ctx, task.first, add_bos);
|
size_t i1 = i0;
|
||||||
auto base_ctx_1st = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos);
|
size_t i_batch = 0;
|
||||||
auto base_ctx_2nd = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos);
|
|
||||||
|
|
||||||
auto sentence_1st = task.first + task.choices[0] + task.second;
|
llama_batch_clear(batch);
|
||||||
auto sentence_2nd = task.first + task.choices[1] + task.second;
|
|
||||||
auto query_1st = ::llama_tokenize(ctx, sentence_1st, add_bos);
|
|
||||||
auto query_2nd = ::llama_tokenize(ctx, sentence_2nd, add_bos);
|
|
||||||
|
|
||||||
if (query_1st.size() > (size_t)n_ctx || query_2nd.size() > (size_t)n_ctx) {
|
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
|
||||||
fprintf(stderr, "%s : number of tokens in queries %zu, %zu > n_ctxl\n", __func__, query_1st.size(), query_2nd.size());
|
const int s0 = 2*(i1 - i0);
|
||||||
|
if (s0 + 2 > max_seq) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
||||||
|
llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1}, false);
|
||||||
|
}
|
||||||
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
|
for (int s = 0; s < 2; ++s) {
|
||||||
|
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
|
||||||
|
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data[i1].i_batch = i_batch;
|
||||||
|
i_batch += data[i1].required_tokens;
|
||||||
|
|
||||||
|
n_cur += data[i1].required_tokens;
|
||||||
|
if (++i1 == data.size()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i0 == i1) {
|
||||||
|
fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto query_1st_size = query_1st.size();
|
|
||||||
auto query_2nd_size = query_2nd.size();
|
|
||||||
|
|
||||||
// Speedup small evaluations by evaluating atleast 32 tokens
|
|
||||||
// For Winogrande this seems to slow it down rather than speed it up.
|
|
||||||
//if (query_1st.size() < 32) query_1st.resize(32);
|
|
||||||
//if (query_2nd.size() < 32) query_2nd.resize(32);
|
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
auto logits_1st = evaluate_tokens(ctx, query_1st, 0, params.n_batch, n_vocab);
|
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
// decode all tasks [i0, i1)
|
||||||
auto logits_2nd = evaluate_tokens(ctx, query_2nd, 0, params.n_batch, n_vocab);
|
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
||||||
|
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
|
||||||
if (logits_1st.empty() || logits_2nd.empty()) {
|
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool skip_choice = query_1st_size - base_ctx_1st.size() > k_min_trailing_ctx &&
|
eval_pairs.clear();
|
||||||
query_2nd_size - base_ctx_2nd.size() > k_min_trailing_ctx;
|
for (size_t i = i0; i < i1; ++i) {
|
||||||
|
auto & task = data[i];
|
||||||
|
|
||||||
float score_1st = 0;
|
const bool skip_choice =
|
||||||
bool is_nan_1st = false;
|
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
|
||||||
const auto& base_1 = skip_choice ? base_ctx_1st : base_context;
|
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
|
||||||
const int last_1st = query_1st_size - base_1.size() > 1 ? 1 : 0;
|
|
||||||
for (size_t j = base_1.size()-1; j < query_1st_size-1-last_1st; ++j) {
|
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
|
||||||
std::memcpy(tok_logits.data(), logits_1st.data() + j*n_vocab, n_vocab*sizeof(float));
|
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
|
||||||
const float prob = softmax(tok_logits)[query_1st[j+1]];
|
size_t li = n_base1 - 1;
|
||||||
if (std::isnan(prob) || !prob) {
|
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
|
||||||
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
|
eval_pairs.push_back(std::make_pair(task.i_batch + li++, task.seq_tokens[0][j+1]));
|
||||||
prob, j, sentence_1st.c_str(), base_context.size());
|
|
||||||
is_nan_1st = true;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
score_1st += std::log(prob);
|
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
|
||||||
}
|
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
|
||||||
score_1st /= (query_1st_size - base_1.size() - last_1st);
|
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
|
||||||
|
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
|
||||||
float score_2nd = 0;
|
eval_pairs.push_back(std::make_pair(task.i_batch + li++, task.seq_tokens[1][j+1]));
|
||||||
bool is_nan_2nd = false;
|
|
||||||
const auto& base_2 = skip_choice ? base_ctx_2nd : base_context;
|
|
||||||
const int last_2nd = query_2nd_size - base_2.size() > 1 ? 1 : 0;
|
|
||||||
for (size_t j = base_2.size()-1; j < query_2nd_size-1-last_2nd; ++j) {
|
|
||||||
std::memcpy(tok_logits.data(), logits_2nd.data() + j*n_vocab, n_vocab*sizeof(float));
|
|
||||||
const float prob = softmax(tok_logits)[query_2nd[j+1]];
|
|
||||||
if (std::isnan(prob) || !prob) {
|
|
||||||
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
|
|
||||||
prob, j, sentence_2nd.c_str(), base_context.size());
|
|
||||||
is_nan_2nd = true;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
score_2nd += std::log(prob);
|
|
||||||
}
|
}
|
||||||
score_2nd /= (query_2nd_size - base_2.size() - last_2nd);
|
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
|
||||||
|
|
||||||
if (is_nan_1st || is_nan_2nd) {
|
size_t ir = 0;
|
||||||
continue;
|
for (size_t i = i0; i < i1; ++i) {
|
||||||
|
auto & task = data[i];
|
||||||
|
|
||||||
|
const bool skip_choice =
|
||||||
|
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
|
||||||
|
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
|
||||||
|
|
||||||
|
float score_1st = 0;
|
||||||
|
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
|
||||||
|
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
|
||||||
|
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
|
||||||
|
score_1st += eval_results[ir++];
|
||||||
|
}
|
||||||
|
score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);
|
||||||
|
|
||||||
|
float score_2nd = 0;
|
||||||
|
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
|
||||||
|
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
|
||||||
|
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
|
||||||
|
score_2nd += eval_results[ir++];
|
||||||
|
}
|
||||||
|
score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);
|
||||||
|
|
||||||
|
int result = score_1st > score_2nd ? 1 : 2;
|
||||||
|
|
||||||
|
if (result == task.answer) {
|
||||||
|
++n_correct;
|
||||||
|
}
|
||||||
|
++n_done;
|
||||||
|
|
||||||
|
// print the accumulated accuracy mean x 100
|
||||||
|
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
|
||||||
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (std::isnan(score_1st) || std::isnan(score_2nd)) {
|
i0 = i1 - 1;
|
||||||
printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
|
|
||||||
printf("Q1: <%s> - %zu tokens\n", sentence_1st.c_str(), query_1st_size);
|
|
||||||
printf("Q2: <%s> - %zu tokens\n", sentence_2nd.c_str(), query_2nd_size);
|
|
||||||
printf("B : <%s> - %zu tokens\n", task.first.c_str(), base_context.size());
|
|
||||||
printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", base_1.size(), base_2.size(), skip_choice);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
int result = score_1st > score_2nd ? 1 : 2;
|
|
||||||
|
|
||||||
if (result == task.answer) {
|
|
||||||
++n_correct;
|
|
||||||
}
|
|
||||||
++n_done;
|
|
||||||
|
|
||||||
// Print the accumulated accuracy mean x 100
|
|
||||||
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n",task_idx+1, 100.0 * n_correct/n_done,score_1st,score_2nd,result,task.answer);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
@ -1558,6 +1558,7 @@ struct llama_server_context
|
|||||||
void process_tasks()
|
void process_tasks()
|
||||||
{
|
{
|
||||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||||
|
std::vector<task_server> deferred_tasks;
|
||||||
while (!queue_tasks.empty())
|
while (!queue_tasks.empty())
|
||||||
{
|
{
|
||||||
task_server task = queue_tasks.front();
|
task_server task = queue_tasks.front();
|
||||||
@ -1568,9 +1569,8 @@ struct llama_server_context
|
|||||||
llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
|
llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
|
||||||
if (slot == nullptr)
|
if (slot == nullptr)
|
||||||
{
|
{
|
||||||
LOG_TEE("slot unavailable\n");
|
// if no slot is available, we defer this task for processing later
|
||||||
// send error result
|
deferred_tasks.push_back(task);
|
||||||
send_error(task, "slot unavailable");
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1616,6 +1616,12 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add all the deferred tasks back the the queue
|
||||||
|
for (task_server &task : deferred_tasks)
|
||||||
|
{
|
||||||
|
queue_tasks.push_back(task);
|
||||||
|
}
|
||||||
|
|
||||||
// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
|
// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
|
||||||
std::vector<task_result> agg_results;
|
std::vector<task_result> agg_results;
|
||||||
auto queue_iterator = queue_multitasks.begin();
|
auto queue_iterator = queue_multitasks.begin();
|
||||||
|
@ -263,7 +263,6 @@ static void init_model(struct my_llama_model * model) {
|
|||||||
model->data.resize(size + tensor_alignment);
|
model->data.resize(size + tensor_alignment);
|
||||||
alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment);
|
alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment);
|
||||||
alloc_model(alloc, model);
|
alloc_model(alloc, model);
|
||||||
ggml_allocr_free(alloc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) {
|
static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) {
|
||||||
@ -1102,7 +1101,6 @@ int main(int argc, char ** argv) {
|
|||||||
alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment);
|
alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment);
|
||||||
ggml_allocr_alloc(alloc, tokens_input);
|
ggml_allocr_alloc(alloc, tokens_input);
|
||||||
ggml_allocr_alloc(alloc, target_probs);
|
ggml_allocr_alloc(alloc, target_probs);
|
||||||
ggml_allocr_free(alloc);
|
|
||||||
|
|
||||||
// context for compute tensors without their data
|
// context for compute tensors without their data
|
||||||
const size_t estimated_compute_size_wo_data = (
|
const size_t estimated_compute_size_wo_data = (
|
||||||
@ -1149,7 +1147,6 @@ int main(int argc, char ** argv) {
|
|||||||
best_compute_size = max_compute_size;
|
best_compute_size = max_compute_size;
|
||||||
best_order = gf->order;
|
best_order = gf->order;
|
||||||
}
|
}
|
||||||
ggml_allocr_free(alloc);
|
|
||||||
ggml_free(ctx_compute);
|
ggml_free(ctx_compute);
|
||||||
}
|
}
|
||||||
size_t max_compute_size = best_compute_size;
|
size_t max_compute_size = best_compute_size;
|
||||||
@ -1177,7 +1174,6 @@ int main(int argc, char ** argv) {
|
|||||||
params.common.use_flash,
|
params.common.use_flash,
|
||||||
params.common.use_checkpointing
|
params.common.use_checkpointing
|
||||||
);
|
);
|
||||||
ggml_allocr_free(alloc);
|
|
||||||
|
|
||||||
std::vector<llama_token> train_tokens;
|
std::vector<llama_token> train_tokens;
|
||||||
std::vector<size_t> train_samples_begin;
|
std::vector<size_t> train_samples_begin;
|
||||||
|
@ -12,9 +12,6 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include "ggml-cuda.h"
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend-impl.h"
|
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS)
|
#if defined(GGML_USE_HIPBLAS)
|
||||||
#include <hip/hip_runtime.h>
|
#include <hip/hip_runtime.h>
|
||||||
@ -118,6 +115,11 @@
|
|||||||
|
|
||||||
#endif // defined(GGML_USE_HIPBLAS)
|
#endif // defined(GGML_USE_HIPBLAS)
|
||||||
|
|
||||||
|
// ggml-cuda need half type so keep ggml headers include at last
|
||||||
|
#include "ggml-cuda.h"
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend-impl.h"
|
||||||
|
|
||||||
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
|
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
|
||||||
|
|
||||||
#define CC_PASCAL 600
|
#define CC_PASCAL 600
|
||||||
|
@ -97,8 +97,10 @@ class MODEL_ARCH(IntEnum):
|
|||||||
BLOOM = auto()
|
BLOOM = auto()
|
||||||
STABLELM = auto()
|
STABLELM = auto()
|
||||||
QWEN = auto()
|
QWEN = auto()
|
||||||
|
QWEN2 = auto()
|
||||||
PHI2 = auto()
|
PHI2 = auto()
|
||||||
PLAMO = auto()
|
PLAMO = auto()
|
||||||
|
CODESHELL = auto()
|
||||||
|
|
||||||
|
|
||||||
class MODEL_TENSOR(IntEnum):
|
class MODEL_TENSOR(IntEnum):
|
||||||
@ -145,8 +147,10 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||||||
MODEL_ARCH.BLOOM: "bloom",
|
MODEL_ARCH.BLOOM: "bloom",
|
||||||
MODEL_ARCH.STABLELM: "stablelm",
|
MODEL_ARCH.STABLELM: "stablelm",
|
||||||
MODEL_ARCH.QWEN: "qwen",
|
MODEL_ARCH.QWEN: "qwen",
|
||||||
|
MODEL_ARCH.QWEN2: "qwen2",
|
||||||
MODEL_ARCH.PHI2: "phi2",
|
MODEL_ARCH.PHI2: "phi2",
|
||||||
MODEL_ARCH.PLAMO: "plamo",
|
MODEL_ARCH.PLAMO: "plamo",
|
||||||
|
MODEL_ARCH.CODESHELL: "codeshell",
|
||||||
}
|
}
|
||||||
|
|
||||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
@ -356,6 +360,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.FFN_UP,
|
MODEL_TENSOR.FFN_UP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.QWEN2: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
],
|
||||||
MODEL_ARCH.PLAMO: [
|
MODEL_ARCH.PLAMO: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
@ -396,6 +414,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_NORM,
|
MODEL_TENSOR.FFN_NORM,
|
||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.FFN_UP,
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
],
|
||||||
|
MODEL_ARCH.CODESHELL: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.POS_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_QKV,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
]
|
]
|
||||||
# TODO
|
# TODO
|
||||||
}
|
}
|
||||||
@ -417,6 +448,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.ROPE_FREQS,
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.CODESHELL: [
|
||||||
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
|
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -154,6 +154,7 @@ class TensorNameMap:
|
|||||||
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
|
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
|
||||||
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
|
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
|
||||||
"model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo
|
"model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo
|
||||||
|
"transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell
|
||||||
),
|
),
|
||||||
|
|
||||||
# Feed-forward norm
|
# Feed-forward norm
|
||||||
|
387
llama.cpp
387
llama.cpp
@ -192,8 +192,10 @@ enum llm_arch {
|
|||||||
LLM_ARCH_BLOOM,
|
LLM_ARCH_BLOOM,
|
||||||
LLM_ARCH_STABLELM,
|
LLM_ARCH_STABLELM,
|
||||||
LLM_ARCH_QWEN,
|
LLM_ARCH_QWEN,
|
||||||
|
LLM_ARCH_QWEN2,
|
||||||
LLM_ARCH_PHI2,
|
LLM_ARCH_PHI2,
|
||||||
LLM_ARCH_PLAMO,
|
LLM_ARCH_PLAMO,
|
||||||
|
LLM_ARCH_CODESHELL,
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -211,8 +213,10 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
|
|||||||
{ LLM_ARCH_BLOOM, "bloom" },
|
{ LLM_ARCH_BLOOM, "bloom" },
|
||||||
{ LLM_ARCH_STABLELM, "stablelm" },
|
{ LLM_ARCH_STABLELM, "stablelm" },
|
||||||
{ LLM_ARCH_QWEN, "qwen" },
|
{ LLM_ARCH_QWEN, "qwen" },
|
||||||
|
{ LLM_ARCH_QWEN2, "qwen2" },
|
||||||
{ LLM_ARCH_PHI2, "phi2" },
|
{ LLM_ARCH_PHI2, "phi2" },
|
||||||
{ LLM_ARCH_PLAMO, "plamo" },
|
{ LLM_ARCH_PLAMO, "plamo" },
|
||||||
|
{ LLM_ARCH_CODESHELL, "codeshell" },
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llm_kv {
|
enum llm_kv {
|
||||||
@ -566,6 +570,23 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
|||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_QWEN2,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_PHI2,
|
LLM_ARCH_PHI2,
|
||||||
{
|
{
|
||||||
@ -600,6 +621,26 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
|||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_CODESHELL,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
@ -1599,7 +1640,7 @@ struct llama_model {
|
|||||||
std::unique_ptr<llama_mmap> mapping;
|
std::unique_ptr<llama_mmap> mapping;
|
||||||
|
|
||||||
// objects representing data potentially being locked in memory
|
// objects representing data potentially being locked in memory
|
||||||
llama_mlock mlock_buf;
|
std::vector<std::unique_ptr<llama_mlock>> mlock_bufs;
|
||||||
llama_mlock mlock_mmap;
|
llama_mlock mlock_mmap;
|
||||||
|
|
||||||
// for quantize-stats only
|
// for quantize-stats only
|
||||||
@ -2847,6 +2888,17 @@ static void llm_load_hparams(
|
|||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_QWEN2:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 24: model.type = e_model::MODEL_1B; break;
|
||||||
|
case 32: model.type = e_model::MODEL_7B; break;
|
||||||
|
case 40: model.type = e_model::MODEL_13B; break;
|
||||||
|
case 80: model.type = e_model::MODEL_70B; break;
|
||||||
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLM_ARCH_PHI2:
|
case LLM_ARCH_PHI2:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
@ -2877,6 +2929,14 @@ static void llm_load_hparams(
|
|||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_CODESHELL:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 42: model.type = e_model::MODEL_SMALL; break;
|
||||||
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
|
||||||
default: (void)0;
|
default: (void)0;
|
||||||
}
|
}
|
||||||
@ -3438,7 +3498,12 @@ static bool llm_load_tensors(
|
|||||||
{
|
{
|
||||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||||
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
|
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
|
||||||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_OUTPUT, "weight").c_str()) >= 0) {
|
||||||
|
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||||
|
} else {
|
||||||
|
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // needs to be on GPU
|
||||||
|
ml.n_created--; // artificial tensor
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < n_layer; ++i) {
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
@ -3669,6 +3734,41 @@ static bool llm_load_tensors(
|
|||||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2});
|
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2});
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_QWEN2:
|
||||||
|
{
|
||||||
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||||
|
|
||||||
|
// output
|
||||||
|
{
|
||||||
|
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||||
|
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||||
|
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||||
|
|
||||||
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||||
|
|
||||||
|
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||||
|
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
|
||||||
|
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||||
|
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||||
|
|
||||||
|
// optional bias tensors
|
||||||
|
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
|
||||||
|
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa});
|
||||||
|
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa});
|
||||||
|
|
||||||
|
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||||
|
|
||||||
|
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||||
|
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||||
|
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLM_ARCH_PHI2:
|
case LLM_ARCH_PHI2:
|
||||||
{
|
{
|
||||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||||
@ -3779,6 +3879,42 @@ static bool llm_load_tensors(
|
|||||||
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
|
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_CODESHELL:
|
||||||
|
{
|
||||||
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||||
|
|
||||||
|
// output
|
||||||
|
{
|
||||||
|
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||||
|
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
|
||||||
|
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||||
|
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||||
|
|
||||||
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||||
|
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
|
||||||
|
|
||||||
|
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
|
||||||
|
layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa});
|
||||||
|
|
||||||
|
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||||
|
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
|
||||||
|
|
||||||
|
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||||
|
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
|
||||||
|
|
||||||
|
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
|
||||||
|
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
|
||||||
|
|
||||||
|
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||||
|
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("unknown architecture");
|
throw std::runtime_error("unknown architecture");
|
||||||
}
|
}
|
||||||
@ -3815,8 +3951,10 @@ static bool llm_load_tensors(
|
|||||||
else {
|
else {
|
||||||
buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||||
if (buf != nullptr && use_mlock && ggml_backend_buffer_is_host(buf)) {
|
if (buf != nullptr && use_mlock && ggml_backend_buffer_is_host(buf)) {
|
||||||
model.mlock_buf.init (ggml_backend_buffer_get_base(buf));
|
model.mlock_bufs.emplace_back(new llama_mlock);
|
||||||
model.mlock_buf.grow_to(ggml_backend_buffer_get_size(buf));
|
auto & mlock_buf = model.mlock_bufs.back();
|
||||||
|
mlock_buf->init (ggml_backend_buffer_get_base(buf));
|
||||||
|
mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (buf == nullptr) {
|
if (buf == nullptr) {
|
||||||
@ -5638,6 +5776,128 @@ struct llm_build_context {
|
|||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph * build_qwen2() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
struct ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
|
||||||
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||||
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
|
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
||||||
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
|
// shift the entire K-cache if needed
|
||||||
|
if (do_rope_shift) {
|
||||||
|
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
// norm
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.layers[il].attn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
// compute Q and K and RoPE them
|
||||||
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
// these nodes are added to the graph together so that they are not reordered
|
||||||
|
// by doing so, the number of splits in the graph is reduced
|
||||||
|
ggml_build_forward_expand(gf, Qcur);
|
||||||
|
ggml_build_forward_expand(gf, Kcur);
|
||||||
|
ggml_build_forward_expand(gf, Vcur);
|
||||||
|
|
||||||
|
Qcur = ggml_rope_custom(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
|
||||||
|
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_custom(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
|
||||||
|
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
|
cb(cur, "kqv_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
// feed-forward network
|
||||||
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||||
|
model.layers[il].ffn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
cur = llm_build_ffn(ctx0, cur,
|
||||||
|
model.layers[il].ffn_up, NULL,
|
||||||
|
model.layers[il].ffn_gate, NULL,
|
||||||
|
model.layers[il].ffn_down, NULL,
|
||||||
|
NULL,
|
||||||
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, cur, hparams,
|
||||||
|
model.output_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, -1);
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_cgraph * build_phi2() {
|
struct ggml_cgraph * build_phi2() {
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
@ -5971,6 +6231,117 @@ struct llm_build_context {
|
|||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph * build_codeshell() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
struct ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
|
||||||
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||||
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
|
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
||||||
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
|
// shift the entire K-cache if needed
|
||||||
|
if (do_rope_shift) {
|
||||||
|
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.layers[il].attn_norm,
|
||||||
|
model.layers[il].attn_norm_b,
|
||||||
|
LLM_NORM, cb, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||||
|
cb(cur, "wqkv", il);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||||
|
cb(cur, "bqkv", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||||
|
struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||||
|
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||||
|
|
||||||
|
cb(tmpq, "tmpq", il);
|
||||||
|
cb(tmpk, "tmpk", il);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * Qcur = ggml_rope_custom(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos,
|
||||||
|
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * Kcur = ggml_rope_custom(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos,
|
||||||
|
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
|
cb(cur, "kqv_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add the input
|
||||||
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
// FF
|
||||||
|
{
|
||||||
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||||
|
model.layers[il].ffn_norm,
|
||||||
|
model.layers[il].ffn_norm_b,
|
||||||
|
LLM_NORM, cb, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
cur = llm_build_ffn(ctx0, cur,
|
||||||
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||||
|
NULL, NULL,
|
||||||
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||||
|
NULL,
|
||||||
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
inpL = ggml_add(ctx0, cur, ffn_inp);
|
||||||
|
cb(inpL, "l_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.output_norm,
|
||||||
|
model.output_norm_b,
|
||||||
|
LLM_NORM, cb, -1);
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct ggml_cgraph * llama_build_graph(
|
static struct ggml_cgraph * llama_build_graph(
|
||||||
@ -6153,6 +6524,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||||||
{
|
{
|
||||||
result = llm.build_qwen();
|
result = llm.build_qwen();
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_QWEN2:
|
||||||
|
{
|
||||||
|
result = llm.build_qwen2();
|
||||||
|
} break;
|
||||||
case LLM_ARCH_PHI2:
|
case LLM_ARCH_PHI2:
|
||||||
{
|
{
|
||||||
result = llm.build_phi2();
|
result = llm.build_phi2();
|
||||||
@ -6165,6 +6540,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||||||
{
|
{
|
||||||
result = llm.build_gpt2();
|
result = llm.build_gpt2();
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_CODESHELL:
|
||||||
|
{
|
||||||
|
result = llm.build_codeshell();
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,7 @@ wget https://raw.githubusercontent.com/klosax/hellaswag_text_data/main/hellaswag
|
|||||||
|
|
||||||
echo "Usage:"
|
echo "Usage:"
|
||||||
echo ""
|
echo ""
|
||||||
echo " ./perplexity --hellaswag --hellaswag-tasks N -f hellaswag_val_full.txt -m modelfile.gguf"
|
echo " ./perplexity -m model.gguf -f hellaswag_val_full.txt --hellaswag [--hellaswag-tasks N] [other params]"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
exit 0
|
exit 0
|
||||||
|
@ -1,3 +1,10 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip
|
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip
|
||||||
|
|
||||||
|
echo "Usage:"
|
||||||
|
echo ""
|
||||||
|
echo " ./perplexity -m model.gguf -f wiki.test.raw [other params]"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
exit 0
|
||||||
|
10
scripts/get-winogrande.sh
Executable file
10
scripts/get-winogrande.sh
Executable file
@ -0,0 +1,10 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
wget https://huggingface.co/datasets/ikawrakow/winogrande-eval-for-llama.cpp/raw/main/winogrande-debiased-eval.csv
|
||||||
|
|
||||||
|
echo "Usage:"
|
||||||
|
echo ""
|
||||||
|
echo " ./perplexity -m model.gguf -f winogrande-debiased-eval.csv --winogrande [--winogrande-tasks N] [other params]"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
exit 0
|
Loading…
Reference in New Issue
Block a user