Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Add Falcon Support #313

Merged
merged 12 commits into from
Jun 28, 2023
6 changes: 6 additions & 0 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ pub enum Args {
#[command(subcommand)]
args: BaseArgs,
},
/// Use a Falcon model
#[clap(id = "falcon")]
Falcon {
#[command(subcommand)]
args: BaseArgs,
},
}

#[derive(Subcommand, Debug)]
Expand Down
1 change: 1 addition & 0 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ fn main() -> Result<()> {
Args::GptJ { args } => handle_args::<llm::models::GptJ>(args),
Args::GptNeoX { args } => handle_args::<llm::models::GptNeoX>(args),
Args::Mpt { args } => handle_args::<llm::models::Mpt>(args),
Args::Falcon { args } => handle_args::<llm::models::Falcon>(args),
}
}

Expand Down
4 changes: 3 additions & 1 deletion crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ llm-gptj = { path = "../models/gptj", optional = true, version = "0.2.0-dev" }
llm-bloom = { path = "../models/bloom", optional = true, version = "0.2.0-dev" }
llm-gptneox = { path = "../models/gptneox", optional = true, version = "0.2.0-dev" }
llm-mpt = { path = "../models/mpt", optional = true, version = "0.2.0-dev" }
llm-falcon = { path = "../models/falcon", optional = true, version = "0.2.0-dev" }

serde = { workspace = true }

Expand All @@ -28,10 +29,11 @@ serde_json = { workspace = true }
clap = { workspace = true }

[features]
default = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt"]
default = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt", "falcon"]
llama = ["dep:llm-llama"]
gpt2 = ["dep:llm-gpt2"]
gptj = ["dep:llm-gptj"]
bloom = ["dep:llm-bloom"]
gptneox = ["dep:llm-gptneox"]
mpt = ["dep:llm-mpt"]
falcon = ["dep:llm-falcon"]
15 changes: 15 additions & 0 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//! - [GPT-NeoX](llm_gptneox)
//! - [LLaMA](llm_llama)
//! - [MPT](llm_mpt)
//! - [Falcon](llm_falcon)
//!
//! At present, the only supported backend is [GGML](https://github.com/ggerganov/ggml), but this is expected to
//! change in the future.
Expand Down Expand Up @@ -101,6 +102,8 @@ pub mod models {
pub use llm_llama::{self as llama, Llama};
#[cfg(feature = "mpt")]
pub use llm_mpt::{self as mpt, Mpt};
#[cfg(feature = "falcon")]
pub use llm_falcon::{self as falcon, Falcon};
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
Expand All @@ -124,6 +127,9 @@ pub enum ModelArchitecture {
#[cfg(feature = "mpt")]
/// [MPT](llm_mpt)
Mpt,
#[cfg(feature = "falcon")]
/// [Falcon](llm_falcon)
Falcon,
}

impl ModelArchitecture {
Expand All @@ -141,6 +147,8 @@ impl ModelArchitecture {
Self::Llama,
#[cfg(feature = "mpt")]
Self::Mpt,
#[cfg(feature = "falcon")]
Self::Falcon,
];
}

Expand Down Expand Up @@ -184,6 +192,8 @@ impl FromStr for ModelArchitecture {
"llama" => Ok(Llama),
#[cfg(feature = "mpt")]
"mpt" => Ok(Mpt),
#[cfg(feature = "falcon")]
"falcon" => Ok(Falcon),

_ => Err(UnsupportedModelArchitecture(format!(
"{s} is not a supported model architecture"
Expand All @@ -209,6 +219,8 @@ impl Display for ModelArchitecture {
Llama => write!(f, "LLaMA"),
#[cfg(feature = "mpt")]
Mpt => write!(f, "MPT"),
#[cfg(feature = "falcon")]
Falcon => write!(f, "Falcon"),
}
}
}
Expand Down Expand Up @@ -263,6 +275,9 @@ pub fn load_dynamic(
}
#[cfg(feature = "mpt")]
Mpt => load_model::<models::Mpt>(path, vocabulary_source, params, load_progress_callback)?,
#[cfg(feature = "falcon")]
Falcon => load_model::<models::Falcon>(path, vocabulary_source, params, load_progress_callback)?,

};

Ok(model)
Expand Down
13 changes: 13 additions & 0 deletions crates/models/falcon/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "llm-falcon"
version = "0.2.0-dev"
license = { workspace = true }
repository = { workspace = true }
description = "An implementation of tiiuae falcon model for the `llm` ecosystem."
LLukas22 marked this conversation as resolved.
Show resolved Hide resolved
edition = "2021"
readme = "../../../README.md"

[dependencies]
llm-base = { path = "../../llm-base", version = "0.2.0-dev" }

bytemuck = { workspace = true }
Loading