diff --git a/Cargo.lock b/Cargo.lock index 19982a4..bda573f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,6 +19,54 @@ dependencies = [ "cpufeatures", ] +[[package]] +name = "anstream" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" + +[[package]] +name = "anstyle-parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" +dependencies = [ + "anstyle", + "windows-sys", +] + [[package]] name = "anyhow" version = "1.0.75" @@ -33,7 +81,7 @@ checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -339,6 +387,52 @@ dependencies = [ "inout", ] +[[package]] +name = "clap" +version = "4.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d04704f56c2cde07f43e8e2c154b43f216dc5c92fc98ada720177362f953b956" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e231faeaca65ebd1ea3c737966bf858971cd38c3849107aa3ea7de90a804e45" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0862016ff20d69b84ef8247369fabf5c008a7417002411897d40ee1f4532b873" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.32", +] + +[[package]] +name = "clap_lex" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961" + +[[package]] +name = "colorchoice" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" + [[package]] name = "constant_time_eq" version = "0.1.5" @@ -439,6 +533,12 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" version = "0.3.1" @@ -473,6 +573,19 @@ dependencies = [ name = "ferrix" version = "0.1.0" +[[package]] +name = "ferrix-cli" +version = "0.1.0" +dependencies = [ + "clap", + "ferrix-model-api", + "ferrix-model-pytorch", + "ferrix-protos", + "ferrix-server", + "tokio", + "toml", +] + [[package]] name = "ferrix-model-api" version = "0.1.0" @@ -480,7 +593,9 @@ dependencies = [ "anyhow", "ferrix-protos", "pyo3", + "serde", "thiserror", + "toml", ] [[package]] @@ -554,6 +669,7 @@ dependencies = [ "axum", "ferrix-model-api", "ferrix-protos", + "ferrix-python-hooks", "prost", "prost-build", "protoc-bin-vendored", @@ -678,7 +794,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap", + "indexmap 1.9.3", "slab", "tokio", "tokio-util", @@ -704,6 +820,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dfda62a12f55daeae5015f81b0baea145391cb4520f86c248fc615d72640d12" + [[package]] name = "heck" version = "0.4.1" @@ -821,7 +943,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.12.3", +] + +[[package]] +name = "indexmap" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8adf3ddd720272c6ea8bf59463c04e0f93d0bbf7c5439b691bca2987e0270897" +dependencies = [ + "equivalent", + "hashbrown 0.14.1", ] [[package]] @@ -1163,7 +1295,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" dependencies = [ "fixedbitset", - "indexmap", + "indexmap 1.9.3", ] [[package]] @@ -1183,7 +1315,7 @@ checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -1620,22 +1752,22 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.164" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.164" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -1658,6 +1790,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96426c9936fd7a0124915f9185ea1d20aa9445cc9821142f0a73bc9207a2e186" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1729,6 +1870,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "subtle" version = "2.5.0" @@ -1748,9 +1895,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.26" +version = "2.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45c3457aacde3c65315de5031ec191ce46604304d2446e803d71ade03308d970" +checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" dependencies = [ "proc-macro2", "quote", @@ -1771,7 +1918,7 @@ checksum = "285ba80e733fac80aa4270fbcdf83772a79b80aa35c97075320abfee4a915b06" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", "unicode-xid", ] @@ -1840,7 +1987,7 @@ checksum = "463fe12d7993d3b327787537ce8dd4dfa058de32fc2b195ef3cde03dc4771e8f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -1909,7 +2056,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -1937,6 +2084,40 @@ dependencies = [ "tracing", ] +[[package]] +name = "toml" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "185d8ab0dfbb35cf1399a6344d8484209c088f75f8f68230da55d48d95d43e3d" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "396e4d48bbb2b7554c944bde63101b5ae446cff6ec4a24227428f15eb72ef338" +dependencies = [ + "indexmap 2.0.2", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + [[package]] name = "tonic" version = "0.9.2" @@ -2001,7 +2182,7 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", - "indexmap", + "indexmap 1.9.3", "pin-project", "pin-project-lite", "rand", @@ -2046,7 +2227,7 @@ checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -2138,6 +2319,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + [[package]] name = "version_check" version = "0.9.4" @@ -2201,7 +2388,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", "wasm-bindgen-shared", ] @@ -2223,7 +2410,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2352,6 +2539,15 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +[[package]] +name = "winnow" +version = "0.5.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711d82167854aff2018dfd193aa0fef5370f456732f0d5a0c59b0f1b4b907" +dependencies = [ + "memchr", +] + [[package]] name = "xattr" version = "1.0.0" @@ -2381,7 +2577,7 @@ checksum = "d5e19fb6ed40002bab5403ffa37e53e0e56f914a4450c8765f533018db1db35f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", "synstructure", ] @@ -2402,7 +2598,7 @@ checksum = "e6a647510471d372f2e6c2e6b7219e44d8c574d24fdc11c610a61455782f18c3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", "synstructure", ] diff --git a/Cargo.toml b/Cargo.toml index 8c464ca..3ae6281 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [workspace] members = [ + "ferrix-cli", "ferrix-model-api", "ferrix-protos", "ferrix-server", diff --git a/ferrix-model-api/Cargo.toml b/ferrix-model-api/Cargo.toml index 58c3220..580735d 100644 --- a/ferrix-model-api/Cargo.toml +++ b/ferrix-model-api/Cargo.toml @@ -7,4 +7,6 @@ edition = "2021" anyhow = "1.0.75" ferrix-protos = { path = "../ferrix-protos" } pyo3 = "0.19.1" +serde = "1.0.188" thiserror = "1.0.43" +toml = "0.8.2" diff --git a/ferrix-model-api/src/internal.rs b/ferrix-model-api/src/internal.rs index e2ed7db..f0882a2 100644 --- a/ferrix-model-api/src/internal.rs +++ b/ferrix-model-api/src/internal.rs @@ -414,16 +414,11 @@ impl TensorData { // Utils fn data_to_py(datatype: String, contents: TensorData, py: pyo3::Python<'_>) -> Py { match datatype.as_str() { - "INT8" => todo!(), - "INT16" => todo!(), - "INT32" => contents.int_contents.into_py(py).extract(py).unwrap(), + "INT8" | "INT16" | "INT32" => contents.int_contents.into_py(py).extract(py).unwrap(), "INT64" => contents.int64_contents.into_py(py).extract(py).unwrap(), - "UINT8" => todo!(), - "UINT16" => todo!(), - "UINT32" => contents.uint_contents.into_py(py).extract(py).unwrap(), + "UINT8" | "UINT16" | "UINT32" => contents.uint_contents.into_py(py).extract(py).unwrap(), "UINT64" => contents.uint64_contents.into_py(py).extract(py).unwrap(), - "FP16" => todo!(), - "FP32" => contents.fp32_contents.into_py(py).extract(py).unwrap(), + "FP16" | "FP32" => contents.fp32_contents.into_py(py).extract(py).unwrap(), "FP64" => contents.fp64_contents.into_py(py).extract(py).unwrap(), "BYTES" => todo!(), _ => todo!(), diff --git a/ferrix-model-api/src/lib.rs b/ferrix-model-api/src/lib.rs index 984c1ef..1b6b1f5 100644 --- a/ferrix-model-api/src/lib.rs +++ b/ferrix-model-api/src/lib.rs @@ -1,5 +1,7 @@ use internal::{InferRequest, InferResponse}; +use serde::Deserialize; use thiserror::Error; +use toml::Value; pub mod internal; pub mod python; @@ -10,6 +12,13 @@ pub trait Model: Send + Sync { fn predict(&self, request: InferRequest) -> ModelResult; } +#[derive(Deserialize)] +pub struct ModelConfig { + pub model_name: String, + pub base_path: String, + pub extended_config: Option, +} + pub type ModelResult = std::result::Result; #[derive(Error, Debug)] diff --git a/ferrix-model-candle/src/lib.rs b/ferrix-model-candle/src/lib.rs index 9892b02..b8625f0 100644 --- a/ferrix-model-candle/src/lib.rs +++ b/ferrix-model-candle/src/lib.rs @@ -10,6 +10,12 @@ struct CandleModel { module: Arc, } +impl CandleModel { + fn new(config: ferrix_model_api::ModelConfig) -> Self { + todo!() + } +} + impl Model for CandleModel { fn load(&mut self) -> ferrix_model_api::ModelResult<()> { todo!() diff --git a/ferrix-model-onnx/src/lib.rs b/ferrix-model-onnx/src/lib.rs index 3c340d2..ba5ff5b 100644 --- a/ferrix-model-onnx/src/lib.rs +++ b/ferrix-model-onnx/src/lib.rs @@ -3,6 +3,12 @@ use ort::*; struct OnnxModel; +impl OnnxModel { + fn new(config: ferrix_model_api::ModelConfig) -> Self { + todo!() + } +} + impl Model for OnnxModel { fn load(&mut self) -> ferrix_model_api::ModelResult<()> { todo!() diff --git a/ferrix-model-pytorch/src/lib.rs b/ferrix-model-pytorch/src/lib.rs index 1ed977b..180bbe3 100644 --- a/ferrix-model-pytorch/src/lib.rs +++ b/ferrix-model-pytorch/src/lib.rs @@ -6,37 +6,33 @@ use anyhow::bail; use atomic_option::AtomicOption; use ferrix_model_api::internal::*; use ferrix_model_api::Model; +use ferrix_model_api::ModelConfig; use ferrix_model_api::ModelError; use ferrix_model_api::ModelResult; use tch::CModule; use tch::Kind; use tch::Tensor as PyTorchTensor; -struct PyTorchModel { +pub struct PyTorchModel { module: AtomicOption, - model_config: PyTorchModelConfig, + model_config: ModelConfig, } impl PyTorchModel { - fn new(model_config: PyTorchModelConfig) -> Self { + pub fn new(config: ferrix_model_api::ModelConfig) -> Self { PyTorchModel { module: AtomicOption::empty(), - model_config, + model_config: config, } } } -#[derive(Clone)] -struct PyTorchModelConfig { - model: String, -} - impl Model for PyTorchModel { fn load(&mut self) -> ModelResult<()> { - let file_name = self.model_config.model.to_string(); + let file_name = self.model_config.base_path.to_string(); println!("Loading file {}", file_name); - let result = tch::CModule::load(self.model_config.model.to_string()); + let result = tch::CModule::load(self.model_config.base_path.to_string()); let model = match result { Ok(module) => module, Err(error) => bail!(ModelError::Load(error.to_string())), @@ -141,8 +137,10 @@ mod tests { fn test_basic_pytorch_inference() { let resource_dir = format!("{}/resource", env!("CARGO_MANIFEST_DIR")); let saved_model_filename = format!("{}/model.pt", resource_dir); - let mut model = PyTorchModel::new(PyTorchModelConfig { - model: saved_model_filename, + let mut model = PyTorchModel::new(ModelConfig { + model_name: String::from(""), + base_path: saved_model_filename, + extended_config: None, }); let load_result = model.load(); diff --git a/ferrix-python-hooks/src/lib.rs b/ferrix-python-hooks/src/lib.rs index 6488a4b..ed16a9c 100644 --- a/ferrix-python-hooks/src/lib.rs +++ b/ferrix-python-hooks/src/lib.rs @@ -44,43 +44,37 @@ fn ferrix(_py: Python, module: &PyModule) -> PyResult<()> { Ok(()) } -pub fn eval() { +pub fn eval(code: String) { pyo3::append_to_inittab!(ferrix); pyo3::prepare_freethreaded_python(); - Python::with_gil(|py| { - py.run(CODE, None, None).unwrap(); - }); + Python::with_gil(|py| py.run(&code, None, None).unwrap()) } -pub fn preprocess(input: InferRequest) -> InferRequest { +pub fn preprocess(input: InferRequest) -> PyResult { Python::with_gil(|py| { let input = input.to_object(py); let args = PyTuple::new(py, &[input]); let response = PREPROCESSOR .get() .unwrap() - .call(py, args, Some(PyDict::new(py))) - .unwrap(); + .call(py, args, Some(PyDict::new(py)))?; response.extract::(py) }) - .unwrap() } -pub fn postprocess(input: InferResponse) -> InferResponse { +pub fn postprocess(input: InferResponse) -> PyResult { Python::with_gil(|py| { let input = input.to_object(py); let args = PyTuple::new(py, &[input]); let response = PREPROCESSOR .get() .unwrap() - .call(py, args, Some(PyDict::new(py))) - .unwrap(); + .call(py, args, Some(PyDict::new(py)))?; response.extract::(py) }) - .unwrap() } #[cfg(test)] @@ -91,7 +85,7 @@ mod tests { #[test] fn test() { - eval(); + eval(CODE.to_string()); PREPROCESSOR.get().unwrap(); @@ -123,6 +117,6 @@ mod tests { let response = preprocess(infer_request); - assert_eq!("1".to_string(), response.id) + assert_eq!("1".to_string(), response.unwrap().id) } } diff --git a/ferrix-server/Cargo.toml b/ferrix-server/Cargo.toml index 1e07ef1..5172845 100644 --- a/ferrix-server/Cargo.toml +++ b/ferrix-server/Cargo.toml @@ -11,6 +11,7 @@ axum = "0.6.18" serde = { version = "1.0.164", features = ["derive"] } ferrix-model-api = { path = "../ferrix-model-api" } ferrix-protos = { path = "../ferrix-protos" } +ferrix-python-hooks = { path = "../ferrix-python-hooks" } [build-dependencies] prost-build = "0.11.9" diff --git a/ferrix-server/src/lib.rs b/ferrix-server/src/lib.rs index c36b6b8..b951f2a 100644 --- a/ferrix-server/src/lib.rs +++ b/ferrix-server/src/lib.rs @@ -1,22 +1,32 @@ -use std::sync::Arc; - use ferrix_model_api::internal::InferRequest; +use inference::Inference; use tonic::Response; use ferrix_model_api::Model; -use ferrix_protos::grpc_inference_service_server::GrpcInferenceService; +use ferrix_protos::grpc_inference_service_server::{ + GrpcInferenceService, GrpcInferenceServiceServer, +}; use ferrix_protos::*; +use tonic::transport::Server; + +pub mod inference; // #[derive(Default)] pub struct GrpcInferenceServiceImpl { - model: Arc, + model: Inference, +} + +impl GrpcInferenceServiceImpl { + pub fn with_model(model: Inference) -> Self { + GrpcInferenceServiceImpl { model } + } } #[tonic::async_trait] impl GrpcInferenceService for GrpcInferenceServiceImpl { async fn server_live( &self, - request: tonic::Request, + _: tonic::Request, ) -> std::result::Result, tonic::Status> { return Ok(Response::new(ServerLiveResponse { live: true })); } @@ -24,7 +34,7 @@ impl GrpcInferenceService for GrpcInferenceServiceImpl { /// The ServerReady API indicates if the server is ready for inferencing. async fn server_ready( &self, - request: tonic::Request, + _: tonic::Request, ) -> std::result::Result, tonic::Status> { return Ok(Response::new(ServerReadyResponse { ready: true })); } @@ -32,7 +42,7 @@ impl GrpcInferenceService for GrpcInferenceServiceImpl { /// The ModelReady API indicates if a specific model is ready for inferencing. async fn model_ready( &self, - request: tonic::Request, + _: tonic::Request, ) -> std::result::Result, tonic::Status> { return Ok(Response::new(ModelReadyResponse { ready: self.model.loaded(), @@ -77,8 +87,27 @@ impl GrpcInferenceService for GrpcInferenceServiceImpl { request: tonic::Request, ) -> std::result::Result, tonic::Status> { let infer_request = InferRequest::from_proto(request.into_inner()); - return Ok(Response::new( - self.model.predict(infer_request).unwrap().to_proto(), - )); + let infer_result = self.model.predict(infer_request); + + match infer_result { + Ok(infer_response) => Ok(Response::new(infer_response.to_proto())), + Err(error) => Err(tonic::Status::internal(error.to_string())), + } } } + +pub async fn serve( + port: i16, + service: GrpcInferenceServiceImpl, +) -> Result<(), Box> { + let addr = format!("[::1]:{}", port).parse().unwrap(); + + println!("GreeterServer listening on {}", addr); + + Server::builder() + .add_service(GrpcInferenceServiceServer::new(service)) + .serve(addr) + .await?; + + Ok(()) +}