Skip to content

Commit

Permalink
[Python] [photonics] Enable 'get_state' API (#2201)
Browse files Browse the repository at this point in the history
  • Loading branch information
khalatepradnya authored Sep 25, 2024
1 parent c86eb8a commit d8f23aa
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 14 deletions.
3 changes: 3 additions & 0 deletions docs/sphinx/examples/python/providers/photonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ def photonicsKernel():

counts = cudaq.sample(photonicsKernel)
print(counts)

state = cudaq.get_state(photonicsKernel)
print(state)
3 changes: 3 additions & 0 deletions docs/sphinx/examples/python/providers/photonics_tbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,6 @@ def TBI(
loop_lengths,
shots_count=1000000)
counts.dump()

state = cudaq.get_state(TBI, bs_angles, ps_angles, input_state, loop_lengths)
state.dump()
12 changes: 10 additions & 2 deletions python/cudaq/handlers/photonics_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from ..mlir._mlir_libs._quakeDialects import cudaq_runtime

_TARGET_NAME = 'photonics'

# The qudit level must be explicitly defined
globalQuditLevel = None

Expand All @@ -33,7 +35,13 @@ class PyQudit:
id: int

def __del__(self):
cudaq_runtime.photonics.release_qudit(self.level, self.id)
try:
cudaq_runtime.photonics.release_qudit(self.level, self.id)
except Exception as e:
if _TARGET_NAME == cudaq_runtime.get_target().name:
raise e
else:
pass


def _is_qudit_type(q: any) -> bool:
Expand Down Expand Up @@ -194,7 +202,7 @@ class PhotonicsHandler(object):

def __init__(self, function):

if 'photonics' != cudaq_runtime.get_target().name:
if _TARGET_NAME != cudaq_runtime.get_target().name:
raise RuntimeError(
"A photonics kernel can only be used with 'photonics' target.")

Expand Down
12 changes: 11 additions & 1 deletion python/cudaq/kernel/kernel_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,17 @@ def __call__(self, *args):
raise RuntimeError(
"The 'photonics' target must be used with a valid function."
)
PhotonicsHandler(self.kernelFunction)(*args)
# NOTE: Since this handler does not support MLIR mode (yet), just
# invoke the kernel. If calling from a bound function, need to
# unpack the arguments, for example, see `pyGetStateLibraryMode`
try:
context_name = cudaq_runtime.getExecutionContextName()
except RuntimeError:
context_name = None
callable_args = args
if "extract-state" == context_name and len(args) == 1:
callable_args = args[0]
PhotonicsHandler(self.kernelFunction)(*callable_args)
return

# Prepare captured state storage for the run
Expand Down
4 changes: 4 additions & 0 deletions python/runtime/common/py_ExecutionContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,9 @@ void bindExecutionContext(py::module &mod) {
auto &platform = cudaq::get_platform();
return platform.supports_conditional_feedback();
});
mod.def("getExecutionContextName", []() {
auto &self = cudaq::get_platform();
return self.get_exec_ctx()->name;
});
}
} // namespace cudaq
17 changes: 17 additions & 0 deletions python/runtime/cudaq/algorithms/py_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,21 @@ state pyGetStateRemote(py::object kernel, py::args args) {
size, returnOffset));
}

state pyGetStateLibraryMode(py::object kernel, py::args args) {
return details::extractState([&]() mutable {
if (0 == args.size())
cudaq::invokeKernel(std::forward<py::object>(kernel));
else {
std::vector<py::object> argsData;
for (size_t i = 0; i < args.size(); i++) {
py::object arg = args[i];
argsData.emplace_back(std::forward<py::object>(arg));
}
cudaq::invokeKernel(std::forward<py::object>(kernel), argsData);
}
});
}

/// @brief Bind the get_state cudaq function
void bindPyState(py::module &mod, LinkedLibraryHolder &holder) {

Expand Down Expand Up @@ -629,6 +644,8 @@ index pair.
if (holder.getTarget().name == "remote-mqpu" ||
holder.getTarget().name == "nvqc")
return pyGetStateRemote(kernel, args);
if (holder.getTarget().name == "photonics")
return pyGetStateLibraryMode(kernel, args);
return pyGetState(kernel, args);
},
R"#(Return the :class:`State` of the system after execution of the provided `kernel`.
Expand Down
56 changes: 56 additions & 0 deletions python/tests/handlers/test_photonics_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
# ============================================================================ #

import pytest

import gc
from typing import List

import cudaq


Expand All @@ -16,6 +20,8 @@ def do_something():
yield
cudaq.reset_target()
cudaq.__clearKernelRegistries()
# Make the tests stable by enforcing resource release
gc.collect()


def test_qudit():
Expand All @@ -32,6 +38,10 @@ def kernel():
assert len(counts) == 1
assert '3' in counts

state = cudaq.get_state(kernel)
state.dump()
assert 4 == state.__len__()


def test_qudit_list():

Expand Down Expand Up @@ -78,6 +88,52 @@ def kernel():
counts.dump()


def test_kernel_with_args():
"""Test that `PhotonicsHandler` supports basic arguments.
The check here is that all the test kernels run successfully."""

@cudaq.kernel
def kernel_1f(theta: float):
q = qudit(4)
plus(q)
phase_shift(q, theta)
mz(q)

result = cudaq.sample(kernel_1f, 0.5)
result.dump()

state = cudaq.get_state(kernel_1f, 0.5)
state.dump()

@cudaq.kernel
def kernel_2f(theta: float, phi: float):
quds = [qudit(3) for _ in range(2)]
plus(quds[0])
phase_shift(quds[0], theta)
beam_splitter(quds[0], quds[1], phi)
mz(quds)

result = cudaq.sample(kernel_2f, 0.7854, 0.3927)
result.dump()

state = cudaq.get_state(kernel_2f, 0.7854, 0.3927)
state.dump()

@cudaq.kernel
def kernel_list(angles: List[float]):
quds = [qudit(2) for _ in range(3)]
plus(quds[0])
phase_shift(quds[1], angles[0])
phase_shift(quds[2], angles[1])
mz(quds)

result = cudaq.sample(kernel_list, [0.5236, 1.0472])
result.dump()

state = cudaq.get_state(kernel_list, [0.5236, 1.0472])
state.dump()


def test_target_change():

@cudaq.kernel
Expand Down
34 changes: 23 additions & 11 deletions runtime/cudaq/qis/managers/photonics/PhotonicsExecutionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ struct PhotonicsState : public cudaq::SimulationState {
PhotonicsState(qpp::ket &&data, std::size_t lvl)
: state(std::move(data)), levels(lvl) {}

/// TODO: Rename the API to be generic
std::size_t getNumQubits() const override {
throw "not supported for this photonics simulator";
return (std::log2(state.size()) / std::log2(levels));
}

std::complex<double> overlap(const cudaq::SimulationState &other) override {
Expand All @@ -39,7 +40,11 @@ struct PhotonicsState : public cudaq::SimulationState {

std::complex<double>
getAmplitude(const std::vector<int> &basisState) override {
/// TODO: Check basisState.size() matches qudit count
if (getNumQubits() != basisState.size())
throw std::runtime_error(fmt::format(
"[photonics] getAmplitude with an invalid number of bits in the "
"basis state: expected {}, provided {}.",
getNumQubits(), basisState.size()));

// Convert the basis state to an index value
const std::size_t idx = std::accumulate(
Expand All @@ -50,27 +55,34 @@ struct PhotonicsState : public cudaq::SimulationState {
}

Tensor getTensor(std::size_t tensorIdx = 0) const override {
throw "not supported for this photonics simulator";
if (tensorIdx != 0)
throw std::runtime_error("[photonics] invalid tensor requested.");
return Tensor{
reinterpret_cast<void *>(
const_cast<std::complex<double> *>(state.data())),
std::vector<std::size_t>{static_cast<std::size_t>(state.size())},
getPrecision()};
}

std::vector<Tensor> getTensors() const override {
throw "not supported for this photonics simulator";
}
/// @brief Return all tensors that represent this state
std::vector<Tensor> getTensors() const override { return {getTensor()}; }

std::size_t getNumTensors() const override {
throw "not supported for this photonics simulator";
}
/// @brief Return the number of tensors that represent this state.
std::size_t getNumTensors() const override { return 1; }

std::complex<double>
operator()(std::size_t tensorIdx,
const std::vector<std::size_t> &indices) override {
throw "not supported for this photonics simulator";
if (tensorIdx != 0)
throw std::runtime_error("[photonics] invalid tensor requested.");
if (indices.size() != 1)
throw std::runtime_error("[photonics] invalid element extraction.");
return state[indices[0]];
}

std::unique_ptr<SimulationState>
createFromSizeAndPtr(std::size_t size, void *ptr, std::size_t) override {
throw "not supported for this photonics simulator";
;
}

void dump(std::ostream &os) const override { os << state << "\n"; }
Expand Down

0 comments on commit d8f23aa

Please sign in to comment.