From 0975213a277e06f1b6afff6e312418b7fbe04063 Mon Sep 17 00:00:00 2001 From: "Davide Gessa (dakk)" Date: Tue, 17 Oct 2023 16:07:25 +0200 Subject: [PATCH] allow custom types and tests --- qlasskit/ast2logic/t_ast.py | 3 ++- qlasskit/qlassf.py | 2 +- test/test_qlassf.py | 17 ++++++++++++++++- test/utils.py | 11 ++++------- 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/qlasskit/ast2logic/t_ast.py b/qlasskit/ast2logic/t_ast.py index 5196685a..5ff82268 100644 --- a/qlasskit/ast2logic/t_ast.py +++ b/qlasskit/ast2logic/t_ast.py @@ -24,11 +24,12 @@ from .typing import Args, LogicFun -def translate_ast(fun) -> LogicFun: +def translate_ast(fun, types) -> LogicFun: fun_name: str = fun.name # env contains names visible from the current scope env = Env() + [env.bind_type((t.__name__, t)) for t in types] args: Args = translate_arguments(fun.args.args, env) diff --git a/qlasskit/qlassf.py b/qlasskit/qlassf.py index cb07363f..6d161f35 100644 --- a/qlasskit/qlassf.py +++ b/qlasskit/qlassf.py @@ -162,7 +162,7 @@ def from_function( fun_ast = ast.parse(f if isinstance(f, str) else inspect.getsource(f)) fun = fun_ast.body[0] - fun_name, args, fun_ret, exps = translate_ast(fun) + fun_name, args, fun_ret, exps = translate_ast(fun, types) original_f = eval(fun_name) if isinstance(f, str) else f qf = QlassF(fun_name, original_f, args, fun_ret, exps) diff --git a/test/test_qlassf.py b/test/test_qlassf.py index 6fab4b91..f659330d 100644 --- a/test/test_qlassf.py +++ b/test/test_qlassf.py @@ -14,7 +14,7 @@ import unittest -from qlasskit import Qint4, Qint12, QlassF, qlassf +from qlasskit import Qint, Qint4, Qint12, QlassF, exceptions, qlassf from . import utils from .utils import COMPILATION_ENABLED, compute_and_compare_results @@ -34,6 +34,21 @@ def test_decorator(self): self.assertTrue(isinstance(c, QlassF)) +class TestQlassfCustomTypes(unittest.TestCase): + def test_custom_qint3(self): + qf = qlassf( + utils.test_qint3, types=[utils.Qint3], to_compile=COMPILATION_ENABLED + ) + compute_and_compare_results(self, qf) + + def test_custom_qint3_notfound(self): + self.assertRaises( + exceptions.UnknownTypeException, + lambda f: qlassf(f, types=[], to_compile=COMPILATION_ENABLED), + utils.test_qint3, + ) + + class TestQlassfTruthTable(unittest.TestCase): def test_not_truth(self): f = "def test(a: bool) -> bool:\n\treturn not a" diff --git a/test/utils.py b/test/utils.py index 6e03d0ed..e9e69c66 100644 --- a/test/utils.py +++ b/test/utils.py @@ -29,17 +29,14 @@ def test_not(a: bool) -> bool: return not a -# def get_qlassf_input_bits(qf: QlassF) -> int: -# pass +class Qint3(Qint): + BIT_SIZE = 3 -# def get_input_combinations(n_bits: int) -> List[List[bool]]: -# pass +def test_qint3(a: Qint3) -> bool: + return not a[0] -# def compute_originalf_results(qf: QlassF) -> List[List[bool]]: -# pass - aer_simulator = Aer.get_backend("aer_simulator")