Skip to content

Commit

Permalink
Add support for fixed params in GeneralEncoder
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilkhatri committed Mar 24, 2024
1 parent b7f8c36 commit e2b190f
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion torchquantum/encoding/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def __init__(self, func_list):
def forward(self, qdev: tq.QuantumDevice, x):
for info in self.func_list:
if tq.op_name_dict[info["func"]].num_params > 0:
params = x[:, info["input_idx"]]
# If params are provided in encoder, use those,
# else use params from x
params = (torch.Tensor(info["params"]).repeat(x.shape[0], 1)
if info.get("params")
else x[:, info["input_idx"]])
else:
params = None
func_name_dict[info["func"]](
Expand Down

0 comments on commit e2b190f

Please sign in to comment.