Skip to content

Commit

Permalink
Add built-in support for parsable/formattable values for EF
Browse files Browse the repository at this point in the history
Rather than adding explicit Ulid support as suggested in #18, a more flexible approach is to simply rely on the combination of IParsable<T> plus IFormattable checks on the TValue/TId and emit a standard EF value converter that relies on those, plus the INewable<TSelf> we already have for the struct-id itself.

This involved a more flexible resource template selection mechanism, which needs further refactoring for simplification (and perhaps dropping altogether) of the BaseGenerator.
  • Loading branch information
kzu committed Dec 21, 2024
1 parent c526032 commit 6e401e2
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 73 deletions.
11 changes: 6 additions & 5 deletions src/StructId.Analyzer/BaseGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Text;

namespace StructId;
Expand Down Expand Up @@ -70,11 +69,13 @@ public virtual void Initialize(IncrementalGeneratorInitializationContext context

protected virtual IncrementalValuesProvider<TemplateArgs> OnInitialize(IncrementalGeneratorInitializationContext context, IncrementalValuesProvider<TemplateArgs> source) => source;

void GenerateCode(SourceProductionContext context, TemplateArgs args) => AddFromTemplate(
context, args, $"{args.TSelf.ToFileName()}.cs",
args.TId.Equals(args.KnownTypes.String, SymbolEqualityComparer.Default) ?
void GenerateCode(SourceProductionContext context, TemplateArgs args)
=> AddFromTemplate(context, args, $"{args.TSelf.ToFileName()}.cs", SelectTemplate(args));

protected virtual SyntaxNode SelectTemplate(TemplateArgs args)
=> args.TId.Equals(args.KnownTypes.String, SymbolEqualityComparer.Default) ?
(stringSyntax ??= CodeTemplate.Parse(stringTemplate, args.KnownTypes.Compilation.GetParseOptions())) :
(typedSyntax ??= CodeTemplate.Parse(typeTemplate, args.KnownTypes.Compilation.GetParseOptions())));
(typedSyntax ??= CodeTemplate.Parse(typeTemplate, args.KnownTypes.Compilation.GetParseOptions()));

protected static void AddFromTemplate(SourceProductionContext context, TemplateArgs args, string hintName, SyntaxNode template)
{
Expand Down
130 changes: 95 additions & 35 deletions src/StructId.Analyzer/EntityFrameworkGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,82 +2,142 @@
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Scriban;
using static StructId.AnalysisExtensions;

namespace StructId;

[Generator(LanguageNames.CSharp)]
public class EntityFrameworkGenerator() : BaseGenerator(
"Microsoft.EntityFrameworkCore.Storage.ValueConversion.ValueConverter`2",
ValueConverterType,
ThisAssembly.Resources.Templates.EntityFramework.Text,
ThisAssembly.Resources.Templates.EntityFramework.Text,
ReferenceCheck.TypeExists)
{
static readonly Template template = Template.Parse(ThisAssembly.Resources.EntityFrameworkSelector.Text);
public const string ValueConverterType = "Microsoft.EntityFrameworkCore.Storage.ValueConversion.ValueConverter`2";

static readonly Dictionary<string, string> builtInTypesMap = new()
{
["System.String"] = "string",
["System.Int32"] = "int",
["System.Int64"] = "long",
["System.Boolean"] = "bool",
["System.Single"] = "float",
["System.Double"] = "double",
["System.Decimal"] = "decimal",
["System.DateTime"] = "DateTime",
["System.Guid"] = "Guid",
["System.TimeSpan"] = "TimeSpan",
["System.Byte"] = "byte",
["System.Byte[]"] = "byte[]",
["System.Char"] = "char",
["System.UInt32"] = "uint",
["System.UInt64"] = "ulong",
["System.SByte"] = "sbyte",
["System.UInt16"] = "ushort",
["System.Int16"] = "short",
["System.Object"] = "object",
};

static readonly Template selectorTemplate = Template.Parse(ThisAssembly.Resources.EntityFrameworkSelector.Text);

SyntaxNode? idTemplate;
SyntaxNode? parsableIdTemplate;

protected override IncrementalValuesProvider<TemplateArgs> OnInitialize(IncrementalGeneratorInitializationContext context, IncrementalValuesProvider<TemplateArgs> source)
{
var converters = context.CompilationProvider
.SelectMany((x, _) => x.Assembly.GetAllTypes().OfType<INamedTypeSymbol>())
.Combine(context.CompilationProvider.Select((x, _) => x.GetTypeByMetadataName("Microsoft.EntityFrameworkCore.Storage.ValueConversion.ValueConverter`2")))
.Combine(context.CompilationProvider.Select((x, _) => x.GetTypeByMetadataName(ValueConverterType)))
.Where(x => x.Left != null && x.Right != null &&
x.Left.Is(x.Right) &&
!x.Left.IsUnboundGenericType &&
x.Left.BaseType?.TypeArguments.Length == 2 &&
// Don't emit as plain converters if they are id templates
// Don't emit as plain converters if they are value templates
!x.Left.GetAttributes().Any(a => a.IsValueTemplate()))
.Select((x, _) => x.Left)
.Collect();

context.RegisterSourceOutput(source.Collect().Combine(converters), GenerateValueSelector);
var templatizedValues = context.SelectTemplatizedValues()
.Combine(context.CompilationProvider.Select((x, _) => x.GetTypeByMetadataName(ValueConverterType)))
.Where(x => x.Left.Template.TTemplate.Is(x.Right))
.Select((x, _) => x.Left);

context.RegisterSourceOutput(source.Collect().Combine(converters).Combine(templatizedValues.Collect()), GenerateValueSelector);

return base.OnInitialize(context, source);
}

void GenerateValueSelector(SourceProductionContext context, (ImmutableArray<TemplateArgs>, ImmutableArray<INamedTypeSymbol>) args)
protected override SyntaxNode SelectTemplate(TemplateArgs args)
{
if (args.TId.Equals(args.KnownTypes.String, SymbolEqualityComparer.Default) ||
builtInTypesMap.ContainsKey(args.TId.ToDisplayString(NamespacedTypeName)))
return idTemplate ??= CodeTemplate.Parse(ThisAssembly.Resources.Templates.EntityFramework.Text, args.KnownTypes.Compilation.GetParseOptions());
else if (args.TId.Is(args.KnownTypes.Compilation.GetTypeByMetadataName("System.IParsable`1")) &&
args.TId.Is(args.KnownTypes.Compilation.GetTypeByMetadataName("System.IFormattable")))
return parsableIdTemplate ??= CodeTemplate.Parse(ThisAssembly.Resources.Templates.EntityFrameworkParsable.Text, args.KnownTypes.Compilation.GetParseOptions());
else
return idTemplate ??= CodeTemplate.Parse(ThisAssembly.Resources.Templates.EntityFramework.Text, args.KnownTypes.Compilation.GetParseOptions());
}

void GenerateValueSelector(SourceProductionContext context, ((ImmutableArray<TemplateArgs>, ImmutableArray<INamedTypeSymbol>), ImmutableArray<TValueTemplate>) args)
{
(var ids, var converters) = args;
((var structIds, var customConverters), var templatizedConverters) = args;

if (ids.Length == 0)
if (structIds.Length == 0 && customConverters.Length == 0 && templatizedConverters.Length == 0)
return;

var model = new SelectorModel(
ids.Select(x => new StructIdModel(x.TSelf.ToFullName(), x.TId.ToFullName())),
converters.Select(x => new ConverterModel(x.BaseType!.TypeArguments[0].ToFullName(), x.BaseType!.TypeArguments[1].ToFullName(), x.ToFullName())));
structIds.Select(x => new StructIdModel(x.TSelf.ToFullName(),
// The TId is used as the ProviderClrType for EF, which should be either a built-in
// supported type or a parsable one. We default to using the type as-is for future-proofing,
// but that may be subject to change.
!builtInTypesMap.ContainsKey(x.TId.ToDisplayString(NamespacedTypeName))
? x.TId.Is(x.KnownTypes.Compilation.GetTypeByMetadataName("System.IParsable`1")) &&
x.TId.Is(x.KnownTypes.Compilation.GetTypeByMetadataName("System.IFormattable"))
// parsable+formattable will result in the parsable template being used as the converter
// so we use string as the underlying EF type.
? "string" : x.TId.ToFullName()
: x.TId.ToFullName())),
customConverters.Select(x => new ConverterModel(x.BaseType!.TypeArguments[0].ToFullName(), x.BaseType!.TypeArguments[1].ToFullName(), x.ToFullName())),
templatizedConverters
.Where(x => !builtInTypesMap.ContainsKey(x.TValue.ToDisplayString(NamespacedTypeName)))
.Select(x => new TemplatizedModel(x)));

var output = template.Render(model, member => member.Name);
var output = selectorTemplate.Render(model, member => member.Name);

context.AddSource($"ValueConverterSelector.cs", output);
}

record StructIdModel(string TSelf, string TIdType)
{
public string TId => TIdType switch
{
"System.String" => "string",
"System.Int32" => "int",
"System.Int64" => "long",
"System.Boolean" => "bool",
"System.Single" => "float",
"System.Double" => "double",
"System.Decimal" => "decimal",
"System.DateTime" => "DateTime",
"System.Guid" => "Guid",
"System.TimeSpan" => "TimeSpan",
"System.Byte" => "byte",
"System.Byte[]" => "byte[]",
"System.Char" => "char",
"System.UInt32" => "uint",
"System.UInt64" => "ulong",
"System.SByte" => "sbyte",
"System.UInt16" => "ushort",
"System.Int16" => "short",
"System.Object" => "object",
_ => TIdType
};
public string TId => builtInTypesMap.TryGetValue(TIdType, out var value) ? value : TIdType;
}

record ConverterModel(string TModel, string TProvider, string TConverter);

record SelectorModel(IEnumerable<StructIdModel> Ids, IEnumerable<ConverterModel> Converters);
class TemplatizedModel
{
public TemplatizedModel(TValueTemplate template)
{
var declaration = template.Template.Syntax.ApplyValue(template.TValue)
.DescendantNodes()
.OfType<TypeDeclarationSyntax>()
.First();

TModel = template.TValue.ToFullName();
TConverter = declaration.Identifier.Text;
Code = declaration.ToFullString();
}

public TemplatizedModel(string tvalue, string tconverter, string code)
=> (TModel, TConverter, Code) = (tvalue, tconverter, code);

public string TModel { get; }
public string TConverter { get; }
public string Code { get; }
}

record SelectorModel(IEnumerable<StructIdModel> Ids, IEnumerable<ConverterModel> Converters, IEnumerable<TemplatizedModel> Templatized);
}
17 changes: 16 additions & 1 deletion src/StructId.Analyzer/EntityFrameworkSelector.sbn
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,30 @@ public static class StructIdDbContextOptionsBuilderExtensions
modelClrType = Unwrap(modelClrType) ?? modelClrType;
providerClrType = Unwrap(providerClrType);

// Struct ID converters
{{~ for id in Ids ~}}
if (modelClrType == typeof({{ id.TSelf }}))
yield return converters.GetOrAdd((modelClrType, providerClrType), key => new ValueConverterInfo(
key.ModelClrType, key.ProviderClrType ?? typeof({{ id.TId }}),
info => new {{ id.TSelf }}.EntityFrameworkValueConverter(info.MappingHints)));

{{~ end ~}}
// Custom EF converters
{{~ for converter in Converters ~}}
if (modelClrType == typeof({{ converter.TModel }}))
yield return converters.GetOrAdd((modelClrType, providerClrType), key => new ValueConverterInfo(
key.ModelClrType, key.ProviderClrType ?? typeof({{ converter.TProvider }}),
info => new {{ converter.TConverter }}(info.MappingHints)));

{{~ end ~}}
// Templatized converters
{{~ for converter in Templatized ~}}
if (modelClrType == typeof({{ converter.TModel }}))
yield return converters.GetOrAdd((modelClrType, providerClrType), key => new ValueConverterInfo(
key.ModelClrType, key.ProviderClrType ?? typeof(string),
info => new {{ converter.TConverter }}(info.MappingHints)));

{{~ end ~}}
}

static Type? Unwrap(Type? type)
Expand All @@ -63,4 +73,9 @@ public static class StructIdDbContextOptionsBuilderExtensions
return Nullable.GetUnderlyingType(type) ?? type;
}
}
}
}

// Templatized converters
{{~ for converter in Templatized ~}}
{{ converter.Code }}
{{~ end ~}}
2 changes: 1 addition & 1 deletion src/StructId.FunctionalTests/Functional.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ public void EntityFramework()
var product = new Product(new ProductId(id), "Product");

// Seed data
context.Products.Add(new Product(ProductId.New(), "Product1"));
context.Products.Add(product);
context.Products.Add(new Product(ProductId.New(), "Product1"));
context.Products.Add(new Product(ProductId.New(), "Product2"));

context.SaveChanges();
Expand Down
30 changes: 16 additions & 14 deletions src/StructId.FunctionalTests/UlidTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ public override void SetValue(IDbDataParameter parameter, Ulid value)
}
}

public partial class UlidToStringConverter : ValueConverter<Ulid, string>
{
public UlidToStringConverter() : this(null) { }
// showcases a custom EF value converter trumps the built-in templatized
// support for types that provide IParsable<T> and IFormattable
//public partial class UlidToStringConverter : ValueConverter<Ulid, string>
//{
// public UlidToStringConverter() : this(null) { }

public UlidToStringConverter(ConverterMappingHints? mappingHints = null)
: base(id => id.ToString(), value => Ulid.Parse(value), mappingHints) { }
}
// public UlidToStringConverter(ConverterMappingHints? mappingHints = null)
// : base(id => id.ToString(), value => Ulid.Parse(value), mappingHints) { }
//}

// showcases alternative serialization
//public class BinaryUlidHandler : TypeHandler<Ulid>
Expand Down Expand Up @@ -126,8 +128,8 @@ public void EntityFramework()
var product = new UlidProduct(new UlidId(id), "Product");

// Seed data
context.Products.Add(new UlidProduct(UlidId.New(), "Product1"));
context.Products.Add(product);
context.Products.Add(new UlidProduct(UlidId.New(), "Product1"));
context.Products.Add(new UlidProduct(UlidId.New(), "Product2"));

context.SaveChanges();
Expand All @@ -152,12 +154,12 @@ public class UlidContext : DbContext
public UlidContext(DbContextOptions<UlidContext> options) : base(options) { }
public DbSet<UlidProduct> Products { get; set; } = null!;

protected override void OnModelCreating(ModelBuilder modelBuilder)
{
modelBuilder.Entity<UlidProduct>().Property(x => x.Id)
//.HasConversion(new UlidToStringConverter())
.HasConversion(new UlidId.EntityFrameworkUlidValueConverter());
//.HasConversion(new UlidId.EntityFrameworkValueConverter());
}
//protected override void OnModelCreating(ModelBuilder modelBuilder)
//{
// //modelBuilder.Entity<UlidProduct>().Property(x => x.Id)
// //.HasConversion(new UlidToStringConverter())
// //.HasConversion(new UlidId.EntityFrameworkUlidValueConverter());
// //.HasConversion(new UlidId.EntityFrameworkValueConverter());
//}
}
}
8 changes: 4 additions & 4 deletions src/StructId/ResourceTemplates/EntityFramework.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
using StructId;

[TStructId]
file partial record struct TSelf(TId Value) : INewable<TSelf, TId>
file partial record struct TSelf(TValue Value) : INewable<TSelf, TValue>
{
/// <summary>
/// Provides value conversion for Entity Framework Core
/// </summary>
public partial class EntityFrameworkValueConverter : ValueConverter<TSelf, TId>
public partial class EntityFrameworkValueConverter : ValueConverter<TSelf, TValue>
{
public EntityFrameworkValueConverter() : this(null) { }

Expand All @@ -21,7 +21,7 @@ public EntityFrameworkValueConverter(ConverterMappingHints? mappingHints = null)

file partial record struct TSelf
{
public static TSelf New(TId value) => throw new System.NotImplementedException();
public static TSelf New(TValue value) => throw new System.NotImplementedException();
}

file partial record struct TId;
file partial record struct TValue;
34 changes: 34 additions & 0 deletions src/StructId/ResourceTemplates/EntityFrameworkParsable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// <auto-generated />
#nullable enable

using System.Diagnostics.CodeAnalysis;
using System;
using Microsoft.EntityFrameworkCore.Storage.ValueConversion;
using StructId;

[TStructId]
file partial record struct TSelf(TValue Value) : INewable<TSelf, TValue>
{
/// <summary>
/// Provides value conversion for Entity Framework Core
/// </summary>
public partial class EntityFrameworkValueConverter : ValueConverter<TSelf, string>
{
public EntityFrameworkValueConverter() : this(null) { }

public EntityFrameworkValueConverter(ConverterMappingHints? mappingHints = null)
: base(id => id.Value.ToString(null, null), value => TSelf.New(TValue.Parse(value, null)), mappingHints) { }
}
}

file partial record struct TSelf
{
public static TSelf New(TValue value) => throw new System.NotImplementedException();
}

file partial struct TValue : IParsable<TValue>, IFormattable
{
public static TValue Parse(string s, IFormatProvider? provider) => throw new NotImplementedException();
public static bool TryParse([NotNullWhen(true)] string? s, IFormatProvider? provider, [MaybeNullWhen(false)] out TValue result) => throw new NotImplementedException();
public string ToString(string? format, IFormatProvider? formatProvider) => throw new NotImplementedException();
}
Loading

0 comments on commit 6e401e2

Please sign in to comment.