Skip to content

Commit

Permalink
fix: do not generalise enumerable dictionary value types (#1155)
Browse files Browse the repository at this point in the history
  • Loading branch information
latonz authored Mar 9, 2024
1 parent da4115c commit b4d5e74
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 65 deletions.
27 changes: 0 additions & 27 deletions src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfo.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Descriptors.Enumerables;

Expand All @@ -22,30 +21,4 @@ bool IsImmutableCollectionType

[MemberNotNullWhen(true, nameof(CountPropertyName))]
public bool CountIsKnown => CountPropertyName != null;

public (ITypeSymbol, ITypeSymbol)? GetDictionaryKeyValueTypes(MappingBuilderContext ctx)
{
if (Type.ImplementsGeneric(ctx.Types.Get(typeof(IDictionary<,>)), out var dictionaryImpl))
{
return (dictionaryImpl.TypeArguments[0], dictionaryImpl.TypeArguments[1]);
}

if (Type.ImplementsGeneric(ctx.Types.Get(typeof(IReadOnlyDictionary<,>)), out var readOnlyDictionaryImpl))
{
return (readOnlyDictionaryImpl.TypeArguments[0], readOnlyDictionaryImpl.TypeArguments[1]);
}

return null;
}

public (ITypeSymbol, ITypeSymbol)? GetEnumeratedKeyValueTypes(WellKnownTypes types)
{
if (
EnumeratedType is not INamedTypeSymbol namedEnumeratedType
|| !SymbolEqualityComparer.Default.Equals(namedEnumeratedType.ConstructedFrom, types.Get(typeof(KeyValuePair<,>)))
)
return null;

return (namedEnumeratedType.TypeArguments[0], namedEnumeratedType.TypeArguments[1]);
}
}
5 changes: 5 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/DictionaryInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using Microsoft.CodeAnalysis;

namespace Riok.Mapperly.Descriptors.Enumerables;

public record DictionaryInfo(CollectionInfo Collection, ITypeSymbol Key, ITypeSymbol Value);
65 changes: 65 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/DictionaryInfoBuilder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Descriptors.Enumerables;

public static class DictionaryInfoBuilder
{
public static DictionaryInfos? Build(WellKnownTypes types, CollectionInfos? collectionInfos)
{
if (collectionInfos == null)
return null;

var source = BuildSource(types, collectionInfos);
if (source == null)
return null;

var target = BuildTarget(types, collectionInfos);
if (target == null)
return null;

return new DictionaryInfos(source, target);
}

private static DictionaryInfo? BuildSource(WellKnownTypes types, CollectionInfos infos)
{
if (GetEnumeratedKeyValueTypes(types, infos.Source) is not var (key, value))
return null;

return new DictionaryInfo(infos.Source, key, value);
}

private static DictionaryInfo? BuildTarget(WellKnownTypes types, CollectionInfos infos)
{
if (GetDictionaryKeyValueTypes(types, infos.Target) is not var (key, value))
return null;

return new DictionaryInfo(infos.Target, key, value);
}

private static (ITypeSymbol, ITypeSymbol)? GetDictionaryKeyValueTypes(WellKnownTypes types, CollectionInfo info)
{
if (info.Type.ImplementsGeneric(types.Get(typeof(IDictionary<,>)), out var dictionaryImpl))
{
return (dictionaryImpl.TypeArguments[0], dictionaryImpl.TypeArguments[1]);
}

if (info.Type.ImplementsGeneric(types.Get(typeof(IReadOnlyDictionary<,>)), out var readOnlyDictionaryImpl))
{
return (readOnlyDictionaryImpl.TypeArguments[0], readOnlyDictionaryImpl.TypeArguments[1]);
}

return null;
}

private static (ITypeSymbol, ITypeSymbol)? GetEnumeratedKeyValueTypes(WellKnownTypes types, CollectionInfo info)
{
if (
info.EnumeratedType is not INamedTypeSymbol namedEnumeratedType
|| !SymbolEqualityComparer.Default.Equals(namedEnumeratedType.ConstructedFrom, types.Get(typeof(KeyValuePair<,>)))
)
return null;

return (namedEnumeratedType.TypeArguments[0], namedEnumeratedType.TypeArguments[1]);
}
}
3 changes: 3 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/DictionaryInfos.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
namespace Riok.Mapperly.Descriptors.Enumerables;

public record DictionaryInfos(DictionaryInfo Source, DictionaryInfo Target);
3 changes: 3 additions & 0 deletions src/Riok.Mapperly/Descriptors/MappingBuilderContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public class MappingBuilderContext : SimpleMappingBuilderContext
{
private readonly FormatProviderCollection _formatProviders;
private CollectionInfos? _collectionInfos;
private DictionaryInfos? _dictionaryInfos;

public MappingBuilderContext(
SimpleMappingBuilderContext parentCtx,
Expand Down Expand Up @@ -58,6 +59,8 @@ bool ignoreDerivedTypes

public CollectionInfos? CollectionInfos => _collectionInfos ??= CollectionInfoBuilder.Build(Types, SymbolAccessor, Source, Target);

public DictionaryInfos? DictionaryInfos => _dictionaryInfos ??= DictionaryInfoBuilder.Build(Types, CollectionInfos);

protected IMethodSymbol? UserSymbol { get; }

public bool HasUserSymbol => UserSymbol != null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public static class DictionaryMappingBuilder
if (!ctx.IsConversionEnabled(MappingConversionType.Dictionary))
return null;

if (ctx.CollectionInfos == null)
if (ctx.DictionaryInfos == null)
return null;

if (BuildKeyValueMapping(ctx) is not var (keyMapping, valueMapping))
Expand Down Expand Up @@ -69,12 +69,7 @@ INewInstanceMapping valueMapping
var sourceCollectionInfo = ctx.CollectionInfos.Source;
if (!hasObjectFactory)
{
sourceCollectionInfo = BuildCollectionTypeForIDictionary(
ctx,
ctx.CollectionInfos!.Source,
keyMapping.SourceType,
valueMapping.SourceType
);
sourceCollectionInfo = BuildCollectionTypeForSourceIDictionary(ctx);
ctx.ObjectFactories.TryFindObjectFactory(ctx.Source, ctx.Target, out objectFactory);

var existingMapping = ctx.BuildDelegatedMapping(sourceCollectionInfo.Type, ctx.Target);
Expand Down Expand Up @@ -116,16 +111,10 @@ INewInstanceMapping valueMapping
var targetType = ctx.Target;
if (!hasObjectFactory)
{
sourceCollectionInfo = BuildCollectionTypeForIDictionary(
ctx,
sourceCollectionInfo,
keyMapping.SourceType,
valueMapping.SourceType
);

sourceCollectionInfo = BuildCollectionTypeForSourceIDictionary(ctx);
targetType = ctx
.Types.Get(typeof(Dictionary<,>))
.Construct(keyMapping.TargetType, valueMapping.TargetType)
.Construct(ctx.DictionaryInfos!.Target.Key, ctx.DictionaryInfos.Target.Value)
.WithNullableAnnotation(NullableAnnotation.NotAnnotated);

ctx.ObjectFactories.TryFindObjectFactory(sourceCollectionInfo.Type, targetType, out objectFactory);
Expand Down Expand Up @@ -181,17 +170,11 @@ INewInstanceMapping valueMapping

private static (INewInstanceMapping, INewInstanceMapping)? BuildKeyValueMapping(MappingBuilderContext ctx)
{
if (ctx.CollectionInfos!.Target.GetDictionaryKeyValueTypes(ctx) is not var (targetKeyType, targetValueType))
return null;

if (ctx.CollectionInfos.Source.GetEnumeratedKeyValueTypes(ctx.Types) is not var (sourceKeyType, sourceValueType))
return null;

var keyMapping = ctx.FindOrBuildMapping(sourceKeyType, targetKeyType);
var keyMapping = ctx.FindOrBuildMapping(ctx.DictionaryInfos!.Source.Key, ctx.DictionaryInfos.Target.Key);
if (keyMapping == null)
return null;

var valueMapping = ctx.FindOrBuildMapping(sourceValueType, targetValueType);
var valueMapping = ctx.FindOrBuildMapping(ctx.DictionaryInfos.Source.Value, ctx.DictionaryInfos.Target.Value);
if (valueMapping == null)
return null;

Expand Down Expand Up @@ -260,38 +243,33 @@ or CollectionType.IImmutableDictionary
};
}

private static CollectionInfo BuildCollectionTypeForIDictionary(
MappingBuilderContext ctx,
CollectionInfo info,
ITypeSymbol key,
ITypeSymbol value
)
private static CollectionInfo BuildCollectionTypeForSourceIDictionary(MappingBuilderContext ctx)
{
var info = ctx.CollectionInfos!.Source;

// the types cannot be changed for mappings with a user symbol
// as the types are defined by the user
if (ctx.HasUserSymbol)
return info;

var dictionaryType = info.ImplementedTypes.HasFlag(CollectionType.IReadOnlyDictionary)
? BuildDictionaryType(ctx, CollectionType.IReadOnlyDictionary, key, value)
? BuildSourceDictionaryType(ctx, CollectionType.IReadOnlyDictionary)
: info.ImplementedTypes.HasFlag(CollectionType.IDictionary)
? BuildDictionaryType(ctx, CollectionType.IDictionary, key, value)
? BuildSourceDictionaryType(ctx, CollectionType.IDictionary)
: null;

return dictionaryType == null
? info
: CollectionInfoBuilder.BuildCollectionInfo(ctx.Types, ctx.SymbolAccessor, dictionaryType, info.EnumeratedType);
}

private static INamedTypeSymbol BuildDictionaryType(
MappingBuilderContext ctx,
CollectionType type,
ITypeSymbol keyType,
ITypeSymbol valueType
)
private static INamedTypeSymbol BuildSourceDictionaryType(MappingBuilderContext ctx, CollectionType type)
{
var genericType = CollectionInfoBuilder.GetGenericClrCollectionType(type);
return (INamedTypeSymbol)
ctx.Types.Get(genericType).Construct(keyType, valueType).WithNullableAnnotation(NullableAnnotation.NotAnnotated);
ctx
.Types.Get(genericType)
.Construct(ctx.DictionaryInfos!.Source.Key, ctx.DictionaryInfos!.Source.Value)
.WithNullableAnnotation(NullableAnnotation.NotAnnotated);
}
}
28 changes: 28 additions & 0 deletions test/Riok.Mapperly.Tests/Mapping/DictionaryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,34 @@ public void MapToExistingCustomDictionary()
);
}

[Fact]
public Task DictionaryWithList()
{
var source = TestSourceBuilder.Mapping(
"A",
"B",
"record A(Dictionary<int, List<C>> Dict);",
"record B(Dictionary<int, List<D>> Dict);",
"record C(int Value);",
"record D(int Value);"
);
return TestHelper.VerifyGenerator(source);
}

[Fact]
public Task CustomDictionaryWithList()
{
var source = TestSourceBuilder.Mapping(
"A",
"B",
"class A : Dictionary<int, List<C>>;",
"class B : Dictionary<int, List<D>>;",
"record C(int Value);",
"record D(int Value);"
);
return TestHelper.VerifyGenerator(source);
}

[Fact]
public void KeyValueEnumerableToExistingDictionary()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//HintName: Mapper.g.cs
// <auto-generated />
#nullable enable
public partial class Mapper
{
[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private partial global::B Map(global::A source)
{
var target = new global::B();
target.EnsureCapacity(source.Count + target.Count);
foreach (var item in source)
{
target[item.Key] = MapToList(item.Value);
}
return target;
}

[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private global::D MapToD(global::C source)
{
var target = new global::D(source.Value);
return target;
}

[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private global::System.Collections.Generic.List<global::D> MapToList(global::System.Collections.Generic.IReadOnlyCollection<global::C> source)
{
var target = new global::System.Collections.Generic.List<global::D>(source.Count);
foreach (var item in source)
{
target.Add(MapToD(item));
}
return target;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//HintName: Mapper.g.cs
// <auto-generated />
#nullable enable
public partial class Mapper
{
[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private partial global::B Map(global::A source)
{
var target = new global::B(MapToDictionary(source.Dict));
return target;
}

[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private global::D MapToD(global::C source)
{
var target = new global::D(source.Value);
return target;
}

[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private global::System.Collections.Generic.List<global::D> MapToList(global::System.Collections.Generic.IReadOnlyCollection<global::C> source)
{
var target = new global::System.Collections.Generic.List<global::D>(source.Count);
foreach (var item in source)
{
target.Add(MapToD(item));
}
return target;
}

[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private global::System.Collections.Generic.Dictionary<int, global::System.Collections.Generic.List<global::D>> MapToDictionary(global::System.Collections.Generic.IReadOnlyDictionary<int, global::System.Collections.Generic.List<global::C>> source)
{
var target = new global::System.Collections.Generic.Dictionary<int, global::System.Collections.Generic.List<global::D>>(source.Count);
foreach (var item in source)
{
target[item.Key] = MapToList(item.Value);
}
return target;
}
}

0 comments on commit b4d5e74

Please sign in to comment.