From dd64a14fdb015a3be8e502355bc7a0314de935e7 Mon Sep 17 00:00:00 2001 From: ruslandoga <67764432+ruslandoga@users.noreply.github.com> Date: Thu, 11 Jan 2024 13:25:11 +0900 Subject: [PATCH] refactor Ch.stream/4 --- CHANGELOG.md | 1 + lib/ch.ex | 4 +- lib/ch/connection.ex | 118 ++++++++++++++++++++++++++---------- lib/ch/query.ex | 5 +- test/ch/connection_test.exs | 48 +++------------ test/ch/stream_test.exs | 40 ++++++++++++ 6 files changed, 141 insertions(+), 75 deletions(-) create mode 100644 test/ch/stream_test.exs diff --git a/CHANGELOG.md b/CHANGELOG.md index b2c990d..deaa40c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - move rows payload (RowBinary, CSV, etc.) to SQL statement and remove pseudo-positional binds, making param names explicit https://github.com/plausible/ch/pull/143 - drop `:headers` from `%Ch.Result{}` but add `:data` https://github.com/plausible/ch/pull/144 - fix query string escaping for `\t`, `\\`, and `\n` https://github.com/plausible/ch/pull/147 +- make `Ch.stream/4` emit `%Ch.Result{data: iodata}` https://github.com/plausible/ch/pull/148 ## 0.2.2 (2023-12-23) diff --git a/lib/ch.ex b/lib/ch.ex index 0a8ac52..91b7336 100644 --- a/lib/ch.ex +++ b/lib/ch.ex @@ -99,7 +99,9 @@ defmodule Ch do DBConnection.execute!(conn, query, params, opts) end - @doc false + @doc """ + Returns a stream for a query on a connection. + """ @spec stream(DBConnection.t(), statement, params, [query_option]) :: DBConnection.Stream.t() def stream(conn, statement, params \\ [], opts \\ []) do query = Query.build(statement, opts) diff --git a/lib/ch/connection.ex b/lib/ch/connection.ex index ad8ea4b..18442be 100644 --- a/lib/ch/connection.ex +++ b/lib/ch/connection.ex @@ -85,41 +85,97 @@ defmodule Ch.Connection do @impl true def handle_declare(query, params, opts, conn) do - {query_params, extra_headers, body} = params + %Query{command: command, statement: statement} = query + {query_params, extra_headers} = params path = path(conn, query_params, opts) headers = headers(conn, extra_headers, opts) - types = Keyword.get(opts, :types) - with {:ok, conn, ref} <- send_request(conn, "POST", path, headers, body) do - {:ok, query, {types, ref}, conn} + with {:ok, conn, _ref} <- send_request(conn, "POST", path, headers, statement), + {:ok, conn} <- eat_ok_status_and_headers(conn, timeout(conn, opts)) do + {:ok, query, %Result{command: command}, conn} end end + @spec eat_ok_status_and_headers(conn, timeout) :: + {:ok, %{conn: conn, buffer: [Mint.Types.response()]}} + | {:error, Ch.Error.t(), conn} + | {:disconnect, Mint.Types.error(), conn} + defp eat_ok_status_and_headers(conn, timeout) do + case HTTP.recv(conn, 0, timeout) do + {:ok, conn, responses} -> + case eat_ok_status_and_headers(responses) do + {:ok, data} -> + {:ok, %{conn: conn, buffer: data}} + + :more -> + eat_ok_status_and_headers(conn, timeout) + + :error -> + all_responses_result = + case handle_all_responses(responses, []) do + {:ok, responses} -> {:ok, conn, responses} + {:more, acc} -> recv_all(conn, acc, timeout) + end + + with {:ok, conn, responses} <- all_responses_result do + [_status, headers | data] = responses + message = IO.iodata_to_binary(data) + + code = + if code = get_header(headers, "x-clickhouse-exception-code") do + String.to_integer(code) + end + + {:error, Error.exception(code: code, message: message), conn} + end + end + + {:error, conn, error, _responses} -> + {:disconnect, error, conn} + end + end + + defp eat_ok_status_and_headers([{:status, _ref, 200} | rest]) do + eat_ok_status_and_headers(rest) + end + + defp eat_ok_status_and_headers([{:status, _ref, _status} | _rest]), do: :error + defp eat_ok_status_and_headers([{:headers, _ref, _headers} | data]), do: {:ok, data} + defp eat_ok_status_and_headers([]), do: :more + @impl true - def handle_fetch(_query, {types, ref}, opts, conn) do + def handle_fetch(query, result, opts, %{conn: conn, buffer: buffer}) do + case buffer do + [] -> handle_fetch(query, result, opts, conn) + _not_empty -> {halt_or_cont(buffer), %Result{result | data: extract_data(buffer)}, conn} + end + end + + def handle_fetch(_query, result, opts, conn) do case HTTP.recv(conn, 0, timeout(conn, opts)) do {:ok, conn, responses} -> - {halt_or_cont(responses, ref), {:stream, types, responses}, conn} + {halt_or_cont(responses), %Result{result | data: extract_data(responses)}, conn} {:error, conn, reason, _responses} -> {:disconnect, reason, conn} end end - defp halt_or_cont([{:done, ref}], ref), do: :halt + defp halt_or_cont([{:done, _ref}]), do: :halt + defp halt_or_cont([_ | rest]), do: halt_or_cont(rest) + defp halt_or_cont([]), do: :cont - defp halt_or_cont([{tag, ref, _data} | rest], ref) when tag in [:data, :status, :headers] do - halt_or_cont(rest, ref) - end - - defp halt_or_cont([], _ref), do: :cont + defp extract_data([{:data, _ref, data} | rest]), do: [data | extract_data(rest)] + defp extract_data([] = empty), do: empty + defp extract_data([{:done, _ref}]), do: [] @impl true - def handle_deallocate(_query, _ref, _opts, conn) do + def handle_deallocate(_query, result, _opts, conn) do case HTTP.open_request_count(conn) do 0 -> - {:ok, [], conn} + # TODO data: [], anything else? + {:ok, %Result{result | data: []}, conn} 1 -> {:disconnect, Error.exception("cannot stop stream before receiving full response"), conn} @@ -165,8 +221,8 @@ defmodule Ch.Connection do | {:error, Error.t(), conn} | {:disconnect, Mint.Types.error(), conn} defp request(conn, method, path, headers, body, opts) do - with {:ok, conn, ref} <- send_request(conn, method, path, headers, body) do - receive_response(conn, ref, timeout(conn, opts)) + with {:ok, conn, _ref} <- send_request(conn, method, path, headers, body) do + receive_full_response(conn, timeout(conn, opts)) end end @@ -184,7 +240,7 @@ defmodule Ch.Connection do def request_chunked(conn, method, path, headers, stream, opts) do with {:ok, conn, ref} <- send_request(conn, method, path, headers, :stream), {:ok, conn} <- stream_body(conn, ref, stream), - do: receive_response(conn, ref, timeout(conn, opts)) + do: receive_full_response(conn, timeout(conn, opts)) end @spec stream_body(conn, Mint.Types.request_ref(), Enumerable.t()) :: @@ -213,12 +269,12 @@ defmodule Ch.Connection do end end - @spec receive_response(conn, Mint.Types.request_ref(), timeout) :: + @spec receive_full_response(conn, timeout) :: {:ok, conn, [response]} | {:error, Error.t(), conn} | {:disconnect, Mint.Types.error(), conn} - defp receive_response(conn, ref, timeout) do - with {:ok, conn, responses} <- recv(conn, ref, [], timeout) do + defp receive_full_response(conn, timeout) do + with {:ok, conn, responses} <- recv_all(conn, [], timeout) do case responses do [200, headers | _rest] -> conn = ensure_same_server(conn, headers) @@ -237,14 +293,14 @@ defmodule Ch.Connection do end end - @spec recv(conn, Mint.Types.request_ref(), [response], timeout()) :: + @spec recv_all(conn, [response], timeout()) :: {:ok, conn, [response]} | {:disconnect, Mint.Types.error(), conn} - defp recv(conn, ref, acc, timeout) do + defp recv_all(conn, acc, timeout) do case HTTP.recv(conn, 0, timeout) do {:ok, conn, responses} -> - case handle_responses(responses, ref, acc) do + case handle_all_responses(responses, acc) do {:ok, responses} -> {:ok, conn, responses} - {:more, acc} -> recv(conn, ref, acc, timeout) + {:more, acc} -> recv_all(conn, acc, timeout) end {:error, conn, reason, _responses} -> @@ -252,16 +308,14 @@ defmodule Ch.Connection do end end - defp handle_responses([{:done, ref}], ref, acc) do - {:ok, :lists.reverse(acc)} - end - - defp handle_responses([{tag, ref, data} | rest], ref, acc) - when tag in [:data, :status, :headers] do - handle_responses(rest, ref, [data | acc]) + for tag <- [:data, :status, :headers] do + defp handle_all_responses([{unquote(tag), _ref, data} | rest], acc) do + handle_all_responses(rest, [data | acc]) + end end - defp handle_responses([], _ref, acc), do: {:more, acc} + defp handle_all_responses([{:done, _ref}], acc), do: {:ok, :lists.reverse(acc)} + defp handle_all_responses([], acc), do: {:more, acc} defp maybe_put_private(conn, _k, nil), do: conn defp maybe_put_private(conn, k, v), do: HTTP.put_private(conn, k, v) diff --git a/lib/ch/query.ex b/lib/ch/query.ex index 3e9bae2..2f850d4 100644 --- a/lib/ch/query.ex +++ b/lib/ch/query.ex @@ -86,7 +86,7 @@ defimpl DBConnection.Query, for: Ch.Query do @spec decode(Query.t(), [response], [Ch.query_option()]) :: Result.t() when response: Mint.Types.status() | Mint.Types.headers() | binary - def decode(%Query{command: command}, responses, opts) do + def decode(%Query{command: command}, responses, opts) when is_list(responses) do [_status, headers | data] = responses format = get_header(headers, "x-clickhouse-format") decode = Keyword.get(opts, :decode, true) @@ -110,6 +110,9 @@ defimpl DBConnection.Query, for: Ch.Query do end end + # stream result + def decode(_query, %Result{} = result, _opts), do: result + defp get_header(headers, key) do case List.keyfind(headers, key, 0) do {_, value} -> value diff --git a/test/ch/connection_test.exs b/test/ch/connection_test.exs index aa99a4a..729d873 100644 --- a/test/ch/connection_test.exs +++ b/test/ch/connection_test.exs @@ -1309,54 +1309,20 @@ defmodule Ch.ConnectionTest do end describe "stream" do - @tag :skip - test "sends mint http packets", %{conn: conn} do - stmt = "select number from system.numbers limit 1000" - - drop_ref = fn packets -> - Enum.map(packets, fn - {tag, _ref, data} -> {tag, data} - {tag, _ref} -> tag - end) - end - - packets = + test "emits result structs containing raw data", %{conn: conn} do + results = DBConnection.run(conn, fn conn -> conn - |> Ch.stream(stmt) - |> Enum.flat_map(drop_ref) - end) - - assert [{:status, 200}, {:headers, headers} | _rest] = packets - - assert List.keyfind!(headers, "transfer-encoding", 0) == {"transfer-encoding", "chunked"} - - assert data_packets = - packets - |> Enum.filter(&match?({:data, _data}, &1)) - |> Enum.map(fn {:data, data} -> data end) - - assert length(data_packets) >= 2 - assert RowBinary.decode_rows(Enum.join(data_packets)) == Enum.map(0..999, &[&1]) - - assert List.last(packets) == :done - end - - @tag :skip - test "decodes RowBinary", %{conn: conn} do - stmt = "select number from system.numbers limit 1000" - - rows = - DBConnection.run(conn, fn conn -> - conn - |> Ch.stream(stmt, _params = [], types: [:u64]) + |> Ch.stream("select number from system.numbers limit 1000") |> Enum.into([]) end) - assert List.flatten(rows) == Enum.into(0..999, []) + assert length(results) >= 2 + + assert results |> Enum.map(& &1.data) |> IO.iodata_to_binary() |> RowBinary.decode_rows() == + Enum.map(0..999, &[&1]) end - @tag :skip test "disconnects on early halt", %{conn: conn} do logs = ExUnit.CaptureLog.capture_log(fn -> diff --git a/test/ch/stream_test.exs b/test/ch/stream_test.exs new file mode 100644 index 0000000..405398d --- /dev/null +++ b/test/ch/stream_test.exs @@ -0,0 +1,40 @@ +defmodule Ch.StreamTest do + use ExUnit.Case + alias Ch.{Result, RowBinary} + + setup do + {:ok, conn: start_supervised!({Ch, database: Ch.Test.database()})} + end + + describe "Ch.stream/4" do + test "emits %Ch.Result{}", %{conn: conn} do + count = 1_000_000 + + assert [%Result{command: :select, data: header} | rest] = + DBConnection.run(conn, fn conn -> + conn + |> Ch.stream("select * from numbers({count:UInt64})", %{"count" => 1_000_000}) + |> Enum.into([]) + end) + + assert header == [<<1, 6, "number", 6, "UInt64">>] + + decoded = + Enum.flat_map(rest, fn %Result{data: data} -> + data |> IO.iodata_to_binary() |> RowBinary.decode_rows([:u64]) + end) + + assert length(decoded) == count + end + + test "raises on error", %{conn: conn} do + assert_raise Ch.Error, + ~r/Code: 62. DB::Exception: Syntax error: failed at position 8/, + fn -> + DBConnection.run(conn, fn conn -> + conn |> Ch.stream("select ", %{"count" => 1_000_000}) |> Enum.into([]) + end) + end + end + end +end