Skip to content

Commit

Permalink
gh-125710: [Enum] fix hashable<->nonhashable comparisons for member v…
Browse files Browse the repository at this point in the history
…alues (GH-125735)
  • Loading branch information
ethanfurman authored Oct 22, 2024
1 parent 079875e commit aaed91c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 6 deletions.
26 changes: 20 additions & 6 deletions Lib/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def __set_name__(self, enum_class, member_name):
# to the map, and by-value lookups for this value will be
# linear.
enum_class._value2member_map_.setdefault(value, enum_member)
if value not in enum_class._hashable_values_:
enum_class._hashable_values_.append(value)
except TypeError:
# keep track of the value in a list so containment checks are quick
enum_class._unhashable_values_.append(value)
Expand Down Expand Up @@ -538,7 +540,8 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
classdict['_member_names_'] = []
classdict['_member_map_'] = {}
classdict['_value2member_map_'] = {}
classdict['_unhashable_values_'] = []
classdict['_hashable_values_'] = [] # for comparing with non-hashable types
classdict['_unhashable_values_'] = [] # e.g. frozenset() with set()
classdict['_unhashable_values_map_'] = {}
classdict['_member_type_'] = member_type
# now set the __repr__ for the value
Expand Down Expand Up @@ -748,7 +751,10 @@ def __contains__(cls, value):
try:
return value in cls._value2member_map_
except TypeError:
return value in cls._unhashable_values_
return (
value in cls._unhashable_values_ # both structures are lists
or value in cls._hashable_values_
)

def __delattr__(cls, attr):
# nicer error message when someone tries to delete an attribute
Expand Down Expand Up @@ -1166,8 +1172,11 @@ def __new__(cls, value):
pass
except TypeError:
# not there, now do long search -- O(n) behavior
for name, values in cls._unhashable_values_map_.items():
if value in values:
for name, unhashable_values in cls._unhashable_values_map_.items():
if value in unhashable_values:
return cls[name]
for name, member in cls._member_map_.items():
if value == member._value_:
return cls[name]
# still not found -- verify that members exist, in-case somebody got here mistakenly
# (such as via super when trying to override __new__)
Expand Down Expand Up @@ -1233,6 +1242,7 @@ def _add_value_alias_(self, value):
# to the map, and by-value lookups for this value will be
# linear.
cls._value2member_map_.setdefault(value, self)
cls._hashable_values_.append(value)
except TypeError:
# keep track of the value in a list so containment checks are quick
cls._unhashable_values_.append(value)
Expand Down Expand Up @@ -1763,6 +1773,7 @@ def convert_class(cls):
body['_member_names_'] = member_names = []
body['_member_map_'] = member_map = {}
body['_value2member_map_'] = value2member_map = {}
body['_hashable_values_'] = hashable_values = []
body['_unhashable_values_'] = unhashable_values = []
body['_unhashable_values_map_'] = {}
body['_member_type_'] = member_type = etype._member_type_
Expand Down Expand Up @@ -1826,7 +1837,7 @@ def convert_class(cls):
contained = value2member_map.get(member._value_)
except TypeError:
contained = None
if member._value_ in unhashable_values:
if member._value_ in unhashable_values or member.value in hashable_values:
for m in enum_class:
if m._value_ == member._value_:
contained = m
Expand All @@ -1846,6 +1857,7 @@ def convert_class(cls):
else:
enum_class._add_member_(name, member)
value2member_map[value] = member
hashable_values.append(value)
if _is_single_bit(value):
# not a multi-bit alias, record in _member_names_ and _flag_mask_
member_names.append(name)
Expand Down Expand Up @@ -1882,7 +1894,7 @@ def convert_class(cls):
contained = value2member_map.get(member._value_)
except TypeError:
contained = None
if member._value_ in unhashable_values:
if member._value_ in unhashable_values or member._value_ in hashable_values:
for m in enum_class:
if m._value_ == member._value_:
contained = m
Expand All @@ -1908,6 +1920,8 @@ def convert_class(cls):
# to the map, and by-value lookups for this value will be
# linear.
enum_class._value2member_map_.setdefault(value, member)
if value not in hashable_values:
hashable_values.append(value)
except TypeError:
# keep track of the value in a list so containment checks are quick
enum_class._unhashable_values_.append(value)
Expand Down
7 changes: 7 additions & 0 deletions Lib/test/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3460,6 +3460,13 @@ def test_empty_names(self):
self.assertRaisesRegex(TypeError, '.int. object is not iterable', Enum, 'bad_enum', names=0)
self.assertRaisesRegex(TypeError, '.int. object is not iterable', Enum, 'bad_enum', 0, type=int)

def test_nonhashable_matches_hashable(self): # issue 125710
class Directions(Enum):
DOWN_ONLY = frozenset({"sc"})
UP_ONLY = frozenset({"cs"})
UNRESTRICTED = frozenset({"sc", "cs"})
self.assertIs(Directions({"sc"}), Directions.DOWN_ONLY)


class TestOrder(unittest.TestCase):
"test usage of the `_order_` attribute"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[Enum] fix hashable<->nonhashable comparisons for member values

0 comments on commit aaed91c

Please sign in to comment.