Skip to content

Commit

Permalink
Added rust_prost_transform rule for modifying granular proto_library
Browse files Browse the repository at this point in the history
  • Loading branch information
UebelAndre committed Dec 11, 2024
1 parent c77ab04 commit d3a6108
Show file tree
Hide file tree
Showing 16 changed files with 474 additions and 9 deletions.
5 changes: 5 additions & 0 deletions extensions/prost/defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ load(
_rust_prost_library = "rust_prost_library",
_rust_prost_toolchain = "rust_prost_toolchain",
)
load(
"//private:prost_transform.bzl",
_rust_prost_transform = "rust_prost_transform",
)

rust_prost_library = _rust_prost_library
rust_prost_toolchain = _rust_prost_toolchain
rust_prost_transform = _rust_prost_transform
54 changes: 46 additions & 8 deletions extensions/prost/private/prost.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ load("@rules_rust//rust/private:rustc.bzl", "rustc_compile_action")
# buildifier: disable=bzl-visibility
load("@rules_rust//rust/private:utils.bzl", "can_build_metadata")
load("//:providers.bzl", "ProstProtoInfo")
load(":prost_transform.bzl", "ProstTransformInfo")

RUST_EDITION = "2021"

Expand All @@ -39,7 +40,15 @@ def _create_proto_lang_toolchain(ctx, prost_toolchain):

return proto_lang_toolchain

def _compile_proto(ctx, crate_name, proto_info, deps, prost_toolchain, rustfmt_toolchain = None):
def _compile_proto(
*,
ctx,
crate_name,
proto_info,
transform_infos,
deps,
prost_toolchain,
rustfmt_toolchain = None):
deps_info_file = ctx.actions.declare_file(ctx.label.name + ".prost_deps_info")
dep_package_infos = [dep[ProstProtoInfo].package_info for dep in deps]
ctx.actions.write(
Expand All @@ -53,6 +62,15 @@ def _compile_proto(ctx, crate_name, proto_info, deps, prost_toolchain, rustfmt_t
proto_compiler = prost_toolchain.proto_compiler
tools = depset([proto_compiler.executable])

tonic_opts = []
prost_opts = []
additional_srcs = []
for transform_info in transform_infos:
tonic_opts.extend(transform_info.tonic_opts)
prost_opts.extend(transform_info.prost_opts)
additional_srcs.append(transform_info.srcs)

all_additional_srcs = depset(transitive = additional_srcs)
direct_crate_names = [dep[ProstProtoInfo].dep_variant_info.crate_info.name for dep in deps]
additional_args = ctx.actions.args()

Expand All @@ -65,22 +83,27 @@ def _compile_proto(ctx, crate_name, proto_info, deps, prost_toolchain, rustfmt_t
additional_args.add("--direct_dep_crate_names={}".format(",".join(direct_crate_names)))
additional_args.add("--prost_opt=compile_well_known_types")
additional_args.add("--descriptor_set={}".format(proto_info.direct_descriptor_set.path))
additional_args.add_all(prost_toolchain.prost_opts, format_each = "--prost_opt=%s")
additional_args.add("--additional_srcs={}".format(",".join([f.path for f in all_additional_srcs.to_list()])))
additional_args.add_all(prost_toolchain.prost_opts + prost_opts, format_each = "--prost_opt=%s")

if prost_toolchain.tonic_plugin:
tonic_plugin = prost_toolchain.tonic_plugin[DefaultInfo].files_to_run
additional_args.add(prost_toolchain.tonic_plugin_flag % tonic_plugin.executable.path)
additional_args.add("--tonic_opt=no_include")
additional_args.add("--tonic_opt=compile_well_known_types")
additional_args.add("--is_tonic")
additional_args.add_all(prost_toolchain.tonic_opts, format_each = "--tonic_opt=%s")

additional_args.add_all(prost_toolchain.tonic_opts + tonic_opts, format_each = "--tonic_opt=%s")
tools = depset([tonic_plugin.executable], transitive = [tools])

if rustfmt_toolchain:
additional_args.add("--rustfmt={}".format(rustfmt_toolchain.rustfmt.path))
tools = depset(transitive = [tools, rustfmt_toolchain.all_files])

additional_inputs = depset([deps_info_file, proto_info.direct_descriptor_set] + [dep[ProstProtoInfo].package_info for dep in deps])
additional_inputs = depset(
[deps_info_file, proto_info.direct_descriptor_set] + [dep[ProstProtoInfo].package_info for dep in deps],
transitive = [all_additional_srcs],
)

proto_common.compile(
actions = ctx.actions,
Expand Down Expand Up @@ -116,7 +139,14 @@ def _get_cc_info(providers):
return provider
fail("Couldn't find a CcInfo in the list of providers")

def _compile_rust(ctx, attr, crate_name, src, deps, edition):
def _compile_rust(
*,
ctx,
attr,
crate_name,
src,
deps,
edition):
"""Compiles a Rust source file.
Args:
Expand Down Expand Up @@ -233,7 +263,14 @@ def _rust_prost_aspect_impl(target, ctx):
if RustAnalyzerInfo in proto_dep:
rust_analyzer_deps.append(proto_dep[RustAnalyzerInfo])

deps = runtime_deps + direct_deps
transform_infos = []
for data_target in getattr(ctx.rule.attr, "data", []):
if ProstTransformInfo in data_target:
transform_infos.append(data_target[ProstTransformInfo])

rust_deps = runtime_deps + direct_deps
for transform_info in transform_infos:
rust_deps.extend(transform_info.deps)

crate_name = ctx.label.name.replace("-", "_").replace("/", "_")

Expand All @@ -243,6 +280,7 @@ def _rust_prost_aspect_impl(target, ctx):
ctx = ctx,
crate_name = crate_name,
proto_info = proto_info,
transform_infos = transform_infos,
deps = proto_deps,
prost_toolchain = prost_toolchain,
rustfmt_toolchain = rustfmt_toolchain,
Expand All @@ -253,7 +291,7 @@ def _rust_prost_aspect_impl(target, ctx):
attr = ctx.rule.attr,
crate_name = crate_name,
src = lib_rs,
deps = deps,
deps = rust_deps,
edition = RUST_EDITION,
)

Expand Down Expand Up @@ -495,7 +533,7 @@ def _current_prost_runtime_impl(ctx):
)]

current_prost_runtime = rule(
doc = "A rule for accessing the current Prost toolchain components needed by the process wrapper",
doc = "A rule for accessing the current Prost toolchain components needed by the process wrapper.",
provides = [rust_common.crate_group_info],
implementation = _current_prost_runtime_impl,
toolchains = [TOOLCHAIN_TYPE],
Expand Down
88 changes: 88 additions & 0 deletions extensions/prost/private/prost_transform.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Prost rules."""

load("@rules_rust//rust:defs.bzl", "rust_common")

ProstTransformInfo = provider(
doc = "Info about transformations to apply to Prost generated source code.",
fields = {
"deps": "List[DepVariantInfo]: Additional dependencies to compile into the Prost target.",
"prost_opts": "List[str]: Additional prost flags.",
"srcs": "Depset[File]: Additional source files to include in generated Prost source code.",
"tonic_opts": "List[str]: Additional tonic flags.",
},
)

def _rust_prost_transform_impl(ctx):
deps = []
for target in ctx.attr.deps:
deps.append(rust_common.dep_variant_info(
crate_info = target[rust_common.crate_info] if rust_common.crate_info in target else None,
dep_info = target[rust_common.dep_info] if rust_common.dep_info in target else None,
cc_info = target[CcInfo] if CcInfo in target else None,
build_info = None,
))

# DefaultInfo is intentionally not returned here to avoid impacting other
# consumers of the `proto_library` target this rule is expected to be passed
# to.
return [ProstTransformInfo(
deps = deps,
prost_opts = ctx.attr.prost_opts,
srcs = depset(ctx.files.srcs),
tonic_opts = ctx.attr.tonic_opts,
)]

rust_prost_transform = rule(
doc = """\
A rule for transforming the outputs of `ProstGenProto` actions.
This rule is used by adding it to the `data` attribute of `proto_library` targets. E.g.
```python
load("@rules_proto//proto:defs.bzl", "proto_library")
load("@rules_rust_prost//:defs.bzl", "rust_prost_library", "rust_prost_transform")
rust_prost_transform(
name = "a_transform",
srcs = [
"a_src.rs",
],
)
proto_library(
name = "a_proto",
srcs = [
"a.proto",
],
data = [
":transform",
],
)
rust_prost_library(
name = "a_rs_proto",
proto = ":a_proto",
)
```
The `rust_prost_library` will spawn an action on the `a_proto` target which consumes the
`a_transform` rule to provide a means of granularly modifying a proto library for `ProstGenProto`
actions with minimal impact to other consumers.
""",
implementation = _rust_prost_transform_impl,
attrs = {
"deps": attr.label_list(
doc = "Additional dependencies to add to the compiled crate.",
providers = [[rust_common.crate_info], [rust_common.crate_group_info]],
),
"prost_opts": attr.string_list(
doc = "Additional options to add to Prost.",
),
"srcs": attr.label_list(
doc = "Additional source files to include in generated Prost source code.",
allow_files = True,
),
"tonic_opts": attr.string_list(
doc = "Additional options to add to Tonic.",
),
},
)
44 changes: 43 additions & 1 deletion extensions/prost/private/protoc_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ impl Module {
}
}

const ADDITIONAL_CONTENT_HEADER: &str =
"// A D D I T I O N A L S O U R C E S ========================================";

/// Generate a lib.rs file with all prost/tonic outputs embeeded in modules which
/// mirror the proto packages. For the example proto file we would expect to see
/// the Rust output that follows it.
Expand Down Expand Up @@ -152,6 +155,7 @@ fn generate_lib_rs(
prost_outputs: &BTreeSet<PathBuf>,
is_tonic: bool,
direct_dep_crate_names: Vec<String>,
additional_content: String,
) -> String {
let mut contents = vec!["// @generated".to_string(), "".to_string()];
for crate_name in direct_dep_crate_names {
Expand Down Expand Up @@ -193,6 +197,14 @@ fn generate_lib_rs(

let mut content = String::new();
write_module(&mut content, &module_info, 0);

if !additional_content.is_empty() {
return format!(
"{}\n\n{}\n\n{}",
content, ADDITIONAL_CONTENT_HEADER, additional_content
);
}

content
}

Expand Down Expand Up @@ -421,6 +433,9 @@ struct Args {
/// The proto files to compile.
proto_files: Vec<PathBuf>,

/// Additional source files to append to the generated rust source.
additional_srcs: Vec<PathBuf>,

/// The include directories.
includes: Vec<String>,

Expand Down Expand Up @@ -454,6 +469,7 @@ impl Args {
let mut crate_name: Option<String> = None;
let mut package_info_file: Option<PathBuf> = None;
let mut proto_files: Vec<PathBuf> = Vec::new();
let mut additional_srcs: Vec<PathBuf> = Vec::new();
let mut includes = Vec::new();
let mut descriptor_set = None;
let mut out_librs: Option<PathBuf> = None;
Expand Down Expand Up @@ -521,6 +537,12 @@ impl Args {
}
}
}
("--additional_srcs", value) => {
if !value.is_empty() {
additional_srcs
.extend(value.split(',').map(PathBuf::from).collect::<Vec<_>>());
}
}
("--direct_dep_crate_names", value) => {
if value.trim().is_empty() {
return;
Expand Down Expand Up @@ -614,6 +636,7 @@ impl Args {
crate_name: crate_name.unwrap(),
package_info_file: package_info_file.unwrap(),
proto_files,
additional_srcs,
includes,
descriptor_set: descriptor_set.unwrap(),
out_librs: out_librs.unwrap(),
Expand Down Expand Up @@ -717,6 +740,7 @@ fn main() {
label,
package_info_file,
proto_files,
additional_srcs,
includes,
descriptor_set,
out_librs,
Expand All @@ -733,6 +757,19 @@ fn main() {
let package_name = get_package_name(&descriptor_set).unwrap_or_default();
let expect_rs = expect_fs_file_to_be_generated(&descriptor_set);
let has_services = has_services(&descriptor_set);
let additional_content = additional_srcs
.into_iter()
.map(|f| {
fs::read_to_string(&f).unwrap_or_else(|e| {
panic!(
"Failed to read additional source file: `{}`\n{:?}",
f.display(),
e
)
})
})
.collect::<Vec<_>>()
.join("\n");

if has_services && !is_tonic {
eprintln!("Warning: Service definitions will not be generated because the prost toolchain did not define a tonic plugin.");
Expand Down Expand Up @@ -875,7 +912,12 @@ fn main() {
// Write outputs
fs::write(
&out_librs,
generate_lib_rs(&rust_files, is_tonic, direct_dep_crate_names),
generate_lib_rs(
&rust_files,
is_tonic,
direct_dep_crate_names,
additional_content,
),
)
.expect("Failed to write file.");
fs::write(
Expand Down
43 changes: 43 additions & 0 deletions extensions/prost/private/tests/transform/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
load("@rules_proto//proto:defs.bzl", "proto_library")
load("@rules_rust//rust:defs.bzl", "rust_test")
load("//:defs.bzl", "rust_prost_library", "rust_prost_transform")

package(default_visibility = ["//private/tests:__subpackages__"])

rust_prost_transform(
name = "transform",
srcs = ["a_src.rs"],
)

proto_library(
name = "a_proto",
srcs = [
"a.proto",
],
data = [
":transform",
],
strip_import_prefix = "/private/tests/transform",
deps = [
"//private/tests/transform/b:b_proto",
"//private/tests/types:types_proto",
"@com_google_protobuf//:duration_proto",
"@com_google_protobuf//:timestamp_proto",
],
)

rust_prost_library(
name = "a_rs_proto",
proto = ":a_proto",
)

rust_test(
name = "a_test",
srcs = ["a_test.rs"],
edition = "2021",
deps = [
":a_rs_proto",
# Add b_proto as a dependency directly to ensure compatibility with `a.proto`'s imports.
"//private/tests/transform/b:b_rs_proto",
],
)
Loading

0 comments on commit d3a6108

Please sign in to comment.