Skip to content

Commit

Permalink
refactoring and integrating inference flow in server
Browse files Browse the repository at this point in the history
  • Loading branch information
rickfast committed Oct 12, 2023
1 parent e1e811f commit 0ba8602
Show file tree
Hide file tree
Showing 11 changed files with 303 additions and 66 deletions.
238 changes: 217 additions & 21 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ edition = "2021"

[workspace]
members = [
"ferrix-cli",
"ferrix-model-api",
"ferrix-protos",
"ferrix-server",
Expand Down
2 changes: 2 additions & 0 deletions ferrix-model-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ edition = "2021"
anyhow = "1.0.75"
ferrix-protos = { path = "../ferrix-protos" }
pyo3 = "0.19.1"
serde = "1.0.188"
thiserror = "1.0.43"
toml = "0.8.2"
11 changes: 3 additions & 8 deletions ferrix-model-api/src/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,16 +414,11 @@ impl TensorData {
// Utils
fn data_to_py(datatype: String, contents: TensorData, py: pyo3::Python<'_>) -> Py<PyList> {
match datatype.as_str() {
"INT8" => todo!(),
"INT16" => todo!(),
"INT32" => contents.int_contents.into_py(py).extract(py).unwrap(),
"INT8" | "INT16" | "INT32" => contents.int_contents.into_py(py).extract(py).unwrap(),
"INT64" => contents.int64_contents.into_py(py).extract(py).unwrap(),
"UINT8" => todo!(),
"UINT16" => todo!(),
"UINT32" => contents.uint_contents.into_py(py).extract(py).unwrap(),
"UINT8" | "UINT16" | "UINT32" => contents.uint_contents.into_py(py).extract(py).unwrap(),
"UINT64" => contents.uint64_contents.into_py(py).extract(py).unwrap(),
"FP16" => todo!(),
"FP32" => contents.fp32_contents.into_py(py).extract(py).unwrap(),
"FP16" | "FP32" => contents.fp32_contents.into_py(py).extract(py).unwrap(),
"FP64" => contents.fp64_contents.into_py(py).extract(py).unwrap(),
"BYTES" => todo!(),
_ => todo!(),
Expand Down
9 changes: 9 additions & 0 deletions ferrix-model-api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use internal::{InferRequest, InferResponse};
use serde::Deserialize;
use thiserror::Error;
use toml::Value;

pub mod internal;
pub mod python;
Expand All @@ -10,6 +12,13 @@ pub trait Model: Send + Sync {
fn predict(&self, request: InferRequest) -> ModelResult<InferResponse>;
}

#[derive(Deserialize)]
pub struct ModelConfig {
pub model_name: String,
pub base_path: String,
pub extended_config: Option<Value>,
}

pub type ModelResult<T> = std::result::Result<T, anyhow::Error>;

#[derive(Error, Debug)]
Expand Down
6 changes: 6 additions & 0 deletions ferrix-model-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ struct CandleModel {
module: Arc<dyn Module + Send + Sync>,
}

impl CandleModel {
fn new(config: ferrix_model_api::ModelConfig) -> Self {
todo!()
}
}

impl Model for CandleModel {
fn load(&mut self) -> ferrix_model_api::ModelResult<()> {
todo!()
Expand Down
6 changes: 6 additions & 0 deletions ferrix-model-onnx/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ use ort::*;

struct OnnxModel;

impl OnnxModel {
fn new(config: ferrix_model_api::ModelConfig) -> Self {
todo!()
}
}

impl Model for OnnxModel {
fn load(&mut self) -> ferrix_model_api::ModelResult<()> {
todo!()
Expand Down
24 changes: 11 additions & 13 deletions ferrix-model-pytorch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,33 @@ use anyhow::bail;
use atomic_option::AtomicOption;
use ferrix_model_api::internal::*;
use ferrix_model_api::Model;
use ferrix_model_api::ModelConfig;
use ferrix_model_api::ModelError;
use ferrix_model_api::ModelResult;
use tch::CModule;
use tch::Kind;
use tch::Tensor as PyTorchTensor;

struct PyTorchModel {
pub struct PyTorchModel {
module: AtomicOption<CModule>,
model_config: PyTorchModelConfig,
model_config: ModelConfig,
}

impl PyTorchModel {
fn new(model_config: PyTorchModelConfig) -> Self {
pub fn new(config: ferrix_model_api::ModelConfig) -> Self {
PyTorchModel {
module: AtomicOption::empty(),
model_config,
model_config: config,
}
}
}

#[derive(Clone)]
struct PyTorchModelConfig {
model: String,
}

impl Model for PyTorchModel {
fn load(&mut self) -> ModelResult<()> {
let file_name = self.model_config.model.to_string();
let file_name = self.model_config.base_path.to_string();

println!("Loading file {}", file_name);
let result = tch::CModule::load(self.model_config.model.to_string());
let result = tch::CModule::load(self.model_config.base_path.to_string());
let model = match result {
Ok(module) => module,
Err(error) => bail!(ModelError::Load(error.to_string())),
Expand Down Expand Up @@ -141,8 +137,10 @@ mod tests {
fn test_basic_pytorch_inference() {
let resource_dir = format!("{}/resource", env!("CARGO_MANIFEST_DIR"));
let saved_model_filename = format!("{}/model.pt", resource_dir);
let mut model = PyTorchModel::new(PyTorchModelConfig {
model: saved_model_filename,
let mut model = PyTorchModel::new(ModelConfig {
model_name: String::from(""),
base_path: saved_model_filename,
extended_config: None,
});
let load_result = model.load();

Expand Down
22 changes: 8 additions & 14 deletions ferrix-python-hooks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,43 +44,37 @@ fn ferrix(_py: Python, module: &PyModule) -> PyResult<()> {
Ok(())
}

pub fn eval() {
pub fn eval(code: String) {
pyo3::append_to_inittab!(ferrix);
pyo3::prepare_freethreaded_python();

Python::with_gil(|py| {
py.run(CODE, None, None).unwrap();
});
Python::with_gil(|py| py.run(&code, None, None).unwrap())
}

pub fn preprocess(input: InferRequest) -> InferRequest {
pub fn preprocess(input: InferRequest) -> PyResult<InferRequest> {
Python::with_gil(|py| {
let input = input.to_object(py);
let args = PyTuple::new(py, &[input]);
let response = PREPROCESSOR
.get()
.unwrap()
.call(py, args, Some(PyDict::new(py)))
.unwrap();
.call(py, args, Some(PyDict::new(py)))?;

response.extract::<InferRequest>(py)
})
.unwrap()
}

pub fn postprocess(input: InferResponse) -> InferResponse {
pub fn postprocess(input: InferResponse) -> PyResult<InferResponse> {
Python::with_gil(|py| {
let input = input.to_object(py);
let args = PyTuple::new(py, &[input]);
let response = PREPROCESSOR
.get()
.unwrap()
.call(py, args, Some(PyDict::new(py)))
.unwrap();
.call(py, args, Some(PyDict::new(py)))?;

response.extract::<InferResponse>(py)
})
.unwrap()
}

#[cfg(test)]
Expand All @@ -91,7 +85,7 @@ mod tests {

#[test]
fn test() {
eval();
eval(CODE.to_string());

PREPROCESSOR.get().unwrap();

Expand Down Expand Up @@ -123,6 +117,6 @@ mod tests {

let response = preprocess(infer_request);

assert_eq!("1".to_string(), response.id)
assert_eq!("1".to_string(), response.unwrap().id)
}
}
1 change: 1 addition & 0 deletions ferrix-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ axum = "0.6.18"
serde = { version = "1.0.164", features = ["derive"] }
ferrix-model-api = { path = "../ferrix-model-api" }
ferrix-protos = { path = "../ferrix-protos" }
ferrix-python-hooks = { path = "../ferrix-python-hooks" }

[build-dependencies]
prost-build = "0.11.9"
Expand Down
49 changes: 39 additions & 10 deletions ferrix-server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,48 @@
use std::sync::Arc;

use ferrix_model_api::internal::InferRequest;
use inference::Inference;
use tonic::Response;

use ferrix_model_api::Model;
use ferrix_protos::grpc_inference_service_server::GrpcInferenceService;
use ferrix_protos::grpc_inference_service_server::{
GrpcInferenceService, GrpcInferenceServiceServer,
};
use ferrix_protos::*;
use tonic::transport::Server;

pub mod inference;

// #[derive(Default)]
pub struct GrpcInferenceServiceImpl {
model: Arc<dyn Model + Send + Sync>,
model: Inference,
}

impl GrpcInferenceServiceImpl {
pub fn with_model(model: Inference) -> Self {
GrpcInferenceServiceImpl { model }
}
}

#[tonic::async_trait]
impl GrpcInferenceService for GrpcInferenceServiceImpl {
async fn server_live(
&self,
request: tonic::Request<ServerLiveRequest>,
_: tonic::Request<ServerLiveRequest>,
) -> std::result::Result<tonic::Response<ServerLiveResponse>, tonic::Status> {
return Ok(Response::new(ServerLiveResponse { live: true }));
}

/// The ServerReady API indicates if the server is ready for inferencing.
async fn server_ready(
&self,
request: tonic::Request<ServerReadyRequest>,
_: tonic::Request<ServerReadyRequest>,
) -> std::result::Result<tonic::Response<ServerReadyResponse>, tonic::Status> {
return Ok(Response::new(ServerReadyResponse { ready: true }));
}

/// The ModelReady API indicates if a specific model is ready for inferencing.
async fn model_ready(
&self,
request: tonic::Request<ModelReadyRequest>,
_: tonic::Request<ModelReadyRequest>,
) -> std::result::Result<tonic::Response<ModelReadyResponse>, tonic::Status> {
return Ok(Response::new(ModelReadyResponse {
ready: self.model.loaded(),
Expand Down Expand Up @@ -77,8 +87,27 @@ impl GrpcInferenceService for GrpcInferenceServiceImpl {
request: tonic::Request<ModelInferRequest>,
) -> std::result::Result<tonic::Response<ModelInferResponse>, tonic::Status> {
let infer_request = InferRequest::from_proto(request.into_inner());
return Ok(Response::new(
self.model.predict(infer_request).unwrap().to_proto(),
));
let infer_result = self.model.predict(infer_request);

match infer_result {
Ok(infer_response) => Ok(Response::new(infer_response.to_proto())),
Err(error) => Err(tonic::Status::internal(error.to_string())),
}
}
}

pub async fn serve(
port: i16,
service: GrpcInferenceServiceImpl,
) -> Result<(), Box<dyn std::error::Error>> {
let addr = format!("[::1]:{}", port).parse().unwrap();

println!("GreeterServer listening on {}", addr);

Server::builder()
.add_service(GrpcInferenceServiceServer::new(service))
.serve(addr)
.await?;

Ok(())
}

0 comments on commit 0ba8602

Please sign in to comment.