Skip to content

Commit

Permalink
refactor Ch.stream/4
Browse files Browse the repository at this point in the history
  • Loading branch information
ruslandoga committed Jan 11, 2024
1 parent 4b0092e commit dd64a14
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 75 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- move rows payload (RowBinary, CSV, etc.) to SQL statement and remove pseudo-positional binds, making param names explicit https://github.com/plausible/ch/pull/143
- drop `:headers` from `%Ch.Result{}` but add `:data` https://github.com/plausible/ch/pull/144
- fix query string escaping for `\t`, `\\`, and `\n` https://github.com/plausible/ch/pull/147
- make `Ch.stream/4` emit `%Ch.Result{data: iodata}` https://github.com/plausible/ch/pull/148

## 0.2.2 (2023-12-23)

Expand Down
4 changes: 3 additions & 1 deletion lib/ch.ex
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ defmodule Ch do
DBConnection.execute!(conn, query, params, opts)
end

@doc false
@doc """
Returns a stream for a query on a connection.
"""
@spec stream(DBConnection.t(), statement, params, [query_option]) :: DBConnection.Stream.t()
def stream(conn, statement, params \\ [], opts \\ []) do
query = Query.build(statement, opts)
Expand Down
118 changes: 86 additions & 32 deletions lib/ch/connection.ex
Original file line number Diff line number Diff line change
Expand Up @@ -85,41 +85,97 @@ defmodule Ch.Connection do

@impl true
def handle_declare(query, params, opts, conn) do
{query_params, extra_headers, body} = params
%Query{command: command, statement: statement} = query
{query_params, extra_headers} = params

path = path(conn, query_params, opts)
headers = headers(conn, extra_headers, opts)
types = Keyword.get(opts, :types)

with {:ok, conn, ref} <- send_request(conn, "POST", path, headers, body) do
{:ok, query, {types, ref}, conn}
with {:ok, conn, _ref} <- send_request(conn, "POST", path, headers, statement),
{:ok, conn} <- eat_ok_status_and_headers(conn, timeout(conn, opts)) do
{:ok, query, %Result{command: command}, conn}
end
end

@spec eat_ok_status_and_headers(conn, timeout) ::
{:ok, %{conn: conn, buffer: [Mint.Types.response()]}}
| {:error, Ch.Error.t(), conn}
| {:disconnect, Mint.Types.error(), conn}
defp eat_ok_status_and_headers(conn, timeout) do
case HTTP.recv(conn, 0, timeout) do
{:ok, conn, responses} ->
case eat_ok_status_and_headers(responses) do
{:ok, data} ->
{:ok, %{conn: conn, buffer: data}}

:more ->
eat_ok_status_and_headers(conn, timeout)

:error ->
all_responses_result =
case handle_all_responses(responses, []) do
{:ok, responses} -> {:ok, conn, responses}
{:more, acc} -> recv_all(conn, acc, timeout)
end

with {:ok, conn, responses} <- all_responses_result do
[_status, headers | data] = responses
message = IO.iodata_to_binary(data)

code =
if code = get_header(headers, "x-clickhouse-exception-code") do
String.to_integer(code)
end

{:error, Error.exception(code: code, message: message), conn}
end
end

{:error, conn, error, _responses} ->
{:disconnect, error, conn}
end
end

defp eat_ok_status_and_headers([{:status, _ref, 200} | rest]) do
eat_ok_status_and_headers(rest)
end

defp eat_ok_status_and_headers([{:status, _ref, _status} | _rest]), do: :error
defp eat_ok_status_and_headers([{:headers, _ref, _headers} | data]), do: {:ok, data}
defp eat_ok_status_and_headers([]), do: :more

@impl true
def handle_fetch(_query, {types, ref}, opts, conn) do
def handle_fetch(query, result, opts, %{conn: conn, buffer: buffer}) do
case buffer do
[] -> handle_fetch(query, result, opts, conn)
_not_empty -> {halt_or_cont(buffer), %Result{result | data: extract_data(buffer)}, conn}
end
end

def handle_fetch(_query, result, opts, conn) do
case HTTP.recv(conn, 0, timeout(conn, opts)) do
{:ok, conn, responses} ->
{halt_or_cont(responses, ref), {:stream, types, responses}, conn}
{halt_or_cont(responses), %Result{result | data: extract_data(responses)}, conn}

{:error, conn, reason, _responses} ->
{:disconnect, reason, conn}
end
end

defp halt_or_cont([{:done, ref}], ref), do: :halt
defp halt_or_cont([{:done, _ref}]), do: :halt
defp halt_or_cont([_ | rest]), do: halt_or_cont(rest)
defp halt_or_cont([]), do: :cont

defp halt_or_cont([{tag, ref, _data} | rest], ref) when tag in [:data, :status, :headers] do
halt_or_cont(rest, ref)
end

defp halt_or_cont([], _ref), do: :cont
defp extract_data([{:data, _ref, data} | rest]), do: [data | extract_data(rest)]
defp extract_data([] = empty), do: empty
defp extract_data([{:done, _ref}]), do: []

@impl true
def handle_deallocate(_query, _ref, _opts, conn) do
def handle_deallocate(_query, result, _opts, conn) do
case HTTP.open_request_count(conn) do
0 ->
{:ok, [], conn}
# TODO data: [], anything else?
{:ok, %Result{result | data: []}, conn}

1 ->
{:disconnect, Error.exception("cannot stop stream before receiving full response"), conn}
Expand Down Expand Up @@ -165,8 +221,8 @@ defmodule Ch.Connection do
| {:error, Error.t(), conn}
| {:disconnect, Mint.Types.error(), conn}
defp request(conn, method, path, headers, body, opts) do
with {:ok, conn, ref} <- send_request(conn, method, path, headers, body) do
receive_response(conn, ref, timeout(conn, opts))
with {:ok, conn, _ref} <- send_request(conn, method, path, headers, body) do
receive_full_response(conn, timeout(conn, opts))
end
end

Expand All @@ -184,7 +240,7 @@ defmodule Ch.Connection do
def request_chunked(conn, method, path, headers, stream, opts) do
with {:ok, conn, ref} <- send_request(conn, method, path, headers, :stream),
{:ok, conn} <- stream_body(conn, ref, stream),
do: receive_response(conn, ref, timeout(conn, opts))
do: receive_full_response(conn, timeout(conn, opts))
end

@spec stream_body(conn, Mint.Types.request_ref(), Enumerable.t()) ::
Expand Down Expand Up @@ -213,12 +269,12 @@ defmodule Ch.Connection do
end
end

@spec receive_response(conn, Mint.Types.request_ref(), timeout) ::
@spec receive_full_response(conn, timeout) ::
{:ok, conn, [response]}
| {:error, Error.t(), conn}
| {:disconnect, Mint.Types.error(), conn}
defp receive_response(conn, ref, timeout) do
with {:ok, conn, responses} <- recv(conn, ref, [], timeout) do
defp receive_full_response(conn, timeout) do
with {:ok, conn, responses} <- recv_all(conn, [], timeout) do
case responses do
[200, headers | _rest] ->
conn = ensure_same_server(conn, headers)
Expand All @@ -237,31 +293,29 @@ defmodule Ch.Connection do
end
end

@spec recv(conn, Mint.Types.request_ref(), [response], timeout()) ::
@spec recv_all(conn, [response], timeout()) ::
{:ok, conn, [response]} | {:disconnect, Mint.Types.error(), conn}
defp recv(conn, ref, acc, timeout) do
defp recv_all(conn, acc, timeout) do
case HTTP.recv(conn, 0, timeout) do
{:ok, conn, responses} ->
case handle_responses(responses, ref, acc) do
case handle_all_responses(responses, acc) do
{:ok, responses} -> {:ok, conn, responses}
{:more, acc} -> recv(conn, ref, acc, timeout)
{:more, acc} -> recv_all(conn, acc, timeout)
end

{:error, conn, reason, _responses} ->
{:disconnect, reason, conn}
end
end

defp handle_responses([{:done, ref}], ref, acc) do
{:ok, :lists.reverse(acc)}
end

defp handle_responses([{tag, ref, data} | rest], ref, acc)
when tag in [:data, :status, :headers] do
handle_responses(rest, ref, [data | acc])
for tag <- [:data, :status, :headers] do
defp handle_all_responses([{unquote(tag), _ref, data} | rest], acc) do
handle_all_responses(rest, [data | acc])
end
end

defp handle_responses([], _ref, acc), do: {:more, acc}
defp handle_all_responses([{:done, _ref}], acc), do: {:ok, :lists.reverse(acc)}
defp handle_all_responses([], acc), do: {:more, acc}

defp maybe_put_private(conn, _k, nil), do: conn
defp maybe_put_private(conn, k, v), do: HTTP.put_private(conn, k, v)
Expand Down
5 changes: 4 additions & 1 deletion lib/ch/query.ex
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ defimpl DBConnection.Query, for: Ch.Query do

@spec decode(Query.t(), [response], [Ch.query_option()]) :: Result.t()
when response: Mint.Types.status() | Mint.Types.headers() | binary
def decode(%Query{command: command}, responses, opts) do
def decode(%Query{command: command}, responses, opts) when is_list(responses) do
[_status, headers | data] = responses
format = get_header(headers, "x-clickhouse-format")
decode = Keyword.get(opts, :decode, true)
Expand All @@ -110,6 +110,9 @@ defimpl DBConnection.Query, for: Ch.Query do
end
end

# stream result
def decode(_query, %Result{} = result, _opts), do: result

defp get_header(headers, key) do
case List.keyfind(headers, key, 0) do
{_, value} -> value
Expand Down
48 changes: 7 additions & 41 deletions test/ch/connection_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1309,54 +1309,20 @@ defmodule Ch.ConnectionTest do
end

describe "stream" do
@tag :skip
test "sends mint http packets", %{conn: conn} do
stmt = "select number from system.numbers limit 1000"

drop_ref = fn packets ->
Enum.map(packets, fn
{tag, _ref, data} -> {tag, data}
{tag, _ref} -> tag
end)
end

packets =
test "emits result structs containing raw data", %{conn: conn} do
results =
DBConnection.run(conn, fn conn ->
conn
|> Ch.stream(stmt)
|> Enum.flat_map(drop_ref)
end)

assert [{:status, 200}, {:headers, headers} | _rest] = packets

assert List.keyfind!(headers, "transfer-encoding", 0) == {"transfer-encoding", "chunked"}

assert data_packets =
packets
|> Enum.filter(&match?({:data, _data}, &1))
|> Enum.map(fn {:data, data} -> data end)

assert length(data_packets) >= 2
assert RowBinary.decode_rows(Enum.join(data_packets)) == Enum.map(0..999, &[&1])

assert List.last(packets) == :done
end

@tag :skip
test "decodes RowBinary", %{conn: conn} do
stmt = "select number from system.numbers limit 1000"

rows =
DBConnection.run(conn, fn conn ->
conn
|> Ch.stream(stmt, _params = [], types: [:u64])
|> Ch.stream("select number from system.numbers limit 1000")
|> Enum.into([])
end)

assert List.flatten(rows) == Enum.into(0..999, [])
assert length(results) >= 2

assert results |> Enum.map(& &1.data) |> IO.iodata_to_binary() |> RowBinary.decode_rows() ==
Enum.map(0..999, &[&1])
end

@tag :skip
test "disconnects on early halt", %{conn: conn} do
logs =
ExUnit.CaptureLog.capture_log(fn ->
Expand Down
40 changes: 40 additions & 0 deletions test/ch/stream_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
defmodule Ch.StreamTest do
use ExUnit.Case
alias Ch.{Result, RowBinary}

setup do
{:ok, conn: start_supervised!({Ch, database: Ch.Test.database()})}
end

describe "Ch.stream/4" do
test "emits %Ch.Result{}", %{conn: conn} do
count = 1_000_000

assert [%Result{command: :select, data: header} | rest] =
DBConnection.run(conn, fn conn ->
conn
|> Ch.stream("select * from numbers({count:UInt64})", %{"count" => 1_000_000})
|> Enum.into([])
end)

assert header == [<<1, 6, "number", 6, "UInt64">>]

decoded =
Enum.flat_map(rest, fn %Result{data: data} ->
data |> IO.iodata_to_binary() |> RowBinary.decode_rows([:u64])
end)

assert length(decoded) == count
end

test "raises on error", %{conn: conn} do
assert_raise Ch.Error,
~r/Code: 62. DB::Exception: Syntax error: failed at position 8/,
fn ->
DBConnection.run(conn, fn conn ->
conn |> Ch.stream("select ", %{"count" => 1_000_000}) |> Enum.into([])
end)
end
end
end
end

0 comments on commit dd64a14

Please sign in to comment.