From 9ba399dfa7f115effc63d48e6860a94c9faa31b2 Mon Sep 17 00:00:00 2001 From: Reza Kakhki Date: Tue, 24 Dec 2024 21:33:04 +0100 Subject: [PATCH] server : add support for "encoding_format": "base64" to the */embeddings endpoints (#10967) * add support for base64 * fix base64 test * improve test --------- Co-authored-by: Xuan Son Nguyen --- examples/server/CMakeLists.txt | 1 + examples/server/server.cpp | 13 ++++++- examples/server/tests/unit/test_embedding.py | 41 ++++++++++++++++++++ examples/server/utils.hpp | 28 ++++++++++--- 4 files changed, 76 insertions(+), 7 deletions(-) diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index a27597cbc..1b7cc8c13 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -34,6 +34,7 @@ endforeach() add_executable(${TARGET} ${TARGET_SRCS}) install(TARGETS ${TARGET} RUNTIME) +target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) if (LLAMA_SERVER_SSL) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3fbfb13c4..30ff3b149 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3790,6 +3790,17 @@ int main(int argc, char ** argv) { return; } + bool use_base64 = false; + if (body.count("encoding_format") != 0) { + const std::string& format = body.at("encoding_format"); + if (format == "base64") { + use_base64 = true; + } else if (format != "float") { + res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true); for (const auto & tokens : tokenized_prompts) { // this check is necessary for models that do not add BOS token to the input @@ -3841,7 +3852,7 @@ int main(int argc, char ** argv) { } // write JSON response - json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses); + json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses); res_ok(res, root); }; diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index 43e372fc7..8b0eb42b0 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -1,3 +1,5 @@ +import base64 +import struct import pytest from openai import OpenAI from utils import * @@ -194,3 +196,42 @@ def test_embedding_usage_multiple(): assert res.status_code == 200 assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] assert res.body['usage']['prompt_tokens'] == 2 * 9 + + +def test_embedding_openai_library_base64(): + server.start() + test_input = "Test base64 embedding output" + + # get embedding in default format + res = server.make_request("POST", "/v1/embeddings", data={ + "input": test_input + }) + assert res.status_code == 200 + vec0 = res.body["data"][0]["embedding"] + + # get embedding in base64 format + res = server.make_request("POST", "/v1/embeddings", data={ + "input": test_input, + "encoding_format": "base64" + }) + + assert res.status_code == 200 + assert "data" in res.body + assert len(res.body["data"]) == 1 + + embedding_data = res.body["data"][0] + assert "embedding" in embedding_data + assert isinstance(embedding_data["embedding"], str) + + # Verify embedding is valid base64 + decoded = base64.b64decode(embedding_data["embedding"]) + # Verify decoded data can be converted back to float array + float_count = len(decoded) // 4 # 4 bytes per float + floats = struct.unpack(f'{float_count}f', decoded) + assert len(floats) > 0 + assert all(isinstance(x, float) for x in floats) + assert len(floats) == len(vec0) + + # make sure the decoded data is the same as the original + for x, y in zip(floats, vec0): + assert abs(x - y) < EPSILON diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 043d8b528..334f2f192 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -3,6 +3,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "common/base64.hpp" #ifndef NDEBUG // crash the server in debug mode, otherwise send an http 500 error @@ -613,16 +614,31 @@ static json oaicompat_completion_params_parse( return llama_params; } -static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) { +static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) { json data = json::array(); int32_t n_tokens = 0; int i = 0; for (const auto & elem : embeddings) { - data.push_back(json{ - {"embedding", json_value(elem, "embedding", json::array())}, - {"index", i++}, - {"object", "embedding"} - }); + json embedding_obj; + + if (use_base64) { + const auto& vec = json_value(elem, "embedding", json::array()).get>(); + const char* data_ptr = reinterpret_cast(vec.data()); + size_t data_size = vec.size() * sizeof(float); + embedding_obj = { + {"embedding", base64::encode(data_ptr, data_size)}, + {"index", i++}, + {"object", "embedding"}, + {"encoding_format", "base64"} + }; + } else { + embedding_obj = { + {"embedding", json_value(elem, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }; + } + data.push_back(embedding_obj); n_tokens += json_value(elem, "tokens_evaluated", 0); }