Skip to content

Commit

Permalink
[Sim][#50] Have a mean to see simplified expressions.
Browse files Browse the repository at this point in the history
  • Loading branch information
kosarev committed Jul 10, 2022
1 parent 7f9a85f commit 3941644
Showing 1 changed file with 116 additions and 45 deletions.
161 changes: 116 additions & 45 deletions tests/z80sim/z80sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,8 @@ class Literal(object):
__shorten_ids = {}

@staticmethod
def get(id, sign=False):
key = id, sign
t = __class__.__literals.get(key)
def get(id):
t = __class__.__literals.get(id)
if t is not None:
return t

Expand All @@ -167,16 +166,27 @@ def get(id, sign=False):
shorten_id, __class__.__shorten_ids[shorten_id], id)

t = __class__()
t.id, t.sign = id, sign
t.shorten_id = shorten_id
t.hash = HASH(str((id, sign)).encode()).digest()
t.__expr = None
__class__.__literals[key] = t
i = __class__()

t.id, t.sign = id, False
i.id, i.sign = id, True
t.shorten_id = i.shorten_id = shorten_id
t.value_expr = None
i.value_expr = ('not', t)
t.__inversion = i
i.__inversion = t
__class__.__literals[id] = t

return t

@staticmethod
def from_hash(hash):
return __class__.get(__class__.__HASH_PREFIX + hash.hex())
def get_intermediate(op, *args):
assert op in ('or', 'and', 'not', 'ifelse')
value_expr = (op,) + args
hash = HASH(str(value_expr).encode()).hexdigest()
t = __class__.get(__class__.__HASH_PREFIX + hash)
t.value_expr = value_expr
return t

def __repr__(self):
SIGNS = {False: '', True: '~'}
Expand All @@ -193,36 +203,52 @@ def cast(x):

class Storage(object):
def __init__(self, *, image=None):
self.__ids = []
self.__id_indexes = {}
self.__literals = []
self.__indexes = {}

if image is not None:
for i, id in enumerate(image):
assert self.__add_id(id) == i
for i, (id, value_expr) in enumerate(image):
assert self.add(Literal.get(id)) == i * 2

def __add_id(self, id):
i = self.__id_indexes.get(id)
if i is None:
i = len(self.__ids)
self.__ids.append(id)
self.__id_indexes[id] = i
for id, value_expr in image:
if value_expr is not None:
op, *args = value_expr
value_expr = (op,) + tuple(self.get(a) for a in args)

return i
Literal.get(id).value_expr = value_expr

def add(self, literal):
i = self.__add_id(literal.id)
return i * 2 + int(literal.sign)
if literal.sign:
return self.add(~literal) + 1

i = self.__indexes.get(literal)
if i is None:
i = len(self.__literals)
self.__literals.append(literal)
self.__indexes[literal] = i

return i * 2

def get(self, image):
i, sign = image // 2, bool(image % 2)
return Literal.get(self.__ids[i], sign)
i, sign = image // 2, image % 2
t = self.__literals[i]
return ~t if sign else t

@property
def image(self):
return tuple(self.__ids)
im = []
for t in self.__literals:
e = t.value_expr
if e is not None:
op, *args = e
e = (op,) + tuple(self.add(a) for a in args)

im.append((t.id, e))

return tuple(im)

def __invert__(self):
return __class__.get(self.id, not self.sign)
return self.__inversion


class Clause(object):
Expand Down Expand Up @@ -371,11 +397,6 @@ def size(self):
def __eq__(self, other):
assert 0, "Bool's should not be compared; use is_equiv() instead."

@staticmethod
def __get_op_symbol(kind, *ops):
key = kind + b''.join(op.hash for op in ops)
return Literal.from_hash(HASH(key).digest())

@staticmethod
def get_or(*args):
args = tuple(a for a in args if a.value is not False)
Expand All @@ -389,7 +410,7 @@ def get_or(*args):
# TODO: Optimise the case of two pure symbols.

syms = sorted(a.symbol for a in args)
r = __class__.__get_op_symbol(b'(or)', *syms)
r = Literal.get_intermediate('or', *syms)

or_clauses = [Clause.get(*(syms + [~r]))]
or_clauses.extend(Clause.get(~s, r) for s in syms)
Expand All @@ -412,7 +433,7 @@ def get_and(*args):
# TODO: Optimise the case of two pure symbols.

syms = sorted(a.symbol for a in args)
r = __class__.__get_op_symbol(b'(and)', *syms)
r = Literal.get_intermediate('and', *syms)

and_clauses = [Clause.get(*([~s for s in syms] + [r]))]
and_clauses.extend(Clause.get(s, ~r) for s in syms)
Expand All @@ -431,7 +452,7 @@ def __invert__(self):
# TODO: Optimise the case of a pure symbol.

a = self.symbol
r = __class__.__get_op_symbol(b'(not)', a)
r = Literal.get_intermediate('not', a)
return __class__.from_clauses(r, self.clauses,
(Clause.get(a, r),
Clause.get(~a, ~r)))
Expand All @@ -452,7 +473,7 @@ def ifelse(cond, a, b):
# TODO: Optimise the case of pure symbols.

i, t, e = cond.symbol, a.symbol, b.symbol
r = __class__.__get_op_symbol(b'(ifelse)', i, t, e)
r = Literal.get_intermediate('ifelse', i, t, e)
return __class__.from_clauses(r, cond.clauses,
a.clauses, b.clauses,
(Clause.get(~i, t, ~r),
Expand Down Expand Up @@ -558,6 +579,58 @@ def simplified(self):
.__get_constraints())
cache.store((s.image,))

def simplified_sexpr(self):
if self.value is not None:
return z3.BoolVal(self.value).sexpr()

key = self.__get_value_or_symbol_expr().sexpr()
cache = Cache.get_entry('simplified', key)
s = cache.load()
if s is not None:
s, = s
return s

temps = []

def add_temps(t):
if t.value_expr is None or t in temps:
return

op, *args = t.value_expr
for a in args:
add_temps(a)

temps.append(t)

for c in self.clauses:
for t in c.literals:
add_temps(t)

e = z3.And(self.__get_constraints_expr(),
self.__get_value_or_symbol_expr())

tactic = z3.Tactic('qe2')
e = tactic.apply(e).as_expr()

while temps:
with Status.do(f'{len(temps)} temps to eliminate'):
t = temps.pop()
op, *args = t.value_expr

s = __class__.__get_literal_expr(t)

OPS = {'or': z3.Or, 'and': z3.And, 'not': z3.Not,
'ifelse': z3.If}
v = OPS[op](*(__class__.__get_literal_expr(a) for a in args))

e = z3.substitute(e, (s, v))
e = tactic.apply(e).as_expr()

s = e.sexpr()
cache.store((s,))

return s

def reduced(self):
# TODO: It seems we do much faster without trying to
# reduce at every step.
Expand Down Expand Up @@ -1458,7 +1531,7 @@ def __get_cache(steps):
# Whenever we make changes that invalidate cached states,
# e.g., the names of the nodes are changed, the version
# number must be bumped.
VERSION = 5
VERSION = 6

key = VERSION, __class__.__get_steps_image(steps)
return Cache.get_entry('states', key)
Expand Down Expand Up @@ -1654,27 +1727,25 @@ def report(self, id):
# Generate image before printing results.
image = self.image

Status.clear()

print(id)
print(self.hash)
Status.print(id)
Status.print(self.hash)

s = Z80Simulator(image=image)

def print_bit(id, with_pull=False):
n = s.get_node(id)
print(f' {id}: {n.state}')
Status.print(f' {id}: {n.state.simplified_sexpr()}')
if with_pull:
print(f' {id} pull: {n.pull}')
Status.print(f' {id} pull: {n.pull.simplified_sexpr()}')

def print_bits(id, width):
bits = s.read_nodes(id, width)
if not isinstance(bits.value, tuple):
print(f' {id}: {bits}')
Status.print(f' {id}: {bits}')
return

for i in reversed(range(bits.width)):
print(f' {id}{i}: {bits[i]}')
Status.print(f' {id}{i}: {bits[i].simplified_sexpr()}')

for pin in _PINS:
print_bit(pin, with_pull=True)
Expand Down

0 comments on commit 3941644

Please sign in to comment.