server : refactor (#5882)

* server : refactoring (wip)

* server : remove llava/clip objects from build

* server : fix empty prompt handling + all slots idle logic

* server : normalize id vars

* server : code style

* server : simplify model chat template validation

* server : code style

* server : minor

* llama : llama_chat_apply_template support null buf

* server : do not process embedding requests when disabled

* server : reorganize structs and enums + naming fixes

* server : merge oai.hpp in utils.hpp

* server : refactor system prompt update at start

* server : disable cached prompts with self-extend

* server : do not process more than n_batch tokens per iter

* server: tests: embeddings use a real embeddings model (#5908)

* server, tests : bump batch to fit 1 embedding prompt

* server: tests: embeddings fix build type Debug is randomly failing (#5911)

* server: tests: embeddings, use different KV Cache size

* server: tests: embeddings, fixed prompt do not exceed n_batch, increase embedding timeout, reduce number of concurrent embeddings

* server: tests: embeddings, no need to wait for server idle as it can timout

* server: refactor: clean up http code (#5912)

* server : avoid n_available var

ggml-ci

* server: refactor: better http codes

* server : simplify json parsing + add comment about t_last

* server : rename server structs

* server : allow to override FQDN in tests

ggml-ci

* server : add comments

---------

Co-authored-by: Pierrick Hymbert <pierrick.hymbert@gmail.com>
This commit is contained in:
Georgi Gerganov 2024-03-07 11:41:53 +02:00 committed by GitHub
parent ceca1aef07
commit 2002bc96bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 2327 additions and 2773 deletions

View File

@ -58,7 +58,8 @@ jobs:
cmake \ cmake \
python3-pip \ python3-pip \
wget \ wget \
psmisc psmisc \
language-pack-en
- name: Build - name: Build
id: cmake_build id: cmake_build

View File

@ -724,10 +724,9 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
server: examples/server/server.cpp examples/server/oai.hpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h examples/llava/llava.h examples/llava/llava.cpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h %.hpp $< examples/llava/clip.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) -o $@ $(LDFLAGS) $(LWINSOCK2)
gguf: examples/gguf/gguf.cpp ggml.o $(OBJS) gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

View File

@ -13,7 +13,7 @@ async def main():
model_url = "http://127.0.0.1:6900" model_url = "http://127.0.0.1:6900"
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async( responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
url= f"{model_url}/embedding", url= f"{model_url}/embedding",
json= {"content": str(i)*1024} json= {"content": str(0)*1024}
) for i in range(n)]) ) for i in range(n)])
for response in responses: for response in responses:

View File

@ -1,12 +1,12 @@
set(TARGET server) set(TARGET server)
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON) option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_executable(${TARGET} server.cpp oai.hpp utils.hpp json.hpp httplib.h) add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
install(TARGETS ${TARGET} RUNTIME) install(TARGETS ${TARGET} RUNTIME)
target_compile_definitions(${TARGET} PRIVATE target_compile_definitions(${TARGET} PRIVATE
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}> SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
) )
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
if (WIN32) if (WIN32)
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32) TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
endif() endif()

View File

@ -436,7 +436,7 @@ Notice that each `probs` is an array of length `n_probs`.
"next_token": { "next_token": {
"has_next_token": true, "has_next_token": true,
"n_remain": -1, "n_remain": -1,
"num_tokens_predicted": 0, "n_decoded": 0,
"stopped_eos": false, "stopped_eos": false,
"stopped_limit": false, "stopped_limit": false,
"stopped_word": false, "stopped_word": false,

View File

@ -1,225 +0,0 @@
#pragma once
#include <string>
#include <vector>
#include <set>
#include <mutex>
#include <condition_variable>
#include <unordered_map>
#include "json.hpp"
#include "utils.hpp"
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
using json = nlohmann::json;
inline static json oaicompat_completion_params_parse(
const struct llama_model * model,
const json &body, /* openai api json semantics */
const std::string &chat_template)
{
json llama_params;
llama_params["__oaicompat"] = true;
// Map OpenAI parameters to llama.cpp parameters
//
// For parameters that are defined by the OpenAI documentation (e.g.
// temperature), we explicitly specify OpenAI's intended default; we
// need to do that because sometimes OpenAI disagrees with llama.cpp
//
// https://platform.openai.com/docs/api-reference/chat/create
llama_sampling_params default_sparams;
llama_params["model"] = json_value(body, "model", std::string("unknown"));
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
llama_params["temperature"] = json_value(body, "temperature", 0.0);
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
llama_params["top_p"] = json_value(body, "top_p", 1.0);
llama_params["n_predict"] = json_value(body, "max_tokens", -1);
llama_params["logit_bias"] = json_value(body, "logit_bias",json::object());
llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0);
llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED);
llama_params["stream"] = json_value(body, "stream", false);
llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat);
llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl);
llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p);
llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
llama_params["ignore_eos"] = json_value(body, "ignore_eos", false);
llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z);
if (body.count("grammar") != 0) {
llama_params["grammar"] = json_value(body, "grammar", json::object());
}
// Handle 'stop' field
if (body.contains("stop") && body["stop"].is_string()) {
llama_params["stop"] = json::array({body["stop"].get<std::string>()});
} else {
llama_params["stop"] = json_value(body, "stop", json::array());
}
// Ensure there is ChatML-specific end sequence among stop words
llama_params["stop"].push_back("<|im_end|>");
return llama_params;
}
inline static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false)
{
json result = response.result_json;
bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
std::string content = json_value(result, "content", std::string(""));
std::string finish_reason = "length";
if (stopped_word || stopped_eos) {
finish_reason = "stop";
}
json choices =
streaming ? json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}})
: json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"message", json{{"content", content},
{"role", "assistant"}}}}});
std::time_t t = std::time(0);
json res =
json{{"choices", choices},
{"created", t},
{"model",
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
{"usage",
json{{"completion_tokens", num_tokens_predicted},
{"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}}},
{"id", gen_chatcmplid()}};
if (server_verbose) {
res["__verbose"] = result;
}
if (result.contains("completion_probabilities")) {
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
}
return res;
}
// return value is vector as there is one case where we might need to generate two responses
inline static std::vector<json> format_partial_response_oaicompat(const task_result &response) {
json result = response.result_json;
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
return std::vector<json>({response.result_json});
}
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
bool stopped_word = json_value(result, "stopped_word", false);
bool stopped_eos = json_value(result, "stopped_eos", false);
bool stopped_limit = json_value(result, "stopped_limit", false);
std::string content = json_value(result, "content", std::string(""));
std::string finish_reason;
if (stopped_word || stopped_eos) {
finish_reason = "stop";
}
if (stopped_limit) {
finish_reason = "length";
}
std::time_t t = std::time(0);
json choices;
if (!finish_reason.empty()) {
choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
} else {
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
}}}})},
{"created", t},
{"id", gen_chatcmplid()},
{"model", modelname},
{"object", "chat.completion.chunk"}};
json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"content", content}}}
}})},
{"created", t},
{"id", gen_chatcmplid()},
{"model", modelname},
{"object", "chat.completion.chunk"}};
return std::vector<json>({initial_ret, second_ret});
}
} else {
// Some idiosyncrasy in task processing logic makes several trailing calls
// with empty content, we ignore these at the calee site.
if (content.empty()) {
return std::vector<json>({json::object()});
}
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json{
{"content", content},
}},
}});
}
}
json ret = json{{"choices", choices},
{"created", t},
{"id", gen_chatcmplid()},
{"model", modelname},
{"object", "chat.completion.chunk"}};
return std::vector<json>({ret});
}
inline static json format_embeddings_response_oaicompat(const json &request, const json &embeddings)
{
json res =
json{
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage",
json{{"prompt_tokens", 0},
{"total_tokens", 0}}},
{"data", embeddings}
};
return res;
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,94 @@
@llama.cpp
@embeddings
Feature: llama.cpp server
Background: Server startup
Given a server listening on localhost:8080
And a model file bert-bge-small/ggml-model-f16.gguf from HF repo ggml-org/models
And a model alias bert-bge-small
And 42 as server seed
And 2 slots
And 1024 as batch size
And 2048 KV cache size
And embeddings extraction
Then the server is starting
Then the server is healthy
Scenario: Embedding
When embeddings are computed for:
"""
What is the capital of Bulgaria ?
"""
Then embeddings are generated
Scenario: OAI Embeddings compatibility
Given a model bert-bge-small
When an OAI compatible embeddings computation request for:
"""
What is the capital of Spain ?
"""
Then embeddings are generated
Scenario: OAI Embeddings compatibility with multiple inputs
Given a model bert-bge-small
Given a prompt:
"""
In which country Paris is located ?
"""
And a prompt:
"""
Is Madrid the capital of Spain ?
"""
When an OAI compatible embeddings computation request for multiple inputs
Then embeddings are generated
Scenario: Multi users embeddings
Given a prompt:
"""
Write a very long story about AI.
"""
And a prompt:
"""
Write another very long music lyrics.
"""
And a prompt:
"""
Write a very long poem.
"""
And a prompt:
"""
Write a very long joke.
"""
Given concurrent embedding requests
Then the server is busy
Then the server is idle
Then all embeddings are generated
Scenario: Multi users OAI compatibility embeddings
Given a prompt:
"""
In which country Paris is located ?
"""
And a prompt:
"""
Is Madrid the capital of Spain ?
"""
And a prompt:
"""
What is the biggest US city ?
"""
And a prompt:
"""
What is the capital of Bulgaria ?
"""
And a model bert-bge-small
Given concurrent OAI embedding requests
Then the server is busy
Then the server is idle
Then all embeddings are generated
Scenario: All embeddings should be the same
Given 10 fixed prompts
And a model bert-bge-small
Given concurrent OAI embedding requests
Then all embeddings are the same

View File

@ -9,7 +9,6 @@ Feature: Parallel
And 512 as batch size And 512 as batch size
And 64 KV cache size And 64 KV cache size
And 2 slots And 2 slots
And embeddings extraction
And continuous batching And continuous batching
Then the server is starting Then the server is starting
Then the server is healthy Then the server is healthy
@ -99,48 +98,3 @@ Feature: Parallel
Then the server is busy Then the server is busy
Then the server is idle Then the server is idle
Then all prompts are predicted Then all prompts are predicted
Scenario: Multi users embeddings
Given a prompt:
"""
Write a very long story about AI.
"""
And a prompt:
"""
Write another very long music lyrics.
"""
And a prompt:
"""
Write a very long poem.
"""
And a prompt:
"""
Write a very long joke.
"""
Given concurrent embedding requests
Then the server is busy
Then the server is idle
Then all embeddings are generated
Scenario: Multi users OAI compatibility embeddings
Given a prompt:
"""
In which country Paris is located ?
"""
And a prompt:
"""
Is Madrid the capital of Spain ?
"""
And a prompt:
"""
What is the biggest US city ?
"""
And a prompt:
"""
What is the capital of Bulgaria ?
"""
And a model tinyllama-2
Given concurrent OAI embedding requests
Then the server is busy
Then the server is idle
Then all embeddings are generated

View File

@ -49,34 +49,6 @@ Feature: llama.cpp server
| llama-2 | Book | What is the best book | 8 | (Mom\|what)+ | 8 | disabled | | llama-2 | Book | What is the best book | 8 | (Mom\|what)+ | 8 | disabled |
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks\|happy\|bird)+ | 32 | enabled | | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks\|happy\|bird)+ | 32 | enabled |
Scenario: Embedding
When embeddings are computed for:
"""
What is the capital of Bulgaria ?
"""
Then embeddings are generated
Scenario: OAI Embeddings compatibility
Given a model tinyllama-2
When an OAI compatible embeddings computation request for:
"""
What is the capital of Spain ?
"""
Then embeddings are generated
Scenario: OAI Embeddings compatibility with multiple inputs
Given a model tinyllama-2
Given a prompt:
"""
In which country Paris is located ?
"""
And a prompt:
"""
Is Madrid the capital of Spain ?
"""
When an OAI compatible embeddings computation request for multiple inputs
Then embeddings are generated
Scenario: Tokenize / Detokenize Scenario: Tokenize / Detokenize
When tokenizing: When tokenizing:
""" """

View File

@ -10,6 +10,7 @@ from contextlib import closing
from re import RegexFlag from re import RegexFlag
import aiohttp import aiohttp
import numpy as np
import openai import openai
from behave import step from behave import step
from behave.api.async_step import async_run_until_complete from behave.api.async_step import async_run_until_complete
@ -24,6 +25,9 @@ def step_server_config(context, server_fqdn, server_port):
if 'PORT' in os.environ: if 'PORT' in os.environ:
context.server_port = int(os.environ['PORT']) context.server_port = int(os.environ['PORT'])
print(f"$PORT set, overriding server port with to {context.server_port}") print(f"$PORT set, overriding server port with to {context.server_port}")
if 'FQDN' in os.environ:
context.server_fqdn = os.environ['FQDN']
print(f"$FQDN set, overriding server fqdn with to {context.server_fqdn}")
context.base_url = f'http://{context.server_fqdn}:{context.server_port}' context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
@ -34,6 +38,7 @@ def step_server_config(context, server_fqdn, server_port):
context.n_ga_w = None context.n_ga_w = None
context.n_gpu_layer = None context.n_gpu_layer = None
context.n_predict = None context.n_predict = None
context.n_prompts = 0
context.n_server_predict = None context.n_server_predict = None
context.n_slots = None context.n_slots = None
context.prompt_prefix = None context.prompt_prefix = None
@ -202,6 +207,7 @@ def step_n_tokens_predicted(context, predicted_n):
@step(u'a user prompt {user_prompt}') @step(u'a user prompt {user_prompt}')
def step_user_prompt(context, user_prompt): def step_user_prompt(context, user_prompt):
context.prompts.append(user_prompt) context.prompts.append(user_prompt)
context.n_prompts = len(context.prompts)
@step(u'a system prompt {system_prompt}') @step(u'a system prompt {system_prompt}')
@ -290,6 +296,12 @@ def step_prompt_passkey(context):
context.prompt_passkey = context.text context.prompt_passkey = context.text
@step(u'{n_prompts:d} fixed prompts')
def step_fixed_prompts(context, n_prompts):
context.prompts.extend([str(0)*(context.n_batch if context.n_batch is not None else 512) for i in range(n_prompts)])
context.n_prompts = n_prompts
@step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk') @step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk')
def step_prompt_passkey(context, passkey, i_pos): def step_prompt_passkey(context, passkey, i_pos):
prompt = "" prompt = ""
@ -301,6 +313,7 @@ def step_prompt_passkey(context, passkey, i_pos):
passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m" passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n") print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix) context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
context.n_prompts = len(context.prompts)
@step(u'an OAI compatible chat completions request with {api_error} api error') @step(u'an OAI compatible chat completions request with {api_error} api error')
@ -341,11 +354,13 @@ async def step_oai_chat_completions(context, api_error):
@step(u'a prompt') @step(u'a prompt')
def step_a_prompt(context): def step_a_prompt(context):
context.prompts.append(context.text) context.prompts.append(context.text)
context.n_prompts = len(context.prompts)
@step(u'a prompt {prompt}') @step(u'a prompt {prompt}')
def step_a_prompt_prompt(context, prompt): def step_a_prompt_prompt(context, prompt):
context.prompts.append(prompt) context.prompts.append(prompt)
context.n_prompts = len(context.prompts)
@step(u'concurrent completion requests') @step(u'concurrent completion requests')
@ -430,25 +445,47 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
@step(u'embeddings are computed for') @step(u'embeddings are computed for')
@async_run_until_complete @async_run_until_complete
async def step_compute_embedding(context): async def step_compute_embedding(context):
context.n_prompts = 1
context.embeddings = await request_embedding(context.text, base_url=context.base_url) context.embeddings = await request_embedding(context.text, base_url=context.base_url)
@step(u'all embeddings are the same')
@async_run_until_complete
async def step_all_embeddings_are_the_same(context):
n_embedding_requests = await gather_tasks_results(context)
assert n_embedding_requests > 0
embeddings = []
for i in range(n_embedding_requests):
embedding = context.tasks_result.pop().pop()
embeddings.append(embedding)
assert_embeddings(embedding)
n = len(embeddings)
for i in range(n-1):
for j in range(i+1, n):
embedding1 = np.array(embeddings[i])
embedding2 = np.array(embeddings[j])
if context.debug:
print(f"embedding1: {embedding1[-8:]}\n")
print(f"embedding2: {embedding2[-8:]}\n")
similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
msg = f"Similarity between {i} and {j}: {similarity:.10f}"
if context.debug:
print(f"{msg}\n")
assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg
@step(u'embeddings are generated') @step(u'embeddings are generated')
def step_assert_embeddings(context): def step_assert_embeddings(context):
if len(context.prompts) == 0: assert context.n_prompts == len(context.embeddings), (f"unexpected response:\n"
assert_embeddings(context.embeddings) f"context.n_prompts={context.n_prompts}\n"
else: f"context.embeddings={context.embeddings}")
assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n" for embedding in context.embeddings:
f"context.prompts={context.prompts}\n" assert_embeddings(embedding)
f"context.embeddings={context.embeddings}")
for embedding in context.embeddings:
context.prompts.pop()
assert_embeddings(embedding)
@step(u'an OAI compatible embeddings computation request for') @step(u'an OAI compatible embeddings computation request for')
@async_run_until_complete @async_run_until_complete
async def step_oai_compute_embeddings(context): async def step_oai_compute_embeddings(context):
context.n_prompts = 1
context.embeddings = await request_oai_embeddings(context.text, context.embeddings = await request_oai_embeddings(context.text,
base_url=context.base_url, base_url=context.base_url,
user_api_key=context.user_api_key, user_api_key=context.user_api_key,
@ -462,6 +499,7 @@ async def step_oai_compute_embeddings_multiple_inputs(context):
base_url=context.base_url, base_url=context.base_url,
user_api_key=context.user_api_key, user_api_key=context.user_api_key,
model=context.model) model=context.model)
context.prompts.clear()
@step(u'concurrent embedding requests') @step(u'concurrent embedding requests')
@ -488,9 +526,9 @@ async def step_concurrent_oai_embedding_requests(context):
@async_run_until_complete() @async_run_until_complete()
async def all_embeddings_are_generated(context): async def all_embeddings_are_generated(context):
n_embedding_requests = await gather_tasks_results(context) n_embedding_requests = await gather_tasks_results(context)
assert n_embedding_requests > 0 assert n_embedding_requests == context.n_prompts
for i in range(n_embedding_requests): for i in range(n_embedding_requests):
assert_embeddings(context.tasks_result.pop()) assert_embeddings(context.tasks_result.pop().pop())
@step(u'tokenizing') @step(u'tokenizing')
@ -588,11 +626,11 @@ def step_supported_models(context, i_model, param, preposition, param_value):
async def concurrent_requests(context, f_completion, *args, **kwargs): async def concurrent_requests(context, f_completion, *args, **kwargs):
n_prompts = len(context.prompts) context.n_prompts = len(context.prompts)
if context.debug: if context.debug:
print(f"starting {n_prompts} concurrent completion requests...") print(f"starting {context.n_prompts} concurrent completion requests...")
assert n_prompts > 0 assert context.n_prompts > 0
for prompt_no in range(n_prompts): for prompt_no in range(context.n_prompts):
shifted_args = [context.prompts.pop(), *args] shifted_args = [context.prompts.pop(), *args]
context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs))) context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
@ -765,7 +803,7 @@ async def request_embedding(content, base_url=None):
}) as response: }) as response:
assert response.status == 200 assert response.status == 200
response_json = await response.json() response_json = await response.json()
return response_json['embedding'] return [response_json['embedding']]
async def request_oai_embeddings(input, async def request_oai_embeddings(input,
@ -775,6 +813,7 @@ async def request_oai_embeddings(input,
user_api_key = user_api_key if user_api_key is not None else 'nope' user_api_key = user_api_key if user_api_key is not None else 'nope'
if async_client: if async_client:
origin = 'llama.cpp' origin = 'llama.cpp'
headers=[]
if user_api_key is not None: if user_api_key is not None:
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@ -783,14 +822,21 @@ async def request_oai_embeddings(input,
"input": input, "input": input,
"model": model, "model": model,
}, },
headers=headers) as response: headers=headers,
timeout=3600) as response:
assert response.status == 200, f"received status code not expected: {response.status}" assert response.status == 200, f"received status code not expected: {response.status}"
assert response.headers['Access-Control-Allow-Origin'] == origin assert response.headers['Access-Control-Allow-Origin'] == origin
assert response.headers['Content-Type'] == "application/json; charset=utf-8" assert response.headers['Content-Type'] == "application/json; charset=utf-8"
response_json = await response.json() response_json = await response.json()
assert response_json['model'] == model, f"invalid model received: {response_json['model']}" assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
assert response_json['object'] == 'list' assert response_json['object'] == 'list'
return response_json['data'] if isinstance(input, collections.abc.Sequence):
embeddings = []
for an_oai_embeddings in response_json['data']:
embeddings.append(an_oai_embeddings['embedding'])
else:
embeddings = [response_json['data']['embedding']]
return embeddings
else: else:
openai.api_key = user_api_key openai.api_key = user_api_key
openai.api_base = f'{base_url}/v1' openai.api_base = f'{base_url}/v1'
@ -804,7 +850,7 @@ async def request_oai_embeddings(input,
for an_oai_embeddings in oai_embeddings.data: for an_oai_embeddings in oai_embeddings.data:
embeddings.append(an_oai_embeddings.embedding) embeddings.append(an_oai_embeddings.embedding)
else: else:
embeddings = oai_embeddings.data.embedding embeddings = [oai_embeddings.data.embedding]
return embeddings return embeddings
@ -899,6 +945,8 @@ def assert_embeddings(embeddings):
assert len(embeddings) > 0 assert len(embeddings) > 0
embeddings_computed = False embeddings_computed = False
for emb in embeddings: for emb in embeddings:
if not isinstance(emb, float):
assert False, f"Bad embeddings: {embeddings}"
if emb != 0: if emb != 0:
embeddings_computed = True embeddings_computed = True
assert embeddings_computed, f"Embeddings: {embeddings}" assert embeddings_computed, f"Embeddings: {embeddings}"

View File

@ -1,5 +1,6 @@
aiohttp~=3.9.3 aiohttp~=3.9.3
behave~=1.2.6 behave~=1.2.6
huggingface_hub~=0.20.3 huggingface_hub~=0.20.3
numpy~=1.24.4
openai~=0.25.0 openai~=0.25.0
prometheus-client~=0.20.0 prometheus-client~=0.20.0

View File

@ -1,15 +1,16 @@
#pragma once #pragma once
#include <string> #include "llama.h"
#include <vector> #include "common.h"
#include <set>
#include <mutex>
#include <condition_variable>
#include <unordered_map>
#include "json.hpp" #include "json.hpp"
#include "../llava/clip.h" #include <string>
#include <vector>
#include <sstream>
#include <random>
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
using json = nlohmann::json; using json = nlohmann::json;
@ -37,83 +38,35 @@ extern bool server_log_json;
#define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
enum server_state { template <typename T>
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet static T json_value(const json &body, const std::string &key, const T &default_value) {
SERVER_STATE_READY, // Server is ready and model is loaded // Fallback null to default value
SERVER_STATE_ERROR // An error occurred, load_model failed return body.contains(key) && !body.at(key).is_null()
}; ? body.value(key, default_value)
: default_value;
enum task_type { }
TASK_TYPE_COMPLETION,
TASK_TYPE_CANCEL,
TASK_TYPE_NEXT_RESPONSE,
TASK_TYPE_METRICS
};
struct task_server {
int id = -1; // to be filled by llama_server_queue
int target_id;
task_type type;
json data;
bool infill_mode = false;
bool embedding_mode = false;
int multitask_id = -1;
};
struct task_result {
int id;
int multitask_id = -1;
bool stop;
bool error;
json result_json;
};
struct task_multi {
int id;
std::set<int> subtasks_remaining{};
std::vector<task_result> results{};
};
// completion token output with probabilities
struct completion_token_output {
struct token_prob
{
llama_token tok;
float prob;
};
std::vector<token_prob> probs;
llama_token tok;
std::string text_to_send;
};
struct token_translator {
llama_context * ctx;
std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); }
std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); }
};
static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) { static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) {
std::stringstream ss_tid; std::stringstream ss_tid;
ss_tid << std::this_thread::get_id(); ss_tid << std::this_thread::get_id();
json log = nlohmann::ordered_json{ json log = nlohmann::ordered_json{
{"tid", ss_tid.str()}, {"tid", ss_tid.str()},
{"timestamp", time(nullptr)}, {"timestamp", time(nullptr)},
}; };
if (server_log_json) { if (server_log_json) {
log.merge_patch( log.merge_patch( {
{ {"level", level},
{"level", level}, {"function", function},
{"function", function}, {"line", line},
{"line", line}, {"msg", message},
{"msg", message}, });
});
if (!extra.empty()) { if (!extra.empty()) {
log.merge_patch(extra); log.merge_patch(extra);
} }
std::cout << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush; printf("%s\n", log.dump(-1, ' ', false, json::error_handler_t::replace).c_str());
} else { } else {
char buf[1024]; char buf[1024];
snprintf(buf, 1024, "%4s [%24s] %s", level, function, message); snprintf(buf, 1024, "%4s [%24s] %s", level, function, message);
@ -136,22 +89,13 @@ static inline void server_log(const char *level, const char *function, int line,
} }
// //
// server utils // chat template utils
// //
template <typename T>
static T json_value(const json &body, const std::string &key, const T &default_value) {
// Fallback null to default value
return body.contains(key) && !body.at(key).is_null()
? body.value(key, default_value)
: default_value;
}
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
inline bool verify_custom_template(const std::string & tmpl) { inline bool verify_custom_template(const std::string & tmpl) {
llama_chat_message chat[] = {{"user", "test"}}; llama_chat_message chat[] = {{"user", "test"}};
std::vector<char> buf(1); int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size());
return res >= 0; return res >= 0;
} }
@ -163,7 +107,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
std::vector<llama_chat_message> chat(messages.size()); std::vector<llama_chat_message> chat(messages.size());
for (size_t i = 0; i < messages.size(); ++i) { for (size_t i = 0; i < messages.size(); ++i) {
auto &curr_msg = messages[i]; const auto & curr_msg = messages[i];
str[i*2 + 0] = json_value(curr_msg, "role", std::string("")); str[i*2 + 0] = json_value(curr_msg, "role", std::string(""));
str[i*2 + 1] = json_value(curr_msg, "content", std::string("")); str[i*2 + 1] = json_value(curr_msg, "content", std::string(""));
alloc_size += str[i*2 + 1].length(); alloc_size += str[i*2 + 1].length();
@ -183,261 +127,13 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
} }
std::string formatted_chat(buf.data(), res); const std::string formatted_chat(buf.data(), res);
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
return formatted_chat; return formatted_chat;
} }
//
// work queue utils
//
struct llama_server_queue {
int id = 0;
std::mutex mutex_tasks;
bool running;
// queues
std::vector<task_server> queue_tasks;
std::vector<task_server> queue_tasks_deferred;
std::vector<task_multi> queue_multitasks;
std::condition_variable condition_tasks;
// callback functions
std::function<void(task_server&)> callback_new_task;
std::function<void(task_multi&)> callback_finish_multitask;
std::function<void(void)> callback_run_slots;
// Add a new task to the end of the queue
int post(task_server task) {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (task.id == -1) {
task.id = id++;
LOG_VERBOSE("new task id", {{"new_id", task.id}});
}
queue_tasks.push_back(std::move(task));
condition_tasks.notify_one();
return task.id;
}
// Add a new task, but defer until one slot is available
void defer(task_server task) {
std::unique_lock<std::mutex> lock(mutex_tasks);
queue_tasks_deferred.push_back(std::move(task));
}
// Get the next id for creating anew task
int get_new_id() {
std::unique_lock<std::mutex> lock(mutex_tasks);
int new_id = id++;
LOG_VERBOSE("new task id", {{"new_id", new_id}});
return new_id;
}
// Register function to process a new task
void on_new_task(std::function<void(task_server&)> callback) {
callback_new_task = callback;
}
// Register function to process a multitask when it is finished
void on_finish_multitask(std::function<void(task_multi&)> callback) {
callback_finish_multitask = callback;
}
// Register the function to be called when all slots data is ready to be processed
void on_run_slots(std::function<void(void)> callback) {
callback_run_slots = callback;
}
// Call when the state of one slot is changed
void notify_slot_changed() {
// move deferred tasks back to main loop
std::unique_lock<std::mutex> lock(mutex_tasks);
for (auto & task : queue_tasks_deferred) {
queue_tasks.push_back(std::move(task));
}
queue_tasks_deferred.clear();
}
// end the start_loop routine
void terminate() {
{
std::unique_lock<std::mutex> lock(mutex_tasks);
running = false;
}
condition_tasks.notify_all();
}
/**
* Main loop consists of these steps:
* - Wait until a new task arrives
* - Process the task (i.e. maybe copy data into slot)
* - Check if multitask is finished
* - Run all slots
*/
void start_loop() {
running = true;
while (true) {
LOG_VERBOSE("new task may arrive", {});
{
while (true)
{
std::unique_lock<std::mutex> lock(mutex_tasks);
if (queue_tasks.empty()) {
lock.unlock();
break;
}
task_server task = queue_tasks.front();
queue_tasks.erase(queue_tasks.begin());
lock.unlock();
LOG_VERBOSE("callback_new_task", {{"task_id", task.id}});
callback_new_task(task);
}
LOG_VERBOSE("update_multitasks", {});
// check if we have any finished multitasks
auto queue_iterator = queue_multitasks.begin();
while (queue_iterator != queue_multitasks.end())
{
if (queue_iterator->subtasks_remaining.empty())
{
// all subtasks done == multitask is done
task_multi current_multitask = *queue_iterator;
callback_finish_multitask(current_multitask);
// remove this multitask
queue_iterator = queue_multitasks.erase(queue_iterator);
}
else
{
++queue_iterator;
}
}
// all tasks in the current loop is processed, slots data is now ready
LOG_VERBOSE("callback_run_slots", {});
callback_run_slots();
}
LOG_VERBOSE("wait for new task", {});
// wait for new task
{
std::unique_lock<std::mutex> lock(mutex_tasks);
if (queue_tasks.empty()) {
if (!running) {
LOG_VERBOSE("ending start_loop", {});
return;
}
condition_tasks.wait(lock, [&]{
return (!queue_tasks.empty() || !running);
});
}
}
}
}
//
// functions to manage multitasks
//
// add a multitask by specifying the id of all subtask (subtask is a task_server)
void add_multitask(int multitask_id, std::vector<int>& sub_ids)
{
std::lock_guard<std::mutex> lock(mutex_tasks);
task_multi multi;
multi.id = multitask_id;
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
queue_multitasks.push_back(multi);
}
// updatethe remaining subtasks, while appending results to multitask
void update_multitask(int multitask_id, int subtask_id, task_result& result)
{
std::lock_guard<std::mutex> lock(mutex_tasks);
for (auto& multitask : queue_multitasks)
{
if (multitask.id == multitask_id)
{
multitask.subtasks_remaining.erase(subtask_id);
multitask.results.push_back(result);
}
}
}
};
struct llama_server_response {
typedef std::function<void(int, int, task_result&)> callback_multitask_t;
callback_multitask_t callback_update_multitask;
// for keeping track of all tasks waiting for the result
std::set<int> waiting_task_ids;
// the main result queue
std::vector<task_result> queue_results;
std::mutex mutex_results;
std::condition_variable condition_results;
// add the task_id to the list of tasks waiting for response
void add_waiting_task_id(int task_id) {
LOG_VERBOSE("waiting for task id", {{"task_id", task_id}});
std::unique_lock<std::mutex> lock(mutex_results);
waiting_task_ids.insert(task_id);
}
// when the request is finished, we can remove task associated with it
void remove_waiting_task_id(int task_id) {
LOG_VERBOSE("remove waiting for task id", {{"task_id", task_id}});
std::unique_lock<std::mutex> lock(mutex_results);
waiting_task_ids.erase(task_id);
}
// This function blocks the thread until there is a response for this task_id
task_result recv(int task_id) {
while (true)
{
std::unique_lock<std::mutex> lock(mutex_results);
condition_results.wait(lock, [&]{
return !queue_results.empty();
});
for (int i = 0; i < (int) queue_results.size(); i++)
{
if (queue_results[i].id == task_id)
{
assert(queue_results[i].multitask_id == -1);
task_result res = queue_results[i];
queue_results.erase(queue_results.begin() + i);
return res;
}
}
}
// should never reach here
}
// Register the function to update multitask
void on_multitask_update(callback_multitask_t callback) {
callback_update_multitask = callback;
}
// Send a new result to a waiting task_id
void send(task_result result) {
std::unique_lock<std::mutex> lock(mutex_results);
LOG_VERBOSE("send new result", {{"task_id", result.id}});
for (auto& task_id : waiting_task_ids) {
// LOG_TEE("waiting task id %i \n", task_id);
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
if (result.multitask_id == task_id)
{
LOG_VERBOSE("callback_update_multitask", {{"task_id", task_id}});
callback_update_multitask(task_id, result.id, result);
continue;
}
if (result.id == task_id)
{
LOG_VERBOSE("queue_results.push_back", {{"task_id", task_id}});
queue_results.push_back(result);
condition_results.notify_all();
return;
}
}
}
};
// //
// base64 utils (TODO: move to common in the future) // base64 utils (TODO: move to common in the future)
// //
@ -447,13 +143,11 @@ static const std::string base64_chars =
"abcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyz"
"0123456789+/"; "0123456789+/";
static inline bool is_base64(uint8_t c) static inline bool is_base64(uint8_t c) {
{
return (isalnum(c) || (c == '+') || (c == '/')); return (isalnum(c) || (c == '+') || (c == '/'));
} }
static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string) static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string) {
{
int i = 0; int i = 0;
int j = 0; int j = 0;
int in_ = 0; int in_ = 0;
@ -465,13 +159,10 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
std::vector<uint8_t> ret; std::vector<uint8_t> ret;
while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
{
char_array_4[i++] = encoded_string[in_]; in_++; char_array_4[i++] = encoded_string[in_]; in_++;
if (i == 4) if (i == 4) {
{ for (i = 0; i < 4; i++) {
for (i = 0; i <4; i++)
{
char_array_4[i] = base64_chars.find(char_array_4[i]); char_array_4[i] = base64_chars.find(char_array_4[i]);
} }
@ -479,23 +170,20 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (i = 0; (i < 3); i++) for (i = 0; (i < 3); i++) {
{
ret.push_back(char_array_3[i]); ret.push_back(char_array_3[i]);
} }
i = 0; i = 0;
} }
} }
if (i) if (i) {
{ for (j = i; j < 4; j++) {
for (j = i; j <4; j++)
{
char_array_4[j] = 0; char_array_4[j] = 0;
} }
for (j = 0; j <4; j++) for (j = 0; j < 4; j++) {
{
char_array_4[j] = base64_chars.find(char_array_4[j]); char_array_4[j] = base64_chars.find(char_array_4[j]);
} }
@ -503,8 +191,7 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (j = 0; (j < i - 1); j++) for (j = 0; j < i - 1; j++) {
{
ret.push_back(char_array_3[j]); ret.push_back(char_array_3[j]);
} }
} }
@ -516,8 +203,7 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
// random string / id // random string / id
// //
static std::string random_string() static std::string random_string() {
{
static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
std::random_device rd; std::random_device rd;
@ -532,10 +218,10 @@ static std::string random_string()
return result; return result;
} }
static std::string gen_chatcmplid() static std::string gen_chatcmplid() {
{
std::stringstream chatcmplid; std::stringstream chatcmplid;
chatcmplid << "chatcmpl-" << random_string(); chatcmplid << "chatcmpl-" << random_string();
return chatcmplid.str(); return chatcmplid.str();
} }
@ -543,91 +229,316 @@ static std::string gen_chatcmplid()
// other common utils // other common utils
// //
static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b) static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
{
size_t i; size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
{
}
return i; return i;
} }
static bool ends_with(const std::string &str, const std::string &suffix) static bool ends_with(const std::string & str, const std::string & suffix) {
{ return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
} }
static size_t find_partial_stop_string(const std::string &stop, static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
const std::string &text) if (!text.empty() && !stop.empty()) {
{
if (!text.empty() && !stop.empty())
{
const char text_last_char = text.back(); const char text_last_char = text.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
{ if (stop[char_index] == text_last_char) {
if (stop[char_index] == text_last_char)
{
const std::string current_partial = stop.substr(0, char_index + 1); const std::string current_partial = stop.substr(0, char_index + 1);
if (ends_with(text, current_partial)) if (ends_with(text, current_partial)) {
{
return text.size() - char_index - 1; return text.size() - char_index - 1;
} }
} }
} }
} }
return std::string::npos; return std::string::npos;
} }
// TODO: reuse llama_detokenize // TODO: reuse llama_detokenize
template <class Iter> template <class Iter>
static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
{
std::string ret; std::string ret;
for (; begin != end; ++begin) for (; begin != end; ++begin) {
{
ret += llama_token_to_piece(ctx, *begin); ret += llama_token_to_piece(ctx, *begin);
} }
return ret; return ret;
} }
// format incomplete utf-8 multibyte character for output // format incomplete utf-8 multibyte character for output
static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
{
std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token); std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
// if the size is 1 and first bit is 1, meaning it's a partial character // if the size is 1 and first bit is 1, meaning it's a partial character
// (size > 1 meaning it's already a known token) // (size > 1 meaning it's already a known token)
if (out.size() == 1 && (out[0] & 0x80) == 0x80) if (out.size() == 1 && (out[0] & 0x80) == 0x80) {
{
std::stringstream ss; std::stringstream ss;
ss << std::hex << (out[0] & 0xff); ss << std::hex << (out[0] & 0xff);
std::string res(ss.str()); std::string res(ss.str());
out = "byte: \\x" + res; out = "byte: \\x" + res;
} }
return out; return out;
} }
struct completion_token_output {
llama_token tok;
std::string text_to_send;
struct token_prob {
llama_token tok;
float prob;
};
std::vector<token_prob> probs;
};
// convert a vector of completion_token_output to json // convert a vector of completion_token_output to json
static json probs_vector_to_json(const llama_context *ctx, const std::vector<completion_token_output> &probs) static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
{
json out = json::array(); json out = json::array();
for (const auto &prob : probs)
{ for (const auto & prob : probs) {
json probs_for_token = json::array(); json probs_for_token = json::array();
for (const auto &p : prob.probs)
{ for (const auto & p : prob.probs) {
std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
probs_for_token.push_back(json probs_for_token.push_back(json {
{
{"tok_str", tok_str}, {"tok_str", tok_str},
{"prob", p.prob}, {"prob", p.prob},
}); });
} }
std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
out.push_back(json{ const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
out.push_back(json {
{"content", tok_str}, {"content", tok_str},
{"probs", probs_for_token}, {"probs", probs_for_token},
}); });
} }
return out; return out;
} }
//
// OAI utils
//
static json oaicompat_completion_params_parse(
const struct llama_model * model,
const json & body, /* openai api json semantics */
const std::string & chat_template) {
json llama_params;
llama_params["__oaicompat"] = true;
// Map OpenAI parameters to llama.cpp parameters
//
// For parameters that are defined by the OpenAI documentation (e.g.
// temperature), we explicitly specify OpenAI's intended default; we
// need to do that because sometimes OpenAI disagrees with llama.cpp
//
// https://platform.openai.com/docs/api-reference/chat/create
llama_sampling_params default_sparams;
llama_params["model"] = json_value(body, "model", std::string("unknown"));
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
llama_params["temperature"] = json_value(body, "temperature", 0.0);
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
llama_params["top_p"] = json_value(body, "top_p", 1.0);
llama_params["n_predict"] = json_value(body, "max_tokens", -1);
llama_params["logit_bias"] = json_value(body, "logit_bias", json::object());
llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0);
llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED);
llama_params["stream"] = json_value(body, "stream", false);
llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat);
llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl);
llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p);
llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
llama_params["ignore_eos"] = json_value(body, "ignore_eos", false);
llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z);
if (body.count("grammar") != 0) {
llama_params["grammar"] = json_value(body, "grammar", json::object());
}
// Handle 'stop' field
if (body.contains("stop") && body["stop"].is_string()) {
llama_params["stop"] = json::array({body["stop"].get<std::string>()});
} else {
llama_params["stop"] = json_value(body, "stop", json::array());
}
// Ensure there is ChatML-specific end sequence among stop words
llama_params["stop"].push_back("<|im_end|>");
return llama_params;
}
static json format_final_response_oaicompat(const json & request, json result, bool streaming = false) {
bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
std::string content = json_value(result, "content", std::string(""));
std::string finish_reason = "length";
if (stopped_word || stopped_eos) {
finish_reason = "stop";
}
json choices =
streaming ? json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}})
: json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"message", json{{"content", content},
{"role", "assistant"}}}}});
std::time_t t = std::time(0);
json res = json {
{"choices", choices},
{"created", t},
{"model",
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
{"usage", json {
{"completion_tokens", num_tokens_predicted},
{"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
}},
{"id", gen_chatcmplid()}
};
if (server_verbose) {
res["__verbose"] = result;
}
if (result.contains("completion_probabilities")) {
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
}
return res;
}
// return value is vector as there is one case where we might need to generate two responses
static std::vector<json> format_partial_response_oaicompat(json result) {
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
return std::vector<json>({result});
}
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
bool stopped_word = json_value(result, "stopped_word", false);
bool stopped_eos = json_value(result, "stopped_eos", false);
bool stopped_limit = json_value(result, "stopped_limit", false);
std::string content = json_value(result, "content", std::string(""));
std::string finish_reason;
if (stopped_word || stopped_eos) {
finish_reason = "stop";
}
if (stopped_limit) {
finish_reason = "length";
}
std::time_t t = std::time(0);
json choices;
if (!finish_reason.empty()) {
choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
} else {
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
}}}})},
{"created", t},
{"id", gen_chatcmplid()},
{"model", modelname},
{"object", "chat.completion.chunk"}};
json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"content", content}}}
}})},
{"created", t},
{"id", gen_chatcmplid()},
{"model", modelname},
{"object", "chat.completion.chunk"}};
return std::vector<json>({initial_ret, second_ret});
}
} else {
// Some idiosyncrasy in task processing logic makes several trailing calls
// with empty content, we ignore these at the calee site.
if (content.empty()) {
return std::vector<json>({json::object()});
}
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json{
{"content", content},
}},
}});
}
}
json ret = json {
{"choices", choices},
{"created", t},
{"id", gen_chatcmplid()},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
return std::vector<json>({ret});
}
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json {
{"prompt_tokens", 0},
{"total_tokens", 0}
}},
{"data", embeddings}
};
return res;
}
static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
return json {
{"tokens", tokens}
};
}
static json format_detokenized_response(const std::string & content) {
return json {
{"content", content}
};
}

View File

@ -13541,18 +13541,22 @@ LLAMA_API int32_t llama_chat_apply_template(
curr_tmpl = std::string(model_template.data(), model_template.size()); curr_tmpl = std::string(model_template.data(), model_template.size());
} }
} }
// format the chat to string // format the chat to string
std::vector<const llama_chat_message *> chat_vec; std::vector<const llama_chat_message *> chat_vec;
chat_vec.resize(n_msg); chat_vec.resize(n_msg);
for (size_t i = 0; i < n_msg; i++) { for (size_t i = 0; i < n_msg; i++) {
chat_vec[i] = &chat[i]; chat_vec[i] = &chat[i];
} }
std::string formatted_chat; std::string formatted_chat;
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
if (res < 0) { if (res < 0) {
return res; return res;
} }
strncpy(buf, formatted_chat.c_str(), length); if (buf && length > 0) {
strncpy(buf, formatted_chat.c_str(), length);
}
return res; return res;
} }