Skip to content

Commit

Permalink
lshift,rshift,bitwise not, fix return cast
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 19, 2023
1 parent ddb885a commit adec21c
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 16 deletions.
8 changes: 4 additions & 4 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@
- [x] Builtin functions: max(), min(), len()
- [x] Function call (to builtin)
- [x] Int arithmetic: +
- [ ] Int: bitwise not
- [ ] Int: shift right / left
- [x] Qtype: bitwise not
- [x] Qtype: shift right / left
- [ ] Int: subtraction
- [ ] Publish doc
- [ ] Publish doc on github

## Month 2:

### Week 1: (23 Oct 23)
- [ ] Inner function
- [ ] Int arithmetic expressions (-, *, /)
- [ ] Int arithmetic expressions (*, /)
- [ ] Parametrized qlassf

### Week 2: (30 Oct 23)
Expand Down
2 changes: 1 addition & 1 deletion qlasskit/ast2logic/t_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def translate_ast(fun, types) -> LogicFun:

exps = []
for stmt in fun.body:
s_exps, env = translate_statement(stmt, env)
s_exps, env = translate_statement(stmt, env, ret_.ttype)
exps.append(s_exps)

exps_flat = flatten(exps)
Expand Down
22 changes: 18 additions & 4 deletions qlasskit/ast2logic/t_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,15 @@ def unfold(v_exps, op):

# Unary: not
elif isinstance(expr, ast.UnaryOp):
if isinstance(expr.op, ast.Not):
texp, exp = translate_expression(expr.operand, env)
texp, exp = translate_expression(expr.operand, env)

if isinstance(expr.op, ast.Not):
if texp != bool:
raise exceptions.TypeErrorException(texp, bool)

return (bool, Not(exp))

elif isinstance(expr.op, ast.Invert) and hasattr(texp, "bitwise_not"):
return texp.bitwise_not((texp, exp))
else:
raise exceptions.ExpressionNotHandledException(expr)

Expand Down Expand Up @@ -203,14 +205,26 @@ def unfold(v_exps, op):

# Binop
elif isinstance(expr, ast.BinOp):
# Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift | RShift
# Sub | Mult | MatMult | Div | Mod | Pow |
# | BitOr | BitXor | BitAnd | FloorDiv
# print(ast.dump(expr))
tleft = translate_expression(expr.left, env)
tright = translate_expression(expr.right, env)

if isinstance(expr.op, ast.Add) and hasattr(tleft[0], "add"):
return tleft[0].add(tleft, tright)
elif (
isinstance(expr.op, ast.LShift)
and hasattr(tleft[0], "shift_left")
and isinstance(expr.right, ast.Constant)
):
return tleft[0].shift_left(tleft, expr.right.value)
elif (
isinstance(expr.op, ast.RShift)
and hasattr(tleft[0], "shift_right")
and isinstance(expr.right, ast.Constant)
):
return tleft[0].shift_right(tleft, expr.right.value)
else:
raise exceptions.ExpressionNotHandledException(expr)

Expand Down
13 changes: 12 additions & 1 deletion qlasskit/ast2logic/t_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
from sympy import Symbol
from sympy.logic.boolalg import Boolean

from ..types import TType
from . import Binding, Env, decompose_to_symbols, exceptions, translate_expression


def translate_statement( # noqa: C901
stmt, env: Env
stmt, env: Env, ret_type: TType
) -> Tuple[List[Tuple[str, Boolean]], Env]:
"""Parse a statement"""
# match stmt:
Expand Down Expand Up @@ -65,6 +66,16 @@ def translate_statement( # noqa: C901

elif isinstance(stmt, ast.Return):
texp, vexp = translate_expression(stmt.value, env) # TODO: typecheck

if (
hasattr(texp, "BIT_SIZE")
and hasattr(ret_type, "BIT_SIZE")
and texp.BIT_SIZE < ret_type.BIT_SIZE
):
texp, vexp = ret_type.fill((texp, vexp)) # type: ignore
elif texp != ret_type:
raise exceptions.TypeErrorException(texp, ret_type)

res = decompose_to_symbols(vexp, "_ret")
env.bind(Binding("_ret", texp, [x[0] for x in res]))
res = list(map(lambda x: (Symbol(x[0]), x[1]), res))
Expand Down
17 changes: 16 additions & 1 deletion qlasskit/types/qtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import Any, List, Literal, Tuple

from sympy.logic.boolalg import Boolean
from sympy.logic.boolalg import Boolean, Not
from typing_extensions import TypeAlias

TType: TypeAlias = object
Expand Down Expand Up @@ -101,6 +101,21 @@ def lte(tleft: TExp, tcomp: TExp) -> TExp:

# Operations

@staticmethod
def bitwise_not(v: TExp) -> TExp:
"""Apply a bitwise not"""
return (v[0], list(map(Not, v[1])))

@staticmethod
def shift_right(v: TExp, i: int = 1) -> TExp:
"""Apply a shift right"""
return (v[0], v[1][i:])

@staticmethod
def shift_left(v: TExp, i: int = 1) -> TExp:
"""Apply a shift left"""
return (v[0], [False] * i + v[1])

@staticmethod
def add(tleft: TExp, tcomp: TExp) -> TExp:
raise Exception("abstract")
2 changes: 1 addition & 1 deletion test/test_ast2logic_t_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ def test_statement_not_handled(self):

self.assertRaises(
exceptions.StatementNotHandledException,
lambda e: ast2logic.translate_statement(e, {}),
lambda e: ast2logic.translate_statement(e, {}, bool),
e,
)
21 changes: 17 additions & 4 deletions test/test_qlassf_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,23 @@ def test_composed_comparators(self):
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)

# def test(a: Qint2) -> Qint2:
# return a + 1
# def test(a: Qint2, b: Qint2) -> Qint2:
# return a + b
def test_shift_left(self):
f = "def test(n: Qint2) -> Qint4: return n << 1"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
print(qf.expressions)
compute_and_compare_results(self, qf)

def test_shift_right(self):
f = "def test(n: Qint2) -> Qint4: return n >> 1"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
print(qf.expressions)
compute_and_compare_results(self, qf)

# Our Qint are unsigned
# def test_invert_bitwise_not(self):
# f = "def test(n: Qint4) -> bool: return ~n"
# qf = qlassf(f, to_compile=COMPILATION_ENABLED)
# compute_and_compare_results(self, qf)


# TODO: parameterize
Expand Down
2 changes: 2 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def res_to_str(res):

res_original = qf.original_f(*args)
res_original_str = res_to_str(res_original)
# print(args, res_original, res_original_str, truth_line)
# print (qf.expressions)

cls.assertEqual(len(res_original_str), qf.ret_size)
return res_original_str
Expand Down

0 comments on commit adec21c

Please sign in to comment.