Skip to content

Commit

Permalink
feat(extensions): include Substrait core extensions (#187)
Browse files Browse the repository at this point in the history
Include core extensions from `Subtrait`.
The majority of the code originates from the un-merged pr #89.
---------

Co-authored-by: Matthijs Brobbel <[email protected]>
  • Loading branch information
shanretoo and mbrobbel authored May 21, 2024
1 parent d1c7318 commit b9fba0f
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 36 deletions.
7 changes: 3 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ include = [

[features]
default = []
extensions = ["dep:once_cell", "dep:serde_yaml"]
parse = ["dep:hex", "dep:thiserror", "dep:url", "semver"]
protoc = ["dep:protobuf-src"]
semver = ["dep:semver"]
serde = ["dep:pbjson", "dep:pbjson-build", "dep:pbjson-types"]

[dependencies]
hex = { version = "0.4.3", optional = true }
once_cell = { version = "1.19.0", optional = true }
pbjson = { version = "0.6.0", optional = true }
pbjson-types = { version = "0.6.0", optional = true }
prost = "0.12.3"
Expand All @@ -41,6 +43,7 @@ url = { version = "2.5.0", optional = true }
semver = { version = "1.0.22", optional = true }
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.114"
serde_yaml = { version = "0.9.32", optional = true }
thiserror = { version = "1.0.57", optional = true }

[build-dependencies]
Expand All @@ -56,10 +59,6 @@ syn = "2.0.11"
typify = "0.1.0"
walkdir = "2.5.0"

[dev-dependencies]
serde_yaml = "0.9.32"
walkdir = "2.5.0"

[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]
100 changes: 96 additions & 4 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@ use std::{
use walkdir::{DirEntry, WalkDir};

const SUBMODULE_ROOT: &str = "substrait";
#[cfg(feature = "extensions")]
const EXTENSIONS_ROOT: &str = "substrait/extensions";
const PROTO_ROOT: &str = "substrait/proto";
const TEXT_ROOT: &str = "substrait/text";
const GEN_ROOT: &str = "gen";

/// Add Substrait version information to the build
fn substrait_version() -> Result<(), Box<dyn Error>> {
fn substrait_version() -> Result<semver::Version, Box<dyn Error>> {
let gen_dir = Path::new(GEN_ROOT);
fs::create_dir_all(gen_dir)?;

let version_in_file = gen_dir.join("version.in");
let substrait_version_file = gen_dir.join("version");

// Rerun if the Substrait submodule changed (to allow setting `dirty`)
println!(
Expand Down Expand Up @@ -98,14 +101,22 @@ pub const SUBSTRAIT_GIT_DIRTY: bool = {git_dirty};
"#
),
)?;

// Also write the version to a file
fs::write(substrait_version_file, version.to_string())?;

Ok(version)
} else {
// If we don't have a version file yet we fail the build.
if !version_in_file.exists() {
panic!("Couldn't find the substrait submodule. Please clone the submodule: `git submodule update --init`.")
}
}

Ok(())
// File exists we should get the version and return it.
Ok(semver::Version::parse(&fs::read_to_string(
substrait_version_file,
)?)?)
}
}

/// `text` type generation
Expand Down Expand Up @@ -168,6 +179,84 @@ pub mod {title} {{
Ok(())
}

#[cfg(feature = "extensions")]
/// Add Substrait core extensions
fn extensions(version: semver::Version, out_dir: &Path) -> Result<(), Box<dyn Error>> {
use std::collections::HashMap;

let substrait_extensions_file = out_dir.join("extensions.in");

let mut output = String::from(
r#"// SPDX-License-Identifier: Apache-2.0
// Note that this file is auto-generated and auto-synced using `build.rs`. It is
// included in `extensions.rs`.
"#,
);
let mut map = HashMap::<String, String>::default();
for extension in WalkDir::new(EXTENSIONS_ROOT)
.into_iter()
.filter_map(Result::ok)
.filter(|entry| entry.file_type().is_file())
.filter(|entry| {
entry
.path()
.extension()
.filter(|&extension| extension == "yaml")
.is_some()
})
.map(DirEntry::into_path)
.inspect(|entry| {
println!("cargo:rerun-if-changed={}", entry.display());
})
{
let name = extension.file_stem().unwrap_or_default().to_string_lossy();
let url = format!(
"https://github.com/substrait-io/substrait/raw/v{}/extensions/{}",
version,
extension.file_name().unwrap_or_default().to_string_lossy()
);
let var_name = name.to_uppercase();
output.push_str(&format!(
r#"
/// Included source of [`{name}`]({url}).
const {var_name}: &str = include_str!("{}/{}");
"#,
PathBuf::from(dbg!(env::var("CARGO_MANIFEST_DIR").unwrap())).display(),
extension.display()
));
map.insert(url, var_name);
}
// Add static lookup map.
output.push_str(
r#"
use std::collections::HashMap;
use std::str::FromStr;
use once_cell::sync::Lazy;
use crate::text::simple_extensions::SimpleExtensions;
use url::Url;
/// Map with Substrait core extensions. Maps URIs to included extensions.
pub static EXTENSIONS: Lazy<HashMap<Url, SimpleExtensions>> = Lazy::new(|| {
let mut map = HashMap::new();"#,
);

for (url, var_name) in map {
output.push_str(&format!(r#"
map.insert(Url::from_str("{url}").expect("a valid url"), serde_yaml::from_str({var_name}).expect("a valid core extension"));"#));
}

output.push_str(
r#"
map
});"#,
);

// Write the file.
fs::write(substrait_extensions_file, output)?;

Ok(())
}

#[cfg(feature = "serde")]
/// Serialize and deserialize implementations for proto types using `pbjson`
fn serde(protos: &[impl AsRef<Path>], out_dir: PathBuf) -> Result<(), Box<dyn Error>> {
Expand All @@ -191,7 +280,7 @@ fn main() -> Result<(), Box<dyn Error>> {
// for use in docker build where file changes can be wonky
println!("cargo:rerun-if-env-changed=FORCE_REBUILD");

substrait_version()?;
let version = substrait_version()?;

#[cfg(feature = "protoc")]
std::env::set_var("PROTOC", protobuf_src::protoc());
Expand All @@ -200,6 +289,9 @@ fn main() -> Result<(), Box<dyn Error>> {

text(out_dir.as_path())?;

#[cfg(feature = "extensions")]
extensions(version, out_dir.as_path())?;

let protos = WalkDir::new(PROTO_ROOT)
.into_iter()
.filter_map(Result::ok)
Expand Down
22 changes: 22 additions & 0 deletions src/extensions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// SPDX-License-Identifier: Apache-2.0

//! Substrait core extensions
//!
//! The contents of this module are auto-generated using `build.rs`. It is
//! included in the packaged crate, ignored by git, and automatically kept
//! in-sync.
include!(concat!(env!("OUT_DIR"), "/extensions.in"));

#[cfg(test)]
mod tests {
use super::*;

use once_cell::sync::Lazy;

#[test]
fn core_extensions() {
// Force evaluation of core extensions.
Lazy::force(&EXTENSIONS);
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
#![deny(missing_docs)]

#[cfg(feature = "extensions")]
pub mod extensions;
#[allow(missing_docs)]
pub mod proto;
#[allow(missing_docs)]
Expand Down
28 changes: 0 additions & 28 deletions src/text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,3 @@
//! Generated types for text-based definitions.
include!(concat!(env!("OUT_DIR"), "/substrait_text.rs"));

#[cfg(test)]
mod tests {
use crate::text::simple_extensions::SimpleExtensions;
use std::{fs, path::PathBuf};
use walkdir::{DirEntry, WalkDir};

#[test]
fn deserialize_core_extensions() {
WalkDir::new(PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("substrait/extensions"))
.into_iter()
.filter_map(Result::ok)
.filter(|entry| entry.file_type().is_file())
.filter(|entry| {
entry
.path()
.extension()
.filter(|extension| extension == &"yaml")
.is_some()
})
.map(DirEntry::into_path)
.for_each(|path| {
let file = fs::read_to_string(path).unwrap();
let simple_extension = serde_yaml::from_str::<SimpleExtensions>(&file);
assert!(simple_extension.is_ok());
});
}
}

0 comments on commit b9fba0f

Please sign in to comment.