Skip to content

Commit

Permalink
don't report reportUnnecessaryComparison whern using assert_never
Browse files Browse the repository at this point in the history
… in an intentionally unreachable case in a match statement
  • Loading branch information
DetachHead committed Aug 19, 2024
1 parent 562a2c0 commit 70c15c1
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 3 deletions.
39 changes: 37 additions & 2 deletions packages/pyright-internal/src/analyzer/patternMatching.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import {
NeverType,
Type,
TypeBase,
TypeCategory,
TypedDictEntry,
UnknownType,
combineTypes,
Expand Down Expand Up @@ -169,9 +170,43 @@ export function narrowTypeBasedOnPattern(

// Determines whether this pattern (or part of the pattern) in
// this case statement will never be matched.
export function checkForUnusedPattern(evaluator: TypeEvaluator, pattern: PatternAtomNode, subjectType: Type): void {
export function checkForUnusedPattern(
evaluator: TypeEvaluator,
pattern: PatternAtomNode,
subjectType: Type,
subjectExpression: ExpressionNode
): void {
if (isNever(subjectType)) {
reportUnnecessaryPattern(evaluator, pattern, subjectType);
// don't report unnecessary pattern if the suite contains an assert_never, because that means it's intentional
const parentNode = pattern.parent;
if (
!parentNode ||
parentNode.nodeType !== ParseNodeType.Case ||
!parentNode.d.suite.d.statements.some(
(statement) =>
statement.nodeType === ParseNodeType.StatementList &&
// this check is probably overkill and we could instead just special-case `typing.assert_never`, but we want to support
// user-defined "assert never" functions to be more flexible. note that there is a very similar check for the same thing
// in _walkStatementsAndReportUnreachable, but this one doesn't check the type of the argument, but rather whether or not
// it's the same variable that was being matched against.
statement.d.statements.some(
(statement) =>
statement.nodeType === ParseNodeType.Call &&
evaluator.matchCallArgsToParams(statement)?.find((result) =>
result.match.argParams.some(
(param) =>
//check the function parameter type:
param.paramType.category === TypeCategory.Never &&
// check that the argument is the same symbol being matched against
param.argument.valueExpression &&
isMatchingExpression(param.argument.valueExpression, subjectExpression)
)
)
)
)
) {
reportUnnecessaryPattern(evaluator, pattern, subjectType);
}
} else if (pattern.nodeType === ParseNodeType.PatternAs && pattern.d.orPatterns.length > 1) {
// Check each of the or patterns separately.
pattern.d.orPatterns.forEach((orPattern) => {
Expand Down
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19437,7 +19437,7 @@ export function createTypeEvaluator(
if (caseStatement === node) {
if (fileInfo.diagnosticRuleSet.reportUnnecessaryComparison !== 'none') {
if (!subjectTypeResult.isIncomplete) {
checkForUnusedPattern(evaluatorInterface, node.d.pattern, subjectType);
checkForUnusedPattern(evaluatorInterface, node.d.pattern, subjectType, node.parent.d.expr);
}
}
break;
Expand Down
17 changes: 17 additions & 0 deletions packages/pyright-internal/src/tests/patternMatching.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { ConfigOptions } from '../common/configOptions';
import { DiagnosticRule } from '../common/diagnosticRules';
import { Uri } from '../common/uri/uri';
import { typeAnalyzeSampleFiles, validateResultsButBased } from './testUtils';

test('assert_never', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.reportUnnecessaryComparison = 'error';

const analysisResults = typeAnalyzeSampleFiles(['reportUnnecessaryComparison.py'], configOptions);
validateResultsButBased(analysisResults, {
errors: [
{ code: DiagnosticRule.reportUnnecessaryComparison, line: 16 },
{ code: DiagnosticRule.reportUnnecessaryComparison, line: 24 },
],
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Never, assert_never

def _(something_else: Never, subject: int | str):
match subject:
case int():
...
case str():
...
case _: # no error, intentional due to assert_never on the subject
assert_never(subject)

match subject:
case int():
...
case str():
...
case _: # error, the argument passed to assert_never is unrelated to the match subject
assert_never(something_else)

match subject:
case int():
...
case str():
...
case _: # error, not an assert_never
print(subject)

0 comments on commit 70c15c1

Please sign in to comment.