From d319547ad22492444ed5dacd1a2849e43e938216 Mon Sep 17 00:00:00 2001 From: Hinome <57831472+RealHinome@users.noreply.github.com> Date: Wed, 27 Dec 2023 00:29:54 +0100 Subject: [PATCH] feat: add new ai model --- src/main.rs | 53 +++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/src/main.rs b/src/main.rs index fd8c2c2..ce4043b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,28 @@ +pub mod corpus; pub mod helpers; +use corpus::CorpusManager; use remini::remini_server::{Remini as Rem, ReminiServer}; use remini::{Reply as ReminiReply, Request as ReminiRequest}; use tonic::{transport::Server, Request, Response, Status}; +pub type Model = tract_onnx::prelude::SimplePlan< + tract_onnx::prelude::TypedFact, + Box, + tract_onnx::prelude::Graph< + tract_onnx::prelude::TypedFact, + Box, + >, +>; + pub mod remini { tonic::include_proto!("remini"); } -#[derive(Debug, Default)] -pub struct Remini {} +struct Remini { + /// Corpus model to detect nodity on a content. + corpus: corpus::Corpus, +} #[tonic::async_trait] impl Rem for Remini { @@ -19,11 +32,29 @@ impl Rem for Remini { ) -> Result, Status> { let content = request.into_inner(); - Ok(Response::new(ReminiReply { - model: content.model, - message: "OK".to_string(), - error: false, - })) + match content.model.as_str() { + "corpus" => match self.corpus.predict(&content.data) { + Ok(result) => Ok(Response::new(ReminiReply { + model: content.model, + message: result, + error: false, + })), + Err(error) => { + log::error!("Corpus model got an error; {}", error); + + Ok(Response::new(ReminiReply { + model: content.model, + message: "Internal server error".to_string(), + error: true, + })) + } + }, + _ => Ok(Response::new(ReminiReply { + model: content.model, + message: "Unknown model".to_string(), + error: true, + })), + } } } @@ -58,7 +89,13 @@ async fn main() -> Result<(), Box> { std::env::var("port").unwrap_or("50051".to_string()) ) .parse()?; - let remini = Remini::default(); + + // Init every models. + let remini = Remini { + corpus: corpus::Corpus { + model: corpus::init()?, + }, + }; log::info!("Server started on {}", addr);