Skip to content

Commit

Permalink
type abi, and comparison simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 16, 2023
1 parent dc80c2d commit cc0936e
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 88 deletions.
1 change: 0 additions & 1 deletion qlasskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from .ast2logic import exceptions # noqa: F401
from .types import ( # noqa: F401, F403
Qtype,
Qbool,
Qint,
Qint2,
Qint4,
Expand Down
5 changes: 5 additions & 0 deletions qlasskit/ast2logic/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
import ast


class OperationNotSupportedException(Exception):
def __init__(self, tt, op):
super().__init__(f"Operation '{op}' not supported by type {tt}")


class TypeErrorException(Exception):
def __init__(self, got, excepted):
super().__init__(f"Got '{got}' excepted '{excepted}'")
Expand Down
84 changes: 26 additions & 58 deletions qlasskit/ast2logic/t_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sympy import Symbol
from sympy.logic import ITE, And, Not, Or, false, true

from ..types import Qbool, Qint, TExp, const_to_qtype
from ..types import Qbool, Qtype, TExp, const_to_qtype
from . import Env, exceptions


Expand Down Expand Up @@ -167,65 +167,33 @@ def unfold(v_exps, op):
tleft = translate_expression(expr.left, env)
tcomp = translate_expression(expr.comparators[0], env)

# TODO: check comparability here

# Eq
if isinstance(expr.ops[0], ast.Eq):
if tleft[0] == bool and tcomp[0] == bool:
return (bool, Qbool.eq(tleft[1], tcomp[1]))

# TODO: get here method from type class automatically, and use a type comparison table
elif issubclass(tleft[0], Qint) and issubclass(tcomp[0], Qint): # type: ignore
return Qint.eq(tleft, tcomp)

raise exceptions.TypeErrorException(tcomp[0], tleft[0])

# NotEq
elif isinstance(expr.ops[0], ast.NotEq):
if tleft[0] == bool and tcomp[0] == bool:
return (bool, Qbool.neq(tleft[1], tcomp[1]))

# TODO: get here method from type class automatically, and use a type comparison table
elif issubclass(tleft[0], Qint) and issubclass(tcomp[0], Qint): # type: ignore
return Qint.neq(tleft, tcomp)

raise exceptions.TypeErrorException(tcomp[0], tleft[0])

# Lt
elif isinstance(expr.ops[0], ast.Lt):
# TODO: get here method from type class automatically, and use a type comparison table
if issubclass(tleft[0], Qint) and issubclass(tcomp[0], Qint): # type: ignore
return Qint.lt(tleft, tcomp)

raise exceptions.TypeErrorException(tcomp[0], tleft[0])

# LtE
elif isinstance(expr.ops[0], ast.LtE):
# TODO: get here method from type class automatically, and use a type comparison table
if issubclass(tleft[0], Qint) and issubclass(tcomp[0], Qint): # type: ignore
return Qint.lte(tleft, tcomp)

raise exceptions.TypeErrorException(tcomp[0], tleft[0])

# Gt
elif isinstance(expr.ops[0], ast.Gt):
# TODO: get here method from type class automatically, and use a type comparison table
if issubclass(tleft[0], Qint) and issubclass(tcomp[0], Qint): # type: ignore
return Qint.gt(tleft, tcomp)

raise exceptions.TypeErrorException(tcomp[0], tleft[0])

# GtE
elif isinstance(expr.ops[0], ast.GtE):
# TODO: get here method from type class automatically, and use a type comparison table
if issubclass(tleft[0], Qint) and issubclass(tcomp[0], Qint): # type: ignore
return Qint.gte(tleft, tcomp)

raise exceptions.TypeErrorException(tcomp[0], tleft[0])
# Check comparability
if tleft[0] == bool and tcomp[0] == bool:
op_type = Qbool
elif issubclass(tleft[0], Qtype) and issubclass(tcomp[0], Qtype): # type: ignore
if not tleft[0].comparable(tcomp[0]): # type: ignore
raise exceptions.TypeErrorException(tcomp[0], tleft[0])
op_type = tleft[0] # type: ignore

# Call the comparator
comparators = [
(ast.Eq, "eq"),
(ast.NotEq, "neq"),
(ast.Lt, "lt"),
(ast.LtE, "lte"),
(ast.Gt, "gt"),
(ast.GtE, "gte"),
]

for ast_comp, comp_name in comparators:
if isinstance(expr.ops[0], ast_comp):
if not hasattr(op_type, comp_name):
raise exceptions.OperationNotSupportedException(op_type, comp_name)

return getattr(op_type, comp_name)(tleft, tcomp)

# Is | IsNot | In | NotIn
else:
raise exceptions.ExpressionNotHandledException(expr)
raise exceptions.ExpressionNotHandledException(expr)

# Binop
elif isinstance(expr, ast.BinOp):
Expand Down
18 changes: 15 additions & 3 deletions qlasskit/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# isort:skip_file

from typing import Any

from .qbool import Qbool # noqa: F401
from .qint import Qint, Qint2, Qint4, Qint8, Qint12, Qint16 # noqa: F401
from .qtype import Qtype, TExp, TType # noqa: F401
from sympy.logic import Not, Xor


def _neq(a, b):
return Xor(a, b)


def _eq(a, b):
return Not(_neq(a, b))


from .qtype import Qtype, TExp, TType # noqa: F401, E402
from .qbool import Qbool # noqa: F401, E402
from .qint import Qint, Qint2, Qint4, Qint8, Qint12, Qint16 # noqa: F401, E402

BUILTIN_TYPES = [Qint2, Qint4, Qint8, Qint12, Qint16]

Expand Down
11 changes: 5 additions & 6 deletions qlasskit/types/qbool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sympy.logic import Not, Xor
from . import TExp, _eq, _neq


class Qbool:
@staticmethod
def neq(a, b):
return Xor(a, b)
def eq(tleft: TExp, tcomp: TExp) -> TExp:
return (tleft[0], _eq(tleft[1], tcomp[1]))

@staticmethod
def eq(a, b):
return Not(Qbool.neq(a, b))
def neq(tleft: TExp, tcomp: TExp) -> TExp:
return (tleft[0], _neq(tleft[1], tcomp[1]))
38 changes: 21 additions & 17 deletions qlasskit/types/qint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple
from typing import List

from sympy import Symbol
from sympy.logic import And, Not, Or, false, true

from .qbool import Qbool
from .qtype import Qtype, TExp, TType
from . import _eq, _neq
from .qtype import Qtype, TExp


class Qint(int, Qtype):
Expand All @@ -28,27 +28,29 @@ def __init__(self, value):
super().__init__()
self.value = value

def __getitem__(self, i):
if i > self.BIT_SIZE:
raise Exception("Unbound")

return self.to_bool_str()[i] == "1"

@classmethod
def from_bool(cls, v: List[bool]):
return cls(int("".join(map(lambda x: "1" if x else "0", v[::-1])), 2))

def to_bool_str(self) -> str:
def to_bin(self) -> str:
s = bin(self.value)[2:][0 : self.BIT_SIZE]
return ("0" * (self.BIT_SIZE - len(s)) + s)[::-1]

@staticmethod
def const(v: int) -> List[bool]:
@classmethod
def comparable(cls, other_type=None) -> bool:
"""Return true if the type is comparable with itself or
with [other_type]"""
if not other_type or issubclass(other_type, Qint):
return True
return False

@classmethod
def const(cls, v: int) -> TExp:
"""Return the list of bool representing an int"""
return list(map(lambda c: True if c == "1" else False, bin(v)[2:]))[::-1]
return (cls, list(map(lambda c: True if c == "1" else False, bin(v)[2:]))[::-1])

@staticmethod
def fill(v: Tuple[TType, List[bool]]) -> Tuple[TType, List[bool]]:
def fill(v: TExp) -> TExp:
"""Fill a Qint to reach its bit_size"""
if len(v[1]) < v[0].BIT_SIZE: # type: ignore
v = (
Expand All @@ -57,12 +59,14 @@ def fill(v: Tuple[TType, List[bool]]) -> Tuple[TType, List[bool]]:
)
return v

# Comparators

@staticmethod
def eq(tleft: TExp, tcomp: TExp) -> TExp:
"""Compare two Qint for equality"""
ex = true
for x in zip(tleft[1], tcomp[1]):
ex = And(ex, Qbool.eq(x[0], x[1]))
ex = And(ex, _eq(x[0], x[1]))

if len(tleft[1]) > len(tcomp[1]):
for x in tleft[1][len(tcomp[1]) :]:
Expand All @@ -79,7 +83,7 @@ def neq(tleft: TExp, tcomp: TExp) -> TExp:
"""Compare two Qint for inequality"""
ex = false
for x in zip(tleft[1], tcomp[1]):
ex = Or(ex, Qbool.neq(x[0], x[1]))
ex = Or(ex, _neq(x[0], x[1]))

if len(tleft[1]) > len(tcomp[1]):
for x in tleft[1][len(tcomp[1]) :]:
Expand All @@ -102,7 +106,7 @@ def gt(tleft: TExp, tcomp: TExp) -> TExp:
else:
ex = Or(ex, And(*(prev + [a, Not(b)])))

prev.append(Qbool.eq(a, b))
prev.append(_eq(a, b))

if len(tleft[1]) > len(tcomp[1]):
for x in tleft[1][len(tcomp[1]) :]:
Expand Down
65 changes: 64 additions & 1 deletion qlasskit/types/qtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple
from typing import Any, List, Tuple

from sympy.logic.boolalg import Boolean
from typing_extensions import TypeAlias
Expand All @@ -23,3 +23,66 @@

class Qtype:
BIT_SIZE = 0

def __getitem__(self, i):
"""Return the i-nth bit value"""
if i > self.BIT_SIZE:
raise Exception("Unbound")

return self.to_bin()[i] == "1"

def to_bin(self) -> str:
"""Return the binary representation of the value"""
raise Exception("abstract")

@classmethod
def from_bool(cls, v: List[bool]) -> "Qtype":
"""Return the Qtype object from a list of booleans"""
raise Exception("abstract")

@classmethod
def comparable(cls, other_type=None) -> bool:
"""Return true if the type is comparable with itself or
with [other_type]"""
raise Exception("abstract")

@classmethod
def size(cls) -> int:
"""Return the size in bit"""
return cls.BIT_SIZE

@classmethod
def const(cls, value: Any) -> TExp:
"""Return a list of bool representing the value"""
raise Exception("abstract")

@staticmethod
def fill(v: TExp) -> TExp:
"""Fill with leading false"""
raise Exception("abstract")

# Comparators

@staticmethod
def eq(tleft: TExp, tcomp: TExp) -> TExp:
raise Exception("abstract")

@staticmethod
def neq(tleft: TExp, tcomp: TExp) -> TExp:
raise Exception("abstract")

@staticmethod
def gt(tleft: TExp, tcomp: TExp) -> TExp:
raise Exception("abstract")

@staticmethod
def gte(tleft: TExp, tcomp: TExp) -> TExp:
raise Exception("abstract")

@staticmethod
def lt(tleft: TExp, tcomp: TExp) -> TExp:
raise Exception("abstract")

@staticmethod
def lte(tleft: TExp, tcomp: TExp) -> TExp:
raise Exception("abstract")
4 changes: 2 additions & 2 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def res_to_str(res):
elif type(res) == int:
qi = Qint(res)
qi.BIT_SIZE = len(bin(res)) - 2
return qi.to_bool_str()
return qi.to_bin()
else:
return res.to_bool_str()
return res.to_bin()

args = []
i = 0
Expand Down

0 comments on commit cc0936e

Please sign in to comment.