diff --git a/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs b/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs index e70c067..ba8d0be 100644 --- a/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs +++ b/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs @@ -27,6 +27,7 @@ internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel) : CSharpS private const string IAsyncEnumerator = "System.Collections.Generic.IAsyncEnumerator"; private const string FromResult = "FromResult"; private const string Delay = "Delay"; + private const string CompletedTask = "CompletedTask"; private static readonly HashSet Drops = [IProgressInterface, CancellationTokenType]; private static readonly HashSet InterfacesToDrop = [IProgressInterface, IAsyncResultInterface]; private static readonly Dictionary Replacements = new() @@ -498,11 +499,31 @@ bool IsValidParameter(ParameterSyntax ps) return @base; } + /// + public override SyntaxNode? VisitImplicitObjectCreationExpression(ImplicitObjectCreationExpressionSyntax node) + { + var @base = base.VisitImplicitObjectCreationExpression(node); + var symbol = GetSymbol(node); + + if (TryReplaceObjectCreation(node, symbol, out var replacement)) + { + return replacement; + } + + return @base; + } + /// public override SyntaxNode? VisitObjectCreationExpression(ObjectCreationExpressionSyntax node) { var @base = (ObjectCreationExpressionSyntax)base.VisitObjectCreationExpression(node)!; var symbol = GetSymbol(node); + + if (TryReplaceObjectCreation(node, symbol, out var replacement)) + { + return replacement; + } + if (symbol is null or { ContainingType.IsGenericType: true } or INamedTypeSymbol { IsGenericType: true }) @@ -1158,6 +1179,7 @@ private static bool ShouldRemoveType(ITypeSymbol symbol) private static bool ShouldRemoveArgument(ISymbol symbol) => symbol switch { + IPropertySymbol { Name: CompletedTask } ps => ps.Type.ToString() is TaskType or ValueTaskType, IMethodSymbol ms => IsSpecialMethod(ms) == SpecialMethod.None && ((ShouldRemoveType(ms.ReturnType) && ms.MethodKind != MethodKind.LocalFunction) @@ -1201,6 +1223,20 @@ private static TypeSyntax GetReturnType(TypeSyntax returnType, INamedTypeSymbol private static string Global(string type) => $"global::{type}"; + private static bool TryReplaceObjectCreation(BaseObjectCreationExpressionSyntax node, ISymbol? symbol, out SyntaxNode? replacement) + { + if (symbol is IMethodSymbol { ReceiverType: INamedTypeSymbol { Name: "ValueTask", IsGenericType: true } type } + && GetNameWithoutTypeParams(type) is ValueTaskType + && node.ArgumentList is { Arguments: [var singleArg] }) + { + replacement = singleArg.Expression; + return true; + } + + replacement = default; + return false; + } + private static InvocationExpressionSyntax UnwrapExtension(InvocationExpressionSyntax ies, bool isMemory, IMethodSymbol reducedFrom, ExpressionSyntax expression) { var arguments = ies.ArgumentList.Arguments; @@ -1514,14 +1550,25 @@ private bool ShouldRemoveArrowExpression(ArrowExpressionClauseSyntax? arrowNulla MemberAccessExpressionSyntax mae => ShouldRemoveArgument(mae.Name), PostfixUnaryExpressionSyntax pue => ShouldRemoveArgument(pue.Operand), PrefixUnaryExpressionSyntax pue => ShouldRemoveArgument(pue.Operand), - ObjectCreationExpressionSyntax oe => ShouldRemoveArgument(oe.Type), + ObjectCreationExpressionSyntax oe => ShouldRemoveArgument(oe.Type) || ShouldRemoveObjectCreation(oe), + ImplicitObjectCreationExpressionSyntax oe => ShouldRemoveObjectCreation(oe), ConditionalAccessExpressionSyntax cae => ShouldRemoveArgument(cae.Expression), AwaitExpressionSyntax ae => ShouldRemoveArgument(ae.Expression), AssignmentExpressionSyntax ae => ShouldRemoveArgument(ae.Right), GenericNameSyntax gn => HasSymbolAndShouldBeRemoved(gn), + LiteralExpressionSyntax le => ShouldRemoveLiteral(le), _ => false, }; + private bool ShouldRemoveLiteral(LiteralExpressionSyntax literalExpression) + => literalExpression.Token.IsKind(SyntaxKind.DefaultKeyword) + && semanticModel.GetTypeInfo(literalExpression).Type is INamedTypeSymbol { Name: "ValueTask", IsGenericType: false } t + && t.ToString() == ValueTaskType; + + private bool ShouldRemoveObjectCreation(BaseObjectCreationExpressionSyntax oe) + => GetSymbol(oe) is IMethodSymbol { ReceiverType: INamedTypeSymbol { Name: "ValueTask", IsGenericType: false } type } + && type.ToString() is ValueTaskType; + /// /// Keeps track of nested directives. /// diff --git a/tests/Generator.Tests/Snapshots/UnitTests.DropUnawaitedCompletedValueTask#DoSomethingAsync.g.verified.cs b/tests/Generator.Tests/Snapshots/UnitTests.DropUnawaitedCompletedValueTask#DoSomethingAsync.g.verified.cs new file mode 100644 index 0000000..e6230c5 --- /dev/null +++ b/tests/Generator.Tests/Snapshots/UnitTests.DropUnawaitedCompletedValueTask#DoSomethingAsync.g.verified.cs @@ -0,0 +1,2 @@ +//HintName: Test.Class.DoSomethingAsync.g.cs +public static void DoSomething() { } diff --git a/tests/Generator.Tests/Snapshots/UnitTests.KeepDefaultValueTaskWithResult#ReturnDefault.g.verified.cs b/tests/Generator.Tests/Snapshots/UnitTests.KeepDefaultValueTaskWithResult#ReturnDefault.g.verified.cs new file mode 100644 index 0000000..78aabf8 --- /dev/null +++ b/tests/Generator.Tests/Snapshots/UnitTests.KeepDefaultValueTaskWithResult#ReturnDefault.g.verified.cs @@ -0,0 +1,2 @@ +//HintName: Test.Class.ReturnDefault.g.cs +public static int ReturnDefault() => default; diff --git a/tests/Generator.Tests/Snapshots/UnitTests.ReturnValueTaskInstance#ReturnAsync.g.verified.cs b/tests/Generator.Tests/Snapshots/UnitTests.ReturnValueTaskInstance#ReturnAsync.g.verified.cs new file mode 100644 index 0000000..24f7972 --- /dev/null +++ b/tests/Generator.Tests/Snapshots/UnitTests.ReturnValueTaskInstance#ReturnAsync.g.verified.cs @@ -0,0 +1,2 @@ +//HintName: Test.Class.ReturnAsync.g.cs +public static int Return() { return 1; } diff --git a/tests/Generator.Tests/UnitTests.cs b/tests/Generator.Tests/UnitTests.cs index 060f663..07cb056 100644 --- a/tests/Generator.Tests/UnitTests.cs +++ b/tests/Generator.Tests/UnitTests.cs @@ -106,6 +106,37 @@ public async Task ExecAsync(CancellationToken ct) public Task DropUnawaitedCompletedTask(string statement) => $""" [Zomp.SyncMethodGenerator.CreateSyncVersion] public static Task DoSomethingAsync() {statement} +""".Verify(disableUnique: true); + + [Theory] + [InlineData("{ return default; }")] + [InlineData("{ return new(); }")] + [InlineData("{ return new ValueTask(); }")] +#if NETCOREAPP1_0_OR_GREATER + [InlineData("{ return ValueTask.CompletedTask; }")] + [InlineData("{ return ValueTask.CompletedTask; Console.WriteLine(\"123\"); }")] + [InlineData("=> ValueTask.CompletedTask;")] +#endif + public Task DropUnawaitedCompletedValueTask(string statement) => $""" +[Zomp.SyncMethodGenerator.CreateSyncVersion] +public static ValueTask DoSomethingAsync() {statement} +""".Verify(disableUnique: true); + + [Fact] + public Task KeepDefaultValueTaskWithResult() => $""" +[Zomp.SyncMethodGenerator.CreateSyncVersion] +public static ValueTask ReturnDefault() => default; +""".Verify(); + + [Theory] + [InlineData("{ return new(1); }")] + [InlineData("{ return new ValueTask(1); }")] +#if NETCOREAPP1_0_OR_GREATER + [InlineData("{ return ValueTask.FromResult(1); }")] +#endif + public Task ReturnValueTaskInstance(string statement) => $""" +[Zomp.SyncMethodGenerator.CreateSyncVersion] +public static ValueTask ReturnAsync() {statement} """.Verify(disableUnique: true); [Fact]