Skip to content

Commit

Permalink
feature: support AI sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
tyrchen committed Sep 29, 2024
1 parent 77f634e commit 39972b7
Show file tree
Hide file tree
Showing 11 changed files with 344 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ repos:
- id: cargo-test
name: cargo test
description: unit test for the project
entry: bash -c 'for dir in chat chatapp/src-tauri; do (cd $dir && cargo nextest run --all-features); done'
entry: bash -c 'for dir in chat chatapp/src-tauri; do (cd $dir && cargo nextest run --all-features -- --include-ignored); done'
language: rust
files: \.rs$
pass_filenames: false
40 changes: 40 additions & 0 deletions ai.rest
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
### use openai

POST https://api.openai.com/v1/chat/completions
Authorization: Bearer {{$processEnv OPENAI_KEY}}
Content-Type: application/json

{
"model": "gpt-4o",
"messages": [
{
"role": "system",
"content": "You are a friendly assistant that answers questions based on your knowledge. Your reply will be limited to 100 words. reply with simplified Chinese, unless the question asks for a specific language."
},
{
"role": "user",
"content": "中国上最长的河流是哪条?"
}
]
}

### use ollama

POST http://localhost:11434/api/chat
Content-Type: application/json


{
"model": "llama3.2",
"messages": [
{
"role": "system",
"content": "You are a friendly assistant that answers questions based on your knowledge. Your reply will be limited to 100 words. reply with simplified Chinese, unless the question asks for a specific language."
},
{
"role": "user",
"content": "中国上最长的河流是哪条?"
}
],
"stream": false
}
10 changes: 10 additions & 0 deletions chat/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion chat/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[workspace]
members = ["chat_server", "chat_core", "notify_server", "chat_test"]
members = ["chat_server", "chat_core", "notify_server", "chat_test", "ai_sdk"]
resolver = "2"

[workspace.dependencies]
Expand Down
12 changes: 12 additions & 0 deletions chat/ai_sdk/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "ai_sdk"
version = "0.1.0"
edition = "2021"

[dependencies]
anyhow = { workspace = true }
reqwest = { version = "0.12.7", default-features = false, features = ["rustls-tls", "json"] }
serde = { workspace = true}

[dev-dependencies]
tokio = { workspace = true}
9 changes: 9 additions & 0 deletions chat/ai_sdk/examples/ollama.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use ai_sdk::{Message, OllamaAdapter, Role, AiService};

#[tokio::main]
async fn main() {
let adapter = OllamaAdapter::default();
let messages = vec![Message { role: Role::User, content: "世界上最长的河流是什么?".to_string() }];
let response = adapter.complete(&messages).await.unwrap();
println!("response: {}", response);
}
11 changes: 11 additions & 0 deletions chat/ai_sdk/examples/openai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use ai_sdk::*;
use std::env;

#[tokio::main]
async fn main() {
let api_key = env::var("OPENAI_API_KEY").unwrap();
let adapter = OpenAIAdapter::new(api_key, "gpt-4o");
let messages = vec![Message { role: Role::User, content: "世界上最长的河流是什么?".to_string() }];
let response = adapter.complete(&messages).await.unwrap();
println!("response: {}", response);
}
5 changes: 5 additions & 0 deletions chat/ai_sdk/src/adapters/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod openai;
mod ollama;

pub use openai::*;
pub use ollama::*;
100 changes: 100 additions & 0 deletions chat/ai_sdk/src/adapters/ollama.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::{Message, AiService};

pub struct OllamaAdapter {
pub host: String,
pub model: String,
pub client: Client,
}

#[derive(Serialize)]
pub struct OllamaChatCompletionRequest {
pub model: String,
pub messages: Vec<OllamaMessage>,
pub stream: bool,
}

#[derive(Serialize, Deserialize)]
pub struct OllamaMessage {
pub role: String,
pub content: String,
}

#[derive(Deserialize)]
pub struct OllamaChatCompletionResponse {
pub model: String,
pub created_at: String,
pub message: OllamaMessage,
pub done: bool,
pub total_duration: u64,
pub load_duration: u64,
pub prompt_eval_count: u32,
pub prompt_eval_duration: u64,
pub eval_count: u32,
pub eval_duration: u64,
}

impl OllamaAdapter {
pub fn new(host: impl Into<String>, model: impl Into<String>) -> Self {
let host = host.into();
let model = model.into();
let client = Client::new();
Self { host, model, client }
}

pub fn new_local(model: impl Into<String>) -> Self {
let model = model.into();
let client = Client::new();
Self { host: "http://localhost:11434".to_string(), model, client }
}
}

impl Default for OllamaAdapter {
fn default() -> Self {
Self::new_local("llama3.2")
}
}

impl AiService for OllamaAdapter {
async fn complete(&self, messages: &[Message]) -> anyhow::Result<String> {
let request = OllamaChatCompletionRequest {
model: self.model.clone(),
messages: messages.iter().map(|m| m.into()).collect(),
stream: false,
};
let url = format!("{}/api/chat", self.host);
let response = self.client.post(url)
.json(&request)
.send().await?;
let response: OllamaChatCompletionResponse = response.json().await?;
Ok(response.message.content)
}
}

impl From<Message> for OllamaMessage {
fn from(message: Message) -> Self {
OllamaMessage { role: message.role.to_string(), content: message.content }
}
}

impl From<&Message> for OllamaMessage {
fn from(message: &Message) -> Self {
OllamaMessage { role: message.role.to_string(), content: message.content.clone() }
}
}

#[cfg(test)]
mod tests {
use crate::Role;
use super::*;

#[ignore]
#[tokio::test]
async fn ollama_complete_should_work() {
let adapter = OllamaAdapter::new_local("llama3.2");
let messages = vec![Message { role: Role::User, content: "Hello".to_string() }];
let response = adapter.complete(&messages).await.unwrap();
println!("response: {}", response);
}
}
120 changes: 120 additions & 0 deletions chat/ai_sdk/src/adapters/openai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use reqwest::Client;
use serde::{Deserialize, Serialize};
use anyhow::anyhow;
use crate::{AiService, Message};

pub struct OpenAIAdapter {
host: String,
api_key: String,
model: String,
client: Client,
}

#[derive(Serialize)]
pub struct OpenAIChatCompletionRequest {
pub model: String,
pub messages: Vec<OpenAIMessage>,
}

#[derive(Serialize, Deserialize)]
pub struct OpenAIMessage {
pub role: String,
pub content: String,
}
#[derive(Deserialize)]
pub struct OpenAIChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub system_fingerprint: String,
pub choices: Vec<OpenAIChoice>,
pub usage: OpenAIUsage,
}

#[derive(Deserialize)]
pub struct OpenAIChoice {
pub index: u32,
pub message: OpenAIMessage,
pub logprobs: Option<i64>,
pub finish_reason: String,
}

#[derive(Deserialize)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub completion_tokens_details: Option<OpenAICompletionTokensDetails>,
}

#[derive(Deserialize)]
pub struct OpenAICompletionTokensDetails {
pub reasoning_tokens: u32,
}

impl OpenAIAdapter {
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
let client = Client::new();
Self {
host: "https://api.openai.com/v1".to_string(),
api_key: api_key.into(),
model: model.into(),
client,
}
}
}

impl AiService for OpenAIAdapter {
async fn complete(&self, messages: &[Message]) -> anyhow::Result<String> {
let request = OpenAIChatCompletionRequest {
model: self.model.clone(),
messages: messages.iter().map(|m| m.into()).collect(),
};

let url = format!("{}/chat/completions", self.host);
let response = self.client.post(url)
.json(&request)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await?;
let mut data: OpenAIChatCompletionResponse = response.json().await?;
let content = data.choices.pop().ok_or(anyhow!("No response"))?.message.content;
Ok(content)
}
}

impl From<Message> for OpenAIMessage {
fn from(message: Message) -> Self {
OpenAIMessage {
role: message.role.to_string(),
content: message.content,
}
}
}

impl From<&Message> for OpenAIMessage {
fn from(message: &Message) -> Self {
OpenAIMessage {
role: message.role.to_string(),
content: message.content.clone(),
}
}
}

#[cfg(test)]
mod tests {
use crate::Role;
use std::env;
use super::*;

#[ignore]
#[tokio::test]
async fn openai_complete_should_work() {
let api_key = env::var("OPENAI_API_KEY").unwrap();
let adapter = OpenAIAdapter::new(api_key, "gpt-4o");
let messages = vec![Message { role: Role::User, content: "Hello".to_string() }];
let response = adapter.complete(&messages).await.unwrap();
assert!(response.len() > 0);
}
}
35 changes: 35 additions & 0 deletions chat/ai_sdk/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
mod adapters;

pub use adapters::*;

use std::fmt;

#[derive(Debug, Clone)]
pub enum Role {
User,
Assistant,
System,
}

#[derive(Debug, Clone)]
pub struct Message {
pub role: Role,
pub content: String,
}


#[allow(async_fn_in_trait)]
pub trait AiService {
async fn complete(&self, messages: &[Message]) -> anyhow::Result<String>;
// other common functions
}

impl fmt::Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
Role::System => write!(f, "system"),
}
}
}

0 comments on commit 39972b7

Please sign in to comment.