Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial KV RingAttention code #684

Open
wants to merge 109 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
109 commits
Select commit Hold shift + click to select a range
b186a77
test minimal changes
joshpopelka20 Aug 12, 2024
152a41c
add to struct
joshpopelka20 Aug 12, 2024
636a48f
Merge branch 'EricLBuehler:master' into master
joshpopelka20 Aug 14, 2024
3e1f47b
add chunks logic
joshpopelka20 Aug 14, 2024
bba0fdd
Merge branch 'master' of https://github.com/joshpopelka20/mistral.rs
joshpopelka20 Aug 14, 2024
04ee7de
clone chunks
joshpopelka20 Aug 14, 2024
83ab8f2
clone x for chunk
joshpopelka20 Aug 14, 2024
36d5eb2
remove chunk to device
joshpopelka20 Aug 14, 2024
c1a973a
push x
joshpopelka20 Aug 14, 2024
e5aee16
fix x move
joshpopelka20 Aug 14, 2024
7aeb80c
dont clone chunks
joshpopelka20 Aug 14, 2024
00f019f
unwrap chunk
joshpopelka20 Aug 14, 2024
4b836bc
change to reference
joshpopelka20 Aug 14, 2024
3763fee
iter
joshpopelka20 Aug 14, 2024
32d3b0d
pop chunks
joshpopelka20 Aug 14, 2024
40a502b
clone x
joshpopelka20 Aug 14, 2024
c0c87e4
change to vec new
joshpopelka20 Aug 14, 2024
728e838
store tensor reference
joshpopelka20 Aug 14, 2024
daf1028
extract by index
joshpopelka20 Aug 14, 2024
ccc6c50
remove unwrap
joshpopelka20 Aug 14, 2024
2ca9acc
clone
joshpopelka20 Aug 14, 2024
e822953
mutably borrow
joshpopelka20 Aug 14, 2024
086e76f
derefernce
joshpopelka20 Aug 14, 2024
ce3d418
create vec of tensors
joshpopelka20 Aug 14, 2024
ffbe3f9
make new vec
joshpopelka20 Aug 14, 2024
9f91594
type tensor
joshpopelka20 Aug 14, 2024
c88cdf5
push to chunks
joshpopelka20 Aug 14, 2024
0eb85f5
self chunks
joshpopelka20 Aug 14, 2024
b7edfbe
create vec of chunks
joshpopelka20 Aug 14, 2024
8732bbc
clone x
joshpopelka20 Aug 14, 2024
9c961f3
remove reference
joshpopelka20 Aug 14, 2024
1aaca72
clone for move
joshpopelka20 Aug 14, 2024
4ab98c8
remove clone
joshpopelka20 Aug 14, 2024
2e0b2fd
add back clone
joshpopelka20 Aug 14, 2024
7bb3cf7
change to copy
joshpopelka20 Aug 14, 2024
f433517
unwrap copy
joshpopelka20 Aug 14, 2024
cf5b204
remove copy
joshpopelka20 Aug 14, 2024
9e0e6c8
use my candle
joshpopelka20 Aug 14, 2024
03be02a
mvoe back to EricLBuehler
joshpopelka20 Aug 14, 2024
d75ee88
move back to josh
joshpopelka20 Aug 14, 2024
f79ef6f
revert candle
joshpopelka20 Aug 14, 2024
4b9ed28
remove copy mapper
joshpopelka20 Aug 14, 2024
b54a5af
clone chunks
joshpopelka20 Aug 14, 2024
a50883b
copy instead of clone
joshpopelka20 Aug 14, 2024
edf82da
move loggers
joshpopelka20 Aug 14, 2024
3e9cc26
add sequence parallelism
joshpopelka20 Aug 22, 2024
30f6b40
add IndexOp import
joshpopelka20 Aug 22, 2024
7e23976
only use chunk on first block index
joshpopelka20 Aug 26, 2024
f20005a
split input into multiple chunks
joshpopelka20 Aug 26, 2024
9d0b6ce
add missing variable block_chunks
joshpopelka20 Aug 26, 2024
0c6a64c
use each chunk first
joshpopelka20 Aug 26, 2024
86e1e54
clone x in accumulated attention
joshpopelka20 Aug 26, 2024
535e5c7
change mapper with block_chunks
joshpopelka20 Aug 26, 2024
8d55784
give block chunks a type
joshpopelka20 Aug 26, 2024
f738105
make as type tensor
joshpopelka20 Aug 26, 2024
4addbb5
move block chunks
joshpopelka20 Aug 26, 2024
d6ffb10
add to accumulated attention
joshpopelka20 Aug 26, 2024
61b9b8a
unwrap x
joshpopelka20 Aug 26, 2024
c1cc882
&tensor
joshpopelka20 Aug 26, 2024
f87ead1
fix block_chunks
joshpopelka20 Aug 26, 2024
f665810
make generic type
joshpopelka20 Aug 26, 2024
8140413
fix blocks_chunks to device
joshpopelka20 Aug 26, 2024
23af80c
another fix for concat block_chunks
joshpopelka20 Aug 26, 2024
dd689e3
remove ? operator
joshpopelka20 Aug 26, 2024
2106933
replace with try_collect
joshpopelka20 Aug 26, 2024
71fdd71
change type of block_chunks
joshpopelka20 Aug 26, 2024
0b129fa
clone to move blcok_chunks
joshpopelka20 Aug 26, 2024
c09b459
remove add
joshpopelka20 Aug 26, 2024
d201134
switch to four devices
joshpopelka20 Aug 27, 2024
c5b4fde
fix compile error with &
joshpopelka20 Aug 27, 2024
79f7606
uodate metadata device
joshpopelka20 Aug 28, 2024
f50a159
add kv cache rotation
joshpopelka20 Aug 28, 2024
b913fee
add missing num_caches
joshpopelka20 Aug 28, 2024
0a7b422
fix compile error
joshpopelka20 Aug 28, 2024
c98dcb7
clone mapper
joshpopelka20 Aug 28, 2024
962f744
remove clone
joshpopelka20 Sep 3, 2024
7cd3503
clone reference
joshpopelka20 Sep 3, 2024
ea04012
return tensor
joshpopelka20 Sep 3, 2024
a57e1c9
remove borrow
joshpopelka20 Sep 3, 2024
b69edcc
fix value moved
joshpopelka20 Sep 3, 2024
7cfb29d
borrow on accumulate
joshpopelka20 Sep 3, 2024
da65eb2
add logging
joshpopelka20 Sep 3, 2024
9c5cd38
more logging
joshpopelka20 Sep 3, 2024
cdd480d
fix chunk to device chunk
joshpopelka20 Sep 3, 2024
4eb4775
remove concat block_chunks
joshpopelka20 Sep 3, 2024
ee27e98
move cache to chunk device
joshpopelka20 Sep 3, 2024
57ae1d8
fix error in masker
joshpopelka20 Sep 3, 2024
34bf2d1
move all to block device
joshpopelka20 Sep 3, 2024
a20d7a4
change to block device
joshpopelka20 Sep 3, 2024
042c0a1
change block device args
joshpopelka20 Sep 3, 2024
770806a
add device to block
joshpopelka20 Sep 3, 2024
0b3c911
fix llama struct
joshpopelka20 Sep 3, 2024
bcf6f84
revert blocks device
joshpopelka20 Sep 3, 2024
9945a8d
revert to device chunk
joshpopelka20 Sep 3, 2024
8d0bc24
add block device
joshpopelka20 Sep 3, 2024
06525fe
add reference
joshpopelka20 Sep 3, 2024
a03670d
update tensor device
joshpopelka20 Sep 3, 2024
deabd31
borrow device chunk
joshpopelka20 Sep 3, 2024
3170304
more logging
joshpopelka20 Sep 3, 2024
1e3e55d
log logits
joshpopelka20 Sep 3, 2024
4fab476
try to clone out all caches
joshpopelka20 Sep 3, 2024
5863802
add logging in cacher
joshpopelka20 Sep 3, 2024
ddcd848
revert clone out cache
joshpopelka20 Sep 3, 2024
d9ac7ec
skip clone out
joshpopelka20 Sep 3, 2024
81cd584
have cache out do nothing
joshpopelka20 Sep 3, 2024
1456c72
fix syntax
joshpopelka20 Sep 3, 2024
d52cdd8
remove clone in cache
joshpopelka20 Sep 3, 2024
bf80940
remove loggers
joshpopelka20 Sep 3, 2024
a4dcd1e
test speculative
joshpopelka20 Sep 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 180 additions & 27 deletions mistralrs-core/src/models/llama.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]

use candle_core::{quantized::QMatMul, DType, Device, Result, Tensor};
use candle_core::{quantized::QMatMul, DType, Device, Result, Tensor, IndexOp};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use mistralrs_quant::{QuantMethod, QuantizedConfig};
use serde::Deserialize;
Expand Down Expand Up @@ -361,10 +361,15 @@ impl Block {
metadata,
)? + residual)?;
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
// let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
let x = self.rms_2.forward(&x)?;
Ok(x)
}

fn get_device(&self) -> Device {
self.mlp.dtype_device().1
}

fn load(
vb: VarBuilder,
cfg: &Config,
Expand Down Expand Up @@ -405,7 +410,9 @@ pub struct Llama {
blocks: Vec<Block>,
ln_f: RmsNorm,
lm_head: QMatMul,
pub kv_cache: crate::pipeline::Cache,
// pub kv_cache: crate::pipeline::Cache,
pub kv_caches: Vec<crate::pipeline::Cache>,
cuda_devices: Vec<candle_core::Device>,
pub device: Device,
mapper: Box<dyn DeviceMapper + Send + Sync>,
cfg: ModelConfigMetadata,
Expand All @@ -421,30 +428,160 @@ impl Llama {
mut metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>,
) -> Result<Tensor> {
let mut x = self.wte.forward(input_ids)?;
let mut cache = self.kv_cache.lock();
let mask = CausalMasker.make_causal_mask_as_attn_bias(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
x.dtype(),
self.blocks[0].attn.num_attention_heads,
)?;
let (batch_size, seq_len, hidden_size) = x.dims3()?;

let num_devices = 4;
let chunk_size = seq_len / num_devices;

let mut chunks: Vec<Tensor> = Vec::with_capacity(num_devices);
// chunks.push(x.copy().unwrap());

// Handle the case where sequence length is less than number of devices
if seq_len <= num_devices {
for j in 0..seq_len {
// let chunk = x.i((.., j..j+1, ..))?;
let chunk = x.clone();
chunks.push(chunk.to_device(&self.cuda_devices[j])?);
}
} else {
for j in 0..num_devices {
let start = j * chunk_size;
let end = if j == num_devices - 1 {
seq_len
} else {
(j+ 1) * chunk_size
};

let chunk = x.i((.., start..end,..))?;
let device = &self.cuda_devices[j];
chunks.push(chunk.to_device(&device)?);
}
}

// let mut cache = self.kv_caches[0].lock();
let mut processed_chunks = Vec::new();
let mut target_device = &self.cuda_devices[0];

let mut block_chunks: Vec<Tensor> = Vec::new();

for (block_idx, block) in self.blocks.iter().enumerate() {
x = self.mapper.map(x, block_idx)?;
x = block.forward(
&x,
&mask.clone().map(|m| m.to_device(x.device()).unwrap()),
seqlen_offsets,
start_offsets_kernel.clone(),
block_idx,
&mut cache,
metadata
.as_mut()
.map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), &mut **metadata)),
)?;

let device_chunk = block.get_device();
// x = self.mapper.map(x, block_idx)?;
// x = self.mapper.map(&chunks[0], block_idx)?;
// println!("block_idx {:?}", block_idx);
// println!("chunk device {:?}", chunks[0].device());
for (chunk_idx, chunk) in chunks.iter().enumerate() {
// println!("chunk_idx {:?}", chunk_idx);
let mut x = if block_idx == 0 {
let tensor = chunk.clone();
self.mapper.map(tensor.clone(), block_idx)?;
tensor.to_device(&device_chunk)?
} else {
let tensor = block_chunks[chunk_idx].clone();
self.mapper.map(tensor.clone(), block_idx)?;
tensor.to_device(&device_chunk)?
};

let num_caches = self.kv_caches.len();

for cache_rotation in 0..num_caches {
let cache_idx = (chunk_idx + cache_rotation) % num_caches;
let kv_cache = &self.kv_caches[cache_idx];
// println!("cache_idx {:?}", cache_idx);
let mut cache = kv_cache.lock();


// Determine the original device of the cache
let original_cache_device = cache.iter().find_map(|opt| {
opt.as_ref().map(|(k, _)| k.device().clone())
}).unwrap_or_else(|| device_chunk.clone());

// Move cache to chunk device
let mut cache_on_chunk_device: Vec<_> = cache.iter().map(|opt| {
opt.as_ref().map(|(k, v)| {
(k.to_device(&device_chunk).unwrap(), v.to_device(&device_chunk).unwrap())
})
}).collect();

let mask = CausalMasker.make_causal_mask_as_attn_bias(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
// x.dtype(),
chunks[0].dtype(),
self.blocks[0].attn.num_attention_heads,
)?;




// x = block.forward(
// &x,
// &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
// seqlen_offsets,
// start_offsets_kernel.clone(),
// block_idx,
// &mut cache,
// metadata
// .as_mut()
// .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), &mut **metadata)),
// )?;

// println!("before block forward");
x = block.forward(
&x,
&mask.clone().map(|m| m.to_device(&device_chunk).unwrap()),
seqlen_offsets,
start_offsets_kernel.clone().to_device(&device_chunk)?,
block_idx,
// &mut cache_on_chunk_device,
&mut cache_on_chunk_device,
metadata
.as_mut()
.map(|(kv_cache, metadata)| {
let (tensor1, tensor2) = kv_cache[block_idx].clone();
(
(tensor1.to_device(&device_chunk).unwrap(), tensor2.to_device(&device_chunk).unwrap()),
&mut **metadata
)
}),
)?;

// println!("after block forward");

// Accumulate attention results
if block_chunks.len() <= chunk_idx {
block_chunks.push(x.clone());
} else {
block_chunks[chunk_idx] = x.clone();
}
}
}

// Concatenate chunks for this block
// let block_chunks: Result<Vec<Tensor>> = block_chunks
// .clone()
// .into_iter()
// .map(|chunk| chunk.to_device(&device_chunk))
// .collect();

// let block_chunks = block_chunks?; // Propagate any errors

// println!("concat block chunks");
let mut x = candle_core::Tensor::cat(&block_chunks, 1)?;

// do feedforward after attention has been run for each chunk
let residual = x.clone();
let mut x = block.mlp.forward(&x)?;
x = (x + &residual)?;
x = x.to_device(&target_device)?;
processed_chunks.push(x.clone());
}
// println!("concat processed chunks");
x = candle_core::Tensor::cat(&processed_chunks, 1)?;
let x = x.to_device(&self.device)?;
let mut x = self.ln_f.forward(&x)?;
if matches!(self.lm_head, QMatMul::QTensor(_)) {
Expand All @@ -468,6 +605,9 @@ impl Llama {
quant_cfg.bits
);
}

let num_devices = 4;
let mut cuda_devices = Vec::with_capacity(num_devices);
let mapper = normal_loading_metadata
.mapper
.into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?;
Expand Down Expand Up @@ -514,6 +654,9 @@ impl Llama {
.expect("Failed to create PagedAttention"),
),
};
if !cuda_devices.iter().any(|d| format!("{:?}", d) == format!("{:?}", device)) {
cuda_devices.push(device.clone());
}
Block::load(
vb.pp(&format!("model.layers.{i}")),
cfg,
Expand All @@ -527,12 +670,21 @@ impl Llama {
})
.collect();

let mut kv_caches: Vec<crate::pipeline::Cache> = Vec::with_capacity(num_devices);

for device_id in 0..num_devices {
let cache = crate::pipeline::Cache::new(cfg.num_hidden_layers , false);
kv_caches.push(cache);
};

Ok(Self {
wte,
blocks,
ln_f,
lm_head: QMatMul::Tensor(lm_head.weight().clone()),
kv_cache: crate::pipeline::Cache::new(cfg.num_hidden_layers, false),
// kv_cache: crate::pipeline::Cache::new(cfg.num_hidden_layers, false),
kv_caches,
cuda_devices,
device: normal_loading_metadata.real_device,
mapper,
cfg: ModelConfigMetadata {
Expand Down Expand Up @@ -623,7 +775,8 @@ impl NormalModel for Llama {
unimplemented!()
}
fn cache(&self) -> &crate::pipeline::Cache {
&self.kv_cache
&self.kv_caches[0]
// &self.kv_cache
}
fn device(&self) -> &Device {
&self.device
Expand Down
Loading
Loading