From 901ada22a2fb0c3f7963074910a6c5b97e84a0a7 Mon Sep 17 00:00:00 2001 From: Truman Mulholland <59830782+trumully@users.noreply.github.com> Date: Wed, 27 Nov 2024 01:42:57 +1300 Subject: [PATCH] Add struct equality methods (#452) This adds `Equals()`, and `GetHashCode()` methods and `==/!=` operators for value structs and value unions. They can be enabled with the `--generate-methods` flag. These methods fix the analyzer warning: `CA1815: Override equals and operator equals on value types`` --- .../SchemaModel/ValueStructSchemaModel.cs | 14 ++++ .../SchemaModel/ValueUnionSchemaModel.cs | 13 ++++ src/NuGet.config | 6 +- .../FlatSharpEndToEndTests.csproj | 5 +- .../StructEquality/StructEquality.cs | 77 +++++++++++++++++++ .../StructEquality/StructEquality.fbs | 12 +++ 6 files changed, 122 insertions(+), 5 deletions(-) create mode 100644 src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.cs create mode 100644 src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.fbs diff --git a/src/FlatSharp.Compiler/SchemaModel/ValueStructSchemaModel.cs b/src/FlatSharp.Compiler/SchemaModel/ValueStructSchemaModel.cs index 21889c6f..420c80ff 100644 --- a/src/FlatSharp.Compiler/SchemaModel/ValueStructSchemaModel.cs +++ b/src/FlatSharp.Compiler/SchemaModel/ValueStructSchemaModel.cs @@ -137,6 +137,10 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context) this.Attributes.EmitAsMetadata(writer); writer.AppendLine($"[System.Runtime.InteropServices.StructLayout(System.Runtime.InteropServices.LayoutKind.Explicit{size})]"); writer.AppendLine($"public partial struct {this.Name}"); + if (context.Options.GenerateMethods) + { + writer.AppendLine($": System.IEquatable<{this.Name}>"); + } using (writer.WithBlock()) { foreach (var field in this.fields) @@ -151,6 +155,16 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context) if (context.Options.GenerateMethods) { + string typeNames = string.Join(", ", this.fields.Select(x => x.TypeName)); + string names = string.Join(", ", this.fields.Select(x => x.Name)); + string tupleType = this.fields.Count == 0 ? "System.ValueTuple" : this.fields.Count == 1 ? $"System.ValueTuple<{typeNames}>" : $"({typeNames})"; + string tupleValue = this.fields.Count < 2 ? $"System.ValueTuple.Create({names})" : $"({names})"; + writer.AppendLine($"public {tupleType} ToTuple() => {tupleValue};"); + writer.AppendLine($"public override bool Equals(object? obj) => obj is {this.Name} other && this.Equals(other);"); + writer.AppendLine($"public bool Equals({this.Name} other) => ToTuple().Equals(other.ToTuple());"); + writer.AppendLine($"public static bool operator ==({this.Name} left, {this.Name} right) => left.Equals(right);"); + writer.AppendLine($"public static bool operator !=({this.Name} left, {this.Name} right) => !left.Equals(right);"); + writer.AppendLine("public override int GetHashCode() => ToTuple().GetHashCode();"); // This matches C# records string fieldStrings = string.Join(", ", this.fields.Select(x => $"{x.Name} = {{this.{x.Name}}}")); string fieldStringsWithSpace = this.fields.Count == 0 ? " " : $" {fieldStrings} "; diff --git a/src/FlatSharp.Compiler/SchemaModel/ValueUnionSchemaModel.cs b/src/FlatSharp.Compiler/SchemaModel/ValueUnionSchemaModel.cs index f323711a..332eba41 100644 --- a/src/FlatSharp.Compiler/SchemaModel/ValueUnionSchemaModel.cs +++ b/src/FlatSharp.Compiler/SchemaModel/ValueUnionSchemaModel.cs @@ -123,6 +123,12 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context) this.Attributes.EmitAsMetadata(writer); writer.AppendLine("[System.Runtime.CompilerServices.CompilerGenerated]"); writer.AppendLine($"public {@unsafe} partial struct {this.Name} : {interfaceName}"); + + if (!generateUnsafeItems && context.Options.GenerateMethods) + { + writer.AppendLine($", System.IEquatable<{this.Name}>"); + } + using (writer.WithBlock()) { // Generate an internal type enum. @@ -154,6 +160,13 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context) if (!generateUnsafeItems && context.Options.GenerateMethods) { + writer.AppendLine(); + writer.AppendLine($"public override bool Equals(object? obj) => obj is {this.Name} other && this.Equals(other);"); + writer.AppendLine($"public bool Equals({this.Name} other) => (this.Discriminator, this.value).Equals((other.Discriminator, other.value));"); + writer.AppendLine($"public static bool operator ==({this.Name} left, {this.Name} right) => left.Equals(right);"); + writer.AppendLine($"public static bool operator !=({this.Name} left, {this.Name} right) => !left.Equals(right);"); + writer.AppendLine("public override int GetHashCode() => (this.Discriminator, this.value).GetHashCode();"); + writer.AppendLine(); string item = this.union.Values.Count == 0 ? " " : $" this.value "; writer.AppendLine($"public override string ToString() => $\"{this.Name} {{{{ {{{item}}} }}}}\";"); } diff --git a/src/NuGet.config b/src/NuGet.config index 0401ce52..be592680 100644 --- a/src/NuGet.config +++ b/src/NuGet.config @@ -1,8 +1,9 @@ - + + @@ -14,5 +15,4 @@ - - + \ No newline at end of file diff --git a/src/Tests/FlatSharpEndToEndTests/FlatSharpEndToEndTests.csproj b/src/Tests/FlatSharpEndToEndTests/FlatSharpEndToEndTests.csproj index 8b2b7ef2..11181bce 100644 --- a/src/Tests/FlatSharpEndToEndTests/FlatSharpEndToEndTests.csproj +++ b/src/Tests/FlatSharpEndToEndTests/FlatSharpEndToEndTests.csproj @@ -57,6 +57,7 @@ + @@ -67,7 +68,7 @@ $(IntermediateOutputPath) $(MSBuildProjectDirectory)/ - $(FlatSharpOutput)ToStringMethods + $(FlatSharpOutput)WithGenerateMethodsOption @@ -76,7 +77,7 @@ $([System.IO.Path]::GetFullPath('$(MSBuildThisFileDirectory)\..\tools\$(CompilerVersion)\FlatSharp.Compiler.dll')) $(FlatSharpCompilerPath) - dotnet "$(CompilerPath)" --input "ToStringMethods/ToStringMethods.fbs" --output "$(FlatSharpOutput)" --generate-methods + dotnet "$(CompilerPath)" --input "ToStringMethods/ToStringMethods.fbs;StructEquality/StructEquality.fbs" --output "$(FlatSharpOutput)" --generate-methods diff --git a/src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.cs b/src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.cs new file mode 100644 index 00000000..f55de98b --- /dev/null +++ b/src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.cs @@ -0,0 +1,77 @@ +using FlatSharpEndToEndTests.Unions; + +namespace FlatSharpEndToEndTests.StructEquality; + +[TestClass] +public class StructEqualityTests +{ + [TestMethod] + public void ToTupleMethod() + { + MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 }; + + var expectedValueStructTuple = (1, 2f, (ushort)5); + + Assert.AreEqual(expectedValueStructTuple, mixedValueStruct.ToTuple()); + } + + [TestMethod] + public void EqualsMethod() + { + MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 }; + StructUnion structUnion = new StructUnion(new A { V = 1 }); + + var mismatchedStruct = new MixedValueStruct { X = 20 }; + var mismatchedUnion = new StructUnion(new B { V = 1 }); + var mismatchedObject = "hi"; + + Assert.IsTrue(mixedValueStruct.Equals(new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 })); + Assert.IsFalse(mixedValueStruct.Equals(mismatchedStruct)); + Assert.IsFalse(mixedValueStruct.Equals(mismatchedObject)); + Assert.IsTrue(structUnion.Equals(new StructUnion(new A { V = 1 }))); + Assert.IsFalse(structUnion.Equals(mismatchedUnion)); + Assert.IsFalse(mixedValueStruct.Equals(mismatchedObject)); + } + + [TestMethod] + public void EqualityOperator() + { + MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 }; + StructUnion structUnion = new StructUnion(new A { V = 1 }); + + var mismatchedStruct = new MixedValueStruct { X = 10 }; + var mismatchedUnion = new StructUnion(new B { V = 21 }); + + Assert.IsTrue(mixedValueStruct == new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 }); + Assert.IsFalse(mixedValueStruct == mismatchedStruct); + Assert.IsTrue(structUnion == new StructUnion(new A { V = 1 })); + Assert.IsFalse(structUnion == mismatchedUnion); + } + + [TestMethod] + public void InequalityOperator() + { + MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 }; + StructUnion structUnion = new StructUnion(new A { V = 1 }); + + var mismatchedStruct = new MixedValueStruct { Y = 13f }; + var mismatchedUnion = new StructUnion(new C { V = 42 }); + var mirrorStruct = mixedValueStruct; + var mirrorUnion = structUnion; + + Assert.IsTrue(mixedValueStruct != mismatchedStruct); + Assert.IsFalse(mixedValueStruct != mirrorStruct); + Assert.IsTrue(structUnion != mismatchedUnion); + Assert.IsFalse(structUnion != mirrorUnion); + } + + [TestMethod] + public void GetHashCodeMethod() + { + MixedValueStruct mixedValueStruct = new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 }; + StructUnion structUnion = new StructUnion(new A { V = 1 }); + + Assert.AreEqual(new MixedValueStruct { X = 1, Y = 2f, Z = (ushort)5 }.GetHashCode(), mixedValueStruct.GetHashCode()); + Assert.AreEqual(new StructUnion(new A { V = 1 }).GetHashCode(), structUnion.GetHashCode()); + } +} \ No newline at end of file diff --git a/src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.fbs b/src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.fbs new file mode 100644 index 00000000..2867e0a8 --- /dev/null +++ b/src/Tests/FlatSharpEndToEndTests/StructEquality/StructEquality.fbs @@ -0,0 +1,12 @@ +attribute "fs_valueStruct"; + +namespace FlatSharpEndToEndTests.StructEquality; + +struct MixedValueStruct (fs_valueStruct) { X : int; Y : float; Z : ushort; } + +struct A (fs_valueStruct) { V : uint; } +struct B (fs_valueStruct) { V : uint; } +struct C (fs_valueStruct) { V : uint; } +struct D (fs_valueStruct) { V : uint; } + +union StructUnion { A, B, C, D } \ No newline at end of file