Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

S4050: Promote to SonarWay #9630

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,21 @@
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

namespace SonarAnalyzer.Rules.CSharp
namespace SonarAnalyzer.Rules.CSharp;

[DiagnosticAnalyzer(LanguageNames.CSharp)]
public sealed class OperatorsShouldBeOverloadedConsistently : SonarDiagnosticAnalyzer
{
[DiagnosticAnalyzer(LanguageNames.CSharp)]
public sealed class OperatorsShouldBeOverloadedConsistently : SonarDiagnosticAnalyzer
{
private const string DiagnosticId = "S4050";
private const string MessageFormat = "Provide an implementation for: {0}.";
private const string DiagnosticId = "S4050";
private const string MessageFormat = "Provide an implementation for: {0}.";

private static readonly DiagnosticDescriptor Rule =
DescriptorFactory.Create(DiagnosticId, MessageFormat);
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } = ImmutableArray.Create(Rule);
private static readonly DiagnosticDescriptor Rule =
DescriptorFactory.Create(DiagnosticId, MessageFormat);

private static class MethodName
{
public const string OperatorPlus = "operator+";
public const string OperatorMinus = "operator-";
public const string OperatorMultiply = "operator*";
public const string OperatorDivide = "operator/";
public const string OperatorReminder = "operator%";
public const string OperatorEquals = "operator==";
public const string OperatorNotEquals = "operator!=";

public const string ObjectEquals = "Object.Equals";
public const string ObjectGetHashCode = "Object.GetHashCode";
}
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } = ImmutableArray.Create(Rule);

protected override void Initialize(SonarAnalysisContext context) =>
context.RegisterNodeAction(c =>
protected override void Initialize(SonarAnalysisContext context) =>
context.RegisterNodeAction(c =>
{
var classDeclaration = (ClassDeclarationSyntax)c.Node;
var classSymbol = (INamedTypeSymbol)c.ContainingSymbol;
Expand All @@ -65,71 +52,82 @@ protected override void Initialize(SonarAnalysisContext context) =>
// This rule is not applicable for records, as for records it is not possible to override the == operator.
SyntaxKind.ClassDeclaration);

private static IEnumerable<string> FindMissingMethods(INamedTypeSymbol classSymbol)
private static IEnumerable<string> FindMissingMethods(INamedTypeSymbol classSymbol)
{
var implementedMethods = GetImplementedMethods(classSymbol).ToHashSet();
var requiredMethods = new HashSet<string>();

if (implementedMethods.Contains(MethodName.OperatorPlus)
|| implementedMethods.Contains(MethodName.OperatorMinus)
|| implementedMethods.Contains(MethodName.OperatorMultiply)
|| implementedMethods.Contains(MethodName.OperatorDivide)
|| implementedMethods.Contains(MethodName.OperatorRemainder))
{
requiredMethods.Add(MethodName.OperatorEquals);
requiredMethods.Add(MethodName.OperatorNotEquals);
requiredMethods.Add(MethodName.ObjectEquals);
requiredMethods.Add(MethodName.ObjectGetHashCode);
}

if (implementedMethods.Contains(MethodName.OperatorEquals))
{
requiredMethods.Add(MethodName.ObjectEquals);
requiredMethods.Add(MethodName.ObjectGetHashCode);
}

if (implementedMethods.Contains(MethodName.OperatorNotEquals))
{
requiredMethods.Add(MethodName.ObjectEquals);
requiredMethods.Add(MethodName.ObjectGetHashCode);
}

return requiredMethods.Except(implementedMethods);
}

private static IEnumerable<string> GetImplementedMethods(INamedTypeSymbol classSymbol)
{
foreach (var member in classSymbol.GetMembers().OfType<IMethodSymbol>().Where(x => !x.IsConstructor()))
{
var implementedMethods = GetImplementedMethods(classSymbol).ToHashSet();
var requiredMethods = new HashSet<string>();

if (implementedMethods.Contains(MethodName.OperatorPlus)
|| implementedMethods.Contains(MethodName.OperatorMinus)
|| implementedMethods.Contains(MethodName.OperatorMultiply)
|| implementedMethods.Contains(MethodName.OperatorDivide)
|| implementedMethods.Contains(MethodName.OperatorReminder))
if (ImplementedOperator(member) is { } name)
{
requiredMethods.Add(MethodName.OperatorEquals);
requiredMethods.Add(MethodName.OperatorNotEquals);
requiredMethods.Add(MethodName.ObjectEquals);
requiredMethods.Add(MethodName.ObjectGetHashCode);
yield return name;
}

if (implementedMethods.Contains(MethodName.OperatorEquals))
else if (KnownMethods.IsObjectEquals(member))
{
requiredMethods.Add(MethodName.OperatorNotEquals);
requiredMethods.Add(MethodName.ObjectEquals);
requiredMethods.Add(MethodName.ObjectGetHashCode);
yield return MethodName.ObjectEquals;
}

if (implementedMethods.Contains(MethodName.OperatorNotEquals))
else if (KnownMethods.IsObjectGetHashCode(member))
{
requiredMethods.Add(MethodName.OperatorEquals);
requiredMethods.Add(MethodName.ObjectEquals);
requiredMethods.Add(MethodName.ObjectGetHashCode);
yield return MethodName.ObjectGetHashCode;
}

return requiredMethods.Except(implementedMethods);
}
}

private static IEnumerable<string> GetImplementedMethods(INamedTypeSymbol classSymbol)
private static string ImplementedOperator(IMethodSymbol member) =>
member switch
{
foreach (var member in classSymbol.GetMembers().OfType<IMethodSymbol>().Where(x => !x.IsConstructor()))
{
if (ImplementedOperator(member) is { } name)
{
yield return name;
}
else if (KnownMethods.IsObjectEquals(member))
{
yield return MethodName.ObjectEquals;
}
else if (KnownMethods.IsObjectGetHashCode(member))
{
yield return MethodName.ObjectGetHashCode;
}
}
}
{ MethodKind: not MethodKind.UserDefinedOperator } => null,
_ when KnownMethods.IsOperatorBinaryPlus(member) => MethodName.OperatorPlus,
_ when KnownMethods.IsOperatorBinaryMinus(member) => MethodName.OperatorMinus,
_ when KnownMethods.IsOperatorBinaryMultiply(member) => MethodName.OperatorMultiply,
_ when KnownMethods.IsOperatorBinaryDivide(member) => MethodName.OperatorDivide,
_ when KnownMethods.IsOperatorBinaryModulus(member) => MethodName.OperatorRemainder,
_ when KnownMethods.IsOperatorEquals(member) => MethodName.OperatorEquals,
_ when KnownMethods.IsOperatorNotEquals(member) => MethodName.OperatorNotEquals,
_ => null
};

private static string ImplementedOperator(IMethodSymbol member) =>
member switch
{
{ MethodKind: not MethodKind.UserDefinedOperator } => null,
_ when KnownMethods.IsOperatorBinaryPlus(member) => MethodName.OperatorPlus,
_ when KnownMethods.IsOperatorBinaryMinus(member) => MethodName.OperatorMinus,
_ when KnownMethods.IsOperatorBinaryMultiply(member) => MethodName.OperatorMultiply,
_ when KnownMethods.IsOperatorBinaryDivide(member) => MethodName.OperatorDivide,
_ when KnownMethods.IsOperatorBinaryModulus(member) => MethodName.OperatorReminder,
_ when KnownMethods.IsOperatorEquals(member) => MethodName.OperatorEquals,
_ when KnownMethods.IsOperatorNotEquals(member) => MethodName.OperatorNotEquals,
_ => null
};
private static class MethodName
{
public const string OperatorPlus = "operator+";
public const string OperatorMinus = "operator-";
public const string OperatorMultiply = "operator*";
public const string OperatorDivide = "operator/";
public const string OperatorRemainder = "operator%";
public const string OperatorEquals = "operator==";
public const string OperatorNotEquals = "operator!=";

public const string ObjectEquals = "Object.Equals";
public const string ObjectGetHashCode = "Object.GetHashCode";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ public class Foo4
public override int GetHashCode() => 0;
}

public class Foo5
// ^^^^ Noncompliant {{Provide an implementation for: 'operator=='.}}
public class Foo5 // Compliant - Covered by CS0216
{
public static object operator !=(Foo5 a, Foo5 b) => new object(); // Error [CS0216] - requires == operator

Expand All @@ -84,7 +83,7 @@ public class Foo5
}

public class Foo6
// ^^^^ Noncompliant {{Provide an implementation for: 'operator!=', 'Object.Equals' and 'Object.GetHashCode'.}}
// ^^^^ Noncompliant {{Provide an implementation for: 'Object.Equals' and 'Object.GetHashCode'.}}
{
public static object operator ==(Foo6 a, Foo6 b) => new object(); // Error [CS0216] - requires != operator
}
Expand Down
Loading