Skip to content

Commit

Permalink
fix type inconsistency
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Apr 11, 2024
1 parent 5a1fc4d commit 1d2a778
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 28 deletions.
6 changes: 3 additions & 3 deletions qlasskit/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def const_to_qtype(value: Any) -> TExp:
return Qchar.const(value)

elif isinstance(value, float):
for det_type in QFIXED_TYPES: # type: ignore
v = det_type.const(value) # type: ignore
c_val = det_type.from_bool(v[1])
for fdet_type in QFIXED_TYPES:
v = fdet_type.const(value)
c_val = fdet_type.from_bool(v[1])
if c_val > value - 0.05 and c_val < value + 0.05:
return v

Expand Down
35 changes: 23 additions & 12 deletions qlasskit/types/qfixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import List, cast

from sympy import Symbol
from sympy.logic import And, Not, Or, false, true
Expand Down Expand Up @@ -126,7 +126,7 @@ def fractional_part(v: TExp):
return v[1][v[0].BIT_SIZE_INTEGER :]

@staticmethod
def _to_qint_repr(v: TExp):
def _to_qint_repr(v: Qtype):
if not issubclass(v[0], QfixedImp):
raise TypeErrorException(v[0], QfixedImp)

Expand Down Expand Up @@ -157,8 +157,16 @@ def neq(tleft: TExp, tcomp: TExp) -> TExp:

@staticmethod
def gt(tleft: TExp, tcomp: TExp) -> TExp:
tl_v = QfixedImp._to_qint_repr(tleft)
tc_v = QfixedImp._to_qint_repr(tcomp)
if not issubclass(tleft[0], QfixedImp):
raise TypeErrorException(tleft[0], QfixedImp)
if not issubclass(tcomp[0], QfixedImp):
raise TypeErrorException(tcomp[0], QfixedImp)

tleft_e = cast(Qtype, tleft)
tcomp_e = cast(Qtype, tcomp)

tl_v = QfixedImp._to_qint_repr(tleft_e)
tc_v = QfixedImp._to_qint_repr(tcomp_e)

prev: List[Symbol] = []

Expand Down Expand Up @@ -205,18 +213,21 @@ def add(cls, tleft: TExp, tright: TExp) -> TExp:
if not issubclass(tleft[0], QfixedImp):
raise TypeErrorException(tleft[0], QfixedImp)

if len(tleft[1]) > len(tright[1]):
tright = tleft[0].fill(tright)
tright_e = cast(Qtype, tright)
tleft_e = cast(Qtype, tleft)

if len(tleft_e[1]) > len(tright_e[1]):
tright_e = tleft_e[0].fill(tright_e)

elif len(tleft[1]) < len(tright[1]):
tleft = tright[0].fill(tleft) # type: ignore
elif len(tleft_e[1]) < len(tright_e[1]):
tleft_e = tright_e[0].fill(tleft_e)

tl_v = QfixedImp._to_qint_repr(tleft)
tr_v = QfixedImp._to_qint_repr(tright)
tl_v = QfixedImp._to_qint_repr(tleft_e)
tr_v = QfixedImp._to_qint_repr(tright_e)

res = QintImp.add((tleft[0], tl_v), (tright[0], tr_v))
res = QintImp.add((tleft_e[0], tl_v), (tright_e[0], tr_v))

return (tleft[0], QfixedImp._from_qint_repr((tleft[0], res[1])))
return (tleft_e[0], QfixedImp._from_qint_repr((tleft_e[0], res[1])))

@classmethod
def sub(cls, tleft: TExp, tright: TExp) -> TExp:
Expand Down
32 changes: 19 additions & 13 deletions qlasskit/types/qint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import List, cast

from sympy import Symbol
from sympy.logic import And, Not, Or, Xor, false, true
Expand Down Expand Up @@ -147,19 +147,22 @@ def add(cls, tleft: TExp, tright: TExp) -> TExp:
if not issubclass(tright[0], Qtype):
raise TypeErrorException(tright[0], Qtype)

if len(tleft[1]) > len(tright[1]):
tright = tleft[0].fill(tright)
tright_e = cast(Qtype, tright)
tleft_e = cast(Qtype, tleft)

elif len(tleft[1]) < len(tright[1]):
tleft = tright[0].fill(tleft) # type: ignore
if len(tleft_e[1]) > len(tright_e[1]):
tright_e = tleft_e[0].fill(tright_e)

elif len(tleft_e[1]) < len(tright_e[1]):
tleft_e = tright_e[0].fill(tleft_e)

carry = False
sums = []
for x in zip(tleft[1], tright[1]):
for x in zip(tleft_e[1], tright_e[1]):
carry, sum = _full_adder(carry, x[0], x[1])
sums.append(sum)

return (cls if cls.BIT_SIZE > tleft[0].BIT_SIZE else tleft[0], sums) # type: ignore
return (cls if cls.BIT_SIZE > tleft_e[0].BIT_SIZE else tleft_e[0], sums)

@classmethod
def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901
Expand Down Expand Up @@ -230,14 +233,17 @@ def bitwise_generic(cls, op, tleft: TExp, tright: TExp) -> TExp:
if not issubclass(tright[0], Qtype):
raise TypeErrorException(tright[0], Qtype)

if len(tleft[1]) > len(tright[1]):
tright = tleft[0].fill(tright)
tright_e = cast(Qtype, tright)
tleft_e = cast(Qtype, tleft)

if len(tleft_e[1]) > len(tright_e[1]):
tright_e = tleft_e[0].fill(tright_e)

elif len(tleft[1]) < len(tright[1]):
tleft = tright[0].fill(tleft) # type: ignore
elif len(tleft_e[1]) < len(tright_e[1]):
tleft_e = tright_e[0].fill(tleft_e)

newl = [op(a, b) for (a, b) in zip(tleft[1], tright[1])]
return (tright[0], newl)
newl = [op(a, b) for (a, b) in zip(tleft_e[1], tright_e[1])]
return (tright_e[0], newl)

@classmethod
def bitwise_xor(cls, tleft: TExp, tright: TExp) -> TExp:
Expand Down

0 comments on commit 1d2a778

Please sign in to comment.