Skip to content

Commit

Permalink
move type operations to type classes
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 7, 2023
1 parent dfbf3c6 commit 6ac38b0
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 46 deletions.
1 change: 0 additions & 1 deletion qlasskit/ast2logic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,3 @@
from .t_ast import translate_ast # noqa: F401, E402
from .typing import Qint, Qint2, Qint4, Qint8, Qint12, Qint16, Qtype # noqa: F401
from . import exceptions # noqa: F401, E402

56 changes: 11 additions & 45 deletions qlasskit/ast2logic/t_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
from typing import List, Tuple, Type, get_args
from typing import List, Tuple, get_args

from sympy import Symbol
from sympy.logic import ITE, And, Not, Or, false, true
from sympy.logic.boolalg import Boolean
from typing_extensions import TypeAlias

from . import Env, exceptions
from .typing import Qint, Qint2, Qint4, Qint8, Qint12, Qint16

TType: TypeAlias = object


def xnor(a, b):
return Or(And(a, b), And(Not(a), Not(b)))
from .typing import Qint, Qint2, Qint4, Qint8, Qint12, Qint16, TType


def type_of_exp(vlist, base, res=[]) -> List[Symbol]:
Expand Down Expand Up @@ -147,27 +140,12 @@ def unfold(v_exps, op):
return (bool, false)
elif isinstance(expr.value, int):
v = expr.value
vb = map(lambda c: True if c == "1" else False, bin(v)[2:])
v_ret: Tuple[Type[Qint], List[bool]]
if v < 2**2:
v_ret = (Qint2, list(vb))
elif v < 2**4:
v_ret = (Qint4, list(vb))
elif v < 2**8:
v_ret = (Qint8, list(vb))
elif v < 2**12:
v_ret = (Qint12, list(vb))
elif v < 2**16:
v_ret = (Qint16, list(vb))
else:
raise Exception("Constant value is too big")

if len(v_ret[1]) < v_ret[0].BIT_SIZE:
v_ret = (
v_ret[0],
[False] * (v_ret[0].BIT_SIZE - len(v_ret[1])) + v_ret[1],
)
return v_ret

for t in [Qint2, Qint4, Qint8, Qint12, Qint16]:
if v < 2**t.BIT_SIZE:
return Qint.fill((t, Qint.const(v)))

raise Exception(f"Constant value is too big: {v}")
else:
raise exceptions.ExpressionNotHandledException(expr)

Expand All @@ -187,24 +165,12 @@ def unfold(v_exps, op):
tleft = translate_expression(expr.left, env)
tcomp = translate_expression(expr.comparators[0], env)

if tleft[0] != tcomp[0] and tleft[0].__bases__ != tcomp[0].__bases__: # type: ignore
raise exceptions.TypeErrorException(tcomp[0], tleft[0])

# Eq
if isinstance(expr.ops[0], ast.Eq):
ex = true
for x in zip(tleft[1], tcomp[1]):
ex = And(ex, xnor(x[0], x[1]))

if len(tleft[1]) > len(tcomp[1]):
for x in tleft[1][len(tcomp[1]) :]:
ex = And(ex, Not(x))
if issubclass(tleft[0], Qint) and issubclass(tcomp[0], Qint): # type: ignore
return Qint.eq(tleft, tcomp)

if len(tleft[1]) < len(tcomp[1]):
for x in tcomp[1][len(tleft[1]) :]:
ex = And(ex, Not(x))

return (bool, ex)
raise exceptions.TypeErrorException(tcomp[0], tleft[0])

# NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn
else:
Expand Down
42 changes: 42 additions & 0 deletions qlasskit/ast2logic/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,18 @@
from typing import List, Tuple

from sympy import Symbol
from sympy.logic import And, Not, Or, true # false
from sympy.logic.boolalg import Boolean
from typing_extensions import TypeAlias

# from .ast2logic.t_expression import TType

TType: TypeAlias = object


def xnor(a, b):
return Or(And(a, b), And(Not(a), Not(b)))


class Arg:
def __init__(self, name: str, ttype: object, bitvec: List[str]):
Expand All @@ -44,6 +52,8 @@ def to_exp(self) -> List[Symbol]:


class Qtype:
BIT_SIZE = 8

def __init__(self):
pass

Expand All @@ -61,13 +71,45 @@ def to_bool(self):
# return self.value


# TODO: use generics for bitsize
class Qint(int, Qtype):
BIT_SIZE = 8

def __init__(self, value):
super().__init__()
self.value = value

@staticmethod
def const(v: int) -> List[bool]:
return list(map(lambda c: True if c == "1" else False, bin(v)[2:]))

@staticmethod
def fill(v: Tuple[TType, List[bool]]) -> Tuple[TType, List[bool]]:
if len(v[1]) < v[0].BIT_SIZE: # type: ignore
v = (
v[0],
[False] * (v[0].BIT_SIZE - len(v[1])) + v[1], # type: ignore
)
return v

@staticmethod
def eq(
tleft: Tuple[TType, Boolean], tcomp: Tuple[TType, Boolean]
) -> Tuple[TType, Boolean]:
ex = true
for x in zip(tleft[1], tcomp[1]):
ex = And(ex, xnor(x[0], x[1]))

if len(tleft[1]) > len(tcomp[1]):
for x in tleft[1][len(tcomp[1]) :]:
ex = And(ex, Not(x))

if len(tleft[1]) < len(tcomp[1]):
for x in tcomp[1][len(tleft[1]) :]:
ex = And(ex, Not(x))

return (bool, ex)


class Qint2(Qint):
BIT_SIZE = 2
Expand Down

0 comments on commit 6ac38b0

Please sign in to comment.