Skip to content

Commit

Permalink
More gardening for open source release (#22)
Browse files Browse the repository at this point in the history
* removed comments, refactored compilation code

* pinning versions

* fp8 module

* pinned nightlies + fp8 paths

* removing half-implemented profiling

* ruff
  • Loading branch information
daanelson authored Sep 23, 2024
1 parent 5c8c435 commit 99cfbb7
Show file tree
Hide file tree
Showing 14 changed files with 71 additions and 148 deletions.
10 changes: 6 additions & 4 deletions cog.yaml.template
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ build:
- "tokenizers==0.19.1"
- "protobuf==5.27.2"
- "diffusers==0.29.2"
- "loguru"
- "pybase64"
- "loguru==0.7.2"
- "pybase64==1.4.0"
- "pydash==8.0.3"


# commands run after the environment is setup
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
- pip uninstall -y torch torchvision torchaudio
- pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124
- pip install pydash
# - pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124
# pinning to specific nightlies for release
- pip3 install https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240918%2Bcu124-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu124/torchaudio-2.5.0.dev20240918%2Bcu124-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu124/torchvision-0.20.0.dev20240918%2Bcu124-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2B5fe38ffd73-cp311-cp311-linux_x86_64.whl
Empty file added fp8/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion float8_quantize.py → fp8/float8_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import __version__
from torch.version import cuda

from modules.flux_model import Modulation
from fp8.modules.flux_model import Modulation

IS_TORCH_2_4 = __version__ < (2, 4, 9)
LT_TORCH_2_4 = __version__ < (2, 4)
Expand Down
19 changes: 7 additions & 12 deletions flux_pipeline.py → fp8/flux_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
import torch
from einops import rearrange, repeat

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.benchmark_limit = 20
torch.set_float32_matmul_precision("high")
from pybase64 import standard_b64decode
from torch._dynamo import config
from torch._inductor import config as ind_config
Expand All @@ -32,9 +27,9 @@
from torchvision.transforms import functional as TF
from tqdm import tqdm

import lora_loading
from image_encoder import ImageEncoder
from util import (
import fp8.lora_loading as lora_loading
from fp8.image_encoder import ImageEncoder
from fp8.util import (
LoadedModels,
ModelSpec,
ModelVersion,
Expand All @@ -51,9 +46,9 @@


if TYPE_CHECKING:
from modules.autoencoder import AutoEncoder
from modules.conditioner import HFEmbedder
from modules.flux_model import Flux
from fp8.modules.autoencoder import AutoEncoder
from fp8.modules.conditioner import HFEmbedder
from fp8.modules.flux_model import Flux


class FluxPipeline:
Expand Down Expand Up @@ -738,7 +733,7 @@ def load_pipeline_from_config_path(
def load_pipeline_from_config(
cls, config: ModelSpec, debug: bool = False, shared_models: LoadedModels = None
) -> "FluxPipeline":
from float8_quantize import quantize_flow_transformer_and_dispatch_float8
from fp8.float8_quantize import quantize_flow_transformer_and_dispatch_float8

with torch.inference_mode():
if debug:
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions lora_loading.py → fp8/lora_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from cublas_ops import CublasLinear
except Exception as e:
CublasLinear = type(None)
from float8_quantize import F8Linear
from modules.flux_model import Flux
from fp8.float8_quantize import F8Linear
from fp8.modules.flux_model import Flux


def swap_scale_shift(weight):
Expand Down
File renamed without changes.
File renamed without changes.
16 changes: 8 additions & 8 deletions modules/flux_model.py → fp8/modules/flux_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

if TYPE_CHECKING:
from util import ModelSpec
from fp8.util import ModelSpec

DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
torch.backends.cuda.matmul.allow_tf32 = True
Expand Down Expand Up @@ -118,7 +118,7 @@ class MLPEmbedder(nn.Module):
def __init__(
self, in_dim: int, hidden_dim: int, prequantized: bool = False, quantized=False
):
from float8_quantize import F8Linear
from fp8.float8_quantize import F8Linear

super().__init__()
self.in_layer = (
Expand Down Expand Up @@ -188,7 +188,7 @@ def __init__(
prequantized: bool = False,
):
super().__init__()
from float8_quantize import F8Linear
from fp8.float8_quantize import F8Linear

self.num_heads = num_heads
head_dim = dim // num_heads
Expand Down Expand Up @@ -236,7 +236,7 @@ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool, quantized_modulation: bool = False):
super().__init__()
from float8_quantize import F8Linear
from fp8.float8_quantize import F8Linear

self.is_double = double
self.multiplier = 6 if double else 3
Expand Down Expand Up @@ -272,7 +272,7 @@ def __init__(
prequantized: bool = False,
):
super().__init__()
from float8_quantize import F8Linear
from fp8.float8_quantize import F8Linear

self.dtype = dtype

Expand Down Expand Up @@ -417,7 +417,7 @@ def __init__(
prequantized: bool = False,
):
super().__init__()
from float8_quantize import F8Linear
from fp8.float8_quantize import F8Linear

self.dtype = dtype
self.hidden_dim = hidden_size
Expand Down Expand Up @@ -515,7 +515,7 @@ def __init__(self, config: "ModelSpec", dtype: torch.dtype = torch.float16):
prequantized_flow = config.prequantized_flow
quantized_embedders = config.quantize_flow_embedder_layers and prequantized_flow
quantized_modulation = config.quantize_modulation and prequantized_flow
from float8_quantize import F8Linear
from fp8.float8_quantize import F8Linear

if config.params.hidden_size % config.params.num_heads != 0:
raise ValueError(
Expand Down Expand Up @@ -671,7 +671,7 @@ def forward(
def from_pretrained(
cls: "Flux", path: str, dtype: torch.dtype = torch.float16
) -> "Flux":
from util import load_config_from_path
from fp8.util import load_config_from_path
from safetensors.torch import load_file

config = load_config_from_path(path)
Expand Down
53 changes: 4 additions & 49 deletions util.py → fp8/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from typing import Any, Literal, Optional

import torch
from modules.autoencoder import AutoEncoder, AutoEncoderParams
from modules.conditioner import HFEmbedder
from modules.flux_model import Flux, FluxParams
from fp8.modules.autoencoder import AutoEncoder, AutoEncoderParams
from fp8.modules.conditioner import HFEmbedder
from fp8.modules.flux_model import Flux, FluxParams
from safetensors.torch import load_file as load_sft

try:
Expand Down Expand Up @@ -34,55 +34,10 @@ class QuantizationDtype(StrEnum):
qint8 = "qint8"


# @dataclass
# class ModelSpec:
# version: ModelVersion
# params: FluxParams
# ae_params: AutoEncoderParams
# ckpt_path: str | None
# ae_path: str | None
# repo_id: str | None
# repo_flow: str | None
# repo_ae: str | None
# text_enc_max_length: int = 512
# text_enc_path: str | None = None
# text_enc_device: str | torch.device | None = "cuda:0"
# ae_device: str | torch.device | None = "cuda:0"
# flux_device: str | torch.device | None = "cuda:0"
# flow_dtype: str = "float16"
# ae_dtype: str = "bfloat16"
# text_enc_dtype: str = "bfloat16"
# num_to_quant: Optional[int] = 20
# quantize_extras: bool = False
# compile_extras: bool = False
# compile_blocks: bool = False
# flow_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
# text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8
# ae_quantization_dtype: Optional[QuantizationDtype] = None
# clip_quantization_dtype: Optional[QuantizationDtype] = None
# offload_text_encoder: bool = False
# offload_vae: bool = False
# offload_flow: bool = False
# prequantized_flow: bool = False
# quantize_modulation: bool = True
# quantize_flow_embedder_layers: bool = False

# @dataclass
# class LoadedModels:
# flow: Flux
# ae: AutoEncoder
# clip: HFEmbedder
# t5: HFEmbedder
# config: ModelSpec

class ModelSpec(BaseModel):
class Config:
arbitrary_types_allowed = True
use_enum_values = True
# model_config: ConfigDict = {
# "arbitrary_types_allowed": True,
# "use_enum_values": True,
# }
version: ModelVersion
params: FluxParams
ae_params: AutoEncoderParams
Expand Down Expand Up @@ -325,7 +280,7 @@ def load_autoencoder(config: ModelSpec) -> AutoEncoder:
print_load_warning(missing, unexpected)
ae.to(device=into_device(config.ae_device), dtype=into_dtype(config.ae_dtype))
if config.ae_quantization_dtype is not None:
from float8_quantize import recursive_swap_linears
from fp8.float8_quantize import recursive_swap_linears

recursive_swap_linears(ae)
if config.offload_vae:
Expand Down
Loading

0 comments on commit 99cfbb7

Please sign in to comment.