-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
344 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
### use openai | ||
|
||
POST https://api.openai.com/v1/chat/completions | ||
Authorization: Bearer {{$processEnv OPENAI_KEY}} | ||
Content-Type: application/json | ||
|
||
{ | ||
"model": "gpt-4o", | ||
"messages": [ | ||
{ | ||
"role": "system", | ||
"content": "You are a friendly assistant that answers questions based on your knowledge. Your reply will be limited to 100 words. reply with simplified Chinese, unless the question asks for a specific language." | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "中国上最长的河流是哪条?" | ||
} | ||
] | ||
} | ||
|
||
### use ollama | ||
|
||
POST http://localhost:11434/api/chat | ||
Content-Type: application/json | ||
|
||
|
||
{ | ||
"model": "llama3.2", | ||
"messages": [ | ||
{ | ||
"role": "system", | ||
"content": "You are a friendly assistant that answers questions based on your knowledge. Your reply will be limited to 100 words. reply with simplified Chinese, unless the question asks for a specific language." | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "中国上最长的河流是哪条?" | ||
} | ||
], | ||
"stream": false | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
[package] | ||
name = "ai_sdk" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
anyhow = { workspace = true } | ||
reqwest = { version = "0.12.7", default-features = false, features = ["rustls-tls", "json"] } | ||
serde = { workspace = true} | ||
|
||
[dev-dependencies] | ||
tokio = { workspace = true} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
use ai_sdk::{Message, OllamaAdapter, Role, AiService}; | ||
|
||
#[tokio::main] | ||
async fn main() { | ||
let adapter = OllamaAdapter::default(); | ||
let messages = vec![Message { role: Role::User, content: "世界上最长的河流是什么?".to_string() }]; | ||
let response = adapter.complete(&messages).await.unwrap(); | ||
println!("response: {}", response); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
use ai_sdk::*; | ||
use std::env; | ||
|
||
#[tokio::main] | ||
async fn main() { | ||
let api_key = env::var("OPENAI_API_KEY").unwrap(); | ||
let adapter = OpenAIAdapter::new(api_key, "gpt-4o"); | ||
let messages = vec![Message { role: Role::User, content: "世界上最长的河流是什么?".to_string() }]; | ||
let response = adapter.complete(&messages).await.unwrap(); | ||
println!("response: {}", response); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
mod openai; | ||
mod ollama; | ||
|
||
pub use openai::*; | ||
pub use ollama::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
use reqwest::Client; | ||
use serde::{Deserialize, Serialize}; | ||
use crate::{Message, AiService}; | ||
|
||
pub struct OllamaAdapter { | ||
pub host: String, | ||
pub model: String, | ||
pub client: Client, | ||
} | ||
|
||
#[derive(Serialize)] | ||
pub struct OllamaChatCompletionRequest { | ||
pub model: String, | ||
pub messages: Vec<OllamaMessage>, | ||
pub stream: bool, | ||
} | ||
|
||
#[derive(Serialize, Deserialize)] | ||
pub struct OllamaMessage { | ||
pub role: String, | ||
pub content: String, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
pub struct OllamaChatCompletionResponse { | ||
pub model: String, | ||
pub created_at: String, | ||
pub message: OllamaMessage, | ||
pub done: bool, | ||
pub total_duration: u64, | ||
pub load_duration: u64, | ||
pub prompt_eval_count: u32, | ||
pub prompt_eval_duration: u64, | ||
pub eval_count: u32, | ||
pub eval_duration: u64, | ||
} | ||
|
||
impl OllamaAdapter { | ||
pub fn new(host: impl Into<String>, model: impl Into<String>) -> Self { | ||
let host = host.into(); | ||
let model = model.into(); | ||
let client = Client::new(); | ||
Self { host, model, client } | ||
} | ||
|
||
pub fn new_local(model: impl Into<String>) -> Self { | ||
let model = model.into(); | ||
let client = Client::new(); | ||
Self { host: "http://localhost:11434".to_string(), model, client } | ||
} | ||
} | ||
|
||
impl Default for OllamaAdapter { | ||
fn default() -> Self { | ||
Self::new_local("llama3.2") | ||
} | ||
} | ||
|
||
impl AiService for OllamaAdapter { | ||
async fn complete(&self, messages: &[Message]) -> anyhow::Result<String> { | ||
let request = OllamaChatCompletionRequest { | ||
model: self.model.clone(), | ||
messages: messages.iter().map(|m| m.into()).collect(), | ||
stream: false, | ||
}; | ||
let url = format!("{}/api/chat", self.host); | ||
let response = self.client.post(url) | ||
.json(&request) | ||
.send().await?; | ||
let response: OllamaChatCompletionResponse = response.json().await?; | ||
Ok(response.message.content) | ||
} | ||
} | ||
|
||
impl From<Message> for OllamaMessage { | ||
fn from(message: Message) -> Self { | ||
OllamaMessage { role: message.role.to_string(), content: message.content } | ||
} | ||
} | ||
|
||
impl From<&Message> for OllamaMessage { | ||
fn from(message: &Message) -> Self { | ||
OllamaMessage { role: message.role.to_string(), content: message.content.clone() } | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::Role; | ||
use super::*; | ||
|
||
#[ignore] | ||
#[tokio::test] | ||
async fn ollama_complete_should_work() { | ||
let adapter = OllamaAdapter::new_local("llama3.2"); | ||
let messages = vec![Message { role: Role::User, content: "Hello".to_string() }]; | ||
let response = adapter.complete(&messages).await.unwrap(); | ||
println!("response: {}", response); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
use reqwest::Client; | ||
use serde::{Deserialize, Serialize}; | ||
use anyhow::anyhow; | ||
use crate::{AiService, Message}; | ||
|
||
pub struct OpenAIAdapter { | ||
host: String, | ||
api_key: String, | ||
model: String, | ||
client: Client, | ||
} | ||
|
||
#[derive(Serialize)] | ||
pub struct OpenAIChatCompletionRequest { | ||
pub model: String, | ||
pub messages: Vec<OpenAIMessage>, | ||
} | ||
|
||
#[derive(Serialize, Deserialize)] | ||
pub struct OpenAIMessage { | ||
pub role: String, | ||
pub content: String, | ||
} | ||
#[derive(Deserialize)] | ||
pub struct OpenAIChatCompletionResponse { | ||
pub id: String, | ||
pub object: String, | ||
pub created: u64, | ||
pub model: String, | ||
pub system_fingerprint: String, | ||
pub choices: Vec<OpenAIChoice>, | ||
pub usage: OpenAIUsage, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
pub struct OpenAIChoice { | ||
pub index: u32, | ||
pub message: OpenAIMessage, | ||
pub logprobs: Option<i64>, | ||
pub finish_reason: String, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
pub struct OpenAIUsage { | ||
pub prompt_tokens: u32, | ||
pub completion_tokens: u32, | ||
pub total_tokens: u32, | ||
pub completion_tokens_details: Option<OpenAICompletionTokensDetails>, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
pub struct OpenAICompletionTokensDetails { | ||
pub reasoning_tokens: u32, | ||
} | ||
|
||
impl OpenAIAdapter { | ||
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self { | ||
let client = Client::new(); | ||
Self { | ||
host: "https://api.openai.com/v1".to_string(), | ||
api_key: api_key.into(), | ||
model: model.into(), | ||
client, | ||
} | ||
} | ||
} | ||
|
||
impl AiService for OpenAIAdapter { | ||
async fn complete(&self, messages: &[Message]) -> anyhow::Result<String> { | ||
let request = OpenAIChatCompletionRequest { | ||
model: self.model.clone(), | ||
messages: messages.iter().map(|m| m.into()).collect(), | ||
}; | ||
|
||
let url = format!("{}/chat/completions", self.host); | ||
let response = self.client.post(url) | ||
.json(&request) | ||
.header("Authorization", format!("Bearer {}", self.api_key)) | ||
.send() | ||
.await?; | ||
let mut data: OpenAIChatCompletionResponse = response.json().await?; | ||
let content = data.choices.pop().ok_or(anyhow!("No response"))?.message.content; | ||
Ok(content) | ||
} | ||
} | ||
|
||
impl From<Message> for OpenAIMessage { | ||
fn from(message: Message) -> Self { | ||
OpenAIMessage { | ||
role: message.role.to_string(), | ||
content: message.content, | ||
} | ||
} | ||
} | ||
|
||
impl From<&Message> for OpenAIMessage { | ||
fn from(message: &Message) -> Self { | ||
OpenAIMessage { | ||
role: message.role.to_string(), | ||
content: message.content.clone(), | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::Role; | ||
use std::env; | ||
use super::*; | ||
|
||
#[ignore] | ||
#[tokio::test] | ||
async fn openai_complete_should_work() { | ||
let api_key = env::var("OPENAI_API_KEY").unwrap(); | ||
let adapter = OpenAIAdapter::new(api_key, "gpt-4o"); | ||
let messages = vec![Message { role: Role::User, content: "Hello".to_string() }]; | ||
let response = adapter.complete(&messages).await.unwrap(); | ||
assert!(response.len() > 0); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
mod adapters; | ||
|
||
pub use adapters::*; | ||
|
||
use std::fmt; | ||
|
||
#[derive(Debug, Clone)] | ||
pub enum Role { | ||
User, | ||
Assistant, | ||
System, | ||
} | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct Message { | ||
pub role: Role, | ||
pub content: String, | ||
} | ||
|
||
|
||
#[allow(async_fn_in_trait)] | ||
pub trait AiService { | ||
async fn complete(&self, messages: &[Message]) -> anyhow::Result<String>; | ||
// other common functions | ||
} | ||
|
||
impl fmt::Display for Role { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
match self { | ||
Role::User => write!(f, "user"), | ||
Role::Assistant => write!(f, "assistant"), | ||
Role::System => write!(f, "system"), | ||
} | ||
} | ||
} |