Skip to content

Commit

Permalink
feat: add multi-argument support in hybrid model
Browse files Browse the repository at this point in the history
This commit adds a small example of running a diffusion model with the
Hybrid FHE model approach.

This needed supporting multi-input hybrid models.

Multi-output isn't supported yet.

Also remote-running hybrid models isn't tested.
  • Loading branch information
fd0r committed Jul 25, 2024
1 parent 23cb7dc commit da36c0b
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 30 deletions.
1 change: 1 addition & 0 deletions .gitleaksignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ a99389ee01cbb972e46a892d3d0e9c7f8ee23f59:use_case_examples/training/analyze.ipyn
f41de03048a9ed27946b875e81b34138bb4bb17b:use_case_examples/training/analyze.ipynb:aws-access-token:6404
e2904473898ddd325f245f4faca526a0e9520f49:builders/Dockerfile.zamalang-env:generic-api-key:5
7d5e885816f1f1e432dd94da38c5c8267292056a:docs/advanced_examples/XGBRegressor.ipynb:aws-access-token:1026
25c5e7abaa7382520af3fb7a64266e193b1f6a59:poetry.lock:square-access-token:6401
3 changes: 2 additions & 1 deletion src/concrete/ml/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,10 @@ def to_tuple(x: Any) -> tuple:
tuple: The input as a tuple.
"""
# If the input is not a tuple, return a tuple of a single element
if isinstance(x, list):
return tuple(x)
if not isinstance(x, tuple):
return (x,)

return x


Expand Down
13 changes: 11 additions & 2 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto):
# Create a GEMM node which combines the MatMul and Add operations
gemm_node = helper.make_node(
"Gemm", # op_type
[matmul_node.input[0], matmul_node.input[1], bias_other_input_node_name], # inputs
[
matmul_node.input[0],
matmul_node.input[1],
bias_other_input_node_name,
], # inputs
[add_node.output[0]], # outputs
name="Gemm_Node",
alpha=1.0,
Expand Down Expand Up @@ -142,9 +146,14 @@ def get_equivalent_numpy_forward_from_torch(

arguments = list(inspect.signature(torch_module.forward).parameters)

if isinstance(dummy_input, torch.Tensor):
dummy_input = dummy_input.to("cpu")
else:
dummy_input = tuple(elt.to("cpu") for elt in dummy_input)

# Export to ONNX
torch.onnx.export(
torch_module,
torch_module.to("cpu"),
dummy_input,
str(output_onnx_file_path),
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
Expand Down
70 changes: 55 additions & 15 deletions src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def init_fhe_client(
file.write(client_response.content)
# Create the client
client = FHEModelClient(
path_dir=str(path_to_client.resolve()), key_dir=str(self.path_to_keys.resolve())
path_dir=str(path_to_client.resolve()),
key_dir=str(self.path_to_keys.resolve()),
)
# The client first need to create the private and evaluation keys.
serialized_evaluation_keys = client.get_serialized_evaluation_keys()
Expand All @@ -218,7 +219,7 @@ def init_fhe_client(
# towards client lazy loading with caching as done on the server.
self.clients[shape] = (uid, client)

def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:
def forward(self, *x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:
"""Forward pass of the remote module.
To change the behavior of this forward function one must change the fhe_local_mode
Expand All @@ -242,6 +243,10 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:
# - simulate: compiled simulation
# - calibrate: calibration

devices = [elt.device for elt in x]
device = devices[0]
assert all(elt.device == device for elt in x)

if self.fhe_local_mode not in {
HybridFHEMode.CALIBRATE,
HybridFHEMode.REMOTE,
Expand All @@ -250,24 +255,32 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:
}:
# Using quantized module
assert self.private_q_module is not None
y = torch.Tensor(
self.private_q_module.forward(x.detach().numpy(), fhe=self.fhe_local_mode.value)
out = self.private_q_module.forward(
*(elt.to("cpu").detach().numpy() for elt in x),
fhe=self.fhe_local_mode.value,
)

# TODO: support multi-output
y = torch.tensor(
out,
device=device,
dtype=torch.float64 if device == "cpu" else torch.float32,
)

elif self.fhe_local_mode == HybridFHEMode.CALIBRATE:
# Calling torch + gathering calibration data
assert self.private_module is not None
self.calibration_data.append(x.detach())
y = self.private_module(x)
self.calibration_data.append(tuple(elt.to("cpu").detach() for elt in x))
y = self.private_module(*x).to(device)
assert isinstance(y, (QuantTensor, torch.Tensor))

elif self.fhe_local_mode == HybridFHEMode.REMOTE: # pragma:no cover
# Remote call
y = self.remote_call(x)
y = self.remote_call(*x).to(device)
elif self.fhe_local_mode == HybridFHEMode.TORCH:
# Using torch layers
assert self.private_module is not None
y = self.private_module(x)
y = self.private_module(*x).to(device)
else: # pragma:no cover
# Shouldn't happen
raise ValueError(f"{self.fhe_local_mode} is not recognized")
Expand Down Expand Up @@ -371,6 +384,11 @@ def __init__(
self.verbose = verbose
self._replace_modules()

def __getattr__(self, name: str):
if name in self.__dict__:
return self.__dict__[name]
return getattr(self.model, name)

def _replace_modules(self):
"""Replace the private modules in the model with remote layers."""

Expand Down Expand Up @@ -399,7 +417,7 @@ def _replace_modules(self):
)
setattr(parent_module, last, remote_module)

def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:
def __call__(self, *x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:
"""Call method to run the model locally with a fhe mode.
Args:
Expand All @@ -410,8 +428,8 @@ def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:
(torch.Tensor): The output tensor.
"""
self.set_fhe_mode(fhe)
x = self.model(x)
return x
y = self.model(*x)
return y

@staticmethod
def _get_module_by_name(model: nn.Module, name: str) -> Union[RemoteModule, nn.Module]:
Expand All @@ -434,7 +452,9 @@ def _get_module_by_name(model: nn.Module, name: str) -> Union[RemoteModule, nn.M
raise ValueError(f"No module found for name {name} in {list(model.named_modules())}")

def init_client(
self, path_to_clients: Optional[Path] = None, path_to_keys: Optional[Path] = None
self,
path_to_clients: Optional[Path] = None,
path_to_keys: Optional[Path] = None,
): # pragma:no cover
"""Initialize client for all remote modules.
Expand All @@ -452,7 +472,7 @@ def init_client(

def compile_model(
self,
x: torch.Tensor,
*x: torch.Tensor,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
Expand All @@ -473,15 +493,32 @@ def compile_model(
"""
# We do a forward pass where we accumulate inputs to use for compilation
self.set_fhe_mode(HybridFHEMode.CALIBRATE)
self.model(x)

if self.verbose >= 2:
print("Doing first forward pass")

# Would there be a way to finish the execution after the last remote module
with torch.no_grad():
self.model(*x)

if self.verbose >= 2:
print("First forward pass done")

self.configuration = configuration

for name in self.module_names:
remote_module = self._get_module_by_name(self.model, name)
assert isinstance(remote_module, RemoteModule)

calibration_data_tensor = torch.cat(remote_module.calibration_data, dim=0)
assert remote_module.calibration_data
num_arg = len(remote_module.calibration_data[0])
assert all(num_arg == len(elt) for elt in remote_module.calibration_data)
calibration_data_tensor: Tuple[torch.Tensor, ...] = tuple(
torch.cat([elt[arg_index] for elt in remote_module.calibration_data], dim=0)
for arg_index in range(num_arg)
)
if self.verbose >= 2:
print(f"Compiling {name}")

if has_any_qnn_layers(self.private_modules[name]):
self.private_q_modules[name] = compile_brevitas_qat_model(
Expand All @@ -504,6 +541,9 @@ def compile_model(

self.remote_modules[name].private_q_module = self.private_q_modules[name]

if self.verbose >= 2:
print(f"Done compiling {name}")

def _save_fhe_circuit(self, path: Path, via_mlir=False):
"""Private method that saves the FHE circuits.
Expand Down
61 changes: 58 additions & 3 deletions tests/torch/test_hybrid_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from concrete.fhe import Configuration
from transformers import GPT2LMHeadModel, GPT2Tokenizer

from concrete.ml.pytest.torch_models import PartialQATModel
from concrete.ml.pytest.torch_models import MultiInputNNConfigurable, PartialQATModel
from concrete.ml.torch.hybrid_model import (
HybridFHEModel,
tuple_to_underscore_str,
Expand Down Expand Up @@ -50,7 +50,11 @@ def run_hybrid_llm_test(
# Create a hybrid model
hybrid_model = HybridFHEModel(model, module_names)
hybrid_model.compile_model(
inputs, p_error=0.1, n_bits=9, rounding_threshold_bits=8, configuration=configuration
inputs,
p_error=0.1,
n_bits=9,
rounding_threshold_bits=8,
configuration=configuration,
)

if has_pbs:
Expand Down Expand Up @@ -148,6 +152,57 @@ def test_gpt2_hybrid_mlp(list_or_str_private_modules_names, expected_accuracy, h
)


def test_multi_input_model():
"""Test hybrid multi-input sub-module."""
input_shape = 32
dataset_size = 100

class Model(torch.nn.Module):
"""Test model"""

def __init__(
self,
): # pylint: disable=unused-argument
super().__init__()
self.sub_model = MultiInputNNConfigurable(
use_conv=False, use_qat=False, input_output=input_shape, n_bits=None
)

def forward(self, x, y):
"""Forward method
Arguments:
x (torch.Tensor): first input
y (torch.Tensor): second input
Returns:
torch.Tensor: output of the model
"""
return self.sub_model(x, y)

inputs = (
torch.randn(
(
dataset_size,
input_shape,
)
),
torch.randn(
(
dataset_size,
input_shape,
)
),
)

model = Model()
# Run the test with using a single module in FHE
model(*inputs)
assert isinstance(model, torch.nn.Module)
hybrid_model = HybridFHEModel(model, module_names="sub_model")
hybrid_model.compile_model(*inputs, rounding_threshold_bits=6, n_bits=6)


def test_hybrid_brevitas_qat_model():
"""Test GPT2 hybrid."""
n_bits = 3
Expand All @@ -166,7 +221,7 @@ def test_hybrid_brevitas_qat_model():
model(inputs)
assert isinstance(model, torch.nn.Module)
hybrid_model = HybridFHEModel(model, module_names="sub_module")
hybrid_model.compile_model(x=inputs)
hybrid_model.compile_model(*inputs)


# Dependency 'huggingface-hub' raises a 'FutureWarning' from version 0.23.0 when calling the
Expand Down
Loading

0 comments on commit da36c0b

Please sign in to comment.