Skip to content

Commit

Permalink
Add struct equality methods (#452)
Browse files Browse the repository at this point in the history
This adds `Equals()`, and `GetHashCode()` methods and `==/!=` operators
for value structs and value unions. They can be enabled with the
`--generate-methods` flag. These methods fix the analyzer warning:
`CA1815: Override equals and operator equals on value types``
  • Loading branch information
trumully authored Nov 26, 2024
1 parent 08ecde1 commit 901ada2
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 5 deletions.
14 changes: 14 additions & 0 deletions src/FlatSharp.Compiler/SchemaModel/ValueStructSchemaModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
this.Attributes.EmitAsMetadata(writer);
writer.AppendLine($"[System.Runtime.InteropServices.StructLayout(System.Runtime.InteropServices.LayoutKind.Explicit{size})]");
writer.AppendLine($"public partial struct {this.Name}");
if (context.Options.GenerateMethods)
{
writer.AppendLine($": System.IEquatable<{this.Name}>");
}
using (writer.WithBlock())
{
foreach (var field in this.fields)
Expand All @@ -151,6 +155,16 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)

if (context.Options.GenerateMethods)
{
string typeNames = string.Join(", ", this.fields.Select(x => x.TypeName));
string names = string.Join(", ", this.fields.Select(x => x.Name));
string tupleType = this.fields.Count == 0 ? "System.ValueTuple" : this.fields.Count == 1 ? $"System.ValueTuple<{typeNames}>" : $"({typeNames})";
string tupleValue = this.fields.Count < 2 ? $"System.ValueTuple.Create({names})" : $"({names})";
writer.AppendLine($"public {tupleType} ToTuple() => {tupleValue};");
writer.AppendLine($"public override bool Equals(object? obj) => obj is {this.Name} other && this.Equals(other);");
writer.AppendLine($"public bool Equals({this.Name} other) => ToTuple().Equals(other.ToTuple());");
writer.AppendLine($"public static bool operator ==({this.Name} left, {this.Name} right) => left.Equals(right);");
writer.AppendLine($"public static bool operator !=({this.Name} left, {this.Name} right) => !left.Equals(right);");
writer.AppendLine("public override int GetHashCode() => ToTuple().GetHashCode();");
// This matches C# records
string fieldStrings = string.Join(", ", this.fields.Select(x => $"{x.Name} = {{this.{x.Name}}}"));
string fieldStringsWithSpace = this.fields.Count == 0 ? " " : $" {fieldStrings} ";
Expand Down
13 changes: 13 additions & 0 deletions src/FlatSharp.Compiler/SchemaModel/ValueUnionSchemaModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
this.Attributes.EmitAsMetadata(writer);
writer.AppendLine("[System.Runtime.CompilerServices.CompilerGenerated]");
writer.AppendLine($"public {@unsafe} partial struct {this.Name} : {interfaceName}");

if (!generateUnsafeItems && context.Options.GenerateMethods)
{
writer.AppendLine($", System.IEquatable<{this.Name}>");
}

using (writer.WithBlock())
{
// Generate an internal type enum.
Expand Down Expand Up @@ -154,6 +160,13 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)

if (!generateUnsafeItems && context.Options.GenerateMethods)
{
writer.AppendLine();
writer.AppendLine($"public override bool Equals(object? obj) => obj is {this.Name} other && this.Equals(other);");
writer.AppendLine($"public bool Equals({this.Name} other) => (this.Discriminator, this.value).Equals((other.Discriminator, other.value));");
writer.AppendLine($"public static bool operator ==({this.Name} left, {this.Name} right) => left.Equals(right);");
writer.AppendLine($"public static bool operator !=({this.Name} left, {this.Name} right) => !left.Equals(right);");
writer.AppendLine("public override int GetHashCode() => (this.Discriminator, this.value).GetHashCode();");
writer.AppendLine();
string item = this.union.Values.Count == 0 ? " " : $" this.value ";
writer.AppendLine($"public override string ToString() => $\"{this.Name} {{{{ {{{item}}} }}}}\";");
}
Expand Down
6 changes: 3 additions & 3 deletions src/NuGet.config
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
<?xml version="1.0" encoding="utf-8"?>
<?xml version="1.0" encoding="utf-8"?>
<configuration>
<packageSources>
<!-- MSTest early access packages. See: https://aka.ms/mstest/preview -->
<add key="test-tools" value="https://pkgs.dev.azure.com/dnceng/public/_packaging/test-tools/nuget/v3/index.json" />
<add key="nuget.org" value="https://api.nuget.org/v3/index.json" />
</packageSources>
<packageSourceMapping>
<!-- key value for <packageSource> should match key values from <packageSources> element -->
Expand All @@ -14,5 +15,4 @@
<package pattern="Microsoft.Testing.*" />
</packageSource>
</packageSourceMapping>
</configuration>

</configuration>
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
<ItemGroup>
<FlatSharpSchema Include="**\*.fbs" />
<FlatSharpSchema Remove="ToStringMethods/ToStringMethods.fbs" />
<FlatSharpSchema Remove="StructEquality/StructEquality.fbs" />
</ItemGroup>

<Target Name="FlatSharpFbsCompileToString" BeforeTargets="ResolveAssemblyReferences">
Expand All @@ -67,7 +68,7 @@
<PropertyGroup>
<FlatSharpOutput>$(IntermediateOutputPath)</FlatSharpOutput>
<FlatSharpOutput Condition=" '$(FlatSharpMutationTestingMode)' == 'true' ">$(MSBuildProjectDirectory)/</FlatSharpOutput>
<FlatSharpOutput>$(FlatSharpOutput)ToStringMethods</FlatSharpOutput>
<FlatSharpOutput>$(FlatSharpOutput)WithGenerateMethodsOption</FlatSharpOutput>
</PropertyGroup>

<MakeDir Directories="$(FlatSharpOutput)" Condition="!Exists('$(FlatSharpOutput)')" />
Expand All @@ -76,7 +77,7 @@
<PropertyGroup>
<CompilerPath>$([System.IO.Path]::GetFullPath('$(MSBuildThisFileDirectory)\..\tools\$(CompilerVersion)\FlatSharp.Compiler.dll'))</CompilerPath>
<CompilerPath Condition=" '$(FlatSharpCompilerPath)' != '' ">$(FlatSharpCompilerPath)</CompilerPath>
<CompilerCommand>dotnet &quot;$(CompilerPath)&quot; --input &quot;ToStringMethods/ToStringMethods.fbs&quot; --output &quot;$(FlatSharpOutput)&quot; --generate-methods</CompilerCommand>
<CompilerCommand>dotnet &quot;$(CompilerPath)&quot; --input &quot;ToStringMethods/ToStringMethods.fbs;StructEquality/StructEquality.fbs&quot; --output &quot;$(FlatSharpOutput)&quot; --generate-methods</CompilerCommand>
</PropertyGroup>

<Message Text="$(CompilerCommand)" Importance="high" />
Expand Down
77 changes: 77 additions & 0 deletions src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
using FlatSharpEndToEndTests.Unions;

namespace FlatSharpEndToEndTests.StructEquality;

[TestClass]
public class StructEqualityTests
{
[TestMethod]
public void ToTupleMethod()
{
MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 };

var expectedValueStructTuple = (1, 2f, (ushort)5);

Assert.AreEqual(expectedValueStructTuple, mixedValueStruct.ToTuple());
}

[TestMethod]
public void EqualsMethod()
{
MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 };
StructUnion structUnion = new StructUnion(new A { V = 1 });

var mismatchedStruct = new MixedValueStruct { X = 20 };
var mismatchedUnion = new StructUnion(new B { V = 1 });
var mismatchedObject = "hi";

Assert.IsTrue(mixedValueStruct.Equals(new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 }));
Assert.IsFalse(mixedValueStruct.Equals(mismatchedStruct));
Assert.IsFalse(mixedValueStruct.Equals(mismatchedObject));
Assert.IsTrue(structUnion.Equals(new StructUnion(new A { V = 1 })));
Assert.IsFalse(structUnion.Equals(mismatchedUnion));
Assert.IsFalse(mixedValueStruct.Equals(mismatchedObject));
}

[TestMethod]
public void EqualityOperator()
{
MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 };
StructUnion structUnion = new StructUnion(new A { V = 1 });

var mismatchedStruct = new MixedValueStruct { X = 10 };
var mismatchedUnion = new StructUnion(new B { V = 21 });

Assert.IsTrue(mixedValueStruct == new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 });
Assert.IsFalse(mixedValueStruct == mismatchedStruct);
Assert.IsTrue(structUnion == new StructUnion(new A { V = 1 }));
Assert.IsFalse(structUnion == mismatchedUnion);
}

[TestMethod]
public void InequalityOperator()
{
MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 };
StructUnion structUnion = new StructUnion(new A { V = 1 });

var mismatchedStruct = new MixedValueStruct { Y = 13f };
var mismatchedUnion = new StructUnion(new C { V = 42 });
var mirrorStruct = mixedValueStruct;
var mirrorUnion = structUnion;

Assert.IsTrue(mixedValueStruct != mismatchedStruct);
Assert.IsFalse(mixedValueStruct != mirrorStruct);
Assert.IsTrue(structUnion != mismatchedUnion);
Assert.IsFalse(structUnion != mirrorUnion);
}

[TestMethod]
public void GetHashCodeMethod()
{
MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 };
StructUnion structUnion = new StructUnion(new A { V = 1 });

Assert.AreEqual(new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 }.GetHashCode(), mixedValueStruct.GetHashCode());
Assert.AreEqual(new StructUnion(new A { V = 1 }).GetHashCode(), structUnion.GetHashCode());
}
}
12 changes: 12 additions & 0 deletions src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.fbs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
attribute "fs_valueStruct";

namespace FlatSharpEndToEndTests.StructEquality;

struct MixedValueStruct (fs_valueStruct) { X : int; Y : float; Z : ushort; }

struct A (fs_valueStruct) { V : uint; }
struct B (fs_valueStruct) { V : uint; }
struct C (fs_valueStruct) { V : uint; }
struct D (fs_valueStruct) { V : uint; }

union StructUnion { A, B, C, D }

0 comments on commit 901ada2

Please sign in to comment.