diff --git a/tests/torchtune/models/flux/__init__.py b/tests/torchtune/models/flux/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/tests/torchtune/models/flux/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/torchtune/models/flux/test_flux_autoencoder.py b/tests/torchtune/models/flux/test_flux_autoencoder.py new file mode 100644 index 0000000000..bb385dbc94 --- /dev/null +++ b/tests/torchtune/models/flux/test_flux_autoencoder.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from torchtune.models.flux import flux_1_autoencoder +from torchtune.training.seed import set_seed + +BSZ = 32 +CH_IN = 3 +RESOLUTION = 16 +CH_MULTS = [1, 2] +CH_Z = 4 +RES_Z = RESOLUTION // len(CH_MULTS) + + +@pytest.fixture(autouse=True) +def random(): + set_seed(0) + + +class TestFluxAutoencoder: + @pytest.fixture + def model(self): + model = flux_1_autoencoder( + resolution=RESOLUTION, + ch_in=CH_IN, + ch_out=3, + ch_base=32, + ch_mults=CH_MULTS, + ch_z=CH_Z, + n_layers_per_resample_block=2, + scale_factor=1.0, + shift_factor=0.0, + ) + + for param in model.parameters(): + param.data.uniform_(0, 0.1) + + return model + + @pytest.fixture + def img(self): + return torch.randn(BSZ, CH_IN, RESOLUTION, RESOLUTION) + + @pytest.fixture + def z(self): + return torch.randn(BSZ, CH_Z, RES_Z, RES_Z) + + def test_forward(self, model, img): + actual = model(img) + assert actual.shape == (BSZ, CH_IN, RESOLUTION, RESOLUTION) + + actual = torch.mean(actual, dim=(0, 2, 3)) + expected = torch.tensor([0.4286, 0.4276, 0.4054]) + torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) + + def test_backward(self, model, img): + y = model(img) + loss = y.mean() + loss.backward() + + def test_encode(self, model, img): + actual = model.encode(img) + assert actual.shape == (BSZ, CH_Z, RES_Z, RES_Z) + + actual = torch.mean(actual, dim=(0, 2, 3)) + expected = torch.tensor([0.6150, 0.7959, 0.7178, 0.7011]) + torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) + + def test_decode(self, model, z): + actual = model.decode(z) + assert actual.shape == (BSZ, CH_IN, RESOLUTION, RESOLUTION) + + actual = torch.mean(actual, dim=(0, 2, 3)) + expected = torch.tensor([0.4246, 0.4241, 0.4014]) + torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) diff --git a/torchtune/models/flux/__init__.py b/torchtune/models/flux/__init__.py new file mode 100644 index 0000000000..3d08ac24fc --- /dev/null +++ b/torchtune/models/flux/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from ._model_builders import flux_1_autoencoder + +__all__ = [ + "flux_1_autoencoder", +] diff --git a/torchtune/models/flux/_autoencoder.py b/torchtune/models/flux/_autoencoder.py new file mode 100644 index 0000000000..666178d1d8 --- /dev/null +++ b/torchtune/models/flux/_autoencoder.py @@ -0,0 +1,321 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from torchtune.modules.attention import MultiHeadAttention + +# ch = number of channels (size of the channel dimension) + + +class FluxAutoencoder(nn.Module): + """ + The image autoencoder for Flux diffusion models. + + Args: + img_shape (Tuple[int, int, int]): The shape of the input image (without the batch dimension). + encoder (nn.Module): The encoder module. + decoder (nn.Module): The decoder module. + """ + + def __init__( + self, + img_shape: Tuple[int, int, int], + encoder: nn.Module, + decoder: nn.Module, + ): + super().__init__() + self._img_shape = img_shape + self.encoder = encoder + self.decoder = decoder + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): input image of shape [bsz, ch_in, img resolution, img resolution] + + Returns: + Tensor: output image of the same shape + """ + return self.decode(self.encode(x)) + + def encode(self, x: Tensor) -> Tensor: + """ + Encode images into their latent representations. + + Args: + x (Tensor): input images (shape = [bsz, ch_in, img resolution, img resolution]) + + Returns: + Tensor: latent encodings (shape = [bsz, ch_z, latent resolution, latent resolution]) + """ + assert x.shape[1:] == self._img_shape + return self.encoder(x) + + def decode(self, z: Tensor) -> Tensor: + """ + Decode latent representations into images. + + Args: + z (Tensor): latent encodings (shape = [bsz, ch_z, latent resolution, latent resolution]) + + Returns: + Tensor: output images (shape = [bsz, ch_in, img resolution, img resolution]) + """ + return self.decoder(z) + + +class FluxEncoder(nn.Module): + """ + The encoder half of the Flux diffusion model's image autoencoder. + + Args: + ch_in (int): The number of channels of the input image. + ch_z (int): The number of latent channels (dimension of the latent vector `z`). + channels (List[int]): The number of output channels for each downsample block. + n_layers_per_down_block (int): Number of resnet layers per upsample block. + scale_factor (float): Constant for scaling `z`. + shift_factor (float): Constant for shifting `z`. + """ + + def __init__( + self, + ch_in: int, + ch_z: int, + channels: List[int], + n_layers_per_down_block: int, + scale_factor: float, + shift_factor: float, + ): + super().__init__() + self.scale_factor = scale_factor + self.shift_factor = shift_factor + + self.conv_in = nn.Conv2d(ch_in, channels[0], kernel_size=3, stride=1, padding=1) + + self.down = nn.ModuleList( + [ + DownBlock( + n_layers=n_layers_per_down_block, + ch_in=channels[i - 1] if i > 0 else channels[0], + ch_out=channels[i], + downsample=i < len(channels) - 1, + ) + for i in range(len(channels)) + ] + ) + + self.mid = mid_block(channels[-1]) + + self.end = end_block(channels[-1], 2 * ch_z) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): input images (shape = [bsz, ch_in, img resolution, img resolution]) + + Returns: + Tensor: latent encodings (shape = [bsz, ch_z, latent resolution, latent resolution]) + """ + h = self.conv_in(x) + for block in self.down: + h = block(h) + h = self.mid(h) + h = self.end(h) + z = diagonal_gaussian(h) + return self.scale_factor * (z - self.shift_factor) + + +class FluxDecoder(nn.Module): + """ + The encoder half of the Flux diffusion model's image autoencoder. + + Args: + ch_out (int): The number of channels of the output image. + ch_z (int): The number of latent channels (dimension of the latent vector `z`). + channels (List[int]): The number of output channels for each upsample block. + n_layers_per_up_block (int): Number of resnet layers per upsample block. + scale_factor (float): Constant for scaling `z`. + shift_factor (float): Constant for shifting `z`. + """ + + def __init__( + self, + ch_out: int, + ch_z: int, + channels: List[int], + n_layers_per_up_block: int, + scale_factor: float, + shift_factor: float, + ): + super().__init__() + self.scale_factor = scale_factor + self.shift_factor = shift_factor + + self.conv_in = nn.Conv2d(ch_z, channels[0], kernel_size=3, stride=1, padding=1) + + self.mid = mid_block(channels[0]) + + self.up = nn.ModuleList( + [ + UpBlock( + n_layers=n_layers_per_up_block, + ch_in=channels[i - 1] if i > 0 else channels[0], + ch_out=channels[i], + upsample=i < len(channels) - 1, + ) + for i in range(len(channels)) + ] + ) + + self.end = end_block(channels[-1], ch_out) + + def forward(self, z: Tensor) -> Tensor: + """ + Args: + z (Tensor): latent encodings (shape = [bsz, ch_z, latent resolution, latent resolution]) + + Returns: + Tensor: output images (shape = [bsz, ch_in, img resolution, img resolution]) + """ + z = z / self.scale_factor + self.shift_factor + h = self.conv_in(z) + h = self.mid(h) + for block in self.up: + h = block(h) + x = self.end(h) + return x + + +def mid_block(ch: int) -> nn.Module: + return nn.Sequential( + ResnetLayer(ch_in=ch, ch_out=ch), + AttnLayer(ch), + ResnetLayer(ch_in=ch, ch_out=ch), + ) + + +def end_block(ch_in: int, ch_out: int) -> nn.Module: + return nn.Sequential( + nn.GroupNorm(num_groups=32, num_channels=ch_in, eps=1e-6, affine=True), + nn.SiLU(), + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1), + ) + + +class DownBlock(nn.Module): + def __init__(self, n_layers: int, ch_in: int, ch_out: int, downsample: bool): + super().__init__() + self.layers = resnet_layers(n_layers, ch_in, ch_out) + self.downsample = Downsample(ch_out) if downsample else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + return self.downsample(self.layers(x)) + + +class UpBlock(nn.Module): + def __init__(self, n_layers: int, ch_in: int, ch_out: int, upsample: bool): + super().__init__() + self.layers = resnet_layers(n_layers, ch_in, ch_out) + self.upsample = Upsample(ch_out) if upsample else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + return self.upsample(self.layers(x)) + + +def resnet_layers(n: int, ch_in: int, ch_out: int) -> nn.Module: + return nn.Sequential( + *[ + ResnetLayer(ch_in=ch_in if i == 0 else ch_out, ch_out=ch_out) + for i in range(n) + ] + ) + + +class AttnLayer(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + self.norm = nn.GroupNorm(num_groups=32, num_channels=dim, eps=1e-6, affine=True) + self.attn = MultiHeadAttention( + embed_dim=dim, + num_heads=1, + num_kv_heads=1, + head_dim=dim, + q_proj=nn.Linear(dim, dim), + k_proj=nn.Linear(dim, dim), + v_proj=nn.Linear(dim, dim), + output_proj=nn.Linear(dim, dim), + is_causal=False, + ) + + def forward(self, x: Tensor) -> Tensor: + b, c, h, w = x.shape + residual = x + + x = self.norm(x) + + # b c h w -> b (h w) c + x = torch.einsum("bchw -> bhwc", x) + x = x.reshape(b, h * w, c) + + x = self.attn(x, x) + + # b (h w) c -> b c h w + x = x.reshape(b, h, w, c) + x = torch.einsum("bhwc -> bchw", x) + + return x + residual + + +class ResnetLayer(nn.Module): + def __init__(self, ch_in: int, ch_out: int): + super().__init__() + self.main = nn.Sequential( + *[ + nn.GroupNorm(num_groups=32, num_channels=ch_in, eps=1e-6, affine=True), + nn.SiLU(), + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(num_groups=32, num_channels=ch_out, eps=1e-6, affine=True), + nn.SiLU(), + nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1), + ] + ) + self.shortcut = ( + nn.Identity() + if ch_in == ch_out + else nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, x: Tensor) -> Tensor: + return self.main(x) + self.shortcut(x) + + +class Downsample(nn.Module): + def __init__(self, ch: int): + super().__init__() + self.conv = nn.Conv2d(ch, ch, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor) -> Tensor: + return self.conv(F.pad(x, (0, 1, 0, 1), mode="constant", value=0)) + + +class Upsample(nn.Module): + def __init__(self, ch: int): + super().__init__() + self.conv = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + return self.conv(F.interpolate(x, scale_factor=2.0, mode="nearest")) + + +def diagonal_gaussian(z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=1) + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) diff --git a/torchtune/models/flux/_convert_weights.py b/torchtune/models/flux/_convert_weights.py new file mode 100644 index 0000000000..0277769ef6 --- /dev/null +++ b/torchtune/models/flux/_convert_weights.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import re +from typing import List + +REGEX_CONVERSIONS = [ + (r"^(encoder|decoder)\.norm_out\.(weight|bias)$", r"\1.end.0.\2"), + (r"^(encoder|decoder)\.conv_out\.(weight|bias)$", r"\1.end.2.\2"), +] + +REGEX_UNCHANGED = [r"^(encoder|decoder)\.conv_in\.(weight|bias)$"] + +RESNET_LAYER_CONVERSION = { + "norm1": "main.0", + "conv1": "main.2", + "norm2": "main.3", + "conv2": "main.5", + "nin_shortcut": "shortcut", +} + +ATTN_LAYER_CONVERSION = { + "q": "attn.q_proj", + "k": "attn.k_proj", + "v": "attn.v_proj", + "proj_out": "attn.output_proj", + "norm": "norm", +} + + +def flux_ae_hf_to_tune(state_dict: dict) -> dict: + new_state_dict = {} + for key, value in state_dict.items(): + new_key = _convert_key(key) + if "proj" in new_key: + value = value.squeeze() + new_state_dict[new_key] = value + return new_state_dict + + +class ConversionError(Exception): + pass + + +def _convert_key(key: str) -> str: + # check if we should leave this key unchanged + for pattern in REGEX_UNCHANGED: + if re.match(pattern, key): + return key + + # check if we can do a simple regex conversion + for pattern, replacement in REGEX_CONVERSIONS: + if re.match(pattern, key): + return re.sub(pattern, replacement, key) + + # build the new key part-by-part + parts = key.split(".") + new_parts = [] + i = 0 + + # add the first encoder/decoder model part unchanged + model = parts[i] + assert model in ["encoder", "decoder"] + new_parts.append(model) + i += 1 + + # add the next mid/down/up section part unchanged + section = parts[i] + new_parts.append(section) + i += 1 + + # convert mid section keys + if section == "mid": + layer = parts[i] + i += 1 + for layer_idx, layer_name in enumerate(["block_1", "attn_1", "block_2"]): + if layer == layer_name: + new_parts.append(str(layer_idx)) + if layer_name.startswith("attn"): + _convert_attn_layer(new_parts, parts, i) + else: + _convert_resnet_layer(new_parts, parts, i) + break + else: + raise ConversionError(key) + + # convert down section keys + elif section == "down": + new_parts.append(parts[i]) # add the down block idx + i += 1 + if parts[i] == "block": + new_parts.append("layers") + i += 1 + new_parts.append(parts[i]) # add the resnet layer idx + i += 1 + _convert_resnet_layer(new_parts, parts, i) + elif parts[i] == "downsample": + new_parts.extend(parts[i:]) # the downsampling layer is left unchanged + else: + raise ConversionError(key) + + # convert up section keys + elif section == "up": + # the first part in the "up" section is the block idx: one of [0, 1, 2, 3] + # up blocks are in reverse order in the original state dict + # so we need to convert [0, 1, 2, 3] -> [3, 2, 1, 0] + new_parts.append(str(3 - int(parts[i]))) + i += 1 + if parts[i] == "block": + new_parts.append("layers") + i += 1 + new_parts.append(parts[i]) # add the resnet layer idx + i += 1 + _convert_resnet_layer(new_parts, parts, i) + elif parts[i] == "upsample": + new_parts.extend(parts[i:]) # the upsampling layer is left unchanged + else: + raise ConversionError(key) + + else: + raise ConversionError("unknown section:", key) + + return ".".join(new_parts) + + +def _convert_attn_layer(new_parts: List[str], parts: List[str], i: int): + new_parts.append(ATTN_LAYER_CONVERSION[parts[i]]) + i += 1 + new_parts.append(parts[i]) + + +def _convert_resnet_layer(new_parts: List[str], parts: List[str], i: int): + new_parts.append(RESNET_LAYER_CONVERSION[parts[i]]) + i += 1 + new_parts.append(parts[i]) diff --git a/torchtune/models/flux/_model_builders.py b/torchtune/models/flux/_model_builders.py new file mode 100644 index 0000000000..84ed08c060 --- /dev/null +++ b/torchtune/models/flux/_model_builders.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +from torchtune.models.flux._autoencoder import FluxAutoencoder, FluxDecoder, FluxEncoder + + +def flux_1_autoencoder( + resolution: int = 256, + ch_in: int = 3, + ch_out: int = 3, + ch_base: int = 128, + ch_mults: List[int] = [1, 2, 4, 4], + ch_z: int = 16, + n_layers_per_resample_block: int = 2, + scale_factor: float = 0.3611, + shift_factor: float = 0.1159, +) -> FluxAutoencoder: + """ + The image autoencoder for all current Flux diffusion models: + - FLUX.1-dev + - FLUX.1-schnell + - FLUX.1-Canny-dev + - FLUX.1-Depth-dev + - FLUX.1-Fill-dev + + ch = number of channels (size of the channel dimension) + + Args: + resolution (int): The height/width of the square input image. + ch_in (int): The number of channels of the input image. + ch_out (int): The number of channels of the output image. + ch_base (int): The base number of channels. + This gets multiplied by `ch_mult` values to get the number of inner channels during downsampling/upsampling. + ch_mults (List[int]): The channel multiple per downsample/upsample block. + This gets multiplied by `ch_base` to get the number of inner channels during downsampling/upsampling. + ch_z (int): The number of latent channels (dimension of the latent vector `z`). + n_layers_per_resample_block (int): Number of resnet layers per downsample/upsample block. + scale_factor (float): Constant for scaling `z`. + shift_factor (float): Constant for shifting `z`. + + Returns: + FluxAutoencoder + """ + channels = [ch_base * mult for mult in ch_mults] + + encoder = FluxEncoder( + ch_in=ch_in, + ch_z=ch_z, + channels=channels, + n_layers_per_down_block=n_layers_per_resample_block, + scale_factor=scale_factor, + shift_factor=shift_factor, + ) + + decoder = FluxDecoder( + ch_out=ch_out, + ch_z=ch_z, + channels=list(reversed(channels)), + # decoder gets one more layer per up block than the encoder's down blocks + n_layers_per_up_block=n_layers_per_resample_block + 1, + scale_factor=scale_factor, + shift_factor=shift_factor, + ) + + return FluxAutoencoder( + img_shape=(ch_in, resolution, resolution), + encoder=encoder, + decoder=decoder, + )