diff --git a/cohttp-eio/examples/client1.ml b/cohttp-eio/examples/client1.ml index 6d8a5feb64..8e982361bd 100644 --- a/cohttp-eio/examples/client1.ml +++ b/cohttp-eio/examples/client1.ml @@ -20,6 +20,6 @@ let () = let res = Client.get ~headers:(Http.Header.of_list [ ("Host", "www.example.org") ]) - sw flow "/" + flow "/" in match Client.read_fixed res with Some b -> print_string b | None -> () diff --git a/cohttp-eio/src/body.ml b/cohttp-eio/src/body.ml index ae0a90f0a0..3e96a46710 100644 --- a/cohttp-eio/src/body.ml +++ b/cohttp-eio/src/body.ml @@ -1,7 +1,10 @@ +module Buf_read = Eio.Buf_read +module Buf_write = Eio.Buf_write + type t = | Fixed of string | Chunked of chunk_writer - | Custom of (Eio.Flow.sink -> unit) + | Custom of (Buf_write.t -> unit) | Empty and chunk_writer = { @@ -42,20 +45,15 @@ let pp_chunk fmt = function fmt chunk | Last_chunk extensions -> pp_chunk_extension fmt extensions -open Parser -open Eio.Buf_read - let read_fixed t headers = let ( let* ) o f = Option.bind o f in let ( let+ ) o f = Option.map f o in let* v = Http.Header.get headers "Content-Length" in let+ content_length = int_of_string_opt v in - take content_length t + Buf_read.take content_length t (* Chunked encoding parser *) -open Eio.Buf_read.Syntax - let hex_digit = function | '0' .. '9' -> true | 'a' .. 'f' -> true @@ -63,7 +61,8 @@ let hex_digit = function | _ -> false let quoted_char = - let+ c = any_char in + let open Buf_read.Syntax in + let+ c = Buf_read.any_char in match c with | ' ' | '\t' | '\x21' .. '\x7E' -> c | c -> failwith (Printf.sprintf "Invalid escape \\%C" c) @@ -75,10 +74,10 @@ let qdtext = function (*-- quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE --*) let quoted_string r = - char '"' r; + Buf_read.char '"' r; let buf = Buffer.create 100 in let rec aux () = - match any_char r with + match Buf_read.any_char r with | '"' -> Buffer.contents buf | '\\' -> Buffer.add_char buf (quoted_char r); @@ -90,30 +89,32 @@ let quoted_string r = aux () let optional c x r = - let c2 = peek_char r in + let c2 = Buf_read.peek_char r in if Some c = c2 then ( - consume r 1; + Buf_read.consume r 1; Some (x r)) else None (*-- https://datatracker.ietf.org/doc/html/rfc7230#section-4.1 --*) let chunk_ext_val = - let* c = peek_char in - match c with Some '"' -> quoted_string | _ -> token + let open Buf_read.Syntax in + let* c = Buf_read.peek_char in + match c with Some '"' -> quoted_string | _ -> Parser.token let rec chunk_exts r = - let c = peek_char r in + let c = Buf_read.peek_char r in match c with | Some ';' -> - consume r 1; - let name = token r in + Buf_read.consume r 1; + let name = Parser.token r in let value = optional '=' chunk_ext_val r in { name; value } :: chunk_exts r | _ -> [] let chunk_size = - let* sz = take_while1 hex_digit in - try return (Format.sprintf "0x%s" sz |> int_of_string) + let open Buf_read.Syntax in + let* sz = Parser.take_while1 hex_digit in + try Parser.return (Format.sprintf "0x%s" sz |> int_of_string) with _ -> failwith (Format.sprintf "Invalid chunk_size: %s" sz) (* Be strict about headers allowed in trailer headers to minimize security @@ -148,21 +149,22 @@ let request_trailer_headers headers = (* Chunk decoding algorithm is explained at https://datatracker.ietf.org/doc/html/rfc7230#section-4.1.3 *) let chunk (total_read : int) (headers : Http.Header.t) = + let open Buf_read.Syntax in let* sz = chunk_size in match sz with | sz when sz > 0 -> - let* extensions = chunk_exts <* crlf in - let* data = take sz <* crlf in - return @@ `Chunk (sz, data, extensions) + let* extensions = chunk_exts <* Parser.crlf in + let* data = Buf_read.take sz <* Parser.crlf in + Parser.return @@ `Chunk (sz, data, extensions) | 0 -> - let* extensions = chunk_exts <* crlf in + let* extensions = chunk_exts <* Parser.crlf in (* Read trailer headers if any and append those to request headers. Only headers names appearing in 'Trailer' request headers and "allowed" trailer headers are appended to request. The spec at https://datatracker.ietf.org/doc/html/rfc7230#section-4.1.3 specifies that 'Content-Length' and 'Transfer-Encoding' headers must be updated. *) - let* trailer_headers = http_headers in + let* trailer_headers = Parser.http_headers in let request_trailer_headers = request_trailer_headers headers in let trailer_headers = List.filter @@ -201,7 +203,7 @@ let chunk (total_read : int) (headers : Http.Header.t) = let headers = Http.Header.add headers "Content-Length" (string_of_int total_read) in - return @@ `Last_chunk (extensions, headers) + Parser.return @@ `Last_chunk (extensions, headers) | sz -> failwith (Format.sprintf "Invalid chunk size: %d" sz) let read_chunked reader headers f = @@ -221,3 +223,46 @@ let read_chunked reader headers f = in chunk_loop f | _ -> None + +let write_headers writer headers = + Http.Header.iter + (fun k v -> + Buf_write.string writer k; + Buf_write.string writer ": "; + Buf_write.string writer v; + Buf_write.string writer "\r\n") + headers + +(* https://datatracker.ietf.org/doc/html/rfc7230#section-4.1 *) +let write_chunked writer chunk_writer = + let write_extensions exts = + List.iter + (fun { name; value } -> + let v = + match value with None -> "" | Some v -> Printf.sprintf "=%s" v + in + Buf_write.string writer (Printf.sprintf ";%s%s" name v)) + exts + in + let write_body = function + | Chunk { size; data; extensions = exts } -> + Buf_write.string writer (Printf.sprintf "%X" size); + write_extensions exts; + Buf_write.string writer "\r\n"; + Buf_write.string writer data; + Buf_write.string writer "\r\n" + | Last_chunk exts -> + Buf_write.string writer "0"; + write_extensions exts; + Buf_write.string writer "\r\n" + in + chunk_writer.body_writer write_body; + chunk_writer.trailer_writer (write_headers writer); + Buf_write.string writer "\r\n" + +let write_body writer body = + match body with + | Fixed s -> Buf_write.string writer s + | Chunked chunk_writer -> write_chunked writer chunk_writer + | Custom f -> f writer + | Empty -> () diff --git a/cohttp-eio/src/client.ml b/cohttp-eio/src/client.ml index b71f4247f5..2c3b4e4423 100644 --- a/cohttp-eio/src/client.ml +++ b/cohttp-eio/src/client.ml @@ -1,4 +1,5 @@ module Buf_read = Eio.Buf_read +module Buf_write = Eio.Buf_write type response = Http.Response.t * Buf_read.t type resource_path = string @@ -6,7 +7,6 @@ type resource_path = string type 'a body_disallowed_call = ?version:Http.Version.t -> ?headers:Http.Header.t -> - Eio.Switch.t -> (#Eio.Flow.two_way as 'a) -> resource_path -> response @@ -17,31 +17,29 @@ type 'a body_allowed_call = ?version:Http.Version.t -> ?headers:Http.Header.t -> ?body:Body.t -> - Eio.Switch.t -> (#Eio.Flow.two_way as 'a) -> resource_path -> response (* Request line https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1 *) let write_request writer (meth, version, headers, resource_path, body) = - Writer.write_string writer (Http.Method.to_string meth); - Writer.write_char writer ' '; - Writer.write_string writer resource_path; - Writer.write_char writer ' '; - Writer.write_string writer (Http.Version.to_string version); - Writer.write_string writer "\r\n"; - Writer.write_headers writer headers; - Writer.write_string writer "\r\n"; - Writer.write_body writer body + Buf_write.string writer (Http.Method.to_string meth); + Buf_write.char writer ' '; + Buf_write.string writer resource_path; + Buf_write.char writer ' '; + Buf_write.string writer (Http.Version.to_string version); + Buf_write.string writer "\r\n"; + Body.write_headers writer headers; + Buf_write.string writer "\r\n"; + Body.write_body writer body (* response parser *) let is_digit = function '0' .. '9' -> true | _ -> false -open Buf_read.Syntax - let status_code = let open Parser in + let open Buf_read.Syntax in let+ status = take_while1 is_digit in Http.Status.of_int (int_of_string status) @@ -52,6 +50,7 @@ let reason_phrase = (* https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2 *) let response buf_read = + let open Buf_read.Syntax in match Buf_read.at_end_of_input buf_read with | true -> Stdlib.raise_notrace End_of_file | false -> @@ -64,11 +63,9 @@ let response buf_read = (* Generic HTTP call *) let call ?(meth = `GET) ?(version = `HTTP_1_1) ?(headers = Http.Header.init ()) - ?(body = Body.Empty) sw flow uri = - let writer = Writer.create (flow :> Eio.Flow.sink) in - Eio.Fiber.fork ~sw (fun () -> Writer.run writer); + ?(body = Body.Empty) flow uri = + Buf_write.with_flow ~initial_size:0x1000 flow @@ fun writer -> write_request writer (meth, version, headers, uri, body); - Writer.wakeup writer; let reader = Eio.Buf_read.of_flow ~initial_size:0x1000 ~max_size:max_int (flow :> Eio.Flow.source) @@ -78,25 +75,25 @@ let call ?(meth = `GET) ?(version = `HTTP_1_1) ?(headers = Http.Header.init ()) (* HTTP Calls with Body Disallowed *) -let get ?version ?headers sw stream uri = - call ~meth:`GET ?version ?headers sw stream uri +let get ?version ?headers stream uri = + call ~meth:`GET ?version ?headers stream uri -let head ?version ?headers sw stream uri = - call ~meth:`HEAD ?version ?headers sw stream uri +let head ?version ?headers stream uri = + call ~meth:`HEAD ?version ?headers stream uri -let delete ?version ?headers sw stream uri = - call ~meth:`DELETE ?version ?headers sw stream uri +let delete ?version ?headers stream uri = + call ~meth:`DELETE ?version ?headers stream uri (* HTTP Calls with Body Allowed *) -let post ?version ?headers ?body sw stream uri = - call ~meth:`POST ?version ?headers ?body sw stream uri +let post ?version ?headers ?body stream uri = + call ~meth:`POST ?version ?headers ?body stream uri -let put ?version ?headers ?body sw stream uri = - call ~meth:`PUT ?version ?headers ?body sw stream uri +let put ?version ?headers ?body stream uri = + call ~meth:`PUT ?version ?headers ?body stream uri -let patch ?version ?headers ?body sw stream uri = - call ~meth:`PATCH ?version ?headers ?body sw stream uri +let patch ?version ?headers ?body stream uri = + call ~meth:`PATCH ?version ?headers ?body stream uri (* Response Body *) diff --git a/cohttp-eio/src/cohttp_eio.mli b/cohttp-eio/src/cohttp_eio.mli index 38b3444e0b..9daa855b72 100644 --- a/cohttp-eio/src/cohttp_eio.mli +++ b/cohttp-eio/src/cohttp_eio.mli @@ -2,7 +2,7 @@ module Body : sig type t = | Fixed of string | Chunked of chunk_writer - | Custom of (Eio.Flow.sink -> unit) + | Custom of (Eio.Buf_write.t -> unit) | Empty and chunk_writer = { @@ -102,7 +102,6 @@ module Client : sig type 'a body_disallowed_call = ?version:Http.Version.t -> ?headers:Http.Header.t -> - Eio.Switch.t -> (#Eio.Flow.two_way as 'a) -> resource_path -> response @@ -113,7 +112,6 @@ module Client : sig ?version:Http.Version.t -> ?headers:Http.Header.t -> ?body:Body.t -> - Eio.Switch.t -> (#Eio.Flow.two_way as 'a) -> resource_path -> response @@ -127,7 +125,6 @@ module Client : sig ?version:Http.Version.t -> ?headers:Http.Header.t -> ?body:Body.t -> - Eio.Switch.t -> #Eio.Flow.two_way -> resource_path -> response diff --git a/cohttp-eio/src/server.ml b/cohttp-eio/src/server.ml index a3e085671f..c7a6036ed2 100644 --- a/cohttp-eio/src/server.ml +++ b/cohttp-eio/src/server.ml @@ -1,8 +1,10 @@ open Eio.Std +module Buf_read = Eio.Buf_read +module Buf_write = Eio.Buf_write type middleware = handler -> handler and handler = request -> response -and request = Http.Request.t * Eio.Buf_read.t +and request = Http.Request.t * Buf_read.t and response = Http.Response.t * Body.t let domain_count = @@ -58,30 +60,30 @@ let internal_server_error_response = let bad_request_response = (Http.Response.make ~status:`Bad_request (), Body.Empty) -let write_response (writer : Writer.t) - ((response, body) : Http.Response.t * Body.t) = +let write_response writer ((response, body) : Http.Response.t * Body.t) = let version = Http.Version.to_string response.version in let status = Http.Status.to_string response.status in - Writer.write_string writer version; - Writer.write_char writer ' '; - Writer.write_string writer status; - Writer.write_string writer "\r\n"; - Writer.write_headers writer response.headers; - Writer.write_string writer "\r\n"; - Writer.write_body writer body + Buf_write.string writer version; + Buf_write.char writer ' '; + Buf_write.string writer status; + Buf_write.string writer "\r\n"; + Body.write_headers writer response.headers; + Buf_write.string writer "\r\n"; + Body.write_body writer body (* request parsers *) -open Eio.Buf_read.Syntax -module Buf_read = Eio.Buf_read - let meth = + let open Eio.Buf_read.Syntax in let+ meth = Parser.(token <* space) in Http.Method.of_string meth -let resource = Parser.(take_while1 (fun c -> c != ' ') <* space) +let resource = + let open Eio.Buf_read.Syntax in + Parser.(take_while1 (fun c -> c != ' ') <* space) let[@warning "-3"] http_request t = + let open Eio.Buf_read.Syntax in match Buf_read.at_end_of_input t with | true -> Stdlib.raise_notrace End_of_file | false -> @@ -102,17 +104,14 @@ let rec handle_request reader writer flow handler = (* A custom response needs to write the main response before calling the custom function for the body. Response.write wakes the writer for us if that is the case. *) - if not (is_custom body) then Writer.wakeup writer; if Http.Request.is_keep_alive request then handle_request reader writer flow handler | (exception End_of_file) | (exception Eio.Net.Connection_reset _) -> () | exception (Failure _ as ex) -> write_response writer bad_request_response; - Writer.wakeup writer; raise ex | exception ex -> write_response writer internal_server_error_response; - Writer.wakeup writer; raise ex let run_domain ssock handler = @@ -127,9 +126,8 @@ let run_domain ssock handler = Eio.Buf_read.of_flow ~initial_size:0x1000 ~max_size:max_int (flow :> Eio.Flow.source) in - let writer = Writer.create (flow :> Eio.Flow.sink) in - Eio.Fiber.fork ~sw (fun () -> Writer.run writer); - handle_request reader writer flow handler) + Buf_write.with_flow ~initial_size:0x1000 flow (fun writer -> + handle_request reader writer flow handler)) done) let run ?(socket_backlog = 128) ?(domains = domain_count) ~port env sw handler = diff --git a/cohttp-eio/src/writer.ml b/cohttp-eio/src/writer.ml deleted file mode 100644 index 4b4d01d6c6..0000000000 --- a/cohttp-eio/src/writer.ml +++ /dev/null @@ -1,118 +0,0 @@ -(*---------------------------------------------------------------------------- - Copyright (c) 2017 Inhabited Type LLC. - All rights reserved. - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - 3. Neither the name of the author nor the names of his contributors - may be used to endorse or promote products derived from this software - without specific prior written permission. - THIS SOFTWARE IS PROVIDED BY THE CONTRIBUTORS ``AS IS'' AND ANY EXPRESS - OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE FOR - ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS - OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) - HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN - ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - POSSIBILITY OF SUCH DAMAGE. - ----------------------------------------------------------------------------*) -module Optional_thunk : sig - type t - - val none : t - val some : (unit -> unit) -> t - val call_if_some : t -> unit -end = struct - type t = unit -> unit - - let none = Sys.opaque_identity (fun () -> ()) - - let some f = - if f == none then - failwith - "Optional_thunk: this function is not representable as a some value"; - f - - let call_if_some t = t () -end - -type t = { - sink : Eio.Flow.sink; - buf : Buffer.t; - mutable wakeup : Optional_thunk.t; -} - -let create sink = - let buf = Buffer.create 0x1000 in - { sink; buf; wakeup = Optional_thunk.none } - -let wakeup t = - let f = t.wakeup in - t.wakeup <- Optional_thunk.none; - Optional_thunk.call_if_some f - -let write_string t s = Buffer.add_string t.buf s -let write_char t c = Buffer.add_char t.buf c - -let write_headers t headers = - Http.Header.iter - (fun k v -> - write_string t k; - write_string t ": "; - write_string t v; - write_string t "\r\n") - headers - -(* https://datatracker.ietf.org/doc/html/rfc7230#section-4.1 *) -let write_chunked t (chunk_writer : Body.chunk_writer) = - let write_extensions exts = - List.iter - (fun { Body.name; value } -> - let v = - match value with None -> "" | Some v -> Printf.sprintf "=%s" v - in - write_string t (Printf.sprintf ";%s%s" name v)) - exts - in - let write_body = function - | Body.Chunk { size; data; extensions = exts } -> - write_string t (Printf.sprintf "%X" size); - write_extensions exts; - write_string t "\r\n"; - write_string t data; - write_string t "\r\n" - | Last_chunk exts -> - write_string t "0"; - write_extensions exts; - write_string t "\r\n" - in - chunk_writer.body_writer write_body; - chunk_writer.trailer_writer (write_headers t); - write_string t "\r\n" - -let write_body t body = - match body with - | Body.Fixed s -> write_string t s - | Chunked chunk_writer -> write_chunked t chunk_writer - | Custom f -> - wakeup t; - f (t.sink :> Eio.Flow.sink) - | Empty -> () - -let run t = - let rec loop () = - if Buffer.length t.buf > 0 then ( - Eio.Flow.copy_string (Buffer.contents t.buf) t.sink; - Buffer.clear t.buf; - loop ()) - else t.wakeup <- Optional_thunk.some loop - in - loop () diff --git a/cohttp-eio/tests/test_client.ml b/cohttp-eio/tests/test_client.ml index c204231e59..d956bdfa4a 100644 --- a/cohttp-eio/tests/test_client.ml +++ b/cohttp-eio/tests/test_client.ml @@ -12,7 +12,7 @@ let get () = let res = Client.get ~headers:(Http.Header.of_list [ ("Accept", "application/json") ]) - sw (get_conn env sw) "/get" + (get_conn env sw) "/get" in match Client.read_fixed res with Some s -> print_string s | None -> () @@ -28,7 +28,7 @@ let post () = [ ("Accept", "application/json"); ("Content-Length", content_length); ]) - ~body:(Body.Fixed content) sw (get_conn env sw) "/post" + ~body:(Body.Fixed content) (get_conn env sw) "/post" in match Client.read_fixed res with Some s -> print_string s | None -> ()