diff --git a/test/encoding/test_encodings.py b/test/encoding/test_encodings.py new file mode 100644 index 00000000..ece4bd76 --- /dev/null +++ b/test/encoding/test_encodings.py @@ -0,0 +1,309 @@ +""" +MIT License + +Copyright (c) 2020-present TorchQuantum Authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from __future__ import annotations +from typing import Callable +import pytest +from pytest import raises +from unittest import mock +from qiskit.circuit import QuantumCircuit +import torch +from torchquantum import ( + QuantumDevice, + GeneralEncoder, + StateEncoder, + PhaseEncoder, + MultiPhaseEncoder, +) +from torchquantum.functional import func_name_dict + + +class TestGeneralEncoder: + """Test class for General Encoder.""" + + @pytest.mark.parametrize("func_list", [None, 1, 2.4, True, list(range(2))]) + def test_invalid_func_list(self, func_list): + with raises( + TypeError, match=r"The input func_list must be of the type list\[dict\]\." + ): + _ = GeneralEncoder(func_list) + + @pytest.mark.parametrize( + "func_list", + [[{"key1": 1}], [{"func": "rx"}], [{"func": "rx", "input_idx": [0]}]], + ) + def test_func_list_keys(self, func_list): + with raises( + ValueError, + match="The dictionary in func_list must is missing func or wires.", + ): + _ = GeneralEncoder(func_list) + + @pytest.mark.parametrize( + "wires, func_list", + [ + (1, [{"input_idx": [0], "func": "ry", "wires": [0]}]), + ( + 2, + [ + {"input_idx": [0], "func": "ry", "wires": [0]}, + {"input_idx": [1], "func": "ry", "wires": [1]}, + ], + ), + ( + 4, + [ + {"input_idx": [0], "func": "rz", "wires": [0]}, + {"input_idx": None, "func": "sx", "wires": [0]}, + {"input_idx": [2], "func": "rx", "wires": [2]}, + {"input_idx": [3], "func": "ry", "wires": [3]}, + ], + ), + ], + ) + def test_general_encoding(self, wires, func_list): + """Tests the GeneralEncoder class.""" + encoder = GeneralEncoder(func_list) + qdev = QuantumDevice(wires) + mock_func = mock.Mock() + for func_dict in func_list: + func = func_dict["func"] + with mock.patch.dict(func_name_dict, {func: mock_func}): + encoder(qdev, torch.rand(1, pow(2, wires))) + assert mock_func.call_count >= 1 + + @pytest.mark.parametrize( + "batch_size, wires, func_list", + [ + (2, 1, [{"input_idx": [0], "func": "rz", "wires": [0]}]), + ( + 4, + 2, + [ + {"input_idx": [0], "func": "ryy", "wires": [0, 1]}, + {"input_idx": [1], "func": "rx", "wires": [1]}, + ], + ), + ( + 2, + 4, + [ + {"input_idx": [0], "func": "rzz", "wires": [0, 2]}, + {"input_idx": [1], "func": "rxx", "wires": [1, 2]}, + {"input_idx": [2], "func": "ry", "wires": [2]}, + {"input_idx": [3], "func": "rzx", "wires": [1, 3]}, + ], + ), + ], + ) + def test_to_qiskit(self, batch_size, wires, func_list): + """Tests conversion of GeneralEncoder to Qiskit.""" + x = torch.rand(batch_size, pow(2, wires)) + encoder = GeneralEncoder(func_list) + qdev = QuantumDevice(n_wires=wires, bsz=batch_size) + encoder(qdev, x) + resulting_circuit = encoder.to_qiskit(wires, x) + for circuit in resulting_circuit: + assert isinstance(circuit, QuantumCircuit) + + @pytest.mark.parametrize( + "batch_size, wires, func_list", + [ + (2, 1, [{"input_idx": [0], "func": "hadamard", "wires": [0]}]), + (2, 2, [{"input_idx": [0], "func": "xx", "wires": [0, 1]}]), + ], + ) + def test_not_implemeted_qiskit(self, batch_size, wires, func_list): + """Tests conversion of GeneralEncoder to Qiskit.""" + x = torch.rand(batch_size, pow(2, wires)) + encoder = GeneralEncoder(func_list) + qdev = QuantumDevice(n_wires=wires, bsz=batch_size) + encoder(qdev, x) + with raises(NotImplementedError, match=r"([\s\S]*?) is not supported yet\."): + _ = encoder.to_qiskit(wires, x) + + +class TestPhaseEncoder: + """Test class for Phase Encoder.""" + + @pytest.mark.parametrize("func", [None, 1, 2.4, {}, True, list(range(2))]) + def test_func_type(self, func): + """Test the type of func input""" + with raises(TypeError, match="The input func must be of the type str."): + _ = PhaseEncoder(func) + + @pytest.mark.parametrize("func", ["hadamard", "ry", "xx", "paulix", "i"]) + def test_phase_encoding(self, func): + """Tests the PhaseEncoder class.""" + assert func in func_name_dict + encoder = PhaseEncoder(func) + qdev = QuantumDevice(2) + with mock.patch.object(encoder, "func") as mock_func: + encoder(qdev, torch.rand(2, 4)) + assert mock_func.call_count >= 1 + + +class TestMultiPhaseEncoder: + """Test class for Multi-phase Encoder.""" + + @pytest.mark.parametrize( + "wires, funcs", + [ + (10, ["rx", "hadamard"]), + (2, ["swap", "ry"]), + (3, ["xx"]), + (1, ["paulix", "i"]), + ], + ) + def test_invalid_func(self, wires, funcs): + with raises(ValueError, match=r"The func (.*?) is not supported\."): + encoder = MultiPhaseEncoder(funcs) + qdev = QuantumDevice(n_wires=wires) + encoder(qdev, torch.rand(1, pow(2, wires))) + + # NOTE: Test with func = u1 currently fails. + @pytest.mark.parametrize( + "batch_size, wires, funcs", + [(2, 5, ["ry", "phaseshift"]), (1, 4, ["u2"]), (3, 1, ["u3"])], + ) + def test_phase_encoding(self, batch_size, wires, funcs): + """Tests the MultiPhaseEncoder class.""" + # wires = 4 + encoder = MultiPhaseEncoder(funcs) + qdev = QuantumDevice(n_wires=wires, bsz=batch_size) + mock_func = mock.Mock() + for func in encoder.funcs: + with mock.patch.dict(func_name_dict, {func: mock_func}): + encoder(qdev, torch.rand(batch_size, pow(2, wires))) + assert mock_func.call_count >= 1 + + +class TestStateEncoder: + """Test class for State Encoder.""" + + @pytest.mark.parametrize( + "qdev", + [{}, list(range(10)), None, 1, True], + ) + def test_qdev(self, qdev): + with raises( + TypeError, + match=r"The qdev input ([\s\S]*?) must be of the type tq\.QuantumDevice\.", + ): + encoder = StateEncoder() + encoder.forward(qdev, torch.rand(2, 2)) + + @pytest.mark.parametrize( + "wires, x", [(2, {}), (4, list(range(10))), (1, None), (10, True), (5, 1)] + ) + def test_type_x(self, wires, x): + with raises( + TypeError, + match=r"The x input ([\s\S]*?) must be of the type torch\.Tensor\.", + ): + qdev = QuantumDevice(wires) + encoder = StateEncoder() + encoder(qdev, x) + + @pytest.mark.parametrize( + "wires, x", + [(2, torch.rand(2, 7)), (4, torch.rand(1, 20)), (1, torch.rand(1, 10))], + ) + def test_size(self, wires, x): + with raises( + ValueError, + match=r"The size of tensors in x \(\d+\) must be less than or " + r"equal to \d+ for a QuantumDevice with " + r"\d+ wires\.", + ): + qdev = QuantumDevice(wires) + encoder = StateEncoder() + encoder(qdev, x) + + @pytest.mark.parametrize( + "wires, x, x_norm", + [ + ( + 2, + [[0.3211], [0.1947]], + [ + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + ], + ), + ( + 4, + [ + [ + 0.1287, + 0.9234, + 0.4864, + 0.6410, + 0.4804, + 0.9749, + 0.1846, + 0.3128, + 0.0897, + 0.4703, + ] + ], + [ + [ + 0.0736, + 0.5279, + 0.2781, + 0.3665, + 0.2747, + 0.5574, + 0.1056, + 0.1788, + 0.0513, + 0.2689, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + ] + ], + ), + (1, [[0.7275, 0.3252]], [[0.9129, 0.4081]]), + ], + ) + def test_state_encoding(self, wires, x, x_norm): + """ + Tests the state encoding performed + by the StateEncoder class. + """ + x, x_norm = torch.tensor(x), torch.tensor(x_norm) + qdev = QuantumDevice(wires) + encoder = StateEncoder() + encoder(qdev, x) + + assert qdev.states.shape[0] == x.shape[0] + assert qdev.states.reshape(x.shape[0], -1).shape == (x.shape[0], pow(2, wires)) + assert torch.allclose( + qdev.states.reshape(x.shape[0], -1), x_norm.type(torch.complex64), atol=1e-3 + ) diff --git a/torchquantum/encoding/encodings.py b/torchquantum/encoding/encodings.py index f8d2056d..7f06a048 100644 --- a/torchquantum/encoding/encodings.py +++ b/torchquantum/encoding/encodings.py @@ -39,13 +39,24 @@ class Encoder(tq.QuantumModule): - forward(qdev: tq.QuantumDevice, x): Performs the encoding using a quantum device. """ + def __init__(self): super().__init__() pass - def forward(self, qdev: tq.QuantumDevice, x): + def forward(self, qdev: tq.QuantumDevice, x: torch.Tensor): raise NotImplementedError + @staticmethod + def validate_inputs(qdev: tq.QuantumDevice, x: torch.Tensor): + if not isinstance(qdev, tq.QuantumDevice): + raise TypeError( + f"The qdev input {qdev} must be of the type tq.QuantumDevice." + ) + + if not isinstance(x, torch.Tensor): + raise TypeError(f"The x input {x} must be of the type torch.Tensor.") + class GeneralEncoder(Encoder, metaclass=ABCMeta): """func_list list of dict @@ -84,10 +95,28 @@ class GeneralEncoder(Encoder, metaclass=ABCMeta): def __init__(self, func_list): super().__init__() + + if not isinstance(func_list, list) or not all( + isinstance(func_dict, dict) for func_dict in func_list + ): + raise TypeError("The input func_list must be of the type list[dict].") + + if any( + ( + "func" not in func_dict + or ("func" in func_dict and "wires" not in func_dict) + ) + for func_dict in func_list + ): + raise ValueError( + "The dictionary in func_list must is missing func or wires." + ) self.func_list = func_list @tq.static_support def forward(self, qdev: tq.QuantumDevice, x): + # Validate inputs + self.validate_inputs(qdev, x) for info in self.func_list: if tq.op_name_dict[info["func"]].num_params > 0: params = x[:, info["input_idx"]] @@ -124,7 +153,7 @@ def to_qiskit(self, n_wires, x): elif info["func"] == "rzx": circ.rzx(x[k][info["input_idx"][0]].item(), *info["wires"]) else: - raise NotImplementedError(info["func"]) + raise NotImplementedError(f"{info['func']} is not supported yet.") circs.append(circ) return circs @@ -132,25 +161,33 @@ def to_qiskit(self, n_wires, x): class PhaseEncoder(Encoder, metaclass=ABCMeta): """PhaseEncoder is a subclass of Encoder and represents a phase encoder. - It applies a specified quantum function to encode input data using a quantum device.""" - def __init__(self, func): + It applies a specified quantum function to encode input data using a quantum device. + """ + + def __init__(self, func: str): super().__init__() - self.func = func + + if not isinstance(func, str): + raise TypeError("The input func must be of the type str.") + self.func = func_name_dict[func] @tq.static_support def forward(self, qdev: tq.QuantumDevice, x): """ - Performs the encoding using a quantum device. + Performs the encoding using a quantum device. - Args: - qdev (tq.QuantumDevice): The quantum device to be used for encoding. - x (torch.Tensor): The input data to be encoded. + Args: + qdev (tq.QuantumDevice): The quantum device to be used for encoding. + x (torch.Tensor): The input data to be encoded. - Returns: - torch.Tensor: The encoded data. + Returns: + torch.Tensor: The encoded data. - """ + """ + # Validate inputs + self.validate_inputs(qdev, x) for k in range(qdev.n_wires): + print("Calling") self.func( qdev, wires=k, @@ -162,7 +199,9 @@ def forward(self, qdev: tq.QuantumDevice, x): class MultiPhaseEncoder(Encoder, metaclass=ABCMeta): """PhaseEncoder is a subclass of Encoder and represents a phase encoder. - It applies a specified quantum function to encode input data using a quantum device.""" + It applies a specified quantum function to encode input data using a quantum device. + """ + def __init__(self, funcs, wires=None): super().__init__() self.funcs = funcs if isinstance(funcs, Iterable) else [funcs] @@ -171,18 +210,21 @@ def __init__(self, funcs, wires=None): @tq.static_support def forward(self, qdev: tq.QuantumDevice, x): """ - Performs the encoding using a quantum device. + Performs the encoding using a quantum device. - Args: - qdev (tq.QuantumDevice): The quantum device to be used for encoding. - x (torch.Tensor): The input data to be encoded. + Args: + qdev (tq.QuantumDevice): The quantum device to be used for encoding. + x (torch.Tensor): The input data to be encoded. - Returns: - torch.Tensor: The encoded data. + Returns: + torch.Tensor: The encoded data. - """ + """ + # Validate inputs + self.validate_inputs(qdev, x) if self.wires is None: - self.wires = list(range(qdev.n_wires)) * (len(self.funcs) // qdev.n_wires) + # self.wires = list(range(qdev.n_wires)) * (len(self.funcs) // qdev.n_wires) + self.wires = list(range(qdev.n_wires + (len(self.funcs) // qdev.n_wires))) x_id = 0 for k, func in enumerate(self.funcs): @@ -193,7 +235,7 @@ def forward(self, qdev: tq.QuantumDevice, x): elif func == "u3": stride = 3 else: - raise ValueError(func) + raise ValueError(f"The func {func} is not supported.") func_name_dict[func]( qdev, @@ -208,30 +250,31 @@ def forward(self, qdev: tq.QuantumDevice, x): class StateEncoder(Encoder, metaclass=ABCMeta): """StateEncoder is a subclass of Encoder and represents a state encoder. It encodes the input data into the state vector of a quantum device.""" + def __init__(self): super().__init__() def forward(self, qdev: tq.QuantumDevice, x): """ - Performs the encoding by preparing the state vector of the quantum device. + Performs the encoding by preparing the state vector of the quantum device. - Args: - qdev (tq.QuantumDevice): The quantum device to be used for encoding. - x (torch.Tensor): The input data to be encoded. - Returns: - torch.Tensor: The encoded data. + Args: + qdev (tq.QuantumDevice): The quantum device to be used for encoding. + x (torch.Tensor): The input data to be encoded. + Returns: + torch.Tensor: The encoded data. - """ + """ + # Validate inputs + self.validate_inputs(qdev, x) + self.validate_tensor_size(qdev, x) # encoder the x to the statevector of the quantum device - # normalize the input x = x / (torch.sqrt((x.abs() ** 2).sum(dim=-1))).unsqueeze(-1) state = torch.cat( ( x, - torch.zeros( - x.shape[0], 2**qdev.n_wires - x.shape[1], device=x.device - ), + torch.zeros(x.shape[0], 2**qdev.n_wires - x.shape[1], device=x.device), ), dim=-1, ) @@ -239,10 +282,20 @@ def forward(self, qdev: tq.QuantumDevice, x): qdev.states = state.type(C_DTYPE) + @staticmethod + def validate_tensor_size(qdev: tq.QuantumDevice, x): + if any(tensor.size()[0] > pow(2, qdev.n_wires) for tensor in x): + raise ValueError( + f"The size of tensors in x ({x.size()[1]}) must be less than or " + f"equal to {pow(2, qdev.n_wires)} for a QuantumDevice with " + f"{qdev.n_wires} wires." + ) + class MagnitudeEncoder(Encoder, metaclass=ABCMeta): """MagnitudeEncoder is a subclass of Encoder and represents a magnitude encoder. It encodes the input data by considering the magnitudes of the elements.""" + def __init__(self): super().__init__()