From f51513565bd10df02768a23b97e773738c10ebca Mon Sep 17 00:00:00 2001 From: huandy Date: Mon, 9 Dec 2024 13:06:20 -0500 Subject: [PATCH 1/2] Get empty dict issue --- mypy/checkexpr.py | 26 ++++++++++++++++++++++++++ mypy/checkmember.py | 19 +++++++++++++++++-- test-data/unit/check-dict-get.test | 19 +++++++++++++++++++ 3 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 test-data/unit/check-dict-get.test diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 577576a4e5f8..14d120646ec1 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -473,6 +473,32 @@ def module_type(self, node: MypyFile) -> Instance: def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type: """Type check a call expression.""" + if ( + self.refers_to_typeddict(e.callee) + or isinstance(e.callee, IndexExpr) + and self.refers_to_typeddict(e.callee.base) + ): + typeddict_callable = get_proper_type(self.accept(e.callee, is_callee=True)) + if isinstance(typeddict_callable, CallableType): + typeddict_type = get_proper_type(typeddict_callable.ret_type) + assert isinstance(typeddict_type, TypedDictType) + return self.check_typeddict_call( + e, typeddict_type, typeddict_callable + ) + + # Add logic to handle the `get` method + if isinstance(e.callee, MemberExpr) and e.callee.name == 'get': + dict_type = self.accept(e.callee.expr) + if isinstance(dict_type, Instance) and dict_type.type.fullname == 'builtins.dict': + key_type = self.accept(e.args[0]) + if len(e.args) == 2: + default_type = self.accept(e.args[1]) + return UnionType.make_union([dict_type.args[1], default_type]) + return UnionType.make_union([dict_type.args[1], NoneType()]) + elif isinstance(dict_type, Instance) and dict_type.type.fullname == 'builtins.dict': + # Handle empty dictionary case + return AnyType(TypeOfAny.special_form) + if e.analyzed: if isinstance(e.analyzed, NamedTupleExpr) and not e.analyzed.is_typed: # Type check the arguments, but ignore the results. This relies diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 9dc8d5475b1a..bc0ea1eff9ce 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Callable, Sequence, cast - +from mypy.nodes import ARG_POS, ARG_OPT from mypy import meet, message_registry, subtypes from mypy.erasetype import erase_typevars from mypy.expandtype import ( @@ -203,6 +203,11 @@ def analyze_member_access( no_deferral=no_deferral, is_self=is_self, ) + + if name == 'get' and isinstance(typ, Instance) and typ.type.fullname == 'builtins.dict': + # Handle overload resolution for dict.get + return analyze_dict_get(typ, context) + result = _analyze_member_access(name, typ, mx, override_info) possible_literal = get_proper_type(result) if ( @@ -214,7 +219,17 @@ def analyze_member_access( else: return result - +def analyze_dict_get(self, typ: Instance, context: Context) -> Type: + key_type = typ.args[0] + value_type = typ.args[1] + return CallableType( + [key_type, value_type], + [ARG_POS, ARG_OPT], + [None, None], + UnionType.make_union([value_type, NoneType()]), + self.named_type('builtins.function') + ) + def _analyze_member_access( name: str, typ: Type, mx: MemberContext, override_info: TypeInfo | None = None ) -> Type: diff --git a/test-data/unit/check-dict-get.test b/test-data/unit/check-dict-get.test new file mode 100644 index 000000000000..84e4dd3432cb --- /dev/null +++ b/test-data/unit/check-dict-get.test @@ -0,0 +1,19 @@ +[case testDictGetEmpty] +def testDictGetEmpty() -> None: + x = {}.get("x") + reveal_type(x) # N: Revealed type is "None" + +[case testDictGetWithDefault] +def testDictGetWithDefault() -> None: + x = {}.get("x", 42) + reveal_type(x) # N: Revealed type is "int" + +[case testDictGetExistingKey] +def testDictGetExistingKey() -> None: + x = {"a": 1}.get("a") + reveal_type(x) # N: Revealed type is "int" + +[case testDictGetMissingKey] +def testDictGetMissingKey() -> None: + x = {"a": 1}.get("b") + reveal_type(x) # N: Revealed type is "None" \ No newline at end of file From 50706fdf0a46176f5730f4fde9232c1e84a37f5f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 18:47:53 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/checkexpr.py | 14 ++++++-------- mypy/checkmember.py | 15 +++++++++------ test-data/unit/check-dict-get.test | 2 +- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 14d120646ec1..fbbcc4c85412 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -482,23 +482,21 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type: if isinstance(typeddict_callable, CallableType): typeddict_type = get_proper_type(typeddict_callable.ret_type) assert isinstance(typeddict_type, TypedDictType) - return self.check_typeddict_call( - e, typeddict_type, typeddict_callable - ) - + return self.check_typeddict_call(e, typeddict_type, typeddict_callable) + # Add logic to handle the `get` method - if isinstance(e.callee, MemberExpr) and e.callee.name == 'get': + if isinstance(e.callee, MemberExpr) and e.callee.name == "get": dict_type = self.accept(e.callee.expr) - if isinstance(dict_type, Instance) and dict_type.type.fullname == 'builtins.dict': + if isinstance(dict_type, Instance) and dict_type.type.fullname == "builtins.dict": key_type = self.accept(e.args[0]) if len(e.args) == 2: default_type = self.accept(e.args[1]) return UnionType.make_union([dict_type.args[1], default_type]) return UnionType.make_union([dict_type.args[1], NoneType()]) - elif isinstance(dict_type, Instance) and dict_type.type.fullname == 'builtins.dict': + elif isinstance(dict_type, Instance) and dict_type.type.fullname == "builtins.dict": # Handle empty dictionary case return AnyType(TypeOfAny.special_form) - + if e.analyzed: if isinstance(e.analyzed, NamedTupleExpr) and not e.analyzed.is_typed: # Type check the arguments, but ignore the results. This relies diff --git a/mypy/checkmember.py b/mypy/checkmember.py index bc0ea1eff9ce..fabf2363b4b5 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Callable, Sequence, cast -from mypy.nodes import ARG_POS, ARG_OPT + from mypy import meet, message_registry, subtypes from mypy.erasetype import erase_typevars from mypy.expandtype import ( @@ -14,6 +14,7 @@ from mypy.maptype import map_instance_to_supertype from mypy.messages import MessageBuilder from mypy.nodes import ( + ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, @@ -203,11 +204,11 @@ def analyze_member_access( no_deferral=no_deferral, is_self=is_self, ) - - if name == 'get' and isinstance(typ, Instance) and typ.type.fullname == 'builtins.dict': + + if name == "get" and isinstance(typ, Instance) and typ.type.fullname == "builtins.dict": # Handle overload resolution for dict.get return analyze_dict_get(typ, context) - + result = _analyze_member_access(name, typ, mx, override_info) possible_literal = get_proper_type(result) if ( @@ -219,6 +220,7 @@ def analyze_member_access( else: return result + def analyze_dict_get(self, typ: Instance, context: Context) -> Type: key_type = typ.args[0] value_type = typ.args[1] @@ -227,9 +229,10 @@ def analyze_dict_get(self, typ: Instance, context: Context) -> Type: [ARG_POS, ARG_OPT], [None, None], UnionType.make_union([value_type, NoneType()]), - self.named_type('builtins.function') + self.named_type("builtins.function"), ) - + + def _analyze_member_access( name: str, typ: Type, mx: MemberContext, override_info: TypeInfo | None = None ) -> Type: diff --git a/test-data/unit/check-dict-get.test b/test-data/unit/check-dict-get.test index 84e4dd3432cb..7dd6b156bef4 100644 --- a/test-data/unit/check-dict-get.test +++ b/test-data/unit/check-dict-get.test @@ -16,4 +16,4 @@ def testDictGetExistingKey() -> None: [case testDictGetMissingKey] def testDictGetMissingKey() -> None: x = {"a": 1}.get("b") - reveal_type(x) # N: Revealed type is "None" \ No newline at end of file + reveal_type(x) # N: Revealed type is "None"