Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit imported interfaces for composed components #83

Merged
merged 2 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ wasm-compose = "0.215.0"
wasm-metadata = "0.212.0"
wasm-opt = { version = "0.116.1", optional = true }
wit-component = "0.212.0"
wit-parser = "0.212.0"

[build-dependencies]
anyhow = "1"
Expand Down
3 changes: 2 additions & 1 deletion src/bin/wasi-virt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ fn main() -> Result<()> {
let mut virt_opts = WasiVirt::default();

virt_opts.debug = args.debug.unwrap_or_default();
virt_opts.compose = args.compose;

// By default, we virtualize all subsystems
// This ensures full encapsulation in the default (no argument) case
Expand Down Expand Up @@ -193,7 +194,7 @@ fn main() -> Result<()> {

let out_path = PathBuf::from(args.out);

let out_bytes = if let Some(compose_path) = args.compose {
let out_bytes = if let Some(compose_path) = virt_opts.compose {
let compose_path = PathBuf::from(compose_path);
let dir = env::temp_dir();
let tmp_virt = dir.join(format!("virt{}.wasm", timestamp()));
Expand Down
71 changes: 69 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use anyhow::{Context, Result};
use anyhow::{bail, Context, Result};
use serde::Deserialize;
use std::env;
use std::{env, fs};
use virt_deny::{
deny_clocks_virt, deny_exit_virt, deny_http_virt, deny_random_virt, deny_sockets_virt,
};
Expand All @@ -9,6 +9,7 @@ use virt_io::{create_io_virt, VirtStdio};
use walrus_ops::strip_virt;
use wasm_metadata::Producers;
use wit_component::{metadata, ComponentEncoder, DecodedWasm, StringEncoding};
use wit_parser::WorldItem;

mod data;
mod stub_preview1;
Expand Down Expand Up @@ -58,6 +59,9 @@ pub struct WasiVirt {
pub random: Option<bool>,
/// Disable wasm-opt run if desired
pub wasm_opt: Option<bool>,
/// Path to compose component
pub compose: Option<String>,
guybedford marked this conversation as resolved.
Show resolved Hide resolved
compose_imports: Option<Vec<String>>,
}

pub struct VirtResult {
Expand Down Expand Up @@ -138,6 +142,38 @@ impl WasiVirt {
}?;
module.name = Some("wasi_virt".into());

// drop capabilities that are not imported by the composed component
if self.compose.is_some() {
self.collect_compose_imports()?;

if !self.contains_compose_import("wasi:cli/environment") {
self.env = None;
}
if !self.contains_compose_import("wasi:filesystem/") {
self.fs = None;
}
if !(self.contains_compose_import("wasi:cli/std")
|| self.contains_compose_import("wasi:cli/terminal"))
{
self.stdio = None;
}
if !self.contains_compose_import("wasi:cli/exit") {
self.exit = None;
}
if !self.contains_compose_import("wasi:clocks/") {
self.clocks = None;
}
if !self.contains_compose_import("wasi:http/") {
self.http = None;
}
if !self.contains_compose_import("wasi:sockets/") {
self.sockets = None;
}
if !self.contains_compose_import("wasi:random/") {
self.random = None;
}
}

// only env virtualization is independent of io
if let Some(env) = &self.env {
create_env_virt(&mut module, env)?;
Expand Down Expand Up @@ -306,6 +342,37 @@ impl WasiVirt {
virtual_files,
})
}

// parse the compose component to collect its imported interfaces
fn collect_compose_imports(&mut self) -> Result<()> {
let module_bytes = fs::read(self.compose.as_ref().unwrap()).map_err(anyhow::Error::new)?;
let (resolve, world_id) = match wit_component::decode(&module_bytes)? {
DecodedWasm::WitPackages(..) => {
bail!("expected a component, found a WIT package")
}
DecodedWasm::Component(resolve, world_id) => (resolve, world_id),
};

let mut import_ids: Vec<String> = vec![];
for (_, import) in &resolve.worlds[world_id].imports {
if let WorldItem::Interface { id, .. } = import {
if let Some(id) = resolve.id_of(*id) {
import_ids.push(id);
}
}
}

self.compose_imports = Some(import_ids);

Ok(())
}

fn contains_compose_import(&self, prefix: &str) -> bool {
match &self.compose_imports {
Some(imports) => imports.iter().any(|i| i.starts_with(prefix)),
None => false,
}
}
}

fn apply_wasm_opt(bytes: Vec<u8>, debug: bool) -> Result<Vec<u8>> {
Expand Down
18 changes: 18 additions & 0 deletions tests/cases/encapsulate-component.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
component = "do-everything"
compose = true

[virt-opts]
clocks = true
# http will be filtered out because the "do-everything" component doesn't import it
http = true
stdio.stdin = "ignore"
stdio.stdout = "ignore"
stdio.stderr = "ignore"

[expect.imports]
required = [
"wasi:clocks/wall-clock",
]
disallowed = [
"wasi:http/incoming-handler",
]
75 changes: 73 additions & 2 deletions tests/virt.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use anyhow::{anyhow, Context, Result};
use anyhow::{anyhow, bail, Context, Result};
use heck::ToSnakeCase;
use serde::Deserialize;
use std::collections::BTreeMap;
Expand All @@ -13,7 +13,8 @@ use wasmtime::{
Config, Engine, Store, WasmBacktraceDetails,
};
use wasmtime_wasi::{DirPerms, FilePerms, WasiCtx, WasiCtxBuilder, WasiView};
use wit_component::ComponentEncoder;
use wit_component::{ComponentEncoder, DecodedWasm};
use wit_parser::WorldItem;

wasmtime::component::bindgen!({
world: "virt-test",
Expand Down Expand Up @@ -49,12 +50,21 @@ struct TestExpectation {
file_read: Option<String>,
encapsulation: Option<bool>,
stdout: Option<String>,
imports: Option<TestExpectationImports>,
}

#[derive(Deserialize, Debug)]
#[serde(rename_all = "kebab-case", deny_unknown_fields)]
struct TestExpectationImports {
required: Option<Vec<String>>,
disallowed: Option<Vec<String>>,
}

#[derive(Deserialize, Debug)]
#[serde(rename_all = "kebab-case", deny_unknown_fields)]
struct TestCase {
component: String,
compose: Option<bool>,
host_env: Option<BTreeMap<String, String>>,
host_fs_path: Option<String>,
virt_opts: Option<WasiVirt>,
Expand Down Expand Up @@ -136,6 +146,16 @@ async fn virt_test() -> Result<()> {
virt_opts.wasm_opt = Some(false);
}
}
if let Some(compose) = test.compose {
if compose {
let compose_path = generated_component_path
.clone()
.into_os_string()
.into_string()
.unwrap();
virt_opts.compose = Some(compose_path);
}
}

let virt_component = virt_opts.finish().with_context(|| {
format!(
Expand Down Expand Up @@ -267,6 +287,37 @@ async fn virt_test() -> Result<()> {
instance.call_test_stdio(&mut store).await?;
}

if let Some(expect_imports) = &test.expect.imports {
let component_imports = collect_component_imports(component_bytes)?;

if let Some(required_imports) = &expect_imports.required {
for required_import in required_imports {
if !component_imports
.iter()
.any(|i| i.starts_with(required_import))
{
return Err(anyhow!(
"Required import missing {required_import} {:?}",
test_case_path
));
}
}
}
if let Some(disallowed_imports) = &expect_imports.disallowed {
for disallowed_import in disallowed_imports {
if component_imports
.iter()
.any(|i| i.starts_with(disallowed_import))
{
return Err(anyhow!(
"Disallowed import {disallowed_import} {:?}",
test_case_path
));
}
}
}
}

println!("\x1b[1;32m√\x1b[0m {:?}", test_case_path);
}
Ok(())
Expand Down Expand Up @@ -316,3 +367,23 @@ fn has_component_import(bytes: &[u8]) -> Result<Option<String>> {
}
}
}

fn collect_component_imports(component_bytes: Vec<u8>) -> Result<Vec<String>> {
let (resolve, world_id) = match wit_component::decode(&component_bytes)? {
DecodedWasm::WitPackages(..) => {
bail!("expected a component, found a WIT package")
}
DecodedWasm::Component(resolve, world_id) => (resolve, world_id),
};

let mut import_ids: Vec<String> = vec![];
for (_, import) in &resolve.worlds[world_id].imports {
if let WorldItem::Interface { id, .. } = import {
if let Some(id) = resolve.id_of(*id) {
import_ids.push(id);
}
}
}

Ok(import_ids)
}