Skip to content

Commit

Permalink
allow custom types and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 17, 2023
1 parent cc0936e commit 0975213
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 10 deletions.
3 changes: 2 additions & 1 deletion qlasskit/ast2logic/t_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
from .typing import Args, LogicFun


def translate_ast(fun) -> LogicFun:
def translate_ast(fun, types) -> LogicFun:
fun_name: str = fun.name

# env contains names visible from the current scope
env = Env()
[env.bind_type((t.__name__, t)) for t in types]

args: Args = translate_arguments(fun.args.args, env)

Expand Down
2 changes: 1 addition & 1 deletion qlasskit/qlassf.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def from_function(
fun_ast = ast.parse(f if isinstance(f, str) else inspect.getsource(f))
fun = fun_ast.body[0]

fun_name, args, fun_ret, exps = translate_ast(fun)
fun_name, args, fun_ret, exps = translate_ast(fun, types)
original_f = eval(fun_name) if isinstance(f, str) else f

qf = QlassF(fun_name, original_f, args, fun_ret, exps)
Expand Down
17 changes: 16 additions & 1 deletion test/test_qlassf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import unittest

from qlasskit import Qint4, Qint12, QlassF, qlassf
from qlasskit import Qint, Qint4, Qint12, QlassF, exceptions, qlassf

from . import utils
from .utils import COMPILATION_ENABLED, compute_and_compare_results
Expand All @@ -34,6 +34,21 @@ def test_decorator(self):
self.assertTrue(isinstance(c, QlassF))


class TestQlassfCustomTypes(unittest.TestCase):
def test_custom_qint3(self):
qf = qlassf(
utils.test_qint3, types=[utils.Qint3], to_compile=COMPILATION_ENABLED
)
compute_and_compare_results(self, qf)

def test_custom_qint3_notfound(self):
self.assertRaises(
exceptions.UnknownTypeException,
lambda f: qlassf(f, types=[], to_compile=COMPILATION_ENABLED),
utils.test_qint3,
)


class TestQlassfTruthTable(unittest.TestCase):
def test_not_truth(self):
f = "def test(a: bool) -> bool:\n\treturn not a"
Expand Down
11 changes: 4 additions & 7 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,14 @@ def test_not(a: bool) -> bool:
return not a


# def get_qlassf_input_bits(qf: QlassF) -> int:
# pass
class Qint3(Qint):
BIT_SIZE = 3


# def get_input_combinations(n_bits: int) -> List[List[bool]]:
# pass
def test_qint3(a: Qint3) -> bool:
return not a[0]


# def compute_originalf_results(qf: QlassF) -> List[List[bool]]:
# pass

aer_simulator = Aer.get_backend("aer_simulator")


Expand Down

0 comments on commit 0975213

Please sign in to comment.