-
Notifications
You must be signed in to change notification settings - Fork 211
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #270 from AbdullahKazi500/AbdullahKazi500-patch-2
- Loading branch information
Showing
5 changed files
with
766 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Quantum Generative Adversarial Network (QGAN) Example | ||
|
||
This repository contains an example implementation of a Quantum Generative Adversarial Network (QGAN) using PyTorch and TorchQuantum. The example is provided in a Jupyter Notebook for interactive exploration. | ||
|
||
## Overview | ||
|
||
A QGAN consists of two main components: | ||
|
||
1. **Generator:** This network generates fake quantum data samples. | ||
2. **Discriminator:** This network tries to distinguish between real and fake quantum data samples. | ||
|
||
The goal is to train the generator to produce quantum data that is indistinguishable from real data, according to the discriminator. This is achieved through an adversarial training process, where the generator and discriminator are trained simultaneously in a competitive manner. | ||
|
||
## Repository Contents | ||
|
||
- `qgan_notebook.ipynb`: Jupyter Notebook demonstrating the QGAN implementation. | ||
- `qgan_script.py`: Python script containing the QGAN model and a main function for initializing the model with command-line arguments. | ||
|
||
## Installation | ||
|
||
To run the examples, you need to have the following dependencies installed: | ||
|
||
- Python 3 | ||
- PyTorch | ||
- TorchQuantum | ||
- Jupyter Notebook | ||
- ipywidgets | ||
|
||
You can install the required Python packages using pip: | ||
|
||
```bash | ||
pip install torch torchquantum jupyter ipywidgets | ||
``` | ||
|
||
|
||
Running the Examples | ||
Jupyter Notebook | ||
Open the qgan_notebook.ipynb file in Jupyter Notebook. | ||
Execute the notebook cells to see the QGAN model in action. | ||
Python Script | ||
You can also run the QGAN model using the Python script. The script uses argparse to handle command-line arguments. | ||
|
||
bash | ||
Copy code | ||
python qgan_script.py <n_qubits> <latent_dim> | ||
Replace <n_qubits> and <latent_dim> with the desired number of qubits and latent dimensions. | ||
|
||
Notebook Details | ||
The Jupyter Notebook is structured as follows: | ||
|
||
Introduction: Provides an overview of the QGAN and its components. | ||
Import Libraries: Imports the necessary libraries, including PyTorch and TorchQuantum. | ||
Generator Class: Defines the quantum generator model. | ||
Discriminator Class: Defines the quantum discriminator model. | ||
QGAN Class: Combines the generator and discriminator into a single QGAN model. | ||
Main Function: Initializes the QGAN model and prints its structure. | ||
Interactive Model Creation: Uses ipywidgets to create an interactive interface for adjusting the number of qubits and latent dimensions. | ||
Understanding QGANs | ||
QGANs are a type of Generative Adversarial Network (GAN) that operate in the quantum domain. They leverage quantum circuits to generate and evaluate data samples. The adversarial training process involves two competing networks: | ||
|
||
The Generator creates fake quantum data samples from a latent space. | ||
The Discriminator attempts to distinguish these fake samples from real quantum data. | ||
Through training, the generator improves its ability to create realistic quantum data, while the discriminator enhances its ability to identify fake data. This process results in a generator that can produce high-quality quantum data samples. | ||
|
||
|
||
## QGAN Implementation for CIFAR-10 Dataset | ||
This implementation trains a QGAN on the CIFAR-10 dataset to generate fake images. It follows a similar structure to the TorchQuantum QGAN, with the addition of data loading and processing specific to the CIFAR-10 dataset. | ||
Generated images can be seen in the folder | ||
|
||
This `README.md` file explains the purpose of the repository, the structure of the notebook, and how to run the examples, along with a brief overview of the QGAN concept for those unfamiliar with it. | ||
|
||
|
||
## Reference | ||
- [ ] https://arxiv.org/abs/2312.09939 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import argparse | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torchquantum as tq | ||
|
||
class Generator(nn.Module): | ||
def __init__(self, n_qubits: int, latent_dim: int): | ||
super().__init__() | ||
self.n_qubits = n_qubits | ||
self.latent_dim = latent_dim | ||
|
||
# Quantum encoder | ||
self.encoder = tq.GeneralEncoder([ | ||
{'input_idx': [i], 'func': 'rx', 'wires': [i]} | ||
for i in range(self.n_qubits) | ||
]) | ||
|
||
# RX gates | ||
self.rxs = nn.ModuleList([ | ||
tq.RX(has_params=True, trainable=True) for _ in range(self.n_qubits) | ||
]) | ||
|
||
def forward(self, x): | ||
qdev = tq.QuantumDevice(n_wires=self.n_qubits, bsz=x.shape[0], device=x.device) | ||
self.encoder(qdev, x) | ||
|
||
for i in range(self.n_qubits): | ||
self.rxs[i](qdev, wires=i) | ||
|
||
return tq.measure(qdev) | ||
|
||
class Discriminator(nn.Module): | ||
def __init__(self, n_qubits: int): | ||
super().__init__() | ||
self.n_qubits = n_qubits | ||
|
||
# Quantum encoder | ||
self.encoder = tq.GeneralEncoder([ | ||
{'input_idx': [i], 'func': 'rx', 'wires': [i]} | ||
for i in range(self.n_qubits) | ||
]) | ||
|
||
# RX gates | ||
self.rxs = nn.ModuleList([ | ||
tq.RX(has_params=True, trainable=True) for _ in range(self.n_qubits) | ||
]) | ||
|
||
# Quantum measurement | ||
self.measure = tq.MeasureAll(tq.PauliZ) | ||
|
||
def forward(self, x): | ||
qdev = tq.QuantumDevice(n_wires=self.n_qubits, bsz=x.shape[0], device=x.device) | ||
self.encoder(qdev, x) | ||
|
||
for i in range(self.n_qubits): | ||
self.rxs[i](qdev, wires=i) | ||
|
||
return self.measure(qdev) | ||
|
||
class QGAN(nn.Module): | ||
def __init__(self, n_qubits: int, latent_dim: int): | ||
super().__init__() | ||
self.generator = Generator(n_qubits, latent_dim) | ||
self.discriminator = Discriminator(n_qubits) | ||
|
||
def forward(self, z): | ||
fake_data = self.generator(z) | ||
fake_output = self.discriminator(fake_data) | ||
return fake_output | ||
|
||
def main(n_qubits, latent_dim): | ||
model = QGAN(n_qubits, latent_dim) | ||
print(model) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Quantum Generative Adversarial Network (QGAN) Example") | ||
parser.add_argument('n_qubits', type=int, help='Number of qubits') | ||
parser.add_argument('latent_dim', type=int, help='Dimension of the latent space') | ||
|
||
args = parser.parse_args() | ||
|
||
main(args.n_qubits, args.latent_dim) | ||
|
Oops, something went wrong.