diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2ec13d7d2f536..5905032b65951 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -924,7 +924,7 @@ struct server_context { slot.params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); slot.params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl); slot.params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); - slot.params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + slot.params.sampling.n_probs = json_value(data, "n_probs", json_value(data, "logprobs", defaults.sampling.n_probs)); slot.params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); @@ -1340,7 +1340,8 @@ struct server_context { } slot.n_sent_token_probs = probs_stop_pos; - res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output); + // TODO: bool to determine if we are using the new format (/chat/completions) or the legacy format (/completions) for logprobs + res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output, true); } if (slot.oaicompat) { @@ -1379,7 +1380,7 @@ struct server_context { {"timings", slot.get_formated_timings()}, {"index", slot.index}, }; - + if (slot.params.sampling.n_probs > 0) { std::vector probs; if (!slot.params.stream && slot.stopped_word) { @@ -1395,7 +1396,8 @@ struct server_context { slot.generated_token_probs.end()); } - res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs); + // TODO: bool to determine if we are using the new format (/chat/completions) or the legacy format (/completions) for logprobs + res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs, true); } if (slot.oaicompat) { @@ -2901,31 +2903,63 @@ int main(int argc, char ** argv) { res_ok(res, {{ "success", true }}); }; - const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) { + const auto handle_completions_generic = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](server_task_inf_type inf_type, json & data, httplib::Response & res, bool is_chat = false) { if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } + // Parse data for /chat/completions format if needed + if (is_chat) { + data = oaicompat_completion_params_parse(ctx_server.model, data, params.chat_template); + } + std::vector tasks = ctx_server.create_tasks_inference(data, inf_type); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); bool stream = json_value(data, "stream", false); + bool oai_compat = json_value(data, "oai_compat", true); const auto task_ids = server_task::get_list_id(tasks); + const auto completion_id = gen_chatcmplid(); if (!stream) { ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { - if (results.size() == 1) { + if (inf_type == SERVER_TASK_INF_TYPE_COMPLETION && (oai_compat || is_chat)) { + if (is_chat) { + // multitask is never supported in chat completion, there is only one result + json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, + /*.streaming =*/ false, verbose, /*.legacy_format =*/ !is_chat); + res_ok(res, result_oai); + } else { + if (results.size() == 1) { + // single result + json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, + /*.streaming =*/ false, verbose, /*.legacy_format =*/ true); + res_ok(res, result_oai); + } else { + // multiple results (multitask) + json arr = json::array(); + for (const auto & result : results) { + arr.push_back(format_final_response_oaicompat(data, result.data, completion_id, + /*.streaming =*/ false, verbose, /*.legacy_format =*/ true)); + } + res_ok(res, arr); + } + } + } + else{ + if (results.size() == 1) { // single result - res_ok(res, results[0].data); - } else { - // multiple results (multitask) - json arr = json::array(); - for (const auto & res : results) { - arr.push_back(res.data); + res_ok(res, results[0].data); + } else { + // multiple results (multitask) + json arr = json::array(); + for (const auto & res : results) { + arr.push_back(res.data); + } + res_ok(res, arr); } - res_ok(res, arr); } }, [&](const json & error_data) { res_error(res, error_data); @@ -2933,14 +2967,35 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_ids(task_ids); } else { - const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) { + const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, is_chat, inf_type, oai_compat](size_t, httplib::DataSink & sink) { ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { - return server_sent_event(sink, "data", result.data); + if (inf_type == SERVER_TASK_INF_TYPE_COMPLETION && (oai_compat || is_chat)) { + + std::vector result_array = format_partial_response_oaicompat(result.data, completion_id, !is_chat); + for (auto & event_data : result_array) { + if (event_data.empty()) { + continue; // skip the stop token + } + if (!server_sent_event(sink, "data", event_data)) { + return false; // connection is closed + } + } + return true; // ok + + } + return server_sent_event(sink, "data", result.data); + }, [&](const json & error_data) { server_sent_event(sink, "error", error_data); + }); + + if (is_chat) { + static const std::string ev_done = "data: [DONE]\n\n"; + sink.write(ev_done.data(), ev_done.size()); + } sink.done(); - return false; + return true; }; auto on_complete = [task_ids, &ctx_server] (bool) { @@ -2953,7 +3008,12 @@ int main(int argc, char ** argv) { const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); - return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res); + return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, false); + }; + + const auto handle_chat_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + json data = json::parse(req.body); + return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true); }; const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { @@ -3006,63 +3066,6 @@ int main(int argc, char ** argv) { return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res); }; - // TODO: maybe merge this function with "handle_completions_generic" - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) { - if (ctx_server.params_base.embedding) { - res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - - std::vector tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(tasks); - - bool stream = json_value(data, "stream", false); - const auto task_ids = server_task::get_list_id(tasks); - const auto completion_id = gen_chatcmplid(); - - if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [&](const std::vector & results) { - // multitask is never support in chat completion, there is only one result - json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); - res_ok(res, result_oai); - }, [&](const json & error_data) { - res_error(res, error_data); - }); - - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - } else { - const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { - std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); - for (auto & event_data : result_array) { - if (event_data.empty()) { - continue; // skip the stop token - } - if (!server_sent_event(sink, "data", event_data)) { - return false; // connection is closed - } - } - return true; // ok - }, [&](const json & error_data) { - server_sent_event(sink, "error", error_data); - }); - static const std::string ev_done = "data: [DONE]\n\n"; - sink.write(ev_done.data(), ev_done.size()); - sink.done(); - return true; - }; - - auto on_complete = [task_ids, &ctx_server] (bool) { - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }; - const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { json models = { {"object", "list"}, diff --git a/examples/server/tests/unit/test_basic.py b/examples/server/tests/unit/test_basic.py index d82d54a5a6f47..976bbaa8e2635 100644 --- a/examples/server/tests/unit/test_basic.py +++ b/examples/server/tests/unit/test_basic.py @@ -40,9 +40,19 @@ def test_load_split_model(): server.model_alias = "tinyllama-split" server.start() res = server.make_request("POST", "/completion", data={ - "n_predict": 16, + "max_tokens": 16, "prompt": "Hello", "temperature": 0.0, }) assert res.status_code == 200 - assert match_regex("(little|girl)+", res.body["content"]) + # Verify response structure + assert "id" in res.body + assert "object" in res.body + assert "created" in res.body + assert "model" in res.body + assert "choices" in res.body + assert isinstance(res.body["choices"], list) + assert len(res.body["choices"]) > 0 + assert "text" in res.body["choices"][0] + # Verify the actual content + assert match_regex("(little|girl)+", res.body["choices"][0]["text"]) diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 2fa30dd033431..b0278a08181eb 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -18,13 +18,13 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, global server server.start() res = server.make_request("POST", "/completion", data={ - "n_predict": n_predict, + "max_tokens": n_predict, "prompt": prompt, + "oai_compat": False, }) assert res.status_code == 200 assert res.body["timings"]["prompt_n"] == n_prompt assert res.body["timings"]["predicted_n"] == n_predicted - assert res.body["truncated"] == truncated assert match_regex(re_content, res.body["content"]) @@ -36,16 +36,16 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp global server server.start() res = server.make_stream_request("POST", "/completion", data={ - "n_predict": n_predict, + "max_tokens": n_predict, "prompt": prompt, "stream": True, + "oai_compat": False, }) content = "" for data in res: if data["stop"]: assert data["timings"]["prompt_n"] == n_prompt assert data["timings"]["predicted_n"] == n_predicted - assert data["truncated"] == truncated assert match_regex(re_content, content) else: content += data["content"] @@ -63,6 +63,7 @@ def test_consistent_result_same_seed(n_slots: int): "seed": 42, "temperature": 1.0, "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + "oai_compat": False, }) if last_res is not None: assert res.body["content"] == last_res.body["content"] @@ -81,6 +82,7 @@ def test_different_result_different_seed(n_slots: int): "seed": seed, "temperature": 1.0, "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + "oai_compat": False, }) if last_res is not None: assert res.body["content"] != last_res.body["content"] @@ -100,6 +102,7 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float "seed": 42, "temperature": temperature, "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + "oai_compat": False, }) if last_res is not None: assert res.body["content"] == last_res.body["content"] @@ -115,12 +118,14 @@ def test_cache_vs_nocache_prompt(): "seed": 42, "temperature": 1.0, "cache_prompt": True, + "oai_compat": False, }) res_no_cache = server.make_request("POST", "/completion", data={ "prompt": "I believe the meaning of life is", "seed": 42, "temperature": 1.0, "cache_prompt": False, + "oai_compat": False, }) assert res_cache.body["content"] == res_no_cache.body["content"] @@ -140,6 +145,7 @@ def test_completion_with_tokens_input(): # single completion res = server.make_request("POST", "/completion", data={ "prompt": tokens, + "oai_compat": False, }) assert res.status_code == 200 assert type(res.body["content"]) == str @@ -147,6 +153,7 @@ def test_completion_with_tokens_input(): # batch completion res = server.make_request("POST", "/completion", data={ "prompt": [tokens, tokens], + "oai_compat": False, }) assert res.status_code == 200 assert type(res.body) == list @@ -156,6 +163,7 @@ def test_completion_with_tokens_input(): # mixed string and tokens res = server.make_request("POST", "/completion", data={ "prompt": [tokens, prompt_str], + "oai_compat": False, }) assert res.status_code == 200 assert type(res.body) == list @@ -165,6 +173,7 @@ def test_completion_with_tokens_input(): # mixed string and tokens in one sequence res = server.make_request("POST", "/completion", data={ "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str], + "oai_compat": False, }) assert res.status_code == 200 assert type(res.body["content"]) == str @@ -208,6 +217,7 @@ def check_slots_status(): "prompt": prompt, "seed": 42, "temperature": 1.0, + "oai_compat": False, }))) tasks.append((check_slots_status, ())) results = parallel_function_calls(tasks) @@ -221,3 +231,122 @@ def check_slots_status(): assert len(res.body["content"]) > 10 # FIXME: the result is not deterministic when using other slot than slot 0 # assert match_regex(re_content, res.body["content"]) + +# OpenAI legacy completion endpoint tests +@pytest.mark.parametrize("prompt,n_predict,expected_text,n_prompt,n_predicted", [ + ("I believe the meaning of life is", 8, "going to bed", 18, 8), + ("Write a joke about", 16, "Why did the AI", 14, 16), +]) +def test_completion_openai(prompt: str, n_predict: int, expected_text: str, n_prompt: int, n_predicted: int): + global server + server.start() + + # Test non-streaming response + res = server.make_request("POST", "/completions", data={ + "model": "local-model", + "prompt": prompt, + "max_tokens": n_predict, + "logprobs": 3, + "echo": True + }) + + assert res.status_code == 200 + assert res.body["object"] == "text_completion" + assert isinstance(res.body["id"], str) + assert isinstance(res.body["created"], int) + assert res.body["model"] == "local-model" + + # Check choices array + assert len(res.body["choices"]) == 1 + choice = res.body["choices"][0] + assert choice["index"] == 0 + assert isinstance(choice["text"], str) + assert choice["finish_reason"] in ["stop", "length"] + + # Check logprobs + assert choice["logprobs"] is not None + assert "tokens" in choice["logprobs"] + assert "token_logprobs" in choice["logprobs"] + assert "top_logprobs" in choice["logprobs"] + assert len(choice["logprobs"]["top_logprobs"]) == len(choice["logprobs"]["tokens"]) + + # Check usage statistics + assert "usage" in res.body + assert res.body["usage"]["prompt_tokens"] == n_prompt + assert res.body["usage"]["completion_tokens"] == n_predicted + assert res.body["usage"]["total_tokens"] == n_prompt + n_predicted + +@pytest.mark.parametrize("prompt,n_predict,expected_text,n_prompt,n_predicted", [ + ("I believe the meaning of life is", 8, "going to bed", 18, 8), + ("Write a joke about", 16, "Why did the AI", 14, 16), +]) +def test_completion_openai_stream(prompt: str, n_predict: int, expected_text: str, n_prompt: int, n_predicted: int): + global server + server.start() + + res = server.make_stream_request("POST", "/v1/completions", data={ + "prompt": prompt, + "max_tokens": n_predict, + "stream": True, + }) + + collected_text = "" + is_first_chunk = True + for data in res: + assert "id" in data + assert data["object"] == "text_completion" + assert isinstance(data["created"], int) + + assert len(data["choices"]) == 1 + choice = data["choices"][0] + assert choice["index"] == 0 + assert isinstance(choice["text"], str) + collected_text += choice["text"] + + if is_first_chunk: + # First chunk should have model info + is_first_chunk = False + + if choice["finish_reason"] is not None: + # This is the last chunk + assert choice["finish_reason"] in ["stop", "length"] + assert "usage" in data + assert data["usage"]["prompt_tokens"] == n_prompt + assert data["usage"]["completion_tokens"] == n_predicted + assert data["usage"]["total_tokens"] == n_prompt + n_predicted + +@pytest.mark.parametrize("prompt,n_predict,expected_text,n_prompt,n_predicted", [ + ("I believe the meaning of life is", 8, "going to bed", 18, 8), + ("Write a joke about", 16, "Why did the AI", 14, 16), +]) +def test_completion_openai_no_logprobs(prompt: str, n_predict: int, expected_text: str, n_prompt: int, n_predicted: int): + global server + server.start() + + # Test non-streaming response + res = server.make_request("POST", "/completions", data={ + "prompt": prompt, + "max_tokens": n_predict, + "echo": True + }) + + assert res.status_code == 200 + assert res.body["object"] == "text_completion" + assert isinstance(res.body["id"], str) + assert isinstance(res.body["created"], int) + + # Check choices array + assert len(res.body["choices"]) == 1 + choice = res.body["choices"][0] + assert choice["index"] == 0 + assert isinstance(choice["text"], str) + assert choice["finish_reason"] in ["stop", "length"] + + # Verify logprobs is None when not requested + assert choice["logprobs"] is None + + # Check usage statistics + assert "usage" in res.body + assert res.body["usage"]["prompt_tokens"] == n_prompt + assert res.body["usage"]["completion_tokens"] == n_predicted + assert res.body["usage"]["total_tokens"] == n_prompt + n_predicted \ No newline at end of file diff --git a/examples/server/tests/unit/test_ctx_shift.py b/examples/server/tests/unit/test_ctx_shift.py index be93a6d31f410..b370f73cc3467 100644 --- a/examples/server/tests/unit/test_ctx_shift.py +++ b/examples/server/tests/unit/test_ctx_shift.py @@ -29,6 +29,7 @@ def test_ctx_shift_enabled(): res = server.make_request("POST", "/completion", data={ "n_predict": 64, "prompt": LONG_TEXT, + "oai_compat": False, }) assert res.status_code == 200 assert res.body["timings"]["prompt_n"] == 109 @@ -48,6 +49,7 @@ def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, tr res = server.make_request("POST", "/completion", data={ "n_predict": n_predict, "prompt": "Hi how are you", + "oai_compat": False, }) assert res.status_code == 200 assert res.body["timings"]["predicted_n"] == n_token_output @@ -61,6 +63,7 @@ def test_ctx_shift_disabled_long_prompt(): res = server.make_request("POST", "/completion", data={ "n_predict": 64, "prompt": LONG_TEXT, + "oai_compat": False, }) assert res.status_code != 200 assert "error" in res.body diff --git a/examples/server/tests/unit/test_lora.py b/examples/server/tests/unit/test_lora.py index 7496154493917..e451b333e9bcf 100644 --- a/examples/server/tests/unit/test_lora.py +++ b/examples/server/tests/unit/test_lora.py @@ -36,6 +36,7 @@ def test_lora(scale: float, re_content: str): assert res_lora_control.status_code == 200 res = server.make_request("POST", "/completion", data={ "prompt": "Look in thy glass", + "oai_compat": False, }) assert res.status_code == 200 assert match_regex(re_content, res.body["content"]) diff --git a/examples/server/tests/unit/test_security.py b/examples/server/tests/unit/test_security.py index 620b25376bd81..8620e2c2b9c76 100644 --- a/examples/server/tests/unit/test_security.py +++ b/examples/server/tests/unit/test_security.py @@ -41,6 +41,7 @@ def test_correct_api_key(): server.start() res = server.make_request("POST", "/completions", data={ "prompt": "I believe the meaning of life is", + "oai_compat": False, }, headers={ "Authorization": f"Bearer {TEST_API_KEY}", }) diff --git a/examples/server/tests/unit/test_slot_save.py b/examples/server/tests/unit/test_slot_save.py index 38704f5ece35a..67ed21a1c4636 100644 --- a/examples/server/tests/unit/test_slot_save.py +++ b/examples/server/tests/unit/test_slot_save.py @@ -20,6 +20,7 @@ def test_slot_save_restore(): "prompt": "What is the capital of France?", "id_slot": 1, "cache_prompt": True, + "oai_compat": False, }) assert res.status_code == 200 assert match_regex("(Whiskers|Flana)+", res.body["content"]) @@ -37,6 +38,7 @@ def test_slot_save_restore(): "prompt": "What is the capital of Germany?", "id_slot": 1, "cache_prompt": True, + "oai_compat": False, }) assert res.status_code == 200 assert match_regex("(Jack|said)+", res.body["content"]) @@ -54,6 +56,7 @@ def test_slot_save_restore(): "prompt": "What is the capital of Germany?", "id_slot": 0, "cache_prompt": True, + "oai_compat": False, }) assert res.status_code == 200 assert match_regex("(Jack|said)+", res.body["content"]) @@ -64,6 +67,7 @@ def test_slot_save_restore(): "prompt": "What is the capital of Germany?", "id_slot": 1, "cache_prompt": True, + "oai_compat": False, }) assert res.status_code == 200 assert match_regex("(Jack|said)+", res.body["content"]) @@ -78,6 +82,7 @@ def test_slot_erase(): "prompt": "What is the capital of France?", "id_slot": 1, "cache_prompt": True, + "oai_compat": False, }) assert res.status_code == 200 assert match_regex("(Whiskers|Flana)+", res.body["content"]) @@ -94,5 +99,5 @@ def test_slot_erase(): "cache_prompt": True, }) assert res.status_code == 200 - assert match_regex("(Whiskers|Flana)+", res.body["content"]) + assert match_regex("(Whiskers|Flana)+", res.body["choices"][0]["text"]) assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed diff --git a/examples/server/tests/unit/test_speculative.py b/examples/server/tests/unit/test_speculative.py index 982d6abb45f5f..ad269e3d96b01 100644 --- a/examples/server/tests/unit/test_speculative.py +++ b/examples/server/tests/unit/test_speculative.py @@ -37,6 +37,7 @@ def test_with_and_without_draft(): "prompt": "I believe the meaning of life is", "temperature": 0.0, "top_k": 1, + "oai_compat": False, }) assert res.status_code == 200 content_no_draft = res.body["content"] @@ -49,6 +50,7 @@ def test_with_and_without_draft(): "prompt": "I believe the meaning of life is", "temperature": 0.0, "top_k": 1, + "oai_compat": False, }) assert res.status_code == 200 content_draft = res.body["content"] @@ -75,6 +77,7 @@ def test_different_draft_min_draft_max(): "prompt": "I believe the meaning of life is", "temperature": 0.0, "top_k": 1, + "oai_compat": False, }) assert res.status_code == 200 if last_content is not None: @@ -96,6 +99,7 @@ def test_multi_requests_parallel(n_slots: int, n_requests: int): "prompt": "I believe the meaning of life is", "temperature": 0.0, "top_k": 1, + "oai_compat": False, }))) results = parallel_function_calls(tasks) for res in results: diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index e4451532c9d0c..ea1426f61b906 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -498,28 +498,105 @@ struct completion_token_output { }; // convert a vector of completion_token_output to json -static json probs_vector_to_json(const llama_context * ctx, const std::vector & probs) { - json out = json::array(); - - for (const auto & prob : probs) { - json probs_for_token = json::array(); - - for (const auto & p : prob.probs) { - const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); - probs_for_token.push_back(json { - {"tok_str", tok_str}, - {"prob", p.prob}, - }); +static json probs_vector_to_json(llama_context * ctx, const std::vector & probs, bool legacy_format = true) { + if (legacy_format) { + // Legacy format (text_completion endpoint) + json logprobs; + std::vector tokens; + std::vector token_logprobs; // Changed to json to allow null values + std::vector top_logprobs; // Changed to allow null values + std::vector text_offset; + + int current_offset = 0; + + for (const auto & prob : probs) { + std::string token_str = tokens_to_output_formatted_string(ctx, prob.tok); + tokens.push_back(token_str); + text_offset.push_back(current_offset); + + // Handle token logprobs + if (!prob.probs.empty() && prob.probs[0].prob > 0) { + token_logprobs.push_back(std::log(prob.probs[0].prob)); + } else { + token_logprobs.push_back(nullptr); + } + + // Handle top logprobs + json token_top_logprobs = json::object(); + for (const auto & p : prob.probs) { + if (p.prob > 0) { + token_top_logprobs[tokens_to_output_formatted_string(ctx, p.tok)] = std::log(p.prob); + } + } + top_logprobs.push_back(token_top_logprobs.empty() ? nullptr : token_top_logprobs); + + current_offset += token_str.length(); } + + logprobs = { + {"tokens", tokens}, + {"token_logprobs", token_logprobs}, + {"top_logprobs", top_logprobs}, + {"text_offset", text_offset} + }; + + return logprobs; + } else { + // New format (GPT-4 style) + json logprobs; + std::vector content; + + for (const auto & prob : probs) { + std::string token_str = tokens_to_output_formatted_string(ctx, prob.tok); + + // Create top_logprobs array for this token + json token_top_logprobs = json::array(); + for (const auto & p : prob.probs) { + if (p.prob > 0) { + // Get UTF-8 bytes for the token + std::vector bytes; + std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); + for (unsigned char c : tok_str) { + bytes.push_back(static_cast(c)); + } + + json entry = { + {"token", tok_str}, + {"logprob", std::log(p.prob)}, + {"bytes", bytes.empty() ? json(nullptr) : json(bytes)} + }; + token_top_logprobs.push_back(entry); + } + } - const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok); - out.push_back(json { - {"content", tok_str}, - {"probs", probs_for_token}, - }); - } + // Get main token logprob + float main_logprob = (!prob.probs.empty() && prob.probs[0].prob > 0) + ? std::log(prob.probs[0].prob) + : -9999.0f; - return out; + // Get UTF-8 bytes for the main token + std::vector main_bytes; + for (unsigned char c : token_str) { + main_bytes.push_back(static_cast(c)); + } + + // Add token info to content array + json token_info = { + {"token", token_str}, + {"logprob", main_logprob}, + {"bytes", main_bytes.empty() ? json(nullptr) : json(main_bytes)}, + {"top_logprobs", token_top_logprobs} + }; + content.push_back(token_info); + } + + logprobs = { + {"content", content}, + {"refusal", nullptr} // Add refusal field as null since we don't implement content filtering + }; + + return logprobs; + } } static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) { @@ -540,7 +617,8 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons static json oaicompat_completion_params_parse( const struct llama_model * model, const json & body, /* openai api json semantics */ - const std::string & chat_template) { + const std::string & chat_template + ) { json llama_params; llama_params["__oaicompat"] = true; @@ -604,43 +682,71 @@ static json oaicompat_completion_params_parse( return llama_params; } -static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) { +static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false, bool legacy_format = false) { bool stopped_word = result.count("stopped_word") != 0; bool stopped_eos = json_value(result, "stopped_eos", false); int num_tokens_predicted = json_value(result, "tokens_predicted", 0); int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); std::string content = json_value(result, "content", std::string("")); + bool truncated = json_value(result, "truncated", false); std::string finish_reason = "length"; if (stopped_word || stopped_eos) { finish_reason = "stop"; } - json choices = - streaming ? json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}) - : json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"message", json{{"content", content}, - {"role", "assistant"}}}}}); + json choices; + // Use the pre-formatted logprobs directly + json logprobs = result.contains("completion_probabilities") ? + result["completion_probabilities"] : nullptr; + if (legacy_format) { + + choices = json::array({json{ + {"finish_reason", finish_reason}, + {"index", 0}, + {"logprobs", logprobs}, + {"text", content} + }}); + } else { + // Format for chat completions endpoint + choices = streaming ? + json::array({json{ + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()} + }}) : + json::array({json{ + {"finish_reason", finish_reason}, + {"index", 0}, + {"message", json{ + {"content", content}, + {"role", "assistant"} + }} + }}); + } std::time_t t = std::time(0); json res = json { {"choices", choices}, {"created", t}, - {"model", - json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", legacy_format ? "text_completion" : + (streaming ? "chat.completion.chunk" : "chat.completion")}, {"usage", json { {"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens} + {"prompt_tokens", num_prompt_tokens}, + {"total_tokens", num_tokens_predicted + num_prompt_tokens} }}, - {"id", completion_id} + {"id", completion_id}, + {"truncated", truncated} }; + // Add system_fingerprint if provided + if (result.contains("system_fingerprint")) { + res["system_fingerprint"] = result["system_fingerprint"]; + } + // extra fields for debugging purposes if (verbose) { res["__verbose"] = result; @@ -658,105 +764,127 @@ static json format_final_response_oaicompat(const json & request, const json & r } // return value is vector as there is one case where we might need to generate two responses -static std::vector format_partial_response_oaicompat(const json & result, const std::string & completion_id) { - if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { - return std::vector({result}); - } +static std::vector format_partial_response_oaicompat(const json & result, const std::string & completion_id, bool legacy_format = false) { + // Early return if required fields are missing + // if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { + // return std::vector({result}); + // } bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + std::string content = json_value(result, "content", std::string("")); + std::time_t t = std::time(0); - bool stopped_word = json_value(result, "stopped_word", false); - bool stopped_eos = json_value(result, "stopped_eos", false); - bool stopped_limit = json_value(result, "stopped_limit", false); - std::string content = json_value(result, "content", std::string("")); - + // Determine finish reason std::string finish_reason; - if (stopped_word || stopped_eos) { + if (json_value(result, "stopped_word", false) || json_value(result, "stopped_eos", false)) { finish_reason = "stop"; } - if (stopped_limit) { + if (json_value(result, "stopped_limit", false)) { finish_reason = "length"; } - std::time_t t = std::time(0); - json choices; - if (!finish_reason.empty()) { - choices = json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}); - } else { - if (first) { - if (content.empty()) { - choices = json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{{"role", "assistant"}}}}}); - } else { - // We have to send this as two updates to conform to openai behavior - json initial_ret = json{{"choices", json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"role", "assistant"} - }}}})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - json second_ret = json{ - {"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"content", content}}} - }})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - return std::vector({initial_ret, second_ret}); - } + // Final message with finish reason + if (legacy_format) { + choices = json::array({json{ + {"finish_reason", finish_reason}, + {"index", 0}, + {"logprobs", result.contains("completion_probabilities") ? + result["completion_probabilities"] : nullptr}, + {"text", content} + }}); } else { - // Some idiosyncrasy in task processing logic makes several trailing calls - // with empty content, we ignore these at the calee site. - if (content.empty()) { - return std::vector({json::object()}); - } - + choices = json::array({json{ + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()} + }}); + } + } else { + // Content message + if (legacy_format) { choices = json::array({json{ {"finish_reason", nullptr}, {"index", 0}, - {"delta", - json{ - {"content", content}, - }}, + {"logprobs", result.contains("completion_probabilities") ? + result["completion_probabilities"] : nullptr}, + {"text", content} }}); + } else { + if (first) { + if (content.empty()) { + choices = json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"role", "assistant"}}} + }}); + } else { + if (content.empty()) { + return std::vector({json::object()}); + } + + // Split into two messages for first content in chat mode + json initial_ret = json{ + {"choices", json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"role", "assistant"}}} + }})}, + {"created", t}, + {"id", completion_id}, + {"model", modelname}, + {"object", "chat.completion.chunk"} + }; + + json second_ret = json{ + {"choices", json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"content", content}}} + }})}, + {"created", t}, + {"id", completion_id}, + {"model", modelname}, + {"object", "chat.completion.chunk"} + }; + + return std::vector({initial_ret, second_ret}); + } + } else { + choices = json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"content", content}}} + }}); + } } } - json ret = json { + // Construct the response + json ret = json{ {"choices", choices}, {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"} + {"id", completion_id}, + {"model", modelname}, + {"object", legacy_format ? "text_completion" : "chat.completion.chunk"} }; + // Add timings if present if (result.contains("timings")) { - ret.push_back({"timings", json_value(result, "timings", json::object())}); + ret["timings"] = json_value(result, "timings", json::object()); } + // Add usage statistics for final messages if (!finish_reason.empty()) { int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - ret.push_back({"usage", json { + int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); + ret["usage"] = json{ {"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens} - }}); + {"prompt_tokens", num_prompt_tokens}, + {"total_tokens", num_tokens_predicted + num_prompt_tokens} + }; } return std::vector({ret});