From 430edf96e03c5d1244e9f1d2b6772424412a74a9 Mon Sep 17 00:00:00 2001 From: Luis Montero Date: Wed, 19 Jun 2024 15:31:55 +0200 Subject: [PATCH] feat: add multi-argument support in hybrid model This commit adds a small example of running a diffusion model with the Hybrid FHE model approach. This needed supporting multi-input hybrid models. Multi-output isn't supported yet. Also remote-running hybrid models isn't tested. --- src/concrete/ml/common/utils.py | 3 +- src/concrete/ml/onnx/convert.py | 13 +- src/concrete/ml/torch/hybrid_model.py | 70 ++++++-- tests/torch/test_hybrid_converter.py | 55 +++++- .../diffusion_unet_to_hybrid.py | 157 ++++++++++++++++++ 5 files changed, 278 insertions(+), 20 deletions(-) create mode 100644 use_case_examples/hybrid_diffusion/diffusion_unet_to_hybrid.py diff --git a/src/concrete/ml/common/utils.py b/src/concrete/ml/common/utils.py index 484676602..5712574ae 100644 --- a/src/concrete/ml/common/utils.py +++ b/src/concrete/ml/common/utils.py @@ -535,9 +535,10 @@ def to_tuple(x: Any) -> tuple: tuple: The input as a tuple. """ # If the input is not a tuple, return a tuple of a single element + if isinstance(x, list): + return tuple(x) if not isinstance(x, tuple): return (x,) - return x diff --git a/src/concrete/ml/onnx/convert.py b/src/concrete/ml/onnx/convert.py index 56545e5df..0529cea1c 100644 --- a/src/concrete/ml/onnx/convert.py +++ b/src/concrete/ml/onnx/convert.py @@ -86,7 +86,11 @@ def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto): # Create a GEMM node which combines the MatMul and Add operations gemm_node = helper.make_node( "Gemm", # op_type - [matmul_node.input[0], matmul_node.input[1], bias_other_input_node_name], # inputs + [ + matmul_node.input[0], + matmul_node.input[1], + bias_other_input_node_name, + ], # inputs [add_node.output[0]], # outputs name="Gemm_Node", alpha=1.0, @@ -149,9 +153,14 @@ def get_equivalent_numpy_forward_from_torch( arguments = list(inspect.signature(torch_module.forward).parameters) + if isinstance(dummy_input, torch.Tensor): + dummy_input = dummy_input.to("cpu") + else: + dummy_input = tuple(elt.to("cpu") for elt in dummy_input) + # Export to ONNX torch.onnx.export( - torch_module, + torch_module.to("cpu"), dummy_input, str(output_onnx_file_path), opset_version=OPSET_VERSION_FOR_ONNX_EXPORT, diff --git a/src/concrete/ml/torch/hybrid_model.py b/src/concrete/ml/torch/hybrid_model.py index 0da5eadee..715d303fb 100644 --- a/src/concrete/ml/torch/hybrid_model.py +++ b/src/concrete/ml/torch/hybrid_model.py @@ -192,7 +192,8 @@ def init_fhe_client( file.write(client_response.content) # Create the client client = FHEModelClient( - path_dir=str(path_to_client.resolve()), key_dir=str(self.path_to_keys.resolve()) + path_dir=str(path_to_client.resolve()), + key_dir=str(self.path_to_keys.resolve()), ) # The client first need to create the private and evaluation keys. serialized_evaluation_keys = client.get_serialized_evaluation_keys() @@ -218,7 +219,7 @@ def init_fhe_client( # towards client lazy loading with caching as done on the server. self.clients[shape] = (uid, client) - def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]: + def forward(self, *x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]: """Forward pass of the remote module. To change the behavior of this forward function one must change the fhe_local_mode @@ -242,6 +243,10 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]: # - simulate: compiled simulation # - calibrate: calibration + devices = [elt.device for elt in x] + device = devices[0] + assert all(elt.device == device for elt in x) + if self.fhe_local_mode not in { HybridFHEMode.CALIBRATE, HybridFHEMode.REMOTE, @@ -250,24 +255,32 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]: }: # Using quantized module assert self.private_q_module is not None - y = torch.Tensor( - self.private_q_module.forward(x.detach().numpy(), fhe=self.fhe_local_mode.value) + out = self.private_q_module.forward( + *(elt.to("cpu").detach().numpy() for elt in x), + fhe=self.fhe_local_mode.value, + ) + + # TODO: support multi-output + y = torch.tensor( + out, + device=device, + dtype=torch.float64 if device == "cpu" else torch.float32, ) elif self.fhe_local_mode == HybridFHEMode.CALIBRATE: # Calling torch + gathering calibration data assert self.private_module is not None - self.calibration_data.append(x.detach()) - y = self.private_module(x) + self.calibration_data.append(tuple(elt.to("cpu").detach() for elt in x)) + y = self.private_module(*x).to(device) assert isinstance(y, (QuantTensor, torch.Tensor)) elif self.fhe_local_mode == HybridFHEMode.REMOTE: # pragma:no cover # Remote call - y = self.remote_call(x) + y = self.remote_call(*x).to(device) elif self.fhe_local_mode == HybridFHEMode.TORCH: # Using torch layers assert self.private_module is not None - y = self.private_module(x) + y = self.private_module(*x).to(device) else: # pragma:no cover # Shouldn't happen raise ValueError(f"{self.fhe_local_mode} is not recognized") @@ -371,6 +384,11 @@ def __init__( self.verbose = verbose self._replace_modules() + def __getattr__(self, name: str): + if name in self.__dict__: + return self.__dict__[name] + return getattr(self.model, name) + def _replace_modules(self): """Replace the private modules in the model with remote layers.""" @@ -399,7 +417,7 @@ def _replace_modules(self): ) setattr(parent_module, last, remote_module) - def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: + def __call__(self, *x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: """Call method to run the model locally with a fhe mode. Args: @@ -410,8 +428,8 @@ def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: (torch.Tensor): The output tensor. """ self.set_fhe_mode(fhe) - x = self.model(x) - return x + y = self.model(*x) + return y @staticmethod def _get_module_by_name(model: nn.Module, name: str) -> Union[RemoteModule, nn.Module]: @@ -434,7 +452,9 @@ def _get_module_by_name(model: nn.Module, name: str) -> Union[RemoteModule, nn.M raise ValueError(f"No module found for name {name} in {list(model.named_modules())}") def init_client( - self, path_to_clients: Optional[Path] = None, path_to_keys: Optional[Path] = None + self, + path_to_clients: Optional[Path] = None, + path_to_keys: Optional[Path] = None, ): # pragma:no cover """Initialize client for all remote modules. @@ -452,7 +472,7 @@ def init_client( def compile_model( self, - x: torch.Tensor, + *x: torch.Tensor, n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE, rounding_threshold_bits: Optional[int] = None, p_error: Optional[float] = None, @@ -473,7 +493,16 @@ def compile_model( """ # We do a forward pass where we accumulate inputs to use for compilation self.set_fhe_mode(HybridFHEMode.CALIBRATE) - self.model(x) + + if self.verbose >= 2: + print("Doing first forward pass") + + # Would there be a way to finish the execution after the last remote module + with torch.no_grad(): + self.model(*x) + + if self.verbose >= 2: + print("First forward pass done") self.configuration = configuration @@ -481,7 +510,15 @@ def compile_model( remote_module = self._get_module_by_name(self.model, name) assert isinstance(remote_module, RemoteModule) - calibration_data_tensor = torch.cat(remote_module.calibration_data, dim=0) + assert remote_module.calibration_data + num_arg = len(remote_module.calibration_data[0]) + assert all(num_arg == len(elt) for elt in remote_module.calibration_data) + calibration_data_tensor: Tuple[torch.Tensor, ...] = tuple( + torch.cat([elt[arg_index] for elt in remote_module.calibration_data], dim=0) + for arg_index in range(num_arg) + ) + if self.verbose >= 2: + print(f"Compiling {name}") if has_any_qnn_layers(self.private_modules[name]): self.private_q_modules[name] = compile_brevitas_qat_model( @@ -504,6 +541,9 @@ def compile_model( self.remote_modules[name].private_q_module = self.private_q_modules[name] + if self.verbose >= 2: + print(f"Done compiling {name}") + def _save_fhe_circuit(self, path: Path, via_mlir=False): """Private method that saves the FHE circuits. diff --git a/tests/torch/test_hybrid_converter.py b/tests/torch/test_hybrid_converter.py index f29a2f323..4aa5501ee 100644 --- a/tests/torch/test_hybrid_converter.py +++ b/tests/torch/test_hybrid_converter.py @@ -10,7 +10,7 @@ from concrete.fhe import Configuration from transformers import GPT2LMHeadModel, GPT2Tokenizer -from concrete.ml.pytest.torch_models import PartialQATModel +from concrete.ml.pytest.torch_models import MultiInputNNConfigurable, PartialQATModel from concrete.ml.torch.hybrid_model import ( HybridFHEModel, tuple_to_underscore_str, @@ -185,6 +185,57 @@ def test_gpt2_hybrid_mlp( ) +def test_multi_input_model(): + """Test hybrid multi-input sub-module.""" + input_shape = 32 + dataset_size = 100 + + class Model(torch.nn.Module): + """Test model""" + + def __init__( + self, + ): # pylint: disable=unused-argument + super().__init__() + self.sub_model = MultiInputNNConfigurable( + use_conv=False, use_qat=False, input_output=input_shape, n_bits=None + ) + + def forward(self, x, y): + """Forward method + + Arguments: + x (torch.Tensor): first input + y (torch.Tensor): second input + + Returns: + torch.Tensor: output of the model + """ + return self.sub_model(x, y) + + inputs = ( + torch.randn( + ( + dataset_size, + input_shape, + ) + ), + torch.randn( + ( + dataset_size, + input_shape, + ) + ), + ) + + model = Model() + # Run the test with using a single module in FHE + model(*inputs) + assert isinstance(model, torch.nn.Module) + hybrid_model = HybridFHEModel(model, module_names="sub_model") + hybrid_model.compile_model(*inputs, rounding_threshold_bits=6, n_bits=6) + + def test_hybrid_brevitas_qat_model(): """Test GPT2 hybrid.""" n_bits = 3 @@ -203,7 +254,7 @@ def test_hybrid_brevitas_qat_model(): model(inputs) assert isinstance(model, torch.nn.Module) hybrid_model = HybridFHEModel(model, module_names="sub_module") - hybrid_model.compile_model(x=inputs) + hybrid_model.compile_model(*inputs) # Dependency 'huggingface-hub' raises a 'FutureWarning' from version 0.23.0 when calling the diff --git a/use_case_examples/hybrid_diffusion/diffusion_unet_to_hybrid.py b/use_case_examples/hybrid_diffusion/diffusion_unet_to_hybrid.py new file mode 100644 index 000000000..b19d152d6 --- /dev/null +++ b/use_case_examples/hybrid_diffusion/diffusion_unet_to_hybrid.py @@ -0,0 +1,157 @@ +""" Hybrid diffusion model in FHE. + +Original file is located at +https://colab.research.google.com/drive/1DFq9AI85hPOmHs2EFR0-gX2YdNpTA5z9 +""" + +import random + +import numpy as np +import torch +from diffusers import DDPMScheduler, UNet2DModel +from PIL import Image +from torch.cuda import seed_all +from tqdm.auto import tqdm + +from concrete.ml.torch.hybrid_model import HybridFHEModel + + +# We need to wrap the model to set defaults parameters +# as to have only one encrypted parameter +class Wrapper(torch.nn.Module): + def __init__(self, model, timestep, *args, **kwargs): + super().__init__(*args, **kwargs) + self.submodule = model + # implement fqallback + self.config = model.config + self.timestep = timestep + + def forward(self, inputs, timestep, **kwargs): + return self.submodule.forward(inputs, timestep=timestep, **kwargs) + + +def seed_everything(seed): + random.seed(seed) + seed += 1 + np.random.seed(seed % 2**32) + seed += 1 + torch.manual_seed(seed) + seed += 1 + torch.use_deterministic_algorithms(True) + return seed + + +def generate_image( + model, + scheduler, + output: str = "generated.png", + timestep: int = 1, + seed=None, + fhe=None, + device=None, +): + if seed is not None: + seed_everything(seed) + + sample_size = model.config.sample_size + scheduler.set_timesteps(timestep) + noise = torch.randn((1, 3, sample_size, sample_size), device=device) + input = noise + + for t in tqdm(scheduler.timesteps, total=timestep): + with torch.no_grad(): + kwargs = {} + if fhe is not None: + kwargs = {"fhe": fhe} + noisy_residual = model(input, t, **kwargs).sample + prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample + input: torch.Tensor = prev_noisy_sample + + image = (input / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy()[0] + image = Image.fromarray((image * 255).round().astype("uint8")) + + with open(output, "wb") as file: + image.save(file) + + return image + + +if __name__ == "__main__": + # TODO: handle multi-device execution in Hybrid Model + # For now we only rely on the CPU + # todo: handle other devices in HybridFHEModel + # model = model.to(torch.device("cpu") + + device_str = "cpu" # default + if torch.backends.mps.is_available(): + device_str = "mps" + if torch.cuda.is_available(): + device_str = "cuda" + # device_str = "cpu" + device = torch.device(device_str) + + # Create objects + scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256") + model = UNet2DModel.from_pretrained("google/ddpm-cat-256", device=device) + model.eval() + model = model.to(device) + sample_size = model.config.sample_size + timestep = 100 + + wrapped_model = Wrapper(model, timestep=timestep) + wrapped_model.to(device) + submodule_names = [ + layer_name + for (layer_name, layer) in wrapped_model.named_modules() + if isinstance(layer, (torch.nn.Linear, torch.nn.Conv2d)) + ] + print(f"{len(submodule_names)}") + + # Random selection of a submodule -> we could try with all submodules + index = 25 + print(submodule_names[index]) + fhe_submodules = [submodule_names[index]] + + hybrid_model = HybridFHEModel( + wrapped_model, + fhe_submodules, + verbose=2, + ) + + # Create a hybrid model + compile_size = 3 + inputs = torch.randn((compile_size, 3, sample_size, sample_size), device=device) + timesteps = torch.arange(0, compile_size, device=device) + print("compiling model") + hybrid_model.compile_model( + inputs, + timesteps, + n_bits=8, + ) + + print(hybrid_model) + + # Generate torch image as reference + generate_image( + hybrid_model, + scheduler, + output="hybrid.png", + timestep=timestep, + fhe="simulate", + device=device, + ) + generate_image( + model, + scheduler, + output="debug.png", + timestep=timestep, + device=device, + ) + generate_image( + wrapped_model, + scheduler, + timestep=timestep, + output="torch.png", + device=device, + )