Skip to content

Commit

Permalink
Fix conditional returns
Browse files Browse the repository at this point in the history
  • Loading branch information
GerardSmit committed Jan 27, 2024
1 parent 6e2ab0d commit 52d2675
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 8 deletions.
43 changes: 35 additions & 8 deletions src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using System.Linq.Expressions;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace Zomp.SyncMethodGenerator;

Expand Down Expand Up @@ -520,17 +521,15 @@ bool IsValidParameter(ParameterSyntax ps)
semanticModel.GetTypeInfo(returnExpression).Type is INamedTypeSymbol { Name: "Task" or "ValueTask", IsGenericType: false } returnType &&
returnType.ToString() is TaskType or ValueTaskType)
{
var result = !ShouldRemoveArgument(returnExpression)
? (ExpressionSyntax)Visit(returnExpression)
: null;
var result = ExpressionToStatement(returnExpression);

if (result is not null && node.Parent is not BlockSyntax)
{
// The parent is not a block, for example: if (true) return ReturnAsync();
// We need to create a block with the expression and the return statement.
return Block(List(new StatementSyntax[]
{
ExpressionStatement(result).WithLeadingTrivia(Space).WithTrailingTrivia(Space),
result.WithLeadingTrivia(Space).WithTrailingTrivia(Space),
ReturnStatement().WithTrailingTrivia(Space),
}))
.WithLeadingTrivia(node.GetLeadingTrivia())
Expand All @@ -546,7 +545,7 @@ bool IsValidParameter(ParameterSyntax ps)
return null;
}

return ExpressionStatement(result)
return result
.WithLeadingTrivia(node.GetLeadingTrivia())
.WithTrailingTrivia(node.GetTrailingTrivia());
}
Expand All @@ -561,7 +560,7 @@ bool IsValidParameter(ParameterSyntax ps)
// Create a block without the braces (eg. Return(); return;)
return Block(List(new StatementSyntax[]
{
ExpressionStatement(result).WithTrailingTrivia(Space),
result.WithTrailingTrivia(Space),
ReturnStatement().WithTrailingTrivia(node.GetTrailingTrivia()),
}))
.WithOpenBraceToken(MissingToken(SyntaxKind.OpenBraceToken))
Expand Down Expand Up @@ -1596,6 +1595,34 @@ private bool DropInvocation(InvocationExpressionSyntax invocation)
private bool ShouldRemoveArrowExpression(ArrowExpressionClauseSyntax? arrowNullable)
=> arrowNullable is { } arrow && ShouldRemoveArgument(arrow.Expression);

private StatementSyntax? ExpressionToStatement(ExpressionSyntax result)
{
// Conditional expression to if statement
if (result is ConditionalExpressionSyntax conditionalExpression)
{
var condition = conditionalExpression.Condition.WithoutTrailingTrivia();

IfStatementSyntax? syntax = (ExpressionToStatement(conditionalExpression.WhenTrue), ExpressionToStatement(conditionalExpression.WhenFalse)) switch
{
(null, null) => null,

Check warning on line 1607 in src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs

View check run for this annotation

Codecov / codecov/patch

src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs#L1607

Added line #L1607 was not covered by tests
(null, { } elseStatement) => IfStatement(PrefixUnaryExpression(SyntaxKind.LogicalNotExpression, condition), elseStatement),
({ } statement, null) => IfStatement(condition, statement),
({ } statement, { } elseStatement) => IfStatement(condition, statement, ElseClause(elseStatement).WithElseKeyword(Token(SyntaxKind.ElseKeyword).PrependSpace().AppendSpace())),
};

return syntax?
.WithIfKeyword(syntax.IfKeyword.AppendSpace())
.WithCloseParenToken(syntax.CloseParenToken.AppendSpace());
}

if (ShouldRemoveArgument(result))
{
return null;
}

return ExpressionStatement((ExpressionSyntax)Visit(result).WithoutTrivia());
}

private bool ShouldRemoveArgument(ExpressionSyntax expr) => expr switch
{
ElementAccessExpressionSyntax ee => ShouldRemoveArgument(ee.Expression),
Expand All @@ -1604,7 +1631,7 @@ private bool ShouldRemoveArrowExpression(ArrowExpressionClauseSyntax? arrowNulla
ParenthesizedExpressionSyntax pe => ShouldRemoveArgument(pe.Expression),
IdentifierNameSyntax id => !id.Identifier.ValueText.EndsWithAsync() && HasSymbolAndShouldBeRemoved(id),
InvocationExpressionSyntax ie => DropInvocation(ie),
ConditionalExpressionSyntax ce => ShouldRemoveArgument(ce.WhenTrue) || ShouldRemoveArgument(ce.WhenFalse),
ConditionalExpressionSyntax ce => ShouldRemoveArgument(ce.WhenTrue) && ShouldRemoveArgument(ce.WhenFalse),
MemberAccessExpressionSyntax mae => ShouldRemoveArgument(mae.Name),
PostfixUnaryExpressionSyntax pue => ShouldRemoveArgument(pue.Operand),
PrefixUnaryExpressionSyntax pue => ShouldRemoveArgument(pue.Operand),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
//HintName: Test.Class.ReturnAsync.g.cs
private void Return(bool input)
{
if (input) Return(); else Return();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
//HintName: Test.Class.ReturnFalseAsync.g.cs
private void ReturnFalse(bool input)
{
if (!input) Return();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
//HintName: Test.Class.ReturnTrueAsync.g.cs
private void ReturnTrue(bool input)
{
if (input) Return();
}
36 changes: 36 additions & 0 deletions tests/Generator.Tests/UnitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,42 @@ private ValueTask ReturnAsync(bool input)
return ReturnAsync();
}

private ValueTask ReturnAsync() => default;
private void Return() { }
""".Verify();

[Fact]
public Task ReturnValueTaskConditional() => """
[CreateSyncVersion]
private ValueTask ReturnAsync(bool input)
{
return input ? ReturnAsync() : ReturnAsync();
}

private ValueTask ReturnAsync() => ReturnAsync();
private void Return() { }
""".Verify();

[Fact]
public Task ReturnValueTaskConditionalTrue() => """
[CreateSyncVersion]
private ValueTask ReturnTrueAsync(bool input)
{
return input ? ReturnAsync() : default;
}

private ValueTask ReturnAsync() => default;
private void Return() { }
""".Verify();

[Fact]
public Task ReturnValueTaskConditionalFalse() => """
[CreateSyncVersion]
private ValueTask ReturnFalseAsync(bool input)
{
return input ? default : ReturnAsync();
}

private ValueTask ReturnAsync() => default;
private void Return() { }
""".Verify();
Expand Down

0 comments on commit 52d2675

Please sign in to comment.