Skip to content

Commit

Permalink
* Unpack arguments based on the caller's execution context
Browse files Browse the repository at this point in the history
* Make tests stable by enforcing garbage collection as part of set up
  • Loading branch information
khalatepradnya committed Sep 23, 2024
1 parent c61a21a commit c13b934
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 16 deletions.
12 changes: 10 additions & 2 deletions python/cudaq/handlers/photonics_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from ..mlir._mlir_libs._quakeDialects import cudaq_runtime

_TARGET_NAME = 'photonics'

# The qudit level must be explicitly defined
globalQuditLevel = None

Expand All @@ -33,7 +35,13 @@ class PyQudit:
id: int

def __del__(self):
cudaq_runtime.photonics.release_qudit(self.level, self.id)
try:
cudaq_runtime.photonics.release_qudit(self.level, self.id)
except Exception as e:
if _TARGET_NAME == cudaq_runtime.get_target().name:
raise e
else:
pass


def _is_qudit_type(q: any) -> bool:
Expand Down Expand Up @@ -194,7 +202,7 @@ class PhotonicsHandler(object):

def __init__(self, function):

if 'photonics' != cudaq_runtime.get_target().name:
if _TARGET_NAME != cudaq_runtime.get_target().name:
raise RuntimeError(
"A photonics kernel can only be used with 'photonics' target.")

Expand Down
12 changes: 11 additions & 1 deletion python/cudaq/kernel/kernel_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,17 @@ def __call__(self, *args):
raise RuntimeError(
"The 'photonics' target must be used with a valid function."
)
PhotonicsHandler(self.kernelFunction)(*args)
# NOTE: Since this handler does not support MLIR mode (yet), just
# invoke the kernel. If calling from a bound function, need to
# unpack the arguments, for example, see `pyGetStateLibraryMode`
try:
context_name = cudaq_runtime.getExecutionContextName()
except RuntimeError:
context_name = None
callable_args = args
if "extract-state" == context_name and len(args) == 1:
callable_args = args[0]
PhotonicsHandler(self.kernelFunction)(*callable_args)
return

# Prepare captured state storage for the run
Expand Down
4 changes: 4 additions & 0 deletions python/runtime/common/py_ExecutionContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,9 @@ void bindExecutionContext(py::module &mod) {
auto &platform = cudaq::get_platform();
return platform.supports_conditional_feedback();
});
mod.def("getExecutionContextName", []() {
auto &self = cudaq::get_platform();
return self.get_exec_ctx()->name;
});
}
} // namespace cudaq
14 changes: 8 additions & 6 deletions python/runtime/cudaq/algorithms/py_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,17 @@ state pyGetStateRemote(py::object kernel, py::args args) {
}

state pyGetStateLibraryMode(py::object kernel, py::args args) {
cudaq::info("Size of arguments = {}", args.size());

/// TODO: Pack / unpack arguments
return details::extractState([&]() mutable {
if (0 == args.size())
cudaq::invokeKernel(std::forward<py::object>(kernel));
else
cudaq::invokeKernel(std::forward<py::object>(kernel),
std::forward<py::args>(args));
else {
std::vector<py::object> argsData;
for (size_t i = 0; i < args.size(); i++) {
py::object arg = args[i];
argsData.emplace_back(std::forward<py::object>(arg));
}
cudaq::invokeKernel(std::forward<py::object>(kernel), argsData);
}
});
}

Expand Down
44 changes: 40 additions & 4 deletions python/tests/handlers/test_photonics_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
# ============================================================================ #

import pytest

import gc
from typing import List

import cudaq


Expand All @@ -16,6 +20,8 @@ def do_something():
yield
cudaq.reset_target()
cudaq.__clearKernelRegistries()
# Make the tests stable by enforcing resource release
gc.collect()


def test_qudit():
Expand Down Expand Up @@ -83,18 +89,48 @@ def kernel():


def test_kernel_with_args():
"""Test that `PhotonicsHandler` supports basic arguments.
The check here is that all the test kernels run successfully."""

@cudaq.kernel
def kernel(theta: float):
def kernel_1f(theta: float):
q = qudit(4)
plus(q)
phase_shift(q, theta)
mz(q)

result = cudaq.sample(kernel, 0.5)
result = cudaq.sample(kernel_1f, 0.5)
result.dump()

state = cudaq.get_state(kernel_1f, 0.5)
state.dump()

@cudaq.kernel
def kernel_2f(theta: float, phi: float):
quds = [qudit(3) for _ in range(2)]
plus(quds[0])
phase_shift(quds[0], theta)
beam_splitter(quds[0], quds[1], phi)
mz(quds)

result = cudaq.sample(kernel_2f, 0.7854, 0.3927)
result.dump()

state = cudaq.get_state(kernel, 0.5)

state = cudaq.get_state(kernel_2f, 0.7854, 0.3927)
state.dump()

@cudaq.kernel
def kernel_list(angles: List[float]):
quds = [qudit(2) for _ in range(3)]
plus(quds[0])
phase_shift(quds[1], angles[0])
phase_shift(quds[2], angles[1])
mz(quds)

result = cudaq.sample(kernel_list, [0.5236, 1.0472])
result.dump()

state = cudaq.get_state(kernel_list, [0.5236, 1.0472])
state.dump()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ struct PhotonicsState : public cudaq::SimulationState {
getPrecision()};
}

// /// @brief Return all tensors that represent this state
/// @brief Return all tensors that represent this state
std::vector<Tensor> getTensors() const override { return {getTensor()}; }

// /// @brief Return the number of tensors that represent this state.
/// @brief Return the number of tensors that represent this state.
std::size_t getNumTensors() const override { return 1; }

std::complex<double>
Expand All @@ -77,7 +77,6 @@ struct PhotonicsState : public cudaq::SimulationState {
throw std::runtime_error("[photonics] invalid tensor requested.");
if (indices.size() != 1)
throw std::runtime_error("[photonics] invalid element extraction.");

return state[indices[0]];
}

Expand Down

0 comments on commit c13b934

Please sign in to comment.