diff --git a/analyzers/src/SonarAnalyzer.CSharp/Rules/OperatorsShouldBeOverloadedConsistently.cs b/analyzers/src/SonarAnalyzer.CSharp/Rules/OperatorsShouldBeOverloadedConsistently.cs index d5232bda296..7a637fa7dc7 100644 --- a/analyzers/src/SonarAnalyzer.CSharp/Rules/OperatorsShouldBeOverloadedConsistently.cs +++ b/analyzers/src/SonarAnalyzer.CSharp/Rules/OperatorsShouldBeOverloadedConsistently.cs @@ -18,34 +18,21 @@ * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. */ -namespace SonarAnalyzer.Rules.CSharp +namespace SonarAnalyzer.Rules.CSharp; + +[DiagnosticAnalyzer(LanguageNames.CSharp)] +public sealed class OperatorsShouldBeOverloadedConsistently : SonarDiagnosticAnalyzer { - [DiagnosticAnalyzer(LanguageNames.CSharp)] - public sealed class OperatorsShouldBeOverloadedConsistently : SonarDiagnosticAnalyzer - { - private const string DiagnosticId = "S4050"; - private const string MessageFormat = "Provide an implementation for: {0}."; + private const string DiagnosticId = "S4050"; + private const string MessageFormat = "Provide an implementation for: {0}."; - private static readonly DiagnosticDescriptor Rule = - DescriptorFactory.Create(DiagnosticId, MessageFormat); - public override ImmutableArray SupportedDiagnostics { get; } = ImmutableArray.Create(Rule); + private static readonly DiagnosticDescriptor Rule = + DescriptorFactory.Create(DiagnosticId, MessageFormat); - private static class MethodName - { - public const string OperatorPlus = "operator+"; - public const string OperatorMinus = "operator-"; - public const string OperatorMultiply = "operator*"; - public const string OperatorDivide = "operator/"; - public const string OperatorReminder = "operator%"; - public const string OperatorEquals = "operator=="; - public const string OperatorNotEquals = "operator!="; - - public const string ObjectEquals = "Object.Equals"; - public const string ObjectGetHashCode = "Object.GetHashCode"; - } + public override ImmutableArray SupportedDiagnostics { get; } = ImmutableArray.Create(Rule); - protected override void Initialize(SonarAnalysisContext context) => - context.RegisterNodeAction(c => + protected override void Initialize(SonarAnalysisContext context) => + context.RegisterNodeAction(c => { var classDeclaration = (ClassDeclarationSyntax)c.Node; var classSymbol = (INamedTypeSymbol)c.ContainingSymbol; @@ -65,71 +52,82 @@ protected override void Initialize(SonarAnalysisContext context) => // This rule is not applicable for records, as for records it is not possible to override the == operator. SyntaxKind.ClassDeclaration); - private static IEnumerable FindMissingMethods(INamedTypeSymbol classSymbol) + private static IEnumerable FindMissingMethods(INamedTypeSymbol classSymbol) + { + var implementedMethods = GetImplementedMethods(classSymbol).ToHashSet(); + var requiredMethods = new HashSet(); + + if (implementedMethods.Contains(MethodName.OperatorPlus) + || implementedMethods.Contains(MethodName.OperatorMinus) + || implementedMethods.Contains(MethodName.OperatorMultiply) + || implementedMethods.Contains(MethodName.OperatorDivide) + || implementedMethods.Contains(MethodName.OperatorRemainder)) + { + requiredMethods.Add(MethodName.OperatorEquals); + requiredMethods.Add(MethodName.OperatorNotEquals); + requiredMethods.Add(MethodName.ObjectEquals); + requiredMethods.Add(MethodName.ObjectGetHashCode); + } + + if (implementedMethods.Contains(MethodName.OperatorEquals)) + { + requiredMethods.Add(MethodName.ObjectEquals); + requiredMethods.Add(MethodName.ObjectGetHashCode); + } + + if (implementedMethods.Contains(MethodName.OperatorNotEquals)) + { + requiredMethods.Add(MethodName.ObjectEquals); + requiredMethods.Add(MethodName.ObjectGetHashCode); + } + + return requiredMethods.Except(implementedMethods); + } + + private static IEnumerable GetImplementedMethods(INamedTypeSymbol classSymbol) + { + foreach (var member in classSymbol.GetMembers().OfType().Where(x => !x.IsConstructor())) { - var implementedMethods = GetImplementedMethods(classSymbol).ToHashSet(); - var requiredMethods = new HashSet(); - - if (implementedMethods.Contains(MethodName.OperatorPlus) - || implementedMethods.Contains(MethodName.OperatorMinus) - || implementedMethods.Contains(MethodName.OperatorMultiply) - || implementedMethods.Contains(MethodName.OperatorDivide) - || implementedMethods.Contains(MethodName.OperatorReminder)) + if (ImplementedOperator(member) is { } name) { - requiredMethods.Add(MethodName.OperatorEquals); - requiredMethods.Add(MethodName.OperatorNotEquals); - requiredMethods.Add(MethodName.ObjectEquals); - requiredMethods.Add(MethodName.ObjectGetHashCode); + yield return name; } - - if (implementedMethods.Contains(MethodName.OperatorEquals)) + else if (KnownMethods.IsObjectEquals(member)) { - requiredMethods.Add(MethodName.OperatorNotEquals); - requiredMethods.Add(MethodName.ObjectEquals); - requiredMethods.Add(MethodName.ObjectGetHashCode); + yield return MethodName.ObjectEquals; } - - if (implementedMethods.Contains(MethodName.OperatorNotEquals)) + else if (KnownMethods.IsObjectGetHashCode(member)) { - requiredMethods.Add(MethodName.OperatorEquals); - requiredMethods.Add(MethodName.ObjectEquals); - requiredMethods.Add(MethodName.ObjectGetHashCode); + yield return MethodName.ObjectGetHashCode; } - - return requiredMethods.Except(implementedMethods); } + } - private static IEnumerable GetImplementedMethods(INamedTypeSymbol classSymbol) + private static string ImplementedOperator(IMethodSymbol member) => + member switch { - foreach (var member in classSymbol.GetMembers().OfType().Where(x => !x.IsConstructor())) - { - if (ImplementedOperator(member) is { } name) - { - yield return name; - } - else if (KnownMethods.IsObjectEquals(member)) - { - yield return MethodName.ObjectEquals; - } - else if (KnownMethods.IsObjectGetHashCode(member)) - { - yield return MethodName.ObjectGetHashCode; - } - } - } + { MethodKind: not MethodKind.UserDefinedOperator } => null, + _ when KnownMethods.IsOperatorBinaryPlus(member) => MethodName.OperatorPlus, + _ when KnownMethods.IsOperatorBinaryMinus(member) => MethodName.OperatorMinus, + _ when KnownMethods.IsOperatorBinaryMultiply(member) => MethodName.OperatorMultiply, + _ when KnownMethods.IsOperatorBinaryDivide(member) => MethodName.OperatorDivide, + _ when KnownMethods.IsOperatorBinaryModulus(member) => MethodName.OperatorRemainder, + _ when KnownMethods.IsOperatorEquals(member) => MethodName.OperatorEquals, + _ when KnownMethods.IsOperatorNotEquals(member) => MethodName.OperatorNotEquals, + _ => null + }; - private static string ImplementedOperator(IMethodSymbol member) => - member switch - { - { MethodKind: not MethodKind.UserDefinedOperator } => null, - _ when KnownMethods.IsOperatorBinaryPlus(member) => MethodName.OperatorPlus, - _ when KnownMethods.IsOperatorBinaryMinus(member) => MethodName.OperatorMinus, - _ when KnownMethods.IsOperatorBinaryMultiply(member) => MethodName.OperatorMultiply, - _ when KnownMethods.IsOperatorBinaryDivide(member) => MethodName.OperatorDivide, - _ when KnownMethods.IsOperatorBinaryModulus(member) => MethodName.OperatorReminder, - _ when KnownMethods.IsOperatorEquals(member) => MethodName.OperatorEquals, - _ when KnownMethods.IsOperatorNotEquals(member) => MethodName.OperatorNotEquals, - _ => null - }; + private static class MethodName + { + public const string OperatorPlus = "operator+"; + public const string OperatorMinus = "operator-"; + public const string OperatorMultiply = "operator*"; + public const string OperatorDivide = "operator/"; + public const string OperatorRemainder = "operator%"; + public const string OperatorEquals = "operator=="; + public const string OperatorNotEquals = "operator!="; + + public const string ObjectEquals = "Object.Equals"; + public const string ObjectGetHashCode = "Object.GetHashCode"; } } diff --git a/analyzers/tests/SonarAnalyzer.Test/TestCases/OperatorsShouldBeOverloadedConsistently.cs b/analyzers/tests/SonarAnalyzer.Test/TestCases/OperatorsShouldBeOverloadedConsistently.cs index ae08e5db4ea..7499e3ac0d0 100644 --- a/analyzers/tests/SonarAnalyzer.Test/TestCases/OperatorsShouldBeOverloadedConsistently.cs +++ b/analyzers/tests/SonarAnalyzer.Test/TestCases/OperatorsShouldBeOverloadedConsistently.cs @@ -74,8 +74,7 @@ public class Foo4 public override int GetHashCode() => 0; } - public class Foo5 -// ^^^^ Noncompliant {{Provide an implementation for: 'operator=='.}} + public class Foo5 // Compliant - Covered by CS0216 { public static object operator !=(Foo5 a, Foo5 b) => new object(); // Error [CS0216] - requires == operator @@ -84,7 +83,7 @@ public class Foo5 } public class Foo6 -// ^^^^ Noncompliant {{Provide an implementation for: 'operator!=', 'Object.Equals' and 'Object.GetHashCode'.}} +// ^^^^ Noncompliant {{Provide an implementation for: 'Object.Equals' and 'Object.GetHashCode'.}} { public static object operator ==(Foo6 a, Foo6 b) => new object(); // Error [CS0216] - requires != operator }