From b186a774bdd784ae71e15201f3ac203485bca432 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 12 Aug 2024 14:03:31 -0400 Subject: [PATCH 001/107] test minimal changes --- mistralrs-core/src/models/llama.rs | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index c3075952b..86227baaa 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -421,7 +421,7 @@ impl Llama { mut metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, ) -> Result { let mut x = self.wte.forward(input_ids)?; - let mut cache = self.kv_cache.lock(); + let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( input_ids, metadata @@ -468,6 +468,9 @@ impl Llama { quant_cfg.bits ); } + + let num_devices = 1; + let mut cuda_devices = Vec::with_capacity(num_devices); let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -514,6 +517,9 @@ impl Llama { .expect("Failed to create PagedAttention"), ), }; + if !cuda_devices.iter().any(|d| format!("{:?}", d) == format!("{:?}", device)) { + cuda_devices.push(device.clone()); + } Block::load( vb.pp(&format!("model.layers.{i}")), cfg, @@ -527,12 +533,21 @@ impl Llama { }) .collect(); + let mut kv_caches: Vec = Vec::with_capacity(num_devices); + + for device_id in 0..num_devices { + let cache = crate::pipeline::Cache::new(cfg.num_hidden_layers , false); + kv_caches.push(cache); + }; + Ok(Self { wte, blocks, ln_f, lm_head: QMatMul::Tensor(lm_head.weight().clone()), - kv_cache: crate::pipeline::Cache::new(cfg.num_hidden_layers, false), + // kv_cache: crate::pipeline::Cache::new(cfg.num_hidden_layers, false), + kv_caches, + cuda_devices, device: normal_loading_metadata.real_device, mapper, cfg: ModelConfigMetadata { @@ -623,7 +638,8 @@ impl NormalModel for Llama { unimplemented!() } fn cache(&self) -> &crate::pipeline::Cache { - &self.kv_cache + &self.kv_caches[0] + // &self.kv_cache } fn device(&self) -> &Device { &self.device From 152a41ce992427aed1d7104779d0222df4e99071 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 12 Aug 2024 14:15:05 -0400 Subject: [PATCH 002/107] add to struct --- mistralrs-core/src/models/llama.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 86227baaa..e3a6c14a9 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -405,7 +405,9 @@ pub struct Llama { blocks: Vec, ln_f: RmsNorm, lm_head: QMatMul, - pub kv_cache: crate::pipeline::Cache, + // pub kv_cache: crate::pipeline::Cache, + pub kv_caches: Vec, + cuda_devices: Vec, pub device: Device, mapper: Box, cfg: ModelConfigMetadata, From 3e1f47bed8966238f5383f0989486e63183cba8a Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:06:22 -0400 Subject: [PATCH 003/107] add chunks logic --- mistralrs-core/src/models/llama.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index e3a6c14a9..4e32670fc 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -423,6 +423,15 @@ impl Llama { mut metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, ) -> Result { let mut x = self.wte.forward(input_ids)?; + let (batch_size, seq_len, hidden_size) = x.dims3()?; + + let num_devices = 1; + let chunk_size = seq_len / num_devices; + + let mut chunks = Vec::with_capacity(num_devices); + let chunk = x; + chunks.push(chunk.to_device(&self.cuda_devices[0])?); + let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( input_ids, @@ -434,7 +443,8 @@ impl Llama { self.blocks[0].attn.num_attention_heads, )?; for (block_idx, block) in self.blocks.iter().enumerate() { - x = self.mapper.map(x, block_idx)?; + // x = self.mapper.map(x, block_idx)?; + x = self.mapper.map(chunks[0], block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From 04ee7dee41612dda6c4ce0ab90528f4cc291bca8 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:08:15 -0400 Subject: [PATCH 004/107] clone chunks --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 3f4eb4171..aa3ef4e4f 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -444,7 +444,7 @@ impl Llama { )?; for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; - x = self.mapper.map(chunks[0], block_idx)?; + x = self.mapper.map(chunks[0].clone(), block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From 83ab8f2a034df22fa99a9a4c881d818a13c05ab6 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:16:04 -0400 Subject: [PATCH 005/107] clone x for chunk --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index aa3ef4e4f..a6f1745ba 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -429,7 +429,7 @@ impl Llama { let chunk_size = seq_len / num_devices; let mut chunks = Vec::with_capacity(num_devices); - let chunk = x; + let chunk = x.clone(); chunks.push(chunk.to_device(&self.cuda_devices[0])?); let mut cache = self.kv_caches[0].lock(); From 36d5eb2c45515102ff071617da1d9c340c4378d4 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:24:38 -0400 Subject: [PATCH 006/107] remove chunk to device --- mistralrs-core/src/models/llama.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index a6f1745ba..4327361d7 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -430,7 +430,8 @@ impl Llama { let mut chunks = Vec::with_capacity(num_devices); let chunk = x.clone(); - chunks.push(chunk.to_device(&self.cuda_devices[0])?); + // chunks.push(chunk.to_device(&self.cuda_devices[0])?); + chunks.push(chunk); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( From c1a973a8504b8a9bff3ec8bd4ff72632c7c8b271 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:32:30 -0400 Subject: [PATCH 007/107] push x --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 4327361d7..f81432a4f 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -429,9 +429,9 @@ impl Llama { let chunk_size = seq_len / num_devices; let mut chunks = Vec::with_capacity(num_devices); - let chunk = x.clone(); + // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); - chunks.push(chunk); + chunks.push(x); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( From e5aee1649911051e05ad015d0e5032b0b696b096 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:33:46 -0400 Subject: [PATCH 008/107] fix x move --- mistralrs-core/src/models/llama.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index f81432a4f..ff26cc230 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -440,7 +440,8 @@ impl Llama { .as_ref() .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) .unwrap_or(&*cache as &dyn PastKvLenCache), - x.dtype(), + // x.dtype(), + chunks[0].dtype(), self.blocks[0].attn.num_attention_heads, )?; for (block_idx, block) in self.blocks.iter().enumerate() { From 7aeb80cd777524ac1a7d574f952fda4f045dec01 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:35:14 -0400 Subject: [PATCH 009/107] dont clone chunks --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index ff26cc230..512db6aab 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -446,7 +446,7 @@ impl Llama { )?; for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; - x = self.mapper.map(chunks[0].clone(), block_idx)?; + x = self.mapper.map(chunks[0], block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From 00f019f870a1ff9fc214225a8ff6b2fe8ea84cc4 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:37:35 -0400 Subject: [PATCH 010/107] unwrap chunk --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 512db6aab..daeefc8c1 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -446,7 +446,7 @@ impl Llama { )?; for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; - x = self.mapper.map(chunks[0], block_idx)?; + x = self.mapper.map(chunks[0].unwrap(), block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From 4b836bc7a09f977600144d35fdaf272cd1281cfe Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:38:57 -0400 Subject: [PATCH 011/107] change to reference --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index daeefc8c1..8e7adca2d 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -446,7 +446,7 @@ impl Llama { )?; for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; - x = self.mapper.map(chunks[0].unwrap(), block_idx)?; + x = self.mapper.map(&chunks[0], block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From 3763fee6fcc976754e03a1227281df9d26a0f54f Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:40:06 -0400 Subject: [PATCH 012/107] iter --- mistralrs-core/src/models/llama.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 8e7adca2d..7c76bef5d 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -446,7 +446,9 @@ impl Llama { )?; for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; - x = self.mapper.map(&chunks[0], block_idx)?; + // x = self.mapper.map(&chunks[0], block_idx)?; + let mut chunks_iter = chunks.into_iter(); + x = self.mapper.map(chunks_iter.next().unwrap(), block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From 32d3b0dabc6c484618217671b05a8be2bd5ef828 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:41:00 -0400 Subject: [PATCH 013/107] pop chunks --- mistralrs-core/src/models/llama.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 7c76bef5d..90b66e982 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -447,8 +447,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - let mut chunks_iter = chunks.into_iter(); - x = self.mapper.map(chunks_iter.next().unwrap(), block_idx)?; + x = self.mapper.map(chunks.pop().unwrap(), block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From 40a502b537c69cfb01e2b4d6c5adb23601bab731 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:42:05 -0400 Subject: [PATCH 014/107] clone x --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 90b66e982..f9af43fb2 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -431,7 +431,7 @@ impl Llama { let mut chunks = Vec::with_capacity(num_devices); // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); - chunks.push(x); + chunks.push(x.clone()); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( From c0c87e48f0631bdac23ad4e1bfe58235ad6aabf3 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:51:04 -0400 Subject: [PATCH 015/107] change to vec new --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index f9af43fb2..41d5e2c1b 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -428,10 +428,10 @@ impl Llama { let num_devices = 1; let chunk_size = seq_len / num_devices; - let mut chunks = Vec::with_capacity(num_devices); + let mut chunks = Vec::new(); // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); - chunks.push(x.clone()); + chunks.push(x); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( From 728e838daffd69b5c964e7c48935d13ed24339d8 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:06:27 -0400 Subject: [PATCH 016/107] store tensor reference --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 41d5e2c1b..e9723693f 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -431,7 +431,7 @@ impl Llama { let mut chunks = Vec::new(); // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); - chunks.push(x); + chunks.push(&x); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( From daf102857d01bb856031290e8599f00426686c76 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:07:20 -0400 Subject: [PATCH 017/107] extract by index --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index e9723693f..eb748d7c1 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -447,7 +447,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - x = self.mapper.map(chunks.pop().unwrap(), block_idx)?; + x = self.mapper.map(chunks[0].unwrap(), block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From ccc6c501be53775b2d6d8fafb4a2a70fa3cd52e5 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:07:59 -0400 Subject: [PATCH 018/107] remove unwrap --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index eb748d7c1..f321bdbea 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -447,7 +447,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - x = self.mapper.map(chunks[0].unwrap(), block_idx)?; + x = self.mapper.map(chunks[0], block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From 2ca9acc78b67e7c4a1a02b4e56cbeb1b45a772a4 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:08:35 -0400 Subject: [PATCH 019/107] clone --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index f321bdbea..c7cf8c831 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -447,7 +447,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - x = self.mapper.map(chunks[0], block_idx)?; + x = self.mapper.map(chunks[0].clone(), block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From e8229535ef016947335417223a4ff19a4b0167f7 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:11:01 -0400 Subject: [PATCH 020/107] mutably borrow --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index c7cf8c831..2ac5f91d8 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -447,7 +447,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - x = self.mapper.map(chunks[0].clone(), block_idx)?; + &mut x = self.mapper.map(chunks[0].clone(), block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From 086e76f9e882235afd9c35af824b9d9b95ca1980 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:11:37 -0400 Subject: [PATCH 021/107] derefernce --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 2ac5f91d8..2fc83f3bf 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -447,7 +447,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - &mut x = self.mapper.map(chunks[0].clone(), block_idx)?; + *&mut x = self.mapper.map(chunks[0].clone(), block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From ce3d4187028f06f18a62fe3f865a9c4f344cf079 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:24:12 -0400 Subject: [PATCH 022/107] create vec of tensors --- mistralrs-core/src/models/llama.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 2fc83f3bf..65f4e467f 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -428,7 +428,8 @@ impl Llama { let num_devices = 1; let chunk_size = seq_len / num_devices; - let mut chunks = Vec::new(); + // let mut chunks = Vec::new(); + let mut chunks = Vec; // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); chunks.push(&x); From ffbe3f93c62ad30201a466103744a63001f7a159 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:27:15 -0400 Subject: [PATCH 023/107] make new vec --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 65f4e467f..54d99a158 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -429,7 +429,7 @@ impl Llama { let chunk_size = seq_len / num_devices; // let mut chunks = Vec::new(); - let mut chunks = Vec; + let mut chunks = Vec::new(); // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); chunks.push(&x); From 9f915940a352c6b48d21650300601c528fcf560c Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:27:50 -0400 Subject: [PATCH 024/107] type tensor --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 54d99a158..622868a2b 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -429,7 +429,7 @@ impl Llama { let chunk_size = seq_len / num_devices; // let mut chunks = Vec::new(); - let mut chunks = Vec::new(); + let mut chunks = Vec::; // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); chunks.push(&x); From c88cdf5197ca56df7e9a8ebc08d7076b7002458d Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:31:26 -0400 Subject: [PATCH 025/107] push to chunks --- mistralrs-core/src/models/llama.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 622868a2b..6a4c113e0 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -411,6 +411,7 @@ pub struct Llama { pub device: Device, mapper: Box, cfg: ModelConfigMetadata, + chunks: Vec } impl Llama { @@ -429,10 +430,10 @@ impl Llama { let chunk_size = seq_len / num_devices; // let mut chunks = Vec::new(); - let mut chunks = Vec::; + // let mut chunks = Vec::; // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); - chunks.push(&x); + self.chunks.push(x); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( From 0eb85f5c6cfccde0adc189e3344e39a89d88a25a Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:32:03 -0400 Subject: [PATCH 026/107] self chunks --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 6a4c113e0..bc59a252c 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -443,13 +443,13 @@ impl Llama { .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) .unwrap_or(&*cache as &dyn PastKvLenCache), // x.dtype(), - chunks[0].dtype(), + self.chunks[0].dtype(), self.blocks[0].attn.num_attention_heads, )?; for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - *&mut x = self.mapper.map(chunks[0].clone(), block_idx)?; + *&mut x = self.mapper.map(self.chunks[0], block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From b7edfbe6b598ef837b778a081c253a4a8d10a189 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:34:46 -0400 Subject: [PATCH 027/107] create vec of chunks --- mistralrs-core/src/models/llama.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index bc59a252c..7d33c4f2a 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -411,7 +411,6 @@ pub struct Llama { pub device: Device, mapper: Box, cfg: ModelConfigMetadata, - chunks: Vec } impl Llama { @@ -433,7 +432,8 @@ impl Llama { // let mut chunks = Vec::; // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); - self.chunks.push(x); + let mut chunks: Vec = Vec::with_capacity(num_devices); + chunks.push(x); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( @@ -443,13 +443,13 @@ impl Llama { .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) .unwrap_or(&*cache as &dyn PastKvLenCache), // x.dtype(), - self.chunks[0].dtype(), + chunks[0].dtype(), self.blocks[0].attn.num_attention_heads, )?; for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - *&mut x = self.mapper.map(self.chunks[0], block_idx)?; + *&mut x = self.mapper.map(chunks[0], block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From 8732bbca5240bfae5dc49cbf0bb488294a3b0a1c Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:36:55 -0400 Subject: [PATCH 028/107] clone x --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 7d33c4f2a..209a89183 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -433,7 +433,7 @@ impl Llama { // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); let mut chunks: Vec = Vec::with_capacity(num_devices); - chunks.push(x); + chunks.push(x.clone()); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( From 9c961f33da753f9f3536850114db41301aa2e596 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:37:36 -0400 Subject: [PATCH 029/107] remove reference --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 209a89183..6a57171be 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -449,7 +449,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - *&mut x = self.mapper.map(chunks[0], block_idx)?; + x = self.mapper.map(chunks[0], block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From 1aaca7234b79f0dea9f5f6511ccffe35a8ec9fee Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:38:41 -0400 Subject: [PATCH 030/107] clone for move --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 6a57171be..2adb087ee 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -449,7 +449,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - x = self.mapper.map(chunks[0], block_idx)?; + x = self.mapper.map(chunks[0].clone(), block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From 4ab98c8f0373eab354cf75c9b524c16e8e76bde0 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:54:35 -0400 Subject: [PATCH 031/107] remove clone --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 2adb087ee..5863c4177 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -433,7 +433,7 @@ impl Llama { // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); let mut chunks: Vec = Vec::with_capacity(num_devices); - chunks.push(x.clone()); + chunks.push(x); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( @@ -449,7 +449,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - x = self.mapper.map(chunks[0].clone(), block_idx)?; + x = self.mapper.map(chunks[0], block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From 2e0b2fd0af96231e9cd1905a5fd0690f6cd9b83f Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:58:43 -0400 Subject: [PATCH 032/107] add back clone --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 5863c4177..2adb087ee 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -433,7 +433,7 @@ impl Llama { // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); let mut chunks: Vec = Vec::with_capacity(num_devices); - chunks.push(x); + chunks.push(x.clone()); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( @@ -449,7 +449,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - x = self.mapper.map(chunks[0], block_idx)?; + x = self.mapper.map(chunks[0].clone(), block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From 7bb3cf743d2c4dc70f701ccc2a37ccd0f673c546 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:08:20 -0400 Subject: [PATCH 033/107] change to copy --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 2adb087ee..0668ee6af 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -433,7 +433,7 @@ impl Llama { // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); let mut chunks: Vec = Vec::with_capacity(num_devices); - chunks.push(x.clone()); + chunks.push(x.copy()); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( @@ -449,7 +449,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - x = self.mapper.map(chunks[0].clone(), block_idx)?; + x = self.mapper.map(chunks[0].copy(), block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From f433517ee8c2123f56aba0ecad15fbe6637683b7 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:10:41 -0400 Subject: [PATCH 034/107] unwrap copy --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 0668ee6af..50e60c852 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -433,7 +433,7 @@ impl Llama { // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); let mut chunks: Vec = Vec::with_capacity(num_devices); - chunks.push(x.copy()); + chunks.push(x.copy().unwrap()); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( @@ -449,7 +449,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - x = self.mapper.map(chunks[0].copy(), block_idx)?; + x = self.mapper.map(chunks[0].copy().unwrap(), block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From cf5b2043e0adb38bf3e3b9c5cdd262900a9e3f41 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:40:52 -0400 Subject: [PATCH 035/107] remove copy --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 50e60c852..e35187351 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -433,7 +433,7 @@ impl Llama { // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); let mut chunks: Vec = Vec::with_capacity(num_devices); - chunks.push(x.copy().unwrap()); + chunks.push(x); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( From 9e0e6c849ff2ea3eb149f8fc98ed902e9b3e7ccb Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:50:37 -0400 Subject: [PATCH 036/107] use my candle --- Cargo.toml | 4 ++-- mistralrs-core/Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 939335c61..ce57c7425 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,8 +25,8 @@ license = "MIT" [workspace.dependencies] anyhow = "1.0.80" -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "2386e4e" } -candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "2386e4e" } +candle-core = { git = "https://github.com/joshpopelka20/candle.git" } +candle-nn = { git = "https://github.com/joshpopelka20/candle.git" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index dd0456a6c..236dab266 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -17,7 +17,7 @@ candle-core.workspace = true candle-nn.workspace = true serde.workspace = true serde_json.workspace = true -candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "2386e4e", optional = true } +candle-flash-attn = { git = "https://github.com/joshpopelka20/candle.git", optional = true } dirs = "5.0.1" hf-hub = "0.3.2" thiserror = "1.0.57" From 03be02ac18cde4235bfa36811ce60f1da826bd19 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 15:12:18 -0400 Subject: [PATCH 037/107] mvoe back to EricLBuehler --- Cargo.toml | 4 ++-- mistralrs-core/Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ce57c7425..84426d6ce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,8 +25,8 @@ license = "MIT" [workspace.dependencies] anyhow = "1.0.80" -candle-core = { git = "https://github.com/joshpopelka20/candle.git" } -candle-nn = { git = "https://github.com/joshpopelka20/candle.git" } +candle-core = { git = "https://github.com/EricLBuehler/candle.git" } +candle-nn = { git = "https://github.com/EricLBuehler/candle.git" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index 236dab266..535d1004c 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -17,7 +17,7 @@ candle-core.workspace = true candle-nn.workspace = true serde.workspace = true serde_json.workspace = true -candle-flash-attn = { git = "https://github.com/joshpopelka20/candle.git", optional = true } +candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", optional = true } dirs = "5.0.1" hf-hub = "0.3.2" thiserror = "1.0.57" From d75ee88bae6f27d50b18906d5f215551c280b5dc Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 15:13:01 -0400 Subject: [PATCH 038/107] move back to josh --- Cargo.toml | 4 ++-- mistralrs-core/Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 84426d6ce..ce57c7425 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,8 +25,8 @@ license = "MIT" [workspace.dependencies] anyhow = "1.0.80" -candle-core = { git = "https://github.com/EricLBuehler/candle.git" } -candle-nn = { git = "https://github.com/EricLBuehler/candle.git" } +candle-core = { git = "https://github.com/joshpopelka20/candle.git" } +candle-nn = { git = "https://github.com/joshpopelka20/candle.git" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index 535d1004c..236dab266 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -17,7 +17,7 @@ candle-core.workspace = true candle-nn.workspace = true serde.workspace = true serde_json.workspace = true -candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", optional = true } +candle-flash-attn = { git = "https://github.com/joshpopelka20/candle.git", optional = true } dirs = "5.0.1" hf-hub = "0.3.2" thiserror = "1.0.57" From f79ef6fd54b7eb14a7c8d2f3d443651986d76a0d Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 15:15:18 -0400 Subject: [PATCH 039/107] revert candle --- Cargo.toml | 4 ++-- mistralrs-core/Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ce57c7425..939335c61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,8 +25,8 @@ license = "MIT" [workspace.dependencies] anyhow = "1.0.80" -candle-core = { git = "https://github.com/joshpopelka20/candle.git" } -candle-nn = { git = "https://github.com/joshpopelka20/candle.git" } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "2386e4e" } +candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "2386e4e" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index 236dab266..dd0456a6c 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -17,7 +17,7 @@ candle-core.workspace = true candle-nn.workspace = true serde.workspace = true serde_json.workspace = true -candle-flash-attn = { git = "https://github.com/joshpopelka20/candle.git", optional = true } +candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "2386e4e", optional = true } dirs = "5.0.1" hf-hub = "0.3.2" thiserror = "1.0.57" From 4b9ed28339048ddd2993d129c9d11a5efe1ab8fc Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 15:30:35 -0400 Subject: [PATCH 040/107] remove copy mapper --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index e35187351..6a57171be 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -433,7 +433,7 @@ impl Llama { // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); let mut chunks: Vec = Vec::with_capacity(num_devices); - chunks.push(x); + chunks.push(x.clone()); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( @@ -449,7 +449,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - x = self.mapper.map(chunks[0].copy().unwrap(), block_idx)?; + x = self.mapper.map(chunks[0], block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From b54a5af012e274327d87153f7f5f435a3f2ad909 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 15:41:01 -0400 Subject: [PATCH 041/107] clone chunks --- mistralrs-core/src/models/llama.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 6a57171be..4a6c41c4a 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -449,7 +449,9 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - x = self.mapper.map(chunks[0], block_idx)?; + println!("x device {:?}", x.device()); + println!("chunk device {:?}", chunks[0].device()); + x = self.mapper.map(chunks[0].clone(), block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From a50883bf1a14ea870ec33a07d0b52f18a14b1534 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 16:30:44 -0400 Subject: [PATCH 042/107] copy instead of clone --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 4a6c41c4a..0a2a1985e 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -433,7 +433,7 @@ impl Llama { // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); let mut chunks: Vec = Vec::with_capacity(num_devices); - chunks.push(x.clone()); + chunks.push(x.copy().unwrap()); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( @@ -451,7 +451,7 @@ impl Llama { // x = self.mapper.map(&chunks[0], block_idx)?; println!("x device {:?}", x.device()); println!("chunk device {:?}", chunks[0].device()); - x = self.mapper.map(chunks[0].clone(), block_idx)?; + x = self.mapper.map(chunks[0].copy().unwrap(), block_idx)?; x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From edf82dac57db4ea2092346367bed3f12da7da293 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 16:40:02 -0400 Subject: [PATCH 043/107] move loggers --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 0a2a1985e..a10cfc306 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -433,7 +433,9 @@ impl Llama { // let chunk = x.clone(); // chunks.push(chunk.to_device(&self.cuda_devices[0])?); let mut chunks: Vec = Vec::with_capacity(num_devices); + println!("x device {:?}", x.device()); chunks.push(x.copy().unwrap()); + println!("chunk device {:?}", chunks[0].device()); let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( @@ -449,8 +451,6 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - println!("x device {:?}", x.device()); - println!("chunk device {:?}", chunks[0].device()); x = self.mapper.map(chunks[0].copy().unwrap(), block_idx)?; x = block.forward( &x, From 3e9cc265802278514d46896dc508059fcade8c0d Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Thu, 22 Aug 2024 10:57:00 -0400 Subject: [PATCH 044/107] add sequence parallelism --- mistralrs-core/src/models/llama.rs | 32 +- .../src/models/llama_ring_attention.rs | 815 ++++++++++++++++++ 2 files changed, 840 insertions(+), 7 deletions(-) create mode 100644 mistralrs-core/src/models/llama_ring_attention.rs diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index a10cfc306..5560ff0e4 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -428,14 +428,30 @@ impl Llama { let num_devices = 1; let chunk_size = seq_len / num_devices; - // let mut chunks = Vec::new(); - // let mut chunks = Vec::; - // let chunk = x.clone(); - // chunks.push(chunk.to_device(&self.cuda_devices[0])?); let mut chunks: Vec = Vec::with_capacity(num_devices); - println!("x device {:?}", x.device()); - chunks.push(x.copy().unwrap()); - println!("chunk device {:?}", chunks[0].device()); + // chunks.push(x.copy().unwrap()); + + // Handle the case where sequence length is less than number of devices + if seq_len <= num_devices { + for j in 0..seq_len { + // let chunk = x.i((.., j..j+1, ..))?; + let chunk = x.clone(); + chunks.push(chunk.to_device(&self.cuda_devices[j])?); + } + } else { + for j in 0..num_devices { + let start = j * chunk_size; + let end = if j == num_devices - 1 { + seq_len + } else { + (j+ 1) * chunk_size + }; + + let chunk = x.i((.., start..end,..))?; + let device = &self.cuda_devices[j]; + chunks.push(chunk.to_device(&device)?); + } + } let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( @@ -451,6 +467,8 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; + println!("x device {:?}", x.device()); + println!("chunk device {:?}", chunks[0].device()); x = self.mapper.map(chunks[0].copy().unwrap(), block_idx)?; x = block.forward( &x, diff --git a/mistralrs-core/src/models/llama_ring_attention.rs b/mistralrs-core/src/models/llama_ring_attention.rs new file mode 100644 index 000000000..5560ff0e4 --- /dev/null +++ b/mistralrs-core/src/models/llama_ring_attention.rs @@ -0,0 +1,815 @@ +#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] + +use candle_core::{quantized::QMatMul, DType, Device, Result, Tensor}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; +use mistralrs_quant::{QuantMethod, QuantizedConfig}; +use serde::Deserialize; +use std::sync::Arc; + +use crate::{ + amoe::{ + AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainableLayer, MlpLayer, + MoeMlp, + }, + device_map::DeviceMapper, + get_delta_from_lora_ab, + layers::{ + repeat_kv, CausalMasker, Llama3RopeConfig, Llama3RotaryEmbedding, MatMul, RmsNorm, + ScaledDotProductAttention, + }, + layers_masker::PastKvLenCache, + paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, + pipeline::{ + extract_logits, text_models_inputs_processor::PagedAttentionInputMetadata, IsqModel, + NormalLoadingMetadata, NormalModel, + }, + utils::progress::NiceProgressBar, +}; + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct Config { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub use_flash_attn: bool, + pub rms_norm_eps: f64, + pub rope_theta: f32, + pub max_position_embeddings: usize, + pub rope_scaling: Option, + pub quantization_config: Option, +} + +struct CausalSelfAttention { + q_proj: Arc, + k_proj: Arc, + v_proj: Arc, + o_proj: Arc, + num_attention_heads: usize, + num_key_value_heads: usize, + head_dim: usize, + use_flash_attn: bool, + rotary_emb: Arc, + max_seq_len: usize, + paged_attn: Option, +} + +impl CausalSelfAttention { + #[allow(clippy::too_many_arguments)] + fn forward( + &self, + x: &Tensor, + attention_mask: &Option, + seqlen_offsets: &[usize], + start_offsets_kernel: Tensor, + block_idx: usize, + kv_cache: &mut crate::pipeline::LayerCaches, + metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>, + ) -> Result { + let (b_sz, seq_len, _) = x.dims3()?; + + let original_dtype = x.dtype(); + let mut x = x.clone(); + if let Some(t) = self.q_proj.quantized_act_type() { + x = x.to_dtype(t)?; + } + let mut q = MatMul.qmethod_matmul(&x, &*self.q_proj)?; + let mut k = MatMul.qmethod_matmul(&x, &*self.k_proj)?; + let mut v = MatMul.qmethod_matmul(&x, &*self.v_proj)?; + if self.q_proj.quantized_act_type().is_some() { + q = q.to_dtype(original_dtype)?; + k = k.to_dtype(original_dtype)?; + v = v.to_dtype(original_dtype)?; + } + + let mut q = q.reshape((b_sz * seq_len, self.num_attention_heads, self.head_dim))?; + let mut k = k.reshape((b_sz * seq_len, self.num_key_value_heads, self.head_dim))?; + let v = if seq_len != 1 { + v.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)? + } else { + // Optimization for seqlen = 1, avoid transpose and just modify reshape dims + v.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))? + }; + + self.rotary_emb + .forward(seqlen_offsets, &start_offsets_kernel, &mut q, &mut k, b_sz)?; + + if q.rank() == 3 && seq_len != 1 { + q = q + .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + k = k + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + } else if q.rank() == 3 { + // Optimization for seqlen = 1, avoid transpose and just modify reshape dims + q = q + .reshape((b_sz, self.num_attention_heads, seq_len, self.head_dim))? + .contiguous()?; + k = k + .reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))? + .contiguous()?; + } + + let mut y = match &self.paged_attn { + Some(paged_attn) => { + let ((key_cache, value_cache), input_metadata) = metadata.unwrap(); + paged_attn.forward( + &q, + &k, + &v, + attention_mask.clone().as_ref(), + Some(key_cache), + Some(value_cache), + input_metadata, + )? + } + None => { + let (k, v) = + crate::pipeline::Cache::update_kv_cache(&mut kv_cache[block_idx], k, v, false)?; + + let k = repeat_kv(k, self.num_attention_heads / self.num_key_value_heads)? + .contiguous()?; + let v = repeat_kv(v, self.num_attention_heads / self.num_key_value_heads)? + .contiguous()?; + + ScaledDotProductAttention.run_attention( + &q, + &k, + &v, + self.num_attention_heads, + self.head_dim, + attention_mask.clone().as_ref(), + self.use_flash_attn, + b_sz, + seq_len, + )? + } + }; + + if let Some(t) = self.q_proj.quantized_act_type() { + y = y.to_dtype(t)?; + } + y = if attention_mask.is_some() { + y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))? + } else { + y.reshape((b_sz, seq_len, ()))? + }; + let mut res = MatMul.qmethod_matmul(&y, &*self.o_proj)?; + if self.q_proj.quantized_act_type().is_some() { + res = res.to_dtype(original_dtype)?; + } + Ok(res) + } + + fn load( + vb: VarBuilder, + cfg: &Config, + rope: Arc, + paged_attn: Option, + ) -> Result { + let size_in = cfg.hidden_size; + let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads; + let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads; + let q_proj = mistralrs_quant::linear_no_bias( + size_in, + size_q, + &cfg.quantization_config, + vb.pp("q_proj"), + )?; + let k_proj = mistralrs_quant::linear_no_bias( + size_in, + size_kv, + &cfg.quantization_config, + vb.pp("k_proj"), + )?; + let v_proj = mistralrs_quant::linear_no_bias( + size_in, + size_kv, + &cfg.quantization_config, + vb.pp("v_proj"), + )?; + let o_proj = mistralrs_quant::linear_no_bias( + size_q, + size_in, + &cfg.quantization_config, + vb.pp("o_proj"), + )?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_attention_heads: cfg.num_attention_heads, + num_key_value_heads: cfg.num_key_value_heads, + head_dim: cfg.hidden_size / cfg.num_attention_heads, + use_flash_attn: cfg.use_flash_attn, + rotary_emb: rope, + max_seq_len: cfg.max_position_embeddings, + paged_attn, + }) + } +} + +#[derive(Clone)] +struct Mlp { + c_fc1: Arc, + c_fc2: Arc, + c_proj: Arc, + params: Vec, +} + +impl Mlp { + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let h_size = cfg.hidden_size; + let i_size = cfg.intermediate_size; + let c_fc1 = mistralrs_quant::linear_no_bias( + h_size, + i_size, + &cfg.quantization_config, + vb.pp("gate_proj"), + )?; + let c_fc2 = mistralrs_quant::linear_no_bias( + h_size, + i_size, + &cfg.quantization_config, + vb.pp("up_proj"), + )?; + let c_proj = mistralrs_quant::linear_no_bias( + i_size, + h_size, + &cfg.quantization_config, + vb.pp("down_proj"), + )?; + Ok(Self { + c_fc1, + c_fc2, + c_proj, + params: vec![h_size, i_size], + }) + } +} + +impl AnyMoeTrainableLayer for Mlp {} + +impl MlpLayer for Mlp { + fn forward(&self, x: &Tensor) -> Result { + let original_dtype = x.dtype(); + let mut x = x.clone(); + if let Some(t) = self.c_fc1.quantized_act_type() { + x = x.to_dtype(t)?; + } + let x = (candle_nn::ops::silu(&MatMul.qmethod_matmul(&x, &*self.c_fc1)?)? + * MatMul.qmethod_matmul(&x, &*self.c_fc2)?)?; + let mut res = MatMul.qmethod_matmul(&x, &*self.c_proj)?; + if self.c_fc1.quantized_act_type().is_some() { + res = res.to_dtype(original_dtype)?; + } + Ok(res) + } + fn get_isq_tensors(&mut self) -> Vec<&mut QMatMul> { + { + let c_fc1 = self.c_fc1.clone().convert_to_isq().unwrap(); + self.c_fc1 = c_fc1; + let c_fc2 = self.c_fc2.clone().convert_to_isq().unwrap(); + self.c_fc2 = c_fc2; + let c_proj = self.c_proj.clone().convert_to_isq().unwrap(); + self.c_proj = c_proj; + } + vec![ + Arc::get_mut(&mut self.c_fc1).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.c_fc2).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.c_proj).unwrap().get_qmatmul(), + ] + .into_iter() + .flatten() + .collect::>() + } + fn get_isq_biases(&mut self) -> Vec> { + vec![None, None, None] + } + fn clone(&self) -> Box { + Box::new(Clone::clone(self)) + } + fn get_params(&self) -> &[usize] { + &self.params + } + // c_fc1, c_fc2, c_proj + fn new_added_delta(&self, deltas: Vec>) -> Result> { + let new_c_fc1 = if let Some(ref delta) = deltas[0] { + self.c_fc1.add_delta_w(delta)? + } else { + self.c_fc1.clone() + }; + let new_c_fc2 = if let Some(ref delta) = deltas[1] { + self.c_fc2.add_delta_w(delta)? + } else { + self.c_fc2.clone() + }; + let new_c_proj = if let Some(ref delta) = deltas[2] { + self.c_proj.add_delta_w(delta)? + } else { + self.c_proj.clone() + }; + + Ok(Box::new(Self { + c_fc1: new_c_fc1, + c_fc2: new_c_fc2, + c_proj: new_c_proj, + params: self.params.clone(), + })) + } + + fn dtype_device(&self) -> (DType, Device) { + self.c_fc1.dtype_and_device() + } +} + +struct Block { + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Box, +} + +impl Block { + #[allow(clippy::too_many_arguments)] + fn forward( + &self, + x: &Tensor, + attention_mask: &Option, + seqlen_offsets: &[usize], + start_offsets_kernel: Tensor, + block_idx: usize, + kv_cache: &mut crate::pipeline::LayerCaches, + metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>, + ) -> Result { + let residual = x; + let x = self.rms_1.forward(x)?; + let x = (self.attn.forward( + &x, + attention_mask, + seqlen_offsets, + start_offsets_kernel, + block_idx, + kv_cache, + metadata, + )? + residual)?; + let residual = &x; + let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; + Ok(x) + } + + fn load( + vb: VarBuilder, + cfg: &Config, + mapper: &dyn DeviceMapper, + layer_idx: usize, + loading_isq: bool, + rope: Arc, + paged_attn: Option, + ) -> Result { + let attn = CausalSelfAttention::load( + mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq), + cfg, + rope, + paged_attn, + )?; + let mlp = Mlp::load(mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq), cfg)?; + let rms_1 = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + mapper.set_device(layer_idx, vb.pp("input_layernorm"), false), + )?; + let rms_2 = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false), + )?; + Ok(Self { + rms_1, + attn, + rms_2, + mlp: Box::new(mlp), + }) + } +} + +pub struct Llama { + wte: Embedding, + blocks: Vec, + ln_f: RmsNorm, + lm_head: QMatMul, + // pub kv_cache: crate::pipeline::Cache, + pub kv_caches: Vec, + cuda_devices: Vec, + pub device: Device, + mapper: Box, + cfg: ModelConfigMetadata, +} + +impl Llama { + pub fn forward( + &self, + input_ids: &Tensor, + seqlen_offsets: &[usize], + start_offsets_kernel: Tensor, + context_lens: Vec<(usize, usize)>, + mut metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, + ) -> Result { + let mut x = self.wte.forward(input_ids)?; + let (batch_size, seq_len, hidden_size) = x.dims3()?; + + let num_devices = 1; + let chunk_size = seq_len / num_devices; + + let mut chunks: Vec = Vec::with_capacity(num_devices); + // chunks.push(x.copy().unwrap()); + + // Handle the case where sequence length is less than number of devices + if seq_len <= num_devices { + for j in 0..seq_len { + // let chunk = x.i((.., j..j+1, ..))?; + let chunk = x.clone(); + chunks.push(chunk.to_device(&self.cuda_devices[j])?); + } + } else { + for j in 0..num_devices { + let start = j * chunk_size; + let end = if j == num_devices - 1 { + seq_len + } else { + (j+ 1) * chunk_size + }; + + let chunk = x.i((.., start..end,..))?; + let device = &self.cuda_devices[j]; + chunks.push(chunk.to_device(&device)?); + } + } + + let mut cache = self.kv_caches[0].lock(); + let mask = CausalMasker.make_causal_mask_as_attn_bias( + input_ids, + metadata + .as_ref() + .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) + .unwrap_or(&*cache as &dyn PastKvLenCache), + // x.dtype(), + chunks[0].dtype(), + self.blocks[0].attn.num_attention_heads, + )?; + for (block_idx, block) in self.blocks.iter().enumerate() { + // x = self.mapper.map(x, block_idx)?; + // x = self.mapper.map(&chunks[0], block_idx)?; + println!("x device {:?}", x.device()); + println!("chunk device {:?}", chunks[0].device()); + x = self.mapper.map(chunks[0].copy().unwrap(), block_idx)?; + x = block.forward( + &x, + &mask.clone().map(|m| m.to_device(x.device()).unwrap()), + seqlen_offsets, + start_offsets_kernel.clone(), + block_idx, + &mut cache, + metadata + .as_mut() + .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), &mut **metadata)), + )?; + } + let x = x.to_device(&self.device)?; + let mut x = self.ln_f.forward(&x)?; + if matches!(self.lm_head, QMatMul::QTensor(_)) { + x = x.to_dtype(DType::F32)?; + } + let logits = MatMul.qmatmul(&x, &self.lm_head)?; + extract_logits(&logits, context_lens) + } + + pub fn new( + cfg: &Config, + vb: VarBuilder, + is_gptx: bool, + normal_loading_metadata: NormalLoadingMetadata, + attention_mechanism: AttentionImplementation, + ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } + + let num_devices = 1; + let mut cuda_devices = Vec::with_capacity(num_devices); + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; + + let wte = embedding( + cfg.vocab_size, + cfg.hidden_size, + mapper.set_nm_device(vb.pp("model.embed_tokens"), false), + )?; + let lm_head = candle_nn::linear_no_bias( + cfg.hidden_size, + cfg.vocab_size, + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), + )?; + let ln_f = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + mapper.set_nm_device(vb.pp("model.norm"), false), + )?; + let head_dim = cfg.hidden_size / cfg.num_attention_heads; + let blocks: Vec<_> = + NiceProgressBar::<_, 'b'>(0..cfg.num_hidden_layers, "Loading repeating layers") + .into_iter() + .map(|i| { + let device = mapper + .device_for(i, false) + .unwrap_or(&normal_loading_metadata.real_device); + let rotary_emb = Arc::new( + Llama3RotaryEmbedding::new(vb.dtype(), cfg, device, is_gptx) + .expect("Failed to create RoPE"), + ); + let paged_attn = match &attention_mechanism { + AttentionImplementation::Eager => None, + AttentionImplementation::PagedAttention => Some( + PagedAttention::new( + cfg.num_attention_heads, + head_dim, + (1.0 / (head_dim as f64).sqrt()) as f32, + Some(cfg.num_key_value_heads), + None, + device, + None, + ) + .expect("Failed to create PagedAttention"), + ), + }; + if !cuda_devices.iter().any(|d| format!("{:?}", d) == format!("{:?}", device)) { + cuda_devices.push(device.clone()); + } + Block::load( + vb.pp(&format!("model.layers.{i}")), + cfg, + &*mapper, + i, + normal_loading_metadata.loading_isq, + rotary_emb, + paged_attn, + ) + .expect("Failed to load block.") + }) + .collect(); + + let mut kv_caches: Vec = Vec::with_capacity(num_devices); + + for device_id in 0..num_devices { + let cache = crate::pipeline::Cache::new(cfg.num_hidden_layers , false); + kv_caches.push(cache); + }; + + Ok(Self { + wte, + blocks, + ln_f, + lm_head: QMatMul::Tensor(lm_head.weight().clone()), + // kv_cache: crate::pipeline::Cache::new(cfg.num_hidden_layers, false), + kv_caches, + cuda_devices, + device: normal_loading_metadata.real_device, + mapper, + cfg: ModelConfigMetadata { + num_layers: cfg.num_hidden_layers, + hidden_size: cfg.hidden_size, + num_kv_heads: cfg.num_key_value_heads, + num_attn_heads: cfg.num_attention_heads, + sliding_window: None, + }, + }) + } +} + +impl IsqModel for Llama { + fn get_matmuls(&mut self) -> (Vec<(&mut QMatMul, Option)>, &dyn DeviceMapper) { + let mut tensors = Vec::new(); + tensors.push((&mut self.lm_head, None)); + for (i, layer) in self.blocks.iter_mut().enumerate() { + { + let q_proj = layer.attn.q_proj.clone().convert_to_isq().unwrap(); + layer.attn.q_proj = q_proj; + let k_proj = layer.attn.k_proj.clone().convert_to_isq().unwrap(); + layer.attn.k_proj = k_proj; + let v_proj = layer.attn.v_proj.clone().convert_to_isq().unwrap(); + layer.attn.v_proj = v_proj; + let o_proj = layer.attn.o_proj.clone().convert_to_isq().unwrap(); + layer.attn.o_proj = o_proj; + } + if let Some(q) = Arc::get_mut(&mut layer.attn.q_proj).unwrap().get_qmatmul() { + tensors.push((q, Some(i))); + } + if let Some(k) = Arc::get_mut(&mut layer.attn.k_proj).unwrap().get_qmatmul() { + tensors.push((k, Some(i))); + } + if let Some(b) = Arc::get_mut(&mut layer.attn.v_proj).unwrap().get_qmatmul() { + tensors.push((b, Some(i))); + } + if let Some(o) = Arc::get_mut(&mut layer.attn.o_proj).unwrap().get_qmatmul() { + tensors.push((o, Some(i))); + } + tensors.extend( + layer + .mlp + .get_isq_tensors() + .into_iter() + .map(|m| (m, Some(i))) + .collect::>(), + ); + } + (tensors, &*self.mapper) + } + fn get_biases(&mut self) -> (Vec<(Option<&mut Tensor>, Option)>, &dyn DeviceMapper) { + (Vec::new(), &*self.mapper) + } +} + +impl NormalModel for Llama { + fn forward( + &self, + input_ids: &Tensor, + seqlen_offsets: &[usize], + start_offsets_kernel: Tensor, + context_lens: Vec<(usize, usize)>, + _position_ids: Vec, + metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, + ) -> Result { + self.forward( + input_ids, + seqlen_offsets, + start_offsets_kernel, + context_lens, + metadata, + ) + } + fn xlora_forward( + &self, + _input_ids: &Tensor, + _input_ids_full: &Tensor, + _seqlen_offsets: &[usize], + _seqlen_offsets_full: &[usize], + _start_offsets_kernel: Tensor, + _start_offsets_kernel_full: Tensor, + _no_kv_cache: bool, + _non_granular_state: &Option, + _context_lens: Vec<(usize, usize)>, + _position_ids: Vec, + ) -> Result { + unimplemented!() + } + fn cache(&self) -> &crate::pipeline::Cache { + &self.kv_caches[0] + // &self.kv_cache + } + fn device(&self) -> &Device { + &self.device + } + fn is_xlora(&self) -> bool { + false + } + fn max_seq_len(&self) -> usize { + self.blocks[0].attn.max_seq_len + } + fn config(&self) -> &ModelConfigMetadata { + &self.cfg + } +} + +impl AnyMoeBaseModelMixin for Llama { + fn get_mlps(&self) -> Vec<&dyn MlpLayer> { + let mut mlps = Vec::new(); + for layer in &self.blocks { + mlps.push(&*layer.mlp); + } + mlps + } + fn get_mlps_mut(&mut self) -> Vec<&mut Box> { + let mut mlps = Vec::new(); + for layer in &mut self.blocks { + mlps.push(&mut layer.mlp); + } + mlps + } + fn create_anymoe_layers( + &mut self, + additional_vbs: Vec, + config: AnyMoeConfig, + (prefix, mlp): (String, String), + mut layers: Vec, + expert_type: AnyMoeExpertType, + gate_vb: Option, + ) -> Result<()> { + let mut experts: Vec>> = Vec::new(); + if layers.is_empty() { + layers = (0..self.blocks.len()).collect::>(); + } + for _ in 0..layers.len() { + experts.push(Vec::new()); + } + for vb in additional_vbs { + let vb = vb.pp(&prefix); + for (layer, row) in experts.iter_mut().enumerate() { + if !layers.contains(&layer) { + continue; + } + + let intermediate_size = self.blocks[layer].mlp.get_params()[1]; + let hidden_size = self.blocks[layer].mlp.get_params()[0]; + match expert_type { + AnyMoeExpertType::FineTuned => { + let (dtype, device) = self.blocks[layer].mlp.dtype_device(); + row.push(Box::new(Mlp::load( + vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device), + &Config { + intermediate_size: self.blocks[layer].mlp.get_params()[1], + hidden_size: self.blocks[layer].mlp.get_params()[0], + ..Default::default() + }, + )?)); + } + AnyMoeExpertType::LoraAdapter { + rank, + alpha, + ref target_modules, + } => { + let vb_mlp = vb.pp(layer).pp(&mlp); + + let c_fc1_delta = if target_modules.contains(&"c_fc1".to_string()) { + Some(get_delta_from_lora_ab!( + vb_mlp, + rank, + alpha, + (hidden_size, intermediate_size), + "c_fc1" + )) + } else { + None + }; + let c_fc2_delta = if target_modules.contains(&"c_fc2".to_string()) { + Some(get_delta_from_lora_ab!( + vb_mlp, + rank, + alpha, + (hidden_size, intermediate_size), + "c_fc2" + )) + } else { + None + }; + let c_proj_delta = if target_modules.contains(&"c_proj".to_string()) { + Some(get_delta_from_lora_ab!( + vb_mlp, + rank, + alpha, + (intermediate_size, hidden_size), + "c_proj" + )) + } else { + None + }; + + row.push(self.blocks[layer].mlp.new_added_delta(vec![ + c_fc1_delta, + c_fc2_delta, + c_proj_delta, + ])?); + } + } + } + } + for (layer, expert) in layers.into_iter().zip(experts) { + let mut experts_all = vec![self.blocks[layer].mlp.clone()]; + experts_all.extend(expert); + let (dtype, device) = self.blocks[layer].mlp.dtype_device(); + self.blocks[layer].mlp = Box::new(MoeMlp::new( + experts_all, + config.clone(), + dtype, + &device, + layer, + gate_vb.as_ref(), + )?); + } + Ok(()) + } + fn amoe_supported(&self) -> bool { + true + } +} From 30f6b40a6a448f342a11c1a3ca13362bad50c9dc Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Thu, 22 Aug 2024 10:58:55 -0400 Subject: [PATCH 045/107] add IndexOp import --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 5560ff0e4..ce5874b58 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -1,6 +1,6 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::{quantized::QMatMul, DType, Device, Result, Tensor}; +use candle_core::{quantized::QMatMul, DType, Device, Result, Tensor, IndexOp}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantizedConfig}; use serde::Deserialize; @@ -430,7 +430,7 @@ impl Llama { let mut chunks: Vec = Vec::with_capacity(num_devices); // chunks.push(x.copy().unwrap()); - + // Handle the case where sequence length is less than number of devices if seq_len <= num_devices { for j in 0..seq_len { From 7e239761397fc252cb5d66b3dfd6fb1c2b07229e Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 11:05:26 -0400 Subject: [PATCH 046/107] only use chunk on first block index --- mistralrs-core/src/models/llama.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index ce5874b58..040ea015f 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -469,7 +469,11 @@ impl Llama { // x = self.mapper.map(&chunks[0], block_idx)?; println!("x device {:?}", x.device()); println!("chunk device {:?}", chunks[0].device()); - x = self.mapper.map(chunks[0].copy().unwrap(), block_idx)?; + if block_idx == 0 { + x = self.mapper.map(chunks[0].copy().unwrap(), block_idx)?; + } else { + x = self.mapper.map(x, block_idx)?; + } x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From f20005a336200e9fb849b965a2ccf2c5842d3828 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 11:33:07 -0400 Subject: [PATCH 047/107] split input into multiple chunks --- mistralrs-core/src/models/llama.rs | 58 +++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 040ea015f..7365fbce0 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -464,28 +464,52 @@ impl Llama { chunks[0].dtype(), self.blocks[0].attn.num_attention_heads, )?; + + let mut processed_chunks = Vec::new(); + let mut target_device = &self.cuda_devices[0]; + for (block_idx, block) in self.blocks.iter().enumerate() { + // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - println!("x device {:?}", x.device()); - println!("chunk device {:?}", chunks[0].device()); - if block_idx == 0 { - x = self.mapper.map(chunks[0].copy().unwrap(), block_idx)?; - } else { - x = self.mapper.map(x, block_idx)?; + // println!("x device {:?}", x.device()); + // println!("chunk device {:?}", chunks[0].device()); + for (chunk_idx, chunk) in chunks.iter().enumerate() { + let mut accumulated_attention: Option = None; + + if block_idx == 0 { + x = self.mapper.map(chunks[0].copy().unwrap(), block_idx)?; + } else { + x = self.mapper.map(x, block_idx)?; + } + x = block.forward( + &x, + &mask.clone().map(|m| m.to_device(x.device()).unwrap()), + seqlen_offsets, + start_offsets_kernel.clone(), + block_idx, + &mut cache, + metadata + .as_mut() + .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), &mut **metadata)), + )?; + + // Accumulate attention results + if let Some(ref mut acc) = accumulated_attention { + *acc = acc.add(&x.to_device(acc.device())?)?; + } else { + accumulated_attention = Some(x); + } + + // Add the accumulated attention for this chunk to block_chunks + if let Some(acc) = accumulated_attention { + block_chunks.push(acc); + } } - x = block.forward( - &x, - &mask.clone().map(|m| m.to_device(x.device()).unwrap()), - seqlen_offsets, - start_offsets_kernel.clone(), - block_idx, - &mut cache, - metadata - .as_mut() - .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), &mut **metadata)), - )?; + x = x.to_device(&target_device)?; + processed_chunks.push(x.clone()); } + x = candle_core::Tensor::cat(&processed_chunks, 1)?; let x = x.to_device(&self.device)?; let mut x = self.ln_f.forward(&x)?; if matches!(self.lm_head, QMatMul::QTensor(_)) { From 9d0b6ce32c9bd230af2f73ce374fb334bd156b6c Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 11:34:26 -0400 Subject: [PATCH 048/107] add missing variable block_chunks --- mistralrs-core/src/models/llama.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 7365fbce0..8e3df9a55 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -470,6 +470,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { + let mut block_chunks = Vec::new(); // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; // println!("x device {:?}", x.device()); From 0c6a64c21e770715448cbcfc6961af847cdb0774 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 11:36:56 -0400 Subject: [PATCH 049/107] use each chunk first --- mistralrs-core/src/models/llama.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 8e3df9a55..56cb00980 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -479,7 +479,8 @@ impl Llama { let mut accumulated_attention: Option = None; if block_idx == 0 { - x = self.mapper.map(chunks[0].copy().unwrap(), block_idx)?; + // x = self.mapper.map(chunks[0].copy().unwrap(), block_idx)?; + let mut x = self.mapper.map(chunk.clone(), block_idx)?.clone(); } else { x = self.mapper.map(x, block_idx)?; } From 86e1e549f7856506290b0108eb01ecac212dd3ac Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 11:37:48 -0400 Subject: [PATCH 050/107] clone x in accumulated attention --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 56cb00980..9702fd529 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -500,7 +500,7 @@ impl Llama { if let Some(ref mut acc) = accumulated_attention { *acc = acc.add(&x.to_device(acc.device())?)?; } else { - accumulated_attention = Some(x); + accumulated_attention = Some(x.clone()); } // Add the accumulated attention for this chunk to block_chunks From 535e5c7299c5c3653c16a393929d699f031ffb00 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 12:52:52 -0400 Subject: [PATCH 051/107] change mapper with block_chunks --- mistralrs-core/src/models/llama.rs | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 9702fd529..8aa323c73 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -361,7 +361,8 @@ impl Block { metadata, )? + residual)?; let residual = &x; - let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; + // let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; + let x = self.rms_2.forward(&x)?; Ok(x) } @@ -478,12 +479,12 @@ impl Llama { for (chunk_idx, chunk) in chunks.iter().enumerate() { let mut accumulated_attention: Option = None; - if block_idx == 0 { - // x = self.mapper.map(chunks[0].copy().unwrap(), block_idx)?; - let mut x = self.mapper.map(chunk.clone(), block_idx)?.clone(); + let x = if block_idx == 0 { + self.mapper.map(chunk.clone(), block_idx)? } else { - x = self.mapper.map(x, block_idx)?; - } + self.mapper.map(block_chunks[chunk_idx].clone(), block_idx)? + }; + x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), @@ -508,6 +509,11 @@ impl Llama { block_chunks.push(acc); } } + + // do feedforward after attention has been run for each chunk + let residual = x.clone(); + let mut x = block.mlp.forward(&x)?; + x = (x + &residual)?; x = x.to_device(&target_device)?; processed_chunks.push(x.clone()); } From 8d55784f8d6b4e242cc1c47915063697a22b676c Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 12:55:33 -0400 Subject: [PATCH 052/107] give block chunks a type --- mistralrs-core/src/models/llama.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 8aa323c73..e80043ac9 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -471,7 +471,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { - let mut block_chunks = Vec::new(); + let mut block_chunks: Vec = Vec::new(); // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; // println!("x device {:?}", x.device()); @@ -479,12 +479,12 @@ impl Llama { for (chunk_idx, chunk) in chunks.iter().enumerate() { let mut accumulated_attention: Option = None; - let x = if block_idx == 0 { + let mut x = if block_idx == 0 { self.mapper.map(chunk.clone(), block_idx)? } else { self.mapper.map(block_chunks[chunk_idx].clone(), block_idx)? }; - + x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), From f7381057030ccd24285063131df3205ee8432b64 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 13:00:00 -0400 Subject: [PATCH 053/107] make as type tensor --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index e80043ac9..5ece4f8ce 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -471,7 +471,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { - let mut block_chunks: Vec = Vec::new(); + let mut block_chunks: Vec = Vec::new(); // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; // println!("x device {:?}", x.device()); From 4addbb5edfa7e7790ce70f7ae822e66393fb7ee4 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 13:07:42 -0400 Subject: [PATCH 054/107] move block chunks --- mistralrs-core/src/models/llama.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 5ece4f8ce..37d8dc702 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -469,9 +469,10 @@ impl Llama { let mut processed_chunks = Vec::new(); let mut target_device = &self.cuda_devices[0]; + let mut block_chunks: Vec = Vec::new(); + for (block_idx, block) in self.blocks.iter().enumerate() { - let mut block_chunks: Vec = Vec::new(); // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; // println!("x device {:?}", x.device()); From d6ffb1015355c9d5a0f6bf2bdb9d115fdf2eb898 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 13:16:56 -0400 Subject: [PATCH 055/107] add to accumulated attention --- mistralrs-core/src/models/llama.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 37d8dc702..b00d2c29b 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -470,7 +470,7 @@ impl Llama { let mut target_device = &self.cuda_devices[0]; let mut block_chunks: Vec = Vec::new(); - + for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; @@ -500,7 +500,8 @@ impl Llama { // Accumulate attention results if let Some(ref mut acc) = accumulated_attention { - *acc = acc.add(&x.to_device(acc.device())?)?; + // *acc = acc.add(&x.to_device(acc.device())?)?; + *acc = acc.add(x); } else { accumulated_attention = Some(x.clone()); } From 61b9b8a16acf0eb1fe7a418bd4d3e4af1263a043 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 13:19:31 -0400 Subject: [PATCH 056/107] unwrap x --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index b00d2c29b..8a9111fd0 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -501,7 +501,7 @@ impl Llama { // Accumulate attention results if let Some(ref mut acc) = accumulated_attention { // *acc = acc.add(&x.to_device(acc.device())?)?; - *acc = acc.add(x); + *acc = acc.add(x)?; } else { accumulated_attention = Some(x.clone()); } From c1cc882cb7696bee86fa73a3532dc81278e1ccb5 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 13:20:05 -0400 Subject: [PATCH 057/107] &tensor --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 8a9111fd0..c12b42a84 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -501,7 +501,7 @@ impl Llama { // Accumulate attention results if let Some(ref mut acc) = accumulated_attention { // *acc = acc.add(&x.to_device(acc.device())?)?; - *acc = acc.add(x)?; + *acc = acc.add(&x)?; } else { accumulated_attention = Some(x.clone()); } From f87ead198d22ed19895bfe4e40f785db0894d1c3 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:10:13 -0400 Subject: [PATCH 058/107] fix block_chunks --- mistralrs-core/src/models/llama.rs | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index c12b42a84..1feb36541 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -478,14 +478,12 @@ impl Llama { // println!("x device {:?}", x.device()); // println!("chunk device {:?}", chunks[0].device()); for (chunk_idx, chunk) in chunks.iter().enumerate() { - let mut accumulated_attention: Option = None; - let mut x = if block_idx == 0 { self.mapper.map(chunk.clone(), block_idx)? } else { self.mapper.map(block_chunks[chunk_idx].clone(), block_idx)? }; - + x = block.forward( &x, &mask.clone().map(|m| m.to_device(x.device()).unwrap()), @@ -497,20 +495,22 @@ impl Llama { .as_mut() .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), &mut **metadata)), )?; - + // Accumulate attention results - if let Some(ref mut acc) = accumulated_attention { - // *acc = acc.add(&x.to_device(acc.device())?)?; - *acc = acc.add(&x)?; + if block_chunks.len() <= chunk_idx { + block_chunks.push(x); } else { - accumulated_attention = Some(x.clone()); + block_chunks[chunk_idx] = block_chunks[chunk_idx].add(&x)?; } + } - // Add the accumulated attention for this chunk to block_chunks - if let Some(acc) = accumulated_attention { - block_chunks.push(acc); - } - } + // Concatenate chunks for this block + let block_chunks: Vec = block_chunks + .into_iter() + .map(|chunk| chunk.to_device(&target_device)) + .collect::, _>>()?; + + let mut x = candle_core::Tensor::cat(&block_chunks, 1)?; // do feedforward after attention has been run for each chunk let residual = x.clone(); From f665810d30848a1d214c048062da38a8977f8892 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:11:48 -0400 Subject: [PATCH 059/107] make generic type --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 1feb36541..cbb150778 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -508,7 +508,7 @@ impl Llama { let block_chunks: Vec = block_chunks .into_iter() .map(|chunk| chunk.to_device(&target_device)) - .collect::, _>>()?; + .collect::>()?; let mut x = candle_core::Tensor::cat(&block_chunks, 1)?; From 8140413c6c35447feea17899296f26c8a7eb85f3 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:13:11 -0400 Subject: [PATCH 060/107] fix blocks_chunks to device --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index cbb150778..02d63f01e 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -508,7 +508,7 @@ impl Llama { let block_chunks: Vec = block_chunks .into_iter() .map(|chunk| chunk.to_device(&target_device)) - .collect::>()?; + .collect()?; let mut x = candle_core::Tensor::cat(&block_chunks, 1)?; From 23af80c27bba9663b2b9eaaa7d60d873625fa32c Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:13:58 -0400 Subject: [PATCH 061/107] another fix for concat block_chunks --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 02d63f01e..f05bb7173 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -508,7 +508,7 @@ impl Llama { let block_chunks: Vec = block_chunks .into_iter() .map(|chunk| chunk.to_device(&target_device)) - .collect()?; + .collect::>()?; let mut x = candle_core::Tensor::cat(&block_chunks, 1)?; From dd689e3c38882d457c628dc9f3412b314eee0fee Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:14:32 -0400 Subject: [PATCH 062/107] remove ? operator --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index f05bb7173..6db8651cb 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -508,7 +508,7 @@ impl Llama { let block_chunks: Vec = block_chunks .into_iter() .map(|chunk| chunk.to_device(&target_device)) - .collect::>()?; + .collect::>(); let mut x = candle_core::Tensor::cat(&block_chunks, 1)?; From 2106933b5644b7b31292dd86bfe69ecfe8f07e4a Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:16:17 -0400 Subject: [PATCH 063/107] replace with try_collect --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 6db8651cb..46bfd25a6 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -508,7 +508,7 @@ impl Llama { let block_chunks: Vec = block_chunks .into_iter() .map(|chunk| chunk.to_device(&target_device)) - .collect::>(); + .try_collect()?; let mut x = candle_core::Tensor::cat(&block_chunks, 1)?; From 71fdd71513b86e411ef1504be21442ec78f52675 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:17:50 -0400 Subject: [PATCH 064/107] change type of block_chunks --- mistralrs-core/src/models/llama.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 46bfd25a6..ba423b473 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -505,10 +505,12 @@ impl Llama { } // Concatenate chunks for this block - let block_chunks: Vec = block_chunks + let block_chunks: Result> = block_chunks .into_iter() .map(|chunk| chunk.to_device(&target_device)) - .try_collect()?; + .collect(); + + let block_chunks = block_chunks?; // Propagate any errors let mut x = candle_core::Tensor::cat(&block_chunks, 1)?; From 0b129fad5ef069ba90e651d236d801ddf458b577 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:18:57 -0400 Subject: [PATCH 065/107] clone to move blcok_chunks --- mistralrs-core/src/models/llama.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index ba423b473..9f322cd81 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -506,6 +506,7 @@ impl Llama { // Concatenate chunks for this block let block_chunks: Result> = block_chunks + .clone() .into_iter() .map(|chunk| chunk.to_device(&target_device)) .collect(); From c09b4596616ffbffe4ee17efb85db963afed235d Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:30:51 -0400 Subject: [PATCH 066/107] remove add --- mistralrs-core/src/models/llama.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 9f322cd81..208efdc7b 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -499,9 +499,10 @@ impl Llama { // Accumulate attention results if block_chunks.len() <= chunk_idx { block_chunks.push(x); - } else { - block_chunks[chunk_idx] = block_chunks[chunk_idx].add(&x)?; - } + } + // else { + // block_chunks[chunk_idx] = block_chunks[chunk_idx].add(&x)?; + // } } // Concatenate chunks for this block From d201134d41b518ad63aa3860569de37ceb35eb13 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 27 Aug 2024 10:20:10 -0400 Subject: [PATCH 067/107] switch to four devices --- mistralrs-core/src/models/llama.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 208efdc7b..96a28f4d2 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -426,7 +426,7 @@ impl Llama { let mut x = self.wte.forward(input_ids)?; let (batch_size, seq_len, hidden_size) = x.dims3()?; - let num_devices = 1; + let num_devices = 4; let chunk_size = seq_len / num_devices; let mut chunks: Vec = Vec::with_capacity(num_devices); @@ -499,10 +499,9 @@ impl Llama { // Accumulate attention results if block_chunks.len() <= chunk_idx { block_chunks.push(x); - } - // else { - // block_chunks[chunk_idx] = block_chunks[chunk_idx].add(&x)?; - // } + } else { + block_chunks[chunk_idx] = &x; + } } // Concatenate chunks for this block @@ -548,7 +547,7 @@ impl Llama { ); } - let num_devices = 1; + let num_devices = 4; let mut cuda_devices = Vec::with_capacity(num_devices); let mapper = normal_loading_metadata .mapper From c5b4fde88f45d2e879dfbf79c7e0039691133fcc Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 27 Aug 2024 10:55:02 -0400 Subject: [PATCH 068/107] fix compile error with & --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 96a28f4d2..64753b180 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -500,7 +500,7 @@ impl Llama { if block_chunks.len() <= chunk_idx { block_chunks.push(x); } else { - block_chunks[chunk_idx] = &x; + block_chunks[chunk_idx] = x; } } From 79f76062b1bf01cd1716e7b0663d4f978aebcc53 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 28 Aug 2024 10:12:57 -0400 Subject: [PATCH 069/107] uodate metadata device --- mistralrs-core/src/models/llama.rs | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 64753b180..39db8256b 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -483,17 +483,38 @@ impl Llama { } else { self.mapper.map(block_chunks[chunk_idx].clone(), block_idx)? }; + + let device_chunk = chunk.device(); + // x = block.forward( + // &x, + // &mask.clone().map(|m| m.to_device(x.device()).unwrap()), + // seqlen_offsets, + // start_offsets_kernel.clone(), + // block_idx, + // &mut cache, + // metadata + // .as_mut() + // .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), &mut **metadata)), + // )?; + x = block.forward( &x, - &mask.clone().map(|m| m.to_device(x.device()).unwrap()), + &mask.clone().map(|m| m.to_device(&device_chunk).unwrap()), seqlen_offsets, - start_offsets_kernel.clone(), + start_offsets_kernel.clone().to_device(&device_chunk)?, block_idx, + // &mut cache_on_chunk_device, &mut cache, metadata .as_mut() - .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), &mut **metadata)), + .map(|(kv_cache, metadata)| { + let (tensor1, tensor2) = kv_cache[block_idx].clone(); + ( + (tensor1.to_device(&device_chunk).unwrap(), tensor2.to_device(&device_chunk).unwrap()), + &mut **metadata + ) + }), )?; // Accumulate attention results From f50a159875ae8bba7d6746b40e942e25d8dc6a85 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 28 Aug 2024 11:49:04 -0400 Subject: [PATCH 070/107] add kv cache rotation --- mistralrs-core/src/models/llama.rs | 84 ++++++++++++++++-------------- 1 file changed, 45 insertions(+), 39 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 39db8256b..2fa90c7aa 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -454,7 +454,7 @@ impl Llama { } } - let mut cache = self.kv_caches[0].lock(); + // let mut cache = self.kv_caches[0].lock(); let mask = CausalMasker.make_causal_mask_as_attn_bias( input_ids, metadata @@ -484,44 +484,50 @@ impl Llama { self.mapper.map(block_chunks[chunk_idx].clone(), block_idx)? }; - let device_chunk = chunk.device(); - - // x = block.forward( - // &x, - // &mask.clone().map(|m| m.to_device(x.device()).unwrap()), - // seqlen_offsets, - // start_offsets_kernel.clone(), - // block_idx, - // &mut cache, - // metadata - // .as_mut() - // .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), &mut **metadata)), - // )?; - - x = block.forward( - &x, - &mask.clone().map(|m| m.to_device(&device_chunk).unwrap()), - seqlen_offsets, - start_offsets_kernel.clone().to_device(&device_chunk)?, - block_idx, - // &mut cache_on_chunk_device, - &mut cache, - metadata - .as_mut() - .map(|(kv_cache, metadata)| { - let (tensor1, tensor2) = kv_cache[block_idx].clone(); - ( - (tensor1.to_device(&device_chunk).unwrap(), tensor2.to_device(&device_chunk).unwrap()), - &mut **metadata - ) - }), - )?; - - // Accumulate attention results - if block_chunks.len() <= chunk_idx { - block_chunks.push(x); - } else { - block_chunks[chunk_idx] = x; + for cache_rotation in 0..num_caches { + let cache_idx = (chunk_idx + cache_rotation) % num_caches; + let kv_cache = &self.kv_caches[cache_idx]; + let mut cache = kv_cache.lock(); + + let device_chunk = chunk.device(); + + // x = block.forward( + // &x, + // &mask.clone().map(|m| m.to_device(x.device()).unwrap()), + // seqlen_offsets, + // start_offsets_kernel.clone(), + // block_idx, + // &mut cache, + // metadata + // .as_mut() + // .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), &mut **metadata)), + // )?; + + x = block.forward( + &x, + &mask.clone().map(|m| m.to_device(&device_chunk).unwrap()), + seqlen_offsets, + start_offsets_kernel.clone().to_device(&device_chunk)?, + block_idx, + // &mut cache_on_chunk_device, + &mut cache, + metadata + .as_mut() + .map(|(kv_cache, metadata)| { + let (tensor1, tensor2) = kv_cache[block_idx].clone(); + ( + (tensor1.to_device(&device_chunk).unwrap(), tensor2.to_device(&device_chunk).unwrap()), + &mut **metadata + ) + }), + )?; + + // Accumulate attention results + if block_chunks.len() <= chunk_idx { + block_chunks.push(x); + } else { + block_chunks[chunk_idx] = x; + } } } From b913fee6199c805de5a8ffcef0fd2d31503f3a4b Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 28 Aug 2024 11:54:38 -0400 Subject: [PATCH 071/107] add missing num_caches --- mistralrs-core/src/models/llama.rs | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 2fa90c7aa..9fa647d80 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -455,17 +455,6 @@ impl Llama { } // let mut cache = self.kv_caches[0].lock(); - let mask = CausalMasker.make_causal_mask_as_attn_bias( - input_ids, - metadata - .as_ref() - .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) - .unwrap_or(&*cache as &dyn PastKvLenCache), - // x.dtype(), - chunks[0].dtype(), - self.blocks[0].attn.num_attention_heads, - )?; - let mut processed_chunks = Vec::new(); let mut target_device = &self.cuda_devices[0]; @@ -484,12 +473,26 @@ impl Llama { self.mapper.map(block_chunks[chunk_idx].clone(), block_idx)? }; + let num_caches = self.kv_caches.len(); + for cache_rotation in 0..num_caches { let cache_idx = (chunk_idx + cache_rotation) % num_caches; let kv_cache = &self.kv_caches[cache_idx]; let mut cache = kv_cache.lock(); let device_chunk = chunk.device(); + + let mask = CausalMasker.make_causal_mask_as_attn_bias( + input_ids, + metadata + .as_ref() + .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) + .unwrap_or(&*cache as &dyn PastKvLenCache), + // x.dtype(), + chunks[0].dtype(), + self.blocks[0].attn.num_attention_heads, + )?; + // x = block.forward( // &x, From 0a7b422763ca03778398351716c771a57d1228f3 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 28 Aug 2024 11:59:41 -0400 Subject: [PATCH 072/107] fix compile error --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 9fa647d80..e938011d5 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -507,7 +507,7 @@ impl Llama { // )?; x = block.forward( - &x, + &x.clone(), &mask.clone().map(|m| m.to_device(&device_chunk).unwrap()), seqlen_offsets, start_offsets_kernel.clone().to_device(&device_chunk)?, From c98dcb7682cf32944816257b77c94c7b2c654aab Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 28 Aug 2024 12:02:06 -0400 Subject: [PATCH 073/107] clone mapper --- mistralrs-core/src/models/llama.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index e938011d5..5e0756442 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -468,9 +468,9 @@ impl Llama { // println!("chunk device {:?}", chunks[0].device()); for (chunk_idx, chunk) in chunks.iter().enumerate() { let mut x = if block_idx == 0 { - self.mapper.map(chunk.clone(), block_idx)? + self.mapper.map(chunk.clone(), block_idx)?.clone(); } else { - self.mapper.map(block_chunks[chunk_idx].clone(), block_idx)? + self.mapper.map(block_chunks[chunk_idx].clone(), block_idx)?.clone(); }; let num_caches = self.kv_caches.len(); @@ -507,7 +507,7 @@ impl Llama { // )?; x = block.forward( - &x.clone(), + &x, &mask.clone().map(|m| m.to_device(&device_chunk).unwrap()), seqlen_offsets, start_offsets_kernel.clone().to_device(&device_chunk)?, From 962f7440d7f5621d917fb73896271c858a7fcd85 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:53:19 -0400 Subject: [PATCH 074/107] remove clone --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 5e0756442..471634765 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -468,9 +468,9 @@ impl Llama { // println!("chunk device {:?}", chunks[0].device()); for (chunk_idx, chunk) in chunks.iter().enumerate() { let mut x = if block_idx == 0 { - self.mapper.map(chunk.clone(), block_idx)?.clone(); + self.mapper.map(chunk, block_idx)?; } else { - self.mapper.map(block_chunks[chunk_idx].clone(), block_idx)?.clone(); + self.mapper.map(block_chunks[chunk_idx], block_idx)?; }; let num_caches = self.kv_caches.len(); From 7cd35035177eb7f6e485169e4fb2f96bb0757e48 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:55:37 -0400 Subject: [PATCH 075/107] clone reference --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 471634765..0f189863d 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -468,7 +468,7 @@ impl Llama { // println!("chunk device {:?}", chunks[0].device()); for (chunk_idx, chunk) in chunks.iter().enumerate() { let mut x = if block_idx == 0 { - self.mapper.map(chunk, block_idx)?; + self.mapper.map(chunk.clone(), block_idx)?; } else { self.mapper.map(block_chunks[chunk_idx], block_idx)?; }; From ea04012862147387e60c6a1ec24141764a429971 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:57:59 -0400 Subject: [PATCH 076/107] return tensor --- mistralrs-core/src/models/llama.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 0f189863d..422bc16ac 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -468,9 +468,13 @@ impl Llama { // println!("chunk device {:?}", chunks[0].device()); for (chunk_idx, chunk) in chunks.iter().enumerate() { let mut x = if block_idx == 0 { - self.mapper.map(chunk.clone(), block_idx)?; + let tensor = chunk.clone(); + self.mapper.map(&tensor, block_idx)?; + tensor } else { - self.mapper.map(block_chunks[chunk_idx], block_idx)?; + let tensor = block_chunks[chunk_idx].clone(); + self.mapper.map(&tensor, block_idx)?; + tensor }; let num_caches = self.kv_caches.len(); From a57e1c99c5c234f9f83f75249d5fa7eb589695ca Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:58:39 -0400 Subject: [PATCH 077/107] remove borrow --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 422bc16ac..6d02ffebc 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -469,11 +469,11 @@ impl Llama { for (chunk_idx, chunk) in chunks.iter().enumerate() { let mut x = if block_idx == 0 { let tensor = chunk.clone(); - self.mapper.map(&tensor, block_idx)?; + self.mapper.map(tensor, block_idx)?; tensor } else { let tensor = block_chunks[chunk_idx].clone(); - self.mapper.map(&tensor, block_idx)?; + self.mapper.map(tensor, block_idx)?; tensor }; From b69edccb5ebe8f09649e5bdeb712115beb4c0a17 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:59:58 -0400 Subject: [PATCH 078/107] fix value moved --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 6d02ffebc..ee345f944 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -469,11 +469,11 @@ impl Llama { for (chunk_idx, chunk) in chunks.iter().enumerate() { let mut x = if block_idx == 0 { let tensor = chunk.clone(); - self.mapper.map(tensor, block_idx)?; + self.mapper.map(tensor.clone(), block_idx)?; tensor } else { let tensor = block_chunks[chunk_idx].clone(); - self.mapper.map(tensor, block_idx)?; + self.mapper.map(tensor.clone(), block_idx)?; tensor }; From 7cfb29d51fd2435942e9a430e83b68c8e808519b Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 11:01:50 -0400 Subject: [PATCH 079/107] borrow on accumulate --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index ee345f944..33b388cf8 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -531,9 +531,9 @@ impl Llama { // Accumulate attention results if block_chunks.len() <= chunk_idx { - block_chunks.push(x); + block_chunks.push(x.clone()); } else { - block_chunks[chunk_idx] = x; + block_chunks[chunk_idx] = x.clone(); } } } From da65eb2de5ebb33abf4b13216befcab3796b658a Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 11:11:30 -0400 Subject: [PATCH 080/107] add logging --- mistralrs-core/src/models/llama.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 33b388cf8..ff382b6cd 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -464,9 +464,10 @@ impl Llama { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - // println!("x device {:?}", x.device()); + println!("block_idx {:?}", block_idx); // println!("chunk device {:?}", chunks[0].device()); for (chunk_idx, chunk) in chunks.iter().enumerate() { + println!("chunk_idx {:?}", chunk_idx); let mut x = if block_idx == 0 { let tensor = chunk.clone(); self.mapper.map(tensor.clone(), block_idx)?; From 9c5cd3844fac3d882d31d80a66a6900a74bf9157 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 11:20:39 -0400 Subject: [PATCH 081/107] more logging --- mistralrs-core/src/models/llama.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index ff382b6cd..4b7a0e467 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -483,6 +483,7 @@ impl Llama { for cache_rotation in 0..num_caches { let cache_idx = (chunk_idx + cache_rotation) % num_caches; let kv_cache = &self.kv_caches[cache_idx]; + println!("cache_idx {:?}", cache_idx); let mut cache = kv_cache.lock(); let device_chunk = chunk.device(); @@ -511,6 +512,7 @@ impl Llama { // .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), &mut **metadata)), // )?; + println!("before block forward"); x = block.forward( &x, &mask.clone().map(|m| m.to_device(&device_chunk).unwrap()), @@ -530,6 +532,8 @@ impl Llama { }), )?; + println!("after block forward"); + // Accumulate attention results if block_chunks.len() <= chunk_idx { block_chunks.push(x.clone()); From cdd480de0da98aa69ff626f05946eaefe2101037 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 11:31:41 -0400 Subject: [PATCH 082/107] fix chunk to device chunk --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 4b7a0e467..12dd37793 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -547,7 +547,7 @@ impl Llama { let block_chunks: Result> = block_chunks .clone() .into_iter() - .map(|chunk| chunk.to_device(&target_device)) + .map(|chunk| chunk.to_device(&device_chunk)) .collect(); let block_chunks = block_chunks?; // Propagate any errors From 4eb47755876f445f364ebd4f7512f4616505c913 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 11:34:17 -0400 Subject: [PATCH 083/107] remove concat block_chunks --- mistralrs-core/src/models/llama.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 12dd37793..300be6b76 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -544,13 +544,13 @@ impl Llama { } // Concatenate chunks for this block - let block_chunks: Result> = block_chunks - .clone() - .into_iter() - .map(|chunk| chunk.to_device(&device_chunk)) - .collect(); + // let block_chunks: Result> = block_chunks + // .clone() + // .into_iter() + // .map(|chunk| chunk.to_device(&device_chunk)) + // .collect(); - let block_chunks = block_chunks?; // Propagate any errors + // let block_chunks = block_chunks?; // Propagate any errors let mut x = candle_core::Tensor::cat(&block_chunks, 1)?; From ee27e9825749625bf78693bb9ce81b608adf6f3a Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 11:44:37 -0400 Subject: [PATCH 084/107] move cache to chunk device --- mistralrs-core/src/models/llama.rs | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 300be6b76..f62ceb539 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -488,16 +488,30 @@ impl Llama { let device_chunk = chunk.device(); + // Determine the original device of the cache + let original_cache_device = cache.iter().find_map(|opt| { + opt.as_ref().map(|(k, _)| k.device().clone()) + }).unwrap_or_else(|| device_chunk.clone()); + + // Move cache to chunk device + let mut cache_on_chunk_device: Vec<_> = cache.iter().map(|opt| { + opt.as_ref().map(|(k, v)| { + (k.to_device(device_chunk).unwrap(), v.to_device(device_chunk).unwrap()) + }) + }).collect(); + let mask = CausalMasker.make_causal_mask_as_attn_bias( input_ids, metadata .as_ref() .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) - .unwrap_or(&*cache as &dyn PastKvLenCache), + .unwrap_or(&*cache_on_chunk_device as &dyn PastKvLenCache), // x.dtype(), chunks[0].dtype(), self.blocks[0].attn.num_attention_heads, )?; + + // x = block.forward( @@ -520,7 +534,7 @@ impl Llama { start_offsets_kernel.clone().to_device(&device_chunk)?, block_idx, // &mut cache_on_chunk_device, - &mut cache, + &mut cache_on_chunk_device, metadata .as_mut() .map(|(kv_cache, metadata)| { From 57ae1d87311af0b84356a00e565026ab9993381b Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 11:48:22 -0400 Subject: [PATCH 085/107] fix error in masker --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index f62ceb539..1266b30e9 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -505,7 +505,7 @@ impl Llama { metadata .as_ref() .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) - .unwrap_or(&*cache_on_chunk_device as &dyn PastKvLenCache), + .unwrap_or(&*cache as &dyn PastKvLenCache), // x.dtype(), chunks[0].dtype(), self.blocks[0].attn.num_attention_heads, From 34bf2d1c9ba86dc73bbde00107cc5107d4c19fa3 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 12:52:53 -0400 Subject: [PATCH 086/107] move all to block device --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 1266b30e9..0d8c9f2d7 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -486,7 +486,7 @@ impl Llama { println!("cache_idx {:?}", cache_idx); let mut cache = kv_cache.lock(); - let device_chunk = chunk.device(); + let device_chunk = block.device(); // Determine the original device of the cache let original_cache_device = cache.iter().find_map(|opt| { From a20d7a4b526990af3e6e924999cb6d51942bf219 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 12:56:21 -0400 Subject: [PATCH 087/107] change to block device --- mistralrs-core/src/models/llama.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 0d8c9f2d7..e99029afa 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -486,7 +486,8 @@ impl Llama { println!("cache_idx {:?}", cache_idx); let mut cache = kv_cache.lock(); - let device_chunk = block.device(); + // let device_chunk = &block.device(); + let device_chunk = self.mapper.device_for(block)?; // Determine the original device of the cache let original_cache_device = cache.iter().find_map(|opt| { From 042c0a1982a3ad9c7882c6fae4e88a9daf562aee Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 12:57:55 -0400 Subject: [PATCH 088/107] change block device args --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index e99029afa..6f9431331 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -487,7 +487,7 @@ impl Llama { let mut cache = kv_cache.lock(); // let device_chunk = &block.device(); - let device_chunk = self.mapper.device_for(block)?; + let device_chunk = self.mapper.device_for(block_idx, false)?; // Determine the original device of the cache let original_cache_device = cache.iter().find_map(|opt| { From 770806ac12fdd901c0ca5ac4f0bcad09ecad59b2 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 13:06:15 -0400 Subject: [PATCH 089/107] add device to block --- mistralrs-core/src/models/llama.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 6f9431331..83bc827d6 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -460,7 +460,7 @@ impl Llama { let mut block_chunks: Vec = Vec::new(); - for (block_idx, block) in self.blocks.iter().enumerate() { + for (block_idx, (block, block_device)) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; @@ -487,7 +487,7 @@ impl Llama { let mut cache = kv_cache.lock(); // let device_chunk = &block.device(); - let device_chunk = self.mapper.device_for(block_idx, false)?; + let device_chunk = block_device; // Determine the original device of the cache let original_cache_device = cache.iter().find_map(|opt| { @@ -623,7 +623,7 @@ impl Llama { mapper.set_nm_device(vb.pp("model.norm"), false), )?; let head_dim = cfg.hidden_size / cfg.num_attention_heads; - let blocks: Vec<_> = + let blocks: Vec<(Block, Device)> = NiceProgressBar::<_, 'b'>(0..cfg.num_hidden_layers, "Loading repeating layers") .into_iter() .map(|i| { @@ -652,7 +652,7 @@ impl Llama { if !cuda_devices.iter().any(|d| format!("{:?}", d) == format!("{:?}", device)) { cuda_devices.push(device.clone()); } - Block::load( + let block = Block::load( vb.pp(&format!("model.layers.{i}")), cfg, &*mapper, @@ -661,7 +661,8 @@ impl Llama { rotary_emb, paged_attn, ) - .expect("Failed to load block.") + .expect("Failed to load block."); + (block, device.clone()) }) .collect(); From 0b3c91117dc0cb03a74eb6fa5442e02bed67a4a2 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 13:07:46 -0400 Subject: [PATCH 090/107] fix llama struct --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 83bc827d6..2866a75cd 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -403,7 +403,7 @@ impl Block { pub struct Llama { wte: Embedding, - blocks: Vec, + blocks: Vec<(Block, Device)>, ln_f: RmsNorm, lm_head: QMatMul, // pub kv_cache: crate::pipeline::Cache, From bcf6f847435579db683297b8db15ded49d25f4a8 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 13:15:45 -0400 Subject: [PATCH 091/107] revert blocks device --- mistralrs-core/src/models/llama.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 2866a75cd..6f9431331 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -403,7 +403,7 @@ impl Block { pub struct Llama { wte: Embedding, - blocks: Vec<(Block, Device)>, + blocks: Vec, ln_f: RmsNorm, lm_head: QMatMul, // pub kv_cache: crate::pipeline::Cache, @@ -460,7 +460,7 @@ impl Llama { let mut block_chunks: Vec = Vec::new(); - for (block_idx, (block, block_device)) in self.blocks.iter().enumerate() { + for (block_idx, block) in self.blocks.iter().enumerate() { // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; @@ -487,7 +487,7 @@ impl Llama { let mut cache = kv_cache.lock(); // let device_chunk = &block.device(); - let device_chunk = block_device; + let device_chunk = self.mapper.device_for(block_idx, false)?; // Determine the original device of the cache let original_cache_device = cache.iter().find_map(|opt| { @@ -623,7 +623,7 @@ impl Llama { mapper.set_nm_device(vb.pp("model.norm"), false), )?; let head_dim = cfg.hidden_size / cfg.num_attention_heads; - let blocks: Vec<(Block, Device)> = + let blocks: Vec<_> = NiceProgressBar::<_, 'b'>(0..cfg.num_hidden_layers, "Loading repeating layers") .into_iter() .map(|i| { @@ -652,7 +652,7 @@ impl Llama { if !cuda_devices.iter().any(|d| format!("{:?}", d) == format!("{:?}", device)) { cuda_devices.push(device.clone()); } - let block = Block::load( + Block::load( vb.pp(&format!("model.layers.{i}")), cfg, &*mapper, @@ -661,8 +661,7 @@ impl Llama { rotary_emb, paged_attn, ) - .expect("Failed to load block."); - (block, device.clone()) + .expect("Failed to load block.") }) .collect(); From 9945a8dbf398abca99e62fa5337bd0e46ce70f14 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 13:21:03 -0400 Subject: [PATCH 092/107] revert to device chunk --- mistralrs-core/src/models/llama.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 6f9431331..1266b30e9 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -486,8 +486,7 @@ impl Llama { println!("cache_idx {:?}", cache_idx); let mut cache = kv_cache.lock(); - // let device_chunk = &block.device(); - let device_chunk = self.mapper.device_for(block_idx, false)?; + let device_chunk = chunk.device(); // Determine the original device of the cache let original_cache_device = cache.iter().find_map(|opt| { From 8d0bc2442ffc1cfa63a809c5e9ce04e2326ff410 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 14:03:59 -0400 Subject: [PATCH 093/107] add block device --- mistralrs-core/src/models/llama.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 1266b30e9..4a25a8f3e 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -366,6 +366,10 @@ impl Block { Ok(x) } + fn get_device(&self) -> Device { + self.mlp.dtype_device().1 + } + fn load( vb: VarBuilder, cfg: &Config, @@ -486,7 +490,7 @@ impl Llama { println!("cache_idx {:?}", cache_idx); let mut cache = kv_cache.lock(); - let device_chunk = chunk.device(); + let device_chunk = block.get_device(); // Determine the original device of the cache let original_cache_device = cache.iter().find_map(|opt| { From 06525fe814c43b971ea376531502720dc739931e Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 14:05:23 -0400 Subject: [PATCH 094/107] add reference --- mistralrs-core/src/models/llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 4a25a8f3e..6b72be055 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -500,7 +500,7 @@ impl Llama { // Move cache to chunk device let mut cache_on_chunk_device: Vec<_> = cache.iter().map(|opt| { opt.as_ref().map(|(k, v)| { - (k.to_device(device_chunk).unwrap(), v.to_device(device_chunk).unwrap()) + (k.to_device(&device_chunk).unwrap(), v.to_device(&device_chunk).unwrap()) }) }).collect(); From a03670d67ce24c2033396e70d2e3fe6d6d7e9ac2 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 14:15:24 -0400 Subject: [PATCH 095/107] update tensor device --- mistralrs-core/src/models/llama.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 6b72be055..17918c6c1 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -466,6 +466,7 @@ impl Llama { for (block_idx, block) in self.blocks.iter().enumerate() { + let device_chunk = block.get_device(); // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; println!("block_idx {:?}", block_idx); @@ -475,11 +476,11 @@ impl Llama { let mut x = if block_idx == 0 { let tensor = chunk.clone(); self.mapper.map(tensor.clone(), block_idx)?; - tensor + tensor.to_device(device_chunk)? } else { let tensor = block_chunks[chunk_idx].clone(); self.mapper.map(tensor.clone(), block_idx)?; - tensor + tensor.to_device(device_chunk)? }; let num_caches = self.kv_caches.len(); @@ -490,7 +491,6 @@ impl Llama { println!("cache_idx {:?}", cache_idx); let mut cache = kv_cache.lock(); - let device_chunk = block.get_device(); // Determine the original device of the cache let original_cache_device = cache.iter().find_map(|opt| { From deabd31a788fa645be7c67478984b43a0dc0425d Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 14:16:14 -0400 Subject: [PATCH 096/107] borrow device chunk --- mistralrs-core/src/models/llama.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 17918c6c1..a293a72eb 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -476,11 +476,11 @@ impl Llama { let mut x = if block_idx == 0 { let tensor = chunk.clone(); self.mapper.map(tensor.clone(), block_idx)?; - tensor.to_device(device_chunk)? + tensor.to_device(&device_chunk)? } else { let tensor = block_chunks[chunk_idx].clone(); self.mapper.map(tensor.clone(), block_idx)?; - tensor.to_device(device_chunk)? + tensor.to_device(&device_chunk)? }; let num_caches = self.kv_caches.len(); From 3170304417384a714d8fe6c91b1a93e7d96f2d66 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 14:25:37 -0400 Subject: [PATCH 097/107] more logging --- mistralrs-core/src/models/llama.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index a293a72eb..06416a6be 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -570,6 +570,7 @@ impl Llama { // let block_chunks = block_chunks?; // Propagate any errors + println!("concat block chunks"); let mut x = candle_core::Tensor::cat(&block_chunks, 1)?; // do feedforward after attention has been run for each chunk @@ -579,6 +580,7 @@ impl Llama { x = x.to_device(&target_device)?; processed_chunks.push(x.clone()); } + println!("concat processed chunks"); x = candle_core::Tensor::cat(&processed_chunks, 1)?; let x = x.to_device(&self.device)?; let mut x = self.ln_f.forward(&x)?; From 1e3e55d2f9d17b49ff1a8823976b2e9b16ac3663 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 14:39:13 -0400 Subject: [PATCH 098/107] log logits --- mistralrs-core/src/pipeline/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 183eec1b3..d850594b1 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -692,6 +692,8 @@ pub trait Pipeline: }) .collect::>>()?; + println!("get logits"); + match post_op { CacheInstruction::Out => self.clone_out_cache(input_seqs, false), CacheInstruction::Nothing(_) => (), From 4fab476e79cd1adc3a71e251ee4e6f80ce467db2 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 14:47:45 -0400 Subject: [PATCH 099/107] try to clone out all caches --- mistralrs-core/src/pipeline/mod.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index d850594b1..5cb5a3294 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -695,7 +695,17 @@ pub trait Pipeline: println!("get logits"); match post_op { - CacheInstruction::Out => self.clone_out_cache(input_seqs, false), + // CacheInstruction::Out => self.clone_out_cache(input_seqs, false), + CacheInstruction::Out => { + let mut caches = Vec::new(); + + for cache in &self.caches { + let cloned_cache = cache.clone_out_cache(input_seqs, false); + caches.push(cloned_cache); + } + + caches // Return the collection of cloned caches + }, CacheInstruction::Nothing(_) => (), CacheInstruction::Reset { reset_non_granular, From 5863802ee70d660ea1b326d25683b88605cf3bb1 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 14:58:52 -0400 Subject: [PATCH 100/107] add logging in cacher --- mistralrs-core/src/pipeline/cache_manager.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mistralrs-core/src/pipeline/cache_manager.rs b/mistralrs-core/src/pipeline/cache_manager.rs index 9d2ed5838..009c7d8c0 100644 --- a/mistralrs-core/src/pipeline/cache_manager.rs +++ b/mistralrs-core/src/pipeline/cache_manager.rs @@ -207,10 +207,12 @@ fn clone_out_cache( seqs: &mut [&mut crate::sequence::Sequence], target: SeqCache, ) { + println!("clone_out_cache"); for layer in 0..num_hidden_layers { let cache = cache.get(layer).unwrap(); let k_cache = cache.as_ref().unwrap().0.clone(); let v_cache = cache.as_ref().unwrap().1.clone(); + println!("after v_cache"); let k_caches = k_cache.chunk(seqs.len(), 0).unwrap(); debug_assert_eq!(k_caches.len(), seqs.len()); From ddcd84875417a4393c22631cce46a75f4766dcbd Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 14:59:27 -0400 Subject: [PATCH 101/107] revert clone out cache --- mistralrs-core/src/pipeline/mod.rs | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 5cb5a3294..d850594b1 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -695,17 +695,7 @@ pub trait Pipeline: println!("get logits"); match post_op { - // CacheInstruction::Out => self.clone_out_cache(input_seqs, false), - CacheInstruction::Out => { - let mut caches = Vec::new(); - - for cache in &self.caches { - let cloned_cache = cache.clone_out_cache(input_seqs, false); - caches.push(cloned_cache); - } - - caches // Return the collection of cloned caches - }, + CacheInstruction::Out => self.clone_out_cache(input_seqs, false), CacheInstruction::Nothing(_) => (), CacheInstruction::Reset { reset_non_granular, From d9ac7ecc99de5a2bc55d3755797675674cc37eca Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 15:12:32 -0400 Subject: [PATCH 102/107] skip clone out --- mistralrs-core/src/pipeline/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index d850594b1..82f57f3a7 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -695,7 +695,7 @@ pub trait Pipeline: println!("get logits"); match post_op { - CacheInstruction::Out => self.clone_out_cache(input_seqs, false), + // CacheInstruction::Out => self.clone_out_cache(input_seqs, false), CacheInstruction::Nothing(_) => (), CacheInstruction::Reset { reset_non_granular, From 81cd58483d9b36e6a387f007d5536c569f689e50 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 15:19:52 -0400 Subject: [PATCH 103/107] have cache out do nothing --- mistralrs-core/src/pipeline/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 82f57f3a7..7080e1185 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -696,6 +696,7 @@ pub trait Pipeline: match post_op { // CacheInstruction::Out => self.clone_out_cache(input_seqs, false), + CacheInstruction::Out(_) => (), CacheInstruction::Nothing(_) => (), CacheInstruction::Reset { reset_non_granular, From 1456c72aaf566af41b5db244e50eddb473ab0b17 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 15:20:26 -0400 Subject: [PATCH 104/107] fix syntax --- mistralrs-core/src/pipeline/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 7080e1185..429e415f7 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -696,7 +696,7 @@ pub trait Pipeline: match post_op { // CacheInstruction::Out => self.clone_out_cache(input_seqs, false), - CacheInstruction::Out(_) => (), + CacheInstruction::Out => (), CacheInstruction::Nothing(_) => (), CacheInstruction::Reset { reset_non_granular, From d52cdd87dfc0e74640bc9d22848a6bdabc2dde54 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 15:28:20 -0400 Subject: [PATCH 105/107] remove clone in cache --- mistralrs-core/src/pipeline/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 429e415f7..ea1ddd721 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -639,7 +639,7 @@ pub trait Pipeline: } AdapterInstruction::None => 0, }; - self.clone_in_cache(input_seqs, false) + // self.clone_in_cache(input_seqs, false) } CacheInstruction::Nothing(ref adapter_inst) => { match adapter_inst { From bf80940adaf4ba0f85549890ab24aaffe20e9bcd Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 15:36:05 -0400 Subject: [PATCH 106/107] remove loggers --- mistralrs-core/src/models/llama.rs | 14 +++++++------- mistralrs-core/src/pipeline/cache_manager.rs | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 06416a6be..39ecfef7b 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -469,10 +469,10 @@ impl Llama { let device_chunk = block.get_device(); // x = self.mapper.map(x, block_idx)?; // x = self.mapper.map(&chunks[0], block_idx)?; - println!("block_idx {:?}", block_idx); + // println!("block_idx {:?}", block_idx); // println!("chunk device {:?}", chunks[0].device()); for (chunk_idx, chunk) in chunks.iter().enumerate() { - println!("chunk_idx {:?}", chunk_idx); + // println!("chunk_idx {:?}", chunk_idx); let mut x = if block_idx == 0 { let tensor = chunk.clone(); self.mapper.map(tensor.clone(), block_idx)?; @@ -488,7 +488,7 @@ impl Llama { for cache_rotation in 0..num_caches { let cache_idx = (chunk_idx + cache_rotation) % num_caches; let kv_cache = &self.kv_caches[cache_idx]; - println!("cache_idx {:?}", cache_idx); + // println!("cache_idx {:?}", cache_idx); let mut cache = kv_cache.lock(); @@ -530,7 +530,7 @@ impl Llama { // .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), &mut **metadata)), // )?; - println!("before block forward"); + // println!("before block forward"); x = block.forward( &x, &mask.clone().map(|m| m.to_device(&device_chunk).unwrap()), @@ -550,7 +550,7 @@ impl Llama { }), )?; - println!("after block forward"); + // println!("after block forward"); // Accumulate attention results if block_chunks.len() <= chunk_idx { @@ -570,7 +570,7 @@ impl Llama { // let block_chunks = block_chunks?; // Propagate any errors - println!("concat block chunks"); + // println!("concat block chunks"); let mut x = candle_core::Tensor::cat(&block_chunks, 1)?; // do feedforward after attention has been run for each chunk @@ -580,7 +580,7 @@ impl Llama { x = x.to_device(&target_device)?; processed_chunks.push(x.clone()); } - println!("concat processed chunks"); + // println!("concat processed chunks"); x = candle_core::Tensor::cat(&processed_chunks, 1)?; let x = x.to_device(&self.device)?; let mut x = self.ln_f.forward(&x)?; diff --git a/mistralrs-core/src/pipeline/cache_manager.rs b/mistralrs-core/src/pipeline/cache_manager.rs index 009c7d8c0..e42d27289 100644 --- a/mistralrs-core/src/pipeline/cache_manager.rs +++ b/mistralrs-core/src/pipeline/cache_manager.rs @@ -207,12 +207,12 @@ fn clone_out_cache( seqs: &mut [&mut crate::sequence::Sequence], target: SeqCache, ) { - println!("clone_out_cache"); + // println!("clone_out_cache"); for layer in 0..num_hidden_layers { let cache = cache.get(layer).unwrap(); let k_cache = cache.as_ref().unwrap().0.clone(); let v_cache = cache.as_ref().unwrap().1.clone(); - println!("after v_cache"); + // println!("after v_cache"); let k_caches = k_cache.chunk(seqs.len(), 0).unwrap(); debug_assert_eq!(k_caches.len(), seqs.len()); From a4dcd1ebd6b44471f1c115b7f885c7752c4b8be0 Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Tue, 3 Sep 2024 16:07:43 -0400 Subject: [PATCH 107/107] test speculative --- mistralrs-core/src/pipeline/mod.rs | 8 ++++---- mistralrs-core/src/pipeline/speculative.rs | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index ea1ddd721..7b9e52819 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -639,7 +639,7 @@ pub trait Pipeline: } AdapterInstruction::None => 0, }; - // self.clone_in_cache(input_seqs, false) + self.clone_in_cache(input_seqs, false) } CacheInstruction::Nothing(ref adapter_inst) => { match adapter_inst { @@ -692,11 +692,11 @@ pub trait Pipeline: }) .collect::>>()?; - println!("get logits"); + // println!("get logits"); match post_op { - // CacheInstruction::Out => self.clone_out_cache(input_seqs, false), - CacheInstruction::Out => (), + CacheInstruction::Out => self.clone_out_cache(input_seqs, false), + // CacheInstruction::Out => (), CacheInstruction::Nothing(_) => (), CacheInstruction::Reset { reset_non_granular, diff --git a/mistralrs-core/src/pipeline/speculative.rs b/mistralrs-core/src/pipeline/speculative.rs index b1a9bf58b..2ee006df0 100644 --- a/mistralrs-core/src/pipeline/speculative.rs +++ b/mistralrs-core/src/pipeline/speculative.rs @@ -229,6 +229,7 @@ impl IsqPipelineMixin for SpeculativePipeline { impl CacheManagerMixin for SpeculativePipeline { fn clone_in_cache(&self, seqs: &mut [&mut Sequence], modify_draft_cache: bool) { + println!("in speculative"); DefaultCacheManager.clone_in_cache( &*get_mut_arcmutex!(self.draft), seqs,