From 901ada22a2fb0c3f7963074910a6c5b97e84a0a7 Mon Sep 17 00:00:00 2001
From: Truman Mulholland <59830782+trumully@users.noreply.github.com>
Date: Wed, 27 Nov 2024 01:42:57 +1300
Subject: [PATCH] Add struct equality methods (#452)
This adds `Equals()`, and `GetHashCode()` methods and `==/!=` operators
for value structs and value unions. They can be enabled with the
`--generate-methods` flag. These methods fix the analyzer warning:
`CA1815: Override equals and operator equals on value types``
---
.../SchemaModel/ValueStructSchemaModel.cs | 14 ++++
.../SchemaModel/ValueUnionSchemaModel.cs | 13 ++++
src/NuGet.config | 6 +-
.../FlatSharpEndToEndTests.csproj | 5 +-
.../StructEquality/StructEquality.cs | 77 +++++++++++++++++++
.../StructEquality/StructEquality.fbs | 12 +++
6 files changed, 122 insertions(+), 5 deletions(-)
create mode 100644 src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.cs
create mode 100644 src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.fbs
diff --git a/src/FlatSharp.Compiler/SchemaModel/ValueStructSchemaModel.cs b/src/FlatSharp.Compiler/SchemaModel/ValueStructSchemaModel.cs
index 21889c6f..420c80ff 100644
--- a/src/FlatSharp.Compiler/SchemaModel/ValueStructSchemaModel.cs
+++ b/src/FlatSharp.Compiler/SchemaModel/ValueStructSchemaModel.cs
@@ -137,6 +137,10 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
this.Attributes.EmitAsMetadata(writer);
writer.AppendLine($"[System.Runtime.InteropServices.StructLayout(System.Runtime.InteropServices.LayoutKind.Explicit{size})]");
writer.AppendLine($"public partial struct {this.Name}");
+ if (context.Options.GenerateMethods)
+ {
+ writer.AppendLine($": System.IEquatable<{this.Name}>");
+ }
using (writer.WithBlock())
{
foreach (var field in this.fields)
@@ -151,6 +155,16 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
if (context.Options.GenerateMethods)
{
+ string typeNames = string.Join(", ", this.fields.Select(x => x.TypeName));
+ string names = string.Join(", ", this.fields.Select(x => x.Name));
+ string tupleType = this.fields.Count == 0 ? "System.ValueTuple" : this.fields.Count == 1 ? $"System.ValueTuple<{typeNames}>" : $"({typeNames})";
+ string tupleValue = this.fields.Count < 2 ? $"System.ValueTuple.Create({names})" : $"({names})";
+ writer.AppendLine($"public {tupleType} ToTuple() => {tupleValue};");
+ writer.AppendLine($"public override bool Equals(object? obj) => obj is {this.Name} other && this.Equals(other);");
+ writer.AppendLine($"public bool Equals({this.Name} other) => ToTuple().Equals(other.ToTuple());");
+ writer.AppendLine($"public static bool operator ==({this.Name} left, {this.Name} right) => left.Equals(right);");
+ writer.AppendLine($"public static bool operator !=({this.Name} left, {this.Name} right) => !left.Equals(right);");
+ writer.AppendLine("public override int GetHashCode() => ToTuple().GetHashCode();");
// This matches C# records
string fieldStrings = string.Join(", ", this.fields.Select(x => $"{x.Name} = {{this.{x.Name}}}"));
string fieldStringsWithSpace = this.fields.Count == 0 ? " " : $" {fieldStrings} ";
diff --git a/src/FlatSharp.Compiler/SchemaModel/ValueUnionSchemaModel.cs b/src/FlatSharp.Compiler/SchemaModel/ValueUnionSchemaModel.cs
index f323711a..332eba41 100644
--- a/src/FlatSharp.Compiler/SchemaModel/ValueUnionSchemaModel.cs
+++ b/src/FlatSharp.Compiler/SchemaModel/ValueUnionSchemaModel.cs
@@ -123,6 +123,12 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
this.Attributes.EmitAsMetadata(writer);
writer.AppendLine("[System.Runtime.CompilerServices.CompilerGenerated]");
writer.AppendLine($"public {@unsafe} partial struct {this.Name} : {interfaceName}");
+
+ if (!generateUnsafeItems && context.Options.GenerateMethods)
+ {
+ writer.AppendLine($", System.IEquatable<{this.Name}>");
+ }
+
using (writer.WithBlock())
{
// Generate an internal type enum.
@@ -154,6 +160,13 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
if (!generateUnsafeItems && context.Options.GenerateMethods)
{
+ writer.AppendLine();
+ writer.AppendLine($"public override bool Equals(object? obj) => obj is {this.Name} other && this.Equals(other);");
+ writer.AppendLine($"public bool Equals({this.Name} other) => (this.Discriminator, this.value).Equals((other.Discriminator, other.value));");
+ writer.AppendLine($"public static bool operator ==({this.Name} left, {this.Name} right) => left.Equals(right);");
+ writer.AppendLine($"public static bool operator !=({this.Name} left, {this.Name} right) => !left.Equals(right);");
+ writer.AppendLine("public override int GetHashCode() => (this.Discriminator, this.value).GetHashCode();");
+ writer.AppendLine();
string item = this.union.Values.Count == 0 ? " " : $" this.value ";
writer.AppendLine($"public override string ToString() => $\"{this.Name} {{{{ {{{item}}} }}}}\";");
}
diff --git a/src/NuGet.config b/src/NuGet.config
index 0401ce52..be592680 100644
--- a/src/NuGet.config
+++ b/src/NuGet.config
@@ -1,8 +1,9 @@
-
+
+
@@ -14,5 +15,4 @@
-
-
+
\ No newline at end of file
diff --git a/src/Tests/FlatSharpEndToEndTests/FlatSharpEndToEndTests.csproj b/src/Tests/FlatSharpEndToEndTests/FlatSharpEndToEndTests.csproj
index 8b2b7ef2..11181bce 100644
--- a/src/Tests/FlatSharpEndToEndTests/FlatSharpEndToEndTests.csproj
+++ b/src/Tests/FlatSharpEndToEndTests/FlatSharpEndToEndTests.csproj
@@ -57,6 +57,7 @@
+
@@ -67,7 +68,7 @@
$(IntermediateOutputPath)
$(MSBuildProjectDirectory)/
- $(FlatSharpOutput)ToStringMethods
+ $(FlatSharpOutput)WithGenerateMethodsOption
@@ -76,7 +77,7 @@
$([System.IO.Path]::GetFullPath('$(MSBuildThisFileDirectory)\..\tools\$(CompilerVersion)\FlatSharp.Compiler.dll'))
$(FlatSharpCompilerPath)
- dotnet "$(CompilerPath)" --input "ToStringMethods/ToStringMethods.fbs" --output "$(FlatSharpOutput)" --generate-methods
+ dotnet "$(CompilerPath)" --input "ToStringMethods/ToStringMethods.fbs;StructEquality/StructEquality.fbs" --output "$(FlatSharpOutput)" --generate-methods
diff --git a/src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.cs b/src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.cs
new file mode 100644
index 00000000..f55de98b
--- /dev/null
+++ b/src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.cs
@@ -0,0 +1,77 @@
+using FlatSharpEndToEndTests.Unions;
+
+namespace FlatSharpEndToEndTests.StructEquality;
+
+[TestClass]
+public class StructEqualityTests
+{
+ [TestMethod]
+ public void ToTupleMethod()
+ {
+ MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 };
+
+ var expectedValueStructTuple = (1, 2f, (ushort)5);
+
+ Assert.AreEqual(expectedValueStructTuple, mixedValueStruct.ToTuple());
+ }
+
+ [TestMethod]
+ public void EqualsMethod()
+ {
+ MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 };
+ StructUnion structUnion = new StructUnion(new A { V = 1 });
+
+ var mismatchedStruct = new MixedValueStruct { X = 20 };
+ var mismatchedUnion = new StructUnion(new B { V = 1 });
+ var mismatchedObject = "hi";
+
+ Assert.IsTrue(mixedValueStruct.Equals(new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 }));
+ Assert.IsFalse(mixedValueStruct.Equals(mismatchedStruct));
+ Assert.IsFalse(mixedValueStruct.Equals(mismatchedObject));
+ Assert.IsTrue(structUnion.Equals(new StructUnion(new A { V = 1 })));
+ Assert.IsFalse(structUnion.Equals(mismatchedUnion));
+ Assert.IsFalse(mixedValueStruct.Equals(mismatchedObject));
+ }
+
+ [TestMethod]
+ public void EqualityOperator()
+ {
+ MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 };
+ StructUnion structUnion = new StructUnion(new A { V = 1 });
+
+ var mismatchedStruct = new MixedValueStruct { X = 10 };
+ var mismatchedUnion = new StructUnion(new B { V = 21 });
+
+ Assert.IsTrue(mixedValueStruct == new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 });
+ Assert.IsFalse(mixedValueStruct == mismatchedStruct);
+ Assert.IsTrue(structUnion == new StructUnion(new A { V = 1 }));
+ Assert.IsFalse(structUnion == mismatchedUnion);
+ }
+
+ [TestMethod]
+ public void InequalityOperator()
+ {
+ MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 };
+ StructUnion structUnion = new StructUnion(new A { V = 1 });
+
+ var mismatchedStruct = new MixedValueStruct { Y = 13f };
+ var mismatchedUnion = new StructUnion(new C { V = 42 });
+ var mirrorStruct = mixedValueStruct;
+ var mirrorUnion = structUnion;
+
+ Assert.IsTrue(mixedValueStruct != mismatchedStruct);
+ Assert.IsFalse(mixedValueStruct != mirrorStruct);
+ Assert.IsTrue(structUnion != mismatchedUnion);
+ Assert.IsFalse(structUnion != mirrorUnion);
+ }
+
+ [TestMethod]
+ public void GetHashCodeMethod()
+ {
+ MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 };
+ StructUnion structUnion = new StructUnion(new A { V = 1 });
+
+ Assert.AreEqual(new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 }.GetHashCode(), mixedValueStruct.GetHashCode());
+ Assert.AreEqual(new StructUnion(new A { V = 1 }).GetHashCode(), structUnion.GetHashCode());
+ }
+}
\ No newline at end of file
diff --git a/src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.fbs b/src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.fbs
new file mode 100644
index 00000000..2867e0a8
--- /dev/null
+++ b/src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.fbs
@@ -0,0 +1,12 @@
+attribute "fs_valueStruct";
+
+namespace FlatSharpEndToEndTests.StructEquality;
+
+struct MixedValueStruct (fs_valueStruct) { X : int; Y : float; Z : ushort; }
+
+struct A (fs_valueStruct) { V : uint; }
+struct B (fs_valueStruct) { V : uint; }
+struct C (fs_valueStruct) { V : uint; }
+struct D (fs_valueStruct) { V : uint; }
+
+union StructUnion { A, B, C, D }
\ No newline at end of file