From 88b2eaa10e04d2fd8697651260f20f441252016f Mon Sep 17 00:00:00 2001 From: SaashaJoshi Date: Sun, 2 Jun 2024 21:15:41 -0700 Subject: [PATCH 1/8] add unit tests for encoding module --- test/encoding/test_encodings.py | 139 +++++++++++++++++++++++++++++ torchquantum/encoding/encodings.py | 18 +++- 2 files changed, 155 insertions(+), 2 deletions(-) create mode 100644 test/encoding/test_encodings.py diff --git a/test/encoding/test_encodings.py b/test/encoding/test_encodings.py new file mode 100644 index 00000000..74f4d61d --- /dev/null +++ b/test/encoding/test_encodings.py @@ -0,0 +1,139 @@ +""" +MIT License + +Copyright (c) 2020-present TorchQuantum Authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from __future__ import annotations +import pytest +from pytest import raises +import torch +from torchquantum import QuantumDevice +from torchquantum.encoding import StateEncoder + + +class TestStateEncoding: + """Test class for State Encoder.""" + + @pytest.mark.parametrize( + "qdev", + [{}, list(range(10)), None, 1, True], + ) + def test_qdev(self, qdev): + with raises( + TypeError, + match=r"The qdev input ([\s\S]*?) must be of the type tq\.QuantumDevice\.", + ): + encoder = StateEncoder() + encoder(qdev, torch.rand(2, 2)) + + @pytest.mark.parametrize( + "wires, x", + [(2, {}), (4, list(range(10))), (1, None), (10, True), (5, 1)] + ) + def test_type_x(self, wires, x): + with raises( + TypeError, + match=r"The x input ([\s\S]*?) must be of the type torch\.Tensor\.", + ): + qdev = QuantumDevice(wires) + encoder = StateEncoder() + encoder(qdev, x) + + @pytest.mark.parametrize( + "wires, x", + [(2, torch.rand(2, 7)), (4, torch.rand(1, 20)), (1, torch.rand(1, 10))], + ) + def test_size(self, wires, x): + with raises( + ValueError, + match=r"The size of tensors in x \(\d+\) must be less than or " + r"equal to \d+ for a QuantumDevice with " + r"\d+ wires\.", + ): + qdev = QuantumDevice(wires) + encoder = StateEncoder() + encoder(qdev, x) + + @pytest.mark.parametrize( + "wires, x, x_norm", + [ + ( + 2, + [[0.3211], [0.1947]], + [ + [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + ], + ), + ( + 4, + [ + [ + 0.1287, + 0.9234, + 0.4864, + 0.6410, + 0.4804, + 0.9749, + 0.1846, + 0.3128, + 0.0897, + 0.4703, + ] + ], + [ + [ + 0.0736, + 0.5279, + 0.2781, + 0.3665, + 0.2747, + 0.5574, + 0.1056, + 0.1788, + 0.0513, + 0.2689, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + ] + ], + ), + (1, [[0.7275, 0.3252]], [[0.9129, 0.4081]]), + ], + ) + def test_state_encoding(self, wires, x, x_norm): + """ + Tests the state encoding performed + by the StateEncoder class. + """ + x, x_norm = torch.tensor(x), torch.tensor(x_norm) + qdev = QuantumDevice(wires) + encoder = StateEncoder() + encoder(qdev, x) + + assert qdev.states.shape[0] == x.shape[0] + assert qdev.states.reshape(x.shape[0], -1).shape == (x.shape[0], pow(2, wires)) + assert torch.allclose(qdev.states.reshape(x.shape[0], -1), x_norm.type(torch.complex64), atol=1e-3) diff --git a/torchquantum/encoding/encodings.py b/torchquantum/encoding/encodings.py index f8d2056d..123ac343 100644 --- a/torchquantum/encoding/encodings.py +++ b/torchquantum/encoding/encodings.py @@ -43,9 +43,22 @@ def __init__(self): super().__init__() pass - def forward(self, qdev: tq.QuantumDevice, x): + def forward(self, qdev: tq.QuantumDevice, x: torch.Tensor): raise NotImplementedError + @staticmethod + def validate_inputs(qdev: tq.QuantumDevice, x: torch.Tensor): + if not isinstance(qdev, tq.QuantumDevice): + raise TypeError(f"The qdev input {qdev} must be of the type tq.QuantumDevice.") + + if not isinstance(x, torch.Tensor): + raise TypeError(f"The x input {x} must be of the type torch.Tensor.") + + if any(tensor.size()[0] > pow(2, qdev.n_wires) for tensor in x): + raise ValueError(f"The size of tensors in x ({x.size()[1]}) must be less than or " + f"equal to {pow(2, qdev.n_wires)} for a QuantumDevice with " + f"{qdev.n_wires} wires.") + class GeneralEncoder(Encoder, metaclass=ABCMeta): """func_list list of dict @@ -222,8 +235,9 @@ def forward(self, qdev: tq.QuantumDevice, x): torch.Tensor: The encoded data. """ + # Validate inputs + self.validate_inputs(qdev, x) # encoder the x to the statevector of the quantum device - # normalize the input x = x / (torch.sqrt((x.abs() ** 2).sum(dim=-1))).unsqueeze(-1) state = torch.cat( From b99d198b78ae5f8b02c13e4063e509df387dc7b7 Mon Sep 17 00:00:00 2001 From: SaashaJoshi Date: Sun, 2 Jun 2024 23:10:32 -0700 Subject: [PATCH 2/8] Adding unit test for phase encoder --- test/encoding/test_encodings.py | 28 ++++++++++++++++++++++++++-- torchquantum/encoding/encodings.py | 8 ++++++-- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/test/encoding/test_encodings.py b/test/encoding/test_encodings.py index 74f4d61d..2dd67ace 100644 --- a/test/encoding/test_encodings.py +++ b/test/encoding/test_encodings.py @@ -23,11 +23,14 @@ """ from __future__ import annotations +from typing import Callable import pytest from pytest import raises +from unittest import mock import torch -from torchquantum import QuantumDevice +from torchquantum import QuantumDevice, PhaseEncoder from torchquantum.encoding import StateEncoder +from torchquantum.functional import func_name_dict class TestStateEncoding: @@ -43,7 +46,7 @@ def test_qdev(self, qdev): match=r"The qdev input ([\s\S]*?) must be of the type tq\.QuantumDevice\.", ): encoder = StateEncoder() - encoder(qdev, torch.rand(2, 2)) + encoder.forward(qdev, torch.rand(2, 2)) @pytest.mark.parametrize( "wires, x", @@ -137,3 +140,24 @@ def test_state_encoding(self, wires, x, x_norm): assert qdev.states.shape[0] == x.shape[0] assert qdev.states.reshape(x.shape[0], -1).shape == (x.shape[0], pow(2, wires)) assert torch.allclose(qdev.states.reshape(x.shape[0], -1), x_norm.type(torch.complex64), atol=1e-3) + + +class TestPhaseEncoding: + """Test class for Phase Encoder.""" + + @pytest.mark.parametrize("func", [None, 1, 2.4, {}, True, list(range(2))]) + def test_func_type(self, func): + """Test the type of func input""" + with raises(TypeError, match="The input func must be of the type str."): + _ = PhaseEncoder(func) + + # + @pytest.mark.parametrize("func", ["hadamard", "ry", "xx", "paulix", "i"]) + def test_phase_encoding(self, func): + """Tests the PhaseEncoder class.""" + assert func in func_name_dict + encoder = PhaseEncoder(func) + with mock.patch.object(encoder, "func") as mock_func: + qdev = QuantumDevice(2) + encoder.forward(qdev, torch.rand(2, 4)) + assert mock_func.call_count > 1 diff --git a/torchquantum/encoding/encodings.py b/torchquantum/encoding/encodings.py index 123ac343..ad27ce9c 100644 --- a/torchquantum/encoding/encodings.py +++ b/torchquantum/encoding/encodings.py @@ -146,9 +146,12 @@ def to_qiskit(self, n_wires, x): class PhaseEncoder(Encoder, metaclass=ABCMeta): """PhaseEncoder is a subclass of Encoder and represents a phase encoder. It applies a specified quantum function to encode input data using a quantum device.""" - def __init__(self, func): + def __init__(self, func: str): super().__init__() - self.func = func + + if not isinstance(func, str): + raise TypeError("The input func must be of the type str.") + self.func = func_name_dict[func] @tq.static_support def forward(self, qdev: tq.QuantumDevice, x): @@ -164,6 +167,7 @@ def forward(self, qdev: tq.QuantumDevice, x): """ for k in range(qdev.n_wires): + print("Calling") self.func( qdev, wires=k, From 466e0029771c6396d1a28ee08805967aff32df63 Mon Sep 17 00:00:00 2001 From: SaashaJoshi Date: Mon, 3 Jun 2024 00:00:42 -0700 Subject: [PATCH 3/8] Adding unit tets for multi-phase encoder --- test/encoding/test_encodings.py | 59 ++++++++++++++++++------ torchquantum/encoding/encodings.py | 74 +++++++++++++++++------------- 2 files changed, 88 insertions(+), 45 deletions(-) diff --git a/test/encoding/test_encodings.py b/test/encoding/test_encodings.py index 2dd67ace..5248fbdc 100644 --- a/test/encoding/test_encodings.py +++ b/test/encoding/test_encodings.py @@ -28,8 +28,7 @@ from pytest import raises from unittest import mock import torch -from torchquantum import QuantumDevice, PhaseEncoder -from torchquantum.encoding import StateEncoder +from torchquantum import QuantumDevice, StateEncoder, PhaseEncoder, MultiPhaseEncoder from torchquantum.functional import func_name_dict @@ -42,20 +41,19 @@ class TestStateEncoding: ) def test_qdev(self, qdev): with raises( - TypeError, - match=r"The qdev input ([\s\S]*?) must be of the type tq\.QuantumDevice\.", + TypeError, + match=r"The qdev input ([\s\S]*?) must be of the type tq\.QuantumDevice\.", ): encoder = StateEncoder() encoder.forward(qdev, torch.rand(2, 2)) @pytest.mark.parametrize( - "wires, x", - [(2, {}), (4, list(range(10))), (1, None), (10, True), (5, 1)] + "wires, x", [(2, {}), (4, list(range(10))), (1, None), (10, True), (5, 1)] ) def test_type_x(self, wires, x): with raises( - TypeError, - match=r"The x input ([\s\S]*?) must be of the type torch\.Tensor\.", + TypeError, + match=r"The x input ([\s\S]*?) must be of the type torch\.Tensor\.", ): qdev = QuantumDevice(wires) encoder = StateEncoder() @@ -139,7 +137,9 @@ def test_state_encoding(self, wires, x, x_norm): assert qdev.states.shape[0] == x.shape[0] assert qdev.states.reshape(x.shape[0], -1).shape == (x.shape[0], pow(2, wires)) - assert torch.allclose(qdev.states.reshape(x.shape[0], -1), x_norm.type(torch.complex64), atol=1e-3) + assert torch.allclose( + qdev.states.reshape(x.shape[0], -1), x_norm.type(torch.complex64), atol=1e-3 + ) class TestPhaseEncoding: @@ -151,13 +151,46 @@ def test_func_type(self, func): with raises(TypeError, match="The input func must be of the type str."): _ = PhaseEncoder(func) - # @pytest.mark.parametrize("func", ["hadamard", "ry", "xx", "paulix", "i"]) def test_phase_encoding(self, func): """Tests the PhaseEncoder class.""" assert func in func_name_dict encoder = PhaseEncoder(func) + qdev = QuantumDevice(2) with mock.patch.object(encoder, "func") as mock_func: - qdev = QuantumDevice(2) - encoder.forward(qdev, torch.rand(2, 4)) - assert mock_func.call_count > 1 + encoder(qdev, torch.rand(2, 4)) + assert mock_func.call_count >= 1 + + +class TestMultiPhaseEncoding: + """Test class for Multi-phase Encoder.""" + + @pytest.mark.parametrize( + "wires, funcs", + [ + (10, ["rx", "hadamard"]), + (2, ["swap", "ry"]), + (3, ["xx"]), + (1, ["paulix", "i"]), + ], + ) + def test_invalid_func(self, wires, funcs): + with raises(ValueError, match=r"The func (.*?) is not supported\."): + encoder = MultiPhaseEncoder(funcs) + qdev = QuantumDevice(n_wires=wires) + encoder(qdev, torch.rand(1, pow(2, wires))) + + # NOTE: Test with func = u1 currently fails. + @pytest.mark.parametrize( + "wires, funcs", [(5, ["ry", "phaseshift"]), (4, ["u2"]), (1, ["u3"])] + ) + def test_phase_encoding(self, wires, funcs): + """Tests the MultiPhaseEncoder class.""" + # wires = 4 + encoder = MultiPhaseEncoder(funcs) + qdev = QuantumDevice(n_wires=wires) + mock_func = mock.Mock() + for func in encoder.funcs: + with mock.patch.dict(func_name_dict, {func: mock_func}): + encoder(qdev, torch.rand(1, pow(2, wires))) + assert mock_func.call_count >= 1 diff --git a/torchquantum/encoding/encodings.py b/torchquantum/encoding/encodings.py index ad27ce9c..0adb68d3 100644 --- a/torchquantum/encoding/encodings.py +++ b/torchquantum/encoding/encodings.py @@ -39,6 +39,7 @@ class Encoder(tq.QuantumModule): - forward(qdev: tq.QuantumDevice, x): Performs the encoding using a quantum device. """ + def __init__(self): super().__init__() pass @@ -49,15 +50,19 @@ def forward(self, qdev: tq.QuantumDevice, x: torch.Tensor): @staticmethod def validate_inputs(qdev: tq.QuantumDevice, x: torch.Tensor): if not isinstance(qdev, tq.QuantumDevice): - raise TypeError(f"The qdev input {qdev} must be of the type tq.QuantumDevice.") + raise TypeError( + f"The qdev input {qdev} must be of the type tq.QuantumDevice." + ) if not isinstance(x, torch.Tensor): raise TypeError(f"The x input {x} must be of the type torch.Tensor.") if any(tensor.size()[0] > pow(2, qdev.n_wires) for tensor in x): - raise ValueError(f"The size of tensors in x ({x.size()[1]}) must be less than or " - f"equal to {pow(2, qdev.n_wires)} for a QuantumDevice with " - f"{qdev.n_wires} wires.") + raise ValueError( + f"The size of tensors in x ({x.size()[1]}) must be less than or " + f"equal to {pow(2, qdev.n_wires)} for a QuantumDevice with " + f"{qdev.n_wires} wires." + ) class GeneralEncoder(Encoder, metaclass=ABCMeta): @@ -145,7 +150,9 @@ def to_qiskit(self, n_wires, x): class PhaseEncoder(Encoder, metaclass=ABCMeta): """PhaseEncoder is a subclass of Encoder and represents a phase encoder. - It applies a specified quantum function to encode input data using a quantum device.""" + It applies a specified quantum function to encode input data using a quantum device. + """ + def __init__(self, func: str): super().__init__() @@ -156,16 +163,16 @@ def __init__(self, func: str): @tq.static_support def forward(self, qdev: tq.QuantumDevice, x): """ - Performs the encoding using a quantum device. + Performs the encoding using a quantum device. - Args: - qdev (tq.QuantumDevice): The quantum device to be used for encoding. - x (torch.Tensor): The input data to be encoded. + Args: + qdev (tq.QuantumDevice): The quantum device to be used for encoding. + x (torch.Tensor): The input data to be encoded. - Returns: - torch.Tensor: The encoded data. + Returns: + torch.Tensor: The encoded data. - """ + """ for k in range(qdev.n_wires): print("Calling") self.func( @@ -179,7 +186,9 @@ def forward(self, qdev: tq.QuantumDevice, x): class MultiPhaseEncoder(Encoder, metaclass=ABCMeta): """PhaseEncoder is a subclass of Encoder and represents a phase encoder. - It applies a specified quantum function to encode input data using a quantum device.""" + It applies a specified quantum function to encode input data using a quantum device. + """ + def __init__(self, funcs, wires=None): super().__init__() self.funcs = funcs if isinstance(funcs, Iterable) else [funcs] @@ -188,18 +197,19 @@ def __init__(self, funcs, wires=None): @tq.static_support def forward(self, qdev: tq.QuantumDevice, x): """ - Performs the encoding using a quantum device. + Performs the encoding using a quantum device. - Args: - qdev (tq.QuantumDevice): The quantum device to be used for encoding. - x (torch.Tensor): The input data to be encoded. + Args: + qdev (tq.QuantumDevice): The quantum device to be used for encoding. + x (torch.Tensor): The input data to be encoded. - Returns: - torch.Tensor: The encoded data. + Returns: + torch.Tensor: The encoded data. - """ + """ if self.wires is None: - self.wires = list(range(qdev.n_wires)) * (len(self.funcs) // qdev.n_wires) + # self.wires = list(range(qdev.n_wires)) * (len(self.funcs) // qdev.n_wires) + self.wires = list(range(qdev.n_wires + (len(self.funcs) // qdev.n_wires))) x_id = 0 for k, func in enumerate(self.funcs): @@ -210,7 +220,7 @@ def forward(self, qdev: tq.QuantumDevice, x): elif func == "u3": stride = 3 else: - raise ValueError(func) + raise ValueError(f"The func {func} is not supported.") func_name_dict[func]( qdev, @@ -225,20 +235,21 @@ def forward(self, qdev: tq.QuantumDevice, x): class StateEncoder(Encoder, metaclass=ABCMeta): """StateEncoder is a subclass of Encoder and represents a state encoder. It encodes the input data into the state vector of a quantum device.""" + def __init__(self): super().__init__() def forward(self, qdev: tq.QuantumDevice, x): """ - Performs the encoding by preparing the state vector of the quantum device. + Performs the encoding by preparing the state vector of the quantum device. - Args: - qdev (tq.QuantumDevice): The quantum device to be used for encoding. - x (torch.Tensor): The input data to be encoded. - Returns: - torch.Tensor: The encoded data. + Args: + qdev (tq.QuantumDevice): The quantum device to be used for encoding. + x (torch.Tensor): The input data to be encoded. + Returns: + torch.Tensor: The encoded data. - """ + """ # Validate inputs self.validate_inputs(qdev, x) # encoder the x to the statevector of the quantum device @@ -247,9 +258,7 @@ def forward(self, qdev: tq.QuantumDevice, x): state = torch.cat( ( x, - torch.zeros( - x.shape[0], 2**qdev.n_wires - x.shape[1], device=x.device - ), + torch.zeros(x.shape[0], 2**qdev.n_wires - x.shape[1], device=x.device), ), dim=-1, ) @@ -261,6 +270,7 @@ def forward(self, qdev: tq.QuantumDevice, x): class MagnitudeEncoder(Encoder, metaclass=ABCMeta): """MagnitudeEncoder is a subclass of Encoder and represents a magnitude encoder. It encodes the input data by considering the magnitudes of the elements.""" + def __init__(self): super().__init__() From 59f36d5e9bf52e47bd079e7f2edc44fbfdd72c2c Mon Sep 17 00:00:00 2001 From: SaashaJoshi Date: Mon, 3 Jun 2024 22:25:03 -0700 Subject: [PATCH 4/8] Add unit tets for general encoder. --- test/encoding/test_encodings.py | 224 +++++++++++++++++++++-------- torchquantum/encoding/encodings.py | 18 ++- 2 files changed, 185 insertions(+), 57 deletions(-) diff --git a/test/encoding/test_encodings.py b/test/encoding/test_encodings.py index 5248fbdc..c24e240f 100644 --- a/test/encoding/test_encodings.py +++ b/test/encoding/test_encodings.py @@ -27,12 +27,178 @@ import pytest from pytest import raises from unittest import mock +from qiskit.circuit import QuantumCircuit import torch -from torchquantum import QuantumDevice, StateEncoder, PhaseEncoder, MultiPhaseEncoder +from torchquantum import ( + QuantumDevice, + GeneralEncoder, + StateEncoder, + PhaseEncoder, + MultiPhaseEncoder, +) from torchquantum.functional import func_name_dict -class TestStateEncoding: +class TestGeneralEncoder: + """Test class for General Encoder.""" + + @pytest.mark.parametrize("func_list", [None, 1, 2.4, True, list(range(2))]) + def test_invalid_func_list(self, func_list): + with raises( + TypeError, match=r"The input func_list must be of the type list\[dict\]\." + ): + _ = GeneralEncoder(func_list) + + @pytest.mark.parametrize( + "func_list", [[{"key1": 1}], [{"func": "rx"}], [{"func": "rx", "key2": None}]] + ) + def test_func_list_keys(self, func_list): + with raises( + ValueError, + match="The dictionary in func_list must contain the " + "keys: input_idx, func, and wires.", + ): + _ = GeneralEncoder(func_list) + + @pytest.mark.parametrize( + "wires, func_list", + [ + (1, [{"input_idx": [0], "func": "ry", "wires": [0]}]), + ( + 2, + [ + {"input_idx": [0], "func": "ry", "wires": [0]}, + {"input_idx": [1], "func": "ry", "wires": [1]}, + ], + ), + ( + 4, + [ + {"input_idx": [0], "func": "rz", "wires": [0]}, + {"input_idx": None, "func": "sx", "wires": [0]}, + {"input_idx": [2], "func": "rx", "wires": [2]}, + {"input_idx": [3], "func": "ry", "wires": [3]}, + ], + ), + ], + ) + def test_general_encoding(self, wires, func_list): + """Tests the GeneralEncoder class.""" + encoder = GeneralEncoder(func_list) + qdev = QuantumDevice(wires) + mock_func = mock.Mock() + for func_dict in func_list: + func = func_dict["func"] + with mock.patch.dict(func_name_dict, {func: mock_func}): + encoder(qdev, torch.rand(1, pow(2, wires))) + assert mock_func.call_count >= 1 + + @pytest.mark.parametrize( + "batch_size, wires, func_list", + [ + (2, 1, [{"input_idx": [0], "func": "rz", "wires": [0]}]), + ( + 4, + 2, + [ + {"input_idx": [0], "func": "ryy", "wires": [0, 1]}, + {"input_idx": [1], "func": "rx", "wires": [1]}, + ], + ), + ( + 2, + 4, + [ + {"input_idx": [0], "func": "rzz", "wires": [0, 2]}, + {"input_idx": [1], "func": "rxx", "wires": [1, 2]}, + {"input_idx": [2], "func": "ry", "wires": [2]}, + {"input_idx": [3], "func": "rzx", "wires": [1, 3]}, + ], + ), + ], + ) + def test_to_qiskit(self, batch_size, wires, func_list): + """Tests conversion of GeneralEncoder to Qiskit.""" + x = torch.rand(batch_size, pow(2, wires)) + encoder = GeneralEncoder(func_list) + qdev = QuantumDevice(n_wires=wires, bsz=batch_size) + encoder(qdev, x) + resulting_circuit = encoder.to_qiskit(wires, x) + for circuit in resulting_circuit: + assert isinstance(circuit, QuantumCircuit) + + @pytest.mark.parametrize( + "batch_size, wires, func_list", + [ + (2, 1, [{"input_idx": [0], "func": "hadamard", "wires": [0]}]), + (2, 2, [{"input_idx": [0], "func": "xx", "wires": [0, 1]}]), + ], + ) + def test_not_implemeted_qiskit(self, batch_size, wires, func_list): + """Tests conversion of GeneralEncoder to Qiskit.""" + x = torch.rand(batch_size, pow(2, wires)) + encoder = GeneralEncoder(func_list) + qdev = QuantumDevice(n_wires=wires, bsz=batch_size) + encoder(qdev, x) + with raises(NotImplementedError, match=r"([\s\S]*?) is not supported yet\."): + _ = encoder.to_qiskit(wires, x) + + +class TestPhaseEncoder: + """Test class for Phase Encoder.""" + + @pytest.mark.parametrize("func", [None, 1, 2.4, {}, True, list(range(2))]) + def test_func_type(self, func): + """Test the type of func input""" + with raises(TypeError, match="The input func must be of the type str."): + _ = PhaseEncoder(func) + + @pytest.mark.parametrize("func", ["hadamard", "ry", "xx", "paulix", "i"]) + def test_phase_encoding(self, func): + """Tests the PhaseEncoder class.""" + assert func in func_name_dict + encoder = PhaseEncoder(func) + qdev = QuantumDevice(2) + with mock.patch.object(encoder, "func") as mock_func: + encoder(qdev, torch.rand(2, 4)) + assert mock_func.call_count >= 1 + + +class TestMultiPhaseEncoder: + """Test class for Multi-phase Encoder.""" + + @pytest.mark.parametrize( + "wires, funcs", + [ + (10, ["rx", "hadamard"]), + (2, ["swap", "ry"]), + (3, ["xx"]), + (1, ["paulix", "i"]), + ], + ) + def test_invalid_func(self, wires, funcs): + with raises(ValueError, match=r"The func (.*?) is not supported\."): + encoder = MultiPhaseEncoder(funcs) + qdev = QuantumDevice(n_wires=wires) + encoder(qdev, torch.rand(1, pow(2, wires))) + + # NOTE: Test with func = u1 currently fails. + @pytest.mark.parametrize( + "batch_size, wires, funcs", [(2, 5, ["ry", "phaseshift"]), (1, 4, ["u2"]), (3, 1, ["u3"])] + ) + def test_phase_encoding(self, batch_size, wires, funcs): + """Tests the MultiPhaseEncoder class.""" + # wires = 4 + encoder = MultiPhaseEncoder(funcs) + qdev = QuantumDevice(n_wires=wires, bsz=batch_size) + mock_func = mock.Mock() + for func in encoder.funcs: + with mock.patch.dict(func_name_dict, {func: mock_func}): + encoder(qdev, torch.rand(batch_size, pow(2, wires))) + assert mock_func.call_count >= 1 + + +class TestStateEncoder: """Test class for State Encoder.""" @pytest.mark.parametrize( @@ -140,57 +306,3 @@ def test_state_encoding(self, wires, x, x_norm): assert torch.allclose( qdev.states.reshape(x.shape[0], -1), x_norm.type(torch.complex64), atol=1e-3 ) - - -class TestPhaseEncoding: - """Test class for Phase Encoder.""" - - @pytest.mark.parametrize("func", [None, 1, 2.4, {}, True, list(range(2))]) - def test_func_type(self, func): - """Test the type of func input""" - with raises(TypeError, match="The input func must be of the type str."): - _ = PhaseEncoder(func) - - @pytest.mark.parametrize("func", ["hadamard", "ry", "xx", "paulix", "i"]) - def test_phase_encoding(self, func): - """Tests the PhaseEncoder class.""" - assert func in func_name_dict - encoder = PhaseEncoder(func) - qdev = QuantumDevice(2) - with mock.patch.object(encoder, "func") as mock_func: - encoder(qdev, torch.rand(2, 4)) - assert mock_func.call_count >= 1 - - -class TestMultiPhaseEncoding: - """Test class for Multi-phase Encoder.""" - - @pytest.mark.parametrize( - "wires, funcs", - [ - (10, ["rx", "hadamard"]), - (2, ["swap", "ry"]), - (3, ["xx"]), - (1, ["paulix", "i"]), - ], - ) - def test_invalid_func(self, wires, funcs): - with raises(ValueError, match=r"The func (.*?) is not supported\."): - encoder = MultiPhaseEncoder(funcs) - qdev = QuantumDevice(n_wires=wires) - encoder(qdev, torch.rand(1, pow(2, wires))) - - # NOTE: Test with func = u1 currently fails. - @pytest.mark.parametrize( - "wires, funcs", [(5, ["ry", "phaseshift"]), (4, ["u2"]), (1, ["u3"])] - ) - def test_phase_encoding(self, wires, funcs): - """Tests the MultiPhaseEncoder class.""" - # wires = 4 - encoder = MultiPhaseEncoder(funcs) - qdev = QuantumDevice(n_wires=wires) - mock_func = mock.Mock() - for func in encoder.funcs: - with mock.patch.dict(func_name_dict, {func: mock_func}): - encoder(qdev, torch.rand(1, pow(2, wires))) - assert mock_func.call_count >= 1 diff --git a/torchquantum/encoding/encodings.py b/torchquantum/encoding/encodings.py index 0adb68d3..3208f4d4 100644 --- a/torchquantum/encoding/encodings.py +++ b/torchquantum/encoding/encodings.py @@ -102,6 +102,22 @@ class GeneralEncoder(Encoder, metaclass=ABCMeta): def __init__(self, func_list): super().__init__() + + if not isinstance(func_list, list) or not all( + isinstance(func_dict, dict) for func_dict in func_list + ): + raise TypeError("The input func_list must be of the type list[dict].") + + if any( + "input_idx" not in func_dict.keys() + or "func" not in func_dict.keys() + or "wires" not in func_dict.keys() + for func_dict in func_list + ): + raise ValueError( + "The dictionary in func_list must contain the keys: " + "input_idx, func, and wires." + ) self.func_list = func_list @tq.static_support @@ -142,7 +158,7 @@ def to_qiskit(self, n_wires, x): elif info["func"] == "rzx": circ.rzx(x[k][info["input_idx"][0]].item(), *info["wires"]) else: - raise NotImplementedError(info["func"]) + raise NotImplementedError(f"{info['func']} is not supported yet.") circs.append(circ) return circs From 96b8725442289164bffbe6332785f42d921b489c Mon Sep 17 00:00:00 2001 From: SaashaJoshi Date: Mon, 3 Jun 2024 22:29:00 -0700 Subject: [PATCH 5/8] add input validation to all forward methods --- test/encoding/test_encodings.py | 3 ++- torchquantum/encoding/encodings.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/test/encoding/test_encodings.py b/test/encoding/test_encodings.py index c24e240f..4e2e9e53 100644 --- a/test/encoding/test_encodings.py +++ b/test/encoding/test_encodings.py @@ -184,7 +184,8 @@ def test_invalid_func(self, wires, funcs): # NOTE: Test with func = u1 currently fails. @pytest.mark.parametrize( - "batch_size, wires, funcs", [(2, 5, ["ry", "phaseshift"]), (1, 4, ["u2"]), (3, 1, ["u3"])] + "batch_size, wires, funcs", + [(2, 5, ["ry", "phaseshift"]), (1, 4, ["u2"]), (3, 1, ["u3"])], ) def test_phase_encoding(self, batch_size, wires, funcs): """Tests the MultiPhaseEncoder class.""" diff --git a/torchquantum/encoding/encodings.py b/torchquantum/encoding/encodings.py index 3208f4d4..70bc4ca2 100644 --- a/torchquantum/encoding/encodings.py +++ b/torchquantum/encoding/encodings.py @@ -122,6 +122,8 @@ def __init__(self, func_list): @tq.static_support def forward(self, qdev: tq.QuantumDevice, x): + # Validate inputs + self.validate_inputs(qdev, x) for info in self.func_list: if tq.op_name_dict[info["func"]].num_params > 0: params = x[:, info["input_idx"]] @@ -189,6 +191,8 @@ def forward(self, qdev: tq.QuantumDevice, x): torch.Tensor: The encoded data. """ + # Validate inputs + self.validate_inputs(qdev, x) for k in range(qdev.n_wires): print("Calling") self.func( @@ -223,6 +227,8 @@ def forward(self, qdev: tq.QuantumDevice, x): torch.Tensor: The encoded data. """ + # Validate inputs + self.validate_inputs(qdev, x) if self.wires is None: # self.wires = list(range(qdev.n_wires)) * (len(self.funcs) // qdev.n_wires) self.wires = list(range(qdev.n_wires + (len(self.funcs) // qdev.n_wires))) From d58250c4ba1473a2fc5516317ca490909a36ca62 Mon Sep 17 00:00:00 2001 From: SaashaJoshi Date: Mon, 3 Jun 2024 23:20:48 -0700 Subject: [PATCH 6/8] tensor size validation is valid only for state encoder --- torchquantum/encoding/encodings.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/torchquantum/encoding/encodings.py b/torchquantum/encoding/encodings.py index 70bc4ca2..42c9f8d9 100644 --- a/torchquantum/encoding/encodings.py +++ b/torchquantum/encoding/encodings.py @@ -57,13 +57,6 @@ def validate_inputs(qdev: tq.QuantumDevice, x: torch.Tensor): if not isinstance(x, torch.Tensor): raise TypeError(f"The x input {x} must be of the type torch.Tensor.") - if any(tensor.size()[0] > pow(2, qdev.n_wires) for tensor in x): - raise ValueError( - f"The size of tensors in x ({x.size()[1]}) must be less than or " - f"equal to {pow(2, qdev.n_wires)} for a QuantumDevice with " - f"{qdev.n_wires} wires." - ) - class GeneralEncoder(Encoder, metaclass=ABCMeta): """func_list list of dict @@ -274,6 +267,7 @@ def forward(self, qdev: tq.QuantumDevice, x): """ # Validate inputs self.validate_inputs(qdev, x) + self.validate_tensor_size(qdev, x) # encoder the x to the statevector of the quantum device # normalize the input x = x / (torch.sqrt((x.abs() ** 2).sum(dim=-1))).unsqueeze(-1) @@ -288,6 +282,15 @@ def forward(self, qdev: tq.QuantumDevice, x): qdev.states = state.type(C_DTYPE) + @staticmethod + def validate_tensor_size(qdev: tq.QuantumDevice, x): + if any(tensor.size()[0] > pow(2, qdev.n_wires) for tensor in x): + raise ValueError( + f"The size of tensors in x ({x.size()[1]}) must be less than or " + f"equal to {pow(2, qdev.n_wires)} for a QuantumDevice with " + f"{qdev.n_wires} wires." + ) + class MagnitudeEncoder(Encoder, metaclass=ABCMeta): """MagnitudeEncoder is a subclass of Encoder and represents a magnitude encoder. From 2bcff54179dc580d91b045afe8a4f4931d109489 Mon Sep 17 00:00:00 2001 From: SaashaJoshi Date: Mon, 3 Jun 2024 23:39:31 -0700 Subject: [PATCH 7/8] edit error handling logic for func_lists --- test/encoding/test_encodings.py | 5 ++--- torchquantum/encoding/encodings.py | 10 +++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/test/encoding/test_encodings.py b/test/encoding/test_encodings.py index 4e2e9e53..f3b6ae1e 100644 --- a/test/encoding/test_encodings.py +++ b/test/encoding/test_encodings.py @@ -50,13 +50,12 @@ def test_invalid_func_list(self, func_list): _ = GeneralEncoder(func_list) @pytest.mark.parametrize( - "func_list", [[{"key1": 1}], [{"func": "rx"}], [{"func": "rx", "key2": None}]] + "func_list", [[{"key1": 1}], [{"func": "rx"}], [{"func": "rx", "input_idx": [0]}]] ) def test_func_list_keys(self, func_list): with raises( ValueError, - match="The dictionary in func_list must contain the " - "keys: input_idx, func, and wires.", + match="The dictionary in func_list must is missing func or wires.", ): _ = GeneralEncoder(func_list) diff --git a/torchquantum/encoding/encodings.py b/torchquantum/encoding/encodings.py index 42c9f8d9..7f06a048 100644 --- a/torchquantum/encoding/encodings.py +++ b/torchquantum/encoding/encodings.py @@ -102,14 +102,14 @@ def __init__(self, func_list): raise TypeError("The input func_list must be of the type list[dict].") if any( - "input_idx" not in func_dict.keys() - or "func" not in func_dict.keys() - or "wires" not in func_dict.keys() + ( + "func" not in func_dict + or ("func" in func_dict and "wires" not in func_dict) + ) for func_dict in func_list ): raise ValueError( - "The dictionary in func_list must contain the keys: " - "input_idx, func, and wires." + "The dictionary in func_list must is missing func or wires." ) self.func_list = func_list From 3d381c863b9ed57689ee094d68aab7d395e98d66 Mon Sep 17 00:00:00 2001 From: SaashaJoshi Date: Mon, 3 Jun 2024 23:40:18 -0700 Subject: [PATCH 8/8] black --- test/encoding/test_encodings.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/encoding/test_encodings.py b/test/encoding/test_encodings.py index f3b6ae1e..ece4bd76 100644 --- a/test/encoding/test_encodings.py +++ b/test/encoding/test_encodings.py @@ -50,7 +50,8 @@ def test_invalid_func_list(self, func_list): _ = GeneralEncoder(func_list) @pytest.mark.parametrize( - "func_list", [[{"key1": 1}], [{"func": "rx"}], [{"func": "rx", "input_idx": [0]}]] + "func_list", + [[{"key1": 1}], [{"func": "rx"}], [{"func": "rx", "input_idx": [0]}]], ) def test_func_list_keys(self, func_list): with raises(