diff --git a/src/driver.ml b/src/driver.ml index 6c1673cbf..320e9d05d 100644 --- a/src/driver.ml +++ b/src/driver.ml @@ -114,52 +114,123 @@ module Instrument = struct end module Transform = struct - type t = { + type meta = { name : string; aliases : string list; - impl : - (Expansion_context.Base.t -> - Parsetree.structure -> - Parsetree.structure With_errors.t) - option; - intf : - (Expansion_context.Base.t -> - Parsetree.signature -> - Parsetree.signature With_errors.t) - option; - lint_impl : - (Expansion_context.Base.t -> Parsetree.structure -> Lint_error.t list) - option; - lint_intf : - (Expansion_context.Base.t -> Parsetree.signature -> Lint_error.t list) - option; - preprocess_impl : - (Expansion_context.Base.t -> - Parsetree.structure -> - Parsetree.structure With_errors.t) - option; - preprocess_intf : - (Expansion_context.Base.t -> - Parsetree.signature -> - Parsetree.signature With_errors.t) - option; - enclose_impl : - (Expansion_context.Base.t -> - Location.t option -> - Parsetree.structure * Parsetree.structure) - option; - enclose_intf : - (Expansion_context.Base.t -> - Location.t option -> - Parsetree.signature * Parsetree.signature) - option; - instrument : Instrument.t option; - rules : Context_free.Rule.t list; registered_at : Caller_id.t; } - let has_name t name = - String.equal name t.name || List.exists ~f:(String.equal name) t.aliases + type 'result struct_fun = + Expansion_context.Base.t -> Parsetree.structure -> 'result + + type 'result sig_fun = + Expansion_context.Base.t -> Parsetree.signature -> 'result + + type impl_intf_pass = + [ `Impl of Parsetree.structure With_errors.t struct_fun + | `Intf of Parsetree.signature With_errors.t sig_fun ] + + type lint_pass = + [ `Lint_impl of Lint_error.t list struct_fun + | `Lint_intf of Lint_error.t list sig_fun ] + + type preprocess_pass = + [ `Preprocess_impl of Parsetree.structure With_errors.t struct_fun + | `Preprocess_intf of Parsetree.signature With_errors.t sig_fun ] + + type enclose_pass = + [ `Enclose_impl of + Expansion_context.Base.t -> + Location.t option -> + Parsetree.structure * Parsetree.structure + | `Enclose_intf of + Expansion_context.Base.t -> + Location.t option -> + Parsetree.signature * Parsetree.signature ] + + (* When registering passes with the driver, context-free rules + are merged into one pass *) + type pass = + [ enclose_pass + | impl_intf_pass + | lint_pass + | preprocess_pass + | `Instrument of Instrument.t + | `Ctx_free of Context_free.Rule.t list ] + + (* Passes are registered as a collection of individual passes associated + with some specific meta shared between them *) + type t = pass list * meta + + let create ?impl ?intf ?lint_impl ?lint_intf ?preprocess_impl ?preprocess_intf + ?enclose_impl ?enclose_intf ?instrument ?rules meta = + let impl = Option.map ~f:(fun f -> `Impl f) impl in + let intf = Option.map ~f:(fun f -> `Intf f) intf in + let preprocess_impl = + Option.map ~f:(fun f -> `Preprocess_impl f) preprocess_impl + in + let preprocess_intf = + Option.map ~f:(fun f -> `Preprocess_intf f) preprocess_intf + in + let lint_impl = Option.map ~f:(fun f -> `Lint_impl f) lint_impl in + let lint_intf = Option.map ~f:(fun f -> `Lint_intf f) lint_intf in + let enclose_impl = Option.map ~f:(fun f -> `Enclose_impl f) enclose_impl in + let enclose_intf = Option.map ~f:(fun f -> `Enclose_intf f) enclose_intf in + let instrument = Option.map ~f:(fun f -> `Instrument f) instrument in + let rules = Option.map ~f:(fun f -> `Ctx_free f) rules in + let rec filter_none acc = function + | [] -> List.rev acc + | None :: rest -> filter_none acc rest + | Some v :: rest -> filter_none (v :: acc) rest + in + let passes = + filter_none [] + [ + rules; + impl; + intf; + preprocess_impl; + preprocess_intf; + lint_impl; + lint_intf; + enclose_impl; + enclose_intf; + instrument; + ] + in + (passes, meta) + + let find_impl (passes, _) = + List.find_map ~f:(function `Impl i -> Some i | _ -> None) passes + + let find_intf (passes, _) = + List.find_map ~f:(function `Intf i -> Some i | _ -> None) passes + + let find_lint_impl (passes, _) = + List.find_map ~f:(function `Lint_impl i -> Some i | _ -> None) passes + + let find_lint_intf (passes, _) = + List.find_map ~f:(function `Lint_intf i -> Some i | _ -> None) passes + + let find_enclose_impl (passes, _) = + List.find_map ~f:(function `Enclose_impl i -> Some i | _ -> None) passes + + let find_enclose_intf (passes, _) = + List.find_map ~f:(function `Enclose_intf i -> Some i | _ -> None) passes + + let find_rules (passes, _) = + match + List.find_map ~f:(function `Ctx_free i -> Some i | _ -> None) passes + with + | None -> [] + | Some lst -> lst + + let find_instruments (passes, _) = + List.find_map ~f:(function `Instrument i -> Some i | _ -> None) passes + + let has_name (_, meta) name = + String.equal name meta.name + || List.exists ~f:(String.equal name) meta.aliases let all : t list ref = ref [] @@ -171,42 +242,30 @@ module Transform = struct let register ?(extensions = []) ?(rules = []) ?enclose_impl ?enclose_intf ?impl ?intf ?lint_impl ?lint_intf ?preprocess_impl ?preprocess_intf ?instrument ?(aliases = []) name = - let rules = List.map extensions ~f:Context_free.Rule.extension @ rules in + let rules = + List.map extensions ~f:(fun ctx -> Context_free.Rule.extension ctx) + @ rules + in let caller_id = Caller_id.get ~skip:[ Stdlib.__FILE__ ] in - (match List.filter !all ~f:(fun ct -> has_name ct name) with - | [] -> () - | ct :: _ -> + (match List.find_opt !all ~f:(fun ct -> has_name ct name) with + | None -> () + | Some (_, meta) -> Printf.eprintf "Warning: code transformation %s registered twice.\n" name; Printf.eprintf " - first time was at %a\n" print_caller_id - ct.registered_at; + meta.registered_at; Printf.eprintf " - second time is at %a\n" print_caller_id caller_id); - let impl = Option.map impl ~f:(fun f ctx ast -> return (f ctx ast)) in - let intf = Option.map intf ~f:(fun f ctx ast -> return (f ctx ast)) in - let preprocess_impl = - Option.map preprocess_impl ~f:(fun f ctx ast -> return (f ctx ast)) + let with_errors f ctx ast = return (f ctx ast) in + let impl = Option.map impl ~f:with_errors in + let intf = Option.map intf ~f:with_errors in + let preprocess_impl = Option.map preprocess_impl ~f:with_errors in + let preprocess_intf = Option.map preprocess_intf ~f:with_errors in + let new_ct_meta = { name; aliases; registered_at = caller_id } in + let transform = + create ?impl ?intf ?preprocess_impl ?preprocess_intf ?lint_impl ?lint_intf + ?enclose_impl ?enclose_intf ~rules ?instrument new_ct_meta in - let preprocess_intf = - Option.map preprocess_intf ~f:(fun f ctx ast -> return (f ctx ast)) - in - let ct = - { - name; - aliases; - rules; - enclose_impl; - enclose_intf; - impl; - intf; - lint_impl; - preprocess_impl; - preprocess_intf; - lint_intf; - instrument; - registered_at = caller_id; - } - in - all := ct :: !all + all := transform :: !all let rec last prev l = match l with [] -> prev | x :: l -> last x l @@ -218,9 +277,13 @@ module Transform = struct let last = get_loc (last x l) in Some { first with loc_end = last.loc_end } - let merge_into_generic_mappers t ~embed_errors ~hook ~expect_mismatch_handler - ~tool_name ~input_name = - let { rules; enclose_impl; enclose_intf; impl; intf; _ } = t in + let merge_into_generic_mappers (t : t) ~embed_errors ~hook + ~expect_mismatch_handler ~tool_name ~input_name = + let rules = find_rules t + and enclose_impl = find_enclose_impl t + and enclose_intf = find_enclose_intf t + and impl = find_impl t + and intf = find_intf t in let map = new Context_free.map_top_down rules ~embed_errors ~generated_code_hook:hook ~expect_mismatch_handler @@ -302,113 +365,102 @@ module Transform = struct map#signature base_ctxt (List.concat [ attrs; header; sg; footer ]) >>= fun sg -> match intf with None -> return sg | Some f -> f ctxt sg in - { t with impl = Some map_impl; intf = Some map_intf } + let passes, meta = t in + (`Impl map_impl :: `Intf map_intf :: passes, meta) let builtin_context_free_name = "" let builtin_of_context_free_rewriters ~hook ~rules ~enclose_impl ~enclose_intf ~input_name = - merge_into_generic_mappers ~hook ~input_name + let meta = { name = builtin_context_free_name; aliases = []; - impl = None; - intf = None; - lint_impl = None; - lint_intf = None; - preprocess_impl = None; - preprocess_intf = None; - enclose_impl; - enclose_intf; - instrument = None; - rules; registered_at = Caller_id.get ~skip:[]; } + in + let t = create ~rules ?enclose_impl ?enclose_intf meta in + merge_into_generic_mappers ~hook ~input_name t (* Meant to be used after partitioning *) - let rewrites_not_context_free t = - match t with + let rewrites_not_context_free (t, meta) = + match meta with | { name; _ } when String.equal name builtin_context_free_name -> false - | { - impl = None; - intf = None; - instrument = None; - preprocess_impl = None; - preprocess_intf = None; - _; - } -> - false - | _ -> true + | _ -> + let check_not_context_free = function + | #impl_intf_pass | `Instrument _ | #preprocess_pass -> true + | _ -> false + in + List.exists ~f:check_not_context_free t - let partition_transformations ts = + let partition_transformations (ts : t list) = let before_instrs, after_instrs, rest = - List.fold_left ts ~init:([], [], []) ~f:(fun (bef_i, aft_i, rest) t -> + List.fold_left ts ~init:([], [], []) + ~f:(fun (bef_i, aft_i, rest) ((t, meta) : t) -> let reduced_t = - { - t with - lint_impl = None; - lint_intf = None; - preprocess_impl = None; - preprocess_intf = None; - } + List.filter + ~f:(function #lint_pass | #preprocess_pass -> false | _ -> true) + t + in + let remove_rules t = + List.filter ~f:(function `Ctx_free _ -> false | _ -> true) t in let f instr = (instr.Instrument.position, instr.Instrument.transformation) in - match Option.map t.instrument ~f with + let instrument = find_instruments (t, meta) in + match Option.map instrument ~f with | Some (Before, transf) -> - ( { reduced_t with impl = Some transf; rules = [] } :: bef_i, + ( (`Impl transf :: remove_rules reduced_t, meta) :: bef_i, aft_i, - reduced_t :: rest ) + (reduced_t, meta) :: rest ) | Some (After, transf) -> ( bef_i, - { reduced_t with impl = Some transf; rules = [] } :: aft_i, - reduced_t :: rest ) - | None -> (bef_i, aft_i, reduced_t :: rest)) + (`Impl transf :: remove_rules reduced_t, meta) :: aft_i, + (reduced_t, meta) :: rest ) + | None -> (bef_i, aft_i, (reduced_t, meta) :: rest)) + in + let linters = + List.filter_map ts ~f:(fun (t, meta) -> + let linters = + List.fold_left + ~f:(fun acc -> function #lint_pass as t -> t :: acc | _ -> acc) + ~init:[] t + in + match linters with + | [] -> None + | linters -> + let new_meta = + { + meta with + name = Printf.sprintf "" meta.name; + aliases = []; + } + in + Some (linters, new_meta)) + in + let preprocessors = + List.filter_map ts ~f:(fun (t, meta) -> + let linters = + List.fold_left + ~f:(fun acc -> function + | #preprocess_pass as t -> t :: acc | _ -> acc) + ~init:[] t + in + match linters with + | [] -> None + | linters -> + let new_meta = + { + meta with + name = Printf.sprintf "" meta.name; + aliases = []; + } + in + Some (linters, new_meta)) in - ( `Linters - (List.filter_map ts ~f:(fun t -> - if Option.is_some t.lint_impl || Option.is_some t.lint_intf then - Some - { - name = Printf.sprintf "" t.name; - aliases = []; - impl = None; - intf = None; - lint_impl = t.lint_impl; - lint_intf = t.lint_intf; - enclose_impl = None; - enclose_intf = None; - preprocess_impl = None; - preprocess_intf = None; - instrument = None; - rules = []; - registered_at = t.registered_at; - } - else None)), - `Preprocess - (List.filter_map ts ~f:(fun t -> - if - Option.is_some t.preprocess_impl - || Option.is_some t.preprocess_intf - then - Some - { - name = Printf.sprintf "" t.name; - aliases = []; - impl = t.preprocess_impl; - intf = t.preprocess_intf; - lint_impl = None; - lint_intf = None; - enclose_impl = None; - enclose_intf = None; - preprocess_impl = None; - preprocess_intf = None; - instrument = None; - rules = []; - registered_at = t.registered_at; - } - else None)), + ( `Linters linters, + `Preprocess preprocessors, `Before_instrs before_instrs, `After_instrs after_instrs, `Rest rest ) @@ -492,7 +544,8 @@ let get_whole_ast_passes ~embed_errors ~hook ~expect_mismatch_handler ~tool_name (* Allow only one preprocessor to assure deterministic order *) (if List.length preprocess > 1 then let pp = - String.concat ~sep:", " (List.map preprocess ~f:(fun t -> t.name)) + String.concat ~sep:", " + (List.map preprocess ~f:(fun (_, meta) -> meta.name)) in let err = Printf.sprintf "At most one preprocessor is allowed, while got: %s" pp @@ -506,18 +559,16 @@ let get_whole_ast_passes ~embed_errors ~hook ~expect_mismatch_handler ~tool_name ~expect_mismatch_handler ~input_name) else (let get_enclosers ~f = - List.filter_map transforms ~f:(fun (ct : Transform.t) -> - match f ct with None -> None | Some x -> Some (ct.name, x)) + List.filter_map transforms ~f:(fun ((_, meta) as ct : Transform.t) -> + match f ct with None -> None | Some x -> Some (meta.name, x)) (* Sort them to ensure deterministic ordering *) |> List.sort ~cmp:(fun (a, _) (b, _) -> String.compare a b) |> List.map ~f:snd in - let rules = - List.map transforms ~f:(fun (ct : Transform.t) -> ct.rules) - |> List.concat - and impl_enclosers = get_enclosers ~f:(fun ct -> ct.enclose_impl) - and intf_enclosers = get_enclosers ~f:(fun ct -> ct.enclose_intf) in + let rules = List.map transforms ~f:Transform.find_rules |> List.concat + and impl_enclosers = get_enclosers ~f:Transform.find_enclose_impl + and intf_enclosers = get_enclosers ~f:Transform.find_enclose_intf in match (rules, impl_enclosers, intf_enclosers) with | [], [], [] -> transforms | _ -> @@ -540,9 +591,12 @@ let get_whole_ast_passes ~embed_errors ~hook ~expect_mismatch_handler ~tool_name ~tool_name ~input_name :: transforms) |> List.filter ~f:(fun (ct : Transform.t) -> - match (ct.impl, ct.intf) with None, None -> false | _ -> true) + match (Transform.find_impl ct, Transform.find_intf ct) with + | None, None -> false + | _ -> true) in - linters @ preprocess @ before_instrs @ make_generic cts @ after_instrs + let generic = make_generic cts in + linters @ preprocess @ before_instrs @ generic @ after_instrs let apply_transforms ~tool_name ~file_path ~field ~lint_field ~dropped_so_far ~hook ~expect_mismatch_handler ~input_name ~embed_errors ?rewritten ast = @@ -568,7 +622,9 @@ let apply_transforms ~tool_name ~file_path ~field ~lint_field ~dropped_so_far let acc = List.fold_left cts ~init:(ast, [], [], []) ~f:(fun - (ast, dropped, (lint_errors : _ list), errors) (ct : Transform.t) -> + (ast, dropped, (lint_errors : _ list), errors) + ((ct, meta) : Transform.t) + -> let input_name = match input_name with | Some input_name -> input_name @@ -579,14 +635,14 @@ let apply_transforms ~tool_name ~file_path ~field ~lint_field ~dropped_so_far in let lint_errors, errors = - match lint_field ct with + match lint_field (ct, meta) with | None -> (lint_errors, errors) | Some f -> ( try (lint_errors @ f ctxt ast, errors) with exn when embed_errors -> (lint_errors, exn_to_loc_error exn :: errors)) in - match field ct with + match field (ct, meta) with | None -> (ast, dropped, lint_errors, errors) | Some f -> let (ast, more_errors), errors = @@ -597,7 +653,7 @@ let apply_transforms ~tool_name ~file_path ~field ~lint_field ~dropped_so_far let dropped = if !debug_attribute_drop then ( let new_dropped = dropped_so_far ast in - debug_dropped_attribute ct.name ~old_dropped:dropped + debug_dropped_attribute meta.name ~old_dropped:dropped ~new_dropped; new_dropped) else [] @@ -645,7 +701,7 @@ let print_passes () = in if !perform_checks then Printf.printf "\n"; - List.iter cts ~f:(fun ct -> Printf.printf "%s\n" ct.Transform.name); + List.iter cts ~f:(fun (_, meta) -> Printf.printf "%s\n" meta.Transform.name); if !perform_checks then ( Printf.printf "\n"; if !perform_checks_on_extensions then @@ -715,9 +771,8 @@ let map_structure_gen ~tool_name ~hook ~expect_mismatch_handler ~input_name in let file_path = get_default_path_str st in let st, lint_errors, errors = - apply_transforms st ~tool_name ~file_path - ~field:(fun (ct : Transform.t) -> ct.impl) - ~lint_field:(fun (ct : Transform.t) -> ct.lint_impl) + apply_transforms st ~tool_name ~file_path ~field:Transform.find_impl + ~lint_field:Transform.find_lint_impl ~dropped_so_far:Attribute.dropped_so_far_structure ~hook ~expect_mismatch_handler ~input_name ~embed_errors ?rewritten in @@ -791,9 +846,8 @@ let map_signature_gen ~tool_name ~hook ~expect_mismatch_handler ~input_name in let file_path = get_default_path_sig sg in let sg, lint_errors, errors = - apply_transforms sg ~tool_name ~file_path - ~field:(fun (ct : Transform.t) -> ct.intf) - ~lint_field:(fun (ct : Transform.t) -> ct.lint_intf) + apply_transforms sg ~tool_name ~file_path ~field:Transform.find_intf + ~lint_field:Transform.find_lint_intf ~dropped_so_far:Attribute.dropped_so_far_signature ~hook ~expect_mismatch_handler ~input_name ~embed_errors ?rewritten in @@ -1242,8 +1296,8 @@ let set_output_mode mode = (arg_of_output_mode y))) let print_transformations () = - List.iter !Transform.all ~f:(fun (ct : Transform.t) -> - Printf.printf "%s\n" ct.name) + List.iter !Transform.all ~f:(fun ((_, meta) : Transform.t) -> + Printf.printf "%s\n" meta.name) let parse_apply_list s = let names = @@ -1282,7 +1336,7 @@ let handle_dont_apply s = let interpret_mask () = if Option.is_some mask.apply || Option.is_some mask.dont_apply then - let selected_transform_name ct = + let selected_transform_name ((_, meta) as ct : Transform.t) = let is_candidate = match mask.apply with | None -> true @@ -1294,7 +1348,7 @@ let interpret_mask () = | Some names -> is_candidate && not (List.exists names ~f:(Transform.has_name ct)) in - if is_selected then Some ct.name else None + if is_selected then Some meta.Transform.name else None in apply_list := Some (List.filter_map !Transform.all ~f:selected_transform_name) diff --git a/stdppx/stdppx.ml b/stdppx/stdppx.ml index 2bcc4b8e0..5741f25b8 100644 --- a/stdppx/stdppx.ml +++ b/stdppx/stdppx.ml @@ -241,6 +241,12 @@ module List = struct let find_map_exn list ~f = match find_map list ~f with Some x -> x | None -> raise Not_found + let rec find_opt list ~f = + match list with + | [] -> None + | head :: tail -> + if f head then Some head else find_opt tail ~f + let rec last = function | [] -> None | [ x ] -> Some x