Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid diffusion #810

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/concrete/ml/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,10 @@ def to_tuple(x: Any) -> tuple:
tuple: The input as a tuple.
"""
# If the input is not a tuple, return a tuple of a single element
if isinstance(x, list):
return tuple(x)
if not isinstance(x, tuple):
return (x,)

return x


Expand Down
13 changes: 11 additions & 2 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto):
# Create a GEMM node which combines the MatMul and Add operations
gemm_node = helper.make_node(
"Gemm", # op_type
[matmul_node.input[0], matmul_node.input[1], bias_other_input_node_name], # inputs
[
matmul_node.input[0],
matmul_node.input[1],
bias_other_input_node_name,
], # inputs
[add_node.output[0]], # outputs
name="Gemm_Node",
alpha=1.0,
Expand Down Expand Up @@ -149,9 +153,14 @@ def get_equivalent_numpy_forward_from_torch(

arguments = list(inspect.signature(torch_module.forward).parameters)

if isinstance(dummy_input, torch.Tensor):
dummy_input = dummy_input.to("cpu")
else:
dummy_input = tuple(elt.to("cpu") for elt in dummy_input)

# Export to ONNX
torch.onnx.export(
torch_module,
torch_module.to("cpu"),
dummy_input,
str(output_onnx_file_path),
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
Expand Down
70 changes: 55 additions & 15 deletions src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def init_fhe_client(
file.write(client_response.content)
# Create the client
client = FHEModelClient(
path_dir=str(path_to_client.resolve()), key_dir=str(self.path_to_keys.resolve())
path_dir=str(path_to_client.resolve()),
key_dir=str(self.path_to_keys.resolve()),
)
# The client first need to create the private and evaluation keys.
serialized_evaluation_keys = client.get_serialized_evaluation_keys()
Expand All @@ -218,7 +219,7 @@ def init_fhe_client(
# towards client lazy loading with caching as done on the server.
self.clients[shape] = (uid, client)

def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:
def forward(self, *x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:
"""Forward pass of the remote module.

To change the behavior of this forward function one must change the fhe_local_mode
Expand All @@ -242,6 +243,10 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:
# - simulate: compiled simulation
# - calibrate: calibration

devices = [elt.device for elt in x]
device = devices[0]
assert all(elt.device == device for elt in x)

if self.fhe_local_mode not in {
HybridFHEMode.CALIBRATE,
HybridFHEMode.REMOTE,
Expand All @@ -250,24 +255,32 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:
}:
# Using quantized module
assert self.private_q_module is not None
y = torch.Tensor(
self.private_q_module.forward(x.detach().numpy(), fhe=self.fhe_local_mode.value)
out = self.private_q_module.forward(
*(elt.to("cpu").detach().numpy() for elt in x),
fhe=self.fhe_local_mode.value,
)

# TODO: support multi-output
y = torch.tensor(
out,
device=device,
dtype=torch.float64 if device == "cpu" else torch.float32,
)

elif self.fhe_local_mode == HybridFHEMode.CALIBRATE:
# Calling torch + gathering calibration data
assert self.private_module is not None
self.calibration_data.append(x.detach())
y = self.private_module(x)
self.calibration_data.append(tuple(elt.to("cpu").detach() for elt in x))
y = self.private_module(*x).to(device)
assert isinstance(y, (QuantTensor, torch.Tensor))

elif self.fhe_local_mode == HybridFHEMode.REMOTE: # pragma:no cover
# Remote call
y = self.remote_call(x)
y = self.remote_call(*x).to(device)
elif self.fhe_local_mode == HybridFHEMode.TORCH:
# Using torch layers
assert self.private_module is not None
y = self.private_module(x)
y = self.private_module(*x).to(device)
else: # pragma:no cover
# Shouldn't happen
raise ValueError(f"{self.fhe_local_mode} is not recognized")
Expand Down Expand Up @@ -371,6 +384,11 @@ def __init__(
self.verbose = verbose
self._replace_modules()

def __getattr__(self, name: str):
if name in self.__dict__:
return self.__dict__[name]
return getattr(self.model, name)

def _replace_modules(self):
"""Replace the private modules in the model with remote layers."""

Expand Down Expand Up @@ -399,7 +417,7 @@ def _replace_modules(self):
)
setattr(parent_module, last, remote_module)

def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:
def __call__(self, *x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:
"""Call method to run the model locally with a fhe mode.

Args:
Expand All @@ -410,8 +428,8 @@ def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:
(torch.Tensor): The output tensor.
"""
self.set_fhe_mode(fhe)
x = self.model(x)
return x
y = self.model(*x)
return y

@staticmethod
def _get_module_by_name(model: nn.Module, name: str) -> Union[RemoteModule, nn.Module]:
Expand All @@ -434,7 +452,9 @@ def _get_module_by_name(model: nn.Module, name: str) -> Union[RemoteModule, nn.M
raise ValueError(f"No module found for name {name} in {list(model.named_modules())}")

def init_client(
self, path_to_clients: Optional[Path] = None, path_to_keys: Optional[Path] = None
self,
path_to_clients: Optional[Path] = None,
path_to_keys: Optional[Path] = None,
): # pragma:no cover
"""Initialize client for all remote modules.

Expand All @@ -452,7 +472,7 @@ def init_client(

def compile_model(
self,
x: torch.Tensor,
*x: torch.Tensor,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
Expand All @@ -473,15 +493,32 @@ def compile_model(
"""
# We do a forward pass where we accumulate inputs to use for compilation
self.set_fhe_mode(HybridFHEMode.CALIBRATE)
self.model(x)

if self.verbose >= 2:
print("Doing first forward pass")

# Would there be a way to finish the execution after the last remote module
with torch.no_grad():
self.model(*x)

if self.verbose >= 2:
print("First forward pass done")

self.configuration = configuration

for name in self.module_names:
remote_module = self._get_module_by_name(self.model, name)
assert isinstance(remote_module, RemoteModule)

calibration_data_tensor = torch.cat(remote_module.calibration_data, dim=0)
assert remote_module.calibration_data
num_arg = len(remote_module.calibration_data[0])
assert all(num_arg == len(elt) for elt in remote_module.calibration_data)
calibration_data_tensor: Tuple[torch.Tensor, ...] = tuple(
torch.cat([elt[arg_index] for elt in remote_module.calibration_data], dim=0)
for arg_index in range(num_arg)
)
if self.verbose >= 2:
print(f"Compiling {name}")

if has_any_qnn_layers(self.private_modules[name]):
self.private_q_modules[name] = compile_brevitas_qat_model(
Expand All @@ -504,6 +541,9 @@ def compile_model(

self.remote_modules[name].private_q_module = self.private_q_modules[name]

if self.verbose >= 2:
print(f"Done compiling {name}")

def _save_fhe_circuit(self, path: Path, via_mlir=False):
"""Private method that saves the FHE circuits.

Expand Down
55 changes: 53 additions & 2 deletions tests/torch/test_hybrid_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from concrete.fhe import Configuration
from transformers import GPT2LMHeadModel, GPT2Tokenizer

from concrete.ml.pytest.torch_models import PartialQATModel
from concrete.ml.pytest.torch_models import MultiInputNNConfigurable, PartialQATModel
from concrete.ml.torch.hybrid_model import (
HybridFHEModel,
tuple_to_underscore_str,
Expand Down Expand Up @@ -185,6 +185,57 @@ def test_gpt2_hybrid_mlp(
)


def test_multi_input_model():
"""Test hybrid multi-input sub-module."""
input_shape = 32
dataset_size = 100

class Model(torch.nn.Module):
"""Test model"""

def __init__(
self,
): # pylint: disable=unused-argument
super().__init__()
self.sub_model = MultiInputNNConfigurable(
use_conv=False, use_qat=False, input_output=input_shape, n_bits=None
)

def forward(self, x, y):
"""Forward method

Arguments:
x (torch.Tensor): first input
y (torch.Tensor): second input

Returns:
torch.Tensor: output of the model
"""
return self.sub_model(x, y)

inputs = (
torch.randn(
(
dataset_size,
input_shape,
)
),
torch.randn(
(
dataset_size,
input_shape,
)
),
)

model = Model()
# Run the test with using a single module in FHE
model(*inputs)
assert isinstance(model, torch.nn.Module)
hybrid_model = HybridFHEModel(model, module_names="sub_model")
hybrid_model.compile_model(*inputs, rounding_threshold_bits=6, n_bits=6)


def test_hybrid_brevitas_qat_model():
"""Test GPT2 hybrid."""
n_bits = 3
Expand All @@ -203,7 +254,7 @@ def test_hybrid_brevitas_qat_model():
model(inputs)
assert isinstance(model, torch.nn.Module)
hybrid_model = HybridFHEModel(model, module_names="sub_module")
hybrid_model.compile_model(x=inputs)
hybrid_model.compile_model(*inputs)


# Dependency 'huggingface-hub' raises a 'FutureWarning' from version 0.23.0 when calling the
Expand Down
Loading
Loading