import pytest from openai import OpenAI from utils import * server = ServerPreset.tinyllama2() @pytest.fixture(scope="module", autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() @pytest.mark.parametrize( "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated", [ ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False), ] ) def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated): global server server.start() res = server.make_request("POST", "/chat/completions", data={ "model": model, "max_tokens": max_tokens, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], }) assert res.status_code == 200 assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["completion_tokens"] == n_predicted choice = res.body["choices"][0] assert "assistant" == choice["message"]["role"] assert match_regex(re_content, choice["message"]["content"]) if truncated: assert choice["finish_reason"] == "length" else: assert choice["finish_reason"] == "stop" @pytest.mark.parametrize( "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated", [ ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False), ] ) def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated): global server server.start() res = server.make_stream_request("POST", "/chat/completions", data={ "model": model, "max_tokens": max_tokens, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], "stream": True, }) content = "" for data in res: choice = data["choices"][0] if choice["finish_reason"] in ["stop", "length"]: assert data["usage"]["prompt_tokens"] == n_prompt assert data["usage"]["completion_tokens"] == n_predicted assert "content" not in choice["delta"] assert match_regex(re_content, content) # FIXME: not sure why this is incorrect in stream mode # if truncated: # assert choice["finish_reason"] == "length" # else: # assert choice["finish_reason"] == "stop" else: assert choice["finish_reason"] is None content += choice["delta"]["content"] def test_chat_completion_with_openai_library(): global server server.start() client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") res = client.chat.completions.create( model="gpt-3.5-turbo-instruct", messages=[ {"role": "system", "content": "Book"}, {"role": "user", "content": "What is the best book"}, ], max_tokens=8, seed=42, temperature=0.8, ) print(res) assert res.choices[0].finish_reason == "stop" assert res.choices[0].message.content is not None assert match_regex("(Suddenly)+", res.choices[0].message.content) @pytest.mark.parametrize("response_format,n_predicted,re_content", [ ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""), ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"), ({"type": "json_object"}, 10, "(\\{|John)+"), ({"type": "sound"}, 0, None), # invalid response format (expected to fail) ({"type": "json_object", "schema": 123}, 0, None), ({"type": "json_object", "schema": {"type": 123}}, 0, None), ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None), ]) def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None): global server server.start() res = server.make_request("POST", "/chat/completions", data={ "max_tokens": n_predicted, "messages": [ {"role": "system", "content": "You are a coding assistant."}, {"role": "user", "content": "Write an example"}, ], "response_format": response_format, }) if re_content is not None: assert res.status_code == 200 choice = res.body["choices"][0] assert match_regex(re_content, choice["message"]["content"]) else: assert res.status_code != 200 assert "error" in res.body @pytest.mark.parametrize("messages", [ None, "string", [123], [{}], [{"role": 123}], [{"role": "system", "content": 123}], # [{"content": "hello"}], # TODO: should not be a valid case [{"role": "system", "content": "test"}, {}], ]) def test_invalid_chat_completion_req(messages): global server server.start() res = server.make_request("POST", "/chat/completions", data={ "messages": messages, }) assert res.status_code == 400 or res.status_code == 500 assert "error" in res.body