Skip to content

Commit

Permalink
Consider cases when ExpressionConvertible returns float, int, duration (
Browse files Browse the repository at this point in the history
#89)

* consider case when ExpressionConvertible returns float, int, duration

* use to_ast to create the ast node after _to_oqpy_expression

* allow use of str in OQPyBinaryExpression

* change type hints

* black

---------

Co-authored-by: Phil Reinhold <[email protected]>
  • Loading branch information
jcjaskula-aws and PhilReinhold authored Jun 17, 2024
1 parent 08671da commit f9d6fea
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 14 deletions.
22 changes: 13 additions & 9 deletions oqpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def expr_matches(a: Any, b: Any) -> bool:
class ExpressionConvertible(Protocol):
"""This is the protocol an object can implement in order to be usable as an expression."""

def _to_oqpy_expression(self) -> HasToAst: ... # pragma: no cover
def _to_oqpy_expression(self) -> AstConvertible: ... # pragma: no cover


@runtime_checkable
Expand Down Expand Up @@ -379,12 +379,17 @@ class OQPyBinaryExpression(OQPyExpression):

def __init__(
self,
op: ast.BinaryOperator,
op: ast.BinaryOperator | str,
lhs: AstConvertible,
rhs: AstConvertible,
ast_type: ast.ClassicalType | None = None,
):
super().__init__()
if isinstance(op, str):
try:
op = ast.BinaryOperator[op]
except KeyError as e:
raise ValueError(f"Invalid binary operator {op}") from e
self.op = op
self.lhs = lhs
self.rhs = rhs
Expand All @@ -396,7 +401,9 @@ def __init__(
elif isinstance(rhs, OQPyExpression):
ast_type = rhs.type
else:
raise TypeError("Neither lhs nor rhs is an expression?")
raise TypeError(
"Cannot infer ast_type from lhs or rhs. Please provide it if possible."
)
self.type = ast_type

# Adding floats to durations is not allowed. So we promote types as necessary.
Expand Down Expand Up @@ -468,17 +475,14 @@ def to_ast(self, program: Program) -> ast.Expression:
def to_ast(program: Program, item: AstConvertible) -> ast.Expression:
"""Convert an object to an AST node."""
if hasattr(item, "_to_oqpy_expression"):
item = cast(ExpressionConvertible, item)
return item._to_oqpy_expression().to_ast(program)
item = cast(ExpressionConvertible, item)._to_oqpy_expression()
if hasattr(item, "_to_cached_oqpy_expression"):
item = cast(CachedExpressionConvertible, item)
if item._oqpy_cache_key is None:
item._oqpy_cache_key = uuid.uuid1()
if item._oqpy_cache_key not in program.expr_cache:
program.expr_cache[item._oqpy_cache_key] = item._to_cached_oqpy_expression().to_ast(
program
)
return program.expr_cache[item._oqpy_cache_key]
program.expr_cache[item._oqpy_cache_key] = item._to_cached_oqpy_expression()
item = program.expr_cache[item._oqpy_cache_key]
if isinstance(item, (complex, np.complexfloating)):
if item.imag == 0:
return to_ast(program, item.real)
Expand Down
8 changes: 4 additions & 4 deletions oqpy/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,16 @@ def convert_float_to_duration(time: AstConvertible, require_nonnegative: bool =
require_nonnegative: if True, raise an exception if the time value is known to
be negative.
"""
if isinstance(time, (float, int)):
if require_nonnegative and time < 0:
raise ValueError(f"Expected a non-negative duration, but got {time}")
return OQDurationLiteral(time)
if hasattr(time, "_to_oqpy_expression"):
time = cast(ExpressionConvertible, time)
time = time._to_oqpy_expression()
if hasattr(time, "_to_cached_oqpy_expression"):
time = cast(CachedExpressionConvertible, time)
time = time._to_cached_oqpy_expression()
if isinstance(time, (float, int)):
if require_nonnegative and time < 0:
raise ValueError(f"Expected a non-negative duration, but got {time}")
return OQDurationLiteral(time)
if isinstance(time, OQPyExpression):
if isinstance(time.type, (ast.UintType, ast.IntType, ast.FloatType)):
time = time * OQDurationLiteral(1)
Expand Down
19 changes: 18 additions & 1 deletion tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import oqpy
from oqpy import *
from oqpy.base import OQPyExpression, expr_matches, logical_and, logical_or
from oqpy.base import OQPyBinaryExpression, OQPyExpression, expr_matches, logical_and, logical_or
from oqpy.classical_types import OQIndexExpression
from oqpy.quantum_types import PhysicalQubits
from oqpy.timing import OQDurationLiteral
Expand Down Expand Up @@ -421,6 +421,7 @@ def test_binary_expressions():
prog.set(d, 5e-9 - d)
prog.set(d, d + convert_float_to_duration(10e-9))
prog.set(f, d / convert_float_to_duration(1))
prog.set(k, OQPyBinaryExpression("+", 2, k))

with pytest.raises(ValueError):
prog.set(f, "a" * i)
Expand All @@ -436,6 +437,8 @@ def test_binary_expressions():
prog.set(d, 5j / d)
with pytest.raises(TypeError):
prog.set(d, 5j * d)
with pytest.raises(ValueError):
OQPyBinaryExpression(".", d, d)

expected = textwrap.dedent(
"""
Expand Down Expand Up @@ -479,6 +482,7 @@ def test_binary_expressions():
d = 5.0ns - d;
d = d + 10.0ns;
f = d / 1s;
k = 2 + k;
"""
).strip()

Expand Down Expand Up @@ -1583,21 +1587,34 @@ class B:
def _to_oqpy_expression(self):
return FloatVar(1e-7, self.name)

@dataclass
class C:
def _to_oqpy_expression(self):
return 1e-7

def __rmul__(self, other):
return other * self._to_oqpy_expression()

frame = FrameVar(name="f1")
prog = Program()
prog.set(A("a1"), 2)
prog.set(FloatVar(name="c1"), 3 * C())
prog.delay(A("a2"), frame)
prog.delay(B("b1"), frame)
prog.delay(C(), frame)
expected = textwrap.dedent(
"""
OPENQASM 3.0;
duration a1 = 100.0ns;
float[64] c1;
duration a2 = 100.0ns;
frame f1;
float[64] b1 = 1e-07;
a1 = 2;
c1 = 3e-07;
delay[a2] f1;
delay[b1 * 1s] f1;
delay[100.0ns] f1;
"""
).strip()
assert prog.to_qasm() == expected
Expand Down

0 comments on commit f9d6fea

Please sign in to comment.