Skip to content

Commit

Permalink
int and float builting cast (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk authored Mar 20, 2024
1 parent dfdb310 commit 82dbf56
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 20 deletions.
2 changes: 2 additions & 0 deletions docs/source/supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ Bultin functions:
- `any(Tuple)`, `any(Qlist)`: returns True if any of the elemnts are True
- `ord(Qchar)`: returns the integer value of the given Qchar
- `chr(Qint)`: returns the char given its ascii code
- `int(Qfixed | Qint)`: returns the integer part of a Qfixed
- `float(Qint | Qfixed)`: returns a Qfixed representing the Qint


Statements
Expand Down
31 changes: 21 additions & 10 deletions qlasskit/ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ def _replace_types_annotations(ann, arg=None):
slice=_ituple,
)

# Replace QintX with Qint[X]
if isinstance(ann, ast.Name) and ann.id[:4] == "Qint":
ann = ast.Subscript(
value=ast.Name(id="Qint", ctx=ast.Load()),
slice=ast.Constant(value=int(ann.id[4:])),
)

# Replace QfixedX with Qfixed[X]
if isinstance(ann, ast.Name) and ann.id[:6] == "Qfixed":
ann = ast.Subscript(
value=ast.Name(id="Qfixed", ctx=ast.Load()),
slice=ast.Constant(value=int(ann.id[6:])),
)

# Replace Qlist[T,n] with Tuple[(T,)*n]
if isinstance(ann, ast.Subscript) and ann.value.id == "Qlist":
_elts = ann.slice.elts
Expand Down Expand Up @@ -471,11 +485,11 @@ def iterif(arg_l):
return iterif(args)

def __call_sum(self, node):
if len(node.args) == 1:
args = self.__unroll_arg(node.args[0])
else:
if len(node.args) != 1:
raise Exception(f"sum() takes at most 1 argument ({len(node.args)} given)")

args = self.__unroll_arg(node.args[0])

def iterif(arg_l):
if len(arg_l) == 1:
return arg_l[0]
Expand All @@ -485,25 +499,22 @@ def iterif(arg_l):
return iterif(args)

def __call_anyall(self, node):
if len(node.args) == 1:
args = self.__unroll_arg(node.args[0])
else:
if len(node.args) != 1:
raise Exception(f"any() takes exactly 1 argument ({len(node.args)} given)")

args = self.__unroll_arg(node.args[0])
op = ast.Or() if node.func.id == "any" else ast.And()
return ast.BoolOp(op=op, values=args)

def __call_chr(self, node):
if len(node.args) != 1:
raise Exception(f"chr() takes exactly 1 argument ({len(node.args)} given)")
args = self.__unroll_arg(node.args[0])
return args[0]
return node.args[0]

def __call_ord(self, node):
if len(node.args) != 1:
raise Exception(f"ord() takes exactly 1 argument ({len(node.args)} given)")
args = self.__unroll_arg(node.args[0])
return args[0]
return node.args[0]

def visit_Call(self, node):
node.args = [self.visit(ar) for ar in node.args]
Expand Down
52 changes: 44 additions & 8 deletions qlasskit/ast2logic/t_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sympy.logic import ITE, And, Not, Or, Xor, false, true

from ..boolquant import QuantumBooleanGate
from ..types import Qbool, Qtype, TExp, const_to_qtype
from ..types import Qbool, Qfixed, Qint, Qtype, TExp, const_to_qtype
from . import Env, exceptions


Expand Down Expand Up @@ -313,19 +313,55 @@ def unfold(v_exps, op):
else:
return args[0][0], q_gate(*args_v)

if not hasattr(expr.func, "id"):
elif not hasattr(expr.func, "id"):
raise exceptions.ExpressionNotHandledException(expr)

# Typecast
if (
env.know_type(expr.func.id)
and len(expr.args) == 1
and isinstance(expr.args[0], ast.Constant)
):
elif env.know_type(expr.func.id):
if len(expr.args) != 1:
raise Exception(
f"{expr.func.id}() takes exactly 1 argument ({len(expr.args)} given)"
)

if not isinstance(expr.args[0], ast.Constant):
raise Exception(f"{expr.func.id}() accepts only constant values")

return env.gettype(expr.func.id).const(expr.args[0].value)

# int()
elif expr.func.id == "int":
if len(expr.args) != 1:
raise Exception(
f"int() takes exactly 1 argument ({len(expr.args)} given)"
)

(ta, te) = translate_expression(expr.args[0], env)
if ta.__name__[:4] == "Qint": # type: ignore
return (ta, te)
elif ta.__name__[:6] == "Qfixed": # type: ignore
ip = ta.integer_part((ta, te)) # type: ignore
return (Qint.type_for_size(len(ip)), ip)
else:
raise Exception(f"int() accepts only Qfixed and Qint: {ta} given")

# float()
elif expr.func.id == "float":
if len(expr.args) != 1:
raise Exception(
f"float() takes exactly 1 argument ({len(expr.args)} given)"
)

(ta, te) = translate_expression(expr.args[0], env)
if ta.__name__[:6] == "Qfixed": # type: ignore
return (ta, te)
elif ta.__name__[:4] == "Qint": # type: ignore
tf = Qfixed.type_for_size(len(te))
return tf.fill((tf, te))
else:
raise Exception(f"float() accepts only Qfixed and Qint: {ta} given")

# Known function
if env.know_function(expr.func.id):
elif env.know_function(expr.func.id):
def_f = env.getdef(expr.func.id)
args = [translate_expression(e, env) for e in expr.args]

Expand Down
6 changes: 5 additions & 1 deletion qlasskit/types/qfixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,4 +361,8 @@ def __getitem__(cls, params):


class Qfixed(metaclass=QfixedMeta):
pass
@staticmethod
def type_for_size(s: int):
for det_type in QFIXED_TYPES:
if det_type.BIT_SIZE_INTEGER == s:
return det_type
6 changes: 5 additions & 1 deletion qlasskit/types/qint.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,8 @@ def __getitem__(cls, params):


class Qint(metaclass=QintMeta):
pass
@staticmethod
def type_for_size(s: int):
for det_type in QINT_TYPES:
if det_type.BIT_SIZE == s:
return det_type
25 changes: 25 additions & 0 deletions test/qlassf/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,28 @@ def test_max_in_list(self):
# f = "def test(a: Qlist[bool, 3]) -> Qint[4]:\n\treturn range(len(a))"
# qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
# compute_and_compare_results(self, qf)

def test_int_of_int(self):
f = "def test(a: Qint[2]) -> Qint[4]:\n\treturn int(a) * 2"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_int_of_int2(self):
f = "def test(a: Qint2) -> Qint[4]:\n\treturn int(a) * 2"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_int_of_fixed(self):
f = "def test(a: Qfixed[2,2]) -> Qint[4]:\n\treturn int(a) * 2"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_float_of_fixed(self):
f = "def test(a: Qfixed[2,2]) -> Qfixed[2,2]:\n\treturn float(a)"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_float_of_int(self):
f = "def test(a: Qint[2]) -> Qfixed[2,2]:\n\treturn float(a)"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

0 comments on commit 82dbf56

Please sign in to comment.