Skip to content

Commit

Permalink
Unwrap union (#176)
Browse files Browse the repository at this point in the history
* Implement Unwrap method for variants.
  • Loading branch information
domn1995 authored Oct 26, 2023
1 parent 8063fe5 commit 8eed569
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/UnionGeneration/UnionDeclaration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ ImmutableEquatableArray<Property> Properties
{
// Extension methods cannot be generated for a union declared in a top level program (no namespace).
// It also doesn't make sense to generate Match extensions if there are no variants to match against.
public bool SupportsAsyncMatchExtensionMethods() => Namespace is not null && Variants.Count > 0;
public bool SupportsExtensionMethods() => Namespace is not null && Variants.Count > 0;

public bool SupportsImplicitConversions()
{
Expand Down
7 changes: 4 additions & 3 deletions src/UnionGeneration/UnionGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

var parsedModel = compilation.Select(static (x, token) => Parse(x.Left, x.Right, token));
var splitModel = parsedModel.SelectMany(static (result, _) => result);

context.RegisterSourceOutput(splitModel, Emit);
}

Expand Down Expand Up @@ -53,7 +53,7 @@ private static void Emit(SourceProductionContext context, UnionDeclaration union
return;
}

if (unionRecord.SupportsAsyncMatchExtensionMethods())
if (unionRecord.SupportsExtensionMethods())
{
var matchExtensions = UnionExtensionsSourceBuilder.GenerateExtensions(unionRecord);
context.AddSource(
Expand Down Expand Up @@ -100,7 +100,8 @@ CancellationToken cancellation
Namespace: @namespace,
Accessibility: recordSymbol.DeclaredAccessibility,
Name: recordSymbol.Name,
TypeParameters: typeParameters?.ToImmutableEquatableArray() ?? ImmutableEquatableArray.Empty<TypeParameter>(),
TypeParameters: typeParameters?.ToImmutableEquatableArray()
?? ImmutableEquatableArray.Empty<TypeParameter>(),
TypeParameterConstraints: typeParameterConstraints.ToImmutableEquatableArray(),
Variants: variants.ToImmutableEquatableArray(),
ParentTypes: parentTypes.ToImmutableEquatableArray(),
Expand Down
73 changes: 66 additions & 7 deletions src/UnionGeneration/UnionSourceBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public static string Build(UnionDeclaration union)

builder.AppendAbstractMatchMethods(union);
builder.AppendAbstractSpecificMatchMethods(union);
builder.AppendAbstractUnwrapMethods(union);

if (union.SupportsImplicitConversions())
{
Expand All @@ -62,6 +63,8 @@ public static string Build(UnionDeclaration union)

builder.AppendVariantMatchMethodImplementations(union, variant);
builder.AppendVariantSpecificMatchMethodImplementations(union, variant);
builder.AppendUnwrapMethodImplementations(union, variant);
builder.AppendLine(" }");
}

builder.AppendLine("}");
Expand Down Expand Up @@ -218,7 +221,9 @@ UnionDeclaration union
// System.Func<TState, Specific<T1, T2, ...>, TMatchOutput> @specific,
// System.Func<TState, TMatchOutput> @else
// );
builder.AppendLine($" public abstract TMatchOutput Match{variant.Identifier}<TState, TMatchOutput>(");
builder.AppendLine(
$" public abstract TMatchOutput Match{variant.Identifier}<TState, TMatchOutput>("
);
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
builder.Append($" System.Func<TState, {variant.Identifier}");
Expand Down Expand Up @@ -319,7 +324,9 @@ VariantDeclaration variant
}
builder.AppendLine();
}
builder.AppendLine($" ) => {variant.Identifier.ToMethodParameterCase()}(state, this);");
builder.AppendLine(
$" ) => {variant.Identifier.ToMethodParameterCase()}(state, this);"
);

// public override void Match<TState>(
// TState state,
Expand All @@ -342,7 +349,9 @@ VariantDeclaration variant
}
builder.AppendLine();
}
builder.AppendLine($" ) => {variant.Identifier.ToMethodParameterCase()}(state, this);");
builder.AppendLine(
$" ) => {variant.Identifier.ToMethodParameterCase()}(state, this);"
);

return builder;
}
Expand Down Expand Up @@ -425,7 +434,9 @@ VariantDeclaration variant
builder.Append(" ) => ");
if (specificVariant.Identifier == variant.Identifier)
{
builder.AppendLine($"{specificVariant.Identifier.ToMethodParameterCase()}(state, this);");
builder.AppendLine(
$"{specificVariant.Identifier.ToMethodParameterCase()}(state, this);"
);
}
else
{
Expand All @@ -441,7 +452,9 @@ VariantDeclaration variant
// ) => unionVariantX(state, this);
foreach (var specificVariant in union.Variants)
{
builder.AppendLine($" public override void Match{specificVariant.Identifier}<TState>(");
builder.AppendLine(
$" public override void Match{specificVariant.Identifier}<TState>("
);
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
builder.Append($" System.Action<TState, {specificVariant.Identifier}");
Expand All @@ -451,17 +464,63 @@ VariantDeclaration variant
builder.Append(" ) => ");
if (specificVariant.Identifier == variant.Identifier)
{
builder.AppendLine($"{specificVariant.Identifier.ToMethodParameterCase()}(state, this);");
builder.AppendLine(
$"{specificVariant.Identifier.ToMethodParameterCase()}(state, this);"
);
}
else
{
builder.AppendLine("@else(state);");
}
}

builder.AppendLine(" }");
builder.AppendLine();

return builder;
}

private static StringBuilder AppendAbstractUnwrapMethods(
this StringBuilder builder,
UnionDeclaration union
)
{
foreach (var variant in union.Variants)
{
// public abstract Variant UnwrapVariant();
builder.AppendLine(
$" public abstract {variant.Identifier} Unwrap{variant.Identifier}();"
);
}

builder.AppendLine();

return builder;
}

private static StringBuilder AppendUnwrapMethodImplementations(
this StringBuilder builder,
UnionDeclaration union,
VariantDeclaration variant
)
{
foreach (var specificVariant in union.Variants)
{
builder.Append(
$" public override {specificVariant.Identifier} Unwrap{specificVariant.Identifier}() => "
);

if (specificVariant.Identifier == variant.Identifier)
{
builder.AppendLine("this;");
}
else
{
builder.AppendLine(
$"throw new System.InvalidOperationException($\"Called `{union.Name}.Unwrap{specificVariant.Identifier}()` on `{variant.Identifier}` value.\");"
);
}
}

return builder;
}
}
4 changes: 1 addition & 3 deletions test/Runtime/AssemblyExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@ namespace Dunet.Test.Runtime;
/// </remarks>
internal static class AssemblyExtensions
{
public static T? ExecuteStaticMethod<T>(this Assembly assembly, string methodName)
where T : notnull =>
public static T? ExecuteStaticMethod<T>(this Assembly assembly, string methodName) =>
(T?)
assembly.DefinedTypes
.SelectMany(type => type.DeclaredMethods)
.FirstOrDefault(method => method.Name.Contains(methodName))
?.Invoke(null, null);

public static T? ExecuteStaticAsyncMethod<T>(this Assembly assembly, string methodName)
where T : notnull
{
var task =
assembly.DefinedTypes
Expand Down
112 changes: 112 additions & 0 deletions test/UnionGeneration/UnwrapTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
using System.Reflection;

namespace Dunet.Test.UnionGeneration;

public sealed class UnwrapTests
{
[Fact]
public void CanUseUnwrapMethodToUnsafelyGetVariantValue()
{
// Arrange.
var programCs = """
using Dunet;

var value = GetValue();

static int GetValue()
{
var option = new Option.Some(1);
return option.UnwrapSome().Value;
}

[Union]
public partial record Option
{
public partial record Some(int Value);
public partial record None;
}
""";

// Act.
var compilation = Compiler.Compile(programCs);
var value = compilation.Assembly?.ExecuteStaticMethod<int>("GetValue");

// Assert.
using var scope = new AssertionScope();
compilation.CompilationErrors.Should().BeEmpty();
compilation.GenerationErrors.Should().BeEmpty();
value.Should().Be(1);
}

[Fact]
public void CanUseUnwrapMethodToUnsafelyGetGenericVariantValue()
{
// Arrange.
var programCs = """
using Dunet;

var value = GetValue();

static int GetValue()
{
var option = new Option<int>.Some(1);
return option.UnwrapSome().Value;
}

[Union]
public partial record Option<T>
{
public partial record Some(T Value);
public partial record None;
}
""";

// Act.
var compilation = Compiler.Compile(programCs);
var value = compilation.Assembly?.ExecuteStaticMethod<int>("GetValue");

// Assert.
using var scope = new AssertionScope();
compilation.CompilationErrors.Should().BeEmpty();
compilation.GenerationErrors.Should().BeEmpty();
value.Should().Be(1);
}

[Fact]
public void UnwrapMethodThrowsWhenCalledWithWrongUnderlyingValue()
{
// Arrange.
var programCs = """
using Dunet;

var value = GetValue();

static int GetValue()
{
var option = new Option.None();
return option.UnwrapSome().Value;
}

[Union]
public partial record Option
{
public partial record Some(int Value);
public partial record None;
}
""";

// Act.
var compilation = Compiler.Compile(programCs);
var action = () => compilation.Assembly?.ExecuteStaticMethod<int>("GetValue");

// Assert.
using var scope = new AssertionScope();
compilation.CompilationErrors.Should().BeEmpty();
compilation.GenerationErrors.Should().BeEmpty();
action
.Should()
.Throw<TargetInvocationException>()
.WithInnerExceptionExactly<InvalidOperationException>()
.WithMessage("Called `Option.UnwrapSome()` on `None` value.");
}
}

0 comments on commit 8eed569

Please sign in to comment.