Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dropping default ValueTask #49

Merged
merged 10 commits into from
Jan 26, 2024
49 changes: 48 additions & 1 deletion src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel) : CSharpS
private const string IAsyncEnumerator = "System.Collections.Generic.IAsyncEnumerator";
private const string FromResult = "FromResult";
private const string Delay = "Delay";
private const string CompletedTask = "CompletedTask";
private static readonly HashSet<string> Drops = [IProgressInterface, CancellationTokenType];
private static readonly HashSet<string> InterfacesToDrop = [IProgressInterface, IAsyncResultInterface];
private static readonly Dictionary<string, string?> Replacements = new()
Expand Down Expand Up @@ -498,11 +499,31 @@ bool IsValidParameter(ParameterSyntax ps)
return @base;
}

/// <inheritdoc/>
public override SyntaxNode? VisitImplicitObjectCreationExpression(ImplicitObjectCreationExpressionSyntax node)
{
var @base = base.VisitImplicitObjectCreationExpression(node);
var symbol = GetSymbol(node);

if (TryReplaceObjectCreation(node, symbol, out var replacement))
{
return replacement;
}

return @base;
}

/// <inheritdoc/>
public override SyntaxNode? VisitObjectCreationExpression(ObjectCreationExpressionSyntax node)
{
var @base = (ObjectCreationExpressionSyntax)base.VisitObjectCreationExpression(node)!;
var symbol = GetSymbol(node);

if (TryReplaceObjectCreation(node, symbol, out var replacement))
{
return replacement;
}

if (symbol is null
or { ContainingType.IsGenericType: true }
or INamedTypeSymbol { IsGenericType: true })
Expand Down Expand Up @@ -1158,6 +1179,7 @@ private static bool ShouldRemoveType(ITypeSymbol symbol)

private static bool ShouldRemoveArgument(ISymbol symbol) => symbol switch
{
IPropertySymbol { Name: CompletedTask } ps => ps.Type.ToString() is TaskType or ValueTaskType,
IMethodSymbol ms =>
IsSpecialMethod(ms) == SpecialMethod.None
&& ((ShouldRemoveType(ms.ReturnType) && ms.MethodKind != MethodKind.LocalFunction)
Expand Down Expand Up @@ -1201,6 +1223,20 @@ private static TypeSyntax GetReturnType(TypeSyntax returnType, INamedTypeSymbol

private static string Global(string type) => $"global::{type}";

private static bool TryReplaceObjectCreation(BaseObjectCreationExpressionSyntax node, ISymbol? symbol, out SyntaxNode? replacement)
{
if (symbol is IMethodSymbol { ReceiverType: INamedTypeSymbol { Name: "ValueTask", IsGenericType: true } type }
&& GetNameWithoutTypeParams(type) is ValueTaskType
&& node.ArgumentList is { Arguments: [var singleArg] })
{
replacement = singleArg.Expression;
return true;
}

replacement = default;
return false;
}

private static InvocationExpressionSyntax UnwrapExtension(InvocationExpressionSyntax ies, bool isMemory, IMethodSymbol reducedFrom, ExpressionSyntax expression)
{
var arguments = ies.ArgumentList.Arguments;
Expand Down Expand Up @@ -1514,14 +1550,25 @@ private bool ShouldRemoveArrowExpression(ArrowExpressionClauseSyntax? arrowNulla
MemberAccessExpressionSyntax mae => ShouldRemoveArgument(mae.Name),
PostfixUnaryExpressionSyntax pue => ShouldRemoveArgument(pue.Operand),
PrefixUnaryExpressionSyntax pue => ShouldRemoveArgument(pue.Operand),
ObjectCreationExpressionSyntax oe => ShouldRemoveArgument(oe.Type),
ObjectCreationExpressionSyntax oe => ShouldRemoveArgument(oe.Type) || ShouldRemoveObjectCreation(oe),
ImplicitObjectCreationExpressionSyntax oe => ShouldRemoveObjectCreation(oe),
ConditionalAccessExpressionSyntax cae => ShouldRemoveArgument(cae.Expression),
AwaitExpressionSyntax ae => ShouldRemoveArgument(ae.Expression),
AssignmentExpressionSyntax ae => ShouldRemoveArgument(ae.Right),
GenericNameSyntax gn => HasSymbolAndShouldBeRemoved(gn),
LiteralExpressionSyntax le => ShouldRemoveLiteral(le),
_ => false,
};

private bool ShouldRemoveLiteral(LiteralExpressionSyntax literalExpression)
=> literalExpression.Token.IsKind(SyntaxKind.DefaultKeyword)
&& semanticModel.GetTypeInfo(literalExpression).Type is INamedTypeSymbol { Name: "ValueTask", IsGenericType: false } t
&& t.ToString() == ValueTaskType;

private bool ShouldRemoveObjectCreation(BaseObjectCreationExpressionSyntax oe)
=> GetSymbol(oe) is IMethodSymbol { ReceiverType: INamedTypeSymbol { Name: "ValueTask", IsGenericType: false } type }
&& type.ToString() is ValueTaskType;

/// <summary>
/// Keeps track of nested directives.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
//HintName: Test.Class.DoSomethingAsync.g.cs
public static void DoSomething() { }
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
//HintName: Test.Class.ReturnDefault.g.cs
public static int ReturnDefault() => default;
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
//HintName: Test.Class.ReturnAsync.g.cs
public static int Return() { return 1; }
31 changes: 31 additions & 0 deletions tests/Generator.Tests/UnitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,37 @@ public async Task ExecAsync(CancellationToken ct)
public Task DropUnawaitedCompletedTask(string statement) => $"""
[Zomp.SyncMethodGenerator.CreateSyncVersion]
public static Task DoSomethingAsync() {statement}
""".Verify(disableUnique: true);

[Theory]
[InlineData("{ return default; }")]
[InlineData("{ return new(); }")]
[InlineData("{ return new ValueTask(); }")]
#if NETCOREAPP1_0_OR_GREATER
[InlineData("{ return ValueTask.CompletedTask; }")]
[InlineData("{ return ValueTask.CompletedTask; Console.WriteLine(\"123\"); }")]
[InlineData("=> ValueTask.CompletedTask;")]
#endif
public Task DropUnawaitedCompletedValueTask(string statement) => $"""
[Zomp.SyncMethodGenerator.CreateSyncVersion]
public static ValueTask DoSomethingAsync() {statement}
""".Verify(disableUnique: true);

[Fact]
public Task KeepDefaultValueTaskWithResult() => $"""
[Zomp.SyncMethodGenerator.CreateSyncVersion]
public static ValueTask<int> ReturnDefault() => default;
""".Verify();

[Theory]
[InlineData("{ return new(1); }")]
[InlineData("{ return new ValueTask<int>(1); }")]
#if NETCOREAPP1_0_OR_GREATER
[InlineData("{ return ValueTask.FromResult(1); }")]
#endif
public Task ReturnValueTaskInstance(string statement) => $"""
[Zomp.SyncMethodGenerator.CreateSyncVersion]
public static ValueTask<int> ReturnAsync() {statement}
""".Verify(disableUnique: true);

[Fact]
Expand Down
Loading