From cf03ea35814d59a3297b777fc8c92e4c45c5c40c Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 8 Sep 2023 23:53:00 -0400 Subject: [PATCH 01/46] start typechecking services --- src/compilerlib/pb_typing.ml | 21 ++-- src/compilerlib/pb_typing.mli | 3 +- src/compilerlib/pb_typing_recursion.mli | 3 +- src/compilerlib/pb_typing_resolution.ml | 114 ++++++++++++------- src/compilerlib/pb_typing_resolution.mli | 7 +- src/compilerlib/pb_typing_type_tree.ml | 24 +++- src/compilerlib/pb_typing_util.ml | 10 +- src/compilerlib/pb_typing_validation.ml | 59 ++++++++-- src/compilerlib/pb_typing_validation.mli | 2 +- src/compilerlib/pb_util.ml | 9 ++ src/compilerlib/pb_util.mli | 3 + src/ocaml-protoc/ocaml_protoc_compilation.ml | 27 +++-- 12 files changed, 206 insertions(+), 76 deletions(-) diff --git a/src/compilerlib/pb_typing.ml b/src/compilerlib/pb_typing.ml index c5725749..f185b27c 100644 --- a/src/compilerlib/pb_typing.ml +++ b/src/compilerlib/pb_typing.ml @@ -23,16 +23,23 @@ *) +module Pt = Pb_parsing_parse_tree module Tt = Pb_typing_type_tree -let perform_typing protos = - let typed_protos = +let perform_typing (protos : Pt.proto list) : Pb_field_type.resolved Tt.proto = + let validated_types, validated_services = List.fold_left - (fun typed_protos proto -> - typed_protos @ Pb_typing_validation.validate proto) - [] protos + (fun (typed_protos, services) proto -> + let val_proto = Pb_typing_validation.validate proto in + ( typed_protos @ List.flatten val_proto.proto_types, + services @ val_proto.proto_services )) + ([], []) protos in - let typed_protos = Pb_typing_resolution.resolve_types typed_protos in + let t, types = Pb_typing_resolution.resolve_types validated_types in + let services = Pb_typing_resolution.resolve_services t validated_services in - List.rev @@ Pb_typing_recursion.group typed_protos + { + Tt.proto_types = List.rev @@ Pb_typing_recursion.group types; + proto_services = services; + } diff --git a/src/compilerlib/pb_typing.mli b/src/compilerlib/pb_typing.mli index a84db4a9..8bec414a 100644 --- a/src/compilerlib/pb_typing.mli +++ b/src/compilerlib/pb_typing.mli @@ -37,8 +37,7 @@ module Tt = Pb_typing_type_tree val perform_typing : - Pb_parsing_parse_tree.proto list -> - Pb_field_type.resolved Tt.proto_type list list + Pb_parsing_parse_tree.proto list -> Pb_field_type.resolved Tt.proto (** [perform_typing parsed_tree] returned the type tree organized in groups of fully resolved types. Each group contains all the mutually recursive types and the type group by reverse dependency order. *) diff --git a/src/compilerlib/pb_typing_recursion.mli b/src/compilerlib/pb_typing_recursion.mli index 6e1891ee..8afbdd5d 100644 --- a/src/compilerlib/pb_typing_recursion.mli +++ b/src/compilerlib/pb_typing_recursion.mli @@ -39,7 +39,8 @@ module Tt = Pb_typing_type_tree val group : - Pb_field_type.resolved Tt.proto -> Pb_field_type.resolved Tt.proto list + Pb_field_type.resolved Tt.proto_type list -> + Pb_field_type.resolved Tt.proto_type list list (** [group types] returns the list of all the mutually recursive group of types in reverse order of dependency. In other the last group of types of the returned list don't depend on any other types. *) diff --git a/src/compilerlib/pb_typing_resolution.ml b/src/compilerlib/pb_typing_resolution.ml index 370460eb..5720ca5a 100644 --- a/src/compilerlib/pb_typing_resolution.ml +++ b/src/compilerlib/pb_typing_resolution.ml @@ -89,11 +89,11 @@ end (* Types_by_scope *) (* this function returns the type path of a message which is the - * packages followed by the enclosing message names and eventually - * the message name of the given type. - * - * If the type is an enum then [Failure] is raised. - * TODO: change [Failure] to a [Pb_exception.Compilation_error] *) + packages followed by the enclosing message names and eventually + the message name of the given type. + + If the type is an enum then [Failure] is raised. + TODO: change [Failure] to a [Pb_exception.Compilation_error] *) let type_path_of_type { Tt.scope; spec; _ } = match spec with | Tt.Enum _ -> assert false @@ -102,15 +102,15 @@ let type_path_of_type { Tt.scope; spec; _ } = packages @ message_names @ [ message_name ] (* this function returns all the scope to search for a type starting - * by the most innner one first. - * - * If [message_scope] = ['Msg1'; 'Msg2'] and [field_scope] = ['Msg3'] then - * the following scopes will be returned: - * [ - * ['Msg1'; 'Msg2'; 'Msg3']; // This would be the scope of the current msg - * ['Msg1'; 'Msg3'; ]; // Outer message scope - * ['Msg3'; ] // Top level scope - * ] *) + by the most innner one first. + + If [message_scope] = ['Msg1'; 'Msg2'] and [field_scope] = ['Msg3'] then + the following scopes will be returned: + [ + ['Msg1'; 'Msg2'; 'Msg3']; // This would be the scope of the current msg + ['Msg1'; 'Msg3'; ]; // Outer message scope + ['Msg3'; ] // Top level scope + ] *) let compute_search_type_paths unresolved_field_type message_type_path = let { Pb_field_type.type_path; type_name = _; from_root } = unresolved_field_type @@ -126,14 +126,14 @@ let compute_search_type_paths unresolved_field_type message_type_path = List.rev @@ loop [] message_type_path ) -(* this function ensure that the default value of the field is correct - * with respect to its type when this latter is a builtin one. - * - * in case the default value is invalid then an - * [Pb_exception.Compilation_error] is raised. - * - * Note that this function also does type coersion when the default value - * is an int and the builtin type is a float or double. *) +(** this function ensure that the default value of the field is correct + with respect to its type when this latter is a builtin one. + + in case the default value is invalid then an + [Pb_exception.Compilation_error] is raised. + + Note that this function also does type coersion when the default value + is an int and the builtin type is a float or double. *) let resolve_builtin_type_field_default field_name builtin_type field_default = match field_default with | None -> None @@ -180,16 +180,16 @@ let resolve_builtin_type_field_default field_name builtin_type field_default = E.invalid_default_value ~field_name ~info:"default value not supported for bytes" ()) -(* This function verifies that the default value for a used defined - * field is correct. - * - * In protobuf, only field which type is [enum] can have a default - * value. Field of type [message] can't. - * - * In the case the field is an enum then the default value must be - * a litteral value which is one of the enum value. - * - * If the validation fails then [Pb_exception.Compilation_error] is raised *) +(** This function verifies that the default value for a used defined + field is correct. + + In protobuf, only field which type is [enum] can have a default + value. Field of type [message] can't. + + In the case the field is an enum then the default value must be + a litteral value which is one of the enum value. + + If the validation fails then [Pb_exception.Compilation_error] is raised *) let resolve_enum_field_default field_name type_ field_default = match field_default with | None -> None @@ -221,14 +221,14 @@ let resolve_enum_field_default field_name type_ field_default = E.invalid_default_value ~field_name ~info:"default value not supported for message" () -(* this function resolves both the type and the defaut value of a field - * type. Note that it is necessary to verify both at the same time since - * the default value must be of the same type as the field type in order - * to be valid. - * - * For builtin the type the validation is trivial while for user defined - * type a search must be done for all the possible scopes the type - * might be in. *) +(** this function resolves both the type and the defaut value of a field + type. Note that it is necessary to verify both at the same time since + the default value must be of the same type as the field type in order + to be valid. + + For builtin the type the validation is trivial while for user defined + type a search must be done for all the possible scopes the type + might be in. *) let resolve_field_type_and_default t field_name field_type field_default message_type_path = match field_type with @@ -327,6 +327,36 @@ let resolve_type t type_ : int Tt.proto_type = in { Tt.scope; id; file_name; file_options; spec } -let resolve_types types = +let resolve_types types : Types_by_scope.t * _ list = let t = List.fold_left Types_by_scope.add Types_by_scope.empty types in - List.map (resolve_type t) types + t, List.map (resolve_type t) types + +let resolve_service t (service : _ Tt.service) : + Pb_field_type.resolved Tt.service = + let resolve_ty ~rpc_name ~name ty : Pb_field_type.resolved Pb_field_type.t = + let rpc_type, _field_default = + let do_resolve () = + resolve_field_type_and_default t name ty None service.service_packages + in + match do_resolve () with + | ret -> ret + | exception Not_found -> + E.unresolved_type ~field_name:name ~type_:"" ~message_name:rpc_name () + in + rpc_type + in + + let resolve_rpc (rpc : _ Tt.rpc) : _ Tt.rpc = + let rpc_name = rpc.rpc_name in + { + rpc with + Tt.rpc_req = resolve_ty ~rpc_name ~name:"req" rpc.rpc_req; + rpc_res = resolve_ty ~rpc_name ~name:"res" rpc.rpc_res; + } + in + + { service with Tt.service_body = List.map resolve_rpc service.service_body } + +let resolve_services (t : Types_by_scope.t) (services : _ Tt.service list) : + _ Tt.service list = + List.map (resolve_service t) services diff --git a/src/compilerlib/pb_typing_resolution.mli b/src/compilerlib/pb_typing_resolution.mli index ca292a32..17e5f790 100644 --- a/src/compilerlib/pb_typing_resolution.mli +++ b/src/compilerlib/pb_typing_resolution.mli @@ -69,7 +69,12 @@ end val resolve_types : Pb_field_type.unresolved Tt.proto_type list -> - Pb_field_type.resolved Tt.proto_type list + Types_by_scope.t * Pb_field_type.resolved Tt.proto_type list (** [resolve_types types] resolves all the field types for all the [types]. If a field cannot be resolved then [Pb_exception.Compilation_error] is raised. *) + +val resolve_services : + Types_by_scope.t -> + Pb_field_type.unresolved Tt.service list -> + Pb_field_type.resolved Tt.service list diff --git a/src/compilerlib/pb_typing_type_tree.ml b/src/compilerlib/pb_typing_type_tree.ml index 97cd4edc..580a7778 100644 --- a/src/compilerlib/pb_typing_type_tree.ml +++ b/src/compilerlib/pb_typing_type_tree.ml @@ -108,4 +108,26 @@ type 'a proto_type = { spec: 'a proto_type_spec; } -type 'a proto = 'a proto_type list +type 'a rpc = { + rpc_name: string; + rpc_options: Pb_option.set; + rpc_req_stream: bool; + rpc_req: 'a; + rpc_res_stream: bool; + rpc_res: 'a; +} +(** A RPC specification. *) + +type 'a service = { + service_name: string; + service_packages: string list; (** Package in which this belongs *) + service_body: 'a Pb_field_type.t rpc list; +} +(** A service, composed of multiple RPCs. *) + +type 'a proto = { + proto_types: 'a proto_type list list; + (** List of strongly connected type definitions *) + proto_services: 'a service list; +} +(** A proto file is composed of a list of types and a list of services. *) diff --git a/src/compilerlib/pb_typing_util.ml b/src/compilerlib/pb_typing_util.ml index eab918b6..0b6b5f6f 100644 --- a/src/compilerlib/pb_typing_util.ml +++ b/src/compilerlib/pb_typing_util.ml @@ -42,8 +42,14 @@ let field_option { Tt.field_options; _ } option_name = let empty_scope = { Tt.packages = []; message_names = [] } let type_id_of_type { Tt.id; _ } = id -let type_of_id all_types id = - List.find (fun t -> type_id_of_type t = id) all_types +let type_of_id (p : _ Tt.proto) id = + match + Pb_util.List.find_map + (fun tys -> Pb_util.List.find_opt (fun t -> type_id_of_type t = id) tys) + p.proto_types + with + | Some ty -> ty + | None -> raise Not_found let string_of_type_scope { Tt.packages; message_names } = Printf.sprintf "scope:{packages:%s, message_names:%s}" diff --git a/src/compilerlib/pb_typing_validation.ml b/src/compilerlib/pb_typing_validation.ml index 4dd017fa..206b98cb 100644 --- a/src/compilerlib/pb_typing_validation.ml +++ b/src/compilerlib/pb_typing_validation.ml @@ -265,8 +265,47 @@ let rec validate_message ?(parent_options = Pb_option.empty) file_name acc.Acc.all_types @ [ make_proto_type ~file_name ~file_options ~id ~scope:message_scope ~spec ] -let validate (proto : Pt.proto) : _ Tt.proto_type list = - let { Pt.package; Pt.proto_file_name; messages; enums; file_options; _ } = +let validate_service (scope : Tt.type_scope) (service : Pt.service) : + _ Tt.service = + let { Pt.service_name; service_body } = service in + let service_body = + List.filter_map + (function + | Pt.Service_option _ -> None + | Pt.Service_rpc + { + rpc_name; + rpc_options; + rpc_req_stream; + rpc_req; + rpc_res_stream; + rpc_res; + } -> + let rpc = + { + Tt.rpc_name; + rpc_options; + rpc_req_stream; + rpc_req; + rpc_res_stream; + rpc_res; + } + in + Some rpc) + service_body + in + { Tt.service_packages = scope.packages; service_name; service_body } + +let validate (proto : Pt.proto) : _ Tt.proto = + let { + Pt.package; + Pt.proto_file_name; + messages; + enums; + file_options; + services; + _; + } = proto in @@ -276,11 +315,17 @@ let validate (proto : Pt.proto) : _ Tt.proto_type list = let pbtt_msgs = List.fold_right (fun e pbtt_msgs -> - compile_enum_p1 file_name file_options scope e :: pbtt_msgs) + [ compile_enum_p1 file_name file_options scope e ] :: pbtt_msgs) enums [] in - List.fold_left - (fun pbtt_msgs pbpt_msg -> - pbtt_msgs @ validate_message file_name file_options scope pbpt_msg) - pbtt_msgs messages + let proto_types = + List.fold_left + (fun pbtt_msgs pbpt_msg -> + let tys = validate_message file_name file_options scope pbpt_msg in + tys :: pbtt_msgs) + pbtt_msgs messages + in + + let proto_services = List.map (validate_service scope) services in + { Tt.proto_types; proto_services } diff --git a/src/compilerlib/pb_typing_validation.mli b/src/compilerlib/pb_typing_validation.mli index b2b53dcd..796a104e 100644 --- a/src/compilerlib/pb_typing_validation.mli +++ b/src/compilerlib/pb_typing_validation.mli @@ -50,4 +50,4 @@ val validate_message : (* file options *) Tt.type_scope -> Pt.message -> - Pb_field_type.unresolved Tt.proto + Pb_field_type.unresolved Tt.proto_type list diff --git a/src/compilerlib/pb_util.ml b/src/compilerlib/pb_util.ml index e45840e3..2163d46b 100644 --- a/src/compilerlib/pb_util.ml +++ b/src/compilerlib/pb_util.ml @@ -127,6 +127,15 @@ module List = struct (match f hd with | None -> filter_map f tl | Some x -> x :: filter_map f tl) + + let find_opt f l = try Some (List.find f l) with Not_found -> None + + let rec find_map f = function + | [] -> None + | x :: tl -> + (match f x with + | Some _ as r -> r + | None -> find_map f tl) end module Int_map = Map.Make (struct diff --git a/src/compilerlib/pb_util.mli b/src/compilerlib/pb_util.mli index ac88b694..4320c7a7 100644 --- a/src/compilerlib/pb_util.mli +++ b/src/compilerlib/pb_util.mli @@ -81,6 +81,9 @@ module List : sig (** [filter_map f l] returns the list of element [x] for which [f] returned [Some x]. The length of the returned list will be less or equal than the length of the input list [l]. *) + + val find_opt : ('a -> bool) -> 'a list -> 'a option + val find_map : ('a -> 'b option) -> 'a list -> 'b option end module Str_map : Map.S with type key = string diff --git a/src/ocaml-protoc/ocaml_protoc_compilation.ml b/src/ocaml-protoc/ocaml_protoc_compilation.ml index d34c339f..1fc2cdee 100644 --- a/src/ocaml-protoc/ocaml_protoc_compilation.ml +++ b/src/ocaml-protoc/ocaml_protoc_compilation.ml @@ -87,17 +87,20 @@ let compile cmdline cmd_line_files_options = in (* typing *) - let grouped_protos = Pb_typing.perform_typing protos in - let all_typed_protos = List.flatten grouped_protos in + let typed_proto = Pb_typing.perform_typing protos in - (* Only get the types which are part of the given proto file - * (compilation unit) *) - let grouped_proto = - List.filter - (function - | { Tt.file_name; _ } :: _ when file_name = proto_file_name -> true - | _ -> false) - grouped_protos + (* Only get the types which are part of the given proto file + (compilation unit) *) + let typed_proto = + { + typed_proto with + Tt.proto_types = + List.filter + (function + | { Tt.file_name; _ } :: _ when file_name = proto_file_name -> true + | _ -> false) + typed_proto.proto_types; + } in (* -- OCaml Backend -- *) @@ -110,11 +113,11 @@ let compile cmdline cmd_line_files_options = List.flatten @@ List.map (fun t -> - BO.compile ~unsigned_tag:!unsigned_tag all_typed_protos t) + BO.compile ~unsigned_tag:!unsigned_tag typed_proto t) types in l :: ocaml_types) - [] grouped_proto + [] typed_proto.proto_types in ocaml_types, proto_file_options From c4458584ee7546b18596ec5bc10dc6a489077df8 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 9 Sep 2023 16:04:12 -0400 Subject: [PATCH 02/46] details --- src/runtime/pbrt.ml | 2 -- src/runtime/pbrt.mli | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/runtime/pbrt.ml b/src/runtime/pbrt.ml index 73f1d569..1b7d7cdd 100644 --- a/src/runtime/pbrt.ml +++ b/src/runtime/pbrt.ml @@ -756,7 +756,6 @@ module Repeated_field = struct let to_list t = map_to_list identity t end -(* Repeated_field*) module Pp = struct module F = Format @@ -816,4 +815,3 @@ module Pp = struct let pp_brk pp_record (fmt : F.formatter) r : unit = F.fprintf fmt "@[{ %a@;<1 -2>@]}" pp_record r end -(* Pp *) diff --git a/src/runtime/pbrt.mli b/src/runtime/pbrt.mli index 2425f514..a6197b90 100644 --- a/src/runtime/pbrt.mli +++ b/src/runtime/pbrt.mli @@ -362,10 +362,11 @@ module Encoder : sig val wrapper_bytes_value : bytes option -> t -> unit end +(** Optimized representation for repeated fields *) module Repeated_field : sig type 'a t (** optimized data structure for fast inserts so that decoding - can be efficient + of repeated fields can be efficient. Type can be constructed at no cost from an existing array. *) @@ -510,4 +511,3 @@ module Pp : sig val pp_brk : (formatter -> 'a -> unit) -> formatter -> 'a -> unit (** [pp_brk fmt r] formats record value [r] with curly brakets. *) end -(* Pp *) From 8a301d82de4dc907c9674eec1e885f9d70fbeeb5 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 9 Sep 2023 22:16:39 -0400 Subject: [PATCH 03/46] wip: code generation for services --- src/compilerlib/dune | 2 +- src/compilerlib/pb_codegen_all.ml | 41 +++++- src/compilerlib/pb_codegen_all.mli | 2 +- src/compilerlib/pb_codegen_backend.ml | 136 ++++++++++++------- src/compilerlib/pb_codegen_backend.mli | 6 +- src/compilerlib/pb_codegen_ocaml_type.ml | 30 +++- src/compilerlib/pb_codegen_types.mli | 2 +- src/compilerlib/pb_codegen_util.ml | 8 ++ src/compilerlib/pb_codegen_util.mli | 9 ++ src/compilerlib/pb_typing_type_tree.ml | 1 + src/compilerlib/pb_typing_validation.ml | 11 +- src/ocaml-protoc/ocaml_protoc.ml | 4 +- src/ocaml-protoc/ocaml_protoc_compilation.ml | 20 +-- src/ocaml-protoc/ocaml_protoc_generation.ml | 10 +- src/ocaml-protoc/ocaml_protoc_generation.mli | 2 +- src/runtime/pbrt.ml | 24 ++++ src/runtime/pbrt.mli | 30 ++++ 17 files changed, 245 insertions(+), 93 deletions(-) diff --git a/src/compilerlib/dune b/src/compilerlib/dune index 3064fcb5..7d2edf59 100644 --- a/src/compilerlib/dune +++ b/src/compilerlib/dune @@ -10,7 +10,7 @@ (modules pb_codegen_all pb_codegen_backend pb_codegen_decode_binary pb_codegen_decode_bs pb_codegen_decode_yojson pb_codegen_default pb_codegen_encode_binary pb_codegen_encode_bs pb_codegen_encode_yojson pb_codegen_formatting - pb_codegen_ocaml_type pb_codegen_pp pb_codegen_plugin pb_codegen_types + pb_codegen_ocaml_type pb_codegen_pp pb_codegen_plugin pb_codegen_types pb_codegen_services pb_codegen_util pb_exception pb_field_type pb_location pb_logger pb_option pb_parsing pb_parsing_lexer pb_parsing_parser pb_parsing_parse_tree pb_parsing_util pb_typing_graph pb_typing pb_typing_recursion diff --git a/src/compilerlib/pb_codegen_all.ml b/src/compilerlib/pb_codegen_all.ml index 2f2f5f09..b23df735 100644 --- a/src/compilerlib/pb_codegen_all.ml +++ b/src/compilerlib/pb_codegen_all.ml @@ -1,8 +1,8 @@ (* The MIT License (MIT) - + Copyright (c) 2016 Maxime Ransan - + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights @@ -29,6 +29,7 @@ module F = Pb_codegen_formatting module Plugin = Pb_codegen_plugin type codegen_f = Plugin.codegen_f +type codegen_service_f = Ot.service -> F.scope -> unit type ocaml_mod = { ml: F.scope; @@ -88,10 +89,20 @@ let generate_for_all_types (ocaml_types : Ot.type_ list list) (sc : F.scope) ()) ocaml_types +let generate_for_all_services (services : Ot.service list) (sc : F.scope) + (f : codegen_service_f) ocamldoc_title : unit = + (match ocamldoc_title with + | None -> () + | Some ocamldoc_title -> + F.empty_line sc; + F.linep sc "(** {2 %s} *)" ocamldoc_title; + F.empty_line sc); + + List.iter (fun (service : Ot.service) -> f service sc) services + let generate_type_and_default (self : ocaml_mod) ocaml_types : unit = generate_for_all_types ocaml_types self.ml Pb_codegen_types.gen_struct None; generate_for_all_types ocaml_types self.ml Pb_codegen_default.gen_struct None; - generate_for_all_types ocaml_types self.mli Pb_codegen_types.gen_sig (Some Pb_codegen_types.ocamldoc_title); generate_for_all_types ocaml_types self.mli Pb_codegen_default.gen_sig @@ -111,6 +122,19 @@ let generate_mutable_records (self : ocaml_mod) ocaml_types : unit = | _ -> ()) ocaml_types +let generate_service_struct service sc : unit = + Pb_codegen_services.gen_service_client_struct service sc; + Pb_codegen_services.gen_service_server_struct service sc + +let generate_service_sig service sc : unit = + Pb_codegen_services.gen_service_client_sig service sc; + Pb_codegen_services.gen_service_server_sig service sc + +let generate_services (self : ocaml_mod) services : unit = + generate_for_all_services services self.ml generate_service_struct None; + generate_for_all_services services self.mli generate_service_sig + (Some "Services") + let generate_plugin (self : ocaml_mod) ocaml_types (p : Plugin.t) : unit = let (module P) = p in F.line self.ml "[@@@ocaml.warning \"-27-30-39\"]"; @@ -119,11 +143,14 @@ let generate_plugin (self : ocaml_mod) ocaml_types (p : Plugin.t) : unit = generate_for_all_types ocaml_types self.mli P.gen_sig (Some P.ocamldoc_title); () -let codegen ocaml_types ~proto_file_options ~proto_file_name +let codegen (proto : Ot.proto) ~proto_file_options ~proto_file_name (plugins : Plugin.t list) : ocaml_mod = let self = new_ocaml_mod ~proto_file_options ~proto_file_name () in - generate_type_and_default self ocaml_types; + generate_type_and_default self proto.proto_types; if List.exists Pb_codegen_plugin.requires_mutable_records plugins then - generate_mutable_records self ocaml_types; - List.iter (generate_plugin self ocaml_types) plugins; + generate_mutable_records self proto.proto_types; + List.iter (generate_plugin self proto.proto_types) plugins; + + (* services come last, they need binary and json *) + generate_services self proto.proto_services; self diff --git a/src/compilerlib/pb_codegen_all.mli b/src/compilerlib/pb_codegen_all.mli index 6394720d..c8cf69ed 100644 --- a/src/compilerlib/pb_codegen_all.mli +++ b/src/compilerlib/pb_codegen_all.mli @@ -10,7 +10,7 @@ type ocaml_mod = { } val codegen : - Ot.type_ list list -> + Ot.proto -> proto_file_options:Pb_option.set -> proto_file_name:string -> Plugin.t list -> diff --git a/src/compilerlib/pb_codegen_backend.ml b/src/compilerlib/pb_codegen_backend.ml index 10243c86..3c6e5a29 100644 --- a/src/compilerlib/pb_codegen_backend.ml +++ b/src/compilerlib/pb_codegen_backend.ml @@ -91,7 +91,7 @@ let module_prefix_of_file_name file_name = | dot_index -> module_name @@ String.sub file_name 0 dot_index | exception Not_found -> E.invalid_file_name file_name -let type_name message_scope name = +let type_name message_scope name : string = let module S = String in let all_names = message_scope @ [ name ] in let all_names = @@ -132,7 +132,7 @@ let wrapper_type_of_type_name = function with the module name. (This is essentially expecting (rightly) a sub module with the same name. *) -let user_defined_type_of_id all_types file_name i = +let user_defined_type_of_id all_types file_name i : Ot.field_type = let module_prefix = module_prefix_of_file_name file_name in match Typing_util.type_of_id all_types i with | exception Not_found -> E.programmatic_error E.No_type_found_for_id @@ -169,7 +169,7 @@ let user_defined_type_of_id all_types file_name i = ) ) -let encoding_info_of_field_type all_types field_type = +let encoding_info_of_field_type all_types field_type : Ot.payload_kind = match field_type with | `Double -> Ot.Pk_bits64 | `Float -> Ot.Pk_bits32 @@ -206,7 +206,7 @@ let encoding_of_field all_types (field : (Pb_field_type.resolved, 'a) Tt.field) pk, Typing_util.field_number field, packed, Typing_util.field_default field let compile_field_type ~unsigned_tag all_types file_options field_options - file_name field_type = + file_name field_type : Ot.field_type = let ocaml_type = match Pb_option.get field_options "ocaml_type" with | Some Pb_option.(Scalar_value (Constant_literal "int_t")) -> `Int_t @@ -301,7 +301,7 @@ let ocaml_container field_options = | Some _ -> None let variant_of_oneof ?include_oneof_name ~outer_message_names ~unsigned_tag - all_types file_options file_name oneof_field = + all_types file_options file_name oneof_field : Ot.variant = let v_constructors = List.map (fun field -> @@ -339,30 +339,30 @@ let variant_of_oneof ?include_oneof_name ~outer_message_names ~unsigned_tag Ot.{ v_name; v_constructors } (* - * Notes on type level PPX extension handling. - * - * ocaml-protoc supports 2 custom options for defining type level ppx - * extensions: - * a) message option called ocaml_type_ppx - * b) file option called ocaml_all_types_ppx - * - * 'ocaml_type_ppx' has priority over 'ocaml_all_types_ppx' extension. - * This means that if a message contains 'ocaml_type_ppx' extension then the - * associated string will be used for the OCaml generated type ppx extension. - * - * 'ocaml_all_types_ppx' is a file option which is a convenient workflow when - * the ppx extensions are the same for all types. (Most likely the case). - * + Notes on type level PPX extension handling. + + ocaml-protoc supports 2 custom options for defining type level ppx + extensions: + a) message option called ocaml_type_ppx + b) file option called ocaml_all_types_ppx + + 'ocaml_type_ppx' has priority over 'ocaml_all_types_ppx' extension. + This means that if a message contains 'ocaml_type_ppx' extension then the + associated string will be used for the OCaml generated type ppx extension. + + 'ocaml_all_types_ppx' is a file option which is a convenient workflow when + the ppx extensions are the same for all types. (Most likely the case). + *) -(* utility function to return the string value from a sring option +(** utility function to return the string value from a sring option *) let string_of_string_option message_name = function | None -> None | Some Pb_option.(Scalar_value (Constant_string s)) -> Some s | _ -> E.invalid_ppx_extension_option message_name -(* utility function to implement the priority logic defined in the notes above. +(** utility function to implement the priority logic defined in the notes above. *) let process_all_types_ppx_extension file_name file_options type_level_ppx_extension = @@ -373,7 +373,7 @@ let process_all_types_ppx_extension file_name file_options |> string_of_string_option file_name let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) - (all_types : Pb_field_type.resolved Tt.proto) (file_name : string) + (proto : Pb_field_type.resolved Tt.proto) (file_name : string) (scope : Tt.type_scope) (message : Pb_field_type.resolved Tt.message) : Ot.type_ list = let module_prefix = module_prefix_of_file_name file_name in @@ -408,7 +408,7 @@ let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) | Tt.Message_oneof_field f :: [] -> let outer_message_names = message_names @ [ message_name ] in let variant = - variant_of_oneof ~unsigned_tag ~outer_message_names all_types file_options + variant_of_oneof ~unsigned_tag ~outer_message_names proto file_options file_name f in [ Ot.{ module_prefix; spec = Variant variant; type_level_ppx_extension } ] @@ -418,7 +418,7 @@ let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) (fun (variants, fields) -> function | Tt.Message_field field -> let pk, encoding_number, packed, _ = - encoding_of_field all_types field + encoding_of_field proto field in let field_name = Typing_util.field_name field in @@ -428,8 +428,8 @@ let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) let field_type = Typing_util.field_type field in let ocaml_field_type = - compile_field_type ~unsigned_tag all_types file_options - field_options file_name field_type + compile_field_type ~unsigned_tag proto file_options field_options + file_name field_type in let field_default = Typing_util.field_default field in @@ -440,17 +440,17 @@ let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) match Typing_util.field_label field with | `Nolabel -> (* From proto3 section on default value: - * https://goo.gl/NKt9Cc - * - * -- - * For message fields, the field is not set. Its exact value is - * language-dependent. See the generated code guide for details. - * -- - * - * Since we must support the face that the message won't be sent - * we always make such a field an OCaml option. It's the - * responsability of the application to check for [None] and - * perform any error handling if required. *) + https://goo.gl/NKt9Cc + + -- + For message fields, the field is not set. Its exact value is + language-dependent. See the generated code guide for details. + -- + + Since we must support the face that the message won't be sent + we always make such a field an OCaml option. It's the + responsability of the application to check for [None] and + perform any error handling if required. *) let is_message = match ocaml_field_type with | Ot.Ft_user_defined_type { Ot.udt_type = `Message; _ } -> @@ -493,7 +493,7 @@ let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) let outer_message_names = message_names @ [ message_name ] in let variant = variant_of_oneof ~unsigned_tag ~include_oneof_name:() - ~outer_message_names all_types file_options file_name field + ~outer_message_names proto file_options file_name field in let record_field = @@ -538,11 +538,11 @@ let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) in let key_type = - compile_field_type ~unsigned_tag all_types file_options - map_options file_name map_key_type + compile_field_type ~unsigned_tag proto file_options map_options + file_name map_key_type in - let key_pk = encoding_info_of_field_type all_types map_key_type in + let key_pk = encoding_info_of_field_type proto map_key_type in let key_type = match key_type with @@ -551,13 +551,11 @@ let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) in let value_type = - compile_field_type ~unsigned_tag all_types file_options - map_options file_name map_value_type + compile_field_type ~unsigned_tag proto file_options map_options + file_name map_value_type in - let value_pk = - encoding_info_of_field_type all_types map_value_type - in + let value_pk = encoding_info_of_field_type proto map_value_type in let associative_type = match ocaml_container map_options with @@ -636,12 +634,54 @@ let compile_enum file_options file_name scope enum = type_level_ppx_extension; } -let compile ~unsigned_tag all_types = function +let compile_rpc ~unsigned_tag ~(file_name : string) proto + (rpc : Pb_field_type.resolved_t Tt.rpc) : Ot.rpc = + let compile_ty ~stream ty : Ot.rpc_type = + let ty : Ot.field_type = + compile_field_type ~unsigned_tag proto Pb_option.empty rpc.rpc_options + file_name ty + in + if stream then + Ot.Rpc_stream ty + else + Ot.Rpc_scalar ty + in + + { + Ot.rpc_name = rpc.rpc_name; + rpc_req = compile_ty ~stream:rpc.rpc_req_stream rpc.rpc_req; + rpc_res = compile_ty ~stream:rpc.rpc_res_stream rpc.rpc_res; + } + +let compile_service ~(unsigned_tag : bool) proto + (service : Pb_field_type.resolved Tt.service) : Ot.service = + { + Ot.service_name = service.service_name; + service_packages = service.service_packages; + service_body = + List.map + (compile_rpc ~unsigned_tag ~file_name:service.service_file_name proto) + service.service_body; + } + +let compile1 ~unsigned_tag proto : _ -> Ot.type_ list = function | { Tt.spec = Tt.Message m; file_name; file_options; scope; _ } -> - compile_message ~unsigned_tag file_options all_types file_name scope m + compile_message ~unsigned_tag file_options proto file_name scope m | { Tt.spec = Tt.Enum e; file_name; scope; file_options; _ } -> [ compile_enum file_options file_name scope e ] +let compile ~unsigned_tag (proto : Pb_field_type.resolved Tt.proto) : Ot.proto = + let tys = + List.map + (fun t -> List.flatten @@ List.map (compile1 ~unsigned_tag proto) t) + proto.proto_types + in + + let services = + List.map (compile_service ~unsigned_tag proto) proto.proto_services + in + { Ot.proto_services = services; proto_types = tys } + module Internal = struct let is_mutable = is_mutable let constructor_name = constructor_name diff --git a/src/compilerlib/pb_codegen_backend.mli b/src/compilerlib/pb_codegen_backend.mli index fc0e1bff..ed65dfac 100644 --- a/src/compilerlib/pb_codegen_backend.mli +++ b/src/compilerlib/pb_codegen_backend.mli @@ -36,11 +36,7 @@ module Ot = Pb_codegen_ocaml_type (** {2 Compilation } *) -val compile : - unsigned_tag:bool -> - Pb_field_type.resolved Tt.proto -> - Pb_field_type.resolved Tt.proto_type -> - Ot.type_ list +val compile : unsigned_tag:bool -> Pb_field_type.resolved Tt.proto -> Ot.proto (** Internal helpers. diff --git a/src/compilerlib/pb_codegen_ocaml_type.ml b/src/compilerlib/pb_codegen_ocaml_type.ml index de7edecd..0730eae9 100644 --- a/src/compilerlib/pb_codegen_ocaml_type.ml +++ b/src/compilerlib/pb_codegen_ocaml_type.ml @@ -65,10 +65,10 @@ type field_type = | Ft_unit | Ft_basic_type of basic_type | Ft_user_defined_type of user_defined_type - (* New wrapper type which indicates that the corresponding ocaml + | Ft_wrapper_type of wrapper_type + (** New wrapper type which indicates that the corresponding ocaml Type should be an `option` along with the fact that it is encoded with special rules *) - | Ft_wrapper_type of wrapper_type type default_value = Pb_option.constant option @@ -155,3 +155,29 @@ type type_ = { spec: type_spec; type_level_ppx_extension: string option; } + +(** RPC argument or return type *) +type rpc_type = + | Rpc_scalar of field_type + | Rpc_stream of field_type + +type rpc = { + rpc_name: string; + rpc_req: rpc_type; + rpc_res: rpc_type; +} +(** A RPC specification, ie the signature for one remote procedure. *) + +type service = { + service_name: string; + service_packages: string list; (** Package in which this belongs *) + service_body: rpc list; +} +(** A service, composed of multiple RPCs. *) + +type proto = { + proto_types: type_ list list; + (** List of strongly connected type definitions *) + proto_services: service list; +} +(** A proto file is composed of a list of types and a list of services. *) diff --git a/src/compilerlib/pb_codegen_types.mli b/src/compilerlib/pb_codegen_types.mli index f1912016..5c748b11 100644 --- a/src/compilerlib/pb_codegen_types.mli +++ b/src/compilerlib/pb_codegen_types.mli @@ -1,4 +1,4 @@ -(** Code generator for the OCaml type *) +(** Code generator for the OCaml types *) include Pb_codegen_plugin.S diff --git a/src/compilerlib/pb_codegen_util.ml b/src/compilerlib/pb_codegen_util.ml index d11a3180..79c87521 100644 --- a/src/compilerlib/pb_codegen_util.ml +++ b/src/compilerlib/pb_codegen_util.ml @@ -84,6 +84,14 @@ let function_name_of_user_defined ~function_prefix = function | { Ot.udt_module_prefix = None; Ot.udt_type_name; _ } -> sp "%s_%s" function_prefix udt_type_name +let module_type_name_of_service_client (service : Ot.service) : string = + String.uppercase_ascii service.service_name ^ "_CLIENT" + +let module_type_name_of_service_server (service : Ot.service) : string = + String.uppercase_ascii service.service_name ^ "_SERVER" + +let function_name_of_rpc (rpc : Ot.rpc) = String.uncapitalize_ascii rpc.rpc_name + let caml_file_name_of_proto_file_name ~proto_file_name = let splitted = Pb_util.rev_split_by_char '.' proto_file_name in if List.length splitted < 2 || List.hd splitted <> "proto" then diff --git a/src/compilerlib/pb_codegen_util.mli b/src/compilerlib/pb_codegen_util.mli index fa52b10b..97fcaf9a 100644 --- a/src/compilerlib/pb_codegen_util.mli +++ b/src/compilerlib/pb_codegen_util.mli @@ -30,6 +30,15 @@ val function_name_of_user_defined : user defined field type. *) +val module_type_name_of_service_client : Pb_codegen_ocaml_type.service -> string +(** Name of the module type for this service (client) *) + +val module_type_name_of_service_server : Pb_codegen_ocaml_type.service -> string +(** Name of the module type for this service (server) *) + +val function_name_of_rpc : Pb_codegen_ocaml_type.rpc -> string +(** Name of the function for this RPC *) + val caml_file_name_of_proto_file_name : proto_file_name:string -> string (** [caml_file_name_of_proto_file_name filename] returns the OCaml file name from the protobuf file name diff --git a/src/compilerlib/pb_typing_type_tree.ml b/src/compilerlib/pb_typing_type_tree.ml index 580a7778..c2d2cb7f 100644 --- a/src/compilerlib/pb_typing_type_tree.ml +++ b/src/compilerlib/pb_typing_type_tree.ml @@ -120,6 +120,7 @@ type 'a rpc = { type 'a service = { service_name: string; + service_file_name: string; service_packages: string list; (** Package in which this belongs *) service_body: 'a Pb_field_type.t rpc list; } diff --git a/src/compilerlib/pb_typing_validation.ml b/src/compilerlib/pb_typing_validation.ml index 206b98cb..29de7d0b 100644 --- a/src/compilerlib/pb_typing_validation.ml +++ b/src/compilerlib/pb_typing_validation.ml @@ -265,7 +265,7 @@ let rec validate_message ?(parent_options = Pb_option.empty) file_name acc.Acc.all_types @ [ make_proto_type ~file_name ~file_options ~id ~scope:message_scope ~spec ] -let validate_service (scope : Tt.type_scope) (service : Pt.service) : +let validate_service (scope : Tt.type_scope) ~file_name (service : Pt.service) : _ Tt.service = let { Pt.service_name; service_body } = service in let service_body = @@ -294,7 +294,12 @@ let validate_service (scope : Tt.type_scope) (service : Pt.service) : Some rpc) service_body in - { Tt.service_packages = scope.packages; service_name; service_body } + { + Tt.service_packages = scope.packages; + service_file_name = file_name; + service_name; + service_body; + } let validate (proto : Pt.proto) : _ Tt.proto = let { @@ -327,5 +332,5 @@ let validate (proto : Pt.proto) : _ Tt.proto = pbtt_msgs messages in - let proto_services = List.map (validate_service scope) services in + let proto_services = List.map (validate_service scope ~file_name) services in { Tt.proto_types; proto_services } diff --git a/src/ocaml-protoc/ocaml_protoc.ml b/src/ocaml-protoc/ocaml_protoc.ml index 09deb011..e89fb84a 100644 --- a/src/ocaml-protoc/ocaml_protoc.ml +++ b/src/ocaml-protoc/ocaml_protoc.ml @@ -46,11 +46,11 @@ let () = File_options.to_file_options cmdline.Cmdline.cmd_line_file_options in - let ocaml_types, proto_file_options = + let ocaml_proto, proto_file_options = Compilation.compile cmdline cmd_line_file_options in - Generation.generate_code ocaml_types ~proto_file_options cmdline + Generation.generate_code ocaml_proto ~proto_file_options cmdline with exn -> Printf.eprintf "%s\n" (Printexc.to_string exn); exit 1 diff --git a/src/ocaml-protoc/ocaml_protoc_compilation.ml b/src/ocaml-protoc/ocaml_protoc_compilation.ml index 1fc2cdee..4621bb5f 100644 --- a/src/ocaml-protoc/ocaml_protoc_compilation.ml +++ b/src/ocaml-protoc/ocaml_protoc_compilation.ml @@ -57,7 +57,7 @@ let find_imported_file include_dirs file_name = | Some file_name -> file_name ) -let compile cmdline cmd_line_files_options = +let compile cmdline cmd_line_files_options : Ot.proto * _ = let { Cmdline.include_dirs; proto_file_name; unsigned_tag; _ } = cmdline in (* parsing *) @@ -105,19 +105,5 @@ let compile cmdline cmd_line_files_options = (* -- OCaml Backend -- *) let module BO = Pb_codegen_backend in - let ocaml_types = - List.rev - @@ List.fold_left - (fun ocaml_types types -> - let l = - List.flatten - @@ List.map - (fun t -> - BO.compile ~unsigned_tag:!unsigned_tag typed_proto t) - types - in - l :: ocaml_types) - [] typed_proto.proto_types - in - - ocaml_types, proto_file_options + let ocaml_proto = BO.compile ~unsigned_tag:!unsigned_tag typed_proto in + ocaml_proto, proto_file_options diff --git a/src/ocaml-protoc/ocaml_protoc_generation.ml b/src/ocaml-protoc/ocaml_protoc_generation.ml index 55ba78b7..51a385f3 100644 --- a/src/ocaml-protoc/ocaml_protoc_generation.ml +++ b/src/ocaml-protoc/ocaml_protoc_generation.ml @@ -53,19 +53,19 @@ let open_files cmdline (f : ml:out_channel -> mli:out_channel -> 'a) : 'a = (fun () -> f ~ml:oc_ml ~mli:oc_mli) let generate_code ocaml_types ~proto_file_options cmdline : unit = - let plugins = + let plugins : Pb_codegen_plugin.t list = List.flatten [ - (if !(cmdline.Cmdline.yojson) then - [ Pb_codegen_encode_yojson.plugin; Pb_codegen_decode_yojson.plugin ] + (if !(cmdline.Cmdline.pp) then + [ Pb_codegen_pp.plugin ] else []); (if !(cmdline.Cmdline.binary) then [ Pb_codegen_encode_binary.plugin; Pb_codegen_decode_binary.plugin ] else []); - (if !(cmdline.Cmdline.pp) then - [ Pb_codegen_pp.plugin ] + (if !(cmdline.Cmdline.yojson) then + [ Pb_codegen_encode_yojson.plugin; Pb_codegen_decode_yojson.plugin ] else []); (if !(cmdline.Cmdline.bs) then diff --git a/src/ocaml-protoc/ocaml_protoc_generation.mli b/src/ocaml-protoc/ocaml_protoc_generation.mli index 1586458c..9a4295f0 100644 --- a/src/ocaml-protoc/ocaml_protoc_generation.mli +++ b/src/ocaml-protoc/ocaml_protoc_generation.mli @@ -4,4 +4,4 @@ module Ot = Pb_codegen_ocaml_type module Cmdline = Ocaml_protoc_cmdline.Cmdline val generate_code : - Ot.type_ list list -> proto_file_options:Pb_option.set -> Cmdline.t -> unit + Ot.proto -> proto_file_options:Pb_option.set -> Cmdline.t -> unit diff --git a/src/runtime/pbrt.ml b/src/runtime/pbrt.ml index 1b7d7cdd..4b1044b1 100644 --- a/src/runtime/pbrt.ml +++ b/src/runtime/pbrt.ml @@ -815,3 +815,27 @@ module Pp = struct let pp_brk pp_record (fmt : F.formatter) r : unit = F.fprintf fmt "@[{ %a@;<1 -2>@]}" pp_record r end + +(** Client end of services *) +module Client = struct + type transport = { query: 'ret. string -> on_result:(string -> 'ret) -> 'ret } + + type ('req, 'ret) rpc = { + call: + 'actual_ret. + transport -> 'req -> on_result:('ret -> 'actual_ret) -> 'actual_ret; + } +end + +(** Server end of services *) +module Server = struct + type rpc = { + rpc_name: string; + rpc_handler: [ `JSON | `BINARY ] -> string -> string; + } + + type t = { + name: string; + handlers: rpc list; + } +end diff --git a/src/runtime/pbrt.mli b/src/runtime/pbrt.mli index a6197b90..42e9b0d2 100644 --- a/src/runtime/pbrt.mli +++ b/src/runtime/pbrt.mli @@ -511,3 +511,33 @@ module Pp : sig val pp_brk : (formatter -> 'a -> unit) -> formatter -> 'a -> unit (** [pp_brk fmt r] formats record value [r] with curly brakets. *) end + +(** Service stubs, client side *) +module Client : sig + type transport = { query: 'ret. string -> on_result:(string -> 'ret) -> 'ret } + (** A transport method, ie. a way to query a remote service + by sending it a query, and register a callback to be + called when the response is received. *) + + type ('req, 'ret) rpc = { + call: + 'actual_ret. + transport -> 'req -> on_result:('ret -> 'actual_ret) -> 'actual_ret; + } + (** A RPC. By calling it with a concrete transport, one gets a future result. *) +end + +(** Service stubs, server side *) +module Server : sig + type rpc = { + rpc_name: string; + rpc_handler: [ `JSON | `BINARY ] -> string -> string; + } + (** A RPC implementation. *) + + type t = { + name: string; + handlers: rpc list; + } + (** A service with fixed set of methods. *) +end From 1921918f97a85b3a3d1f51ab569dc4687ed94757 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 9 Sep 2023 23:11:43 -0400 Subject: [PATCH 04/46] add pbrt_services runtime library --- dune-project | 8 ++ pbrt_services.opam | 31 +++++++ src/runtime-services/dune | 7 ++ src/runtime-services/pbrt_services.ml | 108 +++++++++++++++++++++++++ src/runtime-services/pbrt_services.mli | 83 +++++++++++++++++++ 5 files changed, 237 insertions(+) create mode 100644 pbrt_services.opam create mode 100644 src/runtime-services/dune create mode 100644 src/runtime-services/pbrt_services.ml create mode 100644 src/runtime-services/pbrt_services.mli diff --git a/dune-project b/dune-project index 057276d1..491bce85 100644 --- a/dune-project +++ b/dune-project @@ -36,4 +36,12 @@ base64) (tags (protobuf encode decode))) +(package + (name pbrt_services) + (synopsis "Runtime library for ocaml-protoc to support RPC services") + (depends + (ocaml (>= 4.03)) + (pbrt (= :version)) + (pbrt_yojson (= :version))) + (tags (protobuf encode decode services rpc))) diff --git a/pbrt_services.opam b/pbrt_services.opam new file mode 100644 index 00000000..4a92a90f --- /dev/null +++ b/pbrt_services.opam @@ -0,0 +1,31 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +version: "2.4" +synopsis: "Runtime library for ocaml-protoc to support RPC services" +maintainer: ["Maxime Ransan "] +authors: ["Maxime Ransan "] +license: "MIT" +tags: ["protobuf" "encode" "decode" "services" "rpc"] +homepage: "https://github.com/mransan/ocaml-protoc" +bug-reports: "https://github.com/mransan/ocaml-protoc/issues" +depends: [ + "dune" {>= "2.0"} + "ocaml" {>= "4.03"} + "pbrt" {= version} + "pbrt_yojson" {= version} +] +build: [ + ["dune" "subst"] {pinned} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] +dev-repo: "git+https://github.com/mransan/ocaml-protoc.git" diff --git a/src/runtime-services/dune b/src/runtime-services/dune new file mode 100644 index 00000000..77e192ed --- /dev/null +++ b/src/runtime-services/dune @@ -0,0 +1,7 @@ + +(library + (name pbrt_services) + (public_name pbrt_services) + (wrapped true) + (synopsis "Runtime library for services generated by ocaml-protoc") + (libraries pbrt pbrt_yojson yojson)) diff --git a/src/runtime-services/pbrt_services.ml b/src/runtime-services/pbrt_services.ml new file mode 100644 index 00000000..6cae53fe --- /dev/null +++ b/src/runtime-services/pbrt_services.ml @@ -0,0 +1,108 @@ +(** Client end of services *) +module Client = struct + type error = + | Transport_error of string + | Timeout + | Decode_error_json of string + | Decode_error_binary of Pbrt.Decoder.error + + type transport = { + query: + 'ret. + service_name:string -> + rpc_name:string -> + [ `JSON | `BINARY ] -> + string -> + on_result:((string, error) result -> 'ret) -> + 'ret; + } + + type ('req, 'ret) rpc = { + call: + 'actual_ret. + [ `JSON | `BINARY ] -> + transport -> + 'req -> + on_result:(('ret, error) result -> 'actual_ret) -> + 'actual_ret; + } + + let mk_rpc ~service_name ~rpc_name ~encode_json_req ~encode_pb_req + ~decode_json_res ~decode_pb_res () : _ rpc = + { + call = + (fun encoding (transport : transport) req ~on_result -> + let req_str = + match encoding with + | `JSON -> encode_json_req req |> Yojson.Safe.to_string + | `BINARY -> + let enc = Pbrt.Encoder.create () in + encode_pb_req req enc; + Pbrt.Encoder.to_string enc + in + + let on_result = function + | Error err -> on_result (Error err) + | Ok res_str -> + (match encoding with + | `JSON -> + (match decode_json_res @@ Yojson.Safe.from_string res_str with + | res -> on_result (Ok res) + | exception exn -> + on_result (Error (Decode_error_json (Printexc.to_string exn)))) + | `BINARY -> + let dec = Pbrt.Decoder.of_string res_str in + (match decode_pb_res dec with + | v -> on_result (Ok v) + | exception Pbrt.Decoder.Failure err -> + on_result (Error (Decode_error_binary err)))) + in + + transport.query ~service_name ~rpc_name encoding req_str ~on_result); + } +end + +(** Server end of services *) +module Server = struct + type error = + | Invalid_json + | Invalid_pb of Pbrt.Decoder.error + | Handler_failed of string + + type rpc = { + rpc_name: string; + rpc_handler: [ `JSON | `BINARY ] -> string -> (string, error) result; + } + + let mk_rpc ~name ~(f : 'req -> 'res) ~encode_json_res ~encode_pb_res + ~decode_json_req ~decode_pb_req () : rpc = + let handler fmt req : _ result = + match fmt with + | `JSON -> + (match Yojson.Safe.from_string req with + | exception _ -> Error Invalid_json + | j -> + let req = decode_json_req j in + (match f req with + | res -> Ok (encode_json_res res) + | exception exn -> Error (Handler_failed (Printexc.to_string exn)))) + | `BINARY -> + let decoder = Pbrt.Decoder.of_string req in + (match decode_pb_req decoder with + | exception Pbrt.Decoder.Failure e -> Error (Invalid_pb e) + | req -> + (match f req with + | res -> + let enc = Pbrt.Encoder.create () in + encode_pb_res res enc; + Ok (Pbrt.Encoder.to_string enc) + | exception exn -> Error (Handler_failed (Printexc.to_string exn)))) + in + + { rpc_name = name; rpc_handler = handler } + + type t = { + name: string; + handlers: rpc list; + } +end diff --git a/src/runtime-services/pbrt_services.mli b/src/runtime-services/pbrt_services.mli new file mode 100644 index 00000000..688c3e5a --- /dev/null +++ b/src/runtime-services/pbrt_services.mli @@ -0,0 +1,83 @@ +(** Service stubs, client side *) +module Client : sig + type error = + | Transport_error of string + | Timeout + | Decode_error_json of string + | Decode_error_binary of Pbrt.Decoder.error + + type transport = { + query: + 'ret. + service_name:string -> + rpc_name:string -> + [ `JSON | `BINARY ] -> + string -> + on_result:((string, error) result -> 'ret) -> + 'ret; + } + (** A transport method, ie. a way to query a remote service + by sending it a query, and register a callback to be + called when the response is received. + + The [query] function is called like so: + [transport.query ~service_name ~rpc_name encoding req ~on_result], + where [rpc_name] is the name of the method in the service [service_name], + [req] is the encoded query using [encoding], and [on_result] is a callback + that will be called after the service comes back with a response. + *) + + type ('req, 'ret) rpc = { + call: + 'actual_ret. + [ `JSON | `BINARY ] -> + transport -> + 'req -> + on_result:(('ret, error) result -> 'actual_ret) -> + 'actual_ret; + } + (** A RPC. By calling it with a concrete transport, one gets a future result. *) + + val mk_rpc : + service_name:string -> + rpc_name:string -> + encode_json_req:('req -> Yojson.Safe.t) -> + encode_pb_req:('req -> Pbrt.Encoder.t -> unit) -> + decode_json_res:(Yojson.Safe.t -> 'res) -> + decode_pb_res:(Pbrt.Decoder.t -> 'res) -> + unit -> + ('req, 'res) rpc +end + +(** Service stubs, server side *) +module Server : sig + (** Errors that can arise during request processing. *) + type error = + | Invalid_json + | Invalid_pb of Pbrt.Decoder.error + | Handler_failed of string + + type rpc = { + rpc_name: string; + rpc_handler: [ `JSON | `BINARY ] -> string -> (string, error) result; + } + + val mk_rpc : + name:string -> + f:('req -> 'res) -> + encode_json_res:('res -> string) -> + encode_pb_res:('res -> Pbrt.Encoder.t -> unit) -> + decode_json_req:(Yojson.Safe.t -> 'req) -> + decode_pb_req:(Pbrt.Decoder.t -> 'req) -> + unit -> + rpc + (** Helper to build a RPC *) + + (** A RPC implementation. *) + + type t = { + name: string; + handlers: rpc list; + } + (** A service with fixed set of methods. *) +end From 06acfe14a94d83f4bef2d531606457dd4121f63a Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 9 Sep 2023 23:12:14 -0400 Subject: [PATCH 05/46] cleanup pbrt --- src/runtime-yojson/dune | 2 +- src/runtime/pbrt.ml | 24 ------------------------ src/runtime/pbrt.mli | 30 ------------------------------ 3 files changed, 1 insertion(+), 55 deletions(-) diff --git a/src/runtime-yojson/dune b/src/runtime-yojson/dune index 7f53410b..c7e629dc 100644 --- a/src/runtime-yojson/dune +++ b/src/runtime-yojson/dune @@ -2,4 +2,4 @@ (library (public_name pbrt_yojson) (wrapped false) - (libraries yojson base64)) + (libraries (re_export yojson) base64)) diff --git a/src/runtime/pbrt.ml b/src/runtime/pbrt.ml index 4b1044b1..1b7d7cdd 100644 --- a/src/runtime/pbrt.ml +++ b/src/runtime/pbrt.ml @@ -815,27 +815,3 @@ module Pp = struct let pp_brk pp_record (fmt : F.formatter) r : unit = F.fprintf fmt "@[{ %a@;<1 -2>@]}" pp_record r end - -(** Client end of services *) -module Client = struct - type transport = { query: 'ret. string -> on_result:(string -> 'ret) -> 'ret } - - type ('req, 'ret) rpc = { - call: - 'actual_ret. - transport -> 'req -> on_result:('ret -> 'actual_ret) -> 'actual_ret; - } -end - -(** Server end of services *) -module Server = struct - type rpc = { - rpc_name: string; - rpc_handler: [ `JSON | `BINARY ] -> string -> string; - } - - type t = { - name: string; - handlers: rpc list; - } -end diff --git a/src/runtime/pbrt.mli b/src/runtime/pbrt.mli index 42e9b0d2..a6197b90 100644 --- a/src/runtime/pbrt.mli +++ b/src/runtime/pbrt.mli @@ -511,33 +511,3 @@ module Pp : sig val pp_brk : (formatter -> 'a -> unit) -> formatter -> 'a -> unit (** [pp_brk fmt r] formats record value [r] with curly brakets. *) end - -(** Service stubs, client side *) -module Client : sig - type transport = { query: 'ret. string -> on_result:(string -> 'ret) -> 'ret } - (** A transport method, ie. a way to query a remote service - by sending it a query, and register a callback to be - called when the response is received. *) - - type ('req, 'ret) rpc = { - call: - 'actual_ret. - transport -> 'req -> on_result:('ret -> 'actual_ret) -> 'actual_ret; - } - (** A RPC. By calling it with a concrete transport, one gets a future result. *) -end - -(** Service stubs, server side *) -module Server : sig - type rpc = { - rpc_name: string; - rpc_handler: [ `JSON | `BINARY ] -> string -> string; - } - (** A RPC implementation. *) - - type t = { - name: string; - handlers: rpc list; - } - (** A service with fixed set of methods. *) -end From f7bc76a003481584fa50b84fb798302f6f6c6c3d Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 9 Sep 2023 23:12:27 -0400 Subject: [PATCH 06/46] bulk of the codegen for services --- src/compilerlib/pb_codegen_services.ml | 119 +++++++++++++++++++++++ src/compilerlib/pb_codegen_services.mli | 11 +++ src/ocaml-protoc/ocaml_protoc_cmdline.ml | 10 ++ 3 files changed, 140 insertions(+) create mode 100644 src/compilerlib/pb_codegen_services.ml create mode 100644 src/compilerlib/pb_codegen_services.mli diff --git a/src/compilerlib/pb_codegen_services.ml b/src/compilerlib/pb_codegen_services.ml new file mode 100644 index 00000000..0fbfb659 --- /dev/null +++ b/src/compilerlib/pb_codegen_services.ml @@ -0,0 +1,119 @@ +module Ot = Pb_codegen_ocaml_type +module F = Pb_codegen_formatting + +let string_of_rpc_type (ty : Ot.rpc_type) : string = + let f = Pb_codegen_util.string_of_field_type in + match ty with + | Ot.Rpc_scalar ty -> f ty + | Ot.Rpc_stream ty -> Printf.sprintf "%s Seq.t" (f ty) + +let mod_name_for_client (service : Ot.service) : string = + String.capitalize_ascii service.service_name + +let gen_service_client_struct (service : Ot.service) sc : unit = + F.linep sc "module %s = struct" (mod_name_for_client service); + F.sub_scope sc (fun sc -> + F.linep sc "open Pbrt_services.Client"; + List.iter + (fun (rpc : Ot.rpc) -> + F.empty_line sc; + F.linep sc "let %s : (%s, %s) rpc =" + (Pb_codegen_util.function_name_of_rpc rpc) + (string_of_rpc_type rpc.rpc_req) + (string_of_rpc_type rpc.rpc_res); + F.linep sc " mk_rpc ~service_name:%S ~rpc_name:%S" + service.service_name rpc.rpc_name; + F.linep sc " ~encode_json_req:encode_json_%s" + (string_of_rpc_type rpc.rpc_req); + F.linep sc " ~encode_pb_req:encode_pb_%s" + (string_of_rpc_type rpc.rpc_req); + F.linep sc " ~decode_json_res:decode_json_%s" + (string_of_rpc_type rpc.rpc_req); + F.linep sc " ~decode_pb_res:decode_pb_%s" + (string_of_rpc_type rpc.rpc_req); + F.line sc "()") + service.service_body); + + F.line sc "end"; + F.empty_line sc + +let gen_service_client_sig (service : Ot.service) sc : unit = + F.linep sc "(** Client for %s *)" service.service_name; + F.linep sc "module %s : sig" (mod_name_for_client service); + F.sub_scope sc (fun sc -> + F.linep sc "open Pbrt_services.Client"; + List.iter + (fun (rpc : Ot.rpc) -> + F.empty_line sc; + F.linep sc "val %s : (%s, %s) rpc" + (Pb_codegen_util.function_name_of_rpc rpc) + (string_of_rpc_type rpc.rpc_req) + (string_of_rpc_type rpc.rpc_res)) + service.service_body); + F.line sc "end"; + F.empty_line sc + +(** generate the module type for the server (shared between .ml and .mli) *) +let gen_mod_type_of_service (service : Ot.service) sc : unit = + let mod_type_name = + Pb_codegen_util.module_type_name_of_service_server service + in + + F.linep sc "module type %s = sig" mod_type_name; + F.sub_scope sc (fun sc -> + List.iter + (fun (rpc : Ot.rpc) -> + F.linep sc "val %s : %s -> %s" + (Pb_codegen_util.function_name_of_rpc rpc) + (string_of_rpc_type rpc.rpc_req) + (string_of_rpc_type rpc.rpc_res)) + service.service_body); + F.line sc "end" + +let gen_service_server_struct (service : Ot.service) sc : unit = + let mod_type_name = + Pb_codegen_util.module_type_name_of_service_server service + in + + gen_mod_type_of_service service sc; + F.empty_line sc; + + (* now generate a function from the module type to a [Service_server.t] *) + F.linep sc "let service_impl_of_%s (module M:%s) : Pbrt_services.Server.t =" + (String.lowercase_ascii service.service_name) + mod_type_name; + F.sub_scope sc (fun sc -> + F.line sc "let open Pbrt_services.Server in"; + F.linep sc "{ name=%S;" service.service_name; + F.line sc " handlers=["; + List.iter + (fun (rpc : Ot.rpc) -> + F.linep sc " mk_rpc ~name:%S ~f:M.%s" rpc.rpc_name + (Pb_codegen_util.function_name_of_rpc rpc); + F.linep sc " ~encode_json_res:encode_json_%s" + (string_of_rpc_type rpc.rpc_res); + F.linep sc " ~encode_pb_res:encode_pb_%s" + (string_of_rpc_type rpc.rpc_res); + F.linep sc " ~decode_json_req:decode_json_%s" + (string_of_rpc_type rpc.rpc_req); + F.linep sc " ~decode_pb_req:decode_pb_%s" + (string_of_rpc_type rpc.rpc_req); + F.linep sc " ();") + service.service_body; + F.line sc "]; }"); + F.empty_line sc + +let gen_service_server_sig service sc : unit = + let mod_type_name = + Pb_codegen_util.module_type_name_of_service_server service + in + + F.linep sc "(** Server interface for %s *)" service.service_name; + gen_mod_type_of_service service sc; + F.empty_line sc; + + F.linep sc "(** Convert {!%s} to a generic runtime service *)" mod_type_name; + F.linep sc "val service_impl_of_%s : (module %s) -> Pbrt.Server.t" + (String.lowercase_ascii service.service_name) + mod_type_name; + () diff --git a/src/compilerlib/pb_codegen_services.mli b/src/compilerlib/pb_codegen_services.mli new file mode 100644 index 00000000..52cd9a1c --- /dev/null +++ b/src/compilerlib/pb_codegen_services.mli @@ -0,0 +1,11 @@ +val gen_service_client_sig : + Pb_codegen_ocaml_type.service -> Pb_codegen_formatting.scope -> unit + +val gen_service_client_struct : + Pb_codegen_ocaml_type.service -> Pb_codegen_formatting.scope -> unit + +val gen_service_server_sig : + Pb_codegen_ocaml_type.service -> Pb_codegen_formatting.scope -> unit + +val gen_service_server_struct : + Pb_codegen_ocaml_type.service -> Pb_codegen_formatting.scope -> unit diff --git a/src/ocaml-protoc/ocaml_protoc_cmdline.ml b/src/ocaml-protoc/ocaml_protoc_cmdline.ml index 4002a6eb..b68e1086 100644 --- a/src/ocaml-protoc/ocaml_protoc_cmdline.ml +++ b/src/ocaml-protoc/ocaml_protoc_cmdline.ml @@ -110,6 +110,7 @@ module Cmdline = struct yojson: bool ref; (** whether yojson encoding is enabled *) bs: bool ref; (** whether BuckleScript encoding is enabled *) pp: bool ref; (** whether pretty printing is enabled *) + services: bool ref; (** whether services code generation is enabled *) mutable cmd_line_file_options: File_options.t; (** file options override from the cmd line *) unsigned_tag: bool ref; @@ -127,6 +128,7 @@ module Cmdline = struct yojson = ref false; bs = ref false; pp = ref false; + services = ref false; cmd_line_file_options = File_options.make (); unsigned_tag = ref false; } @@ -137,6 +139,9 @@ module Cmdline = struct "--bs", Arg.Set t.bs, " generate BuckleScript encoding"; "--binary", Arg.Set t.binary, " generate binary encoding"; "--pp", Arg.Set t.pp, " generate pretty print functions"; + ( "--services", + Arg.Set t.services, + " generate code for services (requires json+binary)" ); ( "-I", Arg.String (fun s -> t.include_dirs <- s :: t.include_dirs), " include directories" ); @@ -160,6 +165,11 @@ module Cmdline = struct t.pp := true ); + if !(t.services) then ( + t.binary := true; + t.yojson := true + ); + if t.proto_file_name = "" then failwith "Missing proto file name from command line argument"; From 13a969929b3ec30f7b6656b2245477ddbb45ef2d Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 10 Sep 2023 00:30:01 -0400 Subject: [PATCH 07/46] fix --- src/tests/unit-tests/test_typing.ml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tests/unit-tests/test_typing.ml b/src/tests/unit-tests/test_typing.ml index c963f2d5..6fa41497 100644 --- a/src/tests/unit-tests/test_typing.ml +++ b/src/tests/unit-tests/test_typing.ml @@ -17,7 +17,8 @@ let () = let t = List.fold_left (fun t type_ -> Pb_typing_resolution.Types_by_scope.add t type_) - Pb_typing_resolution.Types_by_scope.empty proto + Pb_typing_resolution.Types_by_scope.empty + (List.flatten proto.proto_types) in assert (is_found t [ "foo"; "bar" ] "M1"); assert (is_found t [ "foo"; "bar" ] "M2"); From c699e1df5eeb56d4c91ecf6fc803285053f2cc7d Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 10 Sep 2023 00:55:14 -0400 Subject: [PATCH 08/46] compilerlib: fix issue related to type lookup in imports --- src/compilerlib/pb_codegen_backend.ml | 74 +++++++++++--------- src/compilerlib/pb_codegen_backend.mli | 8 ++- src/compilerlib/pb_typing_util.ml | 10 +-- src/compilerlib/pb_typing_util.mli | 2 +- src/ocaml-protoc/ocaml_protoc_compilation.ml | 6 +- 5 files changed, 53 insertions(+), 47 deletions(-) diff --git a/src/compilerlib/pb_codegen_backend.ml b/src/compilerlib/pb_codegen_backend.ml index 3c6e5a29..3ffa241c 100644 --- a/src/compilerlib/pb_codegen_backend.ml +++ b/src/compilerlib/pb_codegen_backend.ml @@ -132,7 +132,7 @@ let wrapper_type_of_type_name = function with the module name. (This is essentially expecting (rightly) a sub module with the same name. *) -let user_defined_type_of_id all_types file_name i : Ot.field_type = +let user_defined_type_of_id ~(all_types : _ list) file_name i : Ot.field_type = let module_prefix = module_prefix_of_file_name file_name in match Typing_util.type_of_id all_types i with | exception Not_found -> E.programmatic_error E.No_type_found_for_id @@ -169,7 +169,7 @@ let user_defined_type_of_id all_types file_name i : Ot.field_type = ) ) -let encoding_info_of_field_type all_types field_type : Ot.payload_kind = +let encoding_info_of_field_type ~all_types field_type : Ot.payload_kind = match field_type with | `Double -> Ot.Pk_bits64 | `Float -> Ot.Pk_bits32 @@ -191,7 +191,7 @@ let encoding_info_of_field_type all_types field_type : Ot.payload_kind = | { Tt.spec = Tt.Enum _; _ } -> Ot.Pk_varint false | { Tt.spec = Tt.Message _; _ } -> Ot.Pk_bytes) -let encoding_of_field all_types (field : (Pb_field_type.resolved, 'a) Tt.field) +let encoding_of_field ~all_types (field : (Pb_field_type.resolved, 'a) Tt.field) = let packed = match Typing_util.field_option field "packed" with @@ -201,12 +201,12 @@ let encoding_of_field all_types (field : (Pb_field_type.resolved, 'a) Tt.field) in let pk = - encoding_info_of_field_type all_types (Typing_util.field_type field) + encoding_info_of_field_type ~all_types (Typing_util.field_type field) in pk, Typing_util.field_number field, packed, Typing_util.field_default field -let compile_field_type ~unsigned_tag all_types file_options field_options - file_name field_type : Ot.field_type = +let compile_field_type ~unsigned_tag ~(all_types : _ Tt.proto_type list) + file_options field_options file_name field_type : Ot.field_type = let ocaml_type = match Pb_option.get field_options "ocaml_type" with | Some Pb_option.(Scalar_value (Constant_literal "int_t")) -> `Int_t @@ -285,7 +285,7 @@ let compile_field_type ~unsigned_tag all_types file_options field_options | `Bool, _ -> Ot.(Ft_basic_type Bt_bool) | `String, _ -> Ot.(Ft_basic_type Bt_string) | `Bytes, _ -> Ot.(Ft_basic_type Bt_bytes) - | `User_defined id, _ -> user_defined_type_of_id all_types file_name id + | `User_defined id, _ -> user_defined_type_of_id ~all_types file_name id let is_mutable ?field_name field_options = match Pb_option.get field_options "ocaml_mutable" with @@ -301,19 +301,19 @@ let ocaml_container field_options = | Some _ -> None let variant_of_oneof ?include_oneof_name ~outer_message_names ~unsigned_tag - all_types file_options file_name oneof_field : Ot.variant = + ~all_types file_options file_name oneof_field : Ot.variant = let v_constructors = List.map (fun field -> let pbtt_field_type = Typing_util.field_type field in let field_type = - compile_field_type ~unsigned_tag all_types file_options + compile_field_type ~unsigned_tag ~all_types file_options (Typing_util.field_options field) file_name pbtt_field_type in let vc_payload_kind, vc_encoding_number, _, _ = - encoding_of_field all_types field + encoding_of_field ~all_types field in let vc_constructor = constructor_name (Typing_util.field_name field) in @@ -373,9 +373,9 @@ let process_all_types_ppx_extension file_name file_options |> string_of_string_option file_name let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) - (proto : Pb_field_type.resolved Tt.proto) (file_name : string) - (scope : Tt.type_scope) (message : Pb_field_type.resolved Tt.message) : - Ot.type_ list = + ~(all_types : Pb_field_type.resolved Tt.proto_type list) + (file_name : string) (scope : Tt.type_scope) + (message : Pb_field_type.resolved Tt.message) : Ot.type_ list = let module_prefix = module_prefix_of_file_name file_name in (* TODO maybe module_ should be resolved before `compile_message` since @@ -408,8 +408,8 @@ let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) | Tt.Message_oneof_field f :: [] -> let outer_message_names = message_names @ [ message_name ] in let variant = - variant_of_oneof ~unsigned_tag ~outer_message_names proto file_options - file_name f + variant_of_oneof ~unsigned_tag ~outer_message_names ~all_types + file_options file_name f in [ Ot.{ module_prefix; spec = Variant variant; type_level_ppx_extension } ] | _ -> @@ -418,7 +418,7 @@ let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) (fun (variants, fields) -> function | Tt.Message_field field -> let pk, encoding_number, packed, _ = - encoding_of_field proto field + encoding_of_field ~all_types field in let field_name = Typing_util.field_name field in @@ -428,8 +428,8 @@ let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) let field_type = Typing_util.field_type field in let ocaml_field_type = - compile_field_type ~unsigned_tag proto file_options field_options - file_name field_type + compile_field_type ~unsigned_tag ~all_types file_options + field_options file_name field_type in let field_default = Typing_util.field_default field in @@ -493,7 +493,7 @@ let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) let outer_message_names = message_names @ [ message_name ] in let variant = variant_of_oneof ~unsigned_tag ~include_oneof_name:() - ~outer_message_names proto file_options file_name field + ~outer_message_names ~all_types file_options file_name field in let record_field = @@ -538,11 +538,11 @@ let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) in let key_type = - compile_field_type ~unsigned_tag proto file_options map_options - file_name map_key_type + compile_field_type ~unsigned_tag ~all_types file_options + map_options file_name map_key_type in - let key_pk = encoding_info_of_field_type proto map_key_type in + let key_pk = encoding_info_of_field_type ~all_types map_key_type in let key_type = match key_type with @@ -551,11 +551,13 @@ let compile_message ~(unsigned_tag : bool) (file_options : Pb_option.set) in let value_type = - compile_field_type ~unsigned_tag proto file_options map_options - file_name map_value_type + compile_field_type ~unsigned_tag ~all_types file_options + map_options file_name map_value_type in - let value_pk = encoding_info_of_field_type proto map_value_type in + let value_pk = + encoding_info_of_field_type ~all_types map_value_type + in let associative_type = match ocaml_container map_options with @@ -634,12 +636,12 @@ let compile_enum file_options file_name scope enum = type_level_ppx_extension; } -let compile_rpc ~unsigned_tag ~(file_name : string) proto +let compile_rpc ~unsigned_tag ~(file_name : string) ~all_types (rpc : Pb_field_type.resolved_t Tt.rpc) : Ot.rpc = let compile_ty ~stream ty : Ot.rpc_type = let ty : Ot.field_type = - compile_field_type ~unsigned_tag proto Pb_option.empty rpc.rpc_options - file_name ty + compile_field_type ~unsigned_tag ~all_types Pb_option.empty + rpc.rpc_options file_name ty in if stream then Ot.Rpc_stream ty @@ -653,32 +655,34 @@ let compile_rpc ~unsigned_tag ~(file_name : string) proto rpc_res = compile_ty ~stream:rpc.rpc_res_stream rpc.rpc_res; } -let compile_service ~(unsigned_tag : bool) proto +let compile_service ~(unsigned_tag : bool) ~all_types (service : Pb_field_type.resolved Tt.service) : Ot.service = { Ot.service_name = service.service_name; service_packages = service.service_packages; service_body = List.map - (compile_rpc ~unsigned_tag ~file_name:service.service_file_name proto) + (compile_rpc ~unsigned_tag ~file_name:service.service_file_name + ~all_types) service.service_body; } -let compile1 ~unsigned_tag proto : _ -> Ot.type_ list = function +let compile1 ~unsigned_tag ~all_types : _ -> Ot.type_ list = function | { Tt.spec = Tt.Message m; file_name; file_options; scope; _ } -> - compile_message ~unsigned_tag file_options proto file_name scope m + compile_message ~unsigned_tag file_options ~all_types file_name scope m | { Tt.spec = Tt.Enum e; file_name; scope; file_options; _ } -> [ compile_enum file_options file_name scope e ] -let compile ~unsigned_tag (proto : Pb_field_type.resolved Tt.proto) : Ot.proto = +let compile ~unsigned_tag ~all_types (proto : Pb_field_type.resolved Tt.proto) : + Ot.proto = let tys = List.map - (fun t -> List.flatten @@ List.map (compile1 ~unsigned_tag proto) t) + (fun t -> List.flatten @@ List.map (compile1 ~unsigned_tag ~all_types) t) proto.proto_types in let services = - List.map (compile_service ~unsigned_tag proto) proto.proto_services + List.map (compile_service ~unsigned_tag ~all_types) proto.proto_services in { Ot.proto_services = services; proto_types = tys } diff --git a/src/compilerlib/pb_codegen_backend.mli b/src/compilerlib/pb_codegen_backend.mli index ed65dfac..13c41c42 100644 --- a/src/compilerlib/pb_codegen_backend.mli +++ b/src/compilerlib/pb_codegen_backend.mli @@ -36,7 +36,11 @@ module Ot = Pb_codegen_ocaml_type (** {2 Compilation } *) -val compile : unsigned_tag:bool -> Pb_field_type.resolved Tt.proto -> Ot.proto +val compile : + unsigned_tag:bool -> + all_types:Pb_field_type.resolved Tt.proto_type list -> + Pb_field_type.resolved Tt.proto -> + Ot.proto (** Internal helpers. @@ -56,7 +60,7 @@ module Internal : sig ?include_oneof_name:unit -> outer_message_names:string list -> unsigned_tag:bool -> - 'a Tt.proto -> + all_types:'a Tt.proto_type list -> Pb_option.set -> string -> int Tt.oneof -> diff --git a/src/compilerlib/pb_typing_util.ml b/src/compilerlib/pb_typing_util.ml index 0b6b5f6f..eab918b6 100644 --- a/src/compilerlib/pb_typing_util.ml +++ b/src/compilerlib/pb_typing_util.ml @@ -42,14 +42,8 @@ let field_option { Tt.field_options; _ } option_name = let empty_scope = { Tt.packages = []; message_names = [] } let type_id_of_type { Tt.id; _ } = id -let type_of_id (p : _ Tt.proto) id = - match - Pb_util.List.find_map - (fun tys -> Pb_util.List.find_opt (fun t -> type_id_of_type t = id) tys) - p.proto_types - with - | Some ty -> ty - | None -> raise Not_found +let type_of_id all_types id = + List.find (fun t -> type_id_of_type t = id) all_types let string_of_type_scope { Tt.packages; message_names } = Printf.sprintf "scope:{packages:%s, message_names:%s}" diff --git a/src/compilerlib/pb_typing_util.mli b/src/compilerlib/pb_typing_util.mli index 2f79c8f1..98724474 100644 --- a/src/compilerlib/pb_typing_util.mli +++ b/src/compilerlib/pb_typing_util.mli @@ -56,7 +56,7 @@ val field_option : ('a, 'b) Tt.field -> string -> Pb_option.value option is returned. *) -val type_of_id : 'a Tt.proto -> int -> 'a Tt.proto_type +val type_of_id : 'a Tt.proto_type list -> int -> 'a Tt.proto_type (** [type_of_id all_types id] returns the type associated with the given id, raise [Not_found] if the type is not in the all_types. *) diff --git a/src/ocaml-protoc/ocaml_protoc_compilation.ml b/src/ocaml-protoc/ocaml_protoc_compilation.ml index 4621bb5f..f6ab4eb4 100644 --- a/src/ocaml-protoc/ocaml_protoc_compilation.ml +++ b/src/ocaml-protoc/ocaml_protoc_compilation.ml @@ -88,6 +88,7 @@ let compile cmdline cmd_line_files_options : Ot.proto * _ = (* typing *) let typed_proto = Pb_typing.perform_typing protos in + let all_typed_protos = List.flatten typed_proto.proto_types in (* Only get the types which are part of the given proto file (compilation unit) *) @@ -105,5 +106,8 @@ let compile cmdline cmd_line_files_options : Ot.proto * _ = (* -- OCaml Backend -- *) let module BO = Pb_codegen_backend in - let ocaml_proto = BO.compile ~unsigned_tag:!unsigned_tag typed_proto in + let ocaml_proto = + BO.compile ~unsigned_tag:!unsigned_tag ~all_types:all_typed_protos + typed_proto + in ocaml_proto, proto_file_options From cdff03e35c9ea7694273283a0d9dd112b76c749b Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 10 Sep 2023 01:39:52 -0400 Subject: [PATCH 09/46] pbrt_services: use yojson basic must be consistent with pbrt_yojson. --- src/runtime-services/pbrt_services.ml | 6 +++--- src/runtime-services/pbrt_services.mli | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/runtime-services/pbrt_services.ml b/src/runtime-services/pbrt_services.ml index 6cae53fe..564a9cd1 100644 --- a/src/runtime-services/pbrt_services.ml +++ b/src/runtime-services/pbrt_services.ml @@ -34,7 +34,7 @@ module Client = struct (fun encoding (transport : transport) req ~on_result -> let req_str = match encoding with - | `JSON -> encode_json_req req |> Yojson.Safe.to_string + | `JSON -> encode_json_req req |> Yojson.Basic.to_string | `BINARY -> let enc = Pbrt.Encoder.create () in encode_pb_req req enc; @@ -46,7 +46,7 @@ module Client = struct | Ok res_str -> (match encoding with | `JSON -> - (match decode_json_res @@ Yojson.Safe.from_string res_str with + (match decode_json_res @@ Yojson.Basic.from_string res_str with | res -> on_result (Ok res) | exception exn -> on_result (Error (Decode_error_json (Printexc.to_string exn)))) @@ -79,7 +79,7 @@ module Server = struct let handler fmt req : _ result = match fmt with | `JSON -> - (match Yojson.Safe.from_string req with + (match Yojson.Basic.from_string req with | exception _ -> Error Invalid_json | j -> let req = decode_json_req j in diff --git a/src/runtime-services/pbrt_services.mli b/src/runtime-services/pbrt_services.mli index 688c3e5a..3be18fa1 100644 --- a/src/runtime-services/pbrt_services.mli +++ b/src/runtime-services/pbrt_services.mli @@ -41,9 +41,9 @@ module Client : sig val mk_rpc : service_name:string -> rpc_name:string -> - encode_json_req:('req -> Yojson.Safe.t) -> + encode_json_req:('req -> Yojson.Basic.t) -> encode_pb_req:('req -> Pbrt.Encoder.t -> unit) -> - decode_json_res:(Yojson.Safe.t -> 'res) -> + decode_json_res:(Yojson.Basic.t -> 'res) -> decode_pb_res:(Pbrt.Decoder.t -> 'res) -> unit -> ('req, 'res) rpc @@ -67,7 +67,7 @@ module Server : sig f:('req -> 'res) -> encode_json_res:('res -> string) -> encode_pb_res:('res -> Pbrt.Encoder.t -> unit) -> - decode_json_req:(Yojson.Safe.t -> 'req) -> + decode_json_req:(Yojson.Basic.t -> 'req) -> decode_pb_req:(Pbrt.Decoder.t -> 'req) -> unit -> rpc From f86d426cfd36dcdad2d7a223d04a568ee68778f3 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 10 Sep 2023 01:40:15 -0400 Subject: [PATCH 10/46] services RPC should take messages in and out? --- src/compilerlib/pb_codegen_backend.ml | 26 +++++------- src/compilerlib/pb_codegen_decode_binary.ml | 2 +- src/compilerlib/pb_codegen_decode_yojson.ml | 2 +- src/compilerlib/pb_codegen_decode_yojson.mli | 7 ++++ src/compilerlib/pb_codegen_ocaml_type.ml | 2 +- src/compilerlib/pb_codegen_services.ml | 27 ++++++++++--- src/compilerlib/pb_exception.ml | 18 +++++++++ src/compilerlib/pb_exception.mli | 2 + src/compilerlib/pb_typing_resolution.ml | 42 ++++++++++++-------- src/compilerlib/pb_typing_type_tree.ml | 2 +- src/compilerlib/pb_typing_validation.ml | 10 +++++ 11 files changed, 97 insertions(+), 43 deletions(-) diff --git a/src/compilerlib/pb_codegen_backend.ml b/src/compilerlib/pb_codegen_backend.ml index 3ffa241c..84f12e2e 100644 --- a/src/compilerlib/pb_codegen_backend.ml +++ b/src/compilerlib/pb_codegen_backend.ml @@ -132,7 +132,7 @@ let wrapper_type_of_type_name = function with the module name. (This is essentially expecting (rightly) a sub module with the same name. *) -let user_defined_type_of_id ~(all_types : _ list) file_name i : Ot.field_type = +let user_defined_type_of_id ~(all_types : _ list) ~file_name i : Ot.field_type = let module_prefix = module_prefix_of_file_name file_name in match Typing_util.type_of_id all_types i with | exception Not_found -> E.programmatic_error E.No_type_found_for_id @@ -285,7 +285,7 @@ let compile_field_type ~unsigned_tag ~(all_types : _ Tt.proto_type list) | `Bool, _ -> Ot.(Ft_basic_type Bt_bool) | `String, _ -> Ot.(Ft_basic_type Bt_string) | `Bytes, _ -> Ot.(Ft_basic_type Bt_bytes) - | `User_defined id, _ -> user_defined_type_of_id ~all_types file_name id + | `User_defined id, _ -> user_defined_type_of_id ~all_types ~file_name id let is_mutable ?field_name field_options = match Pb_option.get field_options "ocaml_mutable" with @@ -636,13 +636,10 @@ let compile_enum file_options file_name scope enum = type_level_ppx_extension; } -let compile_rpc ~unsigned_tag ~(file_name : string) ~all_types - (rpc : Pb_field_type.resolved_t Tt.rpc) : Ot.rpc = - let compile_ty ~stream ty : Ot.rpc_type = - let ty : Ot.field_type = - compile_field_type ~unsigned_tag ~all_types Pb_option.empty - rpc.rpc_options file_name ty - in +let compile_rpc ~(file_name : string) ~all_types + (rpc : Pb_field_type.resolved Tt.rpc) : Ot.rpc = + let compile_ty ~stream (ty : int) : Ot.rpc_type = + let ty = user_defined_type_of_id ~all_types ~file_name ty in if stream then Ot.Rpc_stream ty else @@ -655,15 +652,14 @@ let compile_rpc ~unsigned_tag ~(file_name : string) ~all_types rpc_res = compile_ty ~stream:rpc.rpc_res_stream rpc.rpc_res; } -let compile_service ~(unsigned_tag : bool) ~all_types - (service : Pb_field_type.resolved Tt.service) : Ot.service = +let compile_service ~all_types (service : Pb_field_type.resolved Tt.service) : + Ot.service = { Ot.service_name = service.service_name; service_packages = service.service_packages; service_body = List.map - (compile_rpc ~unsigned_tag ~file_name:service.service_file_name - ~all_types) + (compile_rpc ~file_name:service.service_file_name ~all_types) service.service_body; } @@ -681,9 +677,7 @@ let compile ~unsigned_tag ~all_types (proto : Pb_field_type.resolved Tt.proto) : proto.proto_types in - let services = - List.map (compile_service ~unsigned_tag ~all_types) proto.proto_services - in + let services = List.map (compile_service ~all_types) proto.proto_services in { Ot.proto_services = services; proto_types = tys } module Internal = struct diff --git a/src/compilerlib/pb_codegen_decode_binary.ml b/src/compilerlib/pb_codegen_decode_binary.ml index b13bdef4..f9e98225 100644 --- a/src/compilerlib/pb_codegen_decode_binary.ml +++ b/src/compilerlib/pb_codegen_decode_binary.ml @@ -39,7 +39,7 @@ let runtime_function_for_wrapper_type { Ot.wt_type; wt_pk } = | Ot.Bt_bytes, Ot.Pk_bytes -> "Pbrt.Decoder.wrapper_bytes_value" | _ -> assert false -let decode_field_expression field_type pk = +let decode_field_expression field_type pk : string = match field_type with | Ot.Ft_user_defined_type t -> let f_name = diff --git a/src/compilerlib/pb_codegen_decode_yojson.ml b/src/compilerlib/pb_codegen_decode_yojson.ml index c755ea67..00c3a00c 100644 --- a/src/compilerlib/pb_codegen_decode_yojson.ml +++ b/src/compilerlib/pb_codegen_decode_yojson.ml @@ -3,7 +3,7 @@ module F = Pb_codegen_formatting let sp = Pb_codegen_util.sp -(* Function which returns all the possible pattern match for reading a JSON +(** Function which returns all the possible pattern match for reading a JSON value into an OCaml value. The protobuf JSON encoding rules are defined here: https://developers.google.com/protocol-buffers/docs/proto3#json *) diff --git a/src/compilerlib/pb_codegen_decode_yojson.mli b/src/compilerlib/pb_codegen_decode_yojson.mli index fcff7c26..3cd0c82c 100644 --- a/src/compilerlib/pb_codegen_decode_yojson.mli +++ b/src/compilerlib/pb_codegen_decode_yojson.mli @@ -2,4 +2,11 @@ include Pb_codegen_plugin.S +val field_pattern_match : + r_name:string -> + rf_label:string -> + Pb_codegen_ocaml_type.field_type -> + string * string +(** How to decode a field type *) + val plugin : Pb_codegen_plugin.t diff --git a/src/compilerlib/pb_codegen_ocaml_type.ml b/src/compilerlib/pb_codegen_ocaml_type.ml index 0730eae9..7a53f1d8 100644 --- a/src/compilerlib/pb_codegen_ocaml_type.ml +++ b/src/compilerlib/pb_codegen_ocaml_type.ml @@ -156,7 +156,7 @@ type type_ = { type_level_ppx_extension: string option; } -(** RPC argument or return type *) +(** RPC argument or return type. We require message types in RPC. *) type rpc_type = | Rpc_scalar of field_type | Rpc_stream of field_type diff --git a/src/compilerlib/pb_codegen_services.ml b/src/compilerlib/pb_codegen_services.ml index 0fbfb659..6321ca4d 100644 --- a/src/compilerlib/pb_codegen_services.ml +++ b/src/compilerlib/pb_codegen_services.ml @@ -1,12 +1,25 @@ module Ot = Pb_codegen_ocaml_type module F = Pb_codegen_formatting +let spf = Printf.sprintf + let string_of_rpc_type (ty : Ot.rpc_type) : string = let f = Pb_codegen_util.string_of_field_type in match ty with | Ot.Rpc_scalar ty -> f ty | Ot.Rpc_stream ty -> Printf.sprintf "%s Seq.t" (f ty) +let decode_json_ty ~r_name ~rf_label (ty : Ot.rpc_type) : string = + let f ty = + let x, body = + Pb_codegen_decode_yojson.field_pattern_match ~r_name ~rf_label ty + in + spf "(fun %s -> %s)" x body + in + match ty with + | Ot.Rpc_scalar ty -> f ty + | Ot.Rpc_stream ty -> Printf.sprintf "Seq.map %s" (f ty) + let mod_name_for_client (service : Ot.service) : string = String.capitalize_ascii service.service_name @@ -27,10 +40,11 @@ let gen_service_client_struct (service : Ot.service) sc : unit = (string_of_rpc_type rpc.rpc_req); F.linep sc " ~encode_pb_req:encode_pb_%s" (string_of_rpc_type rpc.rpc_req); - F.linep sc " ~decode_json_res:decode_json_%s" - (string_of_rpc_type rpc.rpc_req); + F.linep sc " ~decode_json_res:%s" + (decode_json_ty ~r_name:service.service_name ~rf_label:rpc.rpc_name + rpc.rpc_res); F.linep sc " ~decode_pb_res:decode_pb_%s" - (string_of_rpc_type rpc.rpc_req); + (string_of_rpc_type rpc.rpc_res); F.line sc "()") service.service_body); @@ -94,8 +108,9 @@ let gen_service_server_struct (service : Ot.service) sc : unit = (string_of_rpc_type rpc.rpc_res); F.linep sc " ~encode_pb_res:encode_pb_%s" (string_of_rpc_type rpc.rpc_res); - F.linep sc " ~decode_json_req:decode_json_%s" - (string_of_rpc_type rpc.rpc_req); + F.linep sc " ~decode_json_req:(%s)" + (decode_json_ty ~r_name:service.service_name ~rf_label:rpc.rpc_name + rpc.rpc_req); F.linep sc " ~decode_pb_req:decode_pb_%s" (string_of_rpc_type rpc.rpc_req); F.linep sc " ();") @@ -113,7 +128,7 @@ let gen_service_server_sig service sc : unit = F.empty_line sc; F.linep sc "(** Convert {!%s} to a generic runtime service *)" mod_type_name; - F.linep sc "val service_impl_of_%s : (module %s) -> Pbrt.Server.t" + F.linep sc "val service_impl_of_%s : (module %s) -> Pbrt_services.Server.t" (String.lowercase_ascii service.service_name) mod_type_name; () diff --git a/src/compilerlib/pb_exception.ml b/src/compilerlib/pb_exception.ml index 8cf2de4a..29255d3e 100644 --- a/src/compilerlib/pb_exception.ml +++ b/src/compilerlib/pb_exception.ml @@ -92,6 +92,8 @@ type error = | Invalid_first_enum_value_proto3 of string * string option | Invalid_key_type_for_map of string | Unsupported_wrapper_type of string + | Invalid_rpc_req_type of string * string + | Invalid_rpc_res_type of string * string exception Compilation_error of error (** Exception raised when a compilation error occurs *) @@ -158,6 +160,16 @@ let string_of_error = function ("Invalid field label for field: %s in message: %s. " ^^ "[Required|Optional] are not supported.") field_name message_name + | Invalid_rpc_req_type (service_name, rpc_name) -> + P.sprintf + ("Invalid type for RPC request: %s in service: %s. " + ^^ "The type must be user-defined..") + service_name rpc_name + | Invalid_rpc_res_type (service_name, rpc_name) -> + P.sprintf + ("Invalid type for RPC response: %s in service: %s. " + ^^ "The type must be user-defined..") + service_name rpc_name | Default_field_option_not_supported (field_name, message_name) -> P.sprintf ("Explicit default values are not allowed in proto3. " @@ -252,6 +264,12 @@ let invalid_proto3_field_label ~field_name ~message_name = raise (Compilation_error (Invalid_proto3_field_label (field_name, message_name))) +let invalid_rpc_req_type ~service_name ~rpc_name () = + raise (Compilation_error (Invalid_rpc_req_type (service_name, rpc_name))) + +let invalid_rpc_res_type ~service_name ~rpc_name () = + raise (Compilation_error (Invalid_rpc_res_type (service_name, rpc_name))) + let default_field_option_not_supported ~field_name ~message_name = raise (Compilation_error diff --git a/src/compilerlib/pb_exception.mli b/src/compilerlib/pb_exception.mli index 1faa82ca..50c755cd 100644 --- a/src/compilerlib/pb_exception.mli +++ b/src/compilerlib/pb_exception.mli @@ -69,6 +69,8 @@ val protoc_parsing_error : error -> Pb_location.t -> string -> 'a val unknown_parsing_error : msg:string -> context:string -> Pb_location.t -> 'a val invalid_protobuf_syntax : string -> 'a val invalid_proto3_field_label : field_name:string -> message_name:string -> 'a +val invalid_rpc_req_type : service_name:string -> rpc_name:string -> unit -> 'a +val invalid_rpc_res_type : service_name:string -> rpc_name:string -> unit -> 'a val default_field_option_not_supported : field_name:string -> message_name:string -> 'a diff --git a/src/compilerlib/pb_typing_resolution.ml b/src/compilerlib/pb_typing_resolution.ml index 5720ca5a..85fb75b5 100644 --- a/src/compilerlib/pb_typing_resolution.ml +++ b/src/compilerlib/pb_typing_resolution.ml @@ -221,6 +221,24 @@ let resolve_enum_field_default field_name type_ field_default = E.invalid_default_value ~field_name ~info:"default value not supported for message" () +let resolve_user_defined_type t field_name + (unresolved_field_type : Pb_field_type.unresolved) field_default + message_type_path : int * _ = + let { Pb_field_type.type_name; _ } = unresolved_field_type in + let rec aux = function + | [] -> raise Not_found + | type_path :: tl -> + (match Types_by_scope.find t type_path type_name with + | type_ -> + let id = type_.Tt.id in + let field_default = + resolve_enum_field_default field_name type_ field_default + in + id, field_default + | exception Not_found -> aux tl) + in + aux (compute_search_type_paths unresolved_field_type message_type_path) + (** this function resolves both the type and the defaut value of a field type. Note that it is necessary to verify both at the same time since the default value must be of the same type as the field type in order @@ -238,20 +256,11 @@ let resolve_field_type_and_default t field_name field_type field_default in builtin_type, field_default | `User_defined unresolved_field_type -> - let { Pb_field_type.type_name; _ } = unresolved_field_type in - let rec aux = function - | [] -> raise Not_found - | type_path :: tl -> - (match Types_by_scope.find t type_path type_name with - | type_ -> - let id = type_.Tt.id in - let field_default = - resolve_enum_field_default field_name type_ field_default - in - `User_defined id, field_default - | exception Not_found -> aux tl) + let id, default = + resolve_user_defined_type t field_name unresolved_field_type field_default + message_type_path in - aux (compute_search_type_paths unresolved_field_type message_type_path) + `User_defined id, default (** this function resolves all the field type of the given type *) let resolve_type t type_ : int Tt.proto_type = @@ -331,12 +340,11 @@ let resolve_types types : Types_by_scope.t * _ list = let t = List.fold_left Types_by_scope.add Types_by_scope.empty types in t, List.map (resolve_type t) types -let resolve_service t (service : _ Tt.service) : - Pb_field_type.resolved Tt.service = - let resolve_ty ~rpc_name ~name ty : Pb_field_type.resolved Pb_field_type.t = +let resolve_service t (service : _ Tt.service) : _ Tt.service = + let resolve_ty ~rpc_name ~name ty : Pb_field_type.resolved = let rpc_type, _field_default = let do_resolve () = - resolve_field_type_and_default t name ty None service.service_packages + resolve_user_defined_type t name ty None service.service_packages in match do_resolve () with | ret -> ret diff --git a/src/compilerlib/pb_typing_type_tree.ml b/src/compilerlib/pb_typing_type_tree.ml index c2d2cb7f..000cb294 100644 --- a/src/compilerlib/pb_typing_type_tree.ml +++ b/src/compilerlib/pb_typing_type_tree.ml @@ -122,7 +122,7 @@ type 'a service = { service_name: string; service_file_name: string; service_packages: string list; (** Package in which this belongs *) - service_body: 'a Pb_field_type.t rpc list; + service_body: 'a rpc list; } (** A service, composed of multiple RPCs. *) diff --git a/src/compilerlib/pb_typing_validation.ml b/src/compilerlib/pb_typing_validation.ml index 29de7d0b..2ca05778 100644 --- a/src/compilerlib/pb_typing_validation.ml +++ b/src/compilerlib/pb_typing_validation.ml @@ -281,6 +281,16 @@ let validate_service (scope : Tt.type_scope) ~file_name (service : Pt.service) : rpc_res_stream; rpc_res; } -> + let rpc_req = + match rpc_req with + | `User_defined ty -> ty + | _ -> E.invalid_rpc_req_type ~service_name ~rpc_name () + in + let rpc_res = + match rpc_res with + | `User_defined ty -> ty + | _ -> E.invalid_rpc_res_type ~service_name ~rpc_name () + in let rpc = { Tt.rpc_name; From 50e0ff4d5b8c2c974dd4f0a099082e2dd933b77f Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 10 Sep 2023 01:40:48 -0400 Subject: [PATCH 11/46] wip: example of basic service --- src/examples/calculator.proto | 25 +++++++++++++++++++++++++ src/examples/dune | 12 ++++++++++++ src/examples/t_calculator.ml | 1 + 3 files changed, 38 insertions(+) create mode 100644 src/examples/calculator.proto create mode 100644 src/examples/t_calculator.ml diff --git a/src/examples/calculator.proto b/src/examples/calculator.proto new file mode 100644 index 00000000..fc8e18e2 --- /dev/null +++ b/src/examples/calculator.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +message DivByZero {} + +message AddReq { + int32 a = 1; + int32 b = 2; +} + +message AddAllReq { + repeated int32 ints = 1; +} + +message Empty {} + +service Calculator { + rpc add(AddReq) returns (int32); + + rpc add_all(AddAllReq) returns (int32); + + rpc ping(Empty) returns (Empty); + + rpc get_pings(Empty) returns (int32); +} + diff --git a/src/examples/dune b/src/examples/dune index 882a7fa1..20d4ae54 100644 --- a/src/examples/dune +++ b/src/examples/dune @@ -41,3 +41,15 @@ (name example05) (modules t_example05 example05) (libraries pbrt)) + +(rule + (targets calculator.ml calculator.mli) + (deps calculator.proto) + (action + (run ocaml-protoc --binary --pp --yojson --services --ml_out ./ %{deps}))) + +(test + (name calculator) + (modules t_calculator calculator) + (package pbrt_services) + (libraries pbrt pbrt_yojson pbrt_services)) diff --git a/src/examples/t_calculator.ml b/src/examples/t_calculator.ml new file mode 100644 index 00000000..306831a0 --- /dev/null +++ b/src/examples/t_calculator.ml @@ -0,0 +1 @@ +let () = () From db9ab84afa1bec120d18c43f3095b83e5daf17f0 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 10 Sep 2023 13:29:15 -0400 Subject: [PATCH 12/46] formatting --- dune-project | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dune-project b/dune-project index 491bce85..0c1b97cf 100644 --- a/dune-project +++ b/dune-project @@ -20,7 +20,7 @@ (package (name pbrt) (synopsis "Runtime library for Protobuf tooling") - (depends + (depends stdlib-shims (odoc :with-doc) (ocaml (>= 4.03))) @@ -29,7 +29,7 @@ (package (name pbrt_yojson) (synopsis "Runtime library for ocaml-protoc to support JSON encoding/decoding") - (depends + (depends (ocaml (>= 4.03)) (odoc :with-doc) (yojson (>= 1.6)) @@ -39,7 +39,7 @@ (package (name pbrt_services) (synopsis "Runtime library for ocaml-protoc to support RPC services") - (depends + (depends (ocaml (>= 4.03)) (pbrt (= :version)) (pbrt_yojson (= :version))) From d66a420b16c699331593da3197cc1bf6d80132ef Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 10 Sep 2023 14:41:27 -0400 Subject: [PATCH 13/46] make it compile baby!!! --- src/compilerlib/pb_codegen_encode_yojson.ml | 2 +- src/compilerlib/pb_codegen_ocaml_type.ml | 2 +- src/compilerlib/pb_codegen_services.ml | 125 +++++++++++++++----- src/runtime-services/pbrt_services.ml | 2 +- src/runtime-services/pbrt_services.mli | 2 +- 5 files changed, 99 insertions(+), 34 deletions(-) diff --git a/src/compilerlib/pb_codegen_encode_yojson.ml b/src/compilerlib/pb_codegen_encode_yojson.ml index c3e672c5..d0f80fb0 100644 --- a/src/compilerlib/pb_codegen_encode_yojson.ml +++ b/src/compilerlib/pb_codegen_encode_yojson.ml @@ -36,7 +36,7 @@ let runtime_function_for_basic_type json_label basic_type pk = return the runtime function name which accepts an option field and return the YoJson value (ie Null when value is None *) -let gen_field var_name json_label field_type pk = +let gen_field var_name json_label field_type pk : string option = match field_type, pk with | Ot.Ft_unit, _ -> None (* Basic types *) diff --git a/src/compilerlib/pb_codegen_ocaml_type.ml b/src/compilerlib/pb_codegen_ocaml_type.ml index 7a53f1d8..6d883975 100644 --- a/src/compilerlib/pb_codegen_ocaml_type.ml +++ b/src/compilerlib/pb_codegen_ocaml_type.ml @@ -87,7 +87,7 @@ type is_packed = bool type record_field_type = | Rft_nolabel of (field_type * encoding_number * payload_kind) - (* no default values in proto3 no label fields *) + (** no default values in proto3 no label fields *) | Rft_required of (field_type * encoding_number * payload_kind * default_value) | Rft_optional of diff --git a/src/compilerlib/pb_codegen_services.ml b/src/compilerlib/pb_codegen_services.ml index 6321ca4d..106f6725 100644 --- a/src/compilerlib/pb_codegen_services.ml +++ b/src/compilerlib/pb_codegen_services.ml @@ -1,51 +1,115 @@ module Ot = Pb_codegen_ocaml_type module F = Pb_codegen_formatting -let spf = Printf.sprintf - let string_of_rpc_type (ty : Ot.rpc_type) : string = let f = Pb_codegen_util.string_of_field_type in match ty with | Ot.Rpc_scalar ty -> f ty | Ot.Rpc_stream ty -> Printf.sprintf "%s Seq.t" (f ty) -let decode_json_ty ~r_name ~rf_label (ty : Ot.rpc_type) : string = +let function_name_encode_json ~service_name ~rpc_name (ty : Ot.rpc_type) : + string = + let f ty = + match ty with + | Ot.Ft_unit -> "(fun () -> `Assoc [])" + | Ot.Ft_user_defined_type udt -> + let function_prefix = "encode_json" in + Pb_codegen_util.function_name_of_user_defined ~function_prefix udt + | _ -> + Printf.eprintf "cannot json-encode request for %s in service %s\n%!" + rpc_name service_name; + exit 1 + in + match ty with + | Ot.Rpc_scalar ty -> f ty + | Ot.Rpc_stream ty -> Printf.sprintf "Seq.map %s" (f ty) + +let function_name_decode_json ~service_name ~rpc_name (ty : Ot.rpc_type) : + string = + let f ty = + match ty with + | Ot.Ft_unit -> "(fun _ -> ())" + | Ot.Ft_user_defined_type udt -> + let function_prefix = "decode_json" in + Pb_codegen_util.function_name_of_user_defined ~function_prefix udt + | _ -> + Printf.eprintf "cannot decode json request for %s in service %s\n%!" + rpc_name service_name; + exit 1 + in + match ty with + | Ot.Rpc_scalar ty -> f ty + | Ot.Rpc_stream ty -> Printf.sprintf "Seq.map %s" (f ty) + +let function_name_encode_pb ~service_name ~rpc_name (ty : Ot.rpc_type) : string + = + let f ty = + match ty with + | Ot.Ft_unit -> "(fun () enc -> Pbrt.Encoder.empty_nested enc)" + | Ot.Ft_user_defined_type udt -> + let function_prefix = "encode_pb" in + Pb_codegen_util.function_name_of_user_defined ~function_prefix udt + | _ -> + Printf.eprintf "cannot binary-encode request for %s in service %s\n%!" + rpc_name service_name; + exit 1 + in + match ty with + | Ot.Rpc_scalar ty -> f ty + | Ot.Rpc_stream ty -> Printf.sprintf "Seq.map %s" (f ty) + +let function_name_decode_pb ~service_name ~rpc_name (ty : Ot.rpc_type) : string + = let f ty = - let x, body = - Pb_codegen_decode_yojson.field_pattern_match ~r_name ~rf_label ty - in - spf "(fun %s -> %s)" x body + match ty with + | Ot.Ft_unit -> "(fun d -> Pbrt.Decoder.empty_nested d)" + | Ot.Ft_user_defined_type udt -> + let function_prefix = "decode_pb" in + Pb_codegen_util.function_name_of_user_defined ~function_prefix udt + | _ -> + Printf.eprintf "cannot decode binary request for %s in service %s\n%!" + rpc_name service_name; + exit 1 in match ty with | Ot.Rpc_scalar ty -> f ty | Ot.Rpc_stream ty -> Printf.sprintf "Seq.map %s" (f ty) +let ocaml_type_of_rpc_type (rpc : Ot.rpc_type) : string = + match rpc with + | Rpc_scalar ty -> Pb_codegen_util.string_of_field_type ty + | Rpc_stream ty -> + Printf.sprintf "(%s Seq.t)" (Pb_codegen_util.string_of_field_type ty) + let mod_name_for_client (service : Ot.service) : string = String.capitalize_ascii service.service_name let gen_service_client_struct (service : Ot.service) sc : unit = + let service_name = service.service_name in F.linep sc "module %s = struct" (mod_name_for_client service); F.sub_scope sc (fun sc -> F.linep sc "open Pbrt_services.Client"; List.iter (fun (rpc : Ot.rpc) -> + let rpc_name = rpc.rpc_name in F.empty_line sc; F.linep sc "let %s : (%s, %s) rpc =" (Pb_codegen_util.function_name_of_rpc rpc) (string_of_rpc_type rpc.rpc_req) (string_of_rpc_type rpc.rpc_res); - F.linep sc " mk_rpc ~service_name:%S ~rpc_name:%S" + F.linep sc " (mk_rpc ~service_name:%S ~rpc_name:%S" service.service_name rpc.rpc_name; - F.linep sc " ~encode_json_req:encode_json_%s" - (string_of_rpc_type rpc.rpc_req); - F.linep sc " ~encode_pb_req:encode_pb_%s" - (string_of_rpc_type rpc.rpc_req); + F.linep sc " ~encode_json_req:%s" + (function_name_encode_json ~service_name ~rpc_name rpc.rpc_req); + F.linep sc " ~encode_pb_req:%s" + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req); F.linep sc " ~decode_json_res:%s" - (decode_json_ty ~r_name:service.service_name ~rf_label:rpc.rpc_name - rpc.rpc_res); - F.linep sc " ~decode_pb_res:decode_pb_%s" - (string_of_rpc_type rpc.rpc_res); - F.line sc "()") + (function_name_decode_json ~service_name ~rpc_name rpc.rpc_res); + F.linep sc " ~decode_pb_res:%s" + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res); + F.linep sc "() : (%s, %s) rpc)" + (ocaml_type_of_rpc_type rpc.rpc_req) + (ocaml_type_of_rpc_type rpc.rpc_res)) service.service_body); F.line sc "end"; @@ -85,6 +149,7 @@ let gen_mod_type_of_service (service : Ot.service) sc : unit = F.line sc "end" let gen_service_server_struct (service : Ot.service) sc : unit = + let service_name = service.service_name in let mod_type_name = Pb_codegen_util.module_type_name_of_service_server service in @@ -94,26 +159,26 @@ let gen_service_server_struct (service : Ot.service) sc : unit = (* now generate a function from the module type to a [Service_server.t] *) F.linep sc "let service_impl_of_%s (module M:%s) : Pbrt_services.Server.t =" - (String.lowercase_ascii service.service_name) + (String.lowercase_ascii service_name) mod_type_name; F.sub_scope sc (fun sc -> F.line sc "let open Pbrt_services.Server in"; - F.linep sc "{ name=%S;" service.service_name; + F.linep sc "{ name=%S;" service_name; F.line sc " handlers=["; List.iter (fun (rpc : Ot.rpc) -> - F.linep sc " mk_rpc ~name:%S ~f:M.%s" rpc.rpc_name + let rpc_name = rpc.rpc_name in + F.linep sc " (mk_rpc ~name:%S ~f:M.%s" rpc.rpc_name (Pb_codegen_util.function_name_of_rpc rpc); - F.linep sc " ~encode_json_res:encode_json_%s" - (string_of_rpc_type rpc.rpc_res); - F.linep sc " ~encode_pb_res:encode_pb_%s" - (string_of_rpc_type rpc.rpc_res); - F.linep sc " ~decode_json_req:(%s)" - (decode_json_ty ~r_name:service.service_name ~rf_label:rpc.rpc_name - rpc.rpc_req); - F.linep sc " ~decode_pb_req:decode_pb_%s" - (string_of_rpc_type rpc.rpc_req); - F.linep sc " ();") + F.linep sc " ~encode_json_res:%s" + (function_name_encode_json ~service_name ~rpc_name rpc.rpc_res); + F.linep sc " ~encode_pb_res:%s" + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_res); + F.linep sc " ~decode_json_req:%s" + (function_name_decode_json ~service_name ~rpc_name rpc.rpc_req); + F.linep sc " ~decode_pb_req:%s" + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_req); + F.linep sc " () : rpc);") service.service_body; F.line sc "]; }"); F.empty_line sc diff --git a/src/runtime-services/pbrt_services.ml b/src/runtime-services/pbrt_services.ml index 564a9cd1..bcd5d880 100644 --- a/src/runtime-services/pbrt_services.ml +++ b/src/runtime-services/pbrt_services.ml @@ -84,7 +84,7 @@ module Server = struct | j -> let req = decode_json_req j in (match f req with - | res -> Ok (encode_json_res res) + | res -> Ok (Yojson.Basic.to_string @@ encode_json_res res) | exception exn -> Error (Handler_failed (Printexc.to_string exn)))) | `BINARY -> let decoder = Pbrt.Decoder.of_string req in diff --git a/src/runtime-services/pbrt_services.mli b/src/runtime-services/pbrt_services.mli index 3be18fa1..a15eaa4d 100644 --- a/src/runtime-services/pbrt_services.mli +++ b/src/runtime-services/pbrt_services.mli @@ -65,7 +65,7 @@ module Server : sig val mk_rpc : name:string -> f:('req -> 'res) -> - encode_json_res:('res -> string) -> + encode_json_res:('res -> Yojson.Basic.t) -> encode_pb_res:('res -> Pbrt.Encoder.t -> unit) -> decode_json_req:(Yojson.Basic.t -> 'req) -> decode_pb_req:(Pbrt.Decoder.t -> 'req) -> From 984fdcb6da486b8a94f05f6ed39ea2f3b2bbdfa0 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 10 Sep 2023 14:50:46 -0400 Subject: [PATCH 14/46] change RPC stub to have no behavior; just list encoders/decoders the transport libraries will have more flexbility this way. --- src/compilerlib/pb_codegen_services.ml | 2 +- src/examples/calculator.proto | 10 +- src/runtime-services/dune | 5 + src/runtime-services/errors.proto | 19 ++++ src/runtime-services/pbrt_services.ml | 134 ++++++++++--------------- src/runtime-services/pbrt_services.mli | 84 ++++++++-------- 6 files changed, 124 insertions(+), 130 deletions(-) create mode 100644 src/runtime-services/errors.proto diff --git a/src/compilerlib/pb_codegen_services.ml b/src/compilerlib/pb_codegen_services.ml index 106f6725..1b17264e 100644 --- a/src/compilerlib/pb_codegen_services.ml +++ b/src/compilerlib/pb_codegen_services.ml @@ -163,7 +163,7 @@ let gen_service_server_struct (service : Ot.service) sc : unit = mod_type_name; F.sub_scope sc (fun sc -> F.line sc "let open Pbrt_services.Server in"; - F.linep sc "{ name=%S;" service_name; + F.linep sc "{ service_name=%S;" service_name; F.line sc " handlers=["; List.iter (fun (rpc : Ot.rpc) -> diff --git a/src/examples/calculator.proto b/src/examples/calculator.proto index fc8e18e2..3082757f 100644 --- a/src/examples/calculator.proto +++ b/src/examples/calculator.proto @@ -2,6 +2,10 @@ syntax = "proto3"; message DivByZero {} +message I32 { + int32 value = 0; +} + message AddReq { int32 a = 1; int32 b = 2; @@ -14,12 +18,12 @@ message AddAllReq { message Empty {} service Calculator { - rpc add(AddReq) returns (int32); + rpc add(AddReq) returns (I32); - rpc add_all(AddAllReq) returns (int32); + rpc add_all(AddAllReq) returns (I32); rpc ping(Empty) returns (Empty); - rpc get_pings(Empty) returns (int32); + rpc get_pings(Empty) returns (I32); } diff --git a/src/runtime-services/dune b/src/runtime-services/dune index 77e192ed..8c6c9614 100644 --- a/src/runtime-services/dune +++ b/src/runtime-services/dune @@ -5,3 +5,8 @@ (wrapped true) (synopsis "Runtime library for services generated by ocaml-protoc") (libraries pbrt pbrt_yojson yojson)) + +(rule + (targets errors.ml errors.mli) + (deps (:file errors.proto)) + (action (run ../ocaml-protoc/ocaml_protoc.exe --pp --binary --yojson --ml_out . %{file}))) diff --git a/src/runtime-services/errors.proto b/src/runtime-services/errors.proto new file mode 100644 index 00000000..6eb3b123 --- /dev/null +++ b/src/runtime-services/errors.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +message Empty {} + +message TimeoutInfo { + // Timeout, in seconds + float timeout_s = 1; +} + +message RpcError { + oneof error { + Empty unknown_error = 0; + string transport_error = 1; + string server_error = 2; + TimeoutInfo timeout = 3; + string invalid_json = 4; + string invalid_binary = 5; + } +} diff --git a/src/runtime-services/pbrt_services.ml b/src/runtime-services/pbrt_services.ml index bcd5d880..d1090428 100644 --- a/src/runtime-services/pbrt_services.ml +++ b/src/runtime-services/pbrt_services.ml @@ -1,108 +1,76 @@ +module Errors = Errors + (** Client end of services *) module Client = struct - type error = + type error = Errors.rpc_error = + | Invalid_binary of string + | Invalid_json of string + | Timeout of Errors.timeout_info + | Server_error of string | Transport_error of string - | Timeout - | Decode_error_json of string - | Decode_error_binary of Pbrt.Decoder.error + | Unknown_error - type transport = { - query: - 'ret. - service_name:string -> - rpc_name:string -> - [ `JSON | `BINARY ] -> - string -> - on_result:((string, error) result -> 'ret) -> - 'ret; - } + let pp_error = Errors.pp_rpc_error type ('req, 'ret) rpc = { - call: - 'actual_ret. - [ `JSON | `BINARY ] -> - transport -> - 'req -> - on_result:(('ret, error) result -> 'actual_ret) -> - 'actual_ret; + service_name: string; + rpc_name: string; + encode_json_req: 'req -> Yojson.Basic.t; + encode_pb_req: 'req -> Pbrt.Encoder.t -> unit; + decode_json_res: Yojson.Basic.t -> 'ret; + decode_pb_res: Pbrt.Decoder.t -> 'ret; } let mk_rpc ~service_name ~rpc_name ~encode_json_req ~encode_pb_req ~decode_json_res ~decode_pb_res () : _ rpc = { - call = - (fun encoding (transport : transport) req ~on_result -> - let req_str = - match encoding with - | `JSON -> encode_json_req req |> Yojson.Basic.to_string - | `BINARY -> - let enc = Pbrt.Encoder.create () in - encode_pb_req req enc; - Pbrt.Encoder.to_string enc - in - - let on_result = function - | Error err -> on_result (Error err) - | Ok res_str -> - (match encoding with - | `JSON -> - (match decode_json_res @@ Yojson.Basic.from_string res_str with - | res -> on_result (Ok res) - | exception exn -> - on_result (Error (Decode_error_json (Printexc.to_string exn)))) - | `BINARY -> - let dec = Pbrt.Decoder.of_string res_str in - (match decode_pb_res dec with - | v -> on_result (Ok v) - | exception Pbrt.Decoder.Failure err -> - on_result (Error (Decode_error_binary err)))) - in - - transport.query ~service_name ~rpc_name encoding req_str ~on_result); + service_name; + rpc_name; + encode_pb_req; + encode_json_req; + decode_pb_res; + decode_json_res; } end (** Server end of services *) module Server = struct - type error = - | Invalid_json - | Invalid_pb of Pbrt.Decoder.error - | Handler_failed of string + type error = Errors.rpc_error = + | Invalid_binary of string + | Invalid_json of string + | Timeout of Errors.timeout_info + | Server_error of string + | Transport_error of string + | Unknown_error - type rpc = { - rpc_name: string; - rpc_handler: [ `JSON | `BINARY ] -> string -> (string, error) result; - } + let pp_error = Errors.pp_rpc_error + + (** A RPC endpoint. *) + type rpc = + | RPC : { + name: string; + f: 'req -> 'res; + encode_json_res: 'res -> Yojson.Basic.t; + encode_pb_res: 'res -> Pbrt.Encoder.t -> unit; + decode_json_req: Yojson.Basic.t -> 'req; + decode_pb_req: Pbrt.Decoder.t -> 'req; + } + -> rpc let mk_rpc ~name ~(f : 'req -> 'res) ~encode_json_res ~encode_pb_res ~decode_json_req ~decode_pb_req () : rpc = - let handler fmt req : _ result = - match fmt with - | `JSON -> - (match Yojson.Basic.from_string req with - | exception _ -> Error Invalid_json - | j -> - let req = decode_json_req j in - (match f req with - | res -> Ok (Yojson.Basic.to_string @@ encode_json_res res) - | exception exn -> Error (Handler_failed (Printexc.to_string exn)))) - | `BINARY -> - let decoder = Pbrt.Decoder.of_string req in - (match decode_pb_req decoder with - | exception Pbrt.Decoder.Failure e -> Error (Invalid_pb e) - | req -> - (match f req with - | res -> - let enc = Pbrt.Encoder.create () in - encode_pb_res res enc; - Ok (Pbrt.Encoder.to_string enc) - | exception exn -> Error (Handler_failed (Printexc.to_string exn)))) - in - - { rpc_name = name; rpc_handler = handler } + RPC + { + name; + f; + decode_pb_req; + decode_json_req; + encode_pb_res; + encode_json_res; + } type t = { - name: string; + service_name: string; handlers: rpc list; } end diff --git a/src/runtime-services/pbrt_services.mli b/src/runtime-services/pbrt_services.mli index a15eaa4d..f5045e3d 100644 --- a/src/runtime-services/pbrt_services.mli +++ b/src/runtime-services/pbrt_services.mli @@ -1,42 +1,30 @@ +(** Runtime for Protobuf services. *) + +module Errors = Errors +(** RPC errors. These are printable and serializable. *) + (** Service stubs, client side *) module Client : sig - type error = + type error = Errors.rpc_error = + | Invalid_binary of string + | Invalid_json of string + | Timeout of Errors.timeout_info + | Server_error of string | Transport_error of string - | Timeout - | Decode_error_json of string - | Decode_error_binary of Pbrt.Decoder.error + | Unknown_error - type transport = { - query: - 'ret. - service_name:string -> - rpc_name:string -> - [ `JSON | `BINARY ] -> - string -> - on_result:((string, error) result -> 'ret) -> - 'ret; - } - (** A transport method, ie. a way to query a remote service - by sending it a query, and register a callback to be - called when the response is received. - - The [query] function is called like so: - [transport.query ~service_name ~rpc_name encoding req ~on_result], - where [rpc_name] is the name of the method in the service [service_name], - [req] is the encoded query using [encoding], and [on_result] is a callback - that will be called after the service comes back with a response. - *) + val pp_error : Format.formatter -> error -> unit type ('req, 'ret) rpc = { - call: - 'actual_ret. - [ `JSON | `BINARY ] -> - transport -> - 'req -> - on_result:(('ret, error) result -> 'actual_ret) -> - 'actual_ret; + service_name: string; + rpc_name: string; + encode_json_req: 'req -> Yojson.Basic.t; + encode_pb_req: 'req -> Pbrt.Encoder.t -> unit; + decode_json_res: Yojson.Basic.t -> 'ret; + decode_pb_res: Pbrt.Decoder.t -> 'ret; } - (** A RPC. By calling it with a concrete transport, one gets a future result. *) + (** A RPC description. You need a transport library + that knows where to send the bytes to actually use it. *) val mk_rpc : service_name:string -> @@ -52,15 +40,27 @@ end (** Service stubs, server side *) module Server : sig (** Errors that can arise during request processing. *) - type error = - | Invalid_json - | Invalid_pb of Pbrt.Decoder.error - | Handler_failed of string + type error = Errors.rpc_error = + | Invalid_binary of string + | Invalid_json of string + | Timeout of Errors.timeout_info + | Server_error of string + | Transport_error of string + | Unknown_error - type rpc = { - rpc_name: string; - rpc_handler: [ `JSON | `BINARY ] -> string -> (string, error) result; - } + val pp_error : Format.formatter -> error -> unit + + (** A RPC endpoint. *) + type rpc = + | RPC : { + name: string; + f: 'req -> 'res; + encode_json_res: 'res -> Yojson.Basic.t; + encode_pb_res: 'res -> Pbrt.Encoder.t -> unit; + decode_json_req: Yojson.Basic.t -> 'req; + decode_pb_req: Pbrt.Decoder.t -> 'req; + } + -> rpc val mk_rpc : name:string -> @@ -73,10 +73,8 @@ module Server : sig rpc (** Helper to build a RPC *) - (** A RPC implementation. *) - type t = { - name: string; + service_name: string; handlers: rpc list; } (** A service with fixed set of methods. *) From 6f4d5e5be681e61f7e5e0edd74706438c8735843 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 10 Sep 2023 14:56:50 -0400 Subject: [PATCH 15/46] streamline pbrt_services, with one place where errors are defined --- src/runtime-services/pbrt_services.ml | 30 +++++++++---------------- src/runtime-services/pbrt_services.mli | 31 +++++++++----------------- 2 files changed, 20 insertions(+), 41 deletions(-) diff --git a/src/runtime-services/pbrt_services.ml b/src/runtime-services/pbrt_services.ml index d1090428..425a0ead 100644 --- a/src/runtime-services/pbrt_services.ml +++ b/src/runtime-services/pbrt_services.ml @@ -1,17 +1,17 @@ module Errors = Errors -(** Client end of services *) -module Client = struct - type error = Errors.rpc_error = - | Invalid_binary of string - | Invalid_json of string - | Timeout of Errors.timeout_info - | Server_error of string - | Transport_error of string - | Unknown_error +type rpc_error = Errors.rpc_error = + | Invalid_binary of string + | Invalid_json of string + | Timeout of Errors.timeout_info + | Server_error of string + | Transport_error of string + | Unknown_error - let pp_error = Errors.pp_rpc_error +let pp_rpc_error = Errors.pp_rpc_error +(** Client end of services *) +module Client = struct type ('req, 'ret) rpc = { service_name: string; rpc_name: string; @@ -35,16 +35,6 @@ end (** Server end of services *) module Server = struct - type error = Errors.rpc_error = - | Invalid_binary of string - | Invalid_json of string - | Timeout of Errors.timeout_info - | Server_error of string - | Transport_error of string - | Unknown_error - - let pp_error = Errors.pp_rpc_error - (** A RPC endpoint. *) type rpc = | RPC : { diff --git a/src/runtime-services/pbrt_services.mli b/src/runtime-services/pbrt_services.mli index f5045e3d..6e28ddfc 100644 --- a/src/runtime-services/pbrt_services.mli +++ b/src/runtime-services/pbrt_services.mli @@ -3,18 +3,18 @@ module Errors = Errors (** RPC errors. These are printable and serializable. *) -(** Service stubs, client side *) -module Client : sig - type error = Errors.rpc_error = - | Invalid_binary of string - | Invalid_json of string - | Timeout of Errors.timeout_info - | Server_error of string - | Transport_error of string - | Unknown_error +type rpc_error = Errors.rpc_error = + | Invalid_binary of string + | Invalid_json of string + | Timeout of Errors.timeout_info + | Server_error of string + | Transport_error of string + | Unknown_error - val pp_error : Format.formatter -> error -> unit +val pp_rpc_error : Format.formatter -> rpc_error -> unit +(** Service stubs, client side *) +module Client : sig type ('req, 'ret) rpc = { service_name: string; rpc_name: string; @@ -39,17 +39,6 @@ end (** Service stubs, server side *) module Server : sig - (** Errors that can arise during request processing. *) - type error = Errors.rpc_error = - | Invalid_binary of string - | Invalid_json of string - | Timeout of Errors.timeout_info - | Server_error of string - | Transport_error of string - | Unknown_error - - val pp_error : Format.formatter -> error -> unit - (** A RPC endpoint. *) type rpc = | RPC : { From 986a201c3e46918492c6182d868e5806b36bd125 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 10 Sep 2023 21:24:49 -0400 Subject: [PATCH 16/46] handle package properly in services --- src/compilerlib/pb_codegen_services.ml | 14 ++++++++++++-- src/runtime-services/pbrt_services.ml | 7 +++++-- src/runtime-services/pbrt_services.mli | 5 +++++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/compilerlib/pb_codegen_services.ml b/src/compilerlib/pb_codegen_services.ml index 1b17264e..04404bf9 100644 --- a/src/compilerlib/pb_codegen_services.ml +++ b/src/compilerlib/pb_codegen_services.ml @@ -1,6 +1,8 @@ module Ot = Pb_codegen_ocaml_type module F = Pb_codegen_formatting +let spf = Printf.sprintf + let string_of_rpc_type (ty : Ot.rpc_type) : string = let f = Pb_codegen_util.string_of_field_type in match ty with @@ -84,6 +86,9 @@ let ocaml_type_of_rpc_type (rpc : Ot.rpc_type) : string = let mod_name_for_client (service : Ot.service) : string = String.capitalize_ascii service.service_name +let string_list_of_package (path : string list) : string = + spf "[%s]" (String.concat ";" @@ List.map (fun s -> spf "%S" s) path) + let gen_service_client_struct (service : Ot.service) sc : unit = let service_name = service.service_name in F.linep sc "module %s = struct" (mod_name_for_client service); @@ -97,8 +102,11 @@ let gen_service_client_struct (service : Ot.service) sc : unit = (Pb_codegen_util.function_name_of_rpc rpc) (string_of_rpc_type rpc.rpc_req) (string_of_rpc_type rpc.rpc_res); - F.linep sc " (mk_rpc ~service_name:%S ~rpc_name:%S" - service.service_name rpc.rpc_name; + F.linep sc " (mk_rpc "; + F.linep sc " ~package:%s" + (string_list_of_package service.service_packages); + F.linep sc " ~service_name:%S ~rpc_name:%S" service.service_name + rpc.rpc_name; F.linep sc " ~encode_json_req:%s" (function_name_encode_json ~service_name ~rpc_name rpc.rpc_req); F.linep sc " ~encode_pb_req:%s" @@ -164,6 +172,8 @@ let gen_service_server_struct (service : Ot.service) sc : unit = F.sub_scope sc (fun sc -> F.line sc "let open Pbrt_services.Server in"; F.linep sc "{ service_name=%S;" service_name; + F.linep sc " package=%s;" + (string_list_of_package service.service_packages); F.line sc " handlers=["; List.iter (fun (rpc : Ot.rpc) -> diff --git a/src/runtime-services/pbrt_services.ml b/src/runtime-services/pbrt_services.ml index 425a0ead..ab7ce449 100644 --- a/src/runtime-services/pbrt_services.ml +++ b/src/runtime-services/pbrt_services.ml @@ -14,6 +14,7 @@ let pp_rpc_error = Errors.pp_rpc_error module Client = struct type ('req, 'ret) rpc = { service_name: string; + package: string list; (** Package for the service *) rpc_name: string; encode_json_req: 'req -> Yojson.Basic.t; encode_pb_req: 'req -> Pbrt.Encoder.t -> unit; @@ -21,10 +22,11 @@ module Client = struct decode_pb_res: Pbrt.Decoder.t -> 'ret; } - let mk_rpc ~service_name ~rpc_name ~encode_json_req ~encode_pb_req - ~decode_json_res ~decode_pb_res () : _ rpc = + let mk_rpc ?(package = []) ~service_name ~rpc_name ~encode_json_req + ~encode_pb_req ~decode_json_res ~decode_pb_res () : _ rpc = { service_name; + package; rpc_name; encode_pb_req; encode_json_req; @@ -61,6 +63,7 @@ module Server = struct type t = { service_name: string; + package: string list; handlers: rpc list; } end diff --git a/src/runtime-services/pbrt_services.mli b/src/runtime-services/pbrt_services.mli index 6e28ddfc..c3fb05d9 100644 --- a/src/runtime-services/pbrt_services.mli +++ b/src/runtime-services/pbrt_services.mli @@ -17,6 +17,7 @@ val pp_rpc_error : Format.formatter -> rpc_error -> unit module Client : sig type ('req, 'ret) rpc = { service_name: string; + package: string list; (** Package for the service *) rpc_name: string; encode_json_req: 'req -> Yojson.Basic.t; encode_pb_req: 'req -> Pbrt.Encoder.t -> unit; @@ -27,6 +28,7 @@ module Client : sig that knows where to send the bytes to actually use it. *) val mk_rpc : + ?package:string list -> service_name:string -> rpc_name:string -> encode_json_req:('req -> Yojson.Basic.t) -> @@ -64,6 +66,9 @@ module Server : sig type t = { service_name: string; + package: string list; + (** The package this belongs in (e.g. "bigco.auth.secretpasswordstash"), + split along "." *) handlers: rpc list; } (** A service with fixed set of methods. *) From 32caabd3e5c5e82dfcb8e5619e7069b0da2acc27 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Mon, 11 Sep 2023 23:39:42 -0400 Subject: [PATCH 17/46] handle streaming more seriously in pbrt_services --- src/runtime-services/pbrt_services.ml | 50 ++++++++++++++++++++---- src/runtime-services/pbrt_services.mli | 53 ++++++++++++++++++++++---- 2 files changed, 89 insertions(+), 14 deletions(-) diff --git a/src/runtime-services/pbrt_services.ml b/src/runtime-services/pbrt_services.ml index ab7ce449..84f7512e 100644 --- a/src/runtime-services/pbrt_services.ml +++ b/src/runtime-services/pbrt_services.ml @@ -10,24 +10,52 @@ type rpc_error = Errors.rpc_error = let pp_rpc_error = Errors.pp_rpc_error +module Value_mode = struct + type unary + type stream +end + +module Pull_stream = struct + type 'a t = { pull: 'ret. unit -> on_result:('a option -> 'ret) -> unit } +end + +module Push_stream = struct + type 'a t = { + push: 'a -> unit; + close: unit -> unit; + } + + let push self x = self.push x + let close self = self.close () +end + (** Client end of services *) module Client = struct - type ('req, 'ret) rpc = { + type _ mode = + | Unary : Value_mode.unary mode + | Stream : Value_mode.stream mode + + type ('req, 'req_mode, 'res, 'res_mode) rpc = { service_name: string; package: string list; (** Package for the service *) rpc_name: string; + req_mode: 'req_mode mode; + res_mode: 'res_mode mode; encode_json_req: 'req -> Yojson.Basic.t; encode_pb_req: 'req -> Pbrt.Encoder.t -> unit; - decode_json_res: Yojson.Basic.t -> 'ret; - decode_pb_res: Pbrt.Decoder.t -> 'ret; + decode_json_res: Yojson.Basic.t -> 'res; + decode_pb_res: Pbrt.Decoder.t -> 'res; } - let mk_rpc ?(package = []) ~service_name ~rpc_name ~encode_json_req - ~encode_pb_req ~decode_json_res ~decode_pb_res () : _ rpc = + let mk_rpc ?(package = []) ~service_name ~rpc_name ~req_mode ~res_mode + ~encode_json_req ~encode_pb_req ~decode_json_res ~decode_pb_res () : _ rpc + = { service_name; package; rpc_name; + req_mode; + res_mode; encode_pb_req; encode_json_req; decode_pb_res; @@ -37,11 +65,19 @@ end (** Server end of services *) module Server = struct + type ('req, 'res) handler = + | Unary : ('req -> 'res) -> ('req, 'res) handler + | Client_stream : ('req Pull_stream.t -> 'res) -> ('req, 'res) handler + | Server_stream : ('req -> 'res Push_stream.t) -> ('req, 'res) handler + | Both_stream : + ('req Pull_stream.t -> 'res Push_stream.t) + -> ('req, 'res) handler + (** A RPC endpoint. *) type rpc = | RPC : { name: string; - f: 'req -> 'res; + f: ('req, 'res) handler; encode_json_res: 'res -> Yojson.Basic.t; encode_pb_res: 'res -> Pbrt.Encoder.t -> unit; decode_json_req: Yojson.Basic.t -> 'req; @@ -49,7 +85,7 @@ module Server = struct } -> rpc - let mk_rpc ~name ~(f : 'req -> 'res) ~encode_json_res ~encode_pb_res + let mk_rpc ~name ~(f : _ handler) ~encode_json_res ~encode_pb_res ~decode_json_req ~decode_pb_req () : rpc = RPC { diff --git a/src/runtime-services/pbrt_services.mli b/src/runtime-services/pbrt_services.mli index c3fb05d9..d6305547 100644 --- a/src/runtime-services/pbrt_services.mli +++ b/src/runtime-services/pbrt_services.mli @@ -13,39 +13,78 @@ type rpc_error = Errors.rpc_error = val pp_rpc_error : Format.formatter -> rpc_error -> unit +(** Whether there's a single value or a stream of them *) +module Value_mode : sig + type unary + type stream +end + +module Pull_stream : sig + type 'a t = { pull: 'ret. unit -> on_result:('a option -> 'ret) -> unit } + (** Stream of incoming values, we can pull them out + one by one until [None] is returned. *) +end + +module Push_stream : sig + type 'a t = { + push: 'a -> unit; + close: unit -> unit; + } + (** Stream of outgoing values, we can push new ones until we close it *) + + val push : 'a t -> 'a -> unit + val close : _ t -> unit +end + (** Service stubs, client side *) module Client : sig - type ('req, 'ret) rpc = { + type _ mode = + | Unary : Value_mode.unary mode + | Stream : Value_mode.stream mode + + type ('req, 'req_mode, 'res, 'res_mode) rpc = { service_name: string; package: string list; (** Package for the service *) rpc_name: string; + req_mode: 'req_mode mode; + res_mode: 'res_mode mode; encode_json_req: 'req -> Yojson.Basic.t; encode_pb_req: 'req -> Pbrt.Encoder.t -> unit; - decode_json_res: Yojson.Basic.t -> 'ret; - decode_pb_res: Pbrt.Decoder.t -> 'ret; + decode_json_res: Yojson.Basic.t -> 'res; + decode_pb_res: Pbrt.Decoder.t -> 'res; } (** A RPC description. You need a transport library - that knows where to send the bytes to actually use it. *) + that knows where to send the bytes to actually use it. *) val mk_rpc : ?package:string list -> service_name:string -> rpc_name:string -> + req_mode:'req_mode mode -> + res_mode:'res_mode mode -> encode_json_req:('req -> Yojson.Basic.t) -> encode_pb_req:('req -> Pbrt.Encoder.t -> unit) -> decode_json_res:(Yojson.Basic.t -> 'res) -> decode_pb_res:(Pbrt.Decoder.t -> 'res) -> unit -> - ('req, 'res) rpc + ('req, 'req_mode, 'res, 'res_mode) rpc end (** Service stubs, server side *) module Server : sig + type ('req, 'res) handler = + | Unary : ('req -> 'res) -> ('req, 'res) handler + | Client_stream : ('req Pull_stream.t -> 'res) -> ('req, 'res) handler + | Server_stream : ('req -> 'res Push_stream.t) -> ('req, 'res) handler + | Both_stream : + ('req Pull_stream.t -> 'res Push_stream.t) + -> ('req, 'res) handler + (** A RPC endpoint. *) type rpc = | RPC : { name: string; - f: 'req -> 'res; + f: ('req, 'res) handler; encode_json_res: 'res -> Yojson.Basic.t; encode_pb_res: 'res -> Pbrt.Encoder.t -> unit; decode_json_req: Yojson.Basic.t -> 'req; @@ -55,7 +94,7 @@ module Server : sig val mk_rpc : name:string -> - f:('req -> 'res) -> + f:('req, 'res) handler -> encode_json_res:('res -> Yojson.Basic.t) -> encode_pb_res:('res -> Pbrt.Encoder.t -> unit) -> decode_json_req:(Yojson.Basic.t -> 'req) -> From ee4c134f57d2952963c5c2031705825604dd768b Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Mon, 11 Sep 2023 23:39:56 -0400 Subject: [PATCH 18/46] codegen for new pbrt_services --- src/compilerlib/pb_codegen_services.ml | 73 ++++++++++++++++---------- 1 file changed, 46 insertions(+), 27 deletions(-) diff --git a/src/compilerlib/pb_codegen_services.ml b/src/compilerlib/pb_codegen_services.ml index 04404bf9..4264ce06 100644 --- a/src/compilerlib/pb_codegen_services.ml +++ b/src/compilerlib/pb_codegen_services.ml @@ -3,11 +3,17 @@ module F = Pb_codegen_formatting let spf = Printf.sprintf -let string_of_rpc_type (ty : Ot.rpc_type) : string = +let string_of_rpc_type_pull (ty : Ot.rpc_type) : string = let f = Pb_codegen_util.string_of_field_type in match ty with | Ot.Rpc_scalar ty -> f ty - | Ot.Rpc_stream ty -> Printf.sprintf "%s Seq.t" (f ty) + | Ot.Rpc_stream ty -> Printf.sprintf "%s Pbrt_services.Pull_stream.t" (f ty) + +let string_of_rpc_type_push (ty : Ot.rpc_type) : string = + let f = Pb_codegen_util.string_of_field_type in + match ty with + | Ot.Rpc_scalar ty -> f ty + | Ot.Rpc_stream ty -> Printf.sprintf "%s Pbrt_services.Push_stream.t" (f ty) let function_name_encode_json ~service_name ~rpc_name (ty : Ot.rpc_type) : string = @@ -23,8 +29,7 @@ let function_name_encode_json ~service_name ~rpc_name (ty : Ot.rpc_type) : exit 1 in match ty with - | Ot.Rpc_scalar ty -> f ty - | Ot.Rpc_stream ty -> Printf.sprintf "Seq.map %s" (f ty) + | Ot.Rpc_scalar ty | Ot.Rpc_stream ty -> f ty let function_name_decode_json ~service_name ~rpc_name (ty : Ot.rpc_type) : string = @@ -40,8 +45,7 @@ let function_name_decode_json ~service_name ~rpc_name (ty : Ot.rpc_type) : exit 1 in match ty with - | Ot.Rpc_scalar ty -> f ty - | Ot.Rpc_stream ty -> Printf.sprintf "Seq.map %s" (f ty) + | Ot.Rpc_scalar ty | Ot.Rpc_stream ty -> f ty let function_name_encode_pb ~service_name ~rpc_name (ty : Ot.rpc_type) : string = @@ -57,8 +61,7 @@ let function_name_encode_pb ~service_name ~rpc_name (ty : Ot.rpc_type) : string exit 1 in match ty with - | Ot.Rpc_scalar ty -> f ty - | Ot.Rpc_stream ty -> Printf.sprintf "Seq.map %s" (f ty) + | Ot.Rpc_scalar ty | Ot.Rpc_stream ty -> f ty let function_name_decode_pb ~service_name ~rpc_name (ty : Ot.rpc_type) : string = @@ -74,14 +77,12 @@ let function_name_decode_pb ~service_name ~rpc_name (ty : Ot.rpc_type) : string exit 1 in match ty with - | Ot.Rpc_scalar ty -> f ty - | Ot.Rpc_stream ty -> Printf.sprintf "Seq.map %s" (f ty) + | Ot.Rpc_scalar ty | Ot.Rpc_stream ty -> f ty -let ocaml_type_of_rpc_type (rpc : Ot.rpc_type) : string = +let ocaml_type_of_rpc_type (rpc : Ot.rpc_type) : string * string = match rpc with - | Rpc_scalar ty -> Pb_codegen_util.string_of_field_type ty - | Rpc_stream ty -> - Printf.sprintf "(%s Seq.t)" (Pb_codegen_util.string_of_field_type ty) + | Rpc_scalar ty -> Pb_codegen_util.string_of_field_type ty, "unary" + | Rpc_stream ty -> Pb_codegen_util.string_of_field_type ty, "stream" let mod_name_for_client (service : Ot.service) : string = String.capitalize_ascii service.service_name @@ -94,19 +95,25 @@ let gen_service_client_struct (service : Ot.service) sc : unit = F.linep sc "module %s = struct" (mod_name_for_client service); F.sub_scope sc (fun sc -> F.linep sc "open Pbrt_services.Client"; + F.linep sc "open Pbrt_services.Value_mode"; List.iter (fun (rpc : Ot.rpc) -> let rpc_name = rpc.rpc_name in + let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in + let req_mode_witness = String.capitalize_ascii req_mode in + let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in + let res_mode_witness = String.capitalize_ascii res_mode in F.empty_line sc; - F.linep sc "let %s : (%s, %s) rpc =" + F.linep sc "let %s : (%s, %s, %s, %s) rpc =" (Pb_codegen_util.function_name_of_rpc rpc) - (string_of_rpc_type rpc.rpc_req) - (string_of_rpc_type rpc.rpc_res); + req req_mode res res_mode; F.linep sc " (mk_rpc "; F.linep sc " ~package:%s" (string_list_of_package service.service_packages); F.linep sc " ~service_name:%S ~rpc_name:%S" service.service_name rpc.rpc_name; + F.linep sc " ~req_mode:%s" req_mode_witness; + F.linep sc " ~res_mode:%s" res_mode_witness; F.linep sc " ~encode_json_req:%s" (function_name_encode_json ~service_name ~rpc_name rpc.rpc_req); F.linep sc " ~encode_pb_req:%s" @@ -115,9 +122,9 @@ let gen_service_client_struct (service : Ot.service) sc : unit = (function_name_decode_json ~service_name ~rpc_name rpc.rpc_res); F.linep sc " ~decode_pb_res:%s" (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res); - F.linep sc "() : (%s, %s) rpc)" - (ocaml_type_of_rpc_type rpc.rpc_req) - (ocaml_type_of_rpc_type rpc.rpc_res)) + let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in + let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in + F.linep sc " () : (%s, %s, %s, %s) rpc)" req req_mode res res_mode) service.service_body); F.line sc "end"; @@ -128,13 +135,15 @@ let gen_service_client_sig (service : Ot.service) sc : unit = F.linep sc "module %s : sig" (mod_name_for_client service); F.sub_scope sc (fun sc -> F.linep sc "open Pbrt_services.Client"; + F.linep sc "open Pbrt_services.Value_mode"; List.iter (fun (rpc : Ot.rpc) -> F.empty_line sc; - F.linep sc "val %s : (%s, %s) rpc" + let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in + let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in + F.linep sc "val %s : (%s, %s, %s, %s) rpc" (Pb_codegen_util.function_name_of_rpc rpc) - (string_of_rpc_type rpc.rpc_req) - (string_of_rpc_type rpc.rpc_res)) + req req_mode res res_mode) service.service_body); F.line sc "end"; F.empty_line sc @@ -151,8 +160,8 @@ let gen_mod_type_of_service (service : Ot.service) sc : unit = (fun (rpc : Ot.rpc) -> F.linep sc "val %s : %s -> %s" (Pb_codegen_util.function_name_of_rpc rpc) - (string_of_rpc_type rpc.rpc_req) - (string_of_rpc_type rpc.rpc_res)) + (string_of_rpc_type_pull rpc.rpc_req) + (string_of_rpc_type_push rpc.rpc_res)) service.service_body); F.line sc "end" @@ -178,8 +187,18 @@ let gen_service_server_struct (service : Ot.service) sc : unit = List.iter (fun (rpc : Ot.rpc) -> let rpc_name = rpc.rpc_name in - F.linep sc " (mk_rpc ~name:%S ~f:M.%s" rpc.rpc_name - (Pb_codegen_util.function_name_of_rpc rpc); + + let handler = + let f = Pb_codegen_util.function_name_of_rpc rpc in + match rpc.rpc_req, rpc.rpc_res with + | Rpc_scalar _, Rpc_scalar _ -> spf "(Unary %s)" f + | Rpc_scalar _, Rpc_stream _ -> spf "(Server_stream %s)" f + | Rpc_stream _, Rpc_scalar _ -> spf "(Client_stream %s)" f + | Rpc_stream _, Rpc_stream _ -> spf "(Both_stream %s)" f + in + + F.linep sc " (mk_rpc ~name:%S" rpc.rpc_name; + F.linep sc " ~f:M.%s" handler; F.linep sc " ~encode_json_res:%s" (function_name_encode_json ~service_name ~rpc_name rpc.rpc_res); F.linep sc " ~encode_pb_res:%s" From fde6751fd911197952b1ccc0b33960a5cc85c191 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Mon, 11 Sep 2023 23:40:07 -0400 Subject: [PATCH 19/46] add an example of service with all combinations of unary/streaming --- src/examples/dune | 12 +++++++++++ src/examples/file_server.proto | 39 ++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 src/examples/file_server.proto diff --git a/src/examples/dune b/src/examples/dune index 20d4ae54..f4919db5 100644 --- a/src/examples/dune +++ b/src/examples/dune @@ -53,3 +53,15 @@ (modules t_calculator calculator) (package pbrt_services) (libraries pbrt pbrt_yojson pbrt_services)) + +(rule + (targets file_server.ml file_server.mli) + (deps file_server.proto) + (action + (run ocaml-protoc --binary --pp --yojson --services --ml_out ./ %{deps}))) + +(test + (name file_server) + (modules file_server) ; just check that it compiles + (package pbrt_services) + (libraries pbrt pbrt_yojson pbrt_services)) diff --git a/src/examples/file_server.proto b/src/examples/file_server.proto new file mode 100644 index 00000000..4cea46cf --- /dev/null +++ b/src/examples/file_server.proto @@ -0,0 +1,39 @@ + +// test that streaming variants all compile + +syntax = "proto3"; + +message FileChunk { + string path = 1; + bytes data = 2; + int32 crc = 3; +} + +message FilePath { + string path = 1; +} + +message FileCrc { + /// CRC of the entire file + int32 crc = 1; +} + +message Empty {} + +message Ping {} + +message Pong {} + +service FileServer { + rpc touch_file(FilePath) returns (Empty); + + /// Upload a file + rpc upload_file(stream FileChunk) returns (FileCrc); + + /// Download a file + rpc download_file(FilePath) returns (stream FileChunk); + + // keepalive + rpc ping_pong(stream Ping) returns (stream Pong); +} + From 46b49ba2b61991427d37c398ae0d24d1d6a66058 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 7 Oct 2023 11:26:18 -0400 Subject: [PATCH 20/46] fix warning --- src/tests/google_unittest/dune | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tests/google_unittest/dune b/src/tests/google_unittest/dune index 62913197..b56b15c8 100644 --- a/src/tests/google_unittest/dune +++ b/src/tests/google_unittest/dune @@ -1,5 +1,6 @@ (test (name google_unittest) + (flags :standard -w -11) (libraries pbrt)) (rule From 506e5f99c1d546332220d9ab381ff2120d89a4c5 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 15 Oct 2023 23:51:07 -0400 Subject: [PATCH 21/46] change how server services are encoded. Use state machines for handlers that take client streams. --- src/compilerlib/pb_codegen_services.ml | 27 ++++++----- src/runtime-services/pbrt_services.ml | 56 ++++++++++++++--------- src/runtime-services/pbrt_services.mli | 62 ++++++++++++++++++-------- 3 files changed, 93 insertions(+), 52 deletions(-) diff --git a/src/compilerlib/pb_codegen_services.ml b/src/compilerlib/pb_codegen_services.ml index 4264ce06..ad2ed047 100644 --- a/src/compilerlib/pb_codegen_services.ml +++ b/src/compilerlib/pb_codegen_services.ml @@ -3,17 +3,17 @@ module F = Pb_codegen_formatting let spf = Printf.sprintf -let string_of_rpc_type_pull (ty : Ot.rpc_type) : string = - let f = Pb_codegen_util.string_of_field_type in - match ty with - | Ot.Rpc_scalar ty -> f ty - | Ot.Rpc_stream ty -> Printf.sprintf "%s Pbrt_services.Pull_stream.t" (f ty) - -let string_of_rpc_type_push (ty : Ot.rpc_type) : string = +let string_of_rpc_handler_type (req : Ot.rpc_type) (res : Ot.rpc_type) : string + = let f = Pb_codegen_util.string_of_field_type in - match ty with - | Ot.Rpc_scalar ty -> f ty - | Ot.Rpc_stream ty -> Printf.sprintf "%s Pbrt_services.Push_stream.t" (f ty) + match req, res with + | Ot.Rpc_scalar req, Ot.Rpc_scalar res -> spf "%s -> %s" (f req) (f res) + | Ot.Rpc_stream req, Ot.Rpc_scalar res -> + spf "(%s, %s) Pbrt_services.Server.client_stream_handler" (f req) (f res) + | Ot.Rpc_scalar req, Ot.Rpc_stream res -> + spf "(%s, %s) Pbrt_services.Server.server_stream_handler" (f req) (f res) + | Ot.Rpc_stream req, Ot.Rpc_stream res -> + spf "(%s, %s) Pbrt_services.Server.both_stream_handler" (f req) (f res) let function_name_encode_json ~service_name ~rpc_name (ty : Ot.rpc_type) : string = @@ -158,10 +158,9 @@ let gen_mod_type_of_service (service : Ot.service) sc : unit = F.sub_scope sc (fun sc -> List.iter (fun (rpc : Ot.rpc) -> - F.linep sc "val %s : %s -> %s" + F.linep sc "val %s : %s" (Pb_codegen_util.function_name_of_rpc rpc) - (string_of_rpc_type_pull rpc.rpc_req) - (string_of_rpc_type_push rpc.rpc_res)) + (string_of_rpc_handler_type rpc.rpc_req rpc.rpc_res)) service.service_body); F.line sc "end" @@ -207,7 +206,7 @@ let gen_service_server_struct (service : Ot.service) sc : unit = (function_name_decode_json ~service_name ~rpc_name rpc.rpc_req); F.linep sc " ~decode_pb_req:%s" (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_req); - F.linep sc " () : rpc);") + F.linep sc " () : any_rpc);") service.service_body; F.line sc "]; }"); F.empty_line sc diff --git a/src/runtime-services/pbrt_services.ml b/src/runtime-services/pbrt_services.ml index 84f7512e..88ebd5c8 100644 --- a/src/runtime-services/pbrt_services.ml +++ b/src/runtime-services/pbrt_services.ml @@ -65,28 +65,44 @@ end (** Server end of services *) module Server = struct - type ('req, 'res) handler = - | Unary : ('req -> 'res) -> ('req, 'res) handler - | Client_stream : ('req Pull_stream.t -> 'res) -> ('req, 'res) handler - | Server_stream : ('req -> 'res Push_stream.t) -> ('req, 'res) handler - | Both_stream : - ('req Pull_stream.t -> 'res Push_stream.t) - -> ('req, 'res) handler - - (** A RPC endpoint. *) - type rpc = - | RPC : { - name: string; - f: ('req, 'res) handler; - encode_json_res: 'res -> Yojson.Basic.t; - encode_pb_res: 'res -> Pbrt.Encoder.t -> unit; - decode_json_req: Yojson.Basic.t -> 'req; - decode_pb_req: Pbrt.Decoder.t -> 'req; + type ('req, 'res) client_stream_handler = + | Client_stream_handler : { + init: unit -> 'state; + on_input: + 'state -> 'req -> [ `Update of 'state | `Return_early of 'res ]; + on_close: 'state -> 'res; } - -> rpc + -> ('req, 'res) client_stream_handler + + type ('req, 'res) server_stream_handler = 'req -> 'res Push_stream.t -> unit + + type ('req, 'res) both_stream_handler = + | Both_stream_handler : { + init: unit -> 'res Push_stream.t -> 'state; + on_input: 'state -> 'res Push_stream.t -> 'req -> 'state; + n_close: 'state -> 'res Push_stream.t -> unit; + } + -> ('req, 'res) both_stream_handler + + type ('req, 'res) handler = + | Unary of ('req -> 'res) + | Client_stream of ('req, 'res) client_stream_handler + | Server_stream of ('req, 'res) server_stream_handler + | Both_stream of ('req, 'res) both_stream_handler + + type ('req, 'res) rpc = { + name: string; + f: ('req, 'res) handler; + encode_json_res: 'res -> Yojson.Basic.t; + encode_pb_res: 'res -> Pbrt.Encoder.t -> unit; + decode_json_req: Yojson.Basic.t -> 'req; + decode_pb_req: Pbrt.Decoder.t -> 'req; + } + + type any_rpc = RPC : ('req, 'res) rpc -> any_rpc [@@unboxed] let mk_rpc ~name ~(f : _ handler) ~encode_json_res ~encode_pb_res - ~decode_json_req ~decode_pb_req () : rpc = + ~decode_json_req ~decode_pb_req () : any_rpc = RPC { name; @@ -100,6 +116,6 @@ module Server = struct type t = { service_name: string; package: string list; - handlers: rpc list; + handlers: any_rpc list; } end diff --git a/src/runtime-services/pbrt_services.mli b/src/runtime-services/pbrt_services.mli index d6305547..7d2cee40 100644 --- a/src/runtime-services/pbrt_services.mli +++ b/src/runtime-services/pbrt_services.mli @@ -72,25 +72,51 @@ end (** Service stubs, server side *) module Server : sig + (** Handler that receives a client stream *) + type ('req, 'res) client_stream_handler = + | Client_stream_handler : { + init: unit -> 'state; (** When a stream starts *) + on_input: + 'state -> 'req -> [ `Update of 'state | `Return_early of 'res ]; + (** When an element of the stream is received. This can either + update the internal state, or return a value early and + stop reading from the input stream. *) + on_close: 'state -> 'res; (** When the stream is over *) + } + -> ('req, 'res) client_stream_handler + + type ('req, 'res) server_stream_handler = 'req -> 'res Push_stream.t -> unit + (** Takes the input value and a push stream (to send items to + the caller, and then close the stream at the end). + The stream's [close] function must be called exactly once. *) + + type ('req, 'res) both_stream_handler = + | Both_stream_handler : { + init: unit -> 'res Push_stream.t -> 'state; + on_input: 'state -> 'res Push_stream.t -> 'req -> 'state; + n_close: 'state -> 'res Push_stream.t -> unit; + } + -> ('req, 'res) both_stream_handler + (** Handler taking a stream of values and returning a stream as well. *) + type ('req, 'res) handler = - | Unary : ('req -> 'res) -> ('req, 'res) handler - | Client_stream : ('req Pull_stream.t -> 'res) -> ('req, 'res) handler - | Server_stream : ('req -> 'res Push_stream.t) -> ('req, 'res) handler - | Both_stream : - ('req Pull_stream.t -> 'res Push_stream.t) - -> ('req, 'res) handler + | Unary of ('req -> 'res) + (** Simple unary handler, gets a value, returns a value. *) + | Client_stream of ('req, 'res) client_stream_handler + | Server_stream of ('req, 'res) server_stream_handler + | Both_stream of ('req, 'res) both_stream_handler + + type ('req, 'res) rpc = { + name: string; + f: ('req, 'res) handler; + encode_json_res: 'res -> Yojson.Basic.t; + encode_pb_res: 'res -> Pbrt.Encoder.t -> unit; + decode_json_req: Yojson.Basic.t -> 'req; + decode_pb_req: Pbrt.Decoder.t -> 'req; + } (** A RPC endpoint. *) - type rpc = - | RPC : { - name: string; - f: ('req, 'res) handler; - encode_json_res: 'res -> Yojson.Basic.t; - encode_pb_res: 'res -> Pbrt.Encoder.t -> unit; - decode_json_req: Yojson.Basic.t -> 'req; - decode_pb_req: Pbrt.Decoder.t -> 'req; - } - -> rpc + type any_rpc = RPC : ('req, 'res) rpc -> any_rpc [@@unboxed] val mk_rpc : name:string -> @@ -100,7 +126,7 @@ module Server : sig decode_json_req:(Yojson.Basic.t -> 'req) -> decode_pb_req:(Pbrt.Decoder.t -> 'req) -> unit -> - rpc + any_rpc (** Helper to build a RPC *) type t = { @@ -108,7 +134,7 @@ module Server : sig package: string list; (** The package this belongs in (e.g. "bigco.auth.secretpasswordstash"), split along "." *) - handlers: rpc list; + handlers: any_rpc list; } (** A service with fixed set of methods. *) end From 5854827848b377bee64ab54c58ee4421e3ee9595 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 17 Oct 2023 14:37:22 -0400 Subject: [PATCH 22/46] pbrt_services: promote the error module --- src/runtime-services/dune | 3 +- src/runtime-services/errors.ml | 210 ++++++++++++++++++++++++++++++++ src/runtime-services/errors.mli | 100 +++++++++++++++ 3 files changed, 312 insertions(+), 1 deletion(-) create mode 100644 src/runtime-services/errors.ml create mode 100644 src/runtime-services/errors.mli diff --git a/src/runtime-services/dune b/src/runtime-services/dune index 8c6c9614..6e925292 100644 --- a/src/runtime-services/dune +++ b/src/runtime-services/dune @@ -9,4 +9,5 @@ (rule (targets errors.ml errors.mli) (deps (:file errors.proto)) - (action (run ../ocaml-protoc/ocaml_protoc.exe --pp --binary --yojson --ml_out . %{file}))) + (mode promote) + (action (run %{bin:ocaml-protoc} --pp --binary --yojson --ml_out . %{file}))) diff --git a/src/runtime-services/errors.ml b/src/runtime-services/errors.ml new file mode 100644 index 00000000..6e012103 --- /dev/null +++ b/src/runtime-services/errors.ml @@ -0,0 +1,210 @@ +[@@@ocaml.warning "-27-30-39"] + +type empty = unit + +type timeout_info = { + timeout_s : float; +} + +type rpc_error = + | Invalid_binary of string + | Invalid_json of string + | Timeout of timeout_info + | Server_error of string + | Transport_error of string + | Unknown_error + +let rec default_empty = () + +let rec default_timeout_info + ?timeout_s:((timeout_s:float) = 0.) + () : timeout_info = { + timeout_s; +} + +let rec default_rpc_error () : rpc_error = Invalid_binary ("") + +type timeout_info_mutable = { + mutable timeout_s : float; +} + +let default_timeout_info_mutable () : timeout_info_mutable = { + timeout_s = 0.; +} + +[@@@ocaml.warning "-27-30-39"] + +(** {2 Formatters} *) + +let rec pp_empty fmt (v:empty) = + let pp_i fmt () = + Pbrt.Pp.pp_unit fmt () + in + Pbrt.Pp.pp_brk pp_i fmt () + +let rec pp_timeout_info fmt (v:timeout_info) = + let pp_i fmt () = + Pbrt.Pp.pp_record_field ~first:true "timeout_s" Pbrt.Pp.pp_float fmt v.timeout_s; + in + Pbrt.Pp.pp_brk pp_i fmt () + +let rec pp_rpc_error fmt (v:rpc_error) = + match v with + | Invalid_binary x -> Format.fprintf fmt "@[Invalid_binary(@,%a)@]" Pbrt.Pp.pp_string x + | Invalid_json x -> Format.fprintf fmt "@[Invalid_json(@,%a)@]" Pbrt.Pp.pp_string x + | Timeout x -> Format.fprintf fmt "@[Timeout(@,%a)@]" pp_timeout_info x + | Server_error x -> Format.fprintf fmt "@[Server_error(@,%a)@]" Pbrt.Pp.pp_string x + | Transport_error x -> Format.fprintf fmt "@[Transport_error(@,%a)@]" Pbrt.Pp.pp_string x + | Unknown_error -> Format.fprintf fmt "Unknown_error" + +[@@@ocaml.warning "-27-30-39"] + +(** {2 Protobuf Encoding} *) + +let rec encode_pb_empty (v:empty) encoder = +() + +let rec encode_pb_timeout_info (v:timeout_info) encoder = + Pbrt.Encoder.key (1, Pbrt.Bits32) encoder; + Pbrt.Encoder.float_as_bits32 v.timeout_s encoder; + () + +let rec encode_pb_rpc_error (v:rpc_error) encoder = + begin match v with + | Invalid_binary x -> + Pbrt.Encoder.key (5, Pbrt.Bytes) encoder; + Pbrt.Encoder.string x encoder; + | Invalid_json x -> + Pbrt.Encoder.key (4, Pbrt.Bytes) encoder; + Pbrt.Encoder.string x encoder; + | Timeout x -> + Pbrt.Encoder.key (3, Pbrt.Bytes) encoder; + Pbrt.Encoder.nested (encode_pb_timeout_info x) encoder; + | Server_error x -> + Pbrt.Encoder.key (2, Pbrt.Bytes) encoder; + Pbrt.Encoder.string x encoder; + | Transport_error x -> + Pbrt.Encoder.key (1, Pbrt.Bytes) encoder; + Pbrt.Encoder.string x encoder; + | Unknown_error -> + Pbrt.Encoder.key (0, Pbrt.Bytes) encoder; + Pbrt.Encoder.empty_nested encoder + end + +[@@@ocaml.warning "-27-30-39"] + +(** {2 Protobuf Decoding} *) + +let rec decode_pb_empty d = + match Pbrt.Decoder.key d with + | None -> (); + | Some (_, pk) -> + Pbrt.Decoder.unexpected_payload "Unexpected fields in empty message(empty)" pk + +let rec decode_pb_timeout_info d = + let v = default_timeout_info_mutable () in + let continue__= ref true in + while !continue__ do + match Pbrt.Decoder.key d with + | None -> ( + ); continue__ := false + | Some (1, Pbrt.Bits32) -> begin + v.timeout_s <- Pbrt.Decoder.float_as_bits32 d; + end + | Some (1, pk) -> + Pbrt.Decoder.unexpected_payload "Message(timeout_info), field(1)" pk + | Some (_, payload_kind) -> Pbrt.Decoder.skip d payload_kind + done; + ({ + timeout_s = v.timeout_s; + } : timeout_info) + +let rec decode_pb_rpc_error d = + let rec loop () = + let ret:rpc_error = match Pbrt.Decoder.key d with + | None -> Pbrt.Decoder.malformed_variant "rpc_error" + | Some (5, _) -> (Invalid_binary (Pbrt.Decoder.string d) : rpc_error) + | Some (4, _) -> (Invalid_json (Pbrt.Decoder.string d) : rpc_error) + | Some (3, _) -> (Timeout (decode_pb_timeout_info (Pbrt.Decoder.nested d)) : rpc_error) + | Some (2, _) -> (Server_error (Pbrt.Decoder.string d) : rpc_error) + | Some (1, _) -> (Transport_error (Pbrt.Decoder.string d) : rpc_error) + | Some (0, _) -> begin + Pbrt.Decoder.empty_nested d ; + (Unknown_error : rpc_error) + end + | Some (n, payload_kind) -> ( + Pbrt.Decoder.skip d payload_kind; + loop () + ) + in + ret + in + loop () + +[@@@ocaml.warning "-27-30-39"] + +(** {2 Protobuf YoJson Encoding} *) + +let rec encode_json_empty (v:empty) = +Pbrt_yojson.make_unit v + +let rec encode_json_timeout_info (v:timeout_info) = + let assoc = [] in + let assoc = ("timeoutS", Pbrt_yojson.make_float v.timeout_s) :: assoc in + `Assoc assoc + +let rec encode_json_rpc_error (v:rpc_error) = + begin match v with + | Invalid_binary v -> `Assoc [("invalidBinary", Pbrt_yojson.make_string v)] + | Invalid_json v -> `Assoc [("invalidJson", Pbrt_yojson.make_string v)] + | Timeout v -> `Assoc [("timeout", encode_json_timeout_info v)] + | Server_error v -> `Assoc [("serverError", Pbrt_yojson.make_string v)] + | Transport_error v -> `Assoc [("transportError", Pbrt_yojson.make_string v)] + | Unknown_error -> `Assoc [("unknownError", `Null)] + end + +[@@@ocaml.warning "-27-30-39"] + +(** {2 JSON Decoding} *) + +let rec decode_json_empty d = +Pbrt_yojson.unit d "empty" "empty record" + +let rec decode_json_timeout_info d = + let v = default_timeout_info_mutable () in + let assoc = match d with + | `Assoc assoc -> assoc + | _ -> assert(false) + in + List.iter (function + | ("timeoutS", json_value) -> + v.timeout_s <- Pbrt_yojson.float json_value "timeout_info" "timeout_s" + + | (_, _) -> () (*Unknown fields are ignored*) + ) assoc; + ({ + timeout_s = v.timeout_s; + } : timeout_info) + +let rec decode_json_rpc_error json = + let assoc = match json with + | `Assoc assoc -> assoc + | _ -> assert(false) + in + let rec loop = function + | [] -> Pbrt_yojson.E.malformed_variant "rpc_error" + | ("invalidBinary", json_value)::_ -> + (Invalid_binary (Pbrt_yojson.string json_value "rpc_error" "Invalid_binary") : rpc_error) + | ("invalidJson", json_value)::_ -> + (Invalid_json (Pbrt_yojson.string json_value "rpc_error" "Invalid_json") : rpc_error) + | ("timeout", json_value)::_ -> + (Timeout ((decode_json_timeout_info json_value)) : rpc_error) + | ("serverError", json_value)::_ -> + (Server_error (Pbrt_yojson.string json_value "rpc_error" "Server_error") : rpc_error) + | ("transportError", json_value)::_ -> + (Transport_error (Pbrt_yojson.string json_value "rpc_error" "Transport_error") : rpc_error) + | ("unknownError", _)::_-> (Unknown_error : rpc_error) + + | _ :: tl -> loop tl + in + loop assoc diff --git a/src/runtime-services/errors.mli b/src/runtime-services/errors.mli new file mode 100644 index 00000000..b3bfbd90 --- /dev/null +++ b/src/runtime-services/errors.mli @@ -0,0 +1,100 @@ + +(** Code for errors.proto *) + +(* generated from "errors.proto", do not edit *) + + + +(** {2 Types} *) + +type empty = unit + +type timeout_info = { + timeout_s : float; +} + +type rpc_error = + | Invalid_binary of string + | Invalid_json of string + | Timeout of timeout_info + | Server_error of string + | Transport_error of string + | Unknown_error + + +(** {2 Basic values} *) + +val default_empty : unit +(** [default_empty ()] is the default value for type [empty] *) + +val default_timeout_info : + ?timeout_s:float -> + unit -> + timeout_info +(** [default_timeout_info ()] is the default value for type [timeout_info] *) + +val default_rpc_error : unit -> rpc_error +(** [default_rpc_error ()] is the default value for type [rpc_error] *) + + +(** {2 Formatters} *) + +val pp_empty : Format.formatter -> empty -> unit +(** [pp_empty v] formats v *) + +val pp_timeout_info : Format.formatter -> timeout_info -> unit +(** [pp_timeout_info v] formats v *) + +val pp_rpc_error : Format.formatter -> rpc_error -> unit +(** [pp_rpc_error v] formats v *) + + +(** {2 Protobuf Encoding} *) + +val encode_pb_empty : empty -> Pbrt.Encoder.t -> unit +(** [encode_pb_empty v encoder] encodes [v] with the given [encoder] *) + +val encode_pb_timeout_info : timeout_info -> Pbrt.Encoder.t -> unit +(** [encode_pb_timeout_info v encoder] encodes [v] with the given [encoder] *) + +val encode_pb_rpc_error : rpc_error -> Pbrt.Encoder.t -> unit +(** [encode_pb_rpc_error v encoder] encodes [v] with the given [encoder] *) + + +(** {2 Protobuf Decoding} *) + +val decode_pb_empty : Pbrt.Decoder.t -> empty +(** [decode_pb_empty decoder] decodes a [empty] binary value from [decoder] *) + +val decode_pb_timeout_info : Pbrt.Decoder.t -> timeout_info +(** [decode_pb_timeout_info decoder] decodes a [timeout_info] binary value from [decoder] *) + +val decode_pb_rpc_error : Pbrt.Decoder.t -> rpc_error +(** [decode_pb_rpc_error decoder] decodes a [rpc_error] binary value from [decoder] *) + + +(** {2 Protobuf YoJson Encoding} *) + +val encode_json_empty : empty -> Yojson.Basic.t +(** [encode_json_empty v encoder] encodes [v] to to json *) + +val encode_json_timeout_info : timeout_info -> Yojson.Basic.t +(** [encode_json_timeout_info v encoder] encodes [v] to to json *) + +val encode_json_rpc_error : rpc_error -> Yojson.Basic.t +(** [encode_json_rpc_error v encoder] encodes [v] to to json *) + + +(** {2 JSON Decoding} *) + +val decode_json_empty : Yojson.Basic.t -> empty +(** [decode_json_empty decoder] decodes a [empty] value from [decoder] *) + +val decode_json_timeout_info : Yojson.Basic.t -> timeout_info +(** [decode_json_timeout_info decoder] decodes a [timeout_info] value from [decoder] *) + +val decode_json_rpc_error : Yojson.Basic.t -> rpc_error +(** [decode_json_rpc_error decoder] decodes a [rpc_error] value from [decoder] *) + + +(** {2 Services} *) From 60013c6829deae0e25846be21caa103ff9637294 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 17 Oct 2023 14:38:20 -0400 Subject: [PATCH 23/46] remove errors from pbrt_services altogether --- src/runtime-services/dune | 6 - src/runtime-services/errors.ml | 210 ------------------------- src/runtime-services/errors.mli | 100 ------------ src/runtime-services/errors.proto | 19 --- src/runtime-services/pbrt_services.ml | 12 -- src/runtime-services/pbrt_services.mli | 13 -- 6 files changed, 360 deletions(-) delete mode 100644 src/runtime-services/errors.ml delete mode 100644 src/runtime-services/errors.mli delete mode 100644 src/runtime-services/errors.proto diff --git a/src/runtime-services/dune b/src/runtime-services/dune index 6e925292..77e192ed 100644 --- a/src/runtime-services/dune +++ b/src/runtime-services/dune @@ -5,9 +5,3 @@ (wrapped true) (synopsis "Runtime library for services generated by ocaml-protoc") (libraries pbrt pbrt_yojson yojson)) - -(rule - (targets errors.ml errors.mli) - (deps (:file errors.proto)) - (mode promote) - (action (run %{bin:ocaml-protoc} --pp --binary --yojson --ml_out . %{file}))) diff --git a/src/runtime-services/errors.ml b/src/runtime-services/errors.ml deleted file mode 100644 index 6e012103..00000000 --- a/src/runtime-services/errors.ml +++ /dev/null @@ -1,210 +0,0 @@ -[@@@ocaml.warning "-27-30-39"] - -type empty = unit - -type timeout_info = { - timeout_s : float; -} - -type rpc_error = - | Invalid_binary of string - | Invalid_json of string - | Timeout of timeout_info - | Server_error of string - | Transport_error of string - | Unknown_error - -let rec default_empty = () - -let rec default_timeout_info - ?timeout_s:((timeout_s:float) = 0.) - () : timeout_info = { - timeout_s; -} - -let rec default_rpc_error () : rpc_error = Invalid_binary ("") - -type timeout_info_mutable = { - mutable timeout_s : float; -} - -let default_timeout_info_mutable () : timeout_info_mutable = { - timeout_s = 0.; -} - -[@@@ocaml.warning "-27-30-39"] - -(** {2 Formatters} *) - -let rec pp_empty fmt (v:empty) = - let pp_i fmt () = - Pbrt.Pp.pp_unit fmt () - in - Pbrt.Pp.pp_brk pp_i fmt () - -let rec pp_timeout_info fmt (v:timeout_info) = - let pp_i fmt () = - Pbrt.Pp.pp_record_field ~first:true "timeout_s" Pbrt.Pp.pp_float fmt v.timeout_s; - in - Pbrt.Pp.pp_brk pp_i fmt () - -let rec pp_rpc_error fmt (v:rpc_error) = - match v with - | Invalid_binary x -> Format.fprintf fmt "@[Invalid_binary(@,%a)@]" Pbrt.Pp.pp_string x - | Invalid_json x -> Format.fprintf fmt "@[Invalid_json(@,%a)@]" Pbrt.Pp.pp_string x - | Timeout x -> Format.fprintf fmt "@[Timeout(@,%a)@]" pp_timeout_info x - | Server_error x -> Format.fprintf fmt "@[Server_error(@,%a)@]" Pbrt.Pp.pp_string x - | Transport_error x -> Format.fprintf fmt "@[Transport_error(@,%a)@]" Pbrt.Pp.pp_string x - | Unknown_error -> Format.fprintf fmt "Unknown_error" - -[@@@ocaml.warning "-27-30-39"] - -(** {2 Protobuf Encoding} *) - -let rec encode_pb_empty (v:empty) encoder = -() - -let rec encode_pb_timeout_info (v:timeout_info) encoder = - Pbrt.Encoder.key (1, Pbrt.Bits32) encoder; - Pbrt.Encoder.float_as_bits32 v.timeout_s encoder; - () - -let rec encode_pb_rpc_error (v:rpc_error) encoder = - begin match v with - | Invalid_binary x -> - Pbrt.Encoder.key (5, Pbrt.Bytes) encoder; - Pbrt.Encoder.string x encoder; - | Invalid_json x -> - Pbrt.Encoder.key (4, Pbrt.Bytes) encoder; - Pbrt.Encoder.string x encoder; - | Timeout x -> - Pbrt.Encoder.key (3, Pbrt.Bytes) encoder; - Pbrt.Encoder.nested (encode_pb_timeout_info x) encoder; - | Server_error x -> - Pbrt.Encoder.key (2, Pbrt.Bytes) encoder; - Pbrt.Encoder.string x encoder; - | Transport_error x -> - Pbrt.Encoder.key (1, Pbrt.Bytes) encoder; - Pbrt.Encoder.string x encoder; - | Unknown_error -> - Pbrt.Encoder.key (0, Pbrt.Bytes) encoder; - Pbrt.Encoder.empty_nested encoder - end - -[@@@ocaml.warning "-27-30-39"] - -(** {2 Protobuf Decoding} *) - -let rec decode_pb_empty d = - match Pbrt.Decoder.key d with - | None -> (); - | Some (_, pk) -> - Pbrt.Decoder.unexpected_payload "Unexpected fields in empty message(empty)" pk - -let rec decode_pb_timeout_info d = - let v = default_timeout_info_mutable () in - let continue__= ref true in - while !continue__ do - match Pbrt.Decoder.key d with - | None -> ( - ); continue__ := false - | Some (1, Pbrt.Bits32) -> begin - v.timeout_s <- Pbrt.Decoder.float_as_bits32 d; - end - | Some (1, pk) -> - Pbrt.Decoder.unexpected_payload "Message(timeout_info), field(1)" pk - | Some (_, payload_kind) -> Pbrt.Decoder.skip d payload_kind - done; - ({ - timeout_s = v.timeout_s; - } : timeout_info) - -let rec decode_pb_rpc_error d = - let rec loop () = - let ret:rpc_error = match Pbrt.Decoder.key d with - | None -> Pbrt.Decoder.malformed_variant "rpc_error" - | Some (5, _) -> (Invalid_binary (Pbrt.Decoder.string d) : rpc_error) - | Some (4, _) -> (Invalid_json (Pbrt.Decoder.string d) : rpc_error) - | Some (3, _) -> (Timeout (decode_pb_timeout_info (Pbrt.Decoder.nested d)) : rpc_error) - | Some (2, _) -> (Server_error (Pbrt.Decoder.string d) : rpc_error) - | Some (1, _) -> (Transport_error (Pbrt.Decoder.string d) : rpc_error) - | Some (0, _) -> begin - Pbrt.Decoder.empty_nested d ; - (Unknown_error : rpc_error) - end - | Some (n, payload_kind) -> ( - Pbrt.Decoder.skip d payload_kind; - loop () - ) - in - ret - in - loop () - -[@@@ocaml.warning "-27-30-39"] - -(** {2 Protobuf YoJson Encoding} *) - -let rec encode_json_empty (v:empty) = -Pbrt_yojson.make_unit v - -let rec encode_json_timeout_info (v:timeout_info) = - let assoc = [] in - let assoc = ("timeoutS", Pbrt_yojson.make_float v.timeout_s) :: assoc in - `Assoc assoc - -let rec encode_json_rpc_error (v:rpc_error) = - begin match v with - | Invalid_binary v -> `Assoc [("invalidBinary", Pbrt_yojson.make_string v)] - | Invalid_json v -> `Assoc [("invalidJson", Pbrt_yojson.make_string v)] - | Timeout v -> `Assoc [("timeout", encode_json_timeout_info v)] - | Server_error v -> `Assoc [("serverError", Pbrt_yojson.make_string v)] - | Transport_error v -> `Assoc [("transportError", Pbrt_yojson.make_string v)] - | Unknown_error -> `Assoc [("unknownError", `Null)] - end - -[@@@ocaml.warning "-27-30-39"] - -(** {2 JSON Decoding} *) - -let rec decode_json_empty d = -Pbrt_yojson.unit d "empty" "empty record" - -let rec decode_json_timeout_info d = - let v = default_timeout_info_mutable () in - let assoc = match d with - | `Assoc assoc -> assoc - | _ -> assert(false) - in - List.iter (function - | ("timeoutS", json_value) -> - v.timeout_s <- Pbrt_yojson.float json_value "timeout_info" "timeout_s" - - | (_, _) -> () (*Unknown fields are ignored*) - ) assoc; - ({ - timeout_s = v.timeout_s; - } : timeout_info) - -let rec decode_json_rpc_error json = - let assoc = match json with - | `Assoc assoc -> assoc - | _ -> assert(false) - in - let rec loop = function - | [] -> Pbrt_yojson.E.malformed_variant "rpc_error" - | ("invalidBinary", json_value)::_ -> - (Invalid_binary (Pbrt_yojson.string json_value "rpc_error" "Invalid_binary") : rpc_error) - | ("invalidJson", json_value)::_ -> - (Invalid_json (Pbrt_yojson.string json_value "rpc_error" "Invalid_json") : rpc_error) - | ("timeout", json_value)::_ -> - (Timeout ((decode_json_timeout_info json_value)) : rpc_error) - | ("serverError", json_value)::_ -> - (Server_error (Pbrt_yojson.string json_value "rpc_error" "Server_error") : rpc_error) - | ("transportError", json_value)::_ -> - (Transport_error (Pbrt_yojson.string json_value "rpc_error" "Transport_error") : rpc_error) - | ("unknownError", _)::_-> (Unknown_error : rpc_error) - - | _ :: tl -> loop tl - in - loop assoc diff --git a/src/runtime-services/errors.mli b/src/runtime-services/errors.mli deleted file mode 100644 index b3bfbd90..00000000 --- a/src/runtime-services/errors.mli +++ /dev/null @@ -1,100 +0,0 @@ - -(** Code for errors.proto *) - -(* generated from "errors.proto", do not edit *) - - - -(** {2 Types} *) - -type empty = unit - -type timeout_info = { - timeout_s : float; -} - -type rpc_error = - | Invalid_binary of string - | Invalid_json of string - | Timeout of timeout_info - | Server_error of string - | Transport_error of string - | Unknown_error - - -(** {2 Basic values} *) - -val default_empty : unit -(** [default_empty ()] is the default value for type [empty] *) - -val default_timeout_info : - ?timeout_s:float -> - unit -> - timeout_info -(** [default_timeout_info ()] is the default value for type [timeout_info] *) - -val default_rpc_error : unit -> rpc_error -(** [default_rpc_error ()] is the default value for type [rpc_error] *) - - -(** {2 Formatters} *) - -val pp_empty : Format.formatter -> empty -> unit -(** [pp_empty v] formats v *) - -val pp_timeout_info : Format.formatter -> timeout_info -> unit -(** [pp_timeout_info v] formats v *) - -val pp_rpc_error : Format.formatter -> rpc_error -> unit -(** [pp_rpc_error v] formats v *) - - -(** {2 Protobuf Encoding} *) - -val encode_pb_empty : empty -> Pbrt.Encoder.t -> unit -(** [encode_pb_empty v encoder] encodes [v] with the given [encoder] *) - -val encode_pb_timeout_info : timeout_info -> Pbrt.Encoder.t -> unit -(** [encode_pb_timeout_info v encoder] encodes [v] with the given [encoder] *) - -val encode_pb_rpc_error : rpc_error -> Pbrt.Encoder.t -> unit -(** [encode_pb_rpc_error v encoder] encodes [v] with the given [encoder] *) - - -(** {2 Protobuf Decoding} *) - -val decode_pb_empty : Pbrt.Decoder.t -> empty -(** [decode_pb_empty decoder] decodes a [empty] binary value from [decoder] *) - -val decode_pb_timeout_info : Pbrt.Decoder.t -> timeout_info -(** [decode_pb_timeout_info decoder] decodes a [timeout_info] binary value from [decoder] *) - -val decode_pb_rpc_error : Pbrt.Decoder.t -> rpc_error -(** [decode_pb_rpc_error decoder] decodes a [rpc_error] binary value from [decoder] *) - - -(** {2 Protobuf YoJson Encoding} *) - -val encode_json_empty : empty -> Yojson.Basic.t -(** [encode_json_empty v encoder] encodes [v] to to json *) - -val encode_json_timeout_info : timeout_info -> Yojson.Basic.t -(** [encode_json_timeout_info v encoder] encodes [v] to to json *) - -val encode_json_rpc_error : rpc_error -> Yojson.Basic.t -(** [encode_json_rpc_error v encoder] encodes [v] to to json *) - - -(** {2 JSON Decoding} *) - -val decode_json_empty : Yojson.Basic.t -> empty -(** [decode_json_empty decoder] decodes a [empty] value from [decoder] *) - -val decode_json_timeout_info : Yojson.Basic.t -> timeout_info -(** [decode_json_timeout_info decoder] decodes a [timeout_info] value from [decoder] *) - -val decode_json_rpc_error : Yojson.Basic.t -> rpc_error -(** [decode_json_rpc_error decoder] decodes a [rpc_error] value from [decoder] *) - - -(** {2 Services} *) diff --git a/src/runtime-services/errors.proto b/src/runtime-services/errors.proto deleted file mode 100644 index 6eb3b123..00000000 --- a/src/runtime-services/errors.proto +++ /dev/null @@ -1,19 +0,0 @@ -syntax = "proto3"; - -message Empty {} - -message TimeoutInfo { - // Timeout, in seconds - float timeout_s = 1; -} - -message RpcError { - oneof error { - Empty unknown_error = 0; - string transport_error = 1; - string server_error = 2; - TimeoutInfo timeout = 3; - string invalid_json = 4; - string invalid_binary = 5; - } -} diff --git a/src/runtime-services/pbrt_services.ml b/src/runtime-services/pbrt_services.ml index 88ebd5c8..b39c7e71 100644 --- a/src/runtime-services/pbrt_services.ml +++ b/src/runtime-services/pbrt_services.ml @@ -1,15 +1,3 @@ -module Errors = Errors - -type rpc_error = Errors.rpc_error = - | Invalid_binary of string - | Invalid_json of string - | Timeout of Errors.timeout_info - | Server_error of string - | Transport_error of string - | Unknown_error - -let pp_rpc_error = Errors.pp_rpc_error - module Value_mode = struct type unary type stream diff --git a/src/runtime-services/pbrt_services.mli b/src/runtime-services/pbrt_services.mli index 7d2cee40..d9c29f4d 100644 --- a/src/runtime-services/pbrt_services.mli +++ b/src/runtime-services/pbrt_services.mli @@ -1,18 +1,5 @@ (** Runtime for Protobuf services. *) -module Errors = Errors -(** RPC errors. These are printable and serializable. *) - -type rpc_error = Errors.rpc_error = - | Invalid_binary of string - | Invalid_json of string - | Timeout of Errors.timeout_info - | Server_error of string - | Transport_error of string - | Unknown_error - -val pp_rpc_error : Format.formatter -> rpc_error -> unit - (** Whether there's a single value or a stream of them *) module Value_mode : sig type unary From 75ae0099a38301e76f46efc4d7ea2d4302abd818 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 18 Oct 2023 22:11:40 -0400 Subject: [PATCH 24/46] update pbrt_services for stream-taking handlers --- src/compilerlib/pb_codegen_services.ml | 5 +-- src/runtime-services/pbrt_services.ml | 35 ++++++++++++-------- src/runtime-services/pbrt_services.mli | 46 ++++++++++++++++---------- 3 files changed, 52 insertions(+), 34 deletions(-) diff --git a/src/compilerlib/pb_codegen_services.ml b/src/compilerlib/pb_codegen_services.ml index ad2ed047..58f97827 100644 --- a/src/compilerlib/pb_codegen_services.ml +++ b/src/compilerlib/pb_codegen_services.ml @@ -13,7 +13,8 @@ let string_of_rpc_handler_type (req : Ot.rpc_type) (res : Ot.rpc_type) : string | Ot.Rpc_scalar req, Ot.Rpc_stream res -> spf "(%s, %s) Pbrt_services.Server.server_stream_handler" (f req) (f res) | Ot.Rpc_stream req, Ot.Rpc_stream res -> - spf "(%s, %s) Pbrt_services.Server.both_stream_handler" (f req) (f res) + spf "(%s, %s) Pbrt_services.Server.bidirectional_stream_handler" (f req) + (f res) let function_name_encode_json ~service_name ~rpc_name (ty : Ot.rpc_type) : string = @@ -193,7 +194,7 @@ let gen_service_server_struct (service : Ot.service) sc : unit = | Rpc_scalar _, Rpc_scalar _ -> spf "(Unary %s)" f | Rpc_scalar _, Rpc_stream _ -> spf "(Server_stream %s)" f | Rpc_stream _, Rpc_scalar _ -> spf "(Client_stream %s)" f - | Rpc_stream _, Rpc_stream _ -> spf "(Both_stream %s)" f + | Rpc_stream _, Rpc_stream _ -> spf "(Bidirectional_stream %s)" f in F.linep sc " (mk_rpc ~name:%S" rpc.rpc_name; diff --git a/src/runtime-services/pbrt_services.ml b/src/runtime-services/pbrt_services.ml index b39c7e71..3815d266 100644 --- a/src/runtime-services/pbrt_services.ml +++ b/src/runtime-services/pbrt_services.ml @@ -53,30 +53,37 @@ end (** Server end of services *) module Server = struct + type ('req, 'res, 'state) client_stream_handler_with_state = { + init: unit -> 'state; + on_item: 'state -> 'req -> unit; + on_close: 'state -> 'res; + } + type ('req, 'res) client_stream_handler = - | Client_stream_handler : { - init: unit -> 'state; - on_input: - 'state -> 'req -> [ `Update of 'state | `Return_early of 'res ]; - on_close: 'state -> 'res; - } + | Client_stream_handler : + ('req, 'res, 'state) client_stream_handler_with_state -> ('req, 'res) client_stream_handler + [@@unboxed] type ('req, 'res) server_stream_handler = 'req -> 'res Push_stream.t -> unit - type ('req, 'res) both_stream_handler = - | Both_stream_handler : { - init: unit -> 'res Push_stream.t -> 'state; - on_input: 'state -> 'res Push_stream.t -> 'req -> 'state; - n_close: 'state -> 'res Push_stream.t -> unit; - } - -> ('req, 'res) both_stream_handler + type ('req, 'res, 'state) bidirectional_stream_handler_with_state = { + init: unit -> 'res Push_stream.t -> 'state; + on_item: 'state -> 'req -> unit; + on_close: 'state -> unit; + } + + type ('req, 'res) bidirectional_stream_handler = + | Bidirectional_stream_handler : + ('req, 'res, 'state) bidirectional_stream_handler_with_state + -> ('req, 'res) bidirectional_stream_handler + [@@unboxed] type ('req, 'res) handler = | Unary of ('req -> 'res) | Client_stream of ('req, 'res) client_stream_handler | Server_stream of ('req, 'res) server_stream_handler - | Both_stream of ('req, 'res) both_stream_handler + | Bidirectional_stream of ('req, 'res) bidirectional_stream_handler type ('req, 'res) rpc = { name: string; diff --git a/src/runtime-services/pbrt_services.mli b/src/runtime-services/pbrt_services.mli index d9c29f4d..682e5e2b 100644 --- a/src/runtime-services/pbrt_services.mli +++ b/src/runtime-services/pbrt_services.mli @@ -59,39 +59,49 @@ end (** Service stubs, server side *) module Server : sig + type ('req, 'res, 'state) client_stream_handler_with_state = { + init: unit -> 'state; (** When a stream starts *) + on_item: 'state -> 'req -> unit; + (** When an element of the stream is received. This can either + update the internal state by mutation, performing side effects, + or choose to return a value early and stop reading from the input stream. *) + on_close: 'state -> 'res; (** When the stream is over *) + } (** Handler that receives a client stream *) + type ('req, 'res) client_stream_handler = - | Client_stream_handler : { - init: unit -> 'state; (** When a stream starts *) - on_input: - 'state -> 'req -> [ `Update of 'state | `Return_early of 'res ]; - (** When an element of the stream is received. This can either - update the internal state, or return a value early and - stop reading from the input stream. *) - on_close: 'state -> 'res; (** When the stream is over *) - } + | Client_stream_handler : + ('req, 'res, 'state) client_stream_handler_with_state -> ('req, 'res) client_stream_handler + [@@unboxed] type ('req, 'res) server_stream_handler = 'req -> 'res Push_stream.t -> unit (** Takes the input value and a push stream (to send items to the caller, and then close the stream at the end). The stream's [close] function must be called exactly once. *) - type ('req, 'res) both_stream_handler = - | Both_stream_handler : { - init: unit -> 'res Push_stream.t -> 'state; - on_input: 'state -> 'res Push_stream.t -> 'req -> 'state; - n_close: 'state -> 'res Push_stream.t -> unit; - } - -> ('req, 'res) both_stream_handler - (** Handler taking a stream of values and returning a stream as well. *) + type ('req, 'res, 'state) bidirectional_stream_handler_with_state = { + init: unit -> 'res Push_stream.t -> 'state; + on_item: 'state -> 'req -> unit; + on_close: 'state -> unit; + } + (** Handler taking a stream of values and returning a stream as well. *) + + type ('req, 'res) bidirectional_stream_handler = + | Bidirectional_stream_handler : + ('req, 'res, 'state) bidirectional_stream_handler_with_state + -> ('req, 'res) bidirectional_stream_handler + [@@unboxed] + (** A handler, i.e the server side implementation of a single RPC method. + Handlers come in various flavors because they make take, or return, + streams of values. *) type ('req, 'res) handler = | Unary of ('req -> 'res) (** Simple unary handler, gets a value, returns a value. *) | Client_stream of ('req, 'res) client_stream_handler | Server_stream of ('req, 'res) server_stream_handler - | Both_stream of ('req, 'res) both_stream_handler + | Bidirectional_stream of ('req, 'res) bidirectional_stream_handler type ('req, 'res) rpc = { name: string; From c7fca43aa4ad7fe24acf69219a08e85b803e92c3 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 18 Oct 2023 22:19:28 -0400 Subject: [PATCH 25/46] more doc --- src/runtime-services/pbrt_services.mli | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/runtime-services/pbrt_services.mli b/src/runtime-services/pbrt_services.mli index 682e5e2b..766a23b2 100644 --- a/src/runtime-services/pbrt_services.mli +++ b/src/runtime-services/pbrt_services.mli @@ -67,8 +67,11 @@ module Server : sig or choose to return a value early and stop reading from the input stream. *) on_close: 'state -> 'res; (** When the stream is over *) } - (** Handler that receives a client stream *) + (** Handler that receives a client stream and produces a value at the end. + It has an internal (mutable) state that is updated + every time an item is received from the client. *) + (** A client stream handler with hidden internal state. *) type ('req, 'res) client_stream_handler = | Client_stream_handler : ('req, 'res, 'state) client_stream_handler_with_state @@ -85,8 +88,11 @@ module Server : sig on_item: 'state -> 'req -> unit; on_close: 'state -> unit; } - (** Handler taking a stream of values and returning a stream as well. *) + (** Handler taking a stream of values and returning a stream as well. + It has an internal (mutable) state that can be updated everytime + an item is received from the client. *) + (** A bidirectional handler with the internal state hidden *) type ('req, 'res) bidirectional_stream_handler = | Bidirectional_stream_handler : ('req, 'res, 'state) bidirectional_stream_handler_with_state @@ -100,8 +106,11 @@ module Server : sig | Unary of ('req -> 'res) (** Simple unary handler, gets a value, returns a value. *) | Client_stream of ('req, 'res) client_stream_handler + (** Handler that takes a client stream *) | Server_stream of ('req, 'res) server_stream_handler + (** Handler that returns a stream to the client *) | Bidirectional_stream of ('req, 'res) bidirectional_stream_handler + (** Handler that takes and returns a stream *) type ('req, 'res) rpc = { name: string; @@ -111,6 +120,8 @@ module Server : sig decode_json_req: Yojson.Basic.t -> 'req; decode_pb_req: Pbrt.Decoder.t -> 'req; } + (** A single RPC method, alongside encoders and decoders for + input and output types. . *) (** A RPC endpoint. *) type any_rpc = RPC : ('req, 'res) rpc -> any_rpc [@@unboxed] From 94d6e411823f17d83aeb6394cab80158c338bbe6 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 18 Oct 2023 22:32:48 -0400 Subject: [PATCH 26/46] pbrt_services: remove redundant .ml file --- src/runtime-services/dune | 1 + src/runtime-services/pbrt_services.ml | 116 ------------------------- src/runtime-services/pbrt_services.mli | 11 +-- src/runtime-services/push_stream.ml | 7 ++ src/runtime-services/push_stream.mli | 10 +++ 5 files changed, 19 insertions(+), 126 deletions(-) delete mode 100644 src/runtime-services/pbrt_services.ml create mode 100644 src/runtime-services/push_stream.ml create mode 100644 src/runtime-services/push_stream.mli diff --git a/src/runtime-services/dune b/src/runtime-services/dune index 77e192ed..7075ad1f 100644 --- a/src/runtime-services/dune +++ b/src/runtime-services/dune @@ -4,4 +4,5 @@ (public_name pbrt_services) (wrapped true) (synopsis "Runtime library for services generated by ocaml-protoc") + (modules_without_implementation pbrt_services) (libraries pbrt pbrt_yojson yojson)) diff --git a/src/runtime-services/pbrt_services.ml b/src/runtime-services/pbrt_services.ml deleted file mode 100644 index 3815d266..00000000 --- a/src/runtime-services/pbrt_services.ml +++ /dev/null @@ -1,116 +0,0 @@ -module Value_mode = struct - type unary - type stream -end - -module Pull_stream = struct - type 'a t = { pull: 'ret. unit -> on_result:('a option -> 'ret) -> unit } -end - -module Push_stream = struct - type 'a t = { - push: 'a -> unit; - close: unit -> unit; - } - - let push self x = self.push x - let close self = self.close () -end - -(** Client end of services *) -module Client = struct - type _ mode = - | Unary : Value_mode.unary mode - | Stream : Value_mode.stream mode - - type ('req, 'req_mode, 'res, 'res_mode) rpc = { - service_name: string; - package: string list; (** Package for the service *) - rpc_name: string; - req_mode: 'req_mode mode; - res_mode: 'res_mode mode; - encode_json_req: 'req -> Yojson.Basic.t; - encode_pb_req: 'req -> Pbrt.Encoder.t -> unit; - decode_json_res: Yojson.Basic.t -> 'res; - decode_pb_res: Pbrt.Decoder.t -> 'res; - } - - let mk_rpc ?(package = []) ~service_name ~rpc_name ~req_mode ~res_mode - ~encode_json_req ~encode_pb_req ~decode_json_res ~decode_pb_res () : _ rpc - = - { - service_name; - package; - rpc_name; - req_mode; - res_mode; - encode_pb_req; - encode_json_req; - decode_pb_res; - decode_json_res; - } -end - -(** Server end of services *) -module Server = struct - type ('req, 'res, 'state) client_stream_handler_with_state = { - init: unit -> 'state; - on_item: 'state -> 'req -> unit; - on_close: 'state -> 'res; - } - - type ('req, 'res) client_stream_handler = - | Client_stream_handler : - ('req, 'res, 'state) client_stream_handler_with_state - -> ('req, 'res) client_stream_handler - [@@unboxed] - - type ('req, 'res) server_stream_handler = 'req -> 'res Push_stream.t -> unit - - type ('req, 'res, 'state) bidirectional_stream_handler_with_state = { - init: unit -> 'res Push_stream.t -> 'state; - on_item: 'state -> 'req -> unit; - on_close: 'state -> unit; - } - - type ('req, 'res) bidirectional_stream_handler = - | Bidirectional_stream_handler : - ('req, 'res, 'state) bidirectional_stream_handler_with_state - -> ('req, 'res) bidirectional_stream_handler - [@@unboxed] - - type ('req, 'res) handler = - | Unary of ('req -> 'res) - | Client_stream of ('req, 'res) client_stream_handler - | Server_stream of ('req, 'res) server_stream_handler - | Bidirectional_stream of ('req, 'res) bidirectional_stream_handler - - type ('req, 'res) rpc = { - name: string; - f: ('req, 'res) handler; - encode_json_res: 'res -> Yojson.Basic.t; - encode_pb_res: 'res -> Pbrt.Encoder.t -> unit; - decode_json_req: Yojson.Basic.t -> 'req; - decode_pb_req: Pbrt.Decoder.t -> 'req; - } - - type any_rpc = RPC : ('req, 'res) rpc -> any_rpc [@@unboxed] - - let mk_rpc ~name ~(f : _ handler) ~encode_json_res ~encode_pb_res - ~decode_json_req ~decode_pb_req () : any_rpc = - RPC - { - name; - f; - decode_pb_req; - decode_json_req; - encode_pb_res; - encode_json_res; - } - - type t = { - service_name: string; - package: string list; - handlers: any_rpc list; - } -end diff --git a/src/runtime-services/pbrt_services.mli b/src/runtime-services/pbrt_services.mli index 766a23b2..cf78c1c1 100644 --- a/src/runtime-services/pbrt_services.mli +++ b/src/runtime-services/pbrt_services.mli @@ -12,16 +12,7 @@ module Pull_stream : sig one by one until [None] is returned. *) end -module Push_stream : sig - type 'a t = { - push: 'a -> unit; - close: unit -> unit; - } - (** Stream of outgoing values, we can push new ones until we close it *) - - val push : 'a t -> 'a -> unit - val close : _ t -> unit -end +module Push_stream = Push_stream (** Service stubs, client side *) module Client : sig diff --git a/src/runtime-services/push_stream.ml b/src/runtime-services/push_stream.ml new file mode 100644 index 00000000..7487943e --- /dev/null +++ b/src/runtime-services/push_stream.ml @@ -0,0 +1,7 @@ +type 'a t = { + push: 'a -> unit; + close: unit -> unit; +} + +let push self x = self.push x +let close self = self.close () diff --git a/src/runtime-services/push_stream.mli b/src/runtime-services/push_stream.mli new file mode 100644 index 00000000..d62e8c11 --- /dev/null +++ b/src/runtime-services/push_stream.mli @@ -0,0 +1,10 @@ +(** Producer end of a stream, into which we can push values *) + +type 'a t = { + push: 'a -> unit; + close: unit -> unit; +} +(** Stream of outgoing values, we can push new ones until we close it *) + +val push : 'a t -> 'a -> unit +val close : _ t -> unit From 837128684bc201960fcb8633308106092c15696c Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 18 Oct 2023 22:34:36 -0400 Subject: [PATCH 27/46] remove pull stream --- src/runtime-services/pbrt_services.mli | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/runtime-services/pbrt_services.mli b/src/runtime-services/pbrt_services.mli index cf78c1c1..0ea5f543 100644 --- a/src/runtime-services/pbrt_services.mli +++ b/src/runtime-services/pbrt_services.mli @@ -6,12 +6,6 @@ module Value_mode : sig type stream end -module Pull_stream : sig - type 'a t = { pull: 'ret. unit -> on_result:('a option -> 'ret) -> unit } - (** Stream of incoming values, we can pull them out - one by one until [None] is returned. *) -end - module Push_stream = Push_stream (** Service stubs, client side *) From aa8e55cb1381cda75b8a2143b5524d45926152f0 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 18 Oct 2023 22:58:10 -0400 Subject: [PATCH 28/46] need .ml file apparently? it seems that `modules_without_implementation` isn't enough, once you try to use Pbrt_services in a library. --- src/runtime-services/dune | 1 - .../{pbrt_services.mli => pbrt_services.ml} | 73 +++++++++++++------ 2 files changed, 49 insertions(+), 25 deletions(-) rename src/runtime-services/{pbrt_services.mli => pbrt_services.ml} (74%) diff --git a/src/runtime-services/dune b/src/runtime-services/dune index 7075ad1f..77e192ed 100644 --- a/src/runtime-services/dune +++ b/src/runtime-services/dune @@ -4,5 +4,4 @@ (public_name pbrt_services) (wrapped true) (synopsis "Runtime library for services generated by ocaml-protoc") - (modules_without_implementation pbrt_services) (libraries pbrt pbrt_yojson yojson)) diff --git a/src/runtime-services/pbrt_services.mli b/src/runtime-services/pbrt_services.ml similarity index 74% rename from src/runtime-services/pbrt_services.mli rename to src/runtime-services/pbrt_services.ml index 0ea5f543..1821bd50 100644 --- a/src/runtime-services/pbrt_services.mli +++ b/src/runtime-services/pbrt_services.ml @@ -1,7 +1,7 @@ (** Runtime for Protobuf services. *) (** Whether there's a single value or a stream of them *) -module Value_mode : sig +module Value_mode = struct type unary type stream end @@ -9,7 +9,7 @@ end module Push_stream = Push_stream (** Service stubs, client side *) -module Client : sig +module Client = struct type _ mode = | Unary : Value_mode.unary mode | Stream : Value_mode.stream mode @@ -28,22 +28,36 @@ module Client : sig (** A RPC description. You need a transport library that knows where to send the bytes to actually use it. *) - val mk_rpc : - ?package:string list -> - service_name:string -> - rpc_name:string -> - req_mode:'req_mode mode -> - res_mode:'res_mode mode -> - encode_json_req:('req -> Yojson.Basic.t) -> - encode_pb_req:('req -> Pbrt.Encoder.t -> unit) -> - decode_json_res:(Yojson.Basic.t -> 'res) -> - decode_pb_res:(Pbrt.Decoder.t -> 'res) -> - unit -> - ('req, 'req_mode, 'res, 'res_mode) rpc + let mk_rpc : + ?package:string list -> + service_name:string -> + rpc_name:string -> + req_mode:'req_mode mode -> + res_mode:'res_mode mode -> + encode_json_req:('req -> Yojson.Basic.t) -> + encode_pb_req:('req -> Pbrt.Encoder.t -> unit) -> + decode_json_res:(Yojson.Basic.t -> 'res) -> + decode_pb_res:(Pbrt.Decoder.t -> 'res) -> + unit -> + ('req, 'req_mode, 'res, 'res_mode) rpc = + fun ?(package = []) ~service_name ~rpc_name ~req_mode ~res_mode + ~encode_json_req ~encode_pb_req ~decode_json_res ~decode_pb_res () : + _ rpc -> + { + service_name; + package; + rpc_name; + req_mode; + res_mode; + encode_pb_req; + encode_json_req; + decode_pb_res; + decode_json_res; + } end (** Service stubs, server side *) -module Server : sig +module Server = struct type ('req, 'res, 'state) client_stream_handler_with_state = { init: unit -> 'state; (** When a stream starts *) on_item: 'state -> 'req -> unit; @@ -111,16 +125,27 @@ module Server : sig (** A RPC endpoint. *) type any_rpc = RPC : ('req, 'res) rpc -> any_rpc [@@unboxed] - val mk_rpc : - name:string -> - f:('req, 'res) handler -> - encode_json_res:('res -> Yojson.Basic.t) -> - encode_pb_res:('res -> Pbrt.Encoder.t -> unit) -> - decode_json_req:(Yojson.Basic.t -> 'req) -> - decode_pb_req:(Pbrt.Decoder.t -> 'req) -> - unit -> - any_rpc (** Helper to build a RPC *) + let mk_rpc : + name:string -> + f:('req, 'res) handler -> + encode_json_res:('res -> Yojson.Basic.t) -> + encode_pb_res:('res -> Pbrt.Encoder.t -> unit) -> + decode_json_req:(Yojson.Basic.t -> 'req) -> + decode_pb_req:(Pbrt.Decoder.t -> 'req) -> + unit -> + any_rpc = + fun ~name ~(f : _ handler) ~encode_json_res ~encode_pb_res ~decode_json_req + ~decode_pb_req () : any_rpc -> + RPC + { + name; + f; + decode_pb_req; + decode_json_req; + encode_pb_res; + encode_json_res; + } type t = { service_name: string; From 2d00624eefcd202fd004c4623dfb2aff8675ae53 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Thu, 26 Oct 2023 22:27:29 -0400 Subject: [PATCH 29/46] radical simplification of server side of Pbrt_services we don't try to decide what streams or futures look like. Let the RPC libraries do it. Instead we just provide a way to turn all the methods in a service into uniform handlers. --- src/compilerlib/pb_codegen_all.ml | 6 +- src/compilerlib/pb_codegen_services.ml | 245 ++++++++++++------------ src/compilerlib/pb_codegen_services.mli | 10 +- src/runtime-services/pbrt_services.ml | 106 ++++------ 4 files changed, 157 insertions(+), 210 deletions(-) diff --git a/src/compilerlib/pb_codegen_all.ml b/src/compilerlib/pb_codegen_all.ml index b23df735..17e733c0 100644 --- a/src/compilerlib/pb_codegen_all.ml +++ b/src/compilerlib/pb_codegen_all.ml @@ -123,12 +123,10 @@ let generate_mutable_records (self : ocaml_mod) ocaml_types : unit = ocaml_types let generate_service_struct service sc : unit = - Pb_codegen_services.gen_service_client_struct service sc; - Pb_codegen_services.gen_service_server_struct service sc + Pb_codegen_services.gen_service_struct service sc let generate_service_sig service sc : unit = - Pb_codegen_services.gen_service_client_sig service sc; - Pb_codegen_services.gen_service_server_sig service sc + Pb_codegen_services.gen_service_sig service sc let generate_services (self : ocaml_mod) services : unit = generate_for_all_services services self.ml generate_service_struct None; diff --git a/src/compilerlib/pb_codegen_services.ml b/src/compilerlib/pb_codegen_services.ml index 58f97827..86e8db09 100644 --- a/src/compilerlib/pb_codegen_services.ml +++ b/src/compilerlib/pb_codegen_services.ml @@ -3,18 +3,15 @@ module F = Pb_codegen_formatting let spf = Printf.sprintf -let string_of_rpc_handler_type (req : Ot.rpc_type) (res : Ot.rpc_type) : string - = - let f = Pb_codegen_util.string_of_field_type in - match req, res with - | Ot.Rpc_scalar req, Ot.Rpc_scalar res -> spf "%s -> %s" (f req) (f res) - | Ot.Rpc_stream req, Ot.Rpc_scalar res -> - spf "(%s, %s) Pbrt_services.Server.client_stream_handler" (f req) (f res) - | Ot.Rpc_scalar req, Ot.Rpc_stream res -> - spf "(%s, %s) Pbrt_services.Server.server_stream_handler" (f req) (f res) - | Ot.Rpc_stream req, Ot.Rpc_stream res -> - spf "(%s, %s) Pbrt_services.Server.bidirectional_stream_handler" (f req) - (f res) +let ocaml_type_of_rpc_type (rpc : Ot.rpc_type) : string * string = + match rpc with + | Rpc_scalar ty -> Pb_codegen_util.string_of_field_type ty, "unary" + | Rpc_stream ty -> Pb_codegen_util.string_of_field_type ty, "stream" + +let string_of_server_rpc (req : Ot.rpc_type) (res : Ot.rpc_type) : string = + let req, req_mode = ocaml_type_of_rpc_type req in + let res, res_mode = ocaml_type_of_rpc_type res in + spf "(%s, %s, %s, %s) Server.rpc" req req_mode res res_mode let function_name_encode_json ~service_name ~rpc_name (ty : Ot.rpc_type) : string = @@ -80,11 +77,6 @@ let function_name_decode_pb ~service_name ~rpc_name (ty : Ot.rpc_type) : string match ty with | Ot.Rpc_scalar ty | Ot.Rpc_stream ty -> f ty -let ocaml_type_of_rpc_type (rpc : Ot.rpc_type) : string * string = - match rpc with - | Rpc_scalar ty -> Pb_codegen_util.string_of_field_type ty, "unary" - | Rpc_stream ty -> Pb_codegen_util.string_of_field_type ty, "stream" - let mod_name_for_client (service : Ot.service) : string = String.capitalize_ascii service.service_name @@ -93,136 +85,137 @@ let string_list_of_package (path : string list) : string = let gen_service_client_struct (service : Ot.service) sc : unit = let service_name = service.service_name in + List.iter + (fun (rpc : Ot.rpc) -> + let rpc_name = rpc.rpc_name in + let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in + let req_mode_witness = String.capitalize_ascii req_mode in + let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in + let res_mode_witness = String.capitalize_ascii res_mode in + F.empty_line sc; + F.linep sc "let %s : (%s, %s, %s, %s) Client.rpc =" + (Pb_codegen_util.function_name_of_rpc rpc) + req req_mode res res_mode; + F.linep sc " (Client.mk_rpc "; + F.linep sc " ~package:%s" + (string_list_of_package service.service_packages); + F.linep sc " ~service_name:%S ~rpc_name:%S" service.service_name + rpc.rpc_name; + F.linep sc " ~req_mode:%s" req_mode_witness; + F.linep sc " ~res_mode:%s" res_mode_witness; + F.linep sc " ~encode_json_req:%s" + (function_name_encode_json ~service_name ~rpc_name rpc.rpc_req); + F.linep sc " ~encode_pb_req:%s" + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req); + F.linep sc " ~decode_json_res:%s" + (function_name_decode_json ~service_name ~rpc_name rpc.rpc_res); + F.linep sc " ~decode_pb_res:%s" + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res); + let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in + let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in + F.linep sc " () : (%s, %s, %s, %s) Client.rpc)" req req_mode res + res_mode) + service.service_body + +let gen_service_server_struct (service : Ot.service) sc : unit = + let service_name = service.service_name in + + (* generate rpc descriptions for the server side *) + List.iter + (fun (rpc : Ot.rpc) -> + F.empty_line sc; + let rpc_name = rpc.rpc_name in + let name = Pb_codegen_util.function_name_of_rpc rpc in + let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in + let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in + let req_mode_witness = String.capitalize_ascii req_mode in + let res_mode_witness = String.capitalize_ascii res_mode in + + F.linep sc "let _rpc_%s : (%s,%s,%s,%s) Server.rpc = " name req req_mode + res res_mode; + F.linep sc " (Server.mk_rpc ~name:%S" rpc.rpc_name; + F.linep sc " ~req_mode:%s ~res_mode:%s" req_mode_witness + res_mode_witness; + F.linep sc " ~encode_json_res:%s" + (function_name_encode_json ~service_name ~rpc_name rpc.rpc_res); + F.linep sc " ~encode_pb_res:%s" + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_res); + F.linep sc " ~decode_json_req:%s" + (function_name_decode_json ~service_name ~rpc_name rpc.rpc_req); + F.linep sc " ~decode_pb_req:%s" + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_req); + F.linep sc " () : _ Server.rpc)") + service.service_body; + + (* now generate a function from the module type to a [Service_server.t] *) + F.empty_line sc; + F.linep sc "let make_server"; + List.iter + (fun (rpc : Ot.rpc) -> + let name = Pb_codegen_util.function_name_of_rpc rpc in + F.linep sc " ~%s" name) + service.service_body; + F.line sc " () : _ Server.t ="; + F.linep sc " { Server."; + F.linep sc " service_name=%S;" service_name; + F.linep sc " package=%s;" + (string_list_of_package service.service_packages); + F.line sc " handlers=["; + List.iter + (fun (rpc : Ot.rpc) -> + let f = Pb_codegen_util.function_name_of_rpc rpc in + F.linep sc " {Server.name=%S; handle=%s %s};" rpc.rpc_name f + (spf "_rpc_%s" f)) + service.service_body; + F.line sc " ];"; + F.line sc " }"; + F.empty_line sc + +let gen_service_struct (service : Ot.service) sc : unit = F.linep sc "module %s = struct" (mod_name_for_client service); F.sub_scope sc (fun sc -> - F.linep sc "open Pbrt_services.Client"; + F.linep sc "open Pbrt_services"; F.linep sc "open Pbrt_services.Value_mode"; - List.iter - (fun (rpc : Ot.rpc) -> - let rpc_name = rpc.rpc_name in - let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in - let req_mode_witness = String.capitalize_ascii req_mode in - let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in - let res_mode_witness = String.capitalize_ascii res_mode in - F.empty_line sc; - F.linep sc "let %s : (%s, %s, %s, %s) rpc =" - (Pb_codegen_util.function_name_of_rpc rpc) - req req_mode res res_mode; - F.linep sc " (mk_rpc "; - F.linep sc " ~package:%s" - (string_list_of_package service.service_packages); - F.linep sc " ~service_name:%S ~rpc_name:%S" service.service_name - rpc.rpc_name; - F.linep sc " ~req_mode:%s" req_mode_witness; - F.linep sc " ~res_mode:%s" res_mode_witness; - F.linep sc " ~encode_json_req:%s" - (function_name_encode_json ~service_name ~rpc_name rpc.rpc_req); - F.linep sc " ~encode_pb_req:%s" - (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req); - F.linep sc " ~decode_json_res:%s" - (function_name_decode_json ~service_name ~rpc_name rpc.rpc_res); - F.linep sc " ~decode_pb_res:%s" - (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res); - let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in - let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in - F.linep sc " () : (%s, %s, %s, %s) rpc)" req req_mode res res_mode) - service.service_body); + + gen_service_client_struct service sc; + + (* now the server side *) + gen_service_server_struct service sc); F.line sc "end"; F.empty_line sc -let gen_service_client_sig (service : Ot.service) sc : unit = - F.linep sc "(** Client for %s *)" service.service_name; +let gen_service_sig (service : Ot.service) sc : unit = + F.linep sc "(** %s service *)" service.service_name; F.linep sc "module %s : sig" (mod_name_for_client service); F.sub_scope sc (fun sc -> - F.linep sc "open Pbrt_services.Client"; + F.linep sc "open Pbrt_services"; F.linep sc "open Pbrt_services.Value_mode"; + + (* client *) List.iter (fun (rpc : Ot.rpc) -> F.empty_line sc; let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in - F.linep sc "val %s : (%s, %s, %s, %s) rpc" + F.linep sc "val %s : (%s, %s, %s, %s) Client.rpc" (Pb_codegen_util.function_name_of_rpc rpc) req req_mode res res_mode) - service.service_body); - F.line sc "end"; - F.empty_line sc - -(** generate the module type for the server (shared between .ml and .mli) *) -let gen_mod_type_of_service (service : Ot.service) sc : unit = - let mod_type_name = - Pb_codegen_util.module_type_name_of_service_server service - in + service.service_body; - F.linep sc "module type %s = sig" mod_type_name; - F.sub_scope sc (fun sc -> + (* server *) + F.empty_line sc; + F.line sc "(** Produce a server implementation from handlers *)"; + F.linep sc "val make_server : "; List.iter (fun (rpc : Ot.rpc) -> - F.linep sc "val %s : %s" + F.linep sc " %s:(%s -> 'handler) ->" (Pb_codegen_util.function_name_of_rpc rpc) - (string_of_rpc_handler_type rpc.rpc_req rpc.rpc_res)) - service.service_body); - F.line sc "end" - -let gen_service_server_struct (service : Ot.service) sc : unit = - let service_name = service.service_name in - let mod_type_name = - Pb_codegen_util.module_type_name_of_service_server service - in - - gen_mod_type_of_service service sc; - F.empty_line sc; - - (* now generate a function from the module type to a [Service_server.t] *) - F.linep sc "let service_impl_of_%s (module M:%s) : Pbrt_services.Server.t =" - (String.lowercase_ascii service_name) - mod_type_name; - F.sub_scope sc (fun sc -> - F.line sc "let open Pbrt_services.Server in"; - F.linep sc "{ service_name=%S;" service_name; - F.linep sc " package=%s;" - (string_list_of_package service.service_packages); - F.line sc " handlers=["; - List.iter - (fun (rpc : Ot.rpc) -> - let rpc_name = rpc.rpc_name in - - let handler = - let f = Pb_codegen_util.function_name_of_rpc rpc in - match rpc.rpc_req, rpc.rpc_res with - | Rpc_scalar _, Rpc_scalar _ -> spf "(Unary %s)" f - | Rpc_scalar _, Rpc_stream _ -> spf "(Server_stream %s)" f - | Rpc_stream _, Rpc_scalar _ -> spf "(Client_stream %s)" f - | Rpc_stream _, Rpc_stream _ -> spf "(Bidirectional_stream %s)" f - in - - F.linep sc " (mk_rpc ~name:%S" rpc.rpc_name; - F.linep sc " ~f:M.%s" handler; - F.linep sc " ~encode_json_res:%s" - (function_name_encode_json ~service_name ~rpc_name rpc.rpc_res); - F.linep sc " ~encode_pb_res:%s" - (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_res); - F.linep sc " ~decode_json_req:%s" - (function_name_decode_json ~service_name ~rpc_name rpc.rpc_req); - F.linep sc " ~decode_pb_req:%s" - (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_req); - F.linep sc " () : any_rpc);") + (string_of_server_rpc rpc.rpc_req rpc.rpc_res)) service.service_body; - F.line sc "]; }"); - F.empty_line sc + F.linep sc " unit -> 'handler Server.t"; -let gen_service_server_sig service sc : unit = - let mod_type_name = - Pb_codegen_util.module_type_name_of_service_server service - in + ()); - F.linep sc "(** Server interface for %s *)" service.service_name; - gen_mod_type_of_service service sc; - F.empty_line sc; - - F.linep sc "(** Convert {!%s} to a generic runtime service *)" mod_type_name; - F.linep sc "val service_impl_of_%s : (module %s) -> Pbrt_services.Server.t" - (String.lowercase_ascii service.service_name) - mod_type_name; - () + F.line sc "end"; + F.empty_line sc diff --git a/src/compilerlib/pb_codegen_services.mli b/src/compilerlib/pb_codegen_services.mli index 52cd9a1c..997b4c57 100644 --- a/src/compilerlib/pb_codegen_services.mli +++ b/src/compilerlib/pb_codegen_services.mli @@ -1,11 +1,5 @@ -val gen_service_client_sig : +val gen_service_sig : Pb_codegen_ocaml_type.service -> Pb_codegen_formatting.scope -> unit -val gen_service_client_struct : - Pb_codegen_ocaml_type.service -> Pb_codegen_formatting.scope -> unit - -val gen_service_server_sig : - Pb_codegen_ocaml_type.service -> Pb_codegen_formatting.scope -> unit - -val gen_service_server_struct : +val gen_service_struct : Pb_codegen_ocaml_type.service -> Pb_codegen_formatting.scope -> unit diff --git a/src/runtime-services/pbrt_services.ml b/src/runtime-services/pbrt_services.ml index 1821bd50..909c5e41 100644 --- a/src/runtime-services/pbrt_services.ml +++ b/src/runtime-services/pbrt_services.ml @@ -58,62 +58,14 @@ end (** Service stubs, server side *) module Server = struct - type ('req, 'res, 'state) client_stream_handler_with_state = { - init: unit -> 'state; (** When a stream starts *) - on_item: 'state -> 'req -> unit; - (** When an element of the stream is received. This can either - update the internal state by mutation, performing side effects, - or choose to return a value early and stop reading from the input stream. *) - on_close: 'state -> 'res; (** When the stream is over *) - } - (** Handler that receives a client stream and produces a value at the end. - It has an internal (mutable) state that is updated - every time an item is received from the client. *) - - (** A client stream handler with hidden internal state. *) - type ('req, 'res) client_stream_handler = - | Client_stream_handler : - ('req, 'res, 'state) client_stream_handler_with_state - -> ('req, 'res) client_stream_handler - [@@unboxed] - - type ('req, 'res) server_stream_handler = 'req -> 'res Push_stream.t -> unit - (** Takes the input value and a push stream (to send items to - the caller, and then close the stream at the end). - The stream's [close] function must be called exactly once. *) - - type ('req, 'res, 'state) bidirectional_stream_handler_with_state = { - init: unit -> 'res Push_stream.t -> 'state; - on_item: 'state -> 'req -> unit; - on_close: 'state -> unit; - } - (** Handler taking a stream of values and returning a stream as well. - It has an internal (mutable) state that can be updated everytime - an item is received from the client. *) - - (** A bidirectional handler with the internal state hidden *) - type ('req, 'res) bidirectional_stream_handler = - | Bidirectional_stream_handler : - ('req, 'res, 'state) bidirectional_stream_handler_with_state - -> ('req, 'res) bidirectional_stream_handler - [@@unboxed] - - (** A handler, i.e the server side implementation of a single RPC method. - Handlers come in various flavors because they make take, or return, - streams of values. *) - type ('req, 'res) handler = - | Unary of ('req -> 'res) - (** Simple unary handler, gets a value, returns a value. *) - | Client_stream of ('req, 'res) client_stream_handler - (** Handler that takes a client stream *) - | Server_stream of ('req, 'res) server_stream_handler - (** Handler that returns a stream to the client *) - | Bidirectional_stream of ('req, 'res) bidirectional_stream_handler - (** Handler that takes and returns a stream *) + type 'm mode = 'm Client.mode = + | Unary : Value_mode.unary mode + | Stream : Value_mode.stream mode - type ('req, 'res) rpc = { + type ('req, 'req_mode, 'res, 'res_mode) rpc = { name: string; - f: ('req, 'res) handler; + req_mode: 'req_mode mode; + res_mode: 'res_mode mode; encode_json_res: 'res -> Yojson.Basic.t; encode_pb_res: 'res -> Pbrt.Encoder.t -> unit; decode_json_req: Yojson.Basic.t -> 'req; @@ -123,36 +75,46 @@ module Server = struct input and output types. . *) (** A RPC endpoint. *) - type any_rpc = RPC : ('req, 'res) rpc -> any_rpc [@@unboxed] + type any_rpc = RPC : ('req, 'req_mode, 'res, 'res_mode) rpc -> any_rpc + [@@unboxed] (** Helper to build a RPC *) let mk_rpc : name:string -> - f:('req, 'res) handler -> + req_mode:'req_mode mode -> + res_mode:'res_mode mode -> encode_json_res:('res -> Yojson.Basic.t) -> encode_pb_res:('res -> Pbrt.Encoder.t -> unit) -> decode_json_req:(Yojson.Basic.t -> 'req) -> decode_pb_req:(Pbrt.Decoder.t -> 'req) -> unit -> - any_rpc = - fun ~name ~(f : _ handler) ~encode_json_res ~encode_pb_res ~decode_json_req - ~decode_pb_req () : any_rpc -> - RPC - { - name; - f; - decode_pb_req; - decode_json_req; - encode_pb_res; - encode_json_res; - } + ('req, 'req_mode, 'res, 'res_mode) rpc = + fun ~name ~req_mode ~res_mode ~encode_json_res ~encode_pb_res + ~decode_json_req ~decode_pb_req () -> + { + name; + req_mode; + res_mode; + decode_pb_req; + decode_json_req; + encode_pb_res; + encode_json_res; + } - type t = { - service_name: string; + type 'h handler = { + name: string; + handle: 'h; + } + (** A handler of some runtime-specific type. This + might be synchronous, monadic, streaming, etc. *) + + type 'h t = { + service_name: string; (** Name of the service *) package: string list; (** The package this belongs in (e.g. "bigco.auth.secretpasswordstash"), split along "." *) - handlers: any_rpc list; + handlers: 'h handler list; (** A list of handlers *) } - (** A service with fixed set of methods. *) + (** A service with fixed set of methods, which depends on the concrete RPC + implementation. *) end From fefa233486701077e0cb0ce0ebb1aab3ce844d9a Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Thu, 26 Oct 2023 22:41:02 -0400 Subject: [PATCH 30/46] refine code generation of service handlers --- src/compilerlib/pb_codegen_services.ml | 20 +++++++++----------- src/runtime-services/pbrt_services.ml | 11 ++--------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/src/compilerlib/pb_codegen_services.ml b/src/compilerlib/pb_codegen_services.ml index 86e8db09..0e5ca554 100644 --- a/src/compilerlib/pb_codegen_services.ml +++ b/src/compilerlib/pb_codegen_services.ml @@ -101,8 +101,8 @@ let gen_service_client_struct (service : Ot.service) sc : unit = (string_list_of_package service.service_packages); F.linep sc " ~service_name:%S ~rpc_name:%S" service.service_name rpc.rpc_name; - F.linep sc " ~req_mode:%s" req_mode_witness; - F.linep sc " ~res_mode:%s" res_mode_witness; + F.linep sc " ~req_mode:Client.%s" req_mode_witness; + F.linep sc " ~res_mode:Client.%s" res_mode_witness; F.linep sc " ~encode_json_req:%s" (function_name_encode_json ~service_name ~rpc_name rpc.rpc_req); F.linep sc " ~encode_pb_req:%s" @@ -134,8 +134,8 @@ let gen_service_server_struct (service : Ot.service) sc : unit = F.linep sc "let _rpc_%s : (%s,%s,%s,%s) Server.rpc = " name req req_mode res res_mode; F.linep sc " (Server.mk_rpc ~name:%S" rpc.rpc_name; - F.linep sc " ~req_mode:%s ~res_mode:%s" req_mode_witness - res_mode_witness; + F.linep sc " ~req_mode:Server.%s" req_mode_witness; + F.linep sc " ~res_mode:Server.%s" res_mode_witness; F.linep sc " ~encode_json_res:%s" (function_name_encode_json ~service_name ~rpc_name rpc.rpc_res); F.linep sc " ~encode_pb_res:%s" @@ -156,16 +156,14 @@ let gen_service_server_struct (service : Ot.service) sc : unit = F.linep sc " ~%s" name) service.service_body; F.line sc " () : _ Server.t ="; - F.linep sc " { Server."; - F.linep sc " service_name=%S;" service_name; - F.linep sc " package=%s;" - (string_list_of_package service.service_packages); - F.line sc " handlers=["; + F.linep sc " { Server."; + F.linep sc " service_name=%S;" service_name; + F.linep sc " package=%s;" (string_list_of_package service.service_packages); + F.line sc " handlers=["; List.iter (fun (rpc : Ot.rpc) -> let f = Pb_codegen_util.function_name_of_rpc rpc in - F.linep sc " {Server.name=%S; handle=%s %s};" rpc.rpc_name f - (spf "_rpc_%s" f)) + F.linep sc " (%s %s);" f (spf "_rpc_%s" f)) service.service_body; F.line sc " ];"; F.line sc " }"; diff --git a/src/runtime-services/pbrt_services.ml b/src/runtime-services/pbrt_services.ml index 909c5e41..59b5179a 100644 --- a/src/runtime-services/pbrt_services.ml +++ b/src/runtime-services/pbrt_services.ml @@ -101,20 +101,13 @@ module Server = struct encode_json_res; } - type 'h handler = { - name: string; - handle: 'h; - } - (** A handler of some runtime-specific type. This - might be synchronous, monadic, streaming, etc. *) - type 'h t = { service_name: string; (** Name of the service *) package: string list; (** The package this belongs in (e.g. "bigco.auth.secretpasswordstash"), split along "." *) - handlers: 'h handler list; (** A list of handlers *) + handlers: 'h list; (** A list of handlers *) } (** A service with fixed set of methods, which depends on the concrete RPC - implementation. *) + implementation. Each method is a handler of some type ['h]. *) end From de170cb811a6774343cdb9495c889fda614bc4c6 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 10 Sep 2023 22:11:48 -0400 Subject: [PATCH 31/46] =?UTF-8?q?add=20codegen=20for=20`make=5F=E2=80=A6`?= =?UTF-8?q?=20functions,=20more=20explicit=20than=20`default=5F=E2=80=A6`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit default arguments are only for actually optional fields. Everything else must be passed. --- src/compilerlib/dune | 2 +- src/compilerlib/pb_codegen_all.ml | 11 ++- src/compilerlib/pb_codegen_all.mli | 1 + src/compilerlib/pb_codegen_default.mli | 5 ++ src/compilerlib/pb_codegen_make.ml | 116 +++++++++++++++++++++++++ src/compilerlib/pb_codegen_make.mli | 3 + 6 files changed, 135 insertions(+), 3 deletions(-) create mode 100644 src/compilerlib/pb_codegen_make.ml create mode 100644 src/compilerlib/pb_codegen_make.mli diff --git a/src/compilerlib/dune b/src/compilerlib/dune index 7d2edf59..65a23289 100644 --- a/src/compilerlib/dune +++ b/src/compilerlib/dune @@ -8,7 +8,7 @@ (synopsis "Compiler library for ocaml-protoc, to turn .proto files into OCaml code") (wrapped true) (modules pb_codegen_all pb_codegen_backend pb_codegen_decode_binary pb_codegen_decode_bs - pb_codegen_decode_yojson pb_codegen_default pb_codegen_encode_binary + pb_codegen_decode_yojson pb_codegen_default pb_codegen_make pb_codegen_encode_binary pb_codegen_encode_bs pb_codegen_encode_yojson pb_codegen_formatting pb_codegen_ocaml_type pb_codegen_pp pb_codegen_plugin pb_codegen_types pb_codegen_services pb_codegen_util pb_exception pb_field_type pb_location pb_logger pb_option diff --git a/src/compilerlib/pb_codegen_all.ml b/src/compilerlib/pb_codegen_all.ml index 17e733c0..c88392cb 100644 --- a/src/compilerlib/pb_codegen_all.ml +++ b/src/compilerlib/pb_codegen_all.ml @@ -109,6 +109,12 @@ let generate_type_and_default (self : ocaml_mod) ocaml_types : unit = (Some Pb_codegen_default.ocamldoc_title); () +let generate_make (self : ocaml_mod) ocaml_types : unit = + generate_for_all_types ocaml_types self.ml Pb_codegen_make.gen_struct + (Some Pb_codegen_make.ocamldoc_title); + generate_for_all_types ocaml_types self.mli Pb_codegen_make.gen_sig + (Some Pb_codegen_make.ocamldoc_title) + let generate_mutable_records (self : ocaml_mod) ocaml_types : unit = let ocaml_types = List.flatten ocaml_types in List.iter @@ -141,12 +147,13 @@ let generate_plugin (self : ocaml_mod) ocaml_types (p : Plugin.t) : unit = generate_for_all_types ocaml_types self.mli P.gen_sig (Some P.ocamldoc_title); () -let codegen (proto : Ot.proto) ~proto_file_options ~proto_file_name - (plugins : Plugin.t list) : ocaml_mod = +let codegen (proto : Ot.proto) ~generate_make:gen_make ~proto_file_options + ~proto_file_name (plugins : Plugin.t list) : ocaml_mod = let self = new_ocaml_mod ~proto_file_options ~proto_file_name () in generate_type_and_default self proto.proto_types; if List.exists Pb_codegen_plugin.requires_mutable_records plugins then generate_mutable_records self proto.proto_types; + if gen_make then generate_make self proto.proto_types; List.iter (generate_plugin self proto.proto_types) plugins; (* services come last, they need binary and json *) diff --git a/src/compilerlib/pb_codegen_all.mli b/src/compilerlib/pb_codegen_all.mli index c8cf69ed..55fcc669 100644 --- a/src/compilerlib/pb_codegen_all.mli +++ b/src/compilerlib/pb_codegen_all.mli @@ -11,6 +11,7 @@ type ocaml_mod = { val codegen : Ot.proto -> + generate_make:bool -> proto_file_options:Pb_option.set -> proto_file_name:string -> Plugin.t list -> diff --git a/src/compilerlib/pb_codegen_default.mli b/src/compilerlib/pb_codegen_default.mli index 67b71e8b..8d20e084 100644 --- a/src/compilerlib/pb_codegen_default.mli +++ b/src/compilerlib/pb_codegen_default.mli @@ -2,5 +2,10 @@ include Pb_codegen_plugin.S +val record_field_default_info : + Pb_codegen_ocaml_type.record_field -> string * string * string +(** This function returns [(field_name, field_default_value, field_type)] for + a record field. *) + val gen_record_mutable : Pb_codegen_ocaml_type.record -> Pb_codegen_formatting.scope -> unit diff --git a/src/compilerlib/pb_codegen_make.ml b/src/compilerlib/pb_codegen_make.ml new file mode 100644 index 00000000..a9e7ce58 --- /dev/null +++ b/src/compilerlib/pb_codegen_make.ml @@ -0,0 +1,116 @@ +module Ot = Pb_codegen_ocaml_type +module F = Pb_codegen_formatting +open Pb_codegen_util + +(** Is this field optional enough that we give it a default value? *) +let field_is_optional (r_field : Ot.record_field) : bool = + match r_field.rf_field_type with + | Rft_optional _ -> true + | _ -> true + +(** Obtain information about the fields *) +let fields_of_record { Ot.r_fields; _ } : + (string * string * [ `Optional of _ | `Required ]) list = + List.map + (fun r_field -> + let fname, fdefault, ftype = + Pb_codegen_default.record_field_default_info r_field + in + if field_is_optional r_field then + fname, ftype, `Optional fdefault + else + fname, ftype, `Required) + r_fields + +let gen_record ?and_ ({ Ot.r_name; _ } as r) sc : unit = + let fields = fields_of_record r in + + F.linep sc "%s make_%s " (let_decl_of_and and_) r_name; + + F.sub_scope sc (fun sc -> + List.iter + (fun (fname, ftype, d) -> + match d with + | `Required -> F.linep sc "~(%s:%s)" fname ftype + | `Optional fvalue -> + F.linep sc "?%s:((%s:%s) = %s)" fname fname ftype fvalue) + fields; + F.linep sc "() : %s = {" r_name); + + F.sub_scope sc (fun sc -> + List.iter (fun (fname, _, _) -> F.linep sc "%s;" fname) fields); + + F.line sc "}" + +let gen_unit ?and_ { Ot.er_name } sc = + F.linep sc "%s make_%s = ()" (let_decl_of_and and_) er_name + +let gen_struct ?and_ t sc = + let { Ot.spec; _ } = t in + + let has_encoded = + match spec with + | Ot.Record r -> + gen_record ?and_ r sc; + true + | Ot.Const_variant _ | Ot.Variant _ -> + (* nothing for variants *) + false + | Ot.Unit u -> + gen_unit ?and_ u sc; + true + in + has_encoded + +let gen_sig_record sc ({ Ot.r_name; _ } as r) = + F.linep sc "val make_%s : " r_name; + + let fields : _ list = fields_of_record r in + + F.sub_scope sc (fun sc -> + List.iter + (fun (field_name, field_type, d) -> + match d with + | `Optional _ -> F.linep sc "?%s:%s ->" field_name field_type + | `Required -> F.linep sc "%s:%s ->" field_name field_type) + fields; + F.line sc "unit ->"; + F.line sc r_name); + let rn = r_name in + F.linep sc "(** [make_%s … ()] is a builder for type [%s] *)" rn rn + +let gen_sig_unit sc { Ot.er_name } = + F.linep sc "val make_%s : unit" er_name; + + let rn = er_name in + F.linep sc "(** [make_%s ()] is a builder for type [%s] *)" rn rn + +let gen_sig ?and_:_ t sc = + let f type_name = + F.linep sc "val make_%s : unit -> %s" type_name type_name; + F.linep sc "(** [make_%s … ()] is a builder for type [%s] *)" type_name + type_name + in + + let { Ot.spec; _ } = t in + + let has_encoded = + match spec with + | Ot.Record r -> + gen_sig_record sc r; + true + | Ot.Variant v -> + f v.Ot.v_name; + true + | Ot.Const_variant { Ot.cv_name; _ } -> + f cv_name; + true + | Ot.Unit u -> + gen_sig_unit sc u; + true + in + + has_encoded + +let ocamldoc_title = "Make functions" +let requires_mutable_records = false diff --git a/src/compilerlib/pb_codegen_make.mli b/src/compilerlib/pb_codegen_make.mli new file mode 100644 index 00000000..573153b0 --- /dev/null +++ b/src/compilerlib/pb_codegen_make.mli @@ -0,0 +1,3 @@ +(** Code generator for the [make] functions (i.e builders, but stricter than [default]) *) + +include Pb_codegen_plugin.S From 1d6afcc631a3ae23c7d9a3a0a01ddccda04294fa Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 10 Sep 2023 22:12:11 -0400 Subject: [PATCH 32/46] add `--make` option to main executable --- src/ocaml-protoc/ocaml_protoc_cmdline.ml | 3 +++ src/ocaml-protoc/ocaml_protoc_generation.ml | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/ocaml-protoc/ocaml_protoc_cmdline.ml b/src/ocaml-protoc/ocaml_protoc_cmdline.ml index b68e1086..52f3894a 100644 --- a/src/ocaml-protoc/ocaml_protoc_cmdline.ml +++ b/src/ocaml-protoc/ocaml_protoc_cmdline.ml @@ -111,6 +111,7 @@ module Cmdline = struct bs: bool ref; (** whether BuckleScript encoding is enabled *) pp: bool ref; (** whether pretty printing is enabled *) services: bool ref; (** whether services code generation is enabled *) + make: bool ref; (** whether to generate "make" functions *) mutable cmd_line_file_options: File_options.t; (** file options override from the cmd line *) unsigned_tag: bool ref; @@ -129,6 +130,7 @@ module Cmdline = struct bs = ref false; pp = ref false; services = ref false; + make = ref false; cmd_line_file_options = File_options.make (); unsigned_tag = ref false; } @@ -152,6 +154,7 @@ module Cmdline = struct ( "--unsigned", Arg.Set t.unsigned_tag, " tag uint32 and uint64 types with `unsigned" ); + "--make", Arg.Set t.make, " generate `make` functions"; ] @ File_options.cmd_line_args t.cmd_line_file_options diff --git a/src/ocaml-protoc/ocaml_protoc_generation.ml b/src/ocaml-protoc/ocaml_protoc_generation.ml index 51a385f3..73b46946 100644 --- a/src/ocaml-protoc/ocaml_protoc_generation.ml +++ b/src/ocaml-protoc/ocaml_protoc_generation.ml @@ -76,8 +76,8 @@ let generate_code ocaml_types ~proto_file_options cmdline : unit = in let ocaml_mod : CG_all.ocaml_mod = - CG_all.codegen ocaml_types ~proto_file_options - ~proto_file_name:cmdline.proto_file_name plugins + CG_all.codegen ocaml_types ~generate_make:!(cmdline.make) + ~proto_file_options ~proto_file_name:cmdline.proto_file_name plugins in (* now write the files *) From a5f8a687063ce42a6ee217c826e2fefe21eba906 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 17 Oct 2023 14:56:05 -0400 Subject: [PATCH 33/46] add basic test for `--make` --- src/tests/integration-tests/dune | 13 +++++++++++++ src/tests/integration-tests/test_make.proto | 21 +++++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 src/tests/integration-tests/test_make.proto diff --git a/src/tests/integration-tests/dune b/src/tests/integration-tests/dune index 1ac2574e..dee780fc 100644 --- a/src/tests/integration-tests/dune +++ b/src/tests/integration-tests/dune @@ -324,3 +324,16 @@ (name test_proto3_optional) (libraries pbrt) (modules test_proto3_optional_ml test_proto3_optional)) + +(rule + (targets test_make.ml test_make.mli) + (deps + (:proto test_make.proto) + ../../include/ocaml-protoc/ocamloptions.proto) + (action + (run ocaml-protoc --binary --make --ml_out ./ %{proto}))) + +(executable + (name test_make) + (libraries pbrt) + (modules test_make)) diff --git a/src/tests/integration-tests/test_make.proto b/src/tests/integration-tests/test_make.proto new file mode 100644 index 00000000..7928842f --- /dev/null +++ b/src/tests/integration-tests/test_make.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; + +enum FooEnum { + A = 0; + B = 1; +} + +message FooOuter { + string x = 1; + optional int32 y = 2; + repeated FooEnum enums = 3; + + message FooInner { + string inner_x = 1; + optional int64 inner_y = 2; + repeated float inner_z = 3; + repeated float inner_z2 = 4 [packed=true]; + } + + repeated FooInner foos = 10; +} From db1d45267941ecb36ebe497427dc9615a5d136a3 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 17 Oct 2023 14:58:46 -0400 Subject: [PATCH 34/46] fix: only codegen `make` functions for records --- src/compilerlib/pb_codegen_make.ml | 30 ++---------------------------- 1 file changed, 2 insertions(+), 28 deletions(-) diff --git a/src/compilerlib/pb_codegen_make.ml b/src/compilerlib/pb_codegen_make.ml index a9e7ce58..ac2b9c51 100644 --- a/src/compilerlib/pb_codegen_make.ml +++ b/src/compilerlib/pb_codegen_make.ml @@ -42,9 +42,6 @@ let gen_record ?and_ ({ Ot.r_name; _ } as r) sc : unit = F.line sc "}" -let gen_unit ?and_ { Ot.er_name } sc = - F.linep sc "%s make_%s = ()" (let_decl_of_and and_) er_name - let gen_struct ?and_ t sc = let { Ot.spec; _ } = t in @@ -53,12 +50,9 @@ let gen_struct ?and_ t sc = | Ot.Record r -> gen_record ?and_ r sc; true - | Ot.Const_variant _ | Ot.Variant _ -> + | Ot.Const_variant _ | Ot.Variant _ | Ot.Unit _ -> (* nothing for variants *) false - | Ot.Unit u -> - gen_unit ?and_ u sc; - true in has_encoded @@ -79,19 +73,7 @@ let gen_sig_record sc ({ Ot.r_name; _ } as r) = let rn = r_name in F.linep sc "(** [make_%s … ()] is a builder for type [%s] *)" rn rn -let gen_sig_unit sc { Ot.er_name } = - F.linep sc "val make_%s : unit" er_name; - - let rn = er_name in - F.linep sc "(** [make_%s ()] is a builder for type [%s] *)" rn rn - let gen_sig ?and_:_ t sc = - let f type_name = - F.linep sc "val make_%s : unit -> %s" type_name type_name; - F.linep sc "(** [make_%s … ()] is a builder for type [%s] *)" type_name - type_name - in - let { Ot.spec; _ } = t in let has_encoded = @@ -99,15 +81,7 @@ let gen_sig ?and_:_ t sc = | Ot.Record r -> gen_sig_record sc r; true - | Ot.Variant v -> - f v.Ot.v_name; - true - | Ot.Const_variant { Ot.cv_name; _ } -> - f cv_name; - true - | Ot.Unit u -> - gen_sig_unit sc u; - true + | Ot.Variant _ | Ot.Const_variant _ | Ot.Unit _ -> false in has_encoded From 3f46137cc70e913a47696d08e366a0286f92bb1e Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 17 Oct 2023 15:04:28 -0400 Subject: [PATCH 35/46] =?UTF-8?q?make=20make=20functions=20actually=20non-?= =?UTF-8?q?default=20=F0=9F=99=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/compilerlib/pb_codegen_make.ml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compilerlib/pb_codegen_make.ml b/src/compilerlib/pb_codegen_make.ml index ac2b9c51..17447876 100644 --- a/src/compilerlib/pb_codegen_make.ml +++ b/src/compilerlib/pb_codegen_make.ml @@ -6,7 +6,7 @@ open Pb_codegen_util let field_is_optional (r_field : Ot.record_field) : bool = match r_field.rf_field_type with | Rft_optional _ -> true - | _ -> true + | _ -> false (** Obtain information about the fields *) let fields_of_record { Ot.r_fields; _ } : From e82dee6b372506887983a1bc9afcb21f1fb787f3 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Thu, 7 Sep 2023 23:40:24 -0400 Subject: [PATCH 36/46] add more benchs; update the bench for encode-backward encoding backward in a buffer is now the fastest method! --- benchs/benchs.ml | 199 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 156 insertions(+), 43 deletions(-) diff --git a/benchs/benchs.ml b/benchs/benchs.ml index 2177d71c..5a6abb4d 100644 --- a/benchs/benchs.ml +++ b/benchs/benchs.ml @@ -108,7 +108,7 @@ module Dec = struct let of_bytes source = { source; offset = 0; limit = Bytes.length source } let of_string source = of_bytes (Bytes.unsafe_of_string source) - let byte d = + let[@inline] byte d = if d.offset >= d.limit then raise (Failure Incomplete); let byte = int_of_char (Bytes.get d.source d.offset) in d.offset <- d.offset + 1; @@ -192,7 +192,7 @@ module Dec = struct (let dec = of_string s in for _i = 0 to n do let _n = varint_imp dec in - () + Sys.opaque_identity (ignore _n) done) let test_imp2 n (s : string) = @@ -201,7 +201,7 @@ module Dec = struct (let dec = of_string s in for _i = 0 to n do let _n = varint_imp_noinline dec in - () + Sys.opaque_identity (ignore _n) done) let test_rec n (s : string) = @@ -210,7 +210,7 @@ module Dec = struct (let dec = of_string s in for _i = 0 to n do let _n = varint_rec dec in - () + Sys.opaque_identity (ignore _n) done) let test_rec2 n (s : string) = @@ -219,7 +219,7 @@ module Dec = struct (let dec = of_string s in for _i = 0 to n do let _n = varint_rec_noinline dec in - () + Sys.opaque_identity (ignore _n) done) (* sanity check *) @@ -262,6 +262,117 @@ let () = "enc" @>>> [ test_enc 5; test_enc 10; test_enc 50; test_enc 1000 ]; ] +module Dec_bits64 = struct + open Dec + + (* put the int64 integers from 0 to n in a dec *) + let mk_buf_n n : string = + let enc = Pbrt.Encoder.create () in + for i = 0 to n do + Pbrt.Encoder.int_as_bits64 i enc + done; + Pbrt.Encoder.to_string enc + + let test_imp n (s : string) = + mk_t "dec-varint-imp" @@ fun () -> + Sys.opaque_identity + (let dec = of_string s in + for _i = 0 to n do + let _n = varint_imp dec in + Sys.opaque_identity (ignore _n) + done) + + let bits64_basic (d : t) = + let b1 = byte d in + let b2 = byte d in + let b3 = byte d in + let b4 = byte d in + let b5 = byte d in + let b6 = byte d in + let b7 = byte d in + let b8 = byte d in + Int64.( + add + (shift_left (of_int b8) 56) + (add + (shift_left (of_int b7) 48) + (add + (shift_left (of_int b6) 40) + (add + (shift_left (of_int b5) 32) + (add + (shift_left (of_int b4) 24) + (add + (shift_left (of_int b3) 16) + (add (shift_left (of_int b2) 8) (of_int b1)))))))) + + let bits64_loop (d : t) = + let res = ref 0L in + for i = 0 to 7 do + let byte = byte d in + res := Int64.(logor !res (shift_left (of_int byte) (8 * i))) + done; + !res + + let test_basic s n = + mk_t "dec-bits64-basic" @@ fun () -> + Sys.opaque_identity + (let dec = of_string s in + for _i = 0 to n do + let _n = bits64_basic dec in + Sys.opaque_identity (ignore _n) + done) + + let test_loop s n = + mk_t "dec-bits64-loop" @@ fun () -> + Sys.opaque_identity + (let dec = of_string s in + for _i = 0 to n do + let _n = bits64_loop dec in + Sys.opaque_identity (ignore _n) + done) + + (* sanity check *) + let () = + let n = 5 in + let s = mk_buf_n n in + + let dec_to_l f = + let dec = of_string s in + let l = ref [] in + for _i = 0 to n do + let n = f dec in + l := Int64.to_int n :: !l + done; + List.rev !l + in + assert (dec_to_l bits64_basic = [ 0; 1; 2; 3; 4; 5 ]); + assert (dec_to_l bits64_loop = [ 0; 1; 2; 3; 4; 5 ]); + () +end + +let test_dec_bits64 n = + let open B.Tree in + let s = Dec_bits64.mk_buf_n n in + Printf.sprintf "%d" n + @> lazy + (B.throughputN ~repeat:3 4 + [ Dec_bits64.test_basic s n; Dec_bits64.test_loop s n ]) + +let () = + let open B.Tree in + register @@ "bits64" + @>>> [ + "dec" + @>>> [ + test_dec_bits64 5; + test_dec_bits64 10; + test_dec_bits64 50; + test_dec_bits64 1000; + ]; + (* "enc" @>>> [ test_enc 5; test_enc 10; test_enc 50; test_enc 1000 ]; *) + ] + module Nested = struct type person = Foo.person = { name: string; @@ -503,17 +614,19 @@ module Nested = struct self.b <- b'; self.start <- newcap - n - let[@inline never] grow_ self = - assert (self.start = 0); + let next_cap_ (self : t) : int = let n = cap self in - let newcap = n + (n lsr 1) + 3 in - grow_to_ self newcap; - assert (self.start > 0) - - let[@inline] add_char (self : t) (c : char) : unit = - if self.start = 0 then grow_ self; - self.start <- self.start - 1; - Bytes.unsafe_set self.b self.start c + n + (n lsr 1) + 3 + + (** Reserve [n] bytes, return the offset at which we can write them. *) + let reserve_n (self : t) (n : int) : int = + if self.start < n then ( + let newcap = max (cap self + n) (next_cap_ self) in + grow_to_ self newcap; + assert (self.start >= n) + ); + self.start <- self.start - n; + self.start let add_bytes (self : t) (b : bytes) = let n = Bytes.length b in @@ -522,36 +635,36 @@ module Nested = struct self.start <- self.start - n; () - let varint i (e : t) = - let[@unroll 2] rec write i = - let cur = Int64.(logand i 0x7fL) in - if cur = i then - add_char e (Char.unsafe_chr Int64.(to_int cur)) + (** Number of bytes to encode [i] *) + let varint_size (i : int64) : int = + let i = ref i in + let n = ref 0 in + let continue = ref true in + while !continue do + incr n; + let cur = Int64.(logand !i 0x7fL) in + if cur = !i then + continue := false + else + i := Int64.shift_right_logical !i 7 + done; + !n + + let[@inline] varint (i : int64) (e : t) : unit = + let n_bytes = varint_size i in + let start = reserve_n e n_bytes in + + let i = ref i in + for j = 0 to n_bytes - 1 do + let cur = Int64.(logand !i 0x7fL) in + if j = n_bytes - 1 then + Bytes.set e.b (start + j) (Char.unsafe_chr Int64.(to_int cur)) else ( - write (Int64.shift_right_logical i 7); - add_char e (Char.unsafe_chr Int64.(to_int (logor 0x80L cur))) + Bytes.set e.b (start + j) + (Char.unsafe_chr Int64.(to_int (logor 0x80L cur))); + i := Int64.shift_right_logical !i 7 ) - in - write i - - (* TODO: can we do this in a loop? - let varint (i:int64) (e:t) = - let i = ref i in - let continue = ref true in - while !continue do - let cur = Int64.(logand !i 0x7fL) in - if cur = !i - then ( - continue := false; - add_char e (Char.unsafe_chr Int64.(to_int cur)) - ) else ( - add_char e - (Char.unsafe_chr Int64.( to_int (logor 0x80L cur) - )); - i := Int64.shift_right_logical !i 7; - ) - done - *) + done let int64_as_varint = varint let int_as_varint i e = varint (Int64.of_int i) e From 75d83ede719a402349a2efac455be8e2aa49d6ba Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Thu, 7 Sep 2023 23:40:53 -0400 Subject: [PATCH 37/46] perf: make `Decoder.byte` inline --- src/runtime/pbrt.ml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/runtime/pbrt.ml b/src/runtime/pbrt.ml index 1b7d7cdd..f6c39f56 100644 --- a/src/runtime/pbrt.ml +++ b/src/runtime/pbrt.ml @@ -96,11 +96,14 @@ module Decoder = struct let unexpected_payload field_name pk = raise (Failure (Unexpected_payload (field_name, pk))) - let missing_field field_name = raise (Failure (Missing_field field_name)) + let[@inline never] missing_field field_name = + raise (Failure (Missing_field field_name)) + + let[@inline never] incomplete () = raise (Failure Incomplete) let at_end d = d.limit = d.offset - let byte d = - if d.offset >= d.limit then raise (Failure Incomplete); + let[@inline] byte d = + if d.offset >= d.limit then incomplete (); let byte = int_of_char (Bytes.get d.source d.offset) in d.offset <- d.offset + 1; byte From d2f9542da732f8654913f484902e7594260d6ad8 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 8 Sep 2023 00:15:13 -0400 Subject: [PATCH 38/46] more benchs --- benchs/benchs.ml | 215 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 192 insertions(+), 23 deletions(-) diff --git a/benchs/benchs.ml b/benchs/benchs.ml index 5a6abb4d..dce6d6ca 100644 --- a/benchs/benchs.ml +++ b/benchs/benchs.ml @@ -263,7 +263,8 @@ let () = ] module Dec_bits64 = struct - open Dec + open Pbrt.Decoder + open! Dec (* put the int64 integers from 0 to n in a dec *) let mk_buf_n n : string = @@ -273,16 +274,14 @@ module Dec_bits64 = struct done; Pbrt.Encoder.to_string enc - let test_imp n (s : string) = - mk_t "dec-varint-imp" @@ fun () -> - Sys.opaque_identity - (let dec = of_string s in - for _i = 0 to n do - let _n = varint_imp dec in - Sys.opaque_identity (ignore _n) - done) + (** Read 8 bytes at once, return offset of the first one. *) + let get8 (self : t) : int = + if self.offset + 8 > self.limit then raise (Failure Incomplete); + let n = self.offset in + self.offset <- self.offset + 8; + n - let bits64_basic (d : t) = + let bits64_unrolled (d : t) = let b1 = byte d in let b2 = byte d in let b3 = byte d in @@ -306,6 +305,31 @@ module Dec_bits64 = struct (shift_left (of_int b3) 16) (add (shift_left (of_int b2) 8) (of_int b1)))))))) + let bits64_unrolled_single_read (d : t) = + let off = get8 d in + let b1 = int_of_char @@ Bytes.unsafe_get d.source off in + let b2 = int_of_char @@ Bytes.unsafe_get d.source (off + 1) in + let b3 = int_of_char @@ Bytes.unsafe_get d.source (off + 2) in + let b4 = int_of_char @@ Bytes.unsafe_get d.source (off + 3) in + let b5 = int_of_char @@ Bytes.unsafe_get d.source (off + 4) in + let b6 = int_of_char @@ Bytes.unsafe_get d.source (off + 5) in + let b7 = int_of_char @@ Bytes.unsafe_get d.source (off + 6) in + let b8 = int_of_char @@ Bytes.unsafe_get d.source (off + 7) in + Int64.( + add + (shift_left (of_int b8) 56) + (add + (shift_left (of_int b7) 48) + (add + (shift_left (of_int b6) 40) + (add + (shift_left (of_int b5) 32) + (add + (shift_left (of_int b4) 24) + (add + (shift_left (of_int b3) 16) + (add (shift_left (of_int b2) 8) (of_int b1)))))))) + let bits64_loop (d : t) = let res = ref 0L in for i = 0 to 7 do @@ -314,12 +338,25 @@ module Dec_bits64 = struct done; !res - let test_basic s n = - mk_t "dec-bits64-basic" @@ fun () -> + let bits64_from_stdlib (d : t) : int64 = + let off = get8 d in + Bytes.get_int64_le d.source off + + let test_unrolled s n = + mk_t "dec-bits64-unrolled" @@ fun () -> + Sys.opaque_identity + (let dec = of_string s in + for _i = 0 to n do + let _n = bits64_unrolled dec in + Sys.opaque_identity (ignore _n) + done) + + let test_unrolled_single_read s n = + mk_t "dec-bits64-unrolled-single-read" @@ fun () -> Sys.opaque_identity (let dec = of_string s in for _i = 0 to n do - let _n = bits64_basic dec in + let _n = bits64_unrolled_single_read dec in Sys.opaque_identity (ignore _n) done) @@ -332,6 +369,15 @@ module Dec_bits64 = struct Sys.opaque_identity (ignore _n) done) + let test_stdlib s n = + mk_t "dec-bits64-stdlib" @@ fun () -> + Sys.opaque_identity + (let dec = of_string s in + for _i = 0 to n do + let _n = bits64_from_stdlib dec in + Sys.opaque_identity (ignore _n) + done) + (* sanity check *) let () = let n = 5 in @@ -346,8 +392,10 @@ module Dec_bits64 = struct done; List.rev !l in - assert (dec_to_l bits64_basic = [ 0; 1; 2; 3; 4; 5 ]); + assert (dec_to_l bits64_unrolled = [ 0; 1; 2; 3; 4; 5 ]); + assert (dec_to_l bits64_unrolled_single_read = [ 0; 1; 2; 3; 4; 5 ]); assert (dec_to_l bits64_loop = [ 0; 1; 2; 3; 4; 5 ]); + assert (dec_to_l bits64_from_stdlib = [ 0; 1; 2; 3; 4; 5 ]); () end @@ -357,7 +405,12 @@ let test_dec_bits64 n = Printf.sprintf "%d" n @> lazy (B.throughputN ~repeat:3 4 - [ Dec_bits64.test_basic s n; Dec_bits64.test_loop s n ]) + [ + Dec_bits64.test_unrolled s n; + Dec_bits64.test_unrolled_single_read s n; + Dec_bits64.test_loop s n; + Dec_bits64.test_stdlib s n; + ]) let () = let open B.Tree in @@ -618,25 +671,26 @@ module Nested = struct let n = cap self in n + (n lsr 1) + 3 + let[@inline never] grow_reserve_n (self : t) n : unit = + let newcap = max (cap self + n) (next_cap_ self) in + grow_to_ self newcap; + assert (self.start >= n) + (** Reserve [n] bytes, return the offset at which we can write them. *) - let reserve_n (self : t) (n : int) : int = - if self.start < n then ( - let newcap = max (cap self + n) (next_cap_ self) in - grow_to_ self newcap; - assert (self.start >= n) - ); + let[@inline] reserve_n (self : t) (n : int) : int = + if self.start < n then grow_reserve_n self n; self.start <- self.start - n; self.start let add_bytes (self : t) (b : bytes) = let n = Bytes.length b in - if self.start - n <= 0 then grow_to_ self (cap self + n + (n lsr 1) + 1); + if self.start <= n then grow_to_ self (cap self + n + (n lsr 1) + 1); Bytes.blit b 0 self.b (self.start - n) n; self.start <- self.start - n; () (** Number of bytes to encode [i] *) - let varint_size (i : int64) : int = + let[@inline] varint_size (i : int64) : int = let i = ref i in let n = ref 0 in let continue = ref true in @@ -707,15 +761,127 @@ module Nested = struct int_as_varint size e end) + module From_back2 = Make_bench (struct + let name_of_enc = "write-backward2" + + type t = from_back_end + + let create () : t = { b = Bytes.create 16; start = 16 } + let[@inline] clear self = self.start <- Bytes.length self.b + let[@inline] cap self = Bytes.length self.b + let[@inline] length self = cap self - self.start + + let to_string self : string = + Bytes.sub_string self.b self.start (length self) + + let grow_to_ self newcap = + let n = length self in + let b' = Bytes.create newcap in + Bytes.blit self.b self.start b' (newcap - n) n; + self.b <- b'; + self.start <- newcap - n + + let next_cap_ (self : t) : int = + let n = cap self in + n + (n lsr 1) + 3 + + let[@inline never] grow_reserve_n (self : t) n : unit = + let newcap = max (cap self + n) (next_cap_ self) in + grow_to_ self newcap; + assert (self.start >= n) + + (** Reserve [n] bytes, return the offset at which we can write them. *) + let[@inline] reserve_n (self : t) (n : int) : int = + if self.start < n then grow_reserve_n self n; + self.start <- self.start - n; + self.start + + let add_bytes (self : t) (b : bytes) = + let n = Bytes.length b in + if self.start <= n then grow_to_ self (cap self + n + (n lsr 1) + 1); + Bytes.blit b 0 self.b (self.start - n) n; + self.start <- self.start - n; + () + + let[@inline] varint (i : int64) (e : t) : unit = + let n_bytes = ref 0 in + (let i = ref i in + let continue = ref true in + while !continue do + incr n_bytes; + let cur = Int64.(logand !i 0x7fL) in + if cur = !i then + continue := false + else + i := Int64.shift_right_logical !i 7 + done); + + let start = reserve_n e !n_bytes in + + let i = ref i in + for j = 0 to !n_bytes - 1 do + let cur = Int64.(logand !i 0x7fL) in + if j = !n_bytes - 1 then + Bytes.set e.b (start + j) (Char.unsafe_chr Int64.(to_int cur)) + else ( + Bytes.set e.b (start + j) + (Char.unsafe_chr Int64.(to_int (logor 0x80L cur))); + i := Int64.shift_right_logical !i 7 + ) + done + + let int64_as_varint = varint + let int_as_varint i e = varint (Int64.of_int i) e + + let[@inline] key k pk f e = + let pk' = + match pk with + | Varint -> 0 + | Bits64 -> 1 + | Bytes -> 2 + | Bits32 -> 5 + in + f e; + int_as_varint (pk' lor (k lsl 3)) e; + (* write this after the data *) + () + + let bytes b (e : t) = + add_bytes e b; + int_as_varint (Bytes.length b) e; + () + + let string s e = bytes (Bytes.unsafe_of_string s) e + + (* encode lists in reverse order *) + let list f l e = + let rec loop = function + | [] -> () + | [ x ] -> f x e + | x :: tl -> + loop tl; + f x e + in + loop l + + let nested f (e : t) = + let s0 = length e in + f e; + let size = length e - s0 in + int_as_varint size e + end) + let bench_basic = Basic.bench let bench_buffers_nested = Buffers_nested.bench let bench_from_back = From_back.bench + let bench_from_back2 = From_back2.bench (* sanity check *) let () = let s_basic = Basic.string_of_company (mk_company 1) in let s_buffers_nested = Buffers_nested.string_of_company (mk_company 1) in let s_from_back = From_back.string_of_company (mk_company 1) in + let s_from_back2 = From_back2.string_of_company (mk_company 1) in (* Printf.printf "basic: (len=%d) %S\n" (String.length s_basic) s_basic; Printf.printf "from_back: (len=%d) %S\n" (String.length s_from_back) s_from_back; @@ -728,12 +894,14 @@ module Nested = struct let c_basic = dec_s s_basic in let c_buffers_nested = dec_s s_buffers_nested in let c_from_back = dec_s s_from_back in + let c_from_back2 = dec_s s_from_back2 in (* Format.printf "c_basic=%a@." Foo_pp.pp_company c_basic; Format.printf "c_from_back=%a@." Foo_pp.pp_company c_from_back; *) assert (c_basic = c_buffers_nested); assert (c_basic = c_from_back); + assert (c_basic = c_from_back2); () end @@ -747,6 +915,7 @@ let test_nested_enc n = Nested.bench_basic company; Nested.bench_buffers_nested company; Nested.bench_from_back company; + Nested.bench_from_back2 company; ]) let () = From a9ed5a86819e9accec128d8d501b530825b38d7f Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 6 Oct 2023 22:53:31 -0400 Subject: [PATCH 39/46] compare with noinline for backward encoding --- benchs/benchs.ml | 45 ++++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/benchs/benchs.ml b/benchs/benchs.ml index dce6d6ca..dc0f26bc 100644 --- a/benchs/benchs.ml +++ b/benchs/benchs.ml @@ -761,8 +761,8 @@ module Nested = struct int_as_varint size e end) - module From_back2 = Make_bench (struct - let name_of_enc = "write-backward2" + module From_back_noinline = Make_bench (struct + let name_of_enc = "write-backward_noinline" type t = from_back_end @@ -803,25 +803,28 @@ module Nested = struct self.start <- self.start - n; () - let[@inline] varint (i : int64) (e : t) : unit = - let n_bytes = ref 0 in - (let i = ref i in - let continue = ref true in - while !continue do - incr n_bytes; - let cur = Int64.(logand !i 0x7fL) in - if cur = !i then - continue := false - else - i := Int64.shift_right_logical !i 7 - done); - - let start = reserve_n e !n_bytes in + let[@inline never] varint_size (i : int64) : int = + let i = ref i in + let n = ref 0 in + let continue = ref true in + while !continue do + incr n; + let cur = Int64.(logand !i 0x7fL) in + if cur = !i then + continue := false + else + i := Int64.shift_right_logical !i 7 + done; + !n + + let[@inline never] varint (i : int64) (e : t) : unit = + let n_bytes = varint_size i in + let start = reserve_n e n_bytes in let i = ref i in - for j = 0 to !n_bytes - 1 do + for j = 0 to n_bytes - 1 do let cur = Int64.(logand !i 0x7fL) in - if j = !n_bytes - 1 then + if j = n_bytes - 1 then Bytes.set e.b (start + j) (Char.unsafe_chr Int64.(to_int cur)) else ( Bytes.set e.b (start + j) @@ -874,14 +877,14 @@ module Nested = struct let bench_basic = Basic.bench let bench_buffers_nested = Buffers_nested.bench let bench_from_back = From_back.bench - let bench_from_back2 = From_back2.bench + let bench_from_back_noinline = From_back_noinline.bench (* sanity check *) let () = let s_basic = Basic.string_of_company (mk_company 1) in let s_buffers_nested = Buffers_nested.string_of_company (mk_company 1) in let s_from_back = From_back.string_of_company (mk_company 1) in - let s_from_back2 = From_back2.string_of_company (mk_company 1) in + let s_from_back2 = From_back_noinline.string_of_company (mk_company 1) in (* Printf.printf "basic: (len=%d) %S\n" (String.length s_basic) s_basic; Printf.printf "from_back: (len=%d) %S\n" (String.length s_from_back) s_from_back; @@ -915,7 +918,7 @@ let test_nested_enc n = Nested.bench_basic company; Nested.bench_buffers_nested company; Nested.bench_from_back company; - Nested.bench_from_back2 company; + Nested.bench_from_back_noinline company; ]) let () = From 45c5f1dce699137ede8bc15d05fa38fa60bc0357 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 7 Oct 2023 00:45:06 -0400 Subject: [PATCH 40/46] renforce the C bindings (do varint in C, not varint_size) --- benchs/benchs.ml | 138 ++++++++++++++++++++++++++++++++++++++++++++--- benchs/dune | 1 + 2 files changed, 132 insertions(+), 7 deletions(-) diff --git a/benchs/benchs.ml b/benchs/benchs.ml index dc0f26bc..f436c031 100644 --- a/benchs/benchs.ml +++ b/benchs/benchs.ml @@ -514,7 +514,7 @@ module Nested = struct let bench company = let enc = E.create () in - mk_t (spf "nested-enc-%s" E.name_of_enc) @@ fun () -> + mk_t (spf "nenc-%s" E.name_of_enc) @@ fun () -> Sys.opaque_identity (E.clear enc; enc_company company enc) @@ -721,7 +721,7 @@ module Nested = struct done let int64_as_varint = varint - let int_as_varint i e = varint (Int64.of_int i) e + let[@inline] int_as_varint i e = varint (Int64.of_int i) e let[@inline] key k pk f e = let pk' = @@ -762,7 +762,7 @@ module Nested = struct end) module From_back_noinline = Make_bench (struct - let name_of_enc = "write-backward_noinline" + let name_of_enc = "write-backward-noinline" type t = from_back_end @@ -834,7 +834,121 @@ module Nested = struct done let int64_as_varint = varint - let int_as_varint i e = varint (Int64.of_int i) e + let[@inline] int_as_varint i e = varint (Int64.of_int i) e + + let[@inline] key k pk f e = + let pk' = + match pk with + | Varint -> 0 + | Bits64 -> 1 + | Bytes -> 2 + | Bits32 -> 5 + in + f e; + int_as_varint (pk' lor (k lsl 3)) e; + (* write this after the data *) + () + + let bytes b (e : t) = + add_bytes e b; + int_as_varint (Bytes.length b) e; + () + + let string s e = bytes (Bytes.unsafe_of_string s) e + + (* encode lists in reverse order *) + let list f l e = + let rec loop = function + | [] -> () + | [ x ] -> f x e + | x :: tl -> + loop tl; + f x e + in + loop l + + let nested f (e : t) = + let s0 = length e in + f e; + let size = length e - s0 in + int_as_varint size e + end) + + module From_back_c = Make_bench (struct + let name_of_enc = "write-backward-c" + + type t = from_back_end + + let create () : t = { b = Bytes.create 16; start = 16 } + let[@inline] clear self = self.start <- Bytes.length self.b + let[@inline] cap self = Bytes.length self.b + let[@inline] length self = cap self - self.start + + let to_string self : string = + Bytes.sub_string self.b self.start (length self) + + let grow_to_ self newcap = + let n = length self in + let b' = Bytes.create newcap in + Bytes.blit self.b self.start b' (newcap - n) n; + self.b <- b'; + self.start <- newcap - n + + let next_cap_ (self : t) : int = + let n = cap self in + n + (n lsr 1) + 3 + + let[@inline never] grow_reserve_n (self : t) n : unit = + let newcap = max (cap self + n) (next_cap_ self) in + grow_to_ self newcap; + assert (self.start >= n) + + (** Reserve [n] bytes, return the offset at which we can write them. *) + let[@inline] reserve_n (self : t) (n : int) : int = + if self.start < n then grow_reserve_n self n; + self.start <- self.start - n; + self.start + + let add_bytes (self : t) (b : bytes) = + let n = Bytes.length b in + if self.start <= n then grow_to_ self (cap self + n + (n lsr 1) + 1); + Bytes.blit b 0 self.b (self.start - n) n; + self.start <- self.start - n; + () + + (* + external varint_size : (int64[@unboxed]) -> int + = "caml_pbrt_varint_size_byte" "caml_pbrt_varint_size" + [@@noalloc] + *) + + (* keep this in OCaml because the C overhead is non trivial *) + let[@inline] varint_size (i : int64) : int = + let i = ref i in + let n = ref 0 in + let continue = ref true in + while !continue do + incr n; + let cur = Int64.(logand !i 0x7fL) in + if cur = !i then + continue := false + else + i := Int64.shift_right_logical !i 7 + done; + !n + + external varint_slice : + bytes -> (int[@untagged]) -> (int64[@unboxed]) -> unit + = "caml_pbrt_varint_byte" "caml_pbrt_varint" + [@@noalloc] + + let[@inline] varint (i : int64) (e : t) : unit = + let n_bytes = varint_size i in + let start = reserve_n e n_bytes in + varint_slice e.b start i + + let int64_as_varint = varint + let[@inline] int_as_varint i e = varint (Int64.of_int i) e let[@inline] key k pk f e = let pk' = @@ -878,6 +992,7 @@ module Nested = struct let bench_buffers_nested = Buffers_nested.bench let bench_from_back = From_back.bench let bench_from_back_noinline = From_back_noinline.bench + let bench_from_back_c = From_back_c.bench (* sanity check *) let () = @@ -885,10 +1000,16 @@ module Nested = struct let s_buffers_nested = Buffers_nested.string_of_company (mk_company 1) in let s_from_back = From_back.string_of_company (mk_company 1) in let s_from_back2 = From_back_noinline.string_of_company (mk_company 1) in + let s_from_backc = From_back_c.string_of_company (mk_company 1) in (* - Printf.printf "basic: (len=%d) %S\n" (String.length s_basic) s_basic; - Printf.printf "from_back: (len=%d) %S\n" (String.length s_from_back) s_from_back; - *) + Printf.printf "basic:\n(len=%d) %S\n" (String.length s_basic) s_basic; + Printf.printf "from_back:\n(len=%d) %S\n" + (String.length s_from_back) + s_from_back; + Printf.printf "from_back_c:\n(len=%d) %S\n" + (String.length s_from_backc) + s_from_backc; + *) let dec_s s = Pbrt.Decoder.( let dec = of_string s in @@ -898,6 +1019,7 @@ module Nested = struct let c_buffers_nested = dec_s s_buffers_nested in let c_from_back = dec_s s_from_back in let c_from_back2 = dec_s s_from_back2 in + let c_from_backc = dec_s s_from_backc in (* Format.printf "c_basic=%a@." Foo_pp.pp_company c_basic; Format.printf "c_from_back=%a@." Foo_pp.pp_company c_from_back; @@ -905,6 +1027,7 @@ module Nested = struct assert (c_basic = c_buffers_nested); assert (c_basic = c_from_back); assert (c_basic = c_from_back2); + assert (c_basic = c_from_backc); () end @@ -919,6 +1042,7 @@ let test_nested_enc n = Nested.bench_buffers_nested company; Nested.bench_from_back company; Nested.bench_from_back_noinline company; + Nested.bench_from_back_c company; ]) let () = diff --git a/benchs/dune b/benchs/dune index 8a8d0641..207e4ef7 100644 --- a/benchs/dune +++ b/benchs/dune @@ -1,6 +1,7 @@ (executable (name benchs) (ocamlopt_flags :standard -inline 100) + (foreign_stubs (language c) (flags :standard -std=c99 -O2) (names stubs)) (libraries ocaml-protoc benchmark)) (rule From cac2a34030666283bcd0bf9e4a45d3dcc3c72212 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 7 Oct 2023 16:55:55 -0400 Subject: [PATCH 41/46] add missing stubs --- benchs/stubs.c | 80 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 benchs/stubs.c diff --git a/benchs/stubs.c b/benchs/stubs.c new file mode 100644 index 00000000..ce2398ac --- /dev/null +++ b/benchs/stubs.c @@ -0,0 +1,80 @@ + +#include +#include +#include +#include +#include + +inline int pbrt_varint_size(int64_t i) { + int n = 0; + while (1) { + n++; + int64_t cur = i & 0x7f; + if (cur == i) + break; + i = i >> 7; + } + return n; +} + +// number of bytes for i +CAMLprim value caml_pbrt_varint_size(int64_t i) { + int res = pbrt_varint_size(i); + return Val_int(res); +} + +CAMLprim value caml_pbrt_varint_size_byte(value v_i) { + CAMLparam1(v_i); + + int64_t i = Int64_val(v_i); + int res = pbrt_varint_size(i); + CAMLreturn(Val_int(res)); +} + +// write i at str[idx…] +inline void pbrt_varint(unsigned char *str, int64_t i) { + while (true) { + int64_t cur = i & 0x7f; + if (cur == i) { + *str = (unsigned char)cur; + break; + } else { + *str = (unsigned char)(cur | 0x80); + i = i >> 7; + ++str; + } + } +} + +// let[@inline] varint (i : int64) (e : t) : unit = +// let n_bytes = varint_size i in +// let start = reserve_n e n_bytes in +// +// let i = ref i in +// for j = 0 to n_bytes - 1 do +// let cur = Int64.(logand !i 0x7fL) in +// if j = n_bytes - 1 then +// Bytes.set e.b (start + j) (Char.unsafe_chr Int64.(to_int cur)) +// else ( +// Bytes.set e.b (start + j) +// (Char.unsafe_chr Int64.(to_int (logor 0x80L cur))); +// i := Int64.shift_right_logical !i 7 +// ) +// done + +// write `i` starting at `idx` +CAMLprim value caml_pbrt_varint(value _str, intnat idx, int64_t i) { + CAMLparam1(_str); + char *str = Bytes_val(_str); + pbrt_varint(str + idx, i); + CAMLreturn(Val_unit); +} + +CAMLprim value caml_pbrt_varint_bytes(value _str, value _idx, value _i) { + CAMLparam3(_str, _idx, _i); + char *str = Bytes_val(_str); + int idx = Int_val(_idx); + int64_t i = Int64_val(_idx); + pbrt_varint(str + idx, i); + CAMLreturn(Val_unit); +} From e9f76aada4c0786f48ffe7251143454b65e44a85 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 21 Nov 2023 21:20:48 -0500 Subject: [PATCH 42/46] benchs: have more nesting in encoding bench --- benchs/benchs.ml | 106 +++++++++++++++++++++++++++++++++++------------ benchs/foo.proto | 1 + 2 files changed, 81 insertions(+), 26 deletions(-) diff --git a/benchs/benchs.ml b/benchs/benchs.ml index f436c031..3aee461a 100644 --- a/benchs/benchs.ml +++ b/benchs/benchs.ml @@ -441,6 +441,7 @@ module Nested = struct type company = Foo.company = { name: string; stores: store list; + subsidiaries: company list; } type payload_kind = Pbrt.payload_kind = @@ -479,29 +480,40 @@ module Nested = struct E.list (fun p e -> E.key 3 Bytes (E.nested (enc_person p)) e) st.clients e; () - let enc_company (c : company) (e : E.t) : unit = + let rec enc_company (c : company) (e : E.t) : unit = E.key 1 Bytes (E.string c.name) e; E.list (fun st e -> E.key 2 Bytes (E.nested (enc_store st)) e) c.stores e; + E.list + (fun st c -> E.key 3 Bytes (E.nested (enc_company st)) c) + c.subsidiaries e; () end let spf = Printf.sprintf - let mk_company n = + (* company, with [n] stores and [2^depth] subsidiaries *) + let rec mk_company ~n ~depth : company = { name = "bigcorp"; + subsidiaries = + (if depth = 0 then + [] + else ( + let c = mk_company ~n ~depth:(depth - 1) in + [ c; c ] + )); stores = List.init n (fun i -> { address = spf "%d foobar street" i; clients = - List.init 30 (fun j -> + List.init 2 (fun j -> { name = spf "client_%d_%d" i j; age = Int64.of_int ((j mod 30) + 15); }); employees = - List.init 5 (fun j -> + List.init 2 (fun j -> { name = spf "employee_%d_%d" i j; age = Int64.of_int ((j mod 30) + 18); @@ -513,11 +525,15 @@ module Nested = struct include Make_enc (E) let bench company = - let enc = E.create () in mk_t (spf "nenc-%s" E.name_of_enc) @@ fun () -> - Sys.opaque_identity - (E.clear enc; - enc_company company enc) + for _i = 1 to 10 do + let enc = E.create () in + for _j = 1 to 10 do + Sys.opaque_identity + (E.clear enc; + enc_company company enc) + done + done let string_of_company c = let e = E.create () in @@ -994,13 +1010,22 @@ module Nested = struct let bench_from_back_noinline = From_back_noinline.bench let bench_from_back_c = From_back_c.bench + let pp_size ~n ~depth = + Printf.printf "bench nested enc: length for n=%d, depth=%d is %d B\n" n + depth + (String.length (Basic.string_of_company @@ mk_company ~n ~depth)) + (* sanity check *) - let () = - let s_basic = Basic.string_of_company (mk_company 1) in - let s_buffers_nested = Buffers_nested.string_of_company (mk_company 1) in - let s_from_back = From_back.string_of_company (mk_company 1) in - let s_from_back2 = From_back_noinline.string_of_company (mk_company 1) in - let s_from_backc = From_back_c.string_of_company (mk_company 1) in + let check ~n ~depth () = + let s_basic = Basic.string_of_company (mk_company ~n ~depth) in + let s_buffers_nested = + Buffers_nested.string_of_company (mk_company ~n ~depth) + in + let s_from_back = From_back.string_of_company (mk_company ~n ~depth) in + let s_from_back2 = + From_back_noinline.string_of_company (mk_company ~n ~depth) + in + let s_from_backc = From_back_c.string_of_company (mk_company ~n ~depth) in (* Printf.printf "basic:\n(len=%d) %S\n" (String.length s_basic) s_basic; Printf.printf "from_back:\n(len=%d) %S\n" @@ -1029,20 +1054,28 @@ module Nested = struct assert (c_basic = c_from_back2); assert (c_basic = c_from_backc); () + + let () = + List.iter + (fun (n, depth) -> check ~n ~depth ()) + [ 1, 3; 2, 4; 10, 1; 20, 2 ] end -let test_nested_enc n = +let test_nested_enc ~n ~depth = let open B.Tree in - let company = Nested.mk_company n in - Printf.sprintf "%d" n + let company = Nested.mk_company ~n ~depth in + Printf.sprintf "n=%d,depth=%d" n depth @> lazy - (B.throughputN ~repeat:4 3 + (Nested.pp_size ~n ~depth; + B.throughputN ~repeat:4 3 [ Nested.bench_basic company; Nested.bench_buffers_nested company; Nested.bench_from_back company; + (* Nested.bench_from_back_noinline company; Nested.bench_from_back_c company; + *) ]) let () = @@ -1050,14 +1083,35 @@ let () = register @@ "nested" @>>> [ "enc" - @>>> [ - test_nested_enc 2; - test_nested_enc 5; - test_nested_enc 10; - test_nested_enc 20; - test_nested_enc 50; - test_nested_enc 100; - ]; + @>>> List.map + (fun (n, depth) -> test_nested_enc ~n ~depth) + [ + 1, 1; + 1, 1; + 1, 4; + 1, 6; + 1, 10; + 2, 1; + 2, 4; + 2, 6; + 2, 10; + 5, 1; + 5, 4; + 5, 6; + 10, 1; + 10, 2; + 10, 3; + 10, 4; + 20, 1; + 50, 1; + 20, 3; + 20, 4; + 50, 1; + 50, 3; + 50, 4; + 100, 1; + 100, 3; + ]; ] let () = B.Tree.run_global () diff --git a/benchs/foo.proto b/benchs/foo.proto index ba1e202e..13a07d46 100644 --- a/benchs/foo.proto +++ b/benchs/foo.proto @@ -15,4 +15,5 @@ message Store { message Company { string name = 1; repeated Store stores = 2; + repeated Company subsidiaries = 3; } From 19e166c57a8a0d2a119510aab30dbcdb1b61e8c2 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 21 Nov 2023 22:10:52 -0500 Subject: [PATCH 43/46] more benchs --- benchs/benchs.ml | 134 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 130 insertions(+), 4 deletions(-) diff --git a/benchs/benchs.ml b/benchs/benchs.ml index 3aee461a..4a27da3d 100644 --- a/benchs/benchs.ml +++ b/benchs/benchs.ml @@ -1,5 +1,6 @@ module B = Benchmark +let spf = Printf.sprintf let mk_t name f = name, f, () module Enc = struct @@ -426,6 +427,135 @@ let () = (* "enc" @>>> [ test_enc 5; test_enc 10; test_enc 50; test_enc 1000 ]; *) ] +module Varint_size = struct + type run_loop = n:int -> unit + + module While_inline = struct + (** Number of bytes to encode [i] *) + let[@inline] varint_size (i : int64) : int = + let i = ref i in + let n = ref 0 in + let continue = ref true in + while !continue do + incr n; + let cur = Int64.(logand !i 0x7fL) in + if cur = !i then + continue := false + else + i := Int64.shift_right_logical !i 7 + done; + !n + + let loop ~n = + for i = 1 to n do + ignore (Sys.opaque_identity (varint_size (Int64.of_int i)) : int) + done + end + + module While_noinline = struct + let[@inline never] varint_size (i : int64) : int = + let i = ref i in + let n = ref 0 in + let continue = ref true in + while !continue do + incr n; + let cur = Int64.(logand !i 0x7fL) in + if cur = !i then + continue := false + else + i := Int64.shift_right_logical !i 7 + done; + !n + + let loop ~n = + for i = 1 to n do + ignore (Sys.opaque_identity (varint_size (Int64.of_int i)) : int) + done + end + + module For_loop = struct + external int_of_bool : bool -> int = "%identity" + + let[@inline] varint_size (i : int64) : int = + let i = ref i in + let n = ref 0 in + for _j = 0 to 10 do + n := !n + int_of_bool (not (Int64.equal !i 0L)); + i := Int64.shift_right_logical !i 7 + done; + !n + + let loop ~n = + for i = 1 to n do + ignore (Sys.opaque_identity (varint_size (Int64.of_int i)) : int) + done + end + + module C_while = struct + external varint_size : (int64[@unboxed]) -> int + = "caml_pbrt_varint_size_byte" "caml_pbrt_varint_size" + [@@noalloc] + + let loop ~n = + for i = 1 to n do + ignore (Sys.opaque_identity (varint_size (Int64.of_int i)) : int) + done + end + + (* sanity checks *) + let () = + List.iter + (fun i -> + let i = Int64.of_int i in + let c1 = While_inline.varint_size i in + let c2 = While_noinline.varint_size i in + let c3 = For_loop.varint_size i in + let c4 = C_while.varint_size i in + assert (c1 = c2); + assert (c1 = c3); + assert (c1 = c4)) + [ + 1; + 2; + 3; + 10; + 15; + 20; + 21; + 22; + 30; + 50; + 100; + 300; + 1000; + 2000; + 100_000; + 1_000_000_000; + max_int - 10; + max_int; + ] +end + +let test_varint_size n = + let open B.Tree in + let mkbench name (run : Varint_size.run_loop) = + mk_t (spf "varint-size-%s" name) @@ fun () -> Sys.opaque_identity (run ~n) + in + + spf "%d" n + @> lazy + (B.throughputN ~repeat:4 3 + [ + mkbench "while-inline" Varint_size.While_inline.loop; + mkbench "while-noinline" Varint_size.While_noinline.loop; + mkbench "for-loop" Varint_size.For_loop.loop; + mkbench "c-while" Varint_size.C_while.loop; + ]) + +let () = + let open B.Tree in + register @@ "varint-size" @>>> List.map test_varint_size [ 1000; 100_000 ] + module Nested = struct type person = Foo.person = { name: string; @@ -489,8 +619,6 @@ module Nested = struct () end - let spf = Printf.sprintf - (* company, with [n] stores and [2^depth] subsidiaries *) let rec mk_company ~n ~depth : company = { @@ -1072,10 +1200,8 @@ let test_nested_enc ~n ~depth = Nested.bench_basic company; Nested.bench_buffers_nested company; Nested.bench_from_back company; - (* Nested.bench_from_back_noinline company; Nested.bench_from_back_c company; - *) ]) let () = From 47add5983f775bd7a34f56e30f96b4afff96cd17 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 24 Nov 2023 21:45:48 -0500 Subject: [PATCH 44/46] services: add Client/Server sub-modules to namespace stuff --- src/compilerlib/pb_codegen_services.ml | 207 ++++++++++++++----------- 1 file changed, 113 insertions(+), 94 deletions(-) diff --git a/src/compilerlib/pb_codegen_services.ml b/src/compilerlib/pb_codegen_services.ml index 0e5ca554..5cb39e82 100644 --- a/src/compilerlib/pb_codegen_services.ml +++ b/src/compilerlib/pb_codegen_services.ml @@ -85,88 +85,99 @@ let string_list_of_package (path : string list) : string = let gen_service_client_struct (service : Ot.service) sc : unit = let service_name = service.service_name in - List.iter - (fun (rpc : Ot.rpc) -> - let rpc_name = rpc.rpc_name in - let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in - let req_mode_witness = String.capitalize_ascii req_mode in - let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in - let res_mode_witness = String.capitalize_ascii res_mode in - F.empty_line sc; - F.linep sc "let %s : (%s, %s, %s, %s) Client.rpc =" - (Pb_codegen_util.function_name_of_rpc rpc) - req req_mode res res_mode; - F.linep sc " (Client.mk_rpc "; - F.linep sc " ~package:%s" - (string_list_of_package service.service_packages); - F.linep sc " ~service_name:%S ~rpc_name:%S" service.service_name - rpc.rpc_name; - F.linep sc " ~req_mode:Client.%s" req_mode_witness; - F.linep sc " ~res_mode:Client.%s" res_mode_witness; - F.linep sc " ~encode_json_req:%s" - (function_name_encode_json ~service_name ~rpc_name rpc.rpc_req); - F.linep sc " ~encode_pb_req:%s" - (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req); - F.linep sc " ~decode_json_res:%s" - (function_name_decode_json ~service_name ~rpc_name rpc.rpc_res); - F.linep sc " ~decode_pb_res:%s" - (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res); - let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in - let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in - F.linep sc " () : (%s, %s, %s, %s) Client.rpc)" req req_mode res - res_mode) - service.service_body + F.line sc "module Client = struct"; + let gen_rpc sc (rpc : Ot.rpc) = + let rpc_name = rpc.rpc_name in + let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in + let req_mode_witness = String.capitalize_ascii req_mode in + let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in + let res_mode_witness = String.capitalize_ascii res_mode in + F.empty_line sc; + F.linep sc "let %s : (%s, %s, %s, %s) Client.rpc =" + (Pb_codegen_util.function_name_of_rpc rpc) + req req_mode res res_mode; + F.linep sc " (Client.mk_rpc "; + F.linep sc " ~package:%s" + (string_list_of_package service.service_packages); + F.linep sc " ~service_name:%S ~rpc_name:%S" service.service_name + rpc.rpc_name; + F.linep sc " ~req_mode:Client.%s" req_mode_witness; + F.linep sc " ~res_mode:Client.%s" res_mode_witness; + F.linep sc " ~encode_json_req:%s" + (function_name_encode_json ~service_name ~rpc_name rpc.rpc_req); + F.linep sc " ~encode_pb_req:%s" + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req); + F.linep sc " ~decode_json_res:%s" + (function_name_decode_json ~service_name ~rpc_name rpc.rpc_res); + F.linep sc " ~decode_pb_res:%s" + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res); + let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in + let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in + F.linep sc " () : (%s, %s, %s, %s) Client.rpc)" req req_mode res res_mode + in + F.sub_scope sc (fun sc -> List.iter (gen_rpc sc) service.service_body); + F.line sc "end" let gen_service_server_struct (service : Ot.service) sc : unit = let service_name = service.service_name in (* generate rpc descriptions for the server side *) - List.iter - (fun (rpc : Ot.rpc) -> - F.empty_line sc; - let rpc_name = rpc.rpc_name in - let name = Pb_codegen_util.function_name_of_rpc rpc in - let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in - let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in - let req_mode_witness = String.capitalize_ascii req_mode in - let res_mode_witness = String.capitalize_ascii res_mode in - - F.linep sc "let _rpc_%s : (%s,%s,%s,%s) Server.rpc = " name req req_mode - res res_mode; - F.linep sc " (Server.mk_rpc ~name:%S" rpc.rpc_name; - F.linep sc " ~req_mode:Server.%s" req_mode_witness; - F.linep sc " ~res_mode:Server.%s" res_mode_witness; - F.linep sc " ~encode_json_res:%s" - (function_name_encode_json ~service_name ~rpc_name rpc.rpc_res); - F.linep sc " ~encode_pb_res:%s" - (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_res); - F.linep sc " ~decode_json_req:%s" - (function_name_decode_json ~service_name ~rpc_name rpc.rpc_req); - F.linep sc " ~decode_pb_req:%s" - (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_req); - F.linep sc " () : _ Server.rpc)") - service.service_body; - - (* now generate a function from the module type to a [Service_server.t] *) + let gen_rpc sc (rpc : Ot.rpc) = + F.empty_line sc; + let rpc_name = rpc.rpc_name in + let name = Pb_codegen_util.function_name_of_rpc rpc in + let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in + let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in + let req_mode_witness = String.capitalize_ascii req_mode in + let res_mode_witness = String.capitalize_ascii res_mode in + + F.linep sc "let _rpc_%s : (%s,%s,%s,%s) Server.rpc = " name req req_mode res + res_mode; + F.linep sc " (Server.mk_rpc ~name:%S" rpc.rpc_name; + F.linep sc " ~req_mode:Server.%s" req_mode_witness; + F.linep sc " ~res_mode:Server.%s" res_mode_witness; + F.linep sc " ~encode_json_res:%s" + (function_name_encode_json ~service_name ~rpc_name rpc.rpc_res); + F.linep sc " ~encode_pb_res:%s" + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_res); + F.linep sc " ~decode_json_req:%s" + (function_name_decode_json ~service_name ~rpc_name rpc.rpc_req); + F.linep sc " ~decode_pb_req:%s" + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_req); + F.linep sc " () : _ Server.rpc)" + in + + let gen_server sc = + F.line sc "open Pbrt_services"; + List.iter (gen_rpc sc) service.service_body; + + (* now generate a function from the module type to a [Service_server.t] *) + F.empty_line sc; + F.linep sc "let make"; + List.iter + (fun (rpc : Ot.rpc) -> + let name = Pb_codegen_util.function_name_of_rpc rpc in + F.linep sc " ~%s" name) + service.service_body; + F.line sc " () : _ Server.t ="; + F.linep sc " { Server."; + F.linep sc " service_name=%S;" service_name; + F.linep sc " package=%s;" + (string_list_of_package service.service_packages); + F.line sc " handlers=["; + List.iter + (fun (rpc : Ot.rpc) -> + let f = Pb_codegen_util.function_name_of_rpc rpc in + F.linep sc " (%s %s);" f (spf "_rpc_%s" f)) + service.service_body; + F.line sc " ];"; + F.line sc " }" + in + F.empty_line sc; - F.linep sc "let make_server"; - List.iter - (fun (rpc : Ot.rpc) -> - let name = Pb_codegen_util.function_name_of_rpc rpc in - F.linep sc " ~%s" name) - service.service_body; - F.line sc " () : _ Server.t ="; - F.linep sc " { Server."; - F.linep sc " service_name=%S;" service_name; - F.linep sc " package=%s;" (string_list_of_package service.service_packages); - F.line sc " handlers=["; - List.iter - (fun (rpc : Ot.rpc) -> - let f = Pb_codegen_util.function_name_of_rpc rpc in - F.linep sc " (%s %s);" f (spf "_rpc_%s" f)) - service.service_body; - F.line sc " ];"; - F.line sc " }"; + F.line sc "module Server = struct"; + F.sub_scope sc gen_server; + F.line sc "end"; F.empty_line sc let gen_service_struct (service : Ot.service) sc : unit = @@ -191,27 +202,35 @@ let gen_service_sig (service : Ot.service) sc : unit = F.linep sc "open Pbrt_services.Value_mode"; (* client *) - List.iter - (fun (rpc : Ot.rpc) -> - F.empty_line sc; - let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in - let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in - F.linep sc "val %s : (%s, %s, %s, %s) Client.rpc" - (Pb_codegen_util.function_name_of_rpc rpc) - req req_mode res res_mode) - service.service_body; + let gen_client_rpc sc (rpc : Ot.rpc) = + F.empty_line sc; + let req, req_mode = ocaml_type_of_rpc_type rpc.rpc_req in + let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in + F.linep sc "val %s : (%s, %s, %s, %s) Client.rpc" + (Pb_codegen_util.function_name_of_rpc rpc) + req req_mode res res_mode + in + + F.empty_line sc; + F.line sc "module Client : sig"; + F.sub_scope sc (fun sc -> + List.iter (gen_client_rpc sc) service.service_body); + F.line sc "end"; (* server *) F.empty_line sc; - F.line sc "(** Produce a server implementation from handlers *)"; - F.linep sc "val make_server : "; - List.iter - (fun (rpc : Ot.rpc) -> - F.linep sc " %s:(%s -> 'handler) ->" - (Pb_codegen_util.function_name_of_rpc rpc) - (string_of_server_rpc rpc.rpc_req rpc.rpc_res)) - service.service_body; - F.linep sc " unit -> 'handler Server.t"; + F.line sc "module Server : sig"; + F.sub_scope sc (fun sc -> + F.line sc "(** Produce a server implementation from handlers *)"; + F.linep sc "val make : "; + List.iter + (fun (rpc : Ot.rpc) -> + F.linep sc " %s:(%s -> 'handler) ->" + (Pb_codegen_util.function_name_of_rpc rpc) + (string_of_server_rpc rpc.rpc_req rpc.rpc_res)) + service.service_body; + F.linep sc " unit -> 'handler Pbrt_services.Server.t"); + F.line sc "end"; ()); From 84d0dbd4d03309f8b0ad2b9ae57780479db60c41 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 24 Nov 2023 21:46:02 -0500 Subject: [PATCH 45/46] chore: add myself as maintainer to opam files :^) --- dune-project | 4 ++-- ocaml-protoc.opam | 4 ++-- pbrt.opam | 4 ++-- pbrt_services.opam | 4 ++-- pbrt_yojson.opam | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/dune-project b/dune-project index 0c1b97cf..ec9ff447 100644 --- a/dune-project +++ b/dune-project @@ -3,8 +3,8 @@ (generate_opam_files true) (version 2.4) -(maintainers "Maxime Ransan ") -(authors "Maxime Ransan ") +(maintainers "Maxime Ransan " "Simon Cruanes") +(authors "Maxime Ransan " "Simon Cruanes") (source (github mransan/ocaml-protoc)) (license MIT) diff --git a/ocaml-protoc.opam b/ocaml-protoc.opam index 24c8e06c..618cb62d 100644 --- a/ocaml-protoc.opam +++ b/ocaml-protoc.opam @@ -2,8 +2,8 @@ opam-version: "2.0" version: "2.4" synopsis: "Pure OCaml compiler for .proto files" -maintainer: ["Maxime Ransan "] -authors: ["Maxime Ransan "] +maintainer: ["Maxime Ransan " "Simon Cruanes"] +authors: ["Maxime Ransan " "Simon Cruanes"] license: "MIT" tags: ["protoc" "protobuf" "codegen"] homepage: "https://github.com/mransan/ocaml-protoc" diff --git a/pbrt.opam b/pbrt.opam index 46b39530..b23272a2 100644 --- a/pbrt.opam +++ b/pbrt.opam @@ -2,8 +2,8 @@ opam-version: "2.0" version: "2.4" synopsis: "Runtime library for Protobuf tooling" -maintainer: ["Maxime Ransan "] -authors: ["Maxime Ransan "] +maintainer: ["Maxime Ransan " "Simon Cruanes"] +authors: ["Maxime Ransan " "Simon Cruanes"] license: "MIT" tags: ["protobuf" "encode" "decode"] homepage: "https://github.com/mransan/ocaml-protoc" diff --git a/pbrt_services.opam b/pbrt_services.opam index 4a92a90f..6fd466f3 100644 --- a/pbrt_services.opam +++ b/pbrt_services.opam @@ -2,8 +2,8 @@ opam-version: "2.0" version: "2.4" synopsis: "Runtime library for ocaml-protoc to support RPC services" -maintainer: ["Maxime Ransan "] -authors: ["Maxime Ransan "] +maintainer: ["Maxime Ransan " "Simon Cruanes"] +authors: ["Maxime Ransan " "Simon Cruanes"] license: "MIT" tags: ["protobuf" "encode" "decode" "services" "rpc"] homepage: "https://github.com/mransan/ocaml-protoc" diff --git a/pbrt_yojson.opam b/pbrt_yojson.opam index 6fde790b..6fb78650 100644 --- a/pbrt_yojson.opam +++ b/pbrt_yojson.opam @@ -3,8 +3,8 @@ opam-version: "2.0" version: "2.4" synopsis: "Runtime library for ocaml-protoc to support JSON encoding/decoding" -maintainer: ["Maxime Ransan "] -authors: ["Maxime Ransan "] +maintainer: ["Maxime Ransan " "Simon Cruanes"] +authors: ["Maxime Ransan " "Simon Cruanes"] license: "MIT" tags: ["protobuf" "encode" "decode"] homepage: "https://github.com/mransan/ocaml-protoc" From c1c83928d1e5a6d6b924fef63ffff208915b112c Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 24 Nov 2023 22:02:46 -0500 Subject: [PATCH 46/46] avoid warning --- src/compilerlib/pb_codegen_services.ml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/compilerlib/pb_codegen_services.ml b/src/compilerlib/pb_codegen_services.ml index 5cb39e82..1979eaf9 100644 --- a/src/compilerlib/pb_codegen_services.ml +++ b/src/compilerlib/pb_codegen_services.ml @@ -115,7 +115,9 @@ let gen_service_client_struct (service : Ot.service) sc : unit = let res, res_mode = ocaml_type_of_rpc_type rpc.rpc_res in F.linep sc " () : (%s, %s, %s, %s) Client.rpc)" req req_mode res res_mode in - F.sub_scope sc (fun sc -> List.iter (gen_rpc sc) service.service_body); + F.sub_scope sc (fun sc -> + F.linep sc "open Pbrt_services"; + List.iter (gen_rpc sc) service.service_body); F.line sc "end" let gen_service_server_struct (service : Ot.service) sc : unit = @@ -183,7 +185,6 @@ let gen_service_server_struct (service : Ot.service) sc : unit = let gen_service_struct (service : Ot.service) sc : unit = F.linep sc "module %s = struct" (mod_name_for_client service); F.sub_scope sc (fun sc -> - F.linep sc "open Pbrt_services"; F.linep sc "open Pbrt_services.Value_mode"; gen_service_client_struct service sc;