diff --git a/lib/openai_ex/http_sse.ex b/lib/openai_ex/http_sse.ex index 69d1700..51859d9 100644 --- a/lib/openai_ex/http_sse.ex +++ b/lib/openai_ex/http_sse.ex @@ -8,44 +8,34 @@ defmodule OpenaiEx.HttpSse do # and # https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream - @doc false def post(openai = %OpenaiEx{}, url, json: json) do - request = OpenaiEx.Http.build_post(openai, url, json: json) - me = self() ref = make_ref() - - task = - Task.async(fn -> - on_chunk = create_chunk_handler(me, ref) - options = Http.request_options(openai) - request |> Finch.stream(Map.get(openai, :finch_name), nil, on_chunk, options) - send(me, {:done, ref}) - end) - + task = Task.async(fn -> post_sse(openai, url, json, me, ref) end) status = receive(do: ({:chunk, {:status, status}, ^ref} -> status)) headers = receive(do: ({:chunk, {:headers, headers}, ^ref} -> headers)) if status in 200..299 do - body_stream = - Stream.resource(fn -> {"", ref, task} end, &next_sse/1, fn {_, _, task} -> - Task.shutdown(task) - end) - + body_stream = Stream.resource(fn -> {"", ref} end, &next_sse/1, fn _ -> :ok end) %{status: status, headers: headers, body_stream: body_stream, task_pid: task.pid} else error_message = collect_error_message(ref, "") - Task.shutdown(task) %{status: status, headers: headers, error: Jason.decode!(error_message)} end end - @doc false def cancel_request(task_pid) when is_pid(task_pid) do send(task_pid, :cancel_request) end - @doc false + defp post_sse(openai = %OpenaiEx{}, url, json, me, ref) do + request = OpenaiEx.Http.build_post(openai, url, json: json) + on_chunk = create_chunk_handler(me, ref) + options = Http.request_options(openai) + request |> Finch.stream(Map.get(openai, :finch_name), nil, on_chunk, options) + send(me, {:done, ref}) + end + defp create_chunk_handler(me, ref) do fn chunk, _acc -> receive do @@ -58,33 +48,31 @@ defmodule OpenaiEx.HttpSse do end end - @doc false - defp next_sse({acc, ref, task}) do + defp next_sse({acc, ref}) do receive do {:chunk, {:data, evt_data}, ^ref} -> {events, next_acc} = extract_events(evt_data, acc) - {[events], {next_acc, ref, task}} + {[events], {next_acc, ref}} # some 3rd party providers seem to be ending the stream with eof, # rather than 2 line terminators. Hopefully those will be fixed and this # can be removed in the future {:done, ^ref} when acc == "data: [DONE]" -> - {:halt, {acc, ref, task}} + {:halt, {acc, ref}} {:done, ^ref} -> - if acc != "", do: Logger.error("residual!: #{acc}") - {:halt, {acc, ref, task}} + if acc != "", do: Logger.error("Residue!: #{acc}") + {:halt, {acc, ref}} {:canceled, ^ref} -> Logger.info("Request canceled by user") - {:halt, {acc, ref, task}} + {:halt, {acc, ref}} end end @double_eol ~r/(\r?\n|\r){2}/ @double_eol_eos ~r/(\r?\n|\r){2}$/ - @doc false defp extract_events(evt_data, acc) do all_data = acc <> evt_data @@ -97,40 +85,45 @@ defmodule OpenaiEx.HttpSse do end end - @doc false defp extract_lines(data) do lines = String.split(data, @double_eol) incomplete_line = !Regex.match?(@double_eol_eos, data) if incomplete_line, do: lines |> List.pop_at(-1), else: {"", lines} end - @doc false defp process_fields(lines) do lines |> Enum.map(&extract_field/1) - |> Enum.filter(fn - %{data: "[DONE]"} -> false - %{data: _} -> true - %{eventType: "done\ndata: [DONE]"} -> false - %{eventType: _} -> true - _ -> false - end) - |> Enum.map(fn + |> Enum.filter(&filter_field/1) + |> Enum.map(&decode_field/1) + end + + defp decode_field(field) do + case field do %{data: data} -> %{data: Jason.decode!(data)} %{eventType: value} -> [event_id, data] = String.split(value, "\ndata: ", parts: 2) %{event: event_id, data: Jason.decode!(data)} - end) + end + end + + defp filter_field(field) do + case field do + %{data: "[DONE]"} -> false + %{data: _} -> true + %{eventType: "done\ndata: [DONE]"} -> false + %{eventType: _} -> true + _ -> false + end end - @doc false defp extract_field(line) do - [field | rest] = String.split(line, ":", parts: 2) + [name | rest] = String.split(line, ":", parts: 2) value = Enum.join(rest, "") |> String.replace_prefix(" ", "") - case field do + case name do "data" -> %{data: value} "event" -> %{eventType: value} "id" -> %{lastEventId: value} @@ -140,7 +133,6 @@ defmodule OpenaiEx.HttpSse do end end - @doc false defp collect_error_message(ref, acc) do receive do {:chunk, {:data, chunk}, ^ref} ->