diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9fb436c2a..19a8c1067 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1836,7 +1836,7 @@ struct llama_server_context send_embedding(slot); slot.release(); slot.i_batch = -1; - return true; + continue; } completion_token_output result; diff --git a/examples/server/tests/features/issues.feature b/examples/server/tests/features/issues.feature index 542006d9a..bf5a175a3 100644 --- a/examples/server/tests/features/issues.feature +++ b/examples/server/tests/features/issues.feature @@ -1,36 +1,4 @@ # List of ongoing issues @bug Feature: Issues - # Issue #5655 - Scenario: Multi users embeddings - Given a server listening on localhost:8080 - And a model file stories260K.gguf - And a model alias tinyllama-2 - And 42 as server seed - And 64 KV cache size - And 2 slots - And continuous batching - And embeddings extraction - Then the server is starting - Then the server is healthy - - Given a prompt: - """ - Write a very long story about AI. - """ - And a prompt: - """ - Write another very long music lyrics. - """ - And a prompt: - """ - Write a very long poem. - """ - And a prompt: - """ - Write a very long joke. - """ - Given concurrent embedding requests - Then the server is busy - Then the server is idle - Then all embeddings are generated + # No confirmed issue at the moment diff --git a/examples/server/tests/features/parallel.feature b/examples/server/tests/features/parallel.feature index 802d624ff..c85f9de1d 100644 --- a/examples/server/tests/features/parallel.feature +++ b/examples/server/tests/features/parallel.feature @@ -8,6 +8,7 @@ Feature: Parallel And 42 as server seed And 64 KV cache size And 2 slots + And embeddings extraction And continuous batching Then the server is starting Then the server is healthy @@ -75,3 +76,48 @@ Feature: Parallel Then the server is busy Then the server is idle Then all prompts are predicted + + Scenario: Multi users embeddings + Given a prompt: + """ + Write a very long story about AI. + """ + And a prompt: + """ + Write another very long music lyrics. + """ + And a prompt: + """ + Write a very long poem. + """ + And a prompt: + """ + Write a very long joke. + """ + Given concurrent embedding requests + Then the server is busy + Then the server is idle + Then all embeddings are generated + + Scenario: Multi users OAI compatibility embeddings + Given a prompt: + """ + In which country Paris is located ? + """ + And a prompt: + """ + Is Madrid the capital of Spain ? + """ + And a prompt: + """ + What is the biggest US city ? + """ + And a prompt: + """ + What is the capital of Bulgaria ? + """ + And a model tinyllama-2 + Given concurrent OAI embedding requests + Then the server is busy + Then the server is idle + Then all embeddings are generated diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index fedcfe5ae..5f81d256a 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -60,6 +60,19 @@ Feature: llama.cpp server """ Then embeddings are generated + Scenario: OAI Embeddings compatibility with multiple inputs + Given a model tinyllama-2 + Given a prompt: + """ + In which country Paris is located ? + """ + And a prompt: + """ + Is Madrid the capital of Spain ? + """ + When an OAI compatible embeddings computation request for multiple inputs + Then embeddings are generated + Scenario: Tokenize / Detokenize When tokenizing: diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 50f2b641e..9c825fdbc 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -1,4 +1,5 @@ import asyncio +import collections import json import os import re @@ -261,35 +262,35 @@ def step_a_prompt_prompt(context, prompt): @step(u'concurrent completion requests') @async_run_until_complete() async def step_concurrent_completion_requests(context): - await concurrent_completion_requests(context, - request_completion, - # prompt is inserted automatically - context.base_url, - debug=context.debug, - n_predict=context.n_predict if hasattr(context, 'n_predict') else None, - server_seed=context.server_seed if hasattr(context, 'server_seed') else None, - user_api_key=context.user_api_key if hasattr(context, - 'user_api_key') else None) + await concurrent_requests(context, + request_completion, + # prompt is inserted automatically + context.base_url, + debug=context.debug, + n_predict=context.n_predict if hasattr(context, 'n_predict') else None, + server_seed=context.server_seed if hasattr(context, 'server_seed') else None, + user_api_key=context.user_api_key if hasattr(context, + 'user_api_key') else None) @step(u'concurrent OAI completions requests') @async_run_until_complete async def step_oai_chat_completions(context): - await concurrent_completion_requests(context, oai_chat_completions, - # user_prompt is inserted automatically - context.system_prompt, - context.base_url, - True, # async_client - model=context.model - if hasattr(context, 'model') else None, - n_predict=context.n_predict - if hasattr(context, 'n_predict') else None, - enable_streaming=context.enable_streaming - if hasattr(context, 'enable_streaming') else None, - server_seed=context.server_seed - if hasattr(context, 'server_seed') else None, - user_api_key=context.user_api_key - if hasattr(context, 'user_api_key') else None) + await concurrent_requests(context, oai_chat_completions, + # user_prompt is inserted automatically + context.system_prompt, + context.base_url, + True, # async_client + model=context.model + if hasattr(context, 'model') else None, + n_predict=context.n_predict + if hasattr(context, 'n_predict') else None, + enable_streaming=context.enable_streaming + if hasattr(context, 'enable_streaming') else None, + server_seed=context.server_seed + if hasattr(context, 'server_seed') else None, + user_api_key=context.user_api_key + if hasattr(context, 'user_api_key') else None) @step(u'all prompts are predicted') @@ -316,36 +317,58 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None): @step(u'embeddings are computed for') @async_run_until_complete async def step_compute_embedding(context): - content = context.text - base_url = context.base_url - context.embeddings = await request_embedding(content, base_url) + context.embeddings = await request_embedding(context.text, base_url=context.base_url) @step(u'embeddings are generated') def step_assert_embeddings(context): - assert_embeddings(context.embeddings) + if len(context.prompts) == 0: + assert_embeddings(context.embeddings) + else: + assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n" + f"context.prompts={context.prompts}\n" + f"context.embeddings={context.embeddings}") + for embedding in context.embeddings: + context.prompts.pop() + assert_embeddings(embedding) @step(u'an OAI compatible embeddings computation request for') -def step_oai_compute_embedding(context): - openai.api_key = 'nope' # openai client always expects an api_keu - if context.user_api_key is not None: - openai.api_key = context.user_api_key - openai.api_base = f'{context.base_url}/v1' - embeddings = openai.Embedding.create( - model=context.model, - input=context.text, - ) - context.embeddings = embeddings +@async_run_until_complete +async def step_oai_compute_embeddings(context): + context.embeddings = await request_oai_embeddings(context.text, + base_url=context.base_url, + user_api_key=context.user_api_key, + model=context.model) + + +@step(u'an OAI compatible embeddings computation request for multiple inputs') +@async_run_until_complete +async def step_oai_compute_embeddings_multiple_inputs(context): + context.embeddings = await request_oai_embeddings(context.prompts, + base_url=context.base_url, + user_api_key=context.user_api_key, + model=context.model) @step(u'concurrent embedding requests') @async_run_until_complete() async def step_concurrent_embedding_requests(context): - await concurrent_completion_requests(context, - request_embedding, - # prompt is inserted automatically - context.base_url) + await concurrent_requests(context, + request_embedding, + # prompt is inserted automatically + base_url=context.base_url) + + +@step(u'concurrent OAI embedding requests') +@async_run_until_complete() +async def step_concurrent_oai_embedding_requests(context): + await concurrent_requests(context, + request_oai_embeddings, + # prompt is inserted automatically + base_url=context.base_url, + async_client=True, + model=context.model) @step(u'all embeddings are generated') @@ -401,7 +424,7 @@ def step_check_options_header_value(context, cors_header, cors_header_value): assert context.options_response.headers[cors_header] == cors_header_value -async def concurrent_completion_requests(context, f_completion, *args, **kwargs): +async def concurrent_requests(context, f_completion, *args, **kwargs): n_prompts = len(context.prompts) if context.debug: print(f"starting {n_prompts} concurrent completion requests...") @@ -565,7 +588,7 @@ async def oai_chat_completions(user_prompt, return completion_response -async def request_embedding(content, base_url): +async def request_embedding(content, base_url=None): async with aiohttp.ClientSession() as session: async with session.post(f'{base_url}/embedding', json={ @@ -576,6 +599,46 @@ async def request_embedding(content, base_url): return response_json['embedding'] +async def request_oai_embeddings(input, + base_url=None, user_api_key=None, + model=None, async_client=False): + # openai client always expects an api_key + user_api_key = user_api_key if user_api_key is not None else 'nope' + if async_client: + origin = 'llama.cpp' + if user_api_key is not None: + headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} + async with aiohttp.ClientSession() as session: + async with session.post(f'{base_url}/v1/embeddings', + json={ + "input": input, + "model": model, + }, + headers=headers) as response: + assert response.status == 200, f"received status code not expected: {response.status}" + assert response.headers['Access-Control-Allow-Origin'] == origin + assert response.headers['Content-Type'] == "application/json; charset=utf-8" + response_json = await response.json() + assert response_json['model'] == model, f"invalid model received: {response_json['model']}" + assert response_json['object'] == 'list' + return response_json['data'] + else: + openai.api_key = user_api_key + openai.api_base = f'{base_url}/v1' + oai_embeddings = openai.Embedding.create( + model=model, + input=input, + ) + + if isinstance(input, collections.abc.Sequence): + embeddings = [] + for an_oai_embeddings in oai_embeddings.data: + embeddings.append(an_oai_embeddings.embedding) + else: + embeddings = oai_embeddings.data.embedding + return embeddings + + def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None): content = completion_response['content'] n_predicted = completion_response['timings']['predicted_n']