Skip to content

Commit

Permalink
crop return type
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 28, 2023
1 parent 3084751 commit 7e81cc7
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 2 deletions.
1 change: 1 addition & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
- [x] Inner function
- [x] Groover algorithm
- [x] Tuple-tuple comparison
- [x] Multi var assign

### Week 2: (30 Oct 23)

Expand Down
22 changes: 21 additions & 1 deletion qlasskit/ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@
from .ast2logic import flatten


class NameConstantReplacer(ast.NodeTransformer):
def __init__(self, name_id, constant):
self.name_id = name_id
self.constant = constant

def generic_visit(self, node):
return super().generic_visit(node)

def visit_Name(self, node):
if node.id == self.name_id:
return ast.Constant(value=self.constant)

return node


class ASTRewriter(ast.NodeTransformer):
def __init__(self, env={}, ret=None):
self.env = {}
Expand Down Expand Up @@ -145,7 +160,12 @@ def visit_For(self, node):
ast.Assign(targets=[node.target], value=ast.Constant(value=i))
)
rolls.extend(flatten([tar_assign]))
rolls.extend(flatten([self.visit(copy.deepcopy(b)) for b in node.body]))

new_body = [
NameConstantReplacer(node.target.id, i).visit(copy.deepcopy(b))
for b in node.body
]
rolls.extend(flatten([self.visit(copy.deepcopy(b)) for b in new_body]))

return rolls

Expand Down
2 changes: 1 addition & 1 deletion qlasskit/ast2logic/t_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def unfold(v_exps, op):
elif isinstance(expr.op, ast.BitXor) and hasattr(tleft[0], "bitwise_xor"):
return tleft[0].bitwise_xor(tleft, tright)
elif isinstance(expr.op, ast.BitAnd) and hasattr(tleft[0], "bitwise_and"):
return tleft[0].bitwise_and(tleft, tright)
return tright[0].bitwise_and(tleft, tright) # type: ignore
elif isinstance(expr.op, ast.BitOr) and hasattr(tleft[0], "bitwise_or"):
return tleft[0].bitwise_or(tleft, tright)
elif (
Expand Down
6 changes: 6 additions & 0 deletions qlasskit/ast2logic/t_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ def translate_statement( # noqa: C901
and texp.BIT_SIZE < ret_type.BIT_SIZE
):
texp, vexp = ret_type.fill((texp, vexp)) # type: ignore
elif (
hasattr(texp, "BIT_SIZE")
and hasattr(ret_type, "BIT_SIZE")
and texp.BIT_SIZE > ret_type.BIT_SIZE
):
texp, vexp = ret_type.crop((texp, vexp)) # type: ignore
elif texp != ret_type:
raise exceptions.TypeErrorException(texp, ret_type)

Expand Down
1 change: 1 addition & 0 deletions qlasskit/qlassf.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def from_function(

# Return the qlassf object
qf = QlassF(fun_name, original_f, args, fun_ret, exps)

if to_compile:
qf.compile()
return qf
Expand Down
10 changes: 10 additions & 0 deletions qlasskit/types/qint.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ def fill(cls, v: TExp) -> TExp:
)
return v

@classmethod
def crop(cls, v: TExp) -> TExp:
"""Crop a Qint to reach its bit_size"""
if len(v[1]) > cls.BIT_SIZE: # type: ignore
v = (
cls,
v[1][0 : cls.BIT_SIZE], # type: ignore
)
return v

# Comparators

@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions qlasskit/types/qtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def fill(v: TExp) -> TExp:
"""Fill with leading false"""
raise Exception("abstract")

@staticmethod
def crop(v: TExp) -> TExp:
"""Crop to right size"""
raise Exception("abstract")

# Comparators

@staticmethod
Expand Down

0 comments on commit 7e81cc7

Please sign in to comment.