From 6e2ab0d853d915a6cd083d1ff702cf2df70ee7f6 Mon Sep 17 00:00:00 2001 From: Gerard Smit Date: Fri, 26 Jan 2024 20:37:58 +0100 Subject: [PATCH] Fix return statement in child blocks --- .../AsyncToSyncRewriter.cs | 22 +++++++++++++++---- ...DefaultValueTask#ReturnAsync.g.verified.cs | 9 ++++++++ tests/Generator.Tests/UnitTests.cs | 14 ++++++++++++ 3 files changed, 41 insertions(+), 4 deletions(-) create mode 100644 tests/Generator.Tests/Snapshots/UnitTests.ReturnDefaultValueTask#ReturnAsync.g.verified.cs diff --git a/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs b/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs index cf79f0a..22fb4a5 100644 --- a/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs +++ b/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs @@ -520,9 +520,11 @@ bool IsValidParameter(ParameterSyntax ps) semanticModel.GetTypeInfo(returnExpression).Type is INamedTypeSymbol { Name: "Task" or "ValueTask", IsGenericType: false } returnType && returnType.ToString() is TaskType or ValueTaskType) { - var result = (ExpressionSyntax)Visit(returnExpression); + var result = !ShouldRemoveArgument(returnExpression) + ? (ExpressionSyntax)Visit(returnExpression) + : null; - if (node.Parent is not BlockSyntax) + if (result is not null && node.Parent is not BlockSyntax) { // The parent is not a block, for example: if (true) return ReturnAsync(); // We need to create a block with the expression and the return statement. @@ -536,14 +538,26 @@ bool IsValidParameter(ParameterSyntax ps) } // Don't return if the return statement is the last statement in the method. - if (node.Parent.Parent is MethodDeclarationSyntax { Body.Statements: [.., var lastStatement] } && + if (node.Parent?.Parent is MethodDeclarationSyntax { Body.Statements: [.., var lastStatement] } && lastStatement == node) { + if (result is null) + { + return null; + } + return ExpressionStatement(result) .WithLeadingTrivia(node.GetLeadingTrivia()) .WithTrailingTrivia(node.GetTrailingTrivia()); } + if (result is null) + { + return ReturnStatement() + .WithTrailingTrivia(node.GetTrailingTrivia()) + .WithLeadingTrivia(node.GetLeadingTrivia()); + } + // Create a block without the braces (eg. Return(); return;) return Block(List(new StatementSyntax[] { @@ -1550,7 +1564,7 @@ private TypeSyntax ProcessSyntaxUsingSymbol(TypeSyntax typeSyntax) IfStatementSyntax @if => ShouldRemoveArgument(@if.Condition), ExpressionStatementSyntax e => ShouldRemoveArgument(e.Expression), LocalDeclarationStatementSyntax l => CanDropDeclaration(l), - ReturnStatementSyntax { Expression: { } re } => ShouldRemoveArgument(re), + ReturnStatementSyntax { Parent.Parent: MethodDeclarationSyntax, Expression: { } re } => ShouldRemoveArgument(re), _ => false, }; diff --git a/tests/Generator.Tests/Snapshots/UnitTests.ReturnDefaultValueTask#ReturnAsync.g.verified.cs b/tests/Generator.Tests/Snapshots/UnitTests.ReturnDefaultValueTask#ReturnAsync.g.verified.cs new file mode 100644 index 0000000..4d4d80e --- /dev/null +++ b/tests/Generator.Tests/Snapshots/UnitTests.ReturnDefaultValueTask#ReturnAsync.g.verified.cs @@ -0,0 +1,9 @@ +//HintName: Test.Class.ReturnAsync.g.cs +private void Return(bool input) +{ + if (input) + { + return; + } + global::System.Console.WriteLine("123"); +} diff --git a/tests/Generator.Tests/UnitTests.cs b/tests/Generator.Tests/UnitTests.cs index 31e7bae..b8ce225 100644 --- a/tests/Generator.Tests/UnitTests.cs +++ b/tests/Generator.Tests/UnitTests.cs @@ -195,6 +195,20 @@ private Task ReturnAsync(bool input) private Task ReturnAsync() => Task.CompletedTask; private void Return() { } +""".Verify(); + + [Fact] + public Task ReturnDefaultValueTask() => """ +[CreateSyncVersion] +private ValueTask ReturnAsync(bool input) +{ + if (input) + { + return default; + } + Console.WriteLine("123"); + return default; +} """.Verify(); [Fact]