From 6e401e2047a9440c6f03553511bbeb583374e721 Mon Sep 17 00:00:00 2001 From: Daniel Cazzulino Date: Sat, 21 Dec 2024 00:00:43 -0300 Subject: [PATCH] Add built-in support for parsable/formattable values for EF Rather than adding explicit Ulid support as suggested in #18, a more flexible approach is to simply rely on the combination of IParsable plus IFormattable checks on the TValue/TId and emit a standard EF value converter that relies on those, plus the INewable we already have for the struct-id itself. This involved a more flexible resource template selection mechanism, which needs further refactoring for simplification (and perhaps dropping altogether) of the BaseGenerator. --- src/StructId.Analyzer/BaseGenerator.cs | 11 +- .../EntityFrameworkGenerator.cs | 130 +++++++++++++----- .../EntityFrameworkSelector.sbn | 17 ++- src/StructId.FunctionalTests/Functional.cs | 2 +- src/StructId.FunctionalTests/UlidTests.cs | 30 ++-- .../ResourceTemplates/EntityFramework.cs | 8 +- .../EntityFrameworkParsable.cs | 34 +++++ src/StructId/Templates/DapperTypeHandler.cs | 13 +- .../EntityFrameworkValueConverter.cs | 15 +- 9 files changed, 187 insertions(+), 73 deletions(-) create mode 100644 src/StructId/ResourceTemplates/EntityFrameworkParsable.cs diff --git a/src/StructId.Analyzer/BaseGenerator.cs b/src/StructId.Analyzer/BaseGenerator.cs index fb93028..5fff3e6 100644 --- a/src/StructId.Analyzer/BaseGenerator.cs +++ b/src/StructId.Analyzer/BaseGenerator.cs @@ -1,7 +1,6 @@ using System.Linq; using System.Text; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.Text; namespace StructId; @@ -70,11 +69,13 @@ public virtual void Initialize(IncrementalGeneratorInitializationContext context protected virtual IncrementalValuesProvider OnInitialize(IncrementalGeneratorInitializationContext context, IncrementalValuesProvider source) => source; - void GenerateCode(SourceProductionContext context, TemplateArgs args) => AddFromTemplate( - context, args, $"{args.TSelf.ToFileName()}.cs", - args.TId.Equals(args.KnownTypes.String, SymbolEqualityComparer.Default) ? + void GenerateCode(SourceProductionContext context, TemplateArgs args) + => AddFromTemplate(context, args, $"{args.TSelf.ToFileName()}.cs", SelectTemplate(args)); + + protected virtual SyntaxNode SelectTemplate(TemplateArgs args) + => args.TId.Equals(args.KnownTypes.String, SymbolEqualityComparer.Default) ? (stringSyntax ??= CodeTemplate.Parse(stringTemplate, args.KnownTypes.Compilation.GetParseOptions())) : - (typedSyntax ??= CodeTemplate.Parse(typeTemplate, args.KnownTypes.Compilation.GetParseOptions()))); + (typedSyntax ??= CodeTemplate.Parse(typeTemplate, args.KnownTypes.Compilation.GetParseOptions())); protected static void AddFromTemplate(SourceProductionContext context, TemplateArgs args, string hintName, SyntaxNode template) { diff --git a/src/StructId.Analyzer/EntityFrameworkGenerator.cs b/src/StructId.Analyzer/EntityFrameworkGenerator.cs index ee61ac7..8ec9a44 100644 --- a/src/StructId.Analyzer/EntityFrameworkGenerator.cs +++ b/src/StructId.Analyzer/EntityFrameworkGenerator.cs @@ -2,82 +2,142 @@ using System.Collections.Immutable; using System.Linq; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; using Scriban; +using static StructId.AnalysisExtensions; namespace StructId; [Generator(LanguageNames.CSharp)] public class EntityFrameworkGenerator() : BaseGenerator( - "Microsoft.EntityFrameworkCore.Storage.ValueConversion.ValueConverter`2", + ValueConverterType, ThisAssembly.Resources.Templates.EntityFramework.Text, ThisAssembly.Resources.Templates.EntityFramework.Text, ReferenceCheck.TypeExists) { - static readonly Template template = Template.Parse(ThisAssembly.Resources.EntityFrameworkSelector.Text); + public const string ValueConverterType = "Microsoft.EntityFrameworkCore.Storage.ValueConversion.ValueConverter`2"; + + static readonly Dictionary builtInTypesMap = new() + { + ["System.String"] = "string", + ["System.Int32"] = "int", + ["System.Int64"] = "long", + ["System.Boolean"] = "bool", + ["System.Single"] = "float", + ["System.Double"] = "double", + ["System.Decimal"] = "decimal", + ["System.DateTime"] = "DateTime", + ["System.Guid"] = "Guid", + ["System.TimeSpan"] = "TimeSpan", + ["System.Byte"] = "byte", + ["System.Byte[]"] = "byte[]", + ["System.Char"] = "char", + ["System.UInt32"] = "uint", + ["System.UInt64"] = "ulong", + ["System.SByte"] = "sbyte", + ["System.UInt16"] = "ushort", + ["System.Int16"] = "short", + ["System.Object"] = "object", + }; + + static readonly Template selectorTemplate = Template.Parse(ThisAssembly.Resources.EntityFrameworkSelector.Text); + + SyntaxNode? idTemplate; + SyntaxNode? parsableIdTemplate; protected override IncrementalValuesProvider OnInitialize(IncrementalGeneratorInitializationContext context, IncrementalValuesProvider source) { var converters = context.CompilationProvider .SelectMany((x, _) => x.Assembly.GetAllTypes().OfType()) - .Combine(context.CompilationProvider.Select((x, _) => x.GetTypeByMetadataName("Microsoft.EntityFrameworkCore.Storage.ValueConversion.ValueConverter`2"))) + .Combine(context.CompilationProvider.Select((x, _) => x.GetTypeByMetadataName(ValueConverterType))) .Where(x => x.Left != null && x.Right != null && x.Left.Is(x.Right) && !x.Left.IsUnboundGenericType && x.Left.BaseType?.TypeArguments.Length == 2 && - // Don't emit as plain converters if they are id templates + // Don't emit as plain converters if they are value templates !x.Left.GetAttributes().Any(a => a.IsValueTemplate())) .Select((x, _) => x.Left) .Collect(); - context.RegisterSourceOutput(source.Collect().Combine(converters), GenerateValueSelector); + var templatizedValues = context.SelectTemplatizedValues() + .Combine(context.CompilationProvider.Select((x, _) => x.GetTypeByMetadataName(ValueConverterType))) + .Where(x => x.Left.Template.TTemplate.Is(x.Right)) + .Select((x, _) => x.Left); + + context.RegisterSourceOutput(source.Collect().Combine(converters).Combine(templatizedValues.Collect()), GenerateValueSelector); return base.OnInitialize(context, source); } - void GenerateValueSelector(SourceProductionContext context, (ImmutableArray, ImmutableArray) args) + protected override SyntaxNode SelectTemplate(TemplateArgs args) + { + if (args.TId.Equals(args.KnownTypes.String, SymbolEqualityComparer.Default) || + builtInTypesMap.ContainsKey(args.TId.ToDisplayString(NamespacedTypeName))) + return idTemplate ??= CodeTemplate.Parse(ThisAssembly.Resources.Templates.EntityFramework.Text, args.KnownTypes.Compilation.GetParseOptions()); + else if (args.TId.Is(args.KnownTypes.Compilation.GetTypeByMetadataName("System.IParsable`1")) && + args.TId.Is(args.KnownTypes.Compilation.GetTypeByMetadataName("System.IFormattable"))) + return parsableIdTemplate ??= CodeTemplate.Parse(ThisAssembly.Resources.Templates.EntityFrameworkParsable.Text, args.KnownTypes.Compilation.GetParseOptions()); + else + return idTemplate ??= CodeTemplate.Parse(ThisAssembly.Resources.Templates.EntityFramework.Text, args.KnownTypes.Compilation.GetParseOptions()); + } + + void GenerateValueSelector(SourceProductionContext context, ((ImmutableArray, ImmutableArray), ImmutableArray) args) { - (var ids, var converters) = args; + ((var structIds, var customConverters), var templatizedConverters) = args; - if (ids.Length == 0) + if (structIds.Length == 0 && customConverters.Length == 0 && templatizedConverters.Length == 0) return; var model = new SelectorModel( - ids.Select(x => new StructIdModel(x.TSelf.ToFullName(), x.TId.ToFullName())), - converters.Select(x => new ConverterModel(x.BaseType!.TypeArguments[0].ToFullName(), x.BaseType!.TypeArguments[1].ToFullName(), x.ToFullName()))); + structIds.Select(x => new StructIdModel(x.TSelf.ToFullName(), + // The TId is used as the ProviderClrType for EF, which should be either a built-in + // supported type or a parsable one. We default to using the type as-is for future-proofing, + // but that may be subject to change. + !builtInTypesMap.ContainsKey(x.TId.ToDisplayString(NamespacedTypeName)) + ? x.TId.Is(x.KnownTypes.Compilation.GetTypeByMetadataName("System.IParsable`1")) && + x.TId.Is(x.KnownTypes.Compilation.GetTypeByMetadataName("System.IFormattable")) + // parsable+formattable will result in the parsable template being used as the converter + // so we use string as the underlying EF type. + ? "string" : x.TId.ToFullName() + : x.TId.ToFullName())), + customConverters.Select(x => new ConverterModel(x.BaseType!.TypeArguments[0].ToFullName(), x.BaseType!.TypeArguments[1].ToFullName(), x.ToFullName())), + templatizedConverters + .Where(x => !builtInTypesMap.ContainsKey(x.TValue.ToDisplayString(NamespacedTypeName))) + .Select(x => new TemplatizedModel(x))); - var output = template.Render(model, member => member.Name); + var output = selectorTemplate.Render(model, member => member.Name); context.AddSource($"ValueConverterSelector.cs", output); } record StructIdModel(string TSelf, string TIdType) { - public string TId => TIdType switch - { - "System.String" => "string", - "System.Int32" => "int", - "System.Int64" => "long", - "System.Boolean" => "bool", - "System.Single" => "float", - "System.Double" => "double", - "System.Decimal" => "decimal", - "System.DateTime" => "DateTime", - "System.Guid" => "Guid", - "System.TimeSpan" => "TimeSpan", - "System.Byte" => "byte", - "System.Byte[]" => "byte[]", - "System.Char" => "char", - "System.UInt32" => "uint", - "System.UInt64" => "ulong", - "System.SByte" => "sbyte", - "System.UInt16" => "ushort", - "System.Int16" => "short", - "System.Object" => "object", - _ => TIdType - }; + public string TId => builtInTypesMap.TryGetValue(TIdType, out var value) ? value : TIdType; } record ConverterModel(string TModel, string TProvider, string TConverter); - record SelectorModel(IEnumerable Ids, IEnumerable Converters); + class TemplatizedModel + { + public TemplatizedModel(TValueTemplate template) + { + var declaration = template.Template.Syntax.ApplyValue(template.TValue) + .DescendantNodes() + .OfType() + .First(); + + TModel = template.TValue.ToFullName(); + TConverter = declaration.Identifier.Text; + Code = declaration.ToFullString(); + } + + public TemplatizedModel(string tvalue, string tconverter, string code) + => (TModel, TConverter, Code) = (tvalue, tconverter, code); + + public string TModel { get; } + public string TConverter { get; } + public string Code { get; } + } + + record SelectorModel(IEnumerable Ids, IEnumerable Converters, IEnumerable Templatized); } \ No newline at end of file diff --git a/src/StructId.Analyzer/EntityFrameworkSelector.sbn b/src/StructId.Analyzer/EntityFrameworkSelector.sbn index c9a02a1..f799fce 100644 --- a/src/StructId.Analyzer/EntityFrameworkSelector.sbn +++ b/src/StructId.Analyzer/EntityFrameworkSelector.sbn @@ -39,6 +39,7 @@ public static class StructIdDbContextOptionsBuilderExtensions modelClrType = Unwrap(modelClrType) ?? modelClrType; providerClrType = Unwrap(providerClrType); + // Struct ID converters {{~ for id in Ids ~}} if (modelClrType == typeof({{ id.TSelf }})) yield return converters.GetOrAdd((modelClrType, providerClrType), key => new ValueConverterInfo( @@ -46,6 +47,7 @@ public static class StructIdDbContextOptionsBuilderExtensions info => new {{ id.TSelf }}.EntityFrameworkValueConverter(info.MappingHints))); {{~ end ~}} + // Custom EF converters {{~ for converter in Converters ~}} if (modelClrType == typeof({{ converter.TModel }})) yield return converters.GetOrAdd((modelClrType, providerClrType), key => new ValueConverterInfo( @@ -53,6 +55,14 @@ public static class StructIdDbContextOptionsBuilderExtensions info => new {{ converter.TConverter }}(info.MappingHints))); {{~ end ~}} + // Templatized converters + {{~ for converter in Templatized ~}} + if (modelClrType == typeof({{ converter.TModel }})) + yield return converters.GetOrAdd((modelClrType, providerClrType), key => new ValueConverterInfo( + key.ModelClrType, key.ProviderClrType ?? typeof(string), + info => new {{ converter.TConverter }}(info.MappingHints))); + + {{~ end ~}} } static Type? Unwrap(Type? type) @@ -63,4 +73,9 @@ public static class StructIdDbContextOptionsBuilderExtensions return Nullable.GetUnderlyingType(type) ?? type; } } -} \ No newline at end of file +} + +// Templatized converters +{{~ for converter in Templatized ~}} +{{ converter.Code }} +{{~ end ~}} \ No newline at end of file diff --git a/src/StructId.FunctionalTests/Functional.cs b/src/StructId.FunctionalTests/Functional.cs index 1b5f915..c7d8fbd 100644 --- a/src/StructId.FunctionalTests/Functional.cs +++ b/src/StructId.FunctionalTests/Functional.cs @@ -91,8 +91,8 @@ public void EntityFramework() var product = new Product(new ProductId(id), "Product"); // Seed data - context.Products.Add(new Product(ProductId.New(), "Product1")); context.Products.Add(product); + context.Products.Add(new Product(ProductId.New(), "Product1")); context.Products.Add(new Product(ProductId.New(), "Product2")); context.SaveChanges(); diff --git a/src/StructId.FunctionalTests/UlidTests.cs b/src/StructId.FunctionalTests/UlidTests.cs index 42c923c..787189e 100644 --- a/src/StructId.FunctionalTests/UlidTests.cs +++ b/src/StructId.FunctionalTests/UlidTests.cs @@ -41,13 +41,15 @@ public override void SetValue(IDbDataParameter parameter, Ulid value) } } -public partial class UlidToStringConverter : ValueConverter -{ - public UlidToStringConverter() : this(null) { } +// showcases a custom EF value converter trumps the built-in templatized +// support for types that provide IParsable and IFormattable +//public partial class UlidToStringConverter : ValueConverter +//{ +// public UlidToStringConverter() : this(null) { } - public UlidToStringConverter(ConverterMappingHints? mappingHints = null) - : base(id => id.ToString(), value => Ulid.Parse(value), mappingHints) { } -} +// public UlidToStringConverter(ConverterMappingHints? mappingHints = null) +// : base(id => id.ToString(), value => Ulid.Parse(value), mappingHints) { } +//} // showcases alternative serialization //public class BinaryUlidHandler : TypeHandler @@ -126,8 +128,8 @@ public void EntityFramework() var product = new UlidProduct(new UlidId(id), "Product"); // Seed data - context.Products.Add(new UlidProduct(UlidId.New(), "Product1")); context.Products.Add(product); + context.Products.Add(new UlidProduct(UlidId.New(), "Product1")); context.Products.Add(new UlidProduct(UlidId.New(), "Product2")); context.SaveChanges(); @@ -152,12 +154,12 @@ public class UlidContext : DbContext public UlidContext(DbContextOptions options) : base(options) { } public DbSet Products { get; set; } = null!; - protected override void OnModelCreating(ModelBuilder modelBuilder) - { - modelBuilder.Entity().Property(x => x.Id) - //.HasConversion(new UlidToStringConverter()) - .HasConversion(new UlidId.EntityFrameworkUlidValueConverter()); - //.HasConversion(new UlidId.EntityFrameworkValueConverter()); - } + //protected override void OnModelCreating(ModelBuilder modelBuilder) + //{ + // //modelBuilder.Entity().Property(x => x.Id) + // //.HasConversion(new UlidToStringConverter()) + // //.HasConversion(new UlidId.EntityFrameworkUlidValueConverter()); + // //.HasConversion(new UlidId.EntityFrameworkValueConverter()); + //} } } diff --git a/src/StructId/ResourceTemplates/EntityFramework.cs b/src/StructId/ResourceTemplates/EntityFramework.cs index 95c5b8d..345b15d 100644 --- a/src/StructId/ResourceTemplates/EntityFramework.cs +++ b/src/StructId/ResourceTemplates/EntityFramework.cs @@ -5,12 +5,12 @@ using StructId; [TStructId] -file partial record struct TSelf(TId Value) : INewable +file partial record struct TSelf(TValue Value) : INewable { /// /// Provides value conversion for Entity Framework Core /// - public partial class EntityFrameworkValueConverter : ValueConverter + public partial class EntityFrameworkValueConverter : ValueConverter { public EntityFrameworkValueConverter() : this(null) { } @@ -21,7 +21,7 @@ public EntityFrameworkValueConverter(ConverterMappingHints? mappingHints = null) file partial record struct TSelf { - public static TSelf New(TId value) => throw new System.NotImplementedException(); + public static TSelf New(TValue value) => throw new System.NotImplementedException(); } -file partial record struct TId; \ No newline at end of file +file partial record struct TValue; \ No newline at end of file diff --git a/src/StructId/ResourceTemplates/EntityFrameworkParsable.cs b/src/StructId/ResourceTemplates/EntityFrameworkParsable.cs new file mode 100644 index 0000000..080c958 --- /dev/null +++ b/src/StructId/ResourceTemplates/EntityFrameworkParsable.cs @@ -0,0 +1,34 @@ +// +#nullable enable + +using System.Diagnostics.CodeAnalysis; +using System; +using Microsoft.EntityFrameworkCore.Storage.ValueConversion; +using StructId; + +[TStructId] +file partial record struct TSelf(TValue Value) : INewable +{ + /// + /// Provides value conversion for Entity Framework Core + /// + public partial class EntityFrameworkValueConverter : ValueConverter + { + public EntityFrameworkValueConverter() : this(null) { } + + public EntityFrameworkValueConverter(ConverterMappingHints? mappingHints = null) + : base(id => id.Value.ToString(null, null), value => TSelf.New(TValue.Parse(value, null)), mappingHints) { } + } +} + +file partial record struct TSelf +{ + public static TSelf New(TValue value) => throw new System.NotImplementedException(); +} + +file partial struct TValue : IParsable, IFormattable +{ + public static TValue Parse(string s, IFormatProvider? provider) => throw new NotImplementedException(); + public static bool TryParse([NotNullWhen(true)] string? s, IFormatProvider? provider, [MaybeNullWhen(false)] out TValue result) => throw new NotImplementedException(); + public string ToString(string? format, IFormatProvider? formatProvider) => throw new NotImplementedException(); +} \ No newline at end of file diff --git a/src/StructId/Templates/DapperTypeHandler.cs b/src/StructId/Templates/DapperTypeHandler.cs index 4473fe0..a6b8c33 100644 --- a/src/StructId/Templates/DapperTypeHandler.cs +++ b/src/StructId/Templates/DapperTypeHandler.cs @@ -3,21 +3,22 @@ using System.Diagnostics.CodeAnalysis; using StructId; +// TODO: pending making it conditionally included at compile-time [TValue] -file class TId_TypeHandler : Dapper.SqlMapper.TypeHandler +file class TId_TypeHandler : Dapper.SqlMapper.TypeHandler { - public override TId Parse(object value) => TId.Parse((string)value, null); + public override TValue Parse(object value) => TValue.Parse((string)value, null); - public override void SetValue(IDbDataParameter parameter, TId value) + public override void SetValue(IDbDataParameter parameter, TValue value) { parameter.DbType = DbType.String; parameter.Value = value.ToString(null, null); } } -file partial struct TId : IParsable, IFormattable +file partial struct TValue : IParsable, IFormattable { - public static TId Parse(string s, IFormatProvider? provider) => throw new NotImplementedException(); - public static bool TryParse([NotNullWhen(true)] string? s, IFormatProvider? provider, [MaybeNullWhen(false)] out TId result) => throw new NotImplementedException(); + public static TValue Parse(string s, IFormatProvider? provider) => throw new NotImplementedException(); + public static bool TryParse([NotNullWhen(true)] string? s, IFormatProvider? provider, [MaybeNullWhen(false)] out TValue result) => throw new NotImplementedException(); public string ToString(string? format, IFormatProvider? formatProvider) => throw new NotImplementedException(); } \ No newline at end of file diff --git a/src/StructId/Templates/EntityFrameworkValueConverter.cs b/src/StructId/Templates/EntityFrameworkValueConverter.cs index dbbbcc6..a201935 100644 --- a/src/StructId/Templates/EntityFrameworkValueConverter.cs +++ b/src/StructId/Templates/EntityFrameworkValueConverter.cs @@ -6,18 +6,19 @@ using Microsoft.EntityFrameworkCore.Storage.ValueConversion; using StructId; +// TODO: pending making it conditionally included at compile-time [TValue] -file class TId_ValueConverter : ValueConverter +file class TValue_ValueConverter : ValueConverter { - public TId_ValueConverter() : this(null) { } + public TValue_ValueConverter() : this(null) { } - public TId_ValueConverter(ConverterMappingHints? mappingHints = null) - : base(id => id.ToString(null, null), value => TId.Parse(value, null), mappingHints) { } + public TValue_ValueConverter(ConverterMappingHints? mappingHints = null) + : base(id => id.ToString(null, null), value => TValue.Parse(value, null), mappingHints) { } } -file partial struct TId : IParsable, IFormattable +file partial struct TValue : IParsable, IFormattable { - public static TId Parse(string s, IFormatProvider? provider) => throw new NotImplementedException(); - public static bool TryParse([NotNullWhen(true)] string? s, IFormatProvider? provider, [MaybeNullWhen(false)] out TId result) => throw new NotImplementedException(); + public static TValue Parse(string s, IFormatProvider? provider) => throw new NotImplementedException(); + public static bool TryParse([NotNullWhen(true)] string? s, IFormatProvider? provider, [MaybeNullWhen(false)] out TValue result) => throw new NotImplementedException(); public string ToString(string? format, IFormatProvider? formatProvider) => throw new NotImplementedException(); } \ No newline at end of file