diff --git a/mistralrs-core/src/diffusion_models/flux/model.rs b/mistralrs-core/src/diffusion_models/flux/model.rs index f9e8694bd..6b60add03 100644 --- a/mistralrs-core/src/diffusion_models/flux/model.rs +++ b/mistralrs-core/src/diffusion_models/flux/model.rs @@ -1,9 +1,14 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] +use std::sync::Arc; + use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm::RmsNormNonQuantized, LayerNorm, Linear, RmsNorm, VarBuilder}; +use candle_nn::{layer_norm::RmsNormNonQuantized, LayerNorm, RmsNorm, VarBuilder}; +use mistralrs_quant::QuantMethod; use serde::Deserialize; +use crate::{device_map::DeviceMapper, pipeline::IsqModel, DeviceMapMetadata}; + const MLP_RATIO: f64 = 4.; const HIDDEN_SIZE: usize = 3072; const AXES_DIM: &[usize] = &[16, 56, 56]; @@ -134,16 +139,17 @@ impl candle_core::Module for EmbedNd { } } +// Don't ISQ, so remember to do ISQ device map outside of here! #[derive(Debug, Clone)] pub struct MlpEmbedder { - in_layer: Linear, - out_layer: Linear, + in_layer: Arc, + out_layer: Arc, } impl MlpEmbedder { fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result { - let in_layer = candle_nn::linear(in_sz, h_sz, vb.pp("in_layer"))?; - let out_layer = candle_nn::linear(h_sz, h_sz, vb.pp("out_layer"))?; + let in_layer = mistralrs_quant::linear(in_sz, h_sz, &None, vb.pp("in_layer"))?; + let out_layer = mistralrs_quant::linear(h_sz, h_sz, &None, vb.pp("out_layer"))?; Ok(Self { in_layer, out_layer, @@ -153,7 +159,14 @@ impl MlpEmbedder { impl candle_core::Module for MlpEmbedder { fn forward(&self, xs: &Tensor) -> Result { - xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer) + let original_dtype = xs.dtype(); + let mut xs = xs.clone(); + if let Some(t) = self.in_layer.quantized_act_type() { + xs = xs.to_dtype(t)?; + } + self.out_layer + .forward(&self.in_layer.forward(&xs)?.silu()?)? + .to_dtype(original_dtype) } } @@ -195,19 +208,25 @@ impl ModulationOut { #[derive(Debug, Clone)] struct Modulation1 { - lin: Linear, + lin: Arc, } impl Modulation1 { fn new(dim: usize, vb: VarBuilder) -> Result { - let lin = candle_nn::linear(dim, 3 * dim, vb.pp("lin"))?; + let lin = mistralrs_quant::linear(dim, 3 * dim, &None, vb.pp("lin"))?; Ok(Self { lin }) } fn forward(&self, vec_: &Tensor) -> Result { - let ys = vec_ - .silu()? - .apply(&self.lin)? + let original_dtype = vec_.dtype(); + let mut vec_ = vec_.clone(); + if let Some(t) = self.lin.quantized_act_type() { + vec_ = vec_.to_dtype(t)?; + } + let ys = self + .lin + .forward(&vec_.silu()?)? + .to_dtype(original_dtype)? .unsqueeze(1)? .chunk(3, D::Minus1)?; if ys.len() != 3 { @@ -223,19 +242,25 @@ impl Modulation1 { #[derive(Debug, Clone)] struct Modulation2 { - lin: Linear, + lin: Arc, } impl Modulation2 { fn new(dim: usize, vb: VarBuilder) -> Result { - let lin = candle_nn::linear(dim, 6 * dim, vb.pp("lin"))?; + let lin = mistralrs_quant::linear(dim, 6 * dim, &None, vb.pp("lin"))?; Ok(Self { lin }) } fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> { - let ys = vec_ - .silu()? - .apply(&self.lin)? + let original_dtype = vec_.dtype(); + let mut vec_ = vec_.clone(); + if let Some(t) = self.lin.quantized_act_type() { + vec_ = vec_.to_dtype(t)?; + } + let ys = self + .lin + .forward(&vec_.silu()?)? + .to_dtype(original_dtype)? .unsqueeze(1)? .chunk(6, D::Minus1)?; if ys.len() != 6 { @@ -257,18 +282,36 @@ impl Modulation2 { #[derive(Debug, Clone)] pub struct SelfAttention { - qkv: Linear, + qkv: Arc, norm: QkNorm, - proj: Linear, + proj: Arc, num_attention_heads: usize, } impl SelfAttention { - fn new(dim: usize, num_attention_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result { + fn new( + dim: usize, + num_attention_heads: usize, + qkv_bias: bool, + vb: VarBuilder, + mapper: &dyn DeviceMapper, + loading_isq: bool, + ) -> Result { let head_dim = dim / num_attention_heads; - let qkv = candle_nn::linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?; - let norm = QkNorm::new(head_dim, vb.pp("norm"))?; - let proj = candle_nn::linear(dim, dim, vb.pp("proj"))?; + let qkv = mistralrs_quant::linear_b( + dim, + dim * 3, + qkv_bias, + &None, + mapper.set_nm_device(vb.pp("qkv"), loading_isq), + )?; + let norm = QkNorm::new(head_dim, mapper.set_nm_device(vb.pp("norm"), false))?; + let proj = mistralrs_quant::linear( + dim, + dim, + &None, + mapper.set_nm_device(vb.pp("proj"), loading_isq), + )?; Ok(Self { qkv, norm, @@ -278,7 +321,12 @@ impl SelfAttention { } fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { - let qkv = xs.apply(&self.qkv)?; + let original_dtype = xs.dtype(); + let mut xs = xs.clone(); + if let Some(t) = self.qkv.quantized_act_type() { + xs = xs.to_dtype(t)?; + } + let qkv = self.qkv.forward(&xs)?.to_dtype(original_dtype)?; let (b, l, _khd) = qkv.dims3()?; let qkv = qkv.reshape((b, l, 3, self.num_attention_heads, ()))?; let q = qkv.i((.., .., 0))?.transpose(1, 2)?; @@ -292,61 +340,39 @@ impl SelfAttention { #[allow(unused)] fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result { let (q, k, v) = self.qkv(xs)?; - attention(&q, &k, &v, pe)?.apply(&self.proj) - } - - fn cast_to(&mut self, device: &Device) -> Result<()> { - self.qkv = Linear::new( - self.qkv.weight().to_device(device)?, - self.qkv.bias().map(|x| x.to_device(device).unwrap()), - ); - self.proj = Linear::new( - self.proj.weight().to_device(device)?, - self.proj.bias().map(|x| x.to_device(device).unwrap()), - ); - self.norm = QkNorm { - query_norm: RmsNorm::::new( - self.norm.query_norm.inner().weight().to_device(device)?, - 1e-6, - ), - key_norm: RmsNorm::::new( - self.norm.key_norm.inner().weight().to_device(device)?, - 1e-6, - ), - }; - Ok(()) + let mut attn_weights = attention(&q, &k, &v, pe)?; + let original_dtype = attn_weights.dtype(); + if let Some(t) = self.proj.quantized_act_type() { + attn_weights = attn_weights.to_dtype(t)?; + } + self.proj.forward(&attn_weights)?.to_dtype(original_dtype) } } #[derive(Debug, Clone)] struct Mlp { - lin1: Linear, - lin2: Linear, + lin1: Arc, + lin2: Arc, } impl Mlp { fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result { - let lin1 = candle_nn::linear(in_sz, mlp_sz, vb.pp("0"))?; - let lin2 = candle_nn::linear(mlp_sz, in_sz, vb.pp("2"))?; + let lin1 = mistralrs_quant::linear(in_sz, mlp_sz, &None, vb.pp("0"))?; + let lin2 = mistralrs_quant::linear(mlp_sz, in_sz, &None, vb.pp("2"))?; Ok(Self { lin1, lin2 }) } - - fn cast_to(&mut self, device: &Device) -> Result<()> { - self.lin1 = Linear::new( - self.lin1.weight().to_device(device)?, - self.lin1.bias().map(|x| x.to_device(device).unwrap()), - ); - self.lin2 = Linear::new( - self.lin2.weight().to_device(device)?, - self.lin2.bias().map(|x| x.to_device(device).unwrap()), - ); - Ok(()) - } } impl candle_core::Module for Mlp { fn forward(&self, xs: &Tensor) -> Result { - xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2) + let original_dtype = xs.dtype(); + let mut xs = xs.clone(); + if let Some(t) = self.lin1.quantized_act_type() { + xs = xs.to_dtype(t)?; + } + self.lin2 + .forward(&self.lin1.forward(&xs)?.gelu()?)? + .to_dtype(original_dtype) } } @@ -365,19 +391,46 @@ pub struct DoubleStreamBlock { } impl DoubleStreamBlock { - fn new(cfg: &Config, vb: VarBuilder) -> Result { + fn new( + cfg: &Config, + vb: VarBuilder, + mapper: &dyn DeviceMapper, + loading_isq: bool, + ) -> Result { let h_sz = HIDDEN_SIZE; let mlp_sz = (h_sz as f64 * MLP_RATIO) as usize; - let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?; - let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?; - let img_attn = SelfAttention::new(h_sz, cfg.num_attention_heads, true, vb.pp("img_attn"))?; - let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?; - let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?; - let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?; - let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?; - let txt_attn = SelfAttention::new(h_sz, cfg.num_attention_heads, true, vb.pp("txt_attn"))?; - let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?; - let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?; + let img_mod = Modulation2::new(h_sz, mapper.set_nm_device(vb.pp("img_mod"), loading_isq))?; + let img_norm1 = layer_norm(h_sz, mapper.set_nm_device(vb.pp("img_norm1"), false))?; + let img_attn = SelfAttention::new( + h_sz, + cfg.num_attention_heads, + true, + vb.pp("img_attn"), + mapper, + loading_isq, + )?; + let img_norm2 = layer_norm(h_sz, mapper.set_nm_device(vb.pp("img_norm2"), false))?; + let img_mlp = Mlp::new( + h_sz, + mlp_sz, + mapper.set_nm_device(vb.pp("img_mlp"), loading_isq), + )?; + let txt_mod = Modulation2::new(h_sz, mapper.set_nm_device(vb.pp("txt_mod"), loading_isq))?; + let txt_norm1 = layer_norm(h_sz, mapper.set_nm_device(vb.pp("txt_norm1"), false))?; + let txt_attn = SelfAttention::new( + h_sz, + cfg.num_attention_heads, + true, + vb.pp("txt_attn"), + mapper, + loading_isq, + )?; + let txt_norm2 = layer_norm(h_sz, mapper.set_nm_device(vb.pp("txt_norm2"), false))?; + let txt_mlp = Mlp::new( + h_sz, + mlp_sz, + mapper.set_nm_device(vb.pp("txt_mlp"), loading_isq), + )?; Ok(Self { img_mod, img_norm1, @@ -414,10 +467,23 @@ impl DoubleStreamBlock { let v = Tensor::cat(&[txt_v, img_v], 2)?; let attn = attention(&q, &k, &v, pe)?; - let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?; - let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?; + let mut txt_attn = attn.narrow(1, 0, txt.dim(1)?)?; + let mut img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?; + + let original_dtype = img_attn.dtype(); + if let Some(t) = self.img_attn.proj.quantized_act_type() { + img_attn = img_attn.to_dtype(t)?; + txt_attn = txt_attn.to_dtype(t)?; + } - let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?; + let img = (img + + img_mod1.gate( + &self + .img_attn + .proj + .forward(&img_attn)? + .to_dtype(original_dtype)?, + ))?; let img = (&img + img_mod2.gate( &img_mod2 @@ -425,7 +491,14 @@ impl DoubleStreamBlock { .apply(&self.img_mlp)?, )?)?; - let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?; + let txt = (txt + + txt_mod1.gate( + &self + .txt_attn + .proj + .forward(&txt_attn)? + .to_dtype(original_dtype)?, + ))?; let txt = (&txt + txt_mod2.gate( &txt_mod2 @@ -435,56 +508,12 @@ impl DoubleStreamBlock { Ok((img, txt)) } - - fn cast_to(&mut self, device: &Device) -> Result<()> { - self.img_mod.lin = Linear::new( - self.img_mod.lin.weight().to_device(device)?, - self.img_mod - .lin - .bias() - .map(|x| x.to_device(device).unwrap()), - ); - self.img_norm1 = LayerNorm::new( - self.img_norm1.weight().to_device(device)?, - self.img_norm1.bias().to_device(device)?, - 1e-6, - ); - self.img_attn.cast_to(device)?; - self.img_norm2 = LayerNorm::new( - self.img_norm2.weight().to_device(device)?, - self.img_norm2.bias().to_device(device)?, - 1e-6, - ); - self.img_mlp.cast_to(device)?; - - self.txt_mod.lin = Linear::new( - self.txt_mod.lin.weight().to_device(device)?, - self.txt_mod - .lin - .bias() - .map(|x| x.to_device(device).unwrap()), - ); - self.txt_norm1 = LayerNorm::new( - self.txt_norm1.weight().to_device(device)?, - self.txt_norm1.bias().to_device(device)?, - 1e-6, - ); - self.txt_attn.cast_to(device)?; - self.txt_norm2 = LayerNorm::new( - self.txt_norm2.weight().to_device(device)?, - self.txt_norm2.bias().to_device(device)?, - 1e-6, - ); - self.txt_mlp.cast_to(device)?; - - Ok(()) - } } #[derive(Debug, Clone)] pub struct SingleStreamBlock { - linear1: Linear, - linear2: Linear, + linear1: Arc, + linear2: Arc, norm: QkNorm, pre_norm: LayerNorm, modulation: Modulation1, @@ -494,15 +523,31 @@ pub struct SingleStreamBlock { } impl SingleStreamBlock { - fn new(cfg: &Config, vb: VarBuilder) -> Result { + fn new( + cfg: &Config, + vb: VarBuilder, + mapper: &dyn DeviceMapper, + loading_isq: bool, + ) -> Result { let h_sz = HIDDEN_SIZE; let mlp_sz = (h_sz as f64 * MLP_RATIO) as usize; let head_dim = h_sz / cfg.num_attention_heads; - let linear1 = candle_nn::linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?; - let linear2 = candle_nn::linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?; - let norm = QkNorm::new(head_dim, vb.pp("norm"))?; - let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?; - let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?; + let linear1 = mistralrs_quant::linear( + h_sz, + h_sz * 3 + mlp_sz, + &None, + mapper.set_nm_device(vb.pp("linear1"), loading_isq), + )?; + let linear2 = mistralrs_quant::linear( + h_sz + mlp_sz, + h_sz, + &None, + mapper.set_nm_device(vb.pp("linear2"), loading_isq), + )?; + let norm = QkNorm::new(head_dim, mapper.set_nm_device(vb.pp("norm"), false))?; + let pre_norm = layer_norm(h_sz, mapper.set_nm_device(vb.pp("pre_norm"), false))?; + let modulation = + Modulation1::new(h_sz, mapper.set_nm_device(vb.pp("modulation"), loading_isq))?; Ok(Self { linear1, linear2, @@ -517,8 +562,14 @@ impl SingleStreamBlock { fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result { let mod_ = self.modulation.forward(vec_)?; - let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?; - let x_mod = x_mod.apply(&self.linear1)?; + let mut x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?; + + let original_dtype = x_mod.dtype(); + if let Some(t) = self.linear1.quantized_act_type() { + x_mod = x_mod.to_dtype(t)?; + } + + let x_mod = self.linear1.forward(&x_mod)?.to_dtype(original_dtype)?; let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?; let (b, l, _khd) = qkv.dims3()?; let qkv = qkv.reshape((b, l, 3, self.num_attention_heads, ()))?; @@ -529,57 +580,47 @@ impl SingleStreamBlock { let q = q.apply(&self.norm.query_norm)?; let k = k.apply(&self.norm.key_norm)?; let attn = attention(&q, &k, &v, pe)?; - let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?; - xs + mod_.gate(&output) - } - fn cast_to(&mut self, device: &Device) -> Result<()> { - self.linear1 = Linear::new( - self.linear1.weight().to_device(device)?, - self.linear1.bias().map(|x| x.to_device(device).unwrap()), - ); - self.linear2 = Linear::new( - self.linear2.weight().to_device(device)?, - self.linear2.bias().map(|x| x.to_device(device).unwrap()), - ); - self.norm = QkNorm { - query_norm: RmsNorm::::new( - self.norm.query_norm.inner().weight().to_device(device)?, - 1e-6, - ), - key_norm: RmsNorm::::new( - self.norm.key_norm.inner().weight().to_device(device)?, - 1e-6, - ), - }; - self.pre_norm = LayerNorm::new( - self.pre_norm.weight().to_device(device)?, - self.pre_norm.bias().to_device(device)?, - 1e-6, - ); - self.modulation.lin = Linear::new( - self.modulation.lin.weight().to_device(device)?, - self.modulation - .lin - .bias() - .map(|x| x.to_device(device).unwrap()), - ); - Ok(()) + let mut xs2 = Tensor::cat(&[attn, mlp.gelu()?], 2)?; + let original_dtype = xs2.dtype(); + if let Some(t) = self.linear2.quantized_act_type() { + xs2 = xs2.to_dtype(t)?; + } + + let output = self.linear2.forward(&xs2)?.to_dtype(original_dtype)?; + xs + mod_.gate(&output) } } #[derive(Debug, Clone)] pub struct LastLayer { norm_final: LayerNorm, - linear: Linear, - ada_ln_modulation: Linear, + linear: Arc, + ada_ln_modulation: Arc, } impl LastLayer { - fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result { - let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?; - let linear = candle_nn::linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?; - let ada_ln_modulation = candle_nn::linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?; + fn new( + h_sz: usize, + p_sz: usize, + out_c: usize, + vb: VarBuilder, + mapper: &dyn DeviceMapper, + loading_isq: bool, + ) -> Result { + let norm_final = layer_norm(h_sz, mapper.set_nm_device(vb.pp("norm_final"), false))?; + let linear = mistralrs_quant::linear( + h_sz, + p_sz * p_sz * out_c, + &None, + mapper.set_nm_device(vb.pp("linear"), loading_isq), + )?; + let ada_ln_modulation = mistralrs_quant::linear( + h_sz, + 2 * h_sz, + &None, + mapper.set_nm_device(vb.pp("adaLN_modulation.1"), loading_isq), + )?; Ok(Self { norm_final, linear, @@ -588,20 +629,34 @@ impl LastLayer { } fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result { - let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?; + let original_dtype = vec.dtype(); + let mut vec = vec.clone(); + if let Some(t) = self.ada_ln_modulation.quantized_act_type() { + vec = vec.to_dtype(t)?; + } + let chunks = self + .ada_ln_modulation + .forward(&vec.silu()?)? + .to_dtype(original_dtype)? + .chunk(2, 1)?; let (shift, scale) = (&chunks[0], &chunks[1]); - let xs = xs + let mut xs = xs .apply(&self.norm_final)? .broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)? .broadcast_add(&shift.unsqueeze(1)?)?; - xs.apply(&self.linear) + + let original_dtype = xs.dtype(); + if let Some(t) = self.linear.quantized_act_type() { + xs = xs.to_dtype(t)?; + } + self.linear.forward(&xs)?.to_dtype(original_dtype) } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct Flux { - img_in: Linear, - txt_in: Linear, + img_in: Arc, + txt_in: Arc, time_in: MlpEmbedder, vector_in: MlpEmbedder, guidance_in: Option, @@ -609,49 +664,52 @@ pub struct Flux { double_blocks: Vec, single_blocks: Vec, final_layer: LastLayer, - device: Device, - offloaded: bool, + mapper: Box, } impl Flux { - pub fn new(cfg: &Config, vb: VarBuilder, device: Device, offloaded: bool) -> Result { - let img_in = candle_nn::linear( + pub fn new(cfg: &Config, vb: VarBuilder, device: Device, loading_isq: bool) -> Result { + let mapper = DeviceMapMetadata::dummy().into_mapper(0, &device, None)?; + + let img_in = mistralrs_quant::linear( cfg.in_channels, HIDDEN_SIZE, - vb.pp("img_in").set_device(device.clone()), + &None, + mapper.set_nm_device(vb.pp("img_in"), loading_isq), )?; - let txt_in = candle_nn::linear( + let txt_in = mistralrs_quant::linear( cfg.joint_attention_dim, HIDDEN_SIZE, - vb.pp("txt_in").set_device(device.clone()), + &None, + mapper.set_nm_device(vb.pp("txt_in"), loading_isq), )?; let mut double_blocks = Vec::with_capacity(cfg.num_layers); let vb_d = vb.pp("double_blocks"); for idx in 0..cfg.num_layers { - let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?; + let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx), &*mapper, loading_isq)?; double_blocks.push(db) } let mut single_blocks = Vec::with_capacity(cfg.num_single_layers); let vb_s = vb.pp("single_blocks"); for idx in 0..cfg.num_single_layers { - let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?; + let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx), &*mapper, loading_isq)?; single_blocks.push(sb) } let time_in = MlpEmbedder::new( 256, HIDDEN_SIZE, - vb.pp("time_in").set_device(device.clone()), + mapper.set_nm_device(vb.pp("time_in"), loading_isq), )?; let vector_in = MlpEmbedder::new( cfg.pooled_projection_dim, HIDDEN_SIZE, - vb.pp("vector_in").set_device(device.clone()), + mapper.set_nm_device(vb.pp("vector_in"), loading_isq), )?; let guidance_in = if cfg.guidance_embeds { let mlp = MlpEmbedder::new( 256, HIDDEN_SIZE, - vb.pp("guidance_in").set_device(device.clone()), + mapper.set_nm_device(vb.pp("guidance_in"), false), )?; Some(mlp) } else { @@ -661,7 +719,9 @@ impl Flux { HIDDEN_SIZE, 1, cfg.in_channels, - vb.pp("final_layer").set_device(device.clone()), + vb.pp("final_layer"), + &*mapper, + loading_isq, )?; let pe_dim = HIDDEN_SIZE / cfg.num_attention_heads; let pe_embedder = EmbedNd::new(pe_dim, THETA, AXES_DIM.to_vec()); @@ -675,8 +735,7 @@ impl Flux { double_blocks, single_blocks, final_layer, - device: device.clone(), - offloaded, + mapper, }) } @@ -702,8 +761,15 @@ impl Flux { let ids = Tensor::cat(&[txt_ids, img_ids], 1)?; ids.apply(&self.pe_embedder)? }; - let mut txt = txt.apply(&self.txt_in)?; - let mut img = img.apply(&self.img_in)?; + let original_dtype = txt.dtype(); + let mut txt = txt.clone(); + let mut img = img.clone(); + if let Some(t) = self.txt_in.quantized_act_type() { + txt = txt.to_dtype(t)?; + img = img.to_dtype(t)?; + } + let mut txt = self.txt_in.forward(&txt)?.to_dtype(original_dtype)?; + let mut img = self.img_in.forward(&img)?.to_dtype(original_dtype)?; let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?; let vec_ = match (self.guidance_in.as_ref(), guidance) { (Some(g_in), Some(guidance)) => { @@ -715,26 +781,59 @@ impl Flux { // Double blocks for block in self.double_blocks.iter_mut() { - if self.offloaded { - block.cast_to(&self.device)?; - } (img, txt) = block.forward(&img, &txt, &vec_, &pe)?; - if self.offloaded { - block.cast_to(&Device::Cpu)?; - } } // Single blocks let mut img = Tensor::cat(&[&txt, &img], 1)?; for block in self.single_blocks.iter_mut() { - if self.offloaded { - block.cast_to(&self.device)?; - } img = block.forward(&img, &vec_, &pe)?; - if self.offloaded { - block.cast_to(&Device::Cpu)?; - } } let img = img.i((.., txt.dim(1)?..))?; self.final_layer.forward(&img, &vec_) } } + +impl IsqModel for Flux { + fn get_layers( + &mut self, + ) -> ( + Vec<(&mut Arc, Option)>, + &dyn DeviceMapper, + ) { + let mut layers = vec![ + &mut self.img_in, + &mut self.txt_in, + &mut self.time_in.in_layer, + &mut self.time_in.out_layer, + &mut self.vector_in.in_layer, + &mut self.vector_in.out_layer, + &mut self.final_layer.linear, + &mut self.final_layer.ada_ln_modulation, + ]; + + for double_layer in &mut self.double_blocks { + layers.push(&mut double_layer.img_attn.proj); + layers.push(&mut double_layer.img_attn.qkv); + layers.push(&mut double_layer.img_mlp.lin1); + layers.push(&mut double_layer.img_mlp.lin2); + layers.push(&mut double_layer.img_mod.lin); + + layers.push(&mut double_layer.txt_attn.proj); + layers.push(&mut double_layer.txt_attn.qkv); + layers.push(&mut double_layer.txt_mlp.lin1); + layers.push(&mut double_layer.txt_mlp.lin2); + layers.push(&mut double_layer.txt_mod.lin); + } + + for single_layer in &mut self.single_blocks { + layers.push(&mut single_layer.linear1); + layers.push(&mut single_layer.linear2); + layers.push(&mut single_layer.modulation.lin); + } + + ( + layers.into_iter().map(|l| (l, None)).collect(), + &*self.mapper, + ) + } +} diff --git a/mistralrs-core/src/diffusion_models/flux/stepper.rs b/mistralrs-core/src/diffusion_models/flux/stepper.rs index 3a7501a9a..cace459b2 100644 --- a/mistralrs-core/src/diffusion_models/flux/stepper.rs +++ b/mistralrs-core/src/diffusion_models/flux/stepper.rs @@ -13,7 +13,7 @@ use crate::{ t5::{self, T5EncoderModel}, DiffusionGenerationParams, }, - pipeline::DiffusionModel, + pipeline::{DiffusionModel, IsqModel}, utils::varbuilder_utils::from_mmaped_safetensors, }; @@ -159,7 +159,12 @@ impl FluxStepper { device: &Device, silent: bool, offloaded: bool, + loading_isq: bool, ) -> anyhow::Result { + if offloaded { + anyhow::bail!("`offloaded` is not tsupported."); + } + let api = Api::new()?; info!("Loading T5 XXL tokenizer."); @@ -172,7 +177,7 @@ impl FluxStepper { t5_tok: t5_tokenizer, clip_tok: clip_tokenizer, clip_text: clip_encoder, - flux_model: Flux::new(flux_cfg, flux_vb, device.clone(), offloaded)?, + flux_model: Flux::new(flux_cfg, flux_vb, device.clone(), loading_isq)?, flux_vae: AutoEncoder::new(flux_ae_cfg, flux_ae_vb)?, is_guidance: cfg.is_guidance, device: device.clone(), @@ -281,3 +286,17 @@ impl DiffusionModel for FluxStepper { } } } + +impl IsqModel for FluxStepper { + fn get_layers( + &mut self, + ) -> ( + Vec<( + &mut std::sync::Arc, + Option, + )>, + &dyn crate::device_map::DeviceMapper, + ) { + self.flux_model.get_layers() + } +} diff --git a/mistralrs-core/src/pipeline/diffusion.rs b/mistralrs-core/src/pipeline/diffusion.rs index cdb73ca78..13383c5bb 100644 --- a/mistralrs-core/src/pipeline/diffusion.rs +++ b/mistralrs-core/src/pipeline/diffusion.rs @@ -2,7 +2,7 @@ use super::loaders::{DiffusionModelPaths, DiffusionModelPathsInner}; use super::{ AdapterActivationMixin, AnyMoePipelineMixin, Cache, CacheManagerMixin, DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, FluxLoader, ForwardInputsResult, GeneralMetadata, - IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths, + IsqOrganization, IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths, PreProcessingMixin, Processor, TokenSource, }; use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs}; @@ -150,10 +150,6 @@ impl Loader for DiffusionLoader { anyhow::bail!("Device mapping is not supported for Diffusion models."); } - if in_situ_quant.is_some() { - anyhow::bail!("ISQ is not supported for Diffusion models."); - } - if paged_attn_config.is_some() { warn!("PagedAttention is not supported for Diffusion models, disabling it."); @@ -175,12 +171,12 @@ impl Loader for DiffusionLoader { AttentionImplementation::Eager }; - let model = match self.kind { + let mut model = match self.kind { ModelKind::Normal => { let vbs = paths .filenames .iter() - .zip(self.inner.force_cpu_vb()) + .zip(self.inner.force_cpu_vb(in_situ_quant.is_some())) .map(|(path, force_cpu)| { from_mmaped_safetensors( vec![path.clone()], @@ -200,7 +196,7 @@ impl Loader for DiffusionLoader { vbs, crate::pipeline::NormalLoadingMetadata { mapper, - loading_isq: false, + loading_isq: in_situ_quant.is_some(), real_device: device.clone(), }, attention_mechanism, @@ -210,6 +206,17 @@ impl Loader for DiffusionLoader { _ => unreachable!(), }; + if in_situ_quant.is_some() { + model.quantize( + in_situ_quant, + device.clone(), + None, + silent, + IsqOrganization::Default, + None, // self.config.write_uqff.as_ref(), + )?; + } + let max_seq_len = model.max_seq_len(); Ok(Arc::new(Mutex::new(DiffusionPipeline { model, diff --git a/mistralrs-core/src/pipeline/loaders/diffusion_loaders.rs b/mistralrs-core/src/pipeline/loaders/diffusion_loaders.rs index eaf4262ff..f937ca47c 100644 --- a/mistralrs-core/src/pipeline/loaders/diffusion_loaders.rs +++ b/mistralrs-core/src/pipeline/loaders/diffusion_loaders.rs @@ -30,11 +30,12 @@ use crate::{ }, lora::LoraConfig, paged_attention::AttentionImplementation, + pipeline::IsqModel, xlora_models::XLoraConfig, Ordering, }; -pub trait DiffusionModel { +pub trait DiffusionModel: IsqModel { /// This returns a tensor of shape (bs, c, h, w), with values in [0, 255]. fn forward( &mut self, @@ -50,7 +51,7 @@ pub trait DiffusionModelLoader { fn get_model_paths(&self, api: &ApiRepo, model_id: &Path) -> Result>; /// If the model is being loaded with `load_model_from_hf` (so manual paths not provided), this will be called. fn get_config_filenames(&self, api: &ApiRepo, model_id: &Path) -> Result>; - fn force_cpu_vb(&self) -> Vec; + fn force_cpu_vb(&self, loading_isq: bool) -> Vec; // `configs` and `vbs` should be corresponding. It is up to the implementer to maintain this invaraint. fn load( &self, @@ -166,8 +167,8 @@ impl DiffusionModelLoader for FluxLoader { // NOTE(EricLBuehler): disgusting way of doing this but the 0th path is the flux, 1 is ae Ok(vec![flux_file, ae_file]) } - fn force_cpu_vb(&self) -> Vec { - vec![self.offload, false] + fn force_cpu_vb(&self, loading_isq: bool) -> Vec { + vec![self.offload || loading_isq, false] } fn load( &self, @@ -200,6 +201,7 @@ impl DiffusionModelLoader for FluxLoader { &normal_loading_metadata.real_device, silent, self.offload, + normal_loading_metadata.loading_isq, )?)) } } diff --git a/mistralrs-server/src/interactive_mode.rs b/mistralrs-server/src/interactive_mode.rs index 3575f3fea..643015cd9 100644 --- a/mistralrs-server/src/interactive_mode.rs +++ b/mistralrs-server/src/interactive_mode.rs @@ -418,7 +418,7 @@ async fn diffusion_interactive_mode(mistralrs: Arc) { info!("Starting interactive loop with generation params: {diffusion_params:?}"); println!( - "{}{TEXT_INTERACTIVE_HELP}{}", + "{}{DIFFUSION_INTERACTIVE_HELP}{}", "=".repeat(20), "=".repeat(20) );