Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial KV RingAttention code #684

Open
wants to merge 109 commits into
base: master
Choose a base branch
from
Open
Changes from 45 commits
Commits
Show all changes
109 commits
Select commit Hold shift + click to select a range
b186a77
test minimal changes
joshpopelka20 Aug 12, 2024
152a41c
add to struct
joshpopelka20 Aug 12, 2024
636a48f
Merge branch 'EricLBuehler:master' into master
joshpopelka20 Aug 14, 2024
3e1f47b
add chunks logic
joshpopelka20 Aug 14, 2024
bba0fdd
Merge branch 'master' of https://github.com/joshpopelka20/mistral.rs
joshpopelka20 Aug 14, 2024
04ee7de
clone chunks
joshpopelka20 Aug 14, 2024
83ab8f2
clone x for chunk
joshpopelka20 Aug 14, 2024
36d5eb2
remove chunk to device
joshpopelka20 Aug 14, 2024
c1a973a
push x
joshpopelka20 Aug 14, 2024
e5aee16
fix x move
joshpopelka20 Aug 14, 2024
7aeb80c
dont clone chunks
joshpopelka20 Aug 14, 2024
00f019f
unwrap chunk
joshpopelka20 Aug 14, 2024
4b836bc
change to reference
joshpopelka20 Aug 14, 2024
3763fee
iter
joshpopelka20 Aug 14, 2024
32d3b0d
pop chunks
joshpopelka20 Aug 14, 2024
40a502b
clone x
joshpopelka20 Aug 14, 2024
c0c87e4
change to vec new
joshpopelka20 Aug 14, 2024
728e838
store tensor reference
joshpopelka20 Aug 14, 2024
daf1028
extract by index
joshpopelka20 Aug 14, 2024
ccc6c50
remove unwrap
joshpopelka20 Aug 14, 2024
2ca9acc
clone
joshpopelka20 Aug 14, 2024
e822953
mutably borrow
joshpopelka20 Aug 14, 2024
086e76f
derefernce
joshpopelka20 Aug 14, 2024
ce3d418
create vec of tensors
joshpopelka20 Aug 14, 2024
ffbe3f9
make new vec
joshpopelka20 Aug 14, 2024
9f91594
type tensor
joshpopelka20 Aug 14, 2024
c88cdf5
push to chunks
joshpopelka20 Aug 14, 2024
0eb85f5
self chunks
joshpopelka20 Aug 14, 2024
b7edfbe
create vec of chunks
joshpopelka20 Aug 14, 2024
8732bbc
clone x
joshpopelka20 Aug 14, 2024
9c961f3
remove reference
joshpopelka20 Aug 14, 2024
1aaca72
clone for move
joshpopelka20 Aug 14, 2024
4ab98c8
remove clone
joshpopelka20 Aug 14, 2024
2e0b2fd
add back clone
joshpopelka20 Aug 14, 2024
7bb3cf7
change to copy
joshpopelka20 Aug 14, 2024
f433517
unwrap copy
joshpopelka20 Aug 14, 2024
cf5b204
remove copy
joshpopelka20 Aug 14, 2024
9e0e6c8
use my candle
joshpopelka20 Aug 14, 2024
03be02a
mvoe back to EricLBuehler
joshpopelka20 Aug 14, 2024
d75ee88
move back to josh
joshpopelka20 Aug 14, 2024
f79ef6f
revert candle
joshpopelka20 Aug 14, 2024
4b9ed28
remove copy mapper
joshpopelka20 Aug 14, 2024
b54a5af
clone chunks
joshpopelka20 Aug 14, 2024
a50883b
copy instead of clone
joshpopelka20 Aug 14, 2024
edf82da
move loggers
joshpopelka20 Aug 14, 2024
3e9cc26
add sequence parallelism
joshpopelka20 Aug 22, 2024
30f6b40
add IndexOp import
joshpopelka20 Aug 22, 2024
7e23976
only use chunk on first block index
joshpopelka20 Aug 26, 2024
f20005a
split input into multiple chunks
joshpopelka20 Aug 26, 2024
9d0b6ce
add missing variable block_chunks
joshpopelka20 Aug 26, 2024
0c6a64c
use each chunk first
joshpopelka20 Aug 26, 2024
86e1e54
clone x in accumulated attention
joshpopelka20 Aug 26, 2024
535e5c7
change mapper with block_chunks
joshpopelka20 Aug 26, 2024
8d55784
give block chunks a type
joshpopelka20 Aug 26, 2024
f738105
make as type tensor
joshpopelka20 Aug 26, 2024
4addbb5
move block chunks
joshpopelka20 Aug 26, 2024
d6ffb10
add to accumulated attention
joshpopelka20 Aug 26, 2024
61b9b8a
unwrap x
joshpopelka20 Aug 26, 2024
c1cc882
&tensor
joshpopelka20 Aug 26, 2024
f87ead1
fix block_chunks
joshpopelka20 Aug 26, 2024
f665810
make generic type
joshpopelka20 Aug 26, 2024
8140413
fix blocks_chunks to device
joshpopelka20 Aug 26, 2024
23af80c
another fix for concat block_chunks
joshpopelka20 Aug 26, 2024
dd689e3
remove ? operator
joshpopelka20 Aug 26, 2024
2106933
replace with try_collect
joshpopelka20 Aug 26, 2024
71fdd71
change type of block_chunks
joshpopelka20 Aug 26, 2024
0b129fa
clone to move blcok_chunks
joshpopelka20 Aug 26, 2024
c09b459
remove add
joshpopelka20 Aug 26, 2024
d201134
switch to four devices
joshpopelka20 Aug 27, 2024
c5b4fde
fix compile error with &
joshpopelka20 Aug 27, 2024
79f7606
uodate metadata device
joshpopelka20 Aug 28, 2024
f50a159
add kv cache rotation
joshpopelka20 Aug 28, 2024
b913fee
add missing num_caches
joshpopelka20 Aug 28, 2024
0a7b422
fix compile error
joshpopelka20 Aug 28, 2024
c98dcb7
clone mapper
joshpopelka20 Aug 28, 2024
962f744
remove clone
joshpopelka20 Sep 3, 2024
7cd3503
clone reference
joshpopelka20 Sep 3, 2024
ea04012
return tensor
joshpopelka20 Sep 3, 2024
a57e1c9
remove borrow
joshpopelka20 Sep 3, 2024
b69edcc
fix value moved
joshpopelka20 Sep 3, 2024
7cfb29d
borrow on accumulate
joshpopelka20 Sep 3, 2024
da65eb2
add logging
joshpopelka20 Sep 3, 2024
9c5cd38
more logging
joshpopelka20 Sep 3, 2024
cdd480d
fix chunk to device chunk
joshpopelka20 Sep 3, 2024
4eb4775
remove concat block_chunks
joshpopelka20 Sep 3, 2024
ee27e98
move cache to chunk device
joshpopelka20 Sep 3, 2024
57ae1d8
fix error in masker
joshpopelka20 Sep 3, 2024
34bf2d1
move all to block device
joshpopelka20 Sep 3, 2024
a20d7a4
change to block device
joshpopelka20 Sep 3, 2024
042c0a1
change block device args
joshpopelka20 Sep 3, 2024
770806a
add device to block
joshpopelka20 Sep 3, 2024
0b3c911
fix llama struct
joshpopelka20 Sep 3, 2024
bcf6f84
revert blocks device
joshpopelka20 Sep 3, 2024
9945a8d
revert to device chunk
joshpopelka20 Sep 3, 2024
8d0bc24
add block device
joshpopelka20 Sep 3, 2024
06525fe
add reference
joshpopelka20 Sep 3, 2024
a03670d
update tensor device
joshpopelka20 Sep 3, 2024
deabd31
borrow device chunk
joshpopelka20 Sep 3, 2024
3170304
more logging
joshpopelka20 Sep 3, 2024
1e3e55d
log logits
joshpopelka20 Sep 3, 2024
4fab476
try to clone out all caches
joshpopelka20 Sep 3, 2024
5863802
add logging in cacher
joshpopelka20 Sep 3, 2024
ddcd848
revert clone out cache
joshpopelka20 Sep 3, 2024
d9ac7ec
skip clone out
joshpopelka20 Sep 3, 2024
81cd584
have cache out do nothing
joshpopelka20 Sep 3, 2024
1456c72
fix syntax
joshpopelka20 Sep 3, 2024
d52cdd8
remove clone in cache
joshpopelka20 Sep 3, 2024
bf80940
remove loggers
joshpopelka20 Sep 3, 2024
a4dcd1e
test speculative
joshpopelka20 Sep 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions mistralrs-core/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,9 @@ pub struct Llama {
blocks: Vec<Block>,
ln_f: RmsNorm,
lm_head: QMatMul,
pub kv_cache: crate::pipeline::Cache,
// pub kv_cache: crate::pipeline::Cache,
pub kv_caches: Vec<crate::pipeline::Cache>,
cuda_devices: Vec<candle_core::Device>,
pub device: Device,
mapper: Box<dyn DeviceMapper + Send + Sync>,
cfg: ModelConfigMetadata,
Expand All @@ -421,18 +423,35 @@ impl Llama {
mut metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>,
) -> Result<Tensor> {
let mut x = self.wte.forward(input_ids)?;
let mut cache = self.kv_cache.lock();
let (batch_size, seq_len, hidden_size) = x.dims3()?;

let num_devices = 1;
let chunk_size = seq_len / num_devices;

// let mut chunks = Vec::new();
// let mut chunks = Vec::<Tensor>;
// let chunk = x.clone();
// chunks.push(chunk.to_device(&self.cuda_devices[0])?);
let mut chunks: Vec<Tensor> = Vec::with_capacity(num_devices);
println!("x device {:?}", x.device());
chunks.push(x.copy().unwrap());
println!("chunk device {:?}", chunks[0].device());

let mut cache = self.kv_caches[0].lock();
let mask = CausalMasker.make_causal_mask_as_attn_bias(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
x.dtype(),
// x.dtype(),
chunks[0].dtype(),
self.blocks[0].attn.num_attention_heads,
)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = self.mapper.map(x, block_idx)?;
// x = self.mapper.map(x, block_idx)?;
// x = self.mapper.map(&chunks[0], block_idx)?;
x = self.mapper.map(chunks[0].copy().unwrap(), block_idx)?;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the bit which is causing the issue. It looks like we aren't using the values from the last block as the inputs are always from the embeddings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that's going to be a problem. From the algorithm in the paper, I need to use the local block on each host. So I need to iterate through the layers and then through the local blocks.

image

Any suggestions on how to redesign it?

x = block.forward(
&x,
&mask.clone().map(|m| m.to_device(x.device()).unwrap()),
Expand Down Expand Up @@ -468,6 +487,9 @@ impl Llama {
quant_cfg.bits
);
}

let num_devices = 1;
let mut cuda_devices = Vec::with_capacity(num_devices);
let mapper = normal_loading_metadata
.mapper
.into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?;
Expand Down Expand Up @@ -514,6 +536,9 @@ impl Llama {
.expect("Failed to create PagedAttention"),
),
};
if !cuda_devices.iter().any(|d| format!("{:?}", d) == format!("{:?}", device)) {
cuda_devices.push(device.clone());
}
Block::load(
vb.pp(&format!("model.layers.{i}")),
cfg,
Expand All @@ -527,12 +552,21 @@ impl Llama {
})
.collect();

let mut kv_caches: Vec<crate::pipeline::Cache> = Vec::with_capacity(num_devices);

for device_id in 0..num_devices {
let cache = crate::pipeline::Cache::new(cfg.num_hidden_layers , false);
kv_caches.push(cache);
};

Ok(Self {
wte,
blocks,
ln_f,
lm_head: QMatMul::Tensor(lm_head.weight().clone()),
kv_cache: crate::pipeline::Cache::new(cfg.num_hidden_layers, false),
// kv_cache: crate::pipeline::Cache::new(cfg.num_hidden_layers, false),
kv_caches,
cuda_devices,
device: normal_loading_metadata.real_device,
mapper,
cfg: ModelConfigMetadata {
Expand Down Expand Up @@ -623,7 +657,8 @@ impl NormalModel for Llama {
unimplemented!()
}
fn cache(&self) -> &crate::pipeline::Cache {
&self.kv_cache
&self.kv_caches[0]
// &self.kv_cache
}
fn device(&self) -> &Device {
&self.device
Expand Down
Loading