Skip to content

Commit

Permalink
Add fine-tuned ai model support
Browse files Browse the repository at this point in the history
  • Loading branch information
HoKim98 committed Nov 3, 2023
1 parent 9fdedcd commit 978b85f
Show file tree
Hide file tree
Showing 32 changed files with 1,042 additions and 231 deletions.
1 change: 0 additions & 1 deletion .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

[target.aarch64-unknown-linux-musl]
linker = "clang"
rustflags = ["-C", "link-arg=-fuse-ld=mold"]
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ jobs:
- name: Include target-dependent packages
run: >
sed -i 's/^\( *\)\(.*\# *include *( *[_0-9a-z-]\+ *)\)$/\1# \2/g' ./Cargo.toml
&& sed -i "s/^\( *\)\# *\(.*\# *include *( *$(uname -m) *)\)$/\1\2/g" ./Cargo.toml
find ./ -type f -name Cargo.toml -exec sed -i 's/^\( *\)\(.*\# *include *( *[_0-9a-z-]\+ *)\)$/\1# \2/g' {} +
&& find ./ -type f -name Cargo.toml -exec sed -i "s/^\( *\)\# *\(.*\# *include *( *$(uname -m) *)\)$/\1\2/g" {} +
- name: Build
run: cargo build --all --workspace --verbose
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ minio = { git = "https://github.com/ulagbulag/minio-rs.git", default-features =
"rustls-tls",
] } # not deployed to crates.io
map-macro = { version = "0.2" }
maplit = { version = "1.0" }
octocrab = { version = "0.31", default-features = false, features = ["rustls"] }
opencv = { version = "0.85", default-features = false }
ordered-float = { version = "4.1", default-features = false, features = [
Expand Down
7 changes: 4 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ WORKDIR /src
# Build it!
RUN mkdir /out \
# Exclude non-musl packages
&& sed -i 's/^\( *\)\(.*\# *exclude *( *alpine *)\)$/\1# \2/g' ./Cargo.toml \
&& find ./ -type f -name Cargo.toml -exec sed -i 's/^\( *\)\(.*\# *exclude *( *alpine *)\)$/\1# \2/g' {} + \
&& find ./ -type f -name Cargo.toml -exec sed -i 's/^\( *\)\# *\(.*\# *include *( *alpine *)\)$/\1\2/g' {} + \
# Include target-dependent packages
&& sed -i 's/^\( *\)\(.*\# *include *( *[_0-9a-z-]\+ *)\)$/\1# \2/g' ./Cargo.toml \
&& sed -i "s/^\( *\)\# *\(.*\# *include *( *$(uname -m) *)\)$/\1\2/g" ./Cargo.toml \
&& find ./ -type f -name Cargo.toml -exec sed -i 's/^\( *\)\(.*\# *include *( *[_0-9a-z-]\+ *)\)$/\1# \2/g' {} + \
&& find ./ -type f -name Cargo.toml -exec sed -i "s/^\( *\)\# *\(.*\# *include *( *$(uname -m) *)\)$/\1\2/g" {} + \
# Build
&& cargo build --all --workspace --release \
&& find ./target/release/ -maxdepth 1 -type f -perm +a=x -print0 | xargs -0 -I {} mv {} /out \
Expand Down
10 changes: 2 additions & 8 deletions Dockerfile.full
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ EXPOSE 80/tcp
WORKDIR /usr/local/bin
CMD [ "/bin/sh" ]

# Configure environment variables
ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/lib"

# Install python dependencies
ADD ./requirements.txt /requirements.txt
RUN apt-get update && apt-get install -y \
Expand Down Expand Up @@ -62,19 +59,16 @@ WORKDIR /src
ARG FUNCTION_HOME
RUN mkdir -p /out/bin /out/lib \
# Include target-dependent packages
&& sed -i 's/^\( *\)\(.*\# *include *( *[_0-9a-z-]\+ *)\)$/\1# \2/g' ./Cargo.toml \
&& sed -i "s/^\( *\)\# *\(.*\# *include *( *$(uname -m) *)\)$/\1\2/g" ./Cargo.toml \
&& find ./ -type f -name Cargo.toml -exec sed -i 's/^\( *\)\(.*\# *include *( *[_0-9a-z-]\+ *)\)$/\1# \2/g' {} + \
&& find ./ -type f -name Cargo.toml -exec sed -i "s/^\( *\)\# *\(.*\# *include *( *$(uname -m) *)\)$/\1\2/g" {} + \
# Build
&& "${CARGO_HOME}/bin/rustup" default stable \
&& "${CARGO_HOME}/bin/cargo" build --all --workspace --release \
&& find ./target/release/ -maxdepth 1 -type f -perm -a=x -print0 | xargs -0 -I {} mv {} /out/bin \
&& mv /out/bin/*.so* /out/lib/ \
&& mv ./LICENSE /LICENSE \
# Copy pipe functions
&& mkdir -p "${FUNCTION_HOME}" \
&& mv ./dash/pipe/functions/python/examples "${FUNCTION_HOME}/python" \
# # Remove duplicated onnxruntime CAPI binary
# && rm -rf /out/libonnxruntime_providers_*.so \
# Cleanup
&& rm -rf /src

Expand Down
12 changes: 11 additions & 1 deletion ark/core/src/tracer.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
use std::ffi::OsStr;

pub fn init_once() {
// set default tracing level
const KEY: &str = "RUST_LOG";
if ::std::env::var_os(KEY).is_none() {
::std::env::set_var(KEY, "INFO");
}

::tracing_subscriber::fmt::try_init().ok();
}

pub fn init_once_with(level: impl AsRef<OsStr>) {
// set custom tracing level
::std::env::set_var(KEY, level);

::tracing_subscriber::fmt::try_init().ok();
}

const KEY: &str = "RUST_LOG";
1 change: 1 addition & 0 deletions dash/api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ anyhow = { workspace = true }
chrono = { workspace = true }
k8s-openapi = { workspace = true }
kube = { workspace = true, features = ["derive"] }
maplit = { workspace = true }
ordered-float = { workspace = true }
schemars = { workspace = true }
serde = { workspace = true }
Expand Down
19 changes: 6 additions & 13 deletions dash/api/src/storage/object.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::{collections::BTreeMap, net::Ipv4Addr};
use std::net::Ipv4Addr;

use ark_core_k8s::data::Url;
use k8s_openapi::{
api::core::v1::ResourceRequirements, apimachinery::pkg::api::resource::Quantity,
};
use maplit::btreemap;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -127,18 +128,10 @@ impl ModelStorageObjectOwnedReplicationSpec {

fn default_resources() -> ResourceRequirements {
ResourceRequirements {
requests: Some({
let mut map = BTreeMap::default();
map.insert("cpu".into(), Quantity(Self::default_resources_cpu().into()));
map.insert(
"memory".into(),
Quantity(Self::default_resources_memory().into()),
);
map.insert(
"storage".into(),
Quantity(Self::default_resources_storage().into()),
);
map
requests: Some(btreemap! {
"cpu".into() => Quantity(Self::default_resources_cpu().into()),
"memory".into() => Quantity(Self::default_resources_memory().into()),
"storage".into() => Quantity(Self::default_resources_storage().into()),
}),
..Default::default()
}
Expand Down
5 changes: 3 additions & 2 deletions dash/controller/src/validator/pipe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use dash_api::{
};
use dash_provider::storage::KubernetesStorageClient;
use kube::Client;
use straw_api::pipe::StrawPipe;
use straw_api::{pipe::StrawPipe, plugin::PluginContext};
use straw_provider::StrawSession;

use super::model::ModelValidator;
Expand Down Expand Up @@ -56,7 +56,8 @@ impl<'namespace, 'kube> PipeValidator<'namespace, 'kube> {
}

async fn validate_exec_straw(&self, exec: StrawPipe) -> Result<StrawPipe> {
let ctx = PluginContext::default();
let session = StrawSession::new(self.kube.clone(), Some(self.namespace.into()));
session.create(&exec).await.map(|()| exec)
session.create(&ctx, &exec).await.map(|()| exec)
}
}
5 changes: 4 additions & 1 deletion dash/pipe/functions/ai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ version = { workspace = true }
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[features]
straw = ["kube", "straw-api"]
straw = ["k8s-openapi", "kube", "maplit", "straw-api"]

[dependencies]
ark-core-k8s = { path = "../../../../ark/core/k8s", features = ["data"] }
Expand All @@ -29,7 +29,10 @@ straw-api = { path = "../../../../straw/api", optional = true, features = [
anyhow = { workspace = true }
async-trait = { workspace = true }
clap = { workspace = true }
inflector = { workspace = true }
k8s-openapi = { workspace = true, optional = true }
kube = { workspace = true, optional = true, features = ["client"] }
maplit = { workspace = true, optional = true }
pyo3 = { workspace = true, features = ["auto-initialize"] }
serde = { workspace = true }
serde_json = { workspace = true }
Expand Down
10 changes: 5 additions & 5 deletions dash/pipe/functions/ai/src/plugin/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import io
import json
from typing import Any, Callable

from PIL import Image
from transformers import AutoTokenizer, Pipeline, PretrainedConfig, pipeline
from optimum.onnxruntime import ORTModel

Expand Down Expand Up @@ -56,8 +53,11 @@ def load(model_id: str, kind: str) -> Callable:
def preprocess(input: Any, kind: str) -> dict[str, Any]:
match kind:
# NLP
case 'QuestionAnswering' | \
'Translation':
case 'QuestionAnswering' \
| 'Summarization' \
| 'TextGeneration' \
| 'Translation' \
| 'ZeroShotClassification':
return input.value


Expand Down
41 changes: 40 additions & 1 deletion dash/pipe/functions/ai/src/plugin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,43 @@ pub struct ModelLoader<'a> {
}

#[cfg(feature = "straw")]
impl<'a> ::straw_api::plugin::PluginDaemon for ModelLoader<'a> {}
impl<'a> ::straw_api::plugin::PluginDaemon for ModelLoader<'a> {
fn container_default_env(
&self,
node: &::straw_api::pipe::StrawNode,
) -> Vec<::k8s_openapi::api::core::v1::EnvVar> {
use inflector::Inflector;
use k8s_openapi::api::core::v1::EnvVar;

vec![
EnvVar {
name: "PIPE_AI_MODEL".into(),
value: Some(node.src.to_string()),
value_from: None,
},
EnvVar {
name: "PIPE_AI_MODEL_KIND".into(),
value: Some(node.name.to_pascal_case()),
value_from: None,
},
]
}

fn container_command(&self) -> Option<Vec<String>> {
Some(vec!["dash-pipe-function-ai".into()])
}

fn container_resources(&self) -> Option<::k8s_openapi::api::core::v1::ResourceRequirements> {
use k8s_openapi::apimachinery::pkg::api::resource::Quantity;

Some(::k8s_openapi::api::core::v1::ResourceRequirements {
claims: None,
requests: None,
limits: Some(::maplit::btreemap! {
// "cpu".into() => Quantity("1".into()),
// "memory".into() => Quantity("500Mi".into()),
"nvidia.com/gpu".into() => Quantity("1".into()),
}),
})
}
}
2 changes: 1 addition & 1 deletion dash/pipe/functions/identity/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ impl ::dash_pipe_provider::Function for Function {
}

fn pack_payload(mut message: PipeMessage<Value>) -> Result<PipeMessage<Value>> {
message.payloads = vec![PipePayload::new("test".into(), message.to_bytes()?)];
message.payloads = vec![PipePayload::new("test".into(), (&message).try_into()?)];
Ok(message)
}
2 changes: 1 addition & 1 deletion dash/pipe/provider/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ pub use self::function::{
};
#[cfg(feature = "pyo3")]
pub use self::message::PyPipeMessage;
pub use self::message::{PipeMessage, PipeMessages, PipePayload};
pub use self::message::{Codec, PipeMessage, PipeMessages, PipePayload};
pub use self::messengers::MessengerType;
pub use self::pipe::{DefaultModelIn, PipeArgs};
54 changes: 38 additions & 16 deletions dash/pipe/provider/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use futures::{stream::FuturesOrdered, TryStreamExt};
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value as DynValue;
use strum::{Display, EnumString};

use crate::storage::{StorageSet, StorageType};

Expand Down Expand Up @@ -338,12 +339,7 @@ where
type Error = Error;

fn try_from(value: &PipeMessage<Value, Payload>) -> Result<Self> {
// opcode
let mut buf = vec![OpCode::MessagePack as u8];

::rmp_serde::encode::write(&mut buf, value)
.map(|()| buf.into())
.map_err(Into::into)
value.to_bytes(Codec::default())
}
}

Expand Down Expand Up @@ -419,20 +415,24 @@ where
self.timestamp
}

pub fn to_bytes(&self) -> Result<Bytes>
pub fn to_bytes(&self, encoder: Codec) -> Result<Bytes>
where
Payload: Serialize,
Value: Serialize,
{
self.try_into()
}

pub fn to_json(&self) -> Result<DynValue>
where
Payload: Serialize,
Value: Serialize,
{
self.try_into()
match encoder {
Codec::Json => ::serde_json::to_vec(self)
.map(Into::into)
.map_err(Into::into),
Codec::MessagePack => {
// opcode
let mut buf = vec![OpCode::MessagePack as u8];

::rmp_serde::encode::write(&mut buf, self)
.map(|()| buf.into())
.map_err(Into::into)
}
}
}
}

Expand Down Expand Up @@ -592,6 +592,28 @@ impl PipePayload {
}
}

#[derive(
Copy,
Clone,
Debug,
Display,
EnumString,
Default,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Serialize,
Deserialize,
JsonSchema,
)]
pub enum Codec {
Json,
#[default]
MessagePack,
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum OpCode {
// Special opcodes
Expand Down
Loading

0 comments on commit 978b85f

Please sign in to comment.