From e2b190f39d45b5e4828b0f0a1429fa68f733d7bc Mon Sep 17 00:00:00 2001 From: nikhilkhatri Date: Sun, 24 Mar 2024 12:49:29 +0000 Subject: [PATCH] Add support for fixed params in GeneralEncoder --- torchquantum/encoding/encodings.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchquantum/encoding/encodings.py b/torchquantum/encoding/encodings.py index f8d2056d..e73c526e 100644 --- a/torchquantum/encoding/encodings.py +++ b/torchquantum/encoding/encodings.py @@ -90,7 +90,11 @@ def __init__(self, func_list): def forward(self, qdev: tq.QuantumDevice, x): for info in self.func_list: if tq.op_name_dict[info["func"]].num_params > 0: - params = x[:, info["input_idx"]] + # If params are provided in encoder, use those, + # else use params from x + params = (torch.Tensor(info["params"]).repeat(x.shape[0], 1) + if info.get("params") + else x[:, info["input_idx"]]) else: params = None func_name_dict[info["func"]](