Skip to content

Commit

Permalink
Merge pull request #4 from dakk/if-then-else
Browse files Browse the repository at this point in the history
If-then-else and bool optimizer refactoring
  • Loading branch information
dakk authored Nov 9, 2023
2 parents 9739070 + 1c8a2b1 commit f5db491
Show file tree
Hide file tree
Showing 20 changed files with 587 additions and 261 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ htmlcov
coverage.xml
docs/build
*.egg
*.egg-info
*.egg-info
.t_statistics
2 changes: 1 addition & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@

### Week 3: (6 Nov 23)

- [x] Ast2logic: if-then-else statement
- [ ] Midterm call

### Week 4: (13 Nov 23)
Expand Down Expand Up @@ -119,7 +120,6 @@

### Language support

- [ ] Ast2logic: if-then-else statement
- [ ] Datatype: Dict
- [ ] Datatype: Fixed
- [ ] Datatype: Enum
Expand Down
68 changes: 68 additions & 0 deletions qlasskit/ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ def __init__(self, env={}, ret=None):
self.env = {}
self.const = {}
self.ret = None
self._uniqd = 1

@property
def uniqd(self):
"""Return an unique identifier as str"""
self._uniqd += 1
return f"{hex(self._uniqd)[2:]}"

def __unroll_arg(self, arg):
if isinstance(arg, ast.Tuple):
Expand Down Expand Up @@ -110,6 +117,66 @@ def visit_Name(self, node):

return node

def visit_If(self, node):
body = flatten([self.visit(n) for n in node.body])
orelse = flatten([self.visit(n) for n in node.orelse])
test_name = "_iftarg" + self.uniqd

if_l = [
ast.Assign(
targets=[ast.Name(id=test_name)],
value=self.visit(node.test),
)
]

for b in body:
if not isinstance(b, ast.Assign):
raise Exception("if body only allows assigns: ", ast.dump(b))

if len(b.targets) != 1:
raise Exception("if targets only allow one: ", ast.dump(b))

target_0id = b.targets[0].id

if target_0id[0:2] == "__" and target_0id not in self.env:
orelse_inner = ast.Name(id=target_0id[2:])
else:
orelse_inner = ast.Name(id=target_0id)

if_l.append(
ast.Assign(
targets=b.targets,
value=ast.IfExp(
test=ast.Name(id=test_name), body=b.value, orelse=orelse_inner
),
)
)

for b in orelse:
if not isinstance(b, ast.Assign):
raise Exception("if body only allows assigns: ", ast.dump(b))

if len(b.targets) != 1:
raise Exception("if targets only allow one: ", ast.dump(b))

target_0id = b.targets[0].id

if target_0id[0:2] == "__" and target_0id not in self.env:
orelse_inner = ast.Name(id=target_0id[2:])
else:
orelse_inner = ast.Name(id=target_0id)

if_l.append(
ast.Assign(
targets=b.targets,
value=ast.IfExp(
test=ast.Name(id=test_name), orelse=b.value, body=orelse_inner
),
)
)

return if_l

def visit_List(self, node):
return ast.Tuple(elts=[self.visit(el) for el in node.elts])

Expand Down Expand Up @@ -336,4 +403,5 @@ def ast2ast(a_tree):
a_tree = IndexReplacer().visit(a_tree)

a_tree = ASTRewriter().visit(a_tree)
# print(ast.dump(a_tree))
return a_tree
16 changes: 16 additions & 0 deletions qlasskit/boolopt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2023 Davide Gessa

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# isort:skip_file

from .sympytransformer import SympyTransformer # noqa: F401
100 changes: 100 additions & 0 deletions qlasskit/boolopt/bool_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2023 Davide Gessa

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

from sympy import Symbol, cse
from sympy.logic.boolalg import And, Boolean, Not, Or, Xor, simplify_logic

from ..ast2logic import BoolExpList
from . import SympyTransformer, deprecated
from .exp_transformers import (
remove_Implies,
remove_ITE,
transform_or2and,
transform_or2xor,
)


def custom_simplify_logic(expr):
if isinstance(expr, Xor):
return expr
elif isinstance(expr, (And, Or, Not)):
args = [custom_simplify_logic(arg) for arg in expr.args]
return type(expr)(*args)
else:
return simplify_logic(expr)


def merge_expressions(exps: BoolExpList) -> BoolExpList:
n_exps = []
emap: Dict[Symbol, Boolean] = {}

for s, e in exps:
e = e.xreplace(emap)
e = custom_simplify_logic(e)

if s.name[0:4] != "_ret":
emap[s] = e
else:
n_exps.append((s, e))

return n_exps


def apply_cse(exps: BoolExpList) -> BoolExpList:
lsts = list(zip(*exps))
repl, red = cse(list(lsts[1]))
res = repl + list(zip(lsts[0], red))
return res


class BoolOptimizerProfile:
def __init__(self, steps):
self.steps = steps

def apply(self, exps):
for opt in self.steps:
if isinstance(opt, SympyTransformer):
exps = list(map(lambda e: (e[0], opt.visit(e[1])), exps))
else:
exps = opt(exps)
return exps


bestWorkingOptimizer = BoolOptimizerProfile(
[
merge_expressions,
apply_cse,
remove_ITE(),
remove_Implies(),
transform_or2xor(),
transform_or2and(),
]
)


deprecatedWorkingOptimizer = BoolOptimizerProfile(
[
deprecated.remove_const_exps,
deprecated.remove_unnecessary_assigns,
deprecated.merge_unnecessary_assigns,
merge_expressions,
apply_cse,
remove_ITE(),
remove_Implies(),
transform_or2xor(),
transform_or2and(),
]
)
120 changes: 68 additions & 52 deletions qlasskit/bool_optimizer.py → qlasskit/boolopt/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from typing import Dict

from sympy import Symbol
from sympy.logic.boolalg import Boolean
from sympy.logic.boolalg import Boolean, simplify_logic, to_anf

from .ast2logic import BoolExpList
from ..ast2logic import BoolExpList


def remove_const_exps(exps: BoolExpList) -> BoolExpList:
Expand All @@ -37,34 +37,6 @@ def remove_const_exps(exps: BoolExpList) -> BoolExpList:
return n_exps


# def subsitute_exps(exps: BoolExpList) -> BoolExpList:
# """Subsitute exps (replace a = ~a, a = ~a, a = ~a => a = ~a)"""
# const: Dict[Symbol, Boolean] = {}
# n_exps: BoolExpList = []
# print(exps)

# for i in range(len(exps)):
# (s, e) = exps[i]
# e = e.subs(const)
# const[s] = e

# for x in e.free_symbols:
# if x in const:
# n_exps.append((x, const[x]))
# del const[x]

# for (s,e) in const.items():
# if s == e:
# continue

# n_exps.append((s,e))

# print(n_exps)
# print()
# print()
# return n_exps


def remove_unnecessary_assigns(exps: BoolExpList) -> BoolExpList:
"""Remove exp like: __a.0 = a.0, ..., a.0 = __a.0"""
n_exps: BoolExpList = []
Expand All @@ -89,40 +61,84 @@ def should_add(s, e, n_exps2):

return n_exps

# for s, e in exps:
# n_exps2 = []
# ename = f"__{s.name}"
# n_exps.append((s, e))

# for s_, e_ in reversed(n_exps):
# if s_.name == ename:
# continue
# else:
# _replaced = e_.subs(Symbol(ename), Symbol(s.name))
# if s_ != _replaced:
# n_exps2.append((s_, _replaced))
def merge_unnecessary_assigns(exps: BoolExpList) -> BoolExpList:
"""Translate exp like: __a.0 = !a, a = __a.0 ===> a = !a"""
n_exps: BoolExpList = []
rep_d = {}

# n_exps = n_exps2[::-1]
for s, e in exps:
if len(n_exps) >= 1 and n_exps[-1][0] == e: # and n_exps[-1][0].name[2:] == s:
old = n_exps.pop()
rep_d[old[0]] = old[1]
n_exps.append((s, e.subs(rep_d)))
else:
n_exps.append((s, e.subs(rep_d)))

# return n_exps
return n_exps


def merge_unnecessary_assigns(exps: BoolExpList) -> BoolExpList:
"""Translate exp like: __a.0 = !a, a = __a.0 ===> a = !a"""
def remove_unnecessary_aliases(exps: BoolExpList) -> BoolExpList:
"""Translate exps like: (__d.0, a), (d.0, __d.0 & a) to => (d.0, a & a)"""
n_exps: BoolExpList = []
rep_d = {}

for s, e in exps:
if len(n_exps) >= 1 and n_exps[-1][0] == e:
if len(n_exps) >= 1 and n_exps[-1][0] in e.free_symbols:
old = n_exps.pop()
n_exps.append((s, old[1]))
rep_d[old[0]] = old[1]
n_exps.append((s, e.subs(rep_d)))
else:
n_exps.append((s, e.subs(rep_d)))

return n_exps


def remove_aliases(exps: BoolExpList) -> BoolExpList:
aliases = {}
n_exps = []
for s, e in exps:
if isinstance(e, Symbol):
aliases[s] = e
elif s in aliases:
del aliases[s]
n_exps.append((s, e.subs(aliases)))
else:
n_exps.append((s, e.subs(aliases)))

return n_exps


def s2_mega(exps: BoolExpList) -> BoolExpList:
n_exps: BoolExpList = []
exp_d = {}

for s, e in exps:
exp_d[s] = e
n_exps.append((s, e.subs(exp_d)))

s_count = {}
exps = n_exps

for s, e in exps:
if s.name not in s_count:
s_count[s.name] = 0

for x in e.free_symbols:
if x.name in s_count:
s_count[x.name] += 1

n_exps = []
for s, e in exps:
if s_count[s.name] > 0 or s.name[0:4] == "_ret":
n_exps.append((s, e))

return n_exps


# [(h, a_list.0.0 & a_list.0.1), (h, a_list.1.0 & a_list.1.1 & h),
# (h, a_list.2.0 & a_list.2.1 & h), (_ret, a_list.3.0 & a_list.3.1 & h)]
# TO
# (_ret, a_list_3_0 & a_list_3_1 & a_list_2_0 & a_list_2_1 & a_list_1_0 & a_list_1_1 &
# a_list_0_0 & a_list_0_1)
def exps_simplify(exps: BoolExpList) -> BoolExpList:
return list(map(lambda e: (e[0], simplify_logic(e[1])), exps))


def exps_to_anf(exps: BoolExpList) -> BoolExpList:
return list(map(lambda e: (e[0], to_anf(e[1])), exps))
Loading

0 comments on commit f5db491

Please sign in to comment.