Skip to content

Commit

Permalink
fix qfixed type definition
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Mar 14, 2024
1 parent c6534bd commit 83dfc05
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 43 deletions.
12 changes: 1 addition & 11 deletions qlasskit/ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _replace_types_annotations(ann, arg=None):
)

# Replace Qlist[T,n] with Tuple[(T,)*n]
if isinstance(ann, ast.Subscript) and ann.value.id == "Qlist":
elif isinstance(ann, ast.Subscript) and ann.value.id == "Qlist":
_elts = ann.slice.elts
_ituple = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[1].value)

Expand All @@ -61,16 +61,6 @@ def _replace_types_annotations(ann, arg=None):
slice=_ituple,
)

# Replace Qfixed[TI,TF] with Tuple[(TI,TF)]
elif isinstance(ann, ast.Subscript) and ann.value.id == "Qfixed":
_elts = ann.slice.elts
_ituple = ast.Tuple(elts=[copy.deepcopy(_elts[0]), copy.deepcopy(_elts[1])])

ann = ast.Subscript(
value=ast.Name(id="Tuple", ctx=ast.Load()),
slice=_ituple,
)

# Replace Qmatrix[T,n,m] with Tuple[(Tuple[(T,)*m],)*n]
elif isinstance(ann, ast.Subscript) and ann.value.id == "Qmatrix":
_elts = ann.slice.elts
Expand Down
23 changes: 23 additions & 0 deletions qlasskit/ast2logic/t_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,29 @@ def to_name(a):
ttypes_t = tuple(ttypes)
return Arg(base, Tuple[ttypes_t], al)

# Qfixed
if isinstance(ann, ast.Subscript) and ann.value.id == "Qfixed": # type: ignore
al = []
ind = 0

if hasattr(ann.slice, "elts"):
_elts = ann.slice.elts # type: ignore
else:
_elts = [ann.slice]

for i in _elts: # type: ignore
if isinstance(i, ast.Name) and to_name(i) == "bool":
al.append(f"{base}.{ind}")
ttypes.append(bool)
else:
inner_arg = translate_argument(i, env, base=f"{base}.{ind}")
ttypes.append(inner_arg.ttype)
al.extend(inner_arg.bitvec)
ind += 1
ttypes_t = tuple(ttypes)
return Arg(base, Qfixed[ttypes_t], al)


elif isinstance(ann, ast.Tuple):
al = []
ind = 0
Expand Down
43 changes: 16 additions & 27 deletions qlasskit/types/qfixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Tuple, TypeVar
from typing import Any, Tuple, TypeVar, Generic

from .qint import Qint
from .qtype import TExp
from .qtype import Qtype, TExp

TI = TypeVar("TI") # integer part
TF = TypeVar("TF") # fractional part


class QfixedMeta(type):
def __getitem__(cls, params):
if isinstance(params, tuple) and len(params) == 2:
TI, TF = params

if isinstance(TI, int):
bs = TI
TI = Qint
TI.BIT_SIZE = bs

if isinstance(TF, int):
bs = TF
TF = Qint
TF.BIT_SIZE = bs

assert issubclass(TI, Qint)
assert issubclass(TF, Qint)

return (
TI,
TF,
)


class Qfixed(metaclass=QfixedMeta):
class Qfixed(float, Qtype, Generic[TI, TF]):
def __init__(self, value):
super().__init__()
self.value = value

@classmethod
def const(cls, value: Any) -> TExp:
val_s = str(value).split(".")
Expand All @@ -54,4 +35,12 @@ def const(cls, value: Any) -> TExp:

a = Qint._const(int(val_s[0]))
b = Qint._const(int(val_s[1]))
return Tuple[a[0], b[0]], [a[1], b[1]]
return Qfixed[a[0], b[0]], [a[1], b[1]]


# Operations

@classmethod
def add(cls, tleft: TExp, tright: TExp) -> TExp:
"""Add two Qfixed"""
pass
10 changes: 5 additions & 5 deletions test/qlassf/test_fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
@parameterized_class(("compiler"), ENABLED_COMPILERS)
class TestQlassfFixed2(unittest.TestCase):
def test_fixed_const(self):
f = "def test() -> Qfixed[2, 2]:\n\treturn 0.1"
f = "def test() -> Qfixed[Qint2, Qint2]:\n\treturn 0.1"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

# def test_sum(self):
# f = "def test(a: Qfixed[2,2]) -> Qfixed[2, 2]:\n\treturn 0.1 + a"
# qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
# compute_and_compare_results(self, qf)
def test_sum_const(self):
f = "def test(a: Qfixed[2,2]) -> Qfixed[2, 2]:\n\treturn 0.1 + a"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

0 comments on commit 83dfc05

Please sign in to comment.