diff --git a/cog.yaml.template b/cog.yaml.template index 4be2f0b..dc304e5 100644 --- a/cog.yaml.template +++ b/cog.yaml.template @@ -20,13 +20,15 @@ build: - "tokenizers==0.19.1" - "protobuf==5.27.2" - "diffusers==0.29.2" - - "loguru" - - "pybase64" + - "loguru==0.7.2" + - "pybase64==1.4.0" + - "pydash==8.0.3" # commands run after the environment is setup run: - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget - pip uninstall -y torch torchvision torchaudio - - pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124 - - pip install pydash \ No newline at end of file + # - pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124 + # pinning to specific nightlies for release + - pip3 install https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240918%2Bcu124-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu124/torchaudio-2.5.0.dev20240918%2Bcu124-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu124/torchvision-0.20.0.dev20240918%2Bcu124-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2B5fe38ffd73-cp311-cp311-linux_x86_64.whl diff --git a/fp8/__init__.py b/fp8/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/configs/config-1-flux-dev-h100.json b/fp8/configs/config-1-flux-dev-h100.json similarity index 100% rename from configs/config-1-flux-dev-h100.json rename to fp8/configs/config-1-flux-dev-h100.json diff --git a/configs/config-1-flux-schnell-h100.json b/fp8/configs/config-1-flux-schnell-h100.json similarity index 100% rename from configs/config-1-flux-schnell-h100.json rename to fp8/configs/config-1-flux-schnell-h100.json diff --git a/float8_quantize.py b/fp8/float8_quantize.py similarity index 99% rename from float8_quantize.py rename to fp8/float8_quantize.py index 4161eb0..3e48e91 100644 --- a/float8_quantize.py +++ b/fp8/float8_quantize.py @@ -7,7 +7,7 @@ from torch import __version__ from torch.version import cuda -from modules.flux_model import Modulation +from fp8.modules.flux_model import Modulation IS_TORCH_2_4 = __version__ < (2, 4, 9) LT_TORCH_2_4 = __version__ < (2, 4) diff --git a/flux_pipeline.py b/fp8/flux_pipeline.py similarity index 98% rename from flux_pipeline.py rename to fp8/flux_pipeline.py index dfd80bd..6e42193 100644 --- a/flux_pipeline.py +++ b/fp8/flux_pipeline.py @@ -15,11 +15,6 @@ import torch from einops import rearrange, repeat -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True -torch.backends.cudnn.benchmark = True -torch.backends.cudnn.benchmark_limit = 20 -torch.set_float32_matmul_precision("high") from pybase64 import standard_b64decode from torch._dynamo import config from torch._inductor import config as ind_config @@ -32,9 +27,9 @@ from torchvision.transforms import functional as TF from tqdm import tqdm -import lora_loading -from image_encoder import ImageEncoder -from util import ( +import fp8.lora_loading as lora_loading +from fp8.image_encoder import ImageEncoder +from fp8.util import ( LoadedModels, ModelSpec, ModelVersion, @@ -51,9 +46,9 @@ if TYPE_CHECKING: - from modules.autoencoder import AutoEncoder - from modules.conditioner import HFEmbedder - from modules.flux_model import Flux + from fp8.modules.autoencoder import AutoEncoder + from fp8.modules.conditioner import HFEmbedder + from fp8.modules.flux_model import Flux class FluxPipeline: @@ -738,7 +733,7 @@ def load_pipeline_from_config_path( def load_pipeline_from_config( cls, config: ModelSpec, debug: bool = False, shared_models: LoadedModels = None ) -> "FluxPipeline": - from float8_quantize import quantize_flow_transformer_and_dispatch_float8 + from fp8.float8_quantize import quantize_flow_transformer_and_dispatch_float8 with torch.inference_mode(): if debug: diff --git a/image_encoder.py b/fp8/image_encoder.py similarity index 100% rename from image_encoder.py rename to fp8/image_encoder.py diff --git a/lora_loading.py b/fp8/lora_loading.py similarity index 99% rename from lora_loading.py rename to fp8/lora_loading.py index b5f65e4..935df0c 100644 --- a/lora_loading.py +++ b/fp8/lora_loading.py @@ -7,8 +7,8 @@ from cublas_ops import CublasLinear except Exception as e: CublasLinear = type(None) -from float8_quantize import F8Linear -from modules.flux_model import Flux +from fp8.float8_quantize import F8Linear +from fp8.modules.flux_model import Flux def swap_scale_shift(weight): diff --git a/modules/autoencoder.py b/fp8/modules/autoencoder.py similarity index 100% rename from modules/autoencoder.py rename to fp8/modules/autoencoder.py diff --git a/modules/conditioner.py b/fp8/modules/conditioner.py similarity index 100% rename from modules/conditioner.py rename to fp8/modules/conditioner.py diff --git a/modules/flux_model.py b/fp8/modules/flux_model.py similarity index 98% rename from modules/flux_model.py rename to fp8/modules/flux_model.py index cf4e69f..b34cd3e 100644 --- a/modules/flux_model.py +++ b/fp8/modules/flux_model.py @@ -4,7 +4,7 @@ import torch if TYPE_CHECKING: - from util import ModelSpec + from fp8.util import ModelSpec DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1" torch.backends.cuda.matmul.allow_tf32 = True @@ -118,7 +118,7 @@ class MLPEmbedder(nn.Module): def __init__( self, in_dim: int, hidden_dim: int, prequantized: bool = False, quantized=False ): - from float8_quantize import F8Linear + from fp8.float8_quantize import F8Linear super().__init__() self.in_layer = ( @@ -188,7 +188,7 @@ def __init__( prequantized: bool = False, ): super().__init__() - from float8_quantize import F8Linear + from fp8.float8_quantize import F8Linear self.num_heads = num_heads head_dim = dim // num_heads @@ -236,7 +236,7 @@ def forward(self, x: Tensor, pe: Tensor) -> Tensor: class Modulation(nn.Module): def __init__(self, dim: int, double: bool, quantized_modulation: bool = False): super().__init__() - from float8_quantize import F8Linear + from fp8.float8_quantize import F8Linear self.is_double = double self.multiplier = 6 if double else 3 @@ -272,7 +272,7 @@ def __init__( prequantized: bool = False, ): super().__init__() - from float8_quantize import F8Linear + from fp8.float8_quantize import F8Linear self.dtype = dtype @@ -417,7 +417,7 @@ def __init__( prequantized: bool = False, ): super().__init__() - from float8_quantize import F8Linear + from fp8.float8_quantize import F8Linear self.dtype = dtype self.hidden_dim = hidden_size @@ -515,7 +515,7 @@ def __init__(self, config: "ModelSpec", dtype: torch.dtype = torch.float16): prequantized_flow = config.prequantized_flow quantized_embedders = config.quantize_flow_embedder_layers and prequantized_flow quantized_modulation = config.quantize_modulation and prequantized_flow - from float8_quantize import F8Linear + from fp8.float8_quantize import F8Linear if config.params.hidden_size % config.params.num_heads != 0: raise ValueError( @@ -671,7 +671,7 @@ def forward( def from_pretrained( cls: "Flux", path: str, dtype: torch.dtype = torch.float16 ) -> "Flux": - from util import load_config_from_path + from fp8.util import load_config_from_path from safetensors.torch import load_file config = load_config_from_path(path) diff --git a/util.py b/fp8/util.py similarity index 85% rename from util.py rename to fp8/util.py index e7caca9..73c8146 100644 --- a/util.py +++ b/fp8/util.py @@ -4,9 +4,9 @@ from typing import Any, Literal, Optional import torch -from modules.autoencoder import AutoEncoder, AutoEncoderParams -from modules.conditioner import HFEmbedder -from modules.flux_model import Flux, FluxParams +from fp8.modules.autoencoder import AutoEncoder, AutoEncoderParams +from fp8.modules.conditioner import HFEmbedder +from fp8.modules.flux_model import Flux, FluxParams from safetensors.torch import load_file as load_sft try: @@ -34,55 +34,10 @@ class QuantizationDtype(StrEnum): qint8 = "qint8" -# @dataclass -# class ModelSpec: -# version: ModelVersion -# params: FluxParams -# ae_params: AutoEncoderParams -# ckpt_path: str | None -# ae_path: str | None -# repo_id: str | None -# repo_flow: str | None -# repo_ae: str | None -# text_enc_max_length: int = 512 -# text_enc_path: str | None = None -# text_enc_device: str | torch.device | None = "cuda:0" -# ae_device: str | torch.device | None = "cuda:0" -# flux_device: str | torch.device | None = "cuda:0" -# flow_dtype: str = "float16" -# ae_dtype: str = "bfloat16" -# text_enc_dtype: str = "bfloat16" -# num_to_quant: Optional[int] = 20 -# quantize_extras: bool = False -# compile_extras: bool = False -# compile_blocks: bool = False -# flow_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8 -# text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8 -# ae_quantization_dtype: Optional[QuantizationDtype] = None -# clip_quantization_dtype: Optional[QuantizationDtype] = None -# offload_text_encoder: bool = False -# offload_vae: bool = False -# offload_flow: bool = False -# prequantized_flow: bool = False -# quantize_modulation: bool = True -# quantize_flow_embedder_layers: bool = False - -# @dataclass -# class LoadedModels: -# flow: Flux -# ae: AutoEncoder -# clip: HFEmbedder -# t5: HFEmbedder -# config: ModelSpec - class ModelSpec(BaseModel): class Config: arbitrary_types_allowed = True use_enum_values = True - # model_config: ConfigDict = { - # "arbitrary_types_allowed": True, - # "use_enum_values": True, - # } version: ModelVersion params: FluxParams ae_params: AutoEncoderParams @@ -325,7 +280,7 @@ def load_autoencoder(config: ModelSpec) -> AutoEncoder: print_load_warning(missing, unexpected) ae.to(device=into_device(config.ae_device), dtype=into_dtype(config.ae_dtype)) if config.ae_quantization_dtype is not None: - from float8_quantize import recursive_swap_linears + from fp8.float8_quantize import recursive_swap_linears recursive_swap_linears(ae) if config.offload_vae: diff --git a/predict.py b/predict.py index 792e77d..82efdec 100644 --- a/predict.py +++ b/predict.py @@ -13,8 +13,8 @@ from attr import dataclass from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack -from flux_pipeline import FluxPipeline -from util import LoadedModels +from fp8.flux_pipeline import FluxPipeline +from fp8.util import LoadedModels import numpy as np from einops import rearrange @@ -105,7 +105,12 @@ class Predictor(BasePredictor): def setup(self) -> None: return - def base_setup(self, flow_model_name: str) -> None: + def base_setup( + self, + flow_model_name: str, + compile_fp8: bool = False, + compile_bf16: bool = False, + ) -> None: self.flow_model_name = flow_model_name print(f"Booting model {self.flow_model_name}") @@ -157,10 +162,18 @@ def base_setup(self, flow_model_name: str) -> None: ) self.fp8_pipe = FluxPipeline.load_pipeline_from_config_path( - f"configs/config-1-{flow_model_name}-h100.json", shared_models=shared_models + f"fp8/configs/config-1-{flow_model_name}-h100.json", + shared_models=shared_models, ) - print("compiling") + if compile_fp8: + self.compile_fp8() + + if compile_bf16: + self.compile_bf16() + + def compile_fp8(self): + print("compiling fp8 model") st = time.time() self.fp8_pipe.generate( prompt="a cool dog", @@ -181,6 +194,21 @@ def base_setup(self, flow_model_name: str) -> None: print("compiled in ", time.time() - st) + def compile_bf16(self): + print("compiling bf16 model") + st = time.time() + + self.compile_run = True + self.base_predict( + prompt="a cool dog", + aspect_ratio="1:1", + num_outputs=1, + num_inference_steps=self.num_steps, + guidance=3.5, + seed=123, + ) + print("compiled in ", time.time() - st) + def aspect_ratio_to_width_height(self, aspect_ratio: str): return ASPECT_RATIOS.get(aspect_ratio) @@ -211,7 +239,6 @@ def base_predict( image: Path = None, # img2img for flux-dev prompt_strength: float = 0.8, seed: Optional[int] = None, - profile: bool = None, ) -> List[Path]: """Run a single prediction on the model""" torch_device = torch.device("cuda") @@ -280,37 +307,15 @@ def base_predict( torch.cuda.empty_cache() self.flux = self.flux.to(torch_device) - if self.compile_run: - print("Compiling") - st = time.time() - - if profile: - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ] - ) as p: - x, flux = denoise( - self.flux, - **inp, - timesteps=timesteps, - guidance=guidance, - compile_run=self.compile_run, - ) - - p.export_chrome_trace("trace.json") - else: - x, flux = denoise( - self.flux, - **inp, - timesteps=timesteps, - guidance=guidance, - compile_run=self.compile_run, - ) + x, flux = denoise( + self.flux, + **inp, + timesteps=timesteps, + guidance=guidance, + compile_run=self.compile_run, + ) if self.compile_run: - print(f"Compiled in {time.time() - st}") self.compile_run = False self.flux = flux @@ -377,7 +382,6 @@ def postprocess( output_format: str, output_quality: int, np_images: Optional[List[Image]] = None, - profile: bool = False, ) -> List[Path]: has_nsfw_content = [False] * len(images) @@ -414,8 +418,6 @@ def postprocess( ) print(f"Total safe images: {len(output_paths)} out of {len(images)}") - if profile: - output_paths.append(Path("trace.json")) return output_paths def run_safety_checker(self, images, np_images): @@ -441,19 +443,7 @@ def run_falcon_safety_checker(self, image): class SchnellPredictor(Predictor): def setup(self) -> None: - self.base_setup("flux-schnell") - - # this is how we compile the bf16 model - # self.compile_run = True - # self.predict( - # prompt="a cool dog", - # aspect_ratio="1:1", - # num_outputs=1, - # output_format='png', - # output_quality=80, - # disable_safety_checker=True, - # seed=123 - # ) + self.base_setup("flux-schnell", compile_fp8=True) def predict( self, @@ -494,23 +484,7 @@ def predict( class DevPredictor(Predictor): def setup(self) -> None: - self.base_setup("flux-dev") - - # this is how we compile the bf16 model - # self.compile_run = True - # self.predict( - # prompt="a cool dog", - # aspect_ratio="1:1", - # image=None, - # prompt_strength=1, - # num_outputs=1, - # num_inference_steps=self.num_steps, - # guidance=3.5, - # output_format='png', - # output_quality=80, - # disable_safety_checker=True, - # seed=123 - # ) + self.base_setup("flux-dev", compile_fp8=True) def predict( self, diff --git a/ruff.toml b/ruff.toml index 9499594..33bd358 100644 --- a/ruff.toml +++ b/ruff.toml @@ -26,11 +26,8 @@ exclude = [ "node_modules", "site-packages", "venv", - "modules/", "flux/", - "util.py", - "lora_loading.py", - "flux_pipeline.py", + "fp8/", ] # Same as Black.