-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f2c8b89
commit 69906ab
Showing
46 changed files
with
1,238 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,5 @@ build | |
*.egg-info | ||
*.pyc | ||
.idea | ||
target | ||
Cargo.lock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[workspace] | ||
members = ["crates/*"] | ||
resolver = "2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/target | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
.pytest_cache/ | ||
*.py[cod] | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
.venv/ | ||
env/ | ||
bin/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
include/ | ||
man/ | ||
venv/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
pip-selfcheck.json | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
|
||
# Translations | ||
*.mo | ||
|
||
# Mr Developer | ||
.mr.developer.cfg | ||
.project | ||
.pydevproject | ||
|
||
# Rope | ||
.ropeproject | ||
|
||
# Django stuff: | ||
*.log | ||
*.pot | ||
|
||
.DS_Store | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyCharm | ||
.idea/ | ||
|
||
# VSCode | ||
.vscode/ | ||
|
||
# Pyenv | ||
.python-version |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
[package] | ||
name = "bh_agent_client" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
[lib] | ||
name = "bh_agent_client" | ||
crate-type = ["cdylib"] | ||
|
||
[dependencies] | ||
pyo3 = "0.19.0" | ||
bh_agent_common = { path = "../bh_agent_common" } | ||
tokio = "1.32.0" | ||
anyhow = "1.0.75" | ||
tarpc = { version = "0.33.0", features = ["full"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
use crate::client::build_client; | ||
use anyhow::Result; | ||
use bh_agent_common::{ | ||
AgentError, BhAgentServiceClient, EnvironmentId, FileId, FileOpenMode, FileOpenType, | ||
ProcessChannel, ProcessId, Redirection, RemotePOpenConfig, | ||
}; | ||
use pyo3::exceptions::PyRuntimeError; | ||
use pyo3::prelude::*; | ||
use pyo3::{pyclass, pymethods, pymodule, PyResult, Python}; | ||
use std::future::Future; | ||
use std::net::{IpAddr, SocketAddr}; | ||
use std::str::FromStr; | ||
use tarpc::client::RpcError; | ||
use tarpc::context; | ||
use tokio::runtime; | ||
|
||
#[pyclass] | ||
struct BhAgentClient { | ||
tokio_runtime: runtime::Runtime, | ||
client: BhAgentServiceClient, | ||
} | ||
|
||
fn run_in_runtime<F, R>(client: &BhAgentClient, fut: F) -> PyResult<R> | ||
where | ||
F: Future<Output = Result<Result<R, AgentError>, RpcError>> + Sized, | ||
{ | ||
client | ||
.tokio_runtime | ||
.block_on(fut) | ||
.map_err(|e| PyRuntimeError::new_err(e.to_string())) | ||
.map(|r| r.map_err(|e| PyRuntimeError::new_err(e.to_string()))) | ||
.and_then(|r| r) | ||
} | ||
|
||
#[pymethods] | ||
impl BhAgentClient { | ||
#[staticmethod] | ||
fn initialize_client(ip_addr: String, port: u16) -> PyResult<Self> { | ||
let ip_addr = IpAddr::from_str(&ip_addr)?; | ||
let socket_addr = SocketAddr::new(ip_addr, port); | ||
|
||
let tokio_runtime = runtime::Builder::new_current_thread() | ||
.enable_all() | ||
.build() | ||
.unwrap(); | ||
match tokio_runtime.block_on(build_client(socket_addr)) { | ||
Ok(client) => Ok(Self { | ||
tokio_runtime, | ||
client, | ||
}), | ||
Err(e) => Err(PyRuntimeError::new_err(format!( | ||
"Failed to initialize client: {}", | ||
e | ||
))), | ||
} | ||
} | ||
|
||
fn get_environments(&self) -> PyResult<Vec<EnvironmentId>> { | ||
self.tokio_runtime | ||
.block_on(self.client.get_environments(context::current())) | ||
.map_err(|e| PyRuntimeError::new_err(e.to_string())) | ||
} | ||
|
||
fn get_tempdir(&self, env_id: EnvironmentId) -> PyResult<String> { | ||
run_in_runtime(self, self.client.get_tempdir(context::current(), env_id)) | ||
} | ||
|
||
fn run_process( | ||
&self, | ||
env_id: EnvironmentId, | ||
argv: Vec<String>, | ||
stdin: bool, | ||
stdout: bool, | ||
stderr: bool, | ||
executable: Option<String>, | ||
env: Option<Vec<(String, String)>>, | ||
cwd: Option<String>, | ||
setuid: Option<u32>, | ||
setgid: Option<u32>, | ||
setpgid: bool, | ||
) -> PyResult<ProcessId> { | ||
let config = RemotePOpenConfig { | ||
argv, | ||
stdin: match stdin { | ||
true => Redirection::Save, | ||
false => Redirection::None, | ||
}, | ||
stdout: match stdout { | ||
true => Redirection::Save, | ||
false => Redirection::None, | ||
}, | ||
stderr: match stderr { | ||
true => Redirection::Save, | ||
false => Redirection::None, | ||
}, | ||
executable, | ||
env, | ||
cwd, | ||
setuid, | ||
setgid, | ||
setpgid, | ||
}; | ||
run_in_runtime( | ||
self, | ||
self.client.run_command(context::current(), env_id, config), | ||
) | ||
} | ||
|
||
fn get_process_channel( | ||
&self, | ||
env_id: EnvironmentId, | ||
proc_id: ProcessId, | ||
channel: i32, // TODO: This is just 0, 1, 2 for now | ||
) -> PyResult<FileId> { | ||
run_in_runtime( | ||
self, | ||
self.client.get_process_channel( | ||
context::current(), | ||
env_id, | ||
proc_id, | ||
match channel { | ||
0 => ProcessChannel::Stdin, | ||
1 => ProcessChannel::Stdout, | ||
2 => ProcessChannel::Stderr, | ||
_ => return Err(PyRuntimeError::new_err("Invalid channel")), | ||
}, | ||
), | ||
) | ||
} | ||
|
||
// File IO | ||
fn file_open( | ||
&self, | ||
env_id: EnvironmentId, | ||
path: String, | ||
mode_and_type: String, | ||
) -> PyResult<FileId> { | ||
// Mode parsing | ||
let mut mode = FileOpenMode::Read; | ||
mode_and_type.chars().for_each(|c| match c { | ||
'r' => mode = FileOpenMode::Read, | ||
'w' => mode = FileOpenMode::Write, | ||
'x' => mode = FileOpenMode::ExclusiveWrite, | ||
'a' => mode = FileOpenMode::Append, | ||
'+' => mode = FileOpenMode::Update, | ||
_ => {} | ||
}); | ||
|
||
// Type parsing | ||
let mut type_ = FileOpenType::Text; | ||
if mode_and_type.contains("b") { | ||
type_ = FileOpenType::Binary; | ||
} | ||
|
||
run_in_runtime( | ||
self, | ||
self.client | ||
.file_open(context::current(), env_id, path, mode, type_), | ||
) | ||
} | ||
|
||
fn file_close(&self, env_id: EnvironmentId, fd: FileId) -> PyResult<()> { | ||
run_in_runtime(self, self.client.file_close(context::current(), env_id, fd)) | ||
} | ||
|
||
fn file_is_closed(&self, env_id: EnvironmentId, fd: FileId) -> PyResult<bool> { | ||
run_in_runtime( | ||
self, | ||
self.client.file_is_closed(context::current(), env_id, fd), | ||
) | ||
} | ||
|
||
fn file_is_readable(&self, env_id: EnvironmentId, fd: FileId) -> PyResult<bool> { | ||
run_in_runtime( | ||
self, | ||
self.client.file_is_readable(context::current(), env_id, fd), | ||
) | ||
} | ||
|
||
fn file_read(&self, env_id: EnvironmentId, fd: FileId, num_bytes: u32) -> PyResult<Vec<u8>> { | ||
run_in_runtime( | ||
self, | ||
self.client | ||
.file_read(context::current(), env_id, fd, num_bytes), | ||
) | ||
} | ||
|
||
fn file_read_lines( | ||
&self, | ||
env_id: EnvironmentId, | ||
fd: FileId, | ||
hint: u32, | ||
) -> PyResult<Vec<Vec<u8>>> { | ||
run_in_runtime( | ||
self, | ||
self.client | ||
.file_read_lines(context::current(), env_id, fd, hint), | ||
) | ||
} | ||
|
||
fn file_is_seekable(&self, env_id: EnvironmentId, fd: FileId) -> PyResult<bool> { | ||
run_in_runtime( | ||
self, | ||
self.client.file_is_seekable(context::current(), env_id, fd), | ||
) | ||
} | ||
|
||
fn file_seek( | ||
&self, | ||
env_id: EnvironmentId, | ||
fd: FileId, | ||
offset: i32, | ||
whence: i32, | ||
) -> PyResult<()> { | ||
run_in_runtime( | ||
self, | ||
self.client | ||
.file_seek(context::current(), env_id, fd, offset, whence), | ||
) | ||
} | ||
|
||
fn file_tell(&self, env_id: EnvironmentId, fd: FileId) -> PyResult<i32> { | ||
run_in_runtime(self, self.client.file_tell(context::current(), env_id, fd)) | ||
} | ||
|
||
fn file_is_writable(&self, env_id: EnvironmentId, fd: FileId) -> PyResult<bool> { | ||
run_in_runtime( | ||
self, | ||
self.client.file_is_writable(context::current(), env_id, fd), | ||
) | ||
} | ||
|
||
fn file_write(&self, env_id: EnvironmentId, fd: FileId, data: Vec<u8>) -> PyResult<()> { | ||
run_in_runtime( | ||
self, | ||
self.client.file_write(context::current(), env_id, fd, data), | ||
) | ||
} | ||
} | ||
|
||
#[pymodule] | ||
pub fn bh_agent_client(_py: Python, m: &PyModule) -> PyResult<()> { | ||
m.add_class::<BhAgentClient>()?; | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
use bh_agent_common::BhAgentServiceClient; | ||
use tarpc::{client, tokio_serde::formats::Json}; | ||
use tokio::net::ToSocketAddrs; | ||
|
||
pub async fn build_client<A>(socket_addr: A) -> anyhow::Result<BhAgentServiceClient> | ||
where | ||
A: ToSocketAddrs, | ||
{ | ||
let mut transport = tarpc::serde_transport::tcp::connect(socket_addr, Json::default); | ||
transport.config_mut().max_frame_length(usize::MAX); | ||
|
||
let client = BhAgentServiceClient::new(client::Config::default(), transport.await?).spawn(); | ||
|
||
Ok(client) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
mod bindings; | ||
mod client; | ||
|
||
pub use bindings::bh_agent_client; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
[package] | ||
name = "bh_agent_common" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
anyhow = { version = "1.0.75", features = [] } | ||
tarpc = { version = "0.33.0", features = ["tokio1"] } | ||
serde = { version = "1.0.188", features = ["derive"] } | ||
thiserror = "1.0.48" |
Oops, something went wrong.