Skip to content

Commit

Permalink
Merge pull request #94 from mit-han-lab/dev
Browse files Browse the repository at this point in the history
Dev revise examples
  • Loading branch information
Hanrui-Wang authored Mar 18, 2023
2 parents 2d2268c + dfc4514 commit ac1edd8
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 23 deletions.
134 changes: 134 additions & 0 deletions examples/amplitude_encoding_mnist/mnist_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
author: Vivek Yanamadula @Vivekyy
"""

import torch
import torch.nn.functional as F

import torchquantum as tq
import torchquantum.functional as tqf

from torchquantum.datasets import MNIST
from torchquantum.operators import op_name_dict
from typing import List


class TQNet(tq.QuantumModule):
def __init__(self, layers: List[tq.QuantumModule], encoder=None, use_softmax=False):
super().__init__()

self.encoder = encoder
self.use_softmax = use_softmax

self.layers = tq.QuantumModuleList()

for layer in layers:
self.layers.append(layer)

self.service = "TorchQuantum"
self.measure = tq.MeasureAll(tq.PauliZ)

def forward(self, device, x):
bsz = x.shape[0]
device.reset_states(bsz)

x = F.avg_pool2d(x, 6)
x = x.view(bsz, 16)

if self.encoder:
self.encoder(device, x)

for layer in self.layers:
layer(device)

meas = self.measure(device)

if self.use_softmax:
meas = F.log_softmax(meas, dim=1)

return meas

class TQLayer(tq.QuantumModule):
def __init__(self, gates: List[tq.QuantumModule]):
super().__init__()

self.service = "TorchQuantum"

self.layer = tq.QuantumModuleList()
for gate in gates:
self.layer.append(gate)

@tq.static_support
def forward(self, q_device):
for gate in self.layer:
gate(q_device)

def train_tq(model, device, train_dl, epochs, loss_fn, optimizer):
losses = []
for epoch in range(epochs):
running_loss = 0.0
batches = 0
for batch_dict in train_dl:
x = batch_dict['image']
y = batch_dict['digit']

y = y.to(torch.long)

x = x.to(torch_device)
y = y.to(torch_device)

optimizer.zero_grad()

preds = model(device, x)

loss = loss_fn(preds, y)
loss.backward()

optimizer.step()

running_loss += loss.item()
batches += 1

print(f"Epoch {epoch + 1} | Loss: {running_loss/batches}", end="\r")

print(f"Epoch {epoch + 1} | Loss: {running_loss/batches}")
losses.append(running_loss/batches)

return losses

torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# encoder = None
# encoder = tq.AmplitudeEncoder()
encoder = tq.MultiPhaseEncoder(['u3', 'u3', 'u3', 'u3'])


random_layer = tq.RandomLayer(n_ops=50, wires=list(range(4)))
trainable_layer = [op_name_dict['rx'](trainable=True, has_params=True, wires=[0]),
op_name_dict['ry'](trainable=True, has_params=True, wires=[1]),
op_name_dict['rz'](trainable=True, has_params=True, wires=[3]),
op_name_dict['crx'](trainable=True, has_params=True, wires=[0,2])]
trainable_layer = TQLayer(trainable_layer)
layers = [random_layer, trainable_layer]

device = tq.QuantumDevice(n_wires=4).to(torch_device)

model = TQNet(layers=layers, encoder=encoder, use_softmax=True).to(torch_device)

loss_fn = F.nll_loss
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)

dataset = MNIST(
root='./mnist_data',
train_valid_split_ratio=[.9, .1],
digits_of_interest=[0, 1, 3, 6],
n_test_samples=200,
)

train_dl = torch.utils.data.DataLoader(dataset['train'], batch_size=32, sampler=torch.utils.data.RandomSampler(dataset['train']))
val_dl = torch.utils.data.DataLoader(dataset['valid'], batch_size=32, sampler=torch.utils.data.RandomSampler(dataset['valid']))
test_dl = torch.utils.data.DataLoader(dataset['test'], batch_size=32, sampler=torch.utils.data.RandomSampler(dataset['test']))

print("--Training--")
train_losses = train_tq(model, device, train_dl, 1, loss_fn, optimizer)

53 changes: 30 additions & 23 deletions examples/regression/run_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, n_train, n_valid, n_wires):


class QModel(tq.QuantumModule):
def __init__(self, n_wires, n_blocks):
def __init__(self, n_wires, n_blocks, add_fc=False):
super().__init__()
# inside one block, we have one u3 layer one each qubit and one layer
# cu3 layer with ring connection
Expand All @@ -95,48 +95,56 @@ def __init__(self, n_wires, n_blocks):
)
)
self.measure = tq.MeasureAll(tq.PauliZ)

def forward(self, q_device: tq.QuantumDevice, input_states):
# firstly set the q_device states
q_device.set_states(input_states)
self.add_fc = add_fc
if add_fc:
self.fc_layer = torch.nn.Linear(n_wires, 1)

def forward(self, input_states):
qdev = tq.QuantumDevice(n_wires=self.n_wires, bsz=input_states.shape[0], device=input_states.device)
# firstly set the qdev states
qdev.set_states(input_states)
for k in range(self.n_blocks):
self.u3_layers[k](q_device)
self.cu3_layers[k](q_device)
self.u3_layers[k](qdev)
self.cu3_layers[k](qdev)

res = self.measure(q_device)
res = self.measure(qdev)
if self.add_fc:
res = self.fc_layer(res)
else:
res = res[:, 1]
return res


def train(dataflow, q_device, model, device, optimizer):
def train(dataflow, model, device, optimizer):
for feed_dict in dataflow["train"]:
inputs = feed_dict["states"].to(device).to(torch.complex64)
targets = feed_dict["Xlabel"].to(device).to(torch.float)

outputs = model(q_device, inputs)
outputs = model(inputs)

loss = F.mse_loss(outputs[:, 1], targets)
loss = F.mse_loss(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"loss: {loss.item()}")


def valid_test(dataflow, q_device, split, model, device):
def valid_test(dataflow, split, model, device):
target_all = []
output_all = []
with torch.no_grad():
for feed_dict in dataflow[split]:
inputs = feed_dict["states"].to(device).to(torch.complex64)
targets = feed_dict["Xlabel"].to(device).to(torch.float)

outputs = model(q_device, inputs)
outputs = model(inputs)

target_all.append(targets)
output_all.append(outputs)
target_all = torch.cat(target_all, dim=0)
output_all = torch.cat(output_all, dim=0)

loss = F.mse_loss(output_all[:, 1], target_all)
loss = F.mse_loss(output_all, target_all)

print(f"{split} set loss: {loss}")

Expand Down Expand Up @@ -165,6 +173,9 @@ def main():
parser.add_argument(
"--epochs", type=int, default=100, help="number of training epochs"
)
parser.add_argument(
"--addfc", action="store_true", help="add a final classical FC layer"
)

args = parser.parse_args()

Expand Down Expand Up @@ -202,27 +213,23 @@ def main():
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

model = QModel(n_wires=args.n_wires, n_blocks=args.n_blocks).to(device)
model = QModel(n_wires=args.n_wires, n_blocks=args.n_blocks, add_fc=args.addfc).to(device)

n_epochs = args.epochs
optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)

q_device = tq.QuantumDevice(n_wires=args.n_wires)
q_device.reset_states(bsz=args.bsz)

for epoch in range(1, n_epochs + 1):
# train
print(f"Epoch {epoch}, RL: {optimizer.param_groups[0]['lr']}")
train(dataflow, q_device, model, device, optimizer)
print(f"Epoch {epoch}, LR: {optimizer.param_groups[0]['lr']}")
train(dataflow, model, device, optimizer)

# valid
valid_test(dataflow, q_device, "valid", model, device)
valid_test(dataflow,"valid", model, device)
scheduler.step()

# final valid
valid_test(dataflow, q_device, "valid", model, device)

valid_test(dataflow, "valid", model, device)

if __name__ == "__main__":
main()

0 comments on commit ac1edd8

Please sign in to comment.