Skip to content

Commit

Permalink
Fix precedence naming in expr_unparse.py
Browse files Browse the repository at this point in the history
Better tests
  • Loading branch information
yunline committed Dec 13, 2024
1 parent e29f4f3 commit fc6c81b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 44 deletions.
89 changes: 49 additions & 40 deletions oneliner/expr_unparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,17 @@
}
enum = itertools.count()
PREC_NAME = next(enum)

PREC_ATTR = next(enum)
PREC_ATTR_SLOT = next(enum)

PREC_AWAIT_SLOT = next(enum)
PREC_AWAIT = next(enum)

PREC_POW_SLOT_LEFT = next(enum)
PREC_POW = next(enum)
PREC_INV = next(enum)
PREC_INV_UADD_USUB = next(enum)
PREC_INV_UADD_USUB_SLOT = next(enum)
PREC_POW_SLOT_RIGHT = next(enum)

PREC_MULT_SLOT_RIGHT = next(enum)
Expand All @@ -80,9 +83,13 @@
PREC_BITOR = next(enum)
PREC_BITOR_SLOT_LEFT = next(enum)

PREC_STARRED_SLOT = next(enum)

PREC_COMPARE_SLOT = next(enum)
PREC_COMPARE = next(enum)

PREC_NOT = next(enum)
PREC_NOT_SLOT = next(enum)

PREC_AND_SLOT = next(enum)
PREC_AND = next(enum)
Expand All @@ -96,9 +103,9 @@
PREC_IFEXP = next(enum)
PREC_IFEXP_SLOT_RIGHT = next(enum)

PREC_FORMAT_EXPR = next(enum)
PREC_FORMAT_EXPR_SLOT = next(enum)
PREC_LAMBDA = next(enum)
PREC_EXPR = next(enum)
PREC_EXPR_SLOT = next(enum)

PREC_CALL_SLOT_KWARG = next(enum)
PREC_NAMEDEXPR = next(enum)
Expand Down Expand Up @@ -134,9 +141,9 @@
}

unaryop_node_prec_map: dict[type[unaryop], prec_t] = {
UAdd: PREC_INV,
USub: PREC_INV,
Invert: PREC_INV,
UAdd: PREC_INV_UADD_USUB,
USub: PREC_INV_UADD_USUB,
Invert: PREC_INV_UADD_USUB,
Not: PREC_NOT,
}

Expand All @@ -160,7 +167,7 @@
Compare: PREC_COMPARE,
IfExp: PREC_IFEXP,
Lambda: PREC_LAMBDA,
Slice: PREC_EXPR,
Slice: PREC_NAME,
NamedExpr: PREC_NAMEDEXPR,
GeneratorExp: PREC_GENEXPR,
Yield: PREC_YIELD,
Expand Down Expand Up @@ -232,7 +239,7 @@ def _unparse_JoinedStr(node: JoinedStr, qm: typing.Literal["'", '"']) -> unparse
s = s.replace("{", "{{").replace("}", "}}")
contents.append(s)
elif isinstance(v, FormattedValue):
contents.append((yield PREC_FORMAT_EXPR, v))
contents.append((yield PREC_FORMAT_EXPR_SLOT, v))
return "".join(contents)


Expand All @@ -244,7 +251,7 @@ def unparse_JoinedStr(node: JoinedStr, qm: typing.Literal["'", '"']) -> unparse_


def unparse_FormattedValue(node: FormattedValue, qm) -> unparse_gen_t:
value = yield PREC_FORMAT_EXPR, node.value
value = yield PREC_FORMAT_EXPR_SLOT, node.value
format_spec = ""
if node.format_spec is not None:
assert isinstance(node.format_spec, JoinedStr)
Expand All @@ -261,15 +268,13 @@ def unparse_FormattedValue(node: FormattedValue, qm) -> unparse_gen_t:


def unparse_Starred(node: Starred) -> unparse_gen_t:
precedence = PREC_NAME
value = yield precedence, node.value
value = yield PREC_STARRED_SLOT, node.value
return f"*{value}"
yield


def unparse_Attribute(node: Attribute) -> unparse_gen_t:
precedence = PREC_NAME
value = yield precedence, node.value
value = yield PREC_ATTR_SLOT, node.value
if value.isdigit():
# 0.a is invalid
# (0).a is valid
Expand All @@ -279,25 +284,24 @@ def unparse_Attribute(node: Attribute) -> unparse_gen_t:


def unparse_Subscript(node: Subscript) -> unparse_gen_t:
value = yield PREC_NAME, node.value
_slice = yield PREC_EXPR, node.slice
value = yield PREC_ATTR_SLOT, node.value
_slice = yield PREC_EXPR_SLOT, node.slice
return f"{value}[{_slice}]"


def unparse_Slice(node: Slice) -> unparse_gen_t:
precedence = PREC_EXPR
upper, lower, step = "", "", ""
if node.upper is not None:
upper = yield precedence, node.upper
upper = yield PREC_EXPR_SLOT, node.upper
if node.lower is not None:
lower = yield precedence, node.lower
lower = yield PREC_EXPR_SLOT, node.lower
if node.step is not None:
step = yield precedence, node.step
step = yield PREC_EXPR_SLOT, node.step
return f"{lower}:{upper}:{step}" # todo: simplify


def unparse_Call(node: Call) -> unparse_gen_t:
func = yield PREC_ATTR, node.func
func = yield PREC_ATTR_SLOT, node.func
if len(node.args) == 1 and len(node.keywords) == 0:
_arg = yield PREC_CALL_SLOT_ONLYARG, node.args[0]
return f"{func}({_arg})"
Expand Down Expand Up @@ -362,7 +366,12 @@ def unparse_BoolOp(node: BoolOp) -> unparse_gen_t:


def unparse_UnaryOp(node: UnaryOp) -> unparse_gen_t:
precedence = unaryop_node_prec_map[type(node.op)]
if isinstance(node.op, Not):
precedence = PREC_NOT_SLOT
elif isinstance(node.op, (Invert, UAdd, USub)):
precedence = PREC_INV_UADD_USUB_SLOT
else: # pragma: no cover
raise SyntaxError(f"Unknown UnaryOp type {type(node.op)}")
op = unaryop_map[type(node.op)]
operand = yield precedence, node.operand
return f"{op}{operand}"
Expand All @@ -371,35 +380,35 @@ def unparse_UnaryOp(node: UnaryOp) -> unparse_gen_t:
def unparse_List(node: List) -> unparse_gen_t:
elts = []
for item in node.elts:
elts.append((yield PREC_EXPR, item))
elts.append((yield PREC_EXPR_SLOT, item))
return f"[{','.join(elts)}]"


def unparse_Set(node: Set) -> unparse_gen_t:
elts = []
for item in node.elts:
elts.append((yield PREC_EXPR, item))
elts.append((yield PREC_EXPR_SLOT, item))
return f"{{{','.join(elts)}}}"


def unparse_Dict(node: Dict) -> unparse_gen_t:
item = []
for k, v in zip(node.keys, node.values):
if k is not None:
value = yield PREC_EXPR, v
key = yield PREC_EXPR, k
value = yield PREC_EXPR_SLOT, v
key = yield PREC_EXPR_SLOT, k
item.append(f"{key}:{value}")
else:
# **value requires a smaller precedence value
value = yield PREC_ATTR, v
value = yield PREC_STARRED_SLOT, v
item.append(f"**{value}")
return f"{{{','.join(item)}}}"


def unparse_Tuple(node: Tuple) -> unparse_gen_t:
elts = []
for item in node.elts:
elts.append((yield PREC_EXPR, item))
elts.append((yield PREC_EXPR_SLOT, item))
if len(elts) == 1:
return f"({elts[0]},)"
return f"({','.join(elts)})"
Expand All @@ -416,12 +425,12 @@ def unparse_Compare(node: Compare) -> unparse_gen_t:


def unparse_NamedExpr(node: NamedExpr) -> unparse_gen_t:
value = yield PREC_EXPR, node.value
value = yield PREC_EXPR_SLOT, node.value
return f"{node.target.id}:={value}"


def unparse_Lambda(node: Lambda) -> unparse_gen_t:
body = yield PREC_LAMBDA, node.body
body = yield PREC_EXPR_SLOT, node.body
arg_def_list = []
default: expr | None

Expand All @@ -436,7 +445,7 @@ def unparse_Lambda(node: Lambda) -> unparse_gen_t:
for default in reversed(node.args.defaults):
ind -= 1
if default is not None:
arg_def_list[ind] += f"={yield PREC_EXPR,default}"
arg_def_list[ind] += f"={yield PREC_EXPR_SLOT,default}"

if node.args.posonlyargs:
arg_def_list.insert(len(node.args.posonlyargs), "/")
Expand All @@ -453,7 +462,7 @@ def unparse_Lambda(node: Lambda) -> unparse_gen_t:
kw_list.append(kwonly.arg)
for ind, default in enumerate(node.args.kw_defaults):
if default is not None:
kw_list[ind] += f"={yield PREC_EXPR,default}"
kw_list[ind] += f"={yield PREC_EXPR_SLOT,default}"
arg_def_list.extend(kw_list)

# handle kwarg
Expand All @@ -471,7 +480,7 @@ def _unparse_comprehensions(generators: list[comprehension]) -> unparse_gen_t:
for gen in generators:
_async = "" if not gen.is_async else "async "
_iter = yield PREC_COMPREHENSION_SLOT_ITER, gen.iter
target = yield PREC_EXPR, gen.target
target = yield PREC_EXPR_SLOT, gen.target
if_list = []
for test in gen.ifs:
if_list.append((yield PREC_COMPREHENSION_SLOT_ITER, test))
Expand All @@ -483,26 +492,26 @@ def _unparse_comprehensions(generators: list[comprehension]) -> unparse_gen_t:


def unparse_ListComp(node: ListComp) -> unparse_gen_t:
elt = yield PREC_EXPR, node.elt
elt = yield PREC_EXPR_SLOT, node.elt
generators = yield from _unparse_comprehensions(node.generators)
return f"[{elt} {generators}]"


def unparse_GeneratorExp(node: GeneratorExp) -> unparse_gen_t:
elt = yield PREC_EXPR, node.elt
elt = yield PREC_EXPR_SLOT, node.elt
generators = yield from _unparse_comprehensions(node.generators)
return f"{elt} {generators}"


def unparse_SetComp(node: SetComp) -> unparse_gen_t:
elt = yield PREC_EXPR, node.elt
elt = yield PREC_EXPR_SLOT, node.elt
generators = yield from _unparse_comprehensions(node.generators)
return f"{{{elt} {generators}}}"


def unparse_DictComp(node: DictComp) -> unparse_gen_t:
key = yield PREC_EXPR, node.key
value = yield PREC_EXPR, node.value
key = yield PREC_EXPR_SLOT, node.key
value = yield PREC_EXPR_SLOT, node.value
generators = yield from _unparse_comprehensions(node.generators)
return f"{{{key}:{value} {generators}}}"

Expand All @@ -517,12 +526,12 @@ def unparse_IfExp(node: IfExp) -> unparse_gen_t:
def unparse_Yield(node: Yield) -> unparse_gen_t:
if node.value is None:
return "yield"
value = yield PREC_EXPR, node.value
value = yield PREC_EXPR_SLOT, node.value
return f"yield {value}"


def unparse_YieldFrom(node: YieldFrom) -> unparse_gen_t:
value = yield PREC_EXPR, node.value
value = yield PREC_EXPR_SLOT, node.value
return f"yield from {value}"


Expand Down Expand Up @@ -603,7 +612,7 @@ def __init__(self, outer_precedence: prec_t, node: expr, outer_str_qm: str):

def expr_unparse(node: expr) -> str:
stack: list[_Node] = []
stack.append(_Node(PREC_EXPR, node, '"'))
stack.append(_Node(PREC_EXPR_SLOT, node, '"'))
converted: str | None = None
while stack:
try:
Expand Down
14 changes: 10 additions & 4 deletions oneliner_tests/unparser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_Constant_str(self):
self.assertUnparseConsist('"ef\'\\"gh"')
self.assertUnparseConsist('"ji\\\'kl"')
self.assertUnparseConsist("'mn\\\"op'")
self.assertUnparseConsist('b"hello bytes"')
with self.subTest("JoinedStr"):
self.assertUnparseConsist("f'hello{world:fmt}hello{0}hello{awa:{q}}}}'")
if sys.version_info >= (3, 12):
Expand Down Expand Up @@ -123,6 +124,7 @@ def test_Set(self):

def test_SetComp(self):
self.assertUnparseConsist("{a for b in c}")
self.assertUnparseConsist("{a for b, c, d in e}")
self.assertUnparseConsist("{a for b in c if d}")
self.assertUnparseConsist("{a for b in c if d for e in f}")
self.assertUnparseConsist("{a async for b in c if d for e in f}")
Expand Down Expand Up @@ -161,6 +163,7 @@ def test_UnaryOp(self):
self.assertUnparseConsist("~a")
self.assertUnparseConsist("-a")
self.assertUnparseConsist("+a")
self.assertUnparseConsist("not a")

def test_BinOp_Pow(self):
self.assertUnparseConsist("a**b")
Expand Down Expand Up @@ -219,7 +222,7 @@ def test_YieldFrom(self):

class TestComplexExprUnparse(_TestExprUnparse):
slots = {
"List": ["[a,{0},a]"],
"List": ["[a,{0},a]", "[a, *{0}]"],
"Tuple": ["(a,{0},a)"],
"Set": ["{{a,{0},a}}"],
"Dict": ["{{a:b,{0}:e,c:d}}", "{{a:b,e:{0},c:d}}", "{{a:b,**{0}}}"],
Expand All @@ -245,7 +248,7 @@ class TestComplexExprUnparse(_TestExprUnparse):
"USub": ["-{0}"],
"Invert": ["~{0}"],
"Not": ["not {0}"],
"Compare": ["a>{0}", "{0}>a"],
"Compare": ["a>{0}", "{0}>a", "{0} is a", "a is {0}", "{0} in a", "a in {0}"],
"IfExp": ["{0} if b else c", "a if {0} else c", "a if b else {0}"],
"Lambda": ["lambda:{0}", "lambda kw={0}:a"],
"Call": ["{0}()", "a({0})", "a(kw={0})", "a(*{0})", "a(**{0})", "a(a,b,{0},c)"],
Expand All @@ -260,10 +263,11 @@ class TestComplexExprUnparse(_TestExprUnparse):
"Const": ["114514", "0.5", "1j", "'x'"],
"Name": ["a"],
"List": ["[a]"],
"Tuple": ["(a,)"],
"Tuple": ["()", "(a,)", "(a,b,c)"],
"Set": ["{a}"],
"Dict": ["{a:b}"],
"Attr": ["a.b"],
"Subscript": ["a[b]"],
"Pow": ["a**b"],
"Mult": ["a*b"],
"MatMult": ["a@b"],
Expand All @@ -283,7 +287,7 @@ class TestComplexExprUnparse(_TestExprUnparse):
"USub": ["-a"],
"Invert": ["~a"],
"Not": ["not a"],
"Compare": ["a>b"],
"Compare": ["a>b", "a is b", "a in b"],
"IfExp": ["a if b else c"],
"Lambda": ["lambda:a"],
"Call": ["print()"],
Expand All @@ -295,6 +299,8 @@ class TestComplexExprUnparse(_TestExprUnparse):
"SetComp": ["{a for b in c}"],
"GeneratorExp": ["(a for b in c)"],
"DictComp": ["{a:b for c in d}"],
# Slice/Starred are ignored
# because they can only be inside of Subscript/Call
}

# generate test cases
Expand Down

0 comments on commit fc6c81b

Please sign in to comment.