mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
Merge branch 'master' into gguf
ggml-ci
This commit is contained in:
commit
1e7a0092dd
13
README.md
13
README.md
@ -9,13 +9,13 @@
|
||||
|
||||
Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
|
||||
|
||||
**Hot topics:**
|
||||
### 🚧 Incoming breaking change + refactoring:
|
||||
|
||||
- Simple web chat example: https://github.com/ggerganov/llama.cpp/pull/1998
|
||||
- k-quants now support super-block size of 64: https://github.com/ggerganov/llama.cpp/pull/2001
|
||||
- New roadmap: https://github.com/users/ggerganov/projects/7
|
||||
- Azure CI brainstorming: https://github.com/ggerganov/llama.cpp/discussions/1985
|
||||
- p1 : LLM-based code completion engine at the edge : https://github.com/ggml-org/p1/discussions/1
|
||||
See PR https://github.com/ggerganov/llama.cpp/pull/2398 for more info.
|
||||
|
||||
To devs: avoid making big changes to `llama.h` / `llama.cpp` until merged
|
||||
|
||||
----
|
||||
|
||||
<details>
|
||||
<summary>Table of Contents</summary>
|
||||
@ -99,6 +99,7 @@ as the main playground for developing new features for the [ggml](https://github
|
||||
- Rust: [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp)
|
||||
- C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp)
|
||||
- Scala 3: [donderom/llm4s](https://github.com/donderom/llm4s)
|
||||
- Clojure: [phronmophobic/llama.clj](https://github.com/phronmophobic/llama.clj)
|
||||
|
||||
**UI:**
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <cmath>
|
||||
#include <ctime>
|
||||
#include <sstream>
|
||||
#include <cstring>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
@ -121,6 +122,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch,
|
||||
int n_vocab, int n_thread) {
|
||||
std::vector<float> result;
|
||||
result.reserve(tokens.size() * n_vocab);
|
||||
size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
|
||||
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
|
||||
size_t n_tokens = tokens.size() - i_chunk * n_batch;
|
||||
n_tokens = std::min(n_tokens, size_t(n_batch));
|
||||
if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return {};
|
||||
}
|
||||
|
||||
const auto logits = llama_get_logits(ctx);
|
||||
result.insert(result.end(), logits, logits + n_tokens * n_vocab);
|
||||
|
||||
n_past += n_tokens;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||
// Calculates hellaswag score (acc_norm) from prompt
|
||||
//
|
||||
@ -209,17 +231,19 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||
double acc = 0.0f;
|
||||
const int n_vocab = llama_n_vocab(ctx);
|
||||
|
||||
std::vector<float> tok_logits(n_vocab);
|
||||
|
||||
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
|
||||
|
||||
// Tokenize the context to count tokens
|
||||
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
|
||||
size_t context_size = context_embd.size();
|
||||
|
||||
for (size_t ending_idx=0;ending_idx<4;ending_idx++) {
|
||||
|
||||
// Tokenize the query
|
||||
std::vector<int> query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[ending_idx], prepend_bos);
|
||||
size_t query_size = query_embd.size();
|
||||
// Do the 1st ending
|
||||
// In this case we include the context when evaluating
|
||||
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos);
|
||||
auto query_size = query_embd.size();
|
||||
//printf("First query: %d\n",(int)query_size);
|
||||
|
||||
// Stop if query wont fit the ctx window
|
||||
if (query_size > (size_t)params.n_ctx) {
|
||||
@ -232,27 +256,68 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||
query_embd.resize(32);
|
||||
}
|
||||
|
||||
// Evaluate the query
|
||||
if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
|
||||
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
|
||||
if (logits.empty()) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto query_logits = llama_get_logits(ctx);
|
||||
std::vector<float> logits;
|
||||
logits.insert(logits.end(), query_logits, query_logits + query_size * n_vocab);
|
||||
std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float));
|
||||
const auto first_probs = softmax(tok_logits);
|
||||
|
||||
hs_data[task_idx].ending_logprob_count[ending_idx] = 0;
|
||||
hs_data[task_idx].ending_logprob[ending_idx] = 0.0f;
|
||||
hs_data[task_idx].ending_logprob_count[0] = 1;
|
||||
hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]);
|
||||
|
||||
// Calculate the logprobs over the ending
|
||||
for (size_t j = context_size-1; j < query_size - 1; j++) {
|
||||
// Calculate probability of next token, given the previous ones.
|
||||
const std::vector<float> tok_logits(
|
||||
logits.begin() + (j + 0) * n_vocab,
|
||||
logits.begin() + (j + 1) * n_vocab);
|
||||
for (size_t j = context_size; j < query_size - 1; j++) {
|
||||
|
||||
const float prob = softmax(tok_logits)[query_embd[ j + 1]];
|
||||
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
|
||||
|
||||
const float prob = softmax(tok_logits)[query_embd[j + 1]];
|
||||
|
||||
hs_data[task_idx].ending_logprob[0] += std::log(prob);
|
||||
hs_data[task_idx].ending_logprob_count[0]++;
|
||||
}
|
||||
|
||||
// Calculate the mean token logprob for acc_norm
|
||||
hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0];
|
||||
|
||||
// Do the remaining endings
|
||||
// For these, we use the bare ending with n_past = context_size
|
||||
//
|
||||
for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
|
||||
|
||||
// Tokenize the query
|
||||
query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false);
|
||||
query_size = query_embd.size();
|
||||
|
||||
// Stop if query wont fit the ctx window
|
||||
if (context_size + query_size > (size_t)params.n_ctx) {
|
||||
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
|
||||
return;
|
||||
}
|
||||
|
||||
// Speedup small evaluations by evaluating atleast 32 tokens
|
||||
// No, resizing to 32 is actually slightly slower (at least on CUDA)
|
||||
//if (query_size < 32) {
|
||||
// query_embd.resize(32);
|
||||
//}
|
||||
|
||||
// Evaluate the query
|
||||
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads);
|
||||
if (logits.empty()) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
|
||||
hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
|
||||
|
||||
// Calculate the logprobs over the ending
|
||||
for (size_t j = 0; j < query_size - 1; j++) {
|
||||
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
|
||||
|
||||
const float prob = softmax(tok_logits)[query_embd[j + 1]];
|
||||
|
||||
hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
|
||||
hs_data[task_idx].ending_logprob_count[ending_idx]++;
|
||||
@ -267,9 +332,9 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||
}
|
||||
|
||||
// Find the ending with maximum logprob
|
||||
size_t ending_logprob_max_idx = -1;
|
||||
double ending_logprob_max_val = -INFINITY;
|
||||
for (size_t j=0; j < 4; j++) {
|
||||
size_t ending_logprob_max_idx = 0;
|
||||
double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0];
|
||||
for (size_t j = 1; j < 4; j++) {
|
||||
if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
|
||||
ending_logprob_max_idx = j;
|
||||
ending_logprob_max_val = hs_data[task_idx].ending_logprob[j];
|
||||
|
@ -11,8 +11,10 @@ echo >> $PUBLIC/index.js # add newline
|
||||
|
||||
FILES=$(ls $PUBLIC)
|
||||
|
||||
cd $PUBLIC
|
||||
for FILE in $FILES; do
|
||||
func=$(echo $FILE | tr '.' '_')
|
||||
echo "generate $FILE.hpp ($func)"
|
||||
xxd -n $func -i $PUBLIC/$FILE > $DIR/$FILE.hpp
|
||||
echo "generate $FILE.hpp"
|
||||
|
||||
# use simple flag for old version of xxd
|
||||
xxd -i $FILE > $DIR/$FILE.hpp
|
||||
done
|
||||
|
@ -144,12 +144,12 @@
|
||||
import { SchemaConverter } from '/json-schema-to-grammar.mjs';
|
||||
|
||||
const session = signal({
|
||||
prompt: "This is a conversation between user and llama, a friendly chatbot. respond in simple markdown.",
|
||||
prompt: "This is a conversation between User and Llama, a friendly chatbot. Llama is helpful, kind, honest, good at writing, and never fails to answer any requests immediately and with precision.",
|
||||
template: "{{prompt}}\n\n{{history}}\n{{char}}:",
|
||||
historyTemplate: "{{name}}: {{message}}",
|
||||
transcript: [],
|
||||
type: "chat",
|
||||
char: "llama",
|
||||
char: "Llama",
|
||||
user: "User",
|
||||
})
|
||||
|
||||
|
@ -1898,10 +1898,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
||||
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
threadgroup_barrier(mem_flags::mem_device);
|
||||
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_device);
|
||||
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
||||
if (sgitg==0) {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
|
244
ggml.c
244
ggml.c
@ -1643,11 +1643,37 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
||||
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||
|
||||
static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
[GGML_TYPE_I8] = {
|
||||
.type_name = "i8",
|
||||
.blck_size = 1,
|
||||
.type_size = sizeof(int8_t),
|
||||
.is_quantized = false,
|
||||
},
|
||||
[GGML_TYPE_I16] = {
|
||||
.type_name = "i16",
|
||||
.blck_size = 1,
|
||||
.type_size = sizeof(int16_t),
|
||||
.is_quantized = false,
|
||||
},
|
||||
[GGML_TYPE_I32] = {
|
||||
.type_name = "i32",
|
||||
.blck_size = 1,
|
||||
.type_size = sizeof(int32_t),
|
||||
.is_quantized = false,
|
||||
},
|
||||
[GGML_TYPE_F32] = {
|
||||
.type_name = "f32",
|
||||
.blck_size = 1,
|
||||
.type_size = sizeof(float),
|
||||
.is_quantized = false,
|
||||
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
|
||||
.vec_dot_type = GGML_TYPE_F32,
|
||||
},
|
||||
[GGML_TYPE_F16] = {
|
||||
.type_name = "f16",
|
||||
.blck_size = 1,
|
||||
.type_size = sizeof(ggml_fp16_t),
|
||||
.is_quantized = false,
|
||||
.to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row,
|
||||
.from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row,
|
||||
.from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row,
|
||||
@ -1655,6 +1681,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_F16,
|
||||
},
|
||||
[GGML_TYPE_Q4_0] = {
|
||||
.type_name = "q4_0",
|
||||
.blck_size = QK4_0,
|
||||
.type_size = sizeof(block_q4_0),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q4_0,
|
||||
.from_float = quantize_row_q4_0,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
|
||||
@ -1662,6 +1692,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
},
|
||||
[GGML_TYPE_Q4_1] = {
|
||||
.type_name = "q4_1",
|
||||
.blck_size = QK4_1,
|
||||
.type_size = sizeof(block_q4_1),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q4_1,
|
||||
.from_float = quantize_row_q4_1,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
|
||||
@ -1669,6 +1703,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||
},
|
||||
[GGML_TYPE_Q5_0] = {
|
||||
.type_name = "q5_0",
|
||||
.blck_size = QK5_0,
|
||||
.type_size = sizeof(block_q5_0),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q5_0,
|
||||
.from_float = quantize_row_q5_0,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference,
|
||||
@ -1676,6 +1714,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
},
|
||||
[GGML_TYPE_Q5_1] = {
|
||||
.type_name = "q5_1",
|
||||
.blck_size = QK5_1,
|
||||
.type_size = sizeof(block_q5_1),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q5_1,
|
||||
.from_float = quantize_row_q5_1,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference,
|
||||
@ -1683,6 +1725,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||
},
|
||||
[GGML_TYPE_Q8_0] = {
|
||||
.type_name = "q8_0",
|
||||
.blck_size = QK8_0,
|
||||
.type_size = sizeof(block_q8_0),
|
||||
.is_quantized = true,
|
||||
.to_float = dequantize_row_q8_0,
|
||||
.from_float = quantize_row_q8_0,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
|
||||
@ -1690,12 +1736,20 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
},
|
||||
[GGML_TYPE_Q8_1] = {
|
||||
.type_name = "q8_1",
|
||||
.blck_size = QK8_1,
|
||||
.type_size = sizeof(block_q8_1),
|
||||
.is_quantized = true,
|
||||
.from_float = quantize_row_q8_1,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
|
||||
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||
},
|
||||
#ifdef GGML_USE_K_QUANTS
|
||||
[GGML_TYPE_Q2_K] = {
|
||||
.type_name = "q2_K",
|
||||
.blck_size = QK_K,
|
||||
.type_size = sizeof(block_q2_K),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q2_K,
|
||||
.from_float = quantize_row_q2_K,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference,
|
||||
@ -1703,6 +1757,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
},
|
||||
[GGML_TYPE_Q3_K] = {
|
||||
.type_name = "q3_K",
|
||||
.blck_size = QK_K,
|
||||
.type_size = sizeof(block_q3_K),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q3_K,
|
||||
.from_float = quantize_row_q3_K,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference,
|
||||
@ -1710,6 +1768,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
},
|
||||
[GGML_TYPE_Q4_K] = {
|
||||
.type_name = "q4_K",
|
||||
.blck_size = QK_K,
|
||||
.type_size = sizeof(block_q4_K),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q4_K,
|
||||
.from_float = quantize_row_q4_K,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference,
|
||||
@ -1717,6 +1779,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
},
|
||||
[GGML_TYPE_Q5_K] = {
|
||||
.type_name = "q5_K",
|
||||
.blck_size = QK_K,
|
||||
.type_size = sizeof(block_q5_K),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q5_K,
|
||||
.from_float = quantize_row_q5_K,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference,
|
||||
@ -1724,6 +1790,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
},
|
||||
[GGML_TYPE_Q6_K] = {
|
||||
.type_name = "q6_K",
|
||||
.blck_size = QK_K,
|
||||
.type_size = sizeof(block_q6_K),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q6_K,
|
||||
.from_float = quantize_row_q6_K,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference,
|
||||
@ -1731,15 +1801,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
},
|
||||
[GGML_TYPE_Q8_K] = {
|
||||
.type_name = "q8_K",
|
||||
.blck_size = QK_K,
|
||||
.type_size = sizeof(block_q8_K),
|
||||
.is_quantized = true,
|
||||
.from_float = quantize_row_q8_K,
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
// For internal test use
|
||||
ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type i) {
|
||||
GGML_ASSERT(i < GGML_TYPE_COUNT);
|
||||
return type_traits[i];
|
||||
ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
|
||||
GGML_ASSERT(type < GGML_TYPE_COUNT);
|
||||
return type_traits[type];
|
||||
}
|
||||
|
||||
|
||||
@ -3648,98 +3722,6 @@ inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
|
||||
*s = idx;
|
||||
}
|
||||
|
||||
//
|
||||
// data types
|
||||
//
|
||||
|
||||
static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
||||
[GGML_TYPE_F32] = 1,
|
||||
[GGML_TYPE_F16] = 1,
|
||||
[GGML_TYPE_Q4_0] = QK4_0,
|
||||
[GGML_TYPE_Q4_1] = QK4_1,
|
||||
[GGML_TYPE_Q5_0] = QK5_0,
|
||||
[GGML_TYPE_Q5_1] = QK5_1,
|
||||
[GGML_TYPE_Q8_0] = QK8_0,
|
||||
[GGML_TYPE_Q8_1] = QK8_1,
|
||||
#ifdef GGML_USE_K_QUANTS
|
||||
[GGML_TYPE_Q2_K] = QK_K,
|
||||
[GGML_TYPE_Q3_K] = QK_K,
|
||||
[GGML_TYPE_Q4_K] = QK_K,
|
||||
[GGML_TYPE_Q5_K] = QK_K,
|
||||
[GGML_TYPE_Q6_K] = QK_K,
|
||||
[GGML_TYPE_Q8_K] = QK_K,
|
||||
#endif
|
||||
[GGML_TYPE_I8] = 1,
|
||||
[GGML_TYPE_I16] = 1,
|
||||
[GGML_TYPE_I32] = 1,
|
||||
};
|
||||
static_assert(GGML_TYPE_COUNT == 19, "GGML_BLCK_SIZE is outdated");
|
||||
|
||||
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
||||
[GGML_TYPE_F32] = sizeof(float),
|
||||
[GGML_TYPE_F16] = sizeof(ggml_fp16_t),
|
||||
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
|
||||
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
|
||||
[GGML_TYPE_Q5_0] = sizeof(block_q5_0),
|
||||
[GGML_TYPE_Q5_1] = sizeof(block_q5_1),
|
||||
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
|
||||
[GGML_TYPE_Q8_1] = sizeof(block_q8_1),
|
||||
#ifdef GGML_USE_K_QUANTS
|
||||
[GGML_TYPE_Q2_K] = sizeof(block_q2_K),
|
||||
[GGML_TYPE_Q3_K] = sizeof(block_q3_K),
|
||||
[GGML_TYPE_Q4_K] = sizeof(block_q4_K),
|
||||
[GGML_TYPE_Q5_K] = sizeof(block_q5_K),
|
||||
[GGML_TYPE_Q6_K] = sizeof(block_q6_K),
|
||||
[GGML_TYPE_Q8_K] = sizeof(block_q8_K),
|
||||
#endif
|
||||
[GGML_TYPE_I8] = sizeof(int8_t),
|
||||
[GGML_TYPE_I16] = sizeof(int16_t),
|
||||
[GGML_TYPE_I32] = sizeof(int32_t),
|
||||
};
|
||||
static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_SIZE is outdated");
|
||||
|
||||
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
||||
[GGML_TYPE_F32] = "f32",
|
||||
[GGML_TYPE_F16] = "f16",
|
||||
[GGML_TYPE_Q4_0] = "q4_0",
|
||||
[GGML_TYPE_Q4_1] = "q4_1",
|
||||
[GGML_TYPE_Q5_0] = "q5_0",
|
||||
[GGML_TYPE_Q5_1] = "q5_1",
|
||||
[GGML_TYPE_Q8_0] = "q8_0",
|
||||
[GGML_TYPE_Q8_1] = "q8_1",
|
||||
[GGML_TYPE_Q2_K] = "q2_K",
|
||||
[GGML_TYPE_Q3_K] = "q3_K",
|
||||
[GGML_TYPE_Q4_K] = "q4_K",
|
||||
[GGML_TYPE_Q5_K] = "q5_K",
|
||||
[GGML_TYPE_Q6_K] = "q6_K",
|
||||
[GGML_TYPE_Q8_K] = "q8_K",
|
||||
[GGML_TYPE_I8] = "i8",
|
||||
[GGML_TYPE_I16] = "i16",
|
||||
[GGML_TYPE_I32] = "i32",
|
||||
};
|
||||
static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_NAME is outdated");
|
||||
|
||||
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
||||
[GGML_TYPE_F32] = false,
|
||||
[GGML_TYPE_F16] = false,
|
||||
[GGML_TYPE_Q4_0] = true,
|
||||
[GGML_TYPE_Q4_1] = true,
|
||||
[GGML_TYPE_Q5_0] = true,
|
||||
[GGML_TYPE_Q5_1] = true,
|
||||
[GGML_TYPE_Q8_0] = true,
|
||||
[GGML_TYPE_Q8_1] = true,
|
||||
[GGML_TYPE_Q2_K] = true,
|
||||
[GGML_TYPE_Q3_K] = true,
|
||||
[GGML_TYPE_Q4_K] = true,
|
||||
[GGML_TYPE_Q5_K] = true,
|
||||
[GGML_TYPE_Q6_K] = true,
|
||||
[GGML_TYPE_Q8_K] = true,
|
||||
[GGML_TYPE_I8] = false,
|
||||
[GGML_TYPE_I16] = false,
|
||||
[GGML_TYPE_I32] = false,
|
||||
};
|
||||
static_assert(GGML_TYPE_COUNT == 19, "GGML_IS_QUANTIZED is outdated");
|
||||
|
||||
static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"NONE",
|
||||
|
||||
@ -4109,7 +4091,7 @@ size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
||||
//
|
||||
// is enough, but just in case, adding the second part
|
||||
|
||||
return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]);
|
||||
return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*ggml_type_size(tensor->type))/ggml_blck_size(tensor->type));
|
||||
}
|
||||
|
||||
size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
|
||||
@ -4119,23 +4101,27 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
|
||||
size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
return (nrows_split*tensor->ne[0]*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
|
||||
return (nrows_split*tensor->ne[0]*ggml_type_size(tensor->type))/ggml_blck_size(tensor->type);
|
||||
}
|
||||
|
||||
int ggml_blck_size(enum ggml_type type) {
|
||||
return GGML_BLCK_SIZE[type];
|
||||
return type_traits[type].blck_size;
|
||||
}
|
||||
|
||||
size_t ggml_type_size(enum ggml_type type) {
|
||||
return GGML_TYPE_SIZE[type];
|
||||
return type_traits[type].type_size;
|
||||
}
|
||||
|
||||
float ggml_type_sizef(enum ggml_type type) {
|
||||
return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type];
|
||||
return ((float)(type_traits[type].type_size))/type_traits[type].blck_size;
|
||||
}
|
||||
|
||||
const char * ggml_type_name(enum ggml_type type) {
|
||||
return GGML_TYPE_NAME[type];
|
||||
return type_traits[type].type_name;
|
||||
}
|
||||
|
||||
bool ggml_is_quantized(enum ggml_type type) {
|
||||
return type_traits[type].is_quantized;
|
||||
}
|
||||
|
||||
const char * ggml_op_name(enum ggml_op op) {
|
||||
@ -4147,7 +4133,7 @@ const char * ggml_op_symbol(enum ggml_op op) {
|
||||
}
|
||||
|
||||
size_t ggml_element_size(const struct ggml_tensor * tensor) {
|
||||
return GGML_TYPE_SIZE[tensor->type];
|
||||
return ggml_type_size(tensor->type);
|
||||
}
|
||||
|
||||
static inline bool ggml_is_scalar(const struct ggml_tensor * tensor) {
|
||||
@ -4185,10 +4171,6 @@ static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct
|
||||
(t0->ne[3] == t1->ne[3]);
|
||||
}
|
||||
|
||||
bool ggml_is_quantized(enum ggml_type type) {
|
||||
return GGML_IS_QUANTIZED[type];
|
||||
}
|
||||
|
||||
enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
||||
enum ggml_type wtype = GGML_TYPE_COUNT;
|
||||
|
||||
@ -4226,8 +4208,8 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
return
|
||||
tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
|
||||
tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/GGML_BLCK_SIZE[tensor->type] &&
|
||||
tensor->nb[0] == ggml_type_size(tensor->type) &&
|
||||
tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
|
||||
tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
|
||||
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||
}
|
||||
@ -4236,7 +4218,7 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
return
|
||||
tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
|
||||
tensor->nb[0] == ggml_type_size(tensor->type) &&
|
||||
tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
|
||||
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||
}
|
||||
@ -4251,7 +4233,7 @@ static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
return
|
||||
tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
|
||||
tensor->nb[0] == ggml_type_size(tensor->type) &&
|
||||
tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
|
||||
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||
}
|
||||
@ -4570,7 +4552,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
|
||||
size_t data_size = 0;
|
||||
|
||||
if (data == NULL && !ctx->no_alloc) {
|
||||
data_size += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
|
||||
data_size += ggml_type_size(type)*(ne[0]/ggml_blck_size(type));
|
||||
for (int i = 1; i < n_dims; i++) {
|
||||
data_size *= ne[i];
|
||||
}
|
||||
@ -4625,8 +4607,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
|
||||
result->ne[i] = ne[i];
|
||||
}
|
||||
|
||||
result->nb[0] = GGML_TYPE_SIZE[type];
|
||||
result->nb[1] = result->nb[0]*(result->ne[0]/GGML_BLCK_SIZE[type]);
|
||||
result->nb[0] = ggml_type_size(type);
|
||||
result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type));
|
||||
for (int i = 2; i < GGML_MAX_DIMS; i++) {
|
||||
result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
|
||||
}
|
||||
@ -7748,7 +7730,7 @@ static void ggml_compute_forward_dup_same_cont(
|
||||
memcpy(
|
||||
((char *) dst->data + ie0*nb0),
|
||||
((char *) src0->data + ie0*nb00),
|
||||
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
||||
(ie1 - ie0) * ggml_type_size(src0->type));
|
||||
}
|
||||
|
||||
}
|
||||
@ -7782,7 +7764,7 @@ static void ggml_compute_forward_dup_f16(
|
||||
|
||||
if (src0->type == dst->type &&
|
||||
ne00 == ne0 &&
|
||||
nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
|
||||
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
|
||||
// copy by rows
|
||||
const size_t rs = ne00*nb00;
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
@ -7840,7 +7822,7 @@ static void ggml_compute_forward_dup_f16(
|
||||
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
||||
|
||||
size_t id = 0;
|
||||
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
||||
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
||||
char * dst_ptr = (char *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
@ -8053,7 +8035,7 @@ static void ggml_compute_forward_dup_f32(
|
||||
|
||||
if (src0->type == dst->type &&
|
||||
ne00 == ne0 &&
|
||||
nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
|
||||
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
|
||||
// copy by rows
|
||||
const size_t rs = ne00*nb00;
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
@ -8092,7 +8074,7 @@ static void ggml_compute_forward_dup_f32(
|
||||
ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
|
||||
|
||||
size_t id = 0;
|
||||
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
||||
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
||||
char * dst_ptr = (char *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
@ -8504,7 +8486,7 @@ static void ggml_compute_forward_add_q_f32(
|
||||
ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
|
||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
@ -8778,7 +8760,7 @@ static void ggml_compute_forward_add1_q_f32(
|
||||
ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
|
||||
|
||||
// we don't support permuted src0
|
||||
GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
|
||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
@ -10634,7 +10616,7 @@ static void ggml_compute_forward_mul_mat(
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
|
||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
@ -10717,7 +10699,7 @@ static void ggml_compute_forward_mul_mat(
|
||||
if (params->type == GGML_TASK_INIT) {
|
||||
if (src1->type != vec_dot_type) {
|
||||
char * wdata = params->wdata;
|
||||
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
||||
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
|
||||
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
@ -10737,7 +10719,7 @@ static void ggml_compute_forward_mul_mat(
|
||||
}
|
||||
|
||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
||||
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
|
||||
|
||||
const int64_t nr0 = ne01; // src0 rows
|
||||
const int64_t nr1 = ne11*ne12*ne13; // src1 rows
|
||||
@ -11210,7 +11192,7 @@ static void ggml_compute_forward_get_rows_q(
|
||||
|
||||
assert( dst->ne[0] == nc);
|
||||
assert( dst->ne[1] == nr);
|
||||
assert(src0->nb[0] == GGML_TYPE_SIZE[type]);
|
||||
assert(src0->nb[0] == ggml_type_size(type));
|
||||
|
||||
for (int i = 0; i < nr; ++i) {
|
||||
const int r = ((int32_t *) src1->data)[i];
|
||||
@ -16387,7 +16369,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
||||
|
||||
size_t cur = 0;
|
||||
if (ggml_is_quantized(node->type)) {
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks;
|
||||
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
@ -16400,7 +16382,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
||||
size_t cur = 0;
|
||||
|
||||
if (ggml_is_quantized(node->src[0]->type)) {
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[0]->ne[0] * n_tasks;
|
||||
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
@ -16412,7 +16394,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
||||
size_t cur = 0;
|
||||
|
||||
if (ggml_is_quantized(node->src[0]->type)) {
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[1]->ne[0] * n_tasks;
|
||||
cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
@ -16495,12 +16477,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
||||
// the threads are still spinning
|
||||
if (node->src[0]->type != GGML_TYPE_F32) {
|
||||
// here we need memory just for single 2D matrix from src0
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src[0]->ne[0]*node->src[0]->ne[1]);
|
||||
cur = ggml_type_size(GGML_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]);
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
if (node->src[1]->type != vec_dot_type) {
|
||||
cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src[1])/GGML_BLCK_SIZE[vec_dot_type];
|
||||
cur = ggml_type_size(vec_dot_type)*ggml_nelements(node->src[1])/ggml_blck_size(vec_dot_type);
|
||||
} else {
|
||||
cur = 0;
|
||||
}
|
||||
@ -18306,8 +18288,8 @@ enum ggml_opt_result ggml_opt_resume(
|
||||
struct ggml_tensor * f) {
|
||||
|
||||
// build forward + backward compute graphs
|
||||
struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
|
||||
struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
|
||||
struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32)+ (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
|
||||
struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32)+ (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
|
||||
|
||||
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
|
||||
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
|
||||
|
6
ggml.h
6
ggml.h
@ -1856,6 +1856,10 @@ extern "C" {
|
||||
typedef void (*ggml_vec_dot_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
|
||||
|
||||
typedef struct {
|
||||
const char * type_name;
|
||||
int blck_size;
|
||||
size_t type_size;
|
||||
bool is_quantized;
|
||||
ggml_to_float_t to_float;
|
||||
ggml_from_float_t from_float;
|
||||
ggml_from_float_t from_float_reference;
|
||||
@ -1863,7 +1867,7 @@ extern "C" {
|
||||
enum ggml_type vec_dot_type;
|
||||
} ggml_type_traits_t;
|
||||
|
||||
ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type i);
|
||||
ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user