mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
llama : speedup tokenization (#2831)
* Speedup tokenization On current master it takes ~3.2 seconds to tokenize Wikitext. With this change it becomes ~525 ms. * Fixit: it was missing the piece after the last found occurence --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
eaa13a48ff
commit
463173a6c0
@ -190,10 +190,14 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||||||
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
|
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
|
||||||
const bool add_bos = is_spm;
|
const bool add_bos = is_spm;
|
||||||
|
|
||||||
|
auto tim1 = std::chrono::high_resolution_clock::now();
|
||||||
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
|
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
|
||||||
|
|
||||||
auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
|
auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
|
||||||
|
|
||||||
|
auto tim2 = std::chrono::high_resolution_clock::now();
|
||||||
|
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
|
||||||
|
|
||||||
const int n_chunk_max = tokens.size() / params.n_ctx;
|
const int n_chunk_max = tokens.size() / params.n_ctx;
|
||||||
|
|
||||||
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
||||||
|
15
llama.cpp
15
llama.cpp
@ -114,12 +114,17 @@ static size_t utf8_len(char src) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
||||||
for (size_t pos = 0; ; pos += replace.length()) {
|
std::string result;
|
||||||
pos = s.find(search, pos);
|
for (size_t pos = 0; ; pos += search.length()) {
|
||||||
if (pos == std::string::npos) break;
|
auto new_pos = s.find(search, pos);
|
||||||
s.erase(pos, search.length());
|
if (new_pos == std::string::npos) {
|
||||||
s.insert(pos, replace);
|
result += s.substr(pos, s.size() - pos);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
result += s.substr(pos, new_pos - pos) + replace;
|
||||||
|
pos = new_pos;
|
||||||
}
|
}
|
||||||
|
s = std::move(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void zeros(std::ofstream & file, size_t n) {
|
static void zeros(std::ofstream & file, size_t n) {
|
||||||
|
Loading…
Reference in New Issue
Block a user