Skip to content

Commit

Permalink
Allow setting PagedAttention KV cache allocation from context size (#640
Browse files Browse the repository at this point in the history
)

* Support paged attn memory allocation via context size

* Slightly better logging

* Connect it to the apis

* Clippy
  • Loading branch information
EricLBuehler authored Jul 28, 2024
1 parent 5e8fe09 commit 38fb942
Show file tree
Hide file tree
Showing 12 changed files with 229 additions and 109 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion mistralrs-bench/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ serde_json.workspace = true
clap.workspace = true
mistralrs-core = { version = "0.2.2", path = "../mistralrs-core" }
tracing.workspace = true
either.workspace = true
tokio.workspace = true
cli-table = "0.4.7"

Expand Down
59 changes: 44 additions & 15 deletions mistralrs-bench/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use candle_core::Device;
use clap::Parser;
use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
use either::Either;
use mistralrs_core::{
initialize_logging, paged_attn_supported, Constraint, DefaultSchedulerMethod,
DeviceLayerMapMetadata, DeviceMapMetadata, Loader, LoaderBuilder, MistralRs, MistralRsBuilder,
ModelDType, ModelSelected, NormalRequest, PagedAttentionConfig, Request, RequestMessage,
Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
DeviceLayerMapMetadata, DeviceMapMetadata, Loader, LoaderBuilder, MemoryGpuConfig, MistralRs,
MistralRsBuilder, ModelDType, ModelSelected, NormalRequest, PagedAttentionConfig, Request,
RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
};
use std::fmt::Display;
use std::sync::Arc;
Expand Down Expand Up @@ -292,6 +291,12 @@ struct Args {
#[arg(long = "pa-gpu-mem-usage")]
paged_attn_gpu_mem_usage: Option<f32>,

/// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold)
/// when using PagedAttention, which is only supported on CUDA and is always automatically activated.
/// The priority is as follows: `pa-gpu-mem-usage` (default = 0.9) > `pa-ctxt-len` > `pa-gpu-mem`.
#[arg(long = "pa-ctxt-len")]
paged_ctxt_len: Option<usize>,

/// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
/// PagedAttention is only supported on CUDA and is always automatically activated.
#[arg(long = "pa-blk-size")]
Expand Down Expand Up @@ -383,31 +388,55 @@ fn main() -> anyhow::Result<()> {
args.paged_attn_block_size,
args.paged_attn_gpu_mem,
args.paged_attn_gpu_mem_usage,
args.paged_ctxt_len,
paged_attn_supported(),
args.no_paged_attn,
) {
(block_size, None, None, true, false) => Some(PagedAttentionConfig::new(
(block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
block_size,
512,
Either::Right(0.9), // NOTE(EricLBuehler): default is to use 90% of memory
MemoryGpuConfig::Utilization(0.9), // NOTE(EricLBuehler): default is to use 90% of memory
)?),
(block_size, Some(m), None, true, false) => {
Some(PagedAttentionConfig::new(block_size, 512, Either::Left(m))?)
}
(block_size, None, Some(f), true, false) => Some(PagedAttentionConfig::new(
(block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
block_size,
512,
MemoryGpuConfig::ContextSize(ctxt),
)?),
(block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
block_size,
512,
Either::Right(f),
MemoryGpuConfig::Utilization(f),
)?),
(block_size, Some(_m), Some(f), true, false) => {
info!("Both memory size and usage were specified, defaulting to the usage value.");
(block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
block_size,
512,
MemoryGpuConfig::Amount(m),
)?),
(block_size, Some(_m), Some(f), None, true, false) => {
info!("Both memory size, and usage were specified, defaulting to the usage value.");
Some(PagedAttentionConfig::new(
block_size,
512,
MemoryGpuConfig::Utilization(f),
)?)
}
(block_size, Some(_m), None, Some(ctxt), true, false) => {
info!("All memory size and ctxt len, defaulting to the context len value.");
Some(PagedAttentionConfig::new(
block_size,
512,
MemoryGpuConfig::ContextSize(ctxt),
)?)
}
(block_size, None, Some(f), Some(_ctxt), true, false) => {
info!("Both ctxt len and usage were specified, defaulting to the usage value.");
Some(PagedAttentionConfig::new(
block_size,
512,
Either::Right(f),
MemoryGpuConfig::Utilization(f),
)?)
}
(_, _, _, _, _) => None,
(_, _, _, _, _, _) => None,
};

let pipeline = loader.load_model_from_hf(
Expand Down
38 changes: 29 additions & 9 deletions mistralrs-core/src/dummy_paged_attention/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ pub use block_engine_sequence::BlockEngineSequence;
pub use cache_engine::{CacheConfig, CacheEngine};
use candle_core::{DType, Device};
pub use config::{ModelConfigLike, ModelConfigMetadata};
use either::Either;
pub use layers::PagedAttention;
pub use scheduler::{
PagedAttentionScheduler, PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput,
Expand All @@ -30,14 +29,14 @@ use tracing::info;
pub struct PagedAttentionConfig {
pub(crate) block_size: Option<usize>,
pub(crate) mem_cpu: usize,
pub(crate) mem_gpu: Either<usize, f32>,
pub(crate) mem_gpu: MemoryGpuConfig,
}

impl PagedAttentionConfig {
pub fn new(
_block_size: Option<usize>,
_mem_cpu: usize,
_mem_gpu: Either<usize, f32>,
_mem_gpu: MemoryGpuConfig,
) -> anyhow::Result<Self> {
anyhow::bail!("PagedAttention is only supported for CUDA, compile with feature `cuda`.")
}
Expand All @@ -48,6 +47,14 @@ pub enum AttentionImplementation {
PagedAttention,
}

#[derive(Clone, Copy)]
#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)]
pub enum MemoryGpuConfig {
Amount(usize),
Utilization(f32),
ContextSize(usize),
}

// See `pagedattention.cu` CALL_V1_LAUNCHER_BLOCK_SIZE
const SUPPORTED_BLOCK_SIZE: &[usize] = &[8, 16, 32];

Expand All @@ -65,9 +72,20 @@ macro_rules! mb_to_blocks {
};
}

macro_rules! ctxt_to_blocks {
($context_len:expr, $dtype_size:expr, $block_size:expr, $config:expr) => {
$context_len
* $dtype_size
* $config.num_kv_heads()
* ($config.hidden_size() / $config.num_attn_heads())
* $config.num_layers()
* 2
};
}

/// Memory values are in MBs or a percentage in [0,1]. Specify block size or the default is 32.
pub fn calculate_cache_config(
mem_gpu: Either<usize, f32>,
mem_gpu: MemoryGpuConfig,
mem_cpu: usize,
block_size: Option<usize>,
dtype: DType,
Expand All @@ -82,16 +100,18 @@ pub fn calculate_cache_config(

#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
let mem_gpu = match mem_gpu {
Either::Left(v) => v,
Either::Right(f) => {
MemoryGpuConfig::Amount(v) => v,
MemoryGpuConfig::Utilization(f) => {
let free = MemoryUsage.get_memory_available(device)? as f32 / SIZE_IN_MB as f32;
let total = MemoryUsage.get_total_memory(device)? as f32 / SIZE_IN_MB as f32;
let used = total - free;
let size = (total * f - used) as usize;
info!("Allocating {size} MB for PagedAttention KV cache");
size
(total * f - used) as usize
}
MemoryGpuConfig::ContextSize(toks) => {
ctxt_to_blocks!(toks, dtype_size, block_size, config) / SIZE_IN_MB
}
};
info!("Allocating {mem_gpu} MB for PagedAttention KV cache");

let num_gpu_blocks = mb_to_blocks!(mem_gpu * SIZE_IN_MB, dtype_size, block_size, config);
let num_cpu_blocks = mb_to_blocks!(mem_cpu * SIZE_IN_MB, dtype_size, block_size, config);
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ mod xlora_models;

pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
pub use device_map::{DeviceLayerMapMetadata, DeviceMapMetadata, LayerDeviceMapper};
pub use paged_attention::PagedAttentionConfig;
pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig};
pub use pipeline::{
chat_template::ChatTemplate, AnyMoeLoader, AnyMoePipeline, GGMLLoader, GGMLLoaderBuilder,
GGMLSpecificConfig, GGUFArchitecture, GGUFLoader, GGUFLoaderBuilder, GemmaLoader,
Expand Down
38 changes: 29 additions & 9 deletions mistralrs-core/src/paged_attention/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ pub use block_engine_sequence::BlockEngineSequence;
pub use cache_engine::{CacheConfig, CacheEngine};
use candle_core::{DType, Device};
pub use config::{ModelConfigLike, ModelConfigMetadata};
use either::Either;
pub use layers::PagedAttention;
pub use scheduler::{
PagedAttentionScheduler, PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput,
Expand All @@ -30,14 +29,14 @@ use tracing::info;
pub struct PagedAttentionConfig {
pub(crate) block_size: Option<usize>,
pub(crate) mem_cpu: usize,
pub(crate) mem_gpu: Either<usize, f32>,
pub(crate) mem_gpu: MemoryGpuConfig,
}

impl PagedAttentionConfig {
pub fn new(
block_size: Option<usize>,
mem_cpu: usize,
mem_gpu: Either<usize, f32>,
mem_gpu: MemoryGpuConfig,
) -> anyhow::Result<Self> {
Ok(Self {
block_size,
Expand All @@ -52,6 +51,14 @@ pub enum AttentionImplementation {
PagedAttention,
}

#[derive(Clone, Copy)]
#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)]
pub enum MemoryGpuConfig {
Amount(usize),
Utilization(f32),
ContextSize(usize),
}

// See `pagedattention.cu` CALL_V1_LAUNCHER_BLOCK_SIZE
const SUPPORTED_BLOCK_SIZE: &[usize] = &[8, 16, 32];

Expand All @@ -69,9 +76,20 @@ macro_rules! mb_to_blocks {
};
}

macro_rules! ctxt_to_blocks {
($context_len:expr, $dtype_size:expr, $block_size:expr, $config:expr) => {
$context_len
* $dtype_size
* $config.num_kv_heads()
* ($config.hidden_size() / $config.num_attn_heads())
* $config.num_layers()
* 2
};
}

/// Memory values are in MBs or a percentage in [0,1]. Specify block size or the default is 32.
pub fn calculate_cache_config(
mem_gpu: Either<usize, f32>,
mem_gpu: MemoryGpuConfig,
mem_cpu: usize,
block_size: Option<usize>,
dtype: DType,
Expand All @@ -86,16 +104,18 @@ pub fn calculate_cache_config(

#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
let mem_gpu = match mem_gpu {
Either::Left(v) => v,
Either::Right(f) => {
MemoryGpuConfig::Amount(v) => v,
MemoryGpuConfig::Utilization(f) => {
let free = MemoryUsage.get_memory_available(device)? as f32 / SIZE_IN_MB as f32;
let total = MemoryUsage.get_total_memory(device)? as f32 / SIZE_IN_MB as f32;
let used = total - free;
let size = (total * f - used) as usize;
info!("Allocating {size} MB for PagedAttention KV cache");
size
(total * f - used) as usize
}
MemoryGpuConfig::ContextSize(toks) => {
ctxt_to_blocks!(toks, dtype_size, block_size, config) / SIZE_IN_MB
}
};
info!("Allocating {mem_gpu} MB for PagedAttention KV cache");

let num_gpu_blocks = mb_to_blocks!(mem_gpu * SIZE_IN_MB, dtype_size, block_size, config);
let num_cpu_blocks = mb_to_blocks!(mem_cpu * SIZE_IN_MB, dtype_size, block_size, config);
Expand Down
42 changes: 15 additions & 27 deletions mistralrs-paged-attn/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,29 @@ const CUDA_NVCC_FLAGS: Option<&'static str> = option_env!("CUDA_NVCC_FLAGS");

#[cfg(all(feature = "cuda", target_family = "unix"))]
fn main() -> Result<()> {
use std::fs;
use std::fs::read_to_string;
use std::fs::OpenOptions;
use std::io::prelude::*;
use std::path::PathBuf;

const OTHER_CONTENT: &str = r#"
#[cfg(all(feature = "cuda", target_family = "unix"))]
mod ffi;
pub const COPY_BLOCKS_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/copy_blocks_kernel.ptx"));
#[cfg(all(feature = "cuda", target_family = "unix"))]
pub const PAGEDATTENTION: &str = include_str!(concat!(env!("OUT_DIR"), "/pagedattention.ptx"));
#[cfg(all(feature = "cuda", target_family = "unix"))]
pub const RESHAPE_AND_CACHE_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx"));
#[cfg(all(feature = "cuda", target_family = "unix"))]
mod backend;
#[cfg(all(feature = "cuda", target_family = "unix"))]
mod ffi;
#[cfg(all(feature = "cuda", target_family = "unix"))]
pub use backend::{{copy_blocks, paged_attention, reshape_and_cache, swap_blocks}};
pub use backend::{copy_blocks, paged_attention, reshape_and_cache, swap_blocks};
"#;

fn read_lines(filename: &str) -> Vec<String> {
let mut result = Vec::new();

for line in read_to_string(filename).unwrap().lines() {
result.push(line.to_string())
}

result
}

println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=src/pagedattention.cu");
println!("cargo:rerun-if-changed=src/copy_blocks_kernel.cu");
Expand All @@ -57,20 +54,11 @@ pub use backend::{{copy_blocks, paged_attention, reshape_and_cache, swap_blocks}
println!("cargo:rustc-link-lib=pagedattention");
println!("cargo:rustc-link-lib=dylib=cudart");

let contents = read_lines("src/lib.rs");
for line in contents {
if line == "pub mod ffi;" {
return Ok(());
}
}
let ct = fs::read_to_string("src/lib.rs")?;
if !ct.contains(OTHER_CONTENT) {
let mut file = OpenOptions::new().append(true).open("src/lib.rs").unwrap();
let mut file = OpenOptions::new().write(true).open("src/lib.rs").unwrap();

// Add the other stuff back
if let Err(e) = writeln!(file, "{OTHER_CONTENT}") {
anyhow::bail!("Error while building dependencies: {:?}\n", e)
}
// Add the other stuff back
if let Err(e) = writeln!(file, "{OTHER_CONTENT}") {
anyhow::bail!("Error while building dependencies: {:?}\n", e)
}
Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-paged-attn/src/backend/paged_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ impl candle::CustomOp1 for PagedAttention {
///
/// * `q` - Query tensor with shape `(num_sequences, num_heads_q, head_size)`.
/// * `key_cache` - Key cache paged tensor of shape `(num_blocks, num_heads_kv, head_size / x, block_size, x)`
/// with `x` being the size of an element in bytes.
/// with `x` being the size of an element in bytes.
/// * `value_cache` - Value cache paged tensor of shape `(num_blocks, num_heads_kv, head_size, block_size)`.
/// * `block_tables` - Padded table associating blocks to each sequence of shape `(num_sequences, max_context_len // block_size)`
/// * `context_lens` - Tensor associating lengths to each sequence of shape `(num_sequences)`
Expand Down Expand Up @@ -439,7 +439,7 @@ fn update_cache<
/// * `key` - Key tensor of shape `(num_tokens, num_heads, head_size)`.
/// * `value` - Value tensor of shape `(num_tokens, num_heads, head_size)`.
/// * `key_cache` - Key cache paged tensor of shape `(num_blocks, num_heads, head_size / x, block_size, x)`
/// with `x` being the size of an element in bytes.
/// with `x` being the size of an element in bytes.
/// * `value_cache` - Value cache paged tensor of shape `(num_blocks, num_heads, head_size, block_size)`.
/// * `slot_mapping` - Mapping associating a slot to each token of shape `(num_tokens)`.
pub fn reshape_and_cache(
Expand Down
12 changes: 10 additions & 2 deletions mistralrs-pyo3/mistralrs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,16 @@ class Runner:
the corresponding number of layers.
- `in_situ_quant` sets the optional in-situ quantization for models that are not quantized (not GGUF or GGML).
- `anymoe_config` specifies the AnyMoE config. If this is set, then the model will be loaded as an AnyMoE model.
- `pa_gpu_mem` sets GPU memory to allocate for KV cache with PagedAttention in MBs *OR* the percentage utilization, from 0 to 1. If this is not set and the device is
CUDA, it will default to using 90% of the total memory after allocation of the KV cache. PagedAttention is only supported on CUDA and is always automatically activated.
- `pa_gpu_mem`: GPU memory to allocate for KV cache with PagedAttention in MBs.
PagedAttention is only supported on CUDA and is always automatically activated.
The priority is as follows: `pa-gpu-mem-usage` (default = 0.9) > `pa-ctxt-len` > `pa-gpu-mem`.
- `pa_gpu_mem_usage`: Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
If this is not set and the device is CUDA, it will default to `0.9`.
PagedAttention is only supported on CUDA and is always automatically activated.
The priority is as follows: `pa-gpu-mem-usage` (default = 0.9) > `pa-ctxt-len` > `pa-gpu-mem`.
- `pa_ctxt_len`: Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold)
when using PagedAttention, which is only supported on CUDA and is always automatically activated.
The priority is as follows: `pa-gpu-mem-usage` (default = 0.9) > `pa-ctxt-len` > `pa-gpu-mem`.
- `pa_blk_size` sets the block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA,
it will default to 32. PagedAttention is only supported on CUDA and is always automatically activated.
- `no_paged_attn` disables PagedAttention on CUDA
Expand Down
Loading

0 comments on commit 38fb942

Please sign in to comment.