diff --git a/extensions/prost/defs.bzl b/extensions/prost/defs.bzl index 6d3ddb7041..40b3af086c 100644 --- a/extensions/prost/defs.bzl +++ b/extensions/prost/defs.bzl @@ -146,6 +146,11 @@ load( _rust_prost_library = "rust_prost_library", _rust_prost_toolchain = "rust_prost_toolchain", ) +load( + "//private:prost_transform.bzl", + _rust_prost_transform = "rust_prost_transform", +) rust_prost_library = _rust_prost_library rust_prost_toolchain = _rust_prost_toolchain +rust_prost_transform = _rust_prost_transform diff --git a/extensions/prost/private/prost.bzl b/extensions/prost/private/prost.bzl index 1a87d6c187..8713d7c86a 100644 --- a/extensions/prost/private/prost.bzl +++ b/extensions/prost/private/prost.bzl @@ -19,6 +19,7 @@ load("@rules_rust//rust/private:rustc.bzl", "rustc_compile_action") # buildifier: disable=bzl-visibility load("@rules_rust//rust/private:utils.bzl", "can_build_metadata") load("//:providers.bzl", "ProstProtoInfo") +load(":prost_transform.bzl", "ProstTransformInfo") RUST_EDITION = "2021" @@ -39,7 +40,15 @@ def _create_proto_lang_toolchain(ctx, prost_toolchain): return proto_lang_toolchain -def _compile_proto(ctx, crate_name, proto_info, deps, prost_toolchain, rustfmt_toolchain = None): +def _compile_proto( + *, + ctx, + crate_name, + proto_info, + transform_infos, + deps, + prost_toolchain, + rustfmt_toolchain = None): deps_info_file = ctx.actions.declare_file(ctx.label.name + ".prost_deps_info") dep_package_infos = [dep[ProstProtoInfo].package_info for dep in deps] ctx.actions.write( @@ -53,6 +62,14 @@ def _compile_proto(ctx, crate_name, proto_info, deps, prost_toolchain, rustfmt_t proto_compiler = prost_toolchain.proto_compiler tools = depset([proto_compiler.executable]) + tonic_opts = [] + prost_opts = [] + additional_srcs = [] + for transform_info in transform_infos: + tonic_opts.extend(transform_info.tonic_opts) + prost_opts.extend(transform_info.prost_opts) + additional_srcs.append(transform_info.srcs) + direct_crate_names = [dep[ProstProtoInfo].dep_variant_info.crate_info.name for dep in deps] additional_args = ctx.actions.args() @@ -65,7 +82,12 @@ def _compile_proto(ctx, crate_name, proto_info, deps, prost_toolchain, rustfmt_t additional_args.add("--direct_dep_crate_names={}".format(",".join(direct_crate_names))) additional_args.add("--prost_opt=compile_well_known_types") additional_args.add("--descriptor_set={}".format(proto_info.direct_descriptor_set.path)) - additional_args.add_all(prost_toolchain.prost_opts, format_each = "--prost_opt=%s") + + all_additional_srcs = depset(transitive = additional_srcs) + if additional_srcs: + additional_args.add("--additional_srcs={}".format(",".join([f.path for f in all_additional_srcs.to_list()]))) + + additional_args.add_all(prost_toolchain.prost_opts + prost_opts, format_each = "--prost_opt=%s") if prost_toolchain.tonic_plugin: tonic_plugin = prost_toolchain.tonic_plugin[DefaultInfo].files_to_run @@ -73,14 +95,18 @@ def _compile_proto(ctx, crate_name, proto_info, deps, prost_toolchain, rustfmt_t additional_args.add("--tonic_opt=no_include") additional_args.add("--tonic_opt=compile_well_known_types") additional_args.add("--is_tonic") - additional_args.add_all(prost_toolchain.tonic_opts, format_each = "--tonic_opt=%s") + + additional_args.add_all(prost_toolchain.tonic_opts + tonic_opts, format_each = "--tonic_opt=%s") tools = depset([tonic_plugin.executable], transitive = [tools]) if rustfmt_toolchain: additional_args.add("--rustfmt={}".format(rustfmt_toolchain.rustfmt.path)) tools = depset(transitive = [tools, rustfmt_toolchain.all_files]) - additional_inputs = depset([deps_info_file, proto_info.direct_descriptor_set] + [dep[ProstProtoInfo].package_info for dep in deps]) + additional_inputs = depset( + [deps_info_file, proto_info.direct_descriptor_set] + [dep[ProstProtoInfo].package_info for dep in deps], + transitive = [all_additional_srcs], + ) proto_common.compile( actions = ctx.actions, @@ -116,7 +142,14 @@ def _get_cc_info(providers): return provider fail("Couldn't find a CcInfo in the list of providers") -def _compile_rust(ctx, attr, crate_name, src, deps, edition): +def _compile_rust( + *, + ctx, + attr, + crate_name, + src, + deps, + edition): """Compiles a Rust source file. Args: @@ -233,7 +266,14 @@ def _rust_prost_aspect_impl(target, ctx): if RustAnalyzerInfo in proto_dep: rust_analyzer_deps.append(proto_dep[RustAnalyzerInfo]) - deps = runtime_deps + direct_deps + transform_infos = [] + for data_target in getattr(ctx.rule.attr, "data", []): + if ProstTransformInfo in data_target: + transform_infos.append(data_target[ProstTransformInfo]) + + rust_deps = runtime_deps + direct_deps + for transform_info in transform_infos: + rust_deps.extend(transform_info.deps) crate_name = ctx.label.name.replace("-", "_").replace("/", "_") @@ -243,6 +283,7 @@ def _rust_prost_aspect_impl(target, ctx): ctx = ctx, crate_name = crate_name, proto_info = proto_info, + transform_infos = transform_infos, deps = proto_deps, prost_toolchain = prost_toolchain, rustfmt_toolchain = rustfmt_toolchain, @@ -253,7 +294,7 @@ def _rust_prost_aspect_impl(target, ctx): attr = ctx.rule.attr, crate_name = crate_name, src = lib_rs, - deps = deps, + deps = rust_deps, edition = RUST_EDITION, ) @@ -495,7 +536,7 @@ def _current_prost_runtime_impl(ctx): )] current_prost_runtime = rule( - doc = "A rule for accessing the current Prost toolchain components needed by the process wrapper", + doc = "A rule for accessing the current Prost toolchain components needed by the process wrapper.", provides = [rust_common.crate_group_info], implementation = _current_prost_runtime_impl, toolchains = [TOOLCHAIN_TYPE], diff --git a/extensions/prost/private/prost_transform.bzl b/extensions/prost/private/prost_transform.bzl new file mode 100644 index 0000000000..1712a49e86 --- /dev/null +++ b/extensions/prost/private/prost_transform.bzl @@ -0,0 +1,88 @@ +"""Prost rules.""" + +load("@rules_rust//rust:defs.bzl", "rust_common") + +ProstTransformInfo = provider( + doc = "Info about transformations to apply to Prost generated source code.", + fields = { + "deps": "List[DepVariantInfo]: Additional dependencies to compile into the Prost target.", + "prost_opts": "List[str]: Additional prost flags.", + "srcs": "Depset[File]: Additional source files to include in generated Prost source code.", + "tonic_opts": "List[str]: Additional tonic flags.", + }, +) + +def _rust_prost_transform_impl(ctx): + deps = [] + for target in ctx.attr.deps: + deps.append(rust_common.dep_variant_info( + crate_info = target[rust_common.crate_info] if rust_common.crate_info in target else None, + dep_info = target[rust_common.dep_info] if rust_common.dep_info in target else None, + cc_info = target[CcInfo] if CcInfo in target else None, + build_info = None, + )) + + # DefaultInfo is intentionally not returned here to avoid impacting other + # consumers of the `proto_library` target this rule is expected to be passed + # to. + return [ProstTransformInfo( + deps = deps, + prost_opts = ctx.attr.prost_opts, + srcs = depset(ctx.files.srcs), + tonic_opts = ctx.attr.tonic_opts, + )] + +rust_prost_transform = rule( + doc = """\ +A rule for transforming the outputs of `ProstGenProto` actions. + +This rule is used by adding it to the `data` attribute of `proto_library` targets. E.g. +```python +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_rust_prost//:defs.bzl", "rust_prost_library", "rust_prost_transform") + +rust_prost_transform( + name = "a_transform", + srcs = [ + "a_src.rs", + ], +) + +proto_library( + name = "a_proto", + srcs = [ + "a.proto", + ], + data = [ + ":transform", + ], +) + +rust_prost_library( + name = "a_rs_proto", + proto = ":a_proto", +) +``` + +The `rust_prost_library` will spawn an action on the `a_proto` target which consumes the +`a_transform` rule to provide a means of granularly modifying a proto library for `ProstGenProto` +actions with minimal impact to other consumers. +""", + implementation = _rust_prost_transform_impl, + attrs = { + "deps": attr.label_list( + doc = "Additional dependencies to add to the compiled crate.", + providers = [[rust_common.crate_info], [rust_common.crate_group_info]], + ), + "prost_opts": attr.string_list( + doc = "Additional options to add to Prost.", + ), + "srcs": attr.label_list( + doc = "Additional source files to include in generated Prost source code.", + allow_files = True, + ), + "tonic_opts": attr.string_list( + doc = "Additional options to add to Tonic.", + ), + }, +) diff --git a/extensions/prost/private/protoc_wrapper.rs b/extensions/prost/private/protoc_wrapper.rs index 0facb982b8..842845a840 100644 --- a/extensions/prost/private/protoc_wrapper.rs +++ b/extensions/prost/private/protoc_wrapper.rs @@ -102,6 +102,9 @@ impl Module { } } +const ADDITIONAL_CONTENT_HEADER: &str = + "// A D D I T I O N A L S O U R C E S ========================================"; + /// Generate a lib.rs file with all prost/tonic outputs embeeded in modules which /// mirror the proto packages. For the example proto file we would expect to see /// the Rust output that follows it. @@ -152,6 +155,7 @@ fn generate_lib_rs( prost_outputs: &BTreeSet, is_tonic: bool, direct_dep_crate_names: Vec, + additional_content: String, ) -> String { let mut contents = vec!["// @generated".to_string(), "".to_string()]; for crate_name in direct_dep_crate_names { @@ -193,6 +197,14 @@ fn generate_lib_rs( let mut content = String::new(); write_module(&mut content, &module_info, 0); + + if !additional_content.is_empty() { + return format!( + "{}\n\n{}\n\n{}", + content, ADDITIONAL_CONTENT_HEADER, additional_content + ); + } + content } @@ -421,6 +433,9 @@ struct Args { /// The proto files to compile. proto_files: Vec, + /// Additional source files to append to the generated rust source. + additional_srcs: Vec, + /// The include directories. includes: Vec, @@ -454,6 +469,7 @@ impl Args { let mut crate_name: Option = None; let mut package_info_file: Option = None; let mut proto_files: Vec = Vec::new(); + let mut additional_srcs: Vec = Vec::new(); let mut includes = Vec::new(); let mut descriptor_set = None; let mut out_librs: Option = None; @@ -521,6 +537,9 @@ impl Args { } } } + ("--additional_srcs", value) => { + additional_srcs.extend(value.split(',').map(PathBuf::from).collect::>()); + } ("--direct_dep_crate_names", value) => { if value.trim().is_empty() { return; @@ -614,6 +633,7 @@ impl Args { crate_name: crate_name.unwrap(), package_info_file: package_info_file.unwrap(), proto_files, + additional_srcs, includes, descriptor_set: descriptor_set.unwrap(), out_librs: out_librs.unwrap(), @@ -717,6 +737,7 @@ fn main() { label, package_info_file, proto_files, + additional_srcs, includes, descriptor_set, out_librs, @@ -733,6 +754,19 @@ fn main() { let package_name = get_package_name(&descriptor_set).unwrap_or_default(); let expect_rs = expect_fs_file_to_be_generated(&descriptor_set); let has_services = has_services(&descriptor_set); + let additional_content = additional_srcs + .into_iter() + .map(|f| { + fs::read_to_string(&f).unwrap_or_else(|e| { + panic!( + "Failed to read additional source file: `{}`\n{:?}", + f.display(), + e + ) + }) + }) + .collect::>() + .join("\n"); if has_services && !is_tonic { eprintln!("Warning: Service definitions will not be generated because the prost toolchain did not define a tonic plugin."); @@ -875,7 +909,12 @@ fn main() { // Write outputs fs::write( &out_librs, - generate_lib_rs(&rust_files, is_tonic, direct_dep_crate_names), + generate_lib_rs( + &rust_files, + is_tonic, + direct_dep_crate_names, + additional_content, + ), ) .expect("Failed to write file."); fs::write( diff --git a/extensions/prost/private/tests/transform/BUILD.bazel b/extensions/prost/private/tests/transform/BUILD.bazel new file mode 100644 index 0000000000..b7b0dce435 --- /dev/null +++ b/extensions/prost/private/tests/transform/BUILD.bazel @@ -0,0 +1,43 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_rust//rust:defs.bzl", "rust_test") +load("//:defs.bzl", "rust_prost_library", "rust_prost_transform") + +package(default_visibility = ["//private/tests:__subpackages__"]) + +rust_prost_transform( + name = "transform", + srcs = ["a_src.rs"], +) + +proto_library( + name = "a_proto", + srcs = [ + "a.proto", + ], + data = [ + ":transform", + ], + strip_import_prefix = "/private/tests/transform", + deps = [ + "//private/tests/transform/b:b_proto", + "//private/tests/types:types_proto", + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:timestamp_proto", + ], +) + +rust_prost_library( + name = "a_rs_proto", + proto = ":a_proto", +) + +rust_test( + name = "a_test", + srcs = ["a_test.rs"], + edition = "2021", + deps = [ + ":a_rs_proto", + # Add b_proto as a dependency directly to ensure compatibility with `a.proto`'s imports. + "//private/tests/transform/b:b_rs_proto", + ], +) diff --git a/extensions/prost/private/tests/transform/a.proto b/extensions/prost/private/tests/transform/a.proto new file mode 100644 index 0000000000..b619ef1a58 --- /dev/null +++ b/extensions/prost/private/tests/transform/a.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; +import "google/protobuf/duration.proto"; +import "b/b.proto"; +import "types/types.proto"; + +package a; + +message A { + string name = 1; + + a.b.B b = 2; + + google.protobuf.Timestamp timestamp = 3; + + google.protobuf.Duration duration = 4; + + Types types = 5; +} diff --git a/extensions/prost/private/tests/transform/a_src.rs b/extensions/prost/private/tests/transform/a_src.rs new file mode 100644 index 0000000000..857d9491af --- /dev/null +++ b/extensions/prost/private/tests/transform/a_src.rs @@ -0,0 +1,9 @@ +// Additional source code for `a.proto`. + +use std::fmt::{Display, Formatter, Result}; + +impl Display for crate::a::A { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!(f, "Display of: A") + } +} diff --git a/extensions/prost/private/tests/transform/a_test.rs b/extensions/prost/private/tests/transform/a_test.rs new file mode 100644 index 0000000000..161a977b5c --- /dev/null +++ b/extensions/prost/private/tests/transform/a_test.rs @@ -0,0 +1,58 @@ +//! Tests transitive dependencies. + +use a_proto::a::A; +use a_proto::b_proto::c_proto::a::b::c::C; +use a_proto::duration_proto::google::protobuf::Duration; +use a_proto::timestamp_proto::google::protobuf::Timestamp; +use a_proto::types_proto::Types; + +#[test] +fn test_a() { + let duration = Duration { + seconds: 1, + nanos: 2, + }; + + let a = A { + name: "a".to_string(), + // Ensure the external `b_proto` dependency is compatible with `a_proto`'s `B`. + b: Some(b_proto::a::b::B { + name: "b".to_string(), + c: Some(C { + name: "c".to_string(), + duration: Some(duration), + ..Default::default() + }), + ..Default::default() + }), + timestamp: Some(Timestamp { + seconds: 1, + nanos: 2, + }), + duration: Some(duration), + types: Some(Types::default()), + }; + + assert_eq!( + "Display of: A", + format!("{}", a), + "Unexpected `Display` implementation for {:#?}", + a + ); +} + +#[test] +fn test_b() { + use b_proto::Greeting; + + let b = b_proto::a::b::B { + name: "b".to_string(), + c: Some(C { + name: "c".to_string(), + ..Default::default() + }), + ..Default::default() + }; + + assert_eq!("Hallo, Bazel, my name is B!", b.greet("Bazel")); +} diff --git a/extensions/prost/private/tests/transform/b/BUILD.bazel b/extensions/prost/private/tests/transform/b/BUILD.bazel new file mode 100644 index 0000000000..73c46e499f --- /dev/null +++ b/extensions/prost/private/tests/transform/b/BUILD.bazel @@ -0,0 +1,47 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_rust//rust:defs.bzl", "rust_library", "rust_test") +load("//:defs.bzl", "rust_prost_library", "rust_prost_transform") + +package(default_visibility = ["//private/tests:__subpackages__"]) + +rust_library( + name = "greeting", + srcs = ["greeting.rs"], + edition = "2021", +) + +rust_prost_transform( + name = "transform", + srcs = ["b_src.rs"], + deps = [":greeting"], +) + +proto_library( + name = "b_proto", + srcs = [ + "b.proto", + ], + data = [ + ":transform", + ], + strip_import_prefix = "/private/tests/transform", + deps = [ + "//private/tests/transform/b/c:c_proto", + "@com_google_protobuf//:empty_proto", + ], +) + +rust_prost_library( + name = "b_rs_proto", + proto = ":b_proto", +) + +rust_test( + name = "b_test", + srcs = ["b_test.rs"], + edition = "2021", + deps = [ + ":b_rs_proto", + ":greeting", + ], +) diff --git a/extensions/prost/private/tests/transform/b/b.proto b/extensions/prost/private/tests/transform/b/b.proto new file mode 100644 index 0000000000..f28db6a735 --- /dev/null +++ b/extensions/prost/private/tests/transform/b/b.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +import "google/protobuf/empty.proto"; +import "b/c/c.proto"; + +package a.b; + +message B { + string name = 1; + + google.protobuf.Empty empty = 2; + + a.b.c.C c = 3; +} diff --git a/extensions/prost/private/tests/transform/b/b_src.rs b/extensions/prost/private/tests/transform/b/b_src.rs new file mode 100644 index 0000000000..63e11c4bd9 --- /dev/null +++ b/extensions/prost/private/tests/transform/b/b_src.rs @@ -0,0 +1,9 @@ +// Additional source code for `b.proto`. + +pub use greeting::Greeting; + +impl Greeting for crate::a::b::B { + fn greet(&self, name: &str) -> String { + format!("Hallo, {}, my name is B!", name) + } +} diff --git a/extensions/prost/private/tests/transform/b/b_test.rs b/extensions/prost/private/tests/transform/b/b_test.rs new file mode 100644 index 0000000000..0d3483bd79 --- /dev/null +++ b/extensions/prost/private/tests/transform/b/b_test.rs @@ -0,0 +1,20 @@ +//! Tests transitive dependencies. + +use b_proto::a::b::B; +use b_proto::c_proto::a::b::c::C; + +use greeting::Greeting; + +#[test] +fn test_b() { + let b = B { + name: "b".to_string(), + c: Some(C { + name: "c".to_string(), + ..Default::default() + }), + ..Default::default() + }; + + assert_eq!("Hallo, Bazel, my name is B!", b.greet("Bazel")); +} diff --git a/extensions/prost/private/tests/transform/b/c/BUILD.bazel b/extensions/prost/private/tests/transform/b/c/BUILD.bazel new file mode 100644 index 0000000000..1cc73826d5 --- /dev/null +++ b/extensions/prost/private/tests/transform/b/c/BUILD.bazel @@ -0,0 +1,31 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_rust//rust:defs.bzl", "rust_test") +load("//:defs.bzl", "rust_prost_library") + +package(default_visibility = ["//private/tests:__subpackages__"]) + +proto_library( + name = "c_proto", + srcs = [ + "c.proto", + ], + strip_import_prefix = "/private/tests/transform", + deps = [ + "@com_google_protobuf//:any_proto", + "@com_google_protobuf//:duration_proto", + ], +) + +rust_prost_library( + name = "c_rs_proto", + proto = ":c_proto", +) + +rust_test( + name = "c_test", + srcs = ["c_test.rs"], + edition = "2021", + deps = [ + ":c_rs_proto", + ], +) diff --git a/extensions/prost/private/tests/transform/b/c/c.proto b/extensions/prost/private/tests/transform/b/c/c.proto new file mode 100644 index 0000000000..c8f06d0e08 --- /dev/null +++ b/extensions/prost/private/tests/transform/b/c/c.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +import "google/protobuf/any.proto"; +import "google/protobuf/duration.proto"; + +package a.b.c; + +message C { + string name = 1; + + google.protobuf.Any any = 3; + + google.protobuf.Duration duration = 4; +} diff --git a/extensions/prost/private/tests/transform/b/c/c_test.rs b/extensions/prost/private/tests/transform/b/c/c_test.rs new file mode 100644 index 0000000000..549c440772 --- /dev/null +++ b/extensions/prost/private/tests/transform/b/c/c_test.rs @@ -0,0 +1,20 @@ +//! Tests transitive dependencies. + +use c_proto::a::b::c::C; +use c_proto::any_proto::google::protobuf::Any; +use c_proto::duration_proto::google::protobuf::Duration; + +#[test] +fn test_c() { + let c = C { + name: "c".to_string(), + any: Some(Any::default()), + duration: Some(Duration { + seconds: 1, + nanos: 0, + }), + }; + + assert_eq!(c.name, "c"); + assert_eq!(c.duration.unwrap().seconds, 1); +} diff --git a/extensions/prost/private/tests/transform/b/greeting.rs b/extensions/prost/private/tests/transform/b/greeting.rs new file mode 100644 index 0000000000..5ebfe16233 --- /dev/null +++ b/extensions/prost/private/tests/transform/b/greeting.rs @@ -0,0 +1,5 @@ +//! Implement a trait which is used to generate greetings. + +pub trait Greeting { + fn greet(&self, name: &str) -> String; +}