Compare commits

...

11 Commits

Author SHA1 Message Date
Zhenwei Jin
b42f7a1450
Merge 25d4599e19 into ecd5d6b65b 2024-09-22 07:46:54 +02:00
Shankar
ecd5d6b65b
llama: remove redundant loop when constructing ubatch (#9574)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
2024-09-22 04:30:34 +02:00
Molly Sophia
2a63caaa69
RWKV v6: RWKV_WKV op CUDA implementation (#9454)
* ggml: CUDA unary op EXP

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* ggml: rwkv_wkv op CUDA impl

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
2024-09-22 04:29:12 +02:00
zhenweijin
25d4599e19 remove unused files 2024-09-20 19:06:39 +08:00
zhenweijin
8be5d11c0d Merge branch 'ggerganov-gg/tokenizer-cleanup' into refactor-tokenizer 2024-09-20 18:48:58 +08:00
Georgi Gerganov
02629d98f1 llama : make llm_tokenizer more private
ggml-ci
2024-09-20 18:44:38 +08:00
zhenweijin
2ec25dbf27 refactor tokenizer 2024-09-20 18:44:38 +08:00
zhenweijin
d653d25116 Merge branch 'gg/tokenizer-cleanup' of https://github.com/ggerganov/llama.cpp into ggerganov-gg/tokenizer-cleanup 2024-09-20 18:38:13 +08:00
zhenweijin
403758f93e refactor tokenizer 2024-09-20 17:16:55 +08:00
Georgi Gerganov
6e873e561a
llama : make llm_tokenizer more private
ggml-ci
2024-09-20 11:41:51 +03:00
zhenweijin
d949c5844d refactor tokenizer 2024-09-20 15:02:44 +08:00
10 changed files with 403 additions and 142 deletions

View File

@ -34,6 +34,7 @@
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/rwkv-wkv.cuh"
#include <algorithm>
#include <array>
@ -2243,6 +2244,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_HARDSWISH:
ggml_cuda_op_hardswish(ctx, dst);
break;
case GGML_UNARY_OP_EXP:
ggml_cuda_op_exp(ctx, dst);
break;
default:
return false;
}
@ -2345,6 +2349,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
case GGML_OP_RWKV_WKV:
ggml_cuda_op_rwkv_wkv(ctx, dst);
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
ggml_cuda_cross_entropy_loss_back(ctx, dst);
break;
@ -2806,6 +2812,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@ -2967,6 +2974,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV:
return true;
case GGML_OP_FLASH_ATTN_EXT:
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)

View File

@ -0,0 +1,89 @@
#include "common.cuh"
#include "rwkv-wkv.cuh"
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int head_size = CUDA_WKV_BLOCK_SIZE;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float state[head_size];
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
__syncthreads();
_tf[tid] = tf[head_i * head_size + tid];
__syncthreads();
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
__syncthreads();
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
__syncthreads();
const float _v = v[t];
float y = 0;
for (int j = 0; j < head_size; j += 4) {
const float4& k = (float4&)(_k[j]);
const float4& r = (float4&)(_r[j]);
const float4& tf = (float4&)(_tf[j]);
const float4& td = (float4&)(_td[j]);
float4& s = (float4&)(state[j]);
float4 kv;
kv.x = k.x * _v;
kv.y = k.y * _v;
kv.z = k.z * _v;
kv.w = k.w * _v;
y += r.x * (tf.x * kv.x + s.x);
y += r.y * (tf.y * kv.y + s.y);
y += r.z * (tf.z * kv.z + s.z);
y += r.w * (tf.w * kv.w + s.w);
s.x = s.x * td.x + kv.x;
s.y = s.y * td.y + kv.y;
s.z = s.z * td.z + kv.z;
s.w = s.w * td.w + kv.w;
}
dst[t] = y;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
}
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * k_d = (const float *)dst->src[0]->data;
const float * v_d = (const float *)dst->src[1]->data;
const float * r_d = (const float *)dst->src[2]->data;
const float * tf_d = (const float *)dst->src[3]->data;
const float * td_d = (const float *)dst->src[4]->data;
const float * s_d = (const float *)dst->src[5]->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[3];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[2];
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
}

View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_WKV_BLOCK_SIZE 64
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -95,6 +95,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k)
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
}
static __global__ void exp_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = expf(x[i]);
}
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@ -189,6 +198,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
@ -354,6 +368,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;

View File

@ -8,6 +8,7 @@
#define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_SIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
#define CUDA_EXP_BLOCK_SIZE 256
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_SQRT_BLOCK_SIZE 256
@ -32,6 +33,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -50,7 +50,7 @@ struct naive_trie {
res.first->second.insert(key + 1, len - 1, value);
}
}
std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) {
std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) const {
if (len == 0 || offset == len) {
return std::make_pair(key, offset);
}
@ -79,6 +79,15 @@ struct naive_trie {
// impl
//
struct llm_tokenizer {
llm_tokenizer() {}
virtual ~llm_tokenizer() = default;
};
llama_vocab::~llama_vocab() {
delete tokenizer;
}
int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
GGML_ASSERT(token_left.find(' ') == std::string::npos);
GGML_ASSERT(token_left.find('\n') == std::string::npos);
@ -187,10 +196,16 @@ struct llm_bigram_spm {
size_t size;
};
struct llm_tokenizer_spm {
llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {}
struct llm_tokenizer_spm : llm_tokenizer {
llm_tokenizer_spm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
};
struct llm_tokenizer_spm_session {
llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab),
spm_tokenizer(static_cast<const llm_tokenizer_spm *>(vocab.tokenizer)) {}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
// split string into utf8 chars
int index = 0;
size_t offs = 0;
@ -271,7 +286,7 @@ private:
return;
}
resegment(symbols[p->second.first], output);
resegment(symbols[p->second.first], output);
resegment(symbols[p->second.second], output);
}
@ -279,7 +294,6 @@ private:
if (left == -1 || right == -1) {
return;
}
const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
auto token = vocab.token_to_id.find(text);
@ -306,10 +320,10 @@ private:
}
const llama_vocab & vocab;
const llm_tokenizer_spm * spm_tokenizer; // currently unused
std::vector<llm_symbol> symbols;
llm_bigram_spm::queue work_queue;
std::map<std::string, std::pair<int, int>> rev_merge;
};
@ -352,8 +366,8 @@ struct llm_bigram_bpe {
size_t size;
};
struct llm_tokenizer_bpe {
llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
struct llm_tokenizer_bpe : llm_tokenizer {
llm_tokenizer_bpe(const llama_vocab & vocab) : llm_tokenizer() {
GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
switch (vocab.type_pre) {
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
@ -462,7 +476,14 @@ struct llm_tokenizer_bpe {
}
}
void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
std::vector<std::string> regex_exprs;
};
struct llm_tokenizer_bpe_session {
llm_tokenizer_bpe_session(const llama_vocab & vocab) : vocab(vocab),
bpe_tokenizer(static_cast<const llm_tokenizer_bpe *>(vocab.tokenizer)) {}
static void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) {
output.push_back(token_id);
}
@ -501,12 +522,11 @@ struct llm_tokenizer_bpe {
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
int final_prev_index = -1;
const auto word_collection = unicode_regex_split(text, regex_exprs);
const auto word_collection = unicode_regex_split(text, bpe_tokenizer->regex_exprs);
symbols_final.clear();
for (auto & word : word_collection) {
for (const auto & word : word_collection) {
work_queue = llm_bigram_bpe::queue();
symbols.clear();
@ -609,7 +629,6 @@ private:
if (left == -1 || right == -1) {
return;
}
std::string left_token = std::string(symbols[left].text, symbols[left].n);
std::string right_token = std::string(symbols[right].text, symbols[right].n);
@ -633,12 +652,10 @@ private:
}
const llama_vocab & vocab;
std::vector<std::string> regex_exprs;
const llm_tokenizer_bpe * bpe_tokenizer;
std::vector<llm_symbol> symbols;
std::vector<llm_symbol> symbols_final;
llm_bigram_bpe::queue work_queue;
};
@ -646,15 +663,18 @@ private:
// WPM tokenizer
//
struct llm_tokenizer_wpm {
llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
struct llm_tokenizer_wpm : llm_tokenizer {
llm_tokenizer_wpm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
};
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const {
struct llm_tokenizer_wpm_session {
llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab),
wpm_tokenizer(static_cast<const llm_tokenizer_wpm *>(vocab.tokenizer)) {}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
const auto & token_map = vocab.token_to_id;
// normalize and split by whitespace
std::vector<std::string> words = preprocess(text);
// bos token prepended already
// find the longest tokens that form the words
@ -699,7 +719,7 @@ struct llm_tokenizer_wpm {
}
// TODO: reduce string copies by using cpts_offs array
std::vector<std::string> preprocess(const std::string & text) const {
static std::vector<std::string> preprocess(const std::string & text) {
const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
std::vector<std::string> words(1, "");
@ -751,15 +771,17 @@ struct llm_tokenizer_wpm {
//(cpt >= 0xFF00 && cpt <= 0xFFEF);
}
private:
const llama_vocab & vocab;
const llm_tokenizer_wpm * wpm_tokenizer;
};
//
// UGM tokenizer
//
struct llm_tokenizer_ugm {
llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) {
struct llm_tokenizer_ugm : llm_tokenizer {
llm_tokenizer_ugm(const llama_vocab & vocab) : llm_tokenizer() {
if (vocab.precompiled_charsmap.size() > 0) {
size_t charsmap_offset = 0;
@ -805,6 +827,30 @@ struct llm_tokenizer_ugm {
unknown_token_score = min_score - unknown_token_score_penalty;
}
// escaped space symbol - U+2581 (Lower One Eighth Block)
const std::string escaped_space = "\xE2\x96\x81";
const char * prefix_replacements = NULL;
size_t prefix_replacements_size = 0;
const uint32_t * xcda_array = NULL;
size_t xcda_array_size = 0;
struct naive_trie user_defined_token_matcher;
float min_score = FLT_MAX;
float max_score = -FLT_MAX;
float unknown_token_score_penalty = 10.0;
float unknown_token_score;
struct naive_trie token_matcher;
};
struct llm_tokenizer_ugm_session {
llm_tokenizer_ugm_session(const llama_vocab & vocab) : vocab(vocab),
ugm_tokenizer(static_cast<const llm_tokenizer_ugm *>(vocab.tokenizer)) {}
/* This implementation is based on SentencePiece optimized Viterbi algorithm for
* unigram language models. The general idea is to:
* - move along the input sequence in steps of one UTF code point,
@ -843,7 +889,7 @@ struct llm_tokenizer_ugm {
// traverse the token matcher trie to find a matching token
bool single_codepoint_token_found = false;
const struct best_tokenization & current_best = tokenization_results[input_offset];
const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
const struct naive_trie * node = ugm_tokenizer->token_matcher.traverse(normalized[prefix_offset++]);
while (prefix_offset <= input_len && node != NULL) {
// check if we found valid token in prefix
@ -873,7 +919,7 @@ struct llm_tokenizer_ugm {
// if we didn't find a valid token corresponding to the whole UTF code point
// then use unknown token as the tokenization of this UTF code point
if (!single_codepoint_token_found) {
const double challenger_score = current_best.score_sum + unknown_token_score;
const double challenger_score = current_best.score_sum + ugm_tokenizer->unknown_token_score;
prefix_offset = input_offset + n_utf8_code_units;
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
if (challenger_score > current_champ.score_sum) {
@ -905,7 +951,6 @@ struct llm_tokenizer_ugm {
}
private:
const llama_vocab & vocab;
// helper structure for returning normalization results
struct normalization_result {
@ -918,7 +963,7 @@ private:
normalized->clear();
normalized->reserve(input.size() * 3);
const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " ";
const std::string space = vocab.tokenizer_escape_whitespaces ? ugm_tokenizer->escaped_space : " ";
bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
@ -1000,13 +1045,21 @@ private:
size_t xcda_array_size;
};
// this structure stores the best tokenization so far at input_offset
struct best_tokenization {
llama_token token_id;
size_t input_offset;
float score_sum;
};
struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
if (input_offset == input.size()) {
return { &input[input_offset], 0, 0 };
}
// if input prefix matches some user-defined token return this token as normalization result
auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
auto user_defined_token_match =
ugm_tokenizer->user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
if (user_defined_token_match.second > 0) {
return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
}
@ -1014,8 +1067,8 @@ private:
size_t longest_prefix_length = 0;
size_t longest_prefix_offset = 0;
if (xcda_array_size > 0) {
struct xcda_array_view xcda_view(xcda_array, xcda_array_size);
if (ugm_tokenizer->xcda_array_size > 0) {
struct xcda_array_view xcda_view(ugm_tokenizer->xcda_array, ugm_tokenizer->xcda_array_size);
// Find the longest normalized sequence matching the input prefix by walking
// the XOR-compressed compact double array (XCDA) starting from the root node
@ -1051,50 +1104,27 @@ private:
if (longest_prefix_length > 0) {
// we have a match, so return the replacement sequence
if (longest_prefix_offset >= prefix_replacements_size) {
if (longest_prefix_offset >= ugm_tokenizer->prefix_replacements_size) {
throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
}
const char * prefix_replacement = &prefix_replacements[longest_prefix_offset];
const char * prefix_replacement = &(ugm_tokenizer->prefix_replacements)[longest_prefix_offset];
return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
} else {
// check if the input prefix contains a valid sequence of UTF-8 code units
try {
// if yes, return this sequence unmodified
size_t prefix_offset = input_offset;
unicode_cpt_from_utf8(input, prefix_offset);
return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
} catch (std::invalid_argument & /*ex*/) {
// if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
return { "\xEF\xBF\xBD", 3, 1 };
}
}
// check if the input prefix contains a valid sequence of UTF-8 code units
try {
// if yes, return this sequence unmodified
size_t prefix_offset = input_offset;
unicode_cpt_from_utf8(input, prefix_offset);
return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
} catch (std::invalid_argument & /*ex*/) {
// if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
return { "\xEF\xBF\xBD", 3, 1 };
}
}
// escaped space symbol - U+2581 (Lower One Eighth Block)
const std::string escaped_space = "\xE2\x96\x81";
const char * prefix_replacements = NULL;
size_t prefix_replacements_size = 0;
const uint32_t * xcda_array = NULL;
size_t xcda_array_size = 0;
struct naive_trie user_defined_token_matcher;
// this structure stores the best tokenization so far at input_offset
struct best_tokenization {
llama_token token_id;
size_t input_offset;
float score_sum;
};
float min_score = FLT_MAX;
float max_score = -FLT_MAX;
float unknown_token_score_penalty = 10.0;
float unknown_token_score;
struct naive_trie token_matcher;
const llama_vocab & vocab;
const llm_tokenizer_ugm * ugm_tokenizer;
};
//
@ -1155,8 +1185,8 @@ static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escape
return output;
}
struct llm_tokenizer_rwkv {
llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
struct llm_tokenizer_rwkv : llm_tokenizer {
llm_tokenizer_rwkv(const llama_vocab & vocab) : llm_tokenizer() {
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
// For now, we decode the vocab here into the lookup we'll use for tokenization.
@ -1168,11 +1198,17 @@ struct llm_tokenizer_rwkv {
}
}
struct naive_trie token_matcher;
};
struct llm_tokenizer_rwkv_session {
llm_tokenizer_rwkv_session(const llama_vocab & vocab) : vocab(vocab),
rwkv_tokenizer(static_cast<const llm_tokenizer_rwkv &>(*vocab.tokenizer)) {}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
uint32_t position = 0;
while (position < text.size()) {
const struct naive_trie * node = token_matcher.traverse(text[position]);
const struct naive_trie * node = rwkv_tokenizer.token_matcher.traverse(text[position]);
if (node == NULL) {
// no matching token found, add unknown token
output.push_back(vocab.special_unk_id);
@ -1197,11 +1233,33 @@ struct llm_tokenizer_rwkv {
}
}
private:
const llama_vocab & vocab;
struct naive_trie token_matcher;
const llm_tokenizer_rwkv & rwkv_tokenizer;
};
void llama_vocab::init_tokenizer() {
switch (type) {
case LLAMA_VOCAB_TYPE_SPM:
tokenizer = new llm_tokenizer_spm(*this);
break;
case LLAMA_VOCAB_TYPE_BPE:
tokenizer = new llm_tokenizer_bpe(*this);
break;
case LLAMA_VOCAB_TYPE_WPM:
tokenizer = new llm_tokenizer_wpm(*this);
break;
case LLAMA_VOCAB_TYPE_UGM:
tokenizer = new llm_tokenizer_ugm(*this);
break;
case LLAMA_VOCAB_TYPE_RWKV:
tokenizer = new llm_tokenizer_rwkv(*this);
break;
default:
GGML_ABORT("unsupported vocab type");
}
}
//
// (de-) tokenize
//
@ -1263,7 +1321,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
// if a fragment is text ( not yet processed )
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto & raw_text = fragment.raw_text;
const auto & raw_text = fragment.raw_text;
auto raw_text_base_offset = fragment.offset;
auto raw_text_base_length = fragment.length;
@ -1362,7 +1420,13 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
}
}
std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
std::vector<llama_vocab::id> llama_tokenize_internal(
const llama_vocab & vocab,
std::string raw_text,
bool add_special,
bool parse_special) {
GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
std::vector<llama_vocab::id> output;
std::forward_list<fragment_buffer_variant> fragment_buffer;
@ -1399,9 +1463,9 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
llm_tokenizer_spm tokenizer(vocab);
llama_escape_whitespace(raw_text);
tokenizer.tokenize(raw_text, output);
llm_tokenizer_spm_session session(vocab);
session.tokenize(raw_text, output);
is_prev_special = false;
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
@ -1423,10 +1487,11 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
} break;
case LLAMA_VOCAB_TYPE_BPE:
{
llm_tokenizer_bpe tokenizer(vocab);
llm_tokenizer_bpe_session session(vocab);
// it calls some other methods that are not exist in llm_tokenizer,
// here just cast it to bpe tokenizer object
if (add_special) {
tokenizer.append_bos(output);
session.append_bos(output);
}
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@ -1435,15 +1500,15 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
tokenizer.tokenize(raw_text, output);
session.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
tokenizer.append(fragment.token, output);
session.append(fragment.token, output);
}
}
if (add_special) {
tokenizer.append_eos(output);
tokenizer.check_double_bos_eos(output);
session.append_eos(output);
session.check_double_bos_eos(output);
}
} break;
case LLAMA_VOCAB_TYPE_WPM:
@ -1453,7 +1518,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
output.push_back(vocab.special_cls_id);
}
llm_tokenizer_wpm tokenizer(vocab);
llm_tokenizer_wpm_session session(vocab);
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@ -1462,7 +1527,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
tokenizer.tokenize(raw_text, output);
session.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
}
@ -1475,12 +1540,11 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
} break;
case LLAMA_VOCAB_TYPE_UGM:
{
llm_tokenizer_ugm tokenizer(vocab);
if (add_special && vocab.tokenizer_add_bos != 0) {
GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id);
}
llm_tokenizer_ugm_session session(vocab);
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@ -1488,7 +1552,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
tokenizer.tokenize(raw_text, output);
session.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
}
@ -1508,6 +1572,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
} break;
case LLAMA_VOCAB_TYPE_RWKV:
{
llm_tokenizer_rwkv_session session(vocab);
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@ -1516,8 +1581,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
llm_tokenizer_rwkv tokenizer(vocab);
tokenizer.tokenize(raw_text, output);
session.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
}
@ -1634,13 +1698,13 @@ llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
}
int32_t llama_tokenize_impl(
const struct llama_vocab & vocab,
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special) {
const struct llama_vocab & vocab,
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special) {
auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
if (n_tokens_max < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
@ -1765,6 +1829,8 @@ int32_t llama_detokenize_impl(
int32_t text_len_max,
bool remove_special,
bool unparse_special) {
GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
int32_t avail = text_len_max;
int32_t total = 0;

View File

@ -7,6 +7,8 @@
#include <unordered_map>
#include <map>
struct llm_tokenizer;
struct llama_vocab {
using id = llama_token;
using token = std::string;
@ -61,7 +63,14 @@ struct llama_vocab {
std::vector<char> precompiled_charsmap;
llm_tokenizer * tokenizer = nullptr;
llama_vocab() = default;
~llama_vocab();
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
void init_tokenizer();
};
//

View File

@ -3056,18 +3056,14 @@ struct llama_sbatch {
} else {
// simple split
if (batch->n_seq_id) {
for (size_t i = 0; i < length; ++i) {
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
}
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
} else {
for (size_t i = 0; i < length; ++i) {
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
}
}
if (batch->seq_id) {
for (size_t i = 0; i < length; ++i) {
ubatch.seq_id = batch->seq_id + seq.offset;
}
ubatch.seq_id = batch->seq_id + seq.offset;
} else {
for (size_t i = 0; i < length; ++i) {
ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
@ -6404,6 +6400,8 @@ static void llm_load_vocab(
}
GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
vocab.init_tokenizer();
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
// For Fill-In-the-Middle (FIM)/infill models which where converted
@ -6453,11 +6451,11 @@ static void llm_load_vocab(
} else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
vocab.linefeed_id = vocab.special_pad_id;
} else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) {
const std::vector<int> ids = llama_tokenize_internal(vocab, "\n", false);
const std::vector<int> ids = llama_tokenize_internal(model.vocab, "\n", false);
GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
vocab.linefeed_id = ids[0];
} else {
const std::vector<int> ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A
const std::vector<int> ids = llama_tokenize_internal(model.vocab, "\xC4\x8A", false); // U+010A
GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
vocab.linefeed_id = ids[0];
}

View File

@ -1543,6 +1543,36 @@ struct test_ssm_scan : public test_case {
}
};
// GGML_OP_RWKV_WKV
struct test_rwkv_wkv : public test_case {
const ggml_type type;
const int64_t head_count;
const int64_t head_size;
const int64_t n_seq_tokens;
const int64_t n_seqs;
std::string vars() override {
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
}
test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t n_tokens = n_seq_tokens * n_seqs;
ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
return out;
}
};
// GGML_OP_MUL_MAT
struct test_mul_mat : public test_case {
const ggml_type type_a;
@ -3337,6 +3367,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
#if 1
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {

View File

@ -7,6 +7,7 @@
#include <map>
#include <vector>
#include <fstream>
#include <thread>
//static const std::map<std::string, std::vector<llama_token>> & k_tests() {
// static std::map<std::string, std::vector<llama_token>> _k_tests = {
@ -194,45 +195,64 @@ int main(int argc, char **argv) {
const bool add_special = false;
for (const auto & test_kv : k_tests) {
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, false);
// multi-threaded tokenization
const int nthread = std::thread::hardware_concurrency();
std::vector<std::thread> threads(nthread);
printf("\n");
printf("src: '%s'\n", test_kv.first.c_str());
printf("res: '%s'\n", llama_detokenize(ctx, res).c_str());
printf("tok: ");
for (const auto & tok : res) {
printf("%d ", tok);
}
printf("\n");
for (int i = 0; i < nthread; i++) {
threads[i] = std::thread([&, i]() {
for (const auto & test_kv : k_tests) {
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, false);
bool correct = res.size() == test_kv.second.size();
for (int i = 0; i < (int) res.size() && correct; ++i) {
if (test_kv.second[i] != res[i]) {
correct = false;
// here only print the result of the first thread
// because the other threads are running the same tests
if (i != 0) {
continue;
}
printf("\n");
printf("src: '%s'\n", test_kv.first.c_str());
printf("res: '%s'\n", llama_detokenize(ctx, res).c_str());
printf("tok: ");
for (const auto & tok : res) {
printf("%d ", tok);
}
printf("\n");
bool correct = res.size() == test_kv.second.size();
for (int i = 0; i < (int) res.size() && correct; ++i) {
if (test_kv.second[i] != res[i]) {
correct = false;
}
}
if (!correct) {
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
llama_detokenize(ctx, res).c_str(),
llama_detokenize(ctx, test_kv.second).c_str());
fprintf(stderr, "%s : expected tokens: ", __func__);
for (const auto & t : test_kv.second) {
fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
}
fprintf(stderr, "\n");
fprintf(stderr, "%s : got tokens: ", __func__);
for (const auto & t : res) {
fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
}
fprintf(stderr, "\n");
success = false;
}
}
}
if (!correct) {
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
llama_detokenize(ctx, res).c_str(),
llama_detokenize(ctx, test_kv.second).c_str());
fprintf(stderr, "%s : expected tokens: ", __func__);
for (const auto & t : test_kv.second) {
fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
}
fprintf(stderr, "\n");
fprintf(stderr, "%s : got tokens: ", __func__);
for (const auto & t : res) {
fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
}
fprintf(stderr, "\n");
success = false;
}
});
}
for (int i = 0; i < nthread; i++) {
threads[i].join();
}
// single threaded tokenization
if (!fname_text.empty()) {
fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());