Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check ref safety of arg mixing in interpolated strings #76263

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/Compilers/CSharp/Portable/Binder/Binder.ValueChecks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1940,6 +1940,12 @@ internal SafeContext GetInterpolatedStringHandlerConversionEscapeScope(
escapeScope = escapeScope.Intersect(argEscape);
}

if (!scopeOfTheContainingExpression.IsConvertibleTo(escapeScope) &&
!CheckInterpolatedStringHandlerInvocationArgMixing(expression, escapeFrom: scopeOfTheContainingExpression, escapeTo: escapeScope, BindingDiagnosticBag.Discarded))
{
escapeScope = scopeOfTheContainingExpression;
}

arguments.Free();
return escapeScope;
}
Expand Down Expand Up @@ -4474,6 +4480,11 @@ internal SafeContext GetValEscape(BoundExpression expr, SafeContext scopeOfTheCo
return scopeOfTheContainingExpression;

case BoundKind.InterpolatedStringHandlerPlaceholder:
if (_placeholderScopes?.TryPeek((BoundInterpolatedStringHandlerPlaceholder)expr, out var scope) == true)
{
return scope;
}

// The handler placeholder cannot escape out of the current expression, as it's a compiler-synthesized
// location.
return scopeOfTheContainingExpression;
Expand Down Expand Up @@ -4767,6 +4778,7 @@ internal bool CheckValEscape(SyntaxNode node, BoundExpression expr, SafeContext
case BoundKind.DeconstructValuePlaceholder:
case BoundKind.AwaitableValuePlaceholder:
case BoundKind.InterpolatedStringArgumentPlaceholder:
case BoundKind.InterpolatedStringHandlerPlaceholder:
if (!GetPlaceholderScope((BoundValuePlaceholderBase)expr).IsConvertibleTo(escapeTo))
{
Error(diagnostics, inUnsafeRegion ? ErrorCode.WRN_EscapeVariable : ErrorCode.ERR_EscapeVariable, node, expr.Syntax);
Expand Down Expand Up @@ -5610,6 +5622,11 @@ private bool CheckInterpolatedStringHandlerConversionEscape(BoundExpression expr
}
}

if (result)
{
result = CheckInterpolatedStringHandlerInvocationArgMixing(expression, escapeFrom, escapeTo, diagnostics);
}

arguments.Free();
return result;
}
Expand Down Expand Up @@ -5661,5 +5678,54 @@ void getParts(BoundInterpolatedString interpolatedString)
}
}
}

private bool CheckInterpolatedStringHandlerInvocationArgMixing(BoundExpression expression, SafeContext escapeFrom, SafeContext escapeTo, BindingDiagnosticBag diagnostics)
{
bool result = true;

while (true)
{
switch (expression)
{
case BoundBinaryOperator binary:
result &= CheckInterpolatedStringHandlerInvocationArgMixing(binary.Right, escapeFrom, escapeTo, diagnostics);
expression = binary.Left;
break;

case BoundInterpolatedString interpolatedString:
result &= CheckInterpolatedStringHandlerInvocationArgMixingParts(interpolatedString, escapeFrom, escapeTo, diagnostics);
return result;

default:
throw ExceptionUtilities.UnexpectedValue(expression.Kind);
}
}
}

private bool CheckInterpolatedStringHandlerInvocationArgMixingParts(BoundInterpolatedString interpolatedString, SafeContext escapeFrom, SafeContext escapeTo, BindingDiagnosticBag diagnostics)
{
bool result = true;

foreach (var part in interpolatedString.Parts)
{
if (part is BoundCall { Method.Name: BoundInterpolatedString.AppendFormattedMethod } call)
{
using var _ = new PlaceholderRegion(this, [((BoundInterpolatedStringHandlerPlaceholder)call.ReceiverOpt, escapeTo)], overwriteExistingTemporarily: true);
result &= CheckInvocationArgMixing(
call.Syntax,
MethodInfo.Create(call.Method),
receiverOpt: call.ReceiverOpt,
receiverIsSubjectToCloning: call.InitialBindingReceiverIsSubjectToCloning,
parameters: call.Method.Parameters,
argsOpt: call.Arguments,
argRefKindsOpt: call.ArgumentRefKindsOpt,
argsToParamsOpt: call.ArgsToParamsOpt,
scopeOfTheContainingExpression: escapeFrom,
diagnostics);
}
}

return result;
}
}
}
30 changes: 17 additions & 13 deletions src/Compilers/CSharp/Portable/Binder/RefSafetyAnalysis.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using Microsoft.CodeAnalysis.Collections;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;
Expand Down Expand Up @@ -61,7 +62,7 @@ private static bool InUnsafeMethod(Symbol symbol)
private bool _inUnsafeRegion;
private SafeContext _localScopeDepth;
private Dictionary<LocalSymbol, (SafeContext RefEscapeScope, SafeContext ValEscapeScope)>? _localEscapeScopes;
private Dictionary<BoundValuePlaceholderBase, SafeContext>? _placeholderScopes;
private KeyedStack<BoundValuePlaceholderBase, SafeContext>? _placeholderScopes;
private SafeContext _patternInputValEscape;
#if DEBUG
private const int MaxTrackVisited = 100; // Avoid tracking if too many expressions.
Expand Down Expand Up @@ -152,22 +153,24 @@ private ref struct PlaceholderRegion
{
private readonly RefSafetyAnalysis _analysis;
private readonly ArrayBuilder<(BoundValuePlaceholderBase, SafeContext)> _placeholders;
private readonly bool _overwriteExistingTemporarily;

public PlaceholderRegion(RefSafetyAnalysis analysis, ArrayBuilder<(BoundValuePlaceholderBase, SafeContext)> placeholders)
public PlaceholderRegion(RefSafetyAnalysis analysis, ArrayBuilder<(BoundValuePlaceholderBase, SafeContext)> placeholders, bool overwriteExistingTemporarily = false)
{
_analysis = analysis;
_placeholders = placeholders;
_overwriteExistingTemporarily = overwriteExistingTemporarily;
foreach (var (placeholder, valEscapeScope) in placeholders)
{
_analysis.AddPlaceholderScope(placeholder, valEscapeScope);
_analysis.AddPlaceholderScope(placeholder, valEscapeScope, canExist: overwriteExistingTemporarily);
}
}

public void Dispose()
{
foreach (var (placeholder, _) in _placeholders)
{
_analysis.RemovePlaceholderScope(placeholder);
_analysis.RemovePlaceholderScope(placeholder, forcePop: _overwriteExistingTemporarily);
}
_placeholders.Free();
}
Expand All @@ -189,33 +192,34 @@ private void SetLocalScopes(LocalSymbol local, SafeContext refEscapeScope, SafeC
AddOrSetLocalScopes(local, refEscapeScope, valEscapeScope);
}

private void AddPlaceholderScope(BoundValuePlaceholderBase placeholder, SafeContext valEscapeScope)
private void AddPlaceholderScope(BoundValuePlaceholderBase placeholder, SafeContext valEscapeScope, bool canExist = false)
{
Debug.Assert(_placeholderScopes?.ContainsKey(placeholder) != true);
Debug.Assert(canExist || _placeholderScopes?.ContainsKey(placeholder) != true);

// Consider not adding the placeholder to the dictionary if the escape scope is
// CallingMethod, and simply fallback to that value in GetPlaceholderScope().

_placeholderScopes ??= new Dictionary<BoundValuePlaceholderBase, SafeContext>();
_placeholderScopes[placeholder] = valEscapeScope;
_placeholderScopes ??= new KeyedStack<BoundValuePlaceholderBase, SafeContext>();
_placeholderScopes.Push(placeholder, valEscapeScope);
}

#pragma warning disable IDE0060
private void RemovePlaceholderScope(BoundValuePlaceholderBase placeholder)
private void RemovePlaceholderScope(BoundValuePlaceholderBase placeholder, bool forcePop = false)
{
Debug.Assert(_placeholderScopes?.ContainsKey(placeholder) == true);

// https://github.com/dotnet/roslyn/issues/65961: Currently, analysis may require subsequent calls
// to GetRefEscape(), etc. for the same expression so we cannot remove placeholders eagerly.
//_placeholderScopes.Remove(placeholder);
if (forcePop)
{
_placeholderScopes?.TryPop(placeholder, out _);
}
}
#pragma warning restore IDE0060

private SafeContext GetPlaceholderScope(BoundValuePlaceholderBase placeholder)
{
Debug.Assert(_placeholderScopes?.ContainsKey(placeholder) == true);

return _placeholderScopes?.TryGetValue(placeholder, out var scope) == true
return _placeholderScopes?.TryPeek(placeholder, out var scope) == true
? scope
: SafeContext.CallingMethod;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15213,7 +15213,6 @@ static CustomHandler F3()
Diagnostic(ErrorCode.ERR_EscapeVariable, "h3").WithArguments("h3").WithLocation(29, 16));
}

[WorkItem(63306, "https://github.com/dotnet/roslyn/issues/63306")]
[Theory]
[InlineData(LanguageVersion.CSharp10)]
[InlineData(LanguageVersion.CSharp11)]
Expand Down Expand Up @@ -15251,7 +15250,6 @@ static CustomHandler F4()
}
}
";
// https://github.com/dotnet/roslyn/issues/63306: Should report an error in each case.
var comp = CreateCompilation(new[] { code, InterpolatedStringHandlerAttribute }, parseOptions: TestOptions.Regular.WithLanguageVersion(languageVersion), targetFramework: TargetFramework.Net50);
comp.VerifyDiagnostics();
}
Expand Down Expand Up @@ -15515,9 +15513,17 @@ static CustomHandler F2()
}
}
""";
// https://github.com/dotnet/roslyn/issues/63306: Should report an error that a reference to y will escape F1() and F2().
var comp = CreateCompilation(source, targetFramework: TargetFramework.Net70);
comp.VerifyDiagnostics();
comp.VerifyDiagnostics(
// (14,18): error CS8156: An expression cannot be used in this context because it may not be passed or returned by reference
// return $"{1}";
Diagnostic(ErrorCode.ERR_RefReturnLvalueExpected, "{1}").WithLocation(14, 18),
// (14,18): error CS8350: This combination of arguments to 'CustomHandler.AppendFormatted(int, in int)' is disallowed because it may expose variables referenced by parameter 'y' outside of their declaration scope
// return $"{1}";
Diagnostic(ErrorCode.ERR_CallArgMixing, "{1}").WithArguments("CustomHandler.AppendFormatted(int, in int)", "y").WithLocation(14, 18),
// (19,16): error CS8352: Cannot use variable 'h2' in this context because it may expose referenced variables outside of their declaration scope
// return h2;
Diagnostic(ErrorCode.ERR_EscapeVariable, "h2").WithArguments("h2").WithLocation(19, 16));
}

[WorkItem(67070, "https://github.com/dotnet/roslyn/issues/67070")]
Expand Down
117 changes: 117 additions & 0 deletions src/Compilers/CSharp/Test/Semantic/Semantics/RefEscapingTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10227,6 +10227,123 @@ public void Utf8Addition()
CreateCompilation(code, targetFramework: TargetFramework.Net70).VerifyDiagnostics();
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/63306")]
public void InterpolatedString_UnscopedRef_Return()
{
var source = """
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;

class C
{
R M()
{
var local = 1;
return $"{local}";
}
}

[InterpolatedStringHandlerAttribute]
ref struct R
{
public R(int literalLength, int formattedCount) { }

public void AppendFormatted([UnscopedRef] in int x) { }
}
""";
CreateCompilation([source, UnscopedRefAttributeDefinition, InterpolatedStringHandlerAttribute]).VerifyDiagnostics(
// (9,18): error CS8350: This combination of arguments to 'R.AppendFormatted(in int)' is disallowed because it may expose variables referenced by parameter 'x' outside of their declaration scope
// return $"{local}";
Diagnostic(ErrorCode.ERR_CallArgMixing, "{local}").WithArguments("R.AppendFormatted(in int)", "x").WithLocation(9, 18),
// (9,19): error CS8168: Cannot return local 'local' by reference because it is not a ref local
// return $"{local}";
Diagnostic(ErrorCode.ERR_RefReturnLocal, "local").WithArguments("local").WithLocation(9, 19));
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/63306")]
public void InterpolatedString_UnscopedRef_Assignment()
{
var source = """
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;

class C
{
R M()
{
var local = 1;
R r = $"{local}";
return r;
}
}

[InterpolatedStringHandlerAttribute]
ref struct R
{
public R(int literalLength, int formattedCount) { }

public void AppendFormatted([UnscopedRef] in int x) { }
}
""";
CreateCompilation([source, UnscopedRefAttributeDefinition, InterpolatedStringHandlerAttribute]).VerifyDiagnostics(
// (10,16): error CS8352: Cannot use variable 'r' in this context because it may expose referenced variables outside of their declaration scope
// return r;
Diagnostic(ErrorCode.ERR_EscapeVariable, "r").WithArguments("r").WithLocation(10, 16));
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/63306")]
public void InterpolatedString_ScopedRef_Return()
{
var source = """
using System.Runtime.CompilerServices;

class C
{
R M()
{
var local = 1;
return $"{local}";
}
}

[InterpolatedStringHandlerAttribute]
ref struct R
{
public R(int literalLength, int formattedCount) { }

public void AppendFormatted(in int x) { }
}
""";
CreateCompilation([source, InterpolatedStringHandlerAttribute]).VerifyDiagnostics();
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/63306")]
public void InterpolatedString_ScopedRef_Assignment()
{
var source = """
using System.Runtime.CompilerServices;

class C
{
R M()
{
var local = 1;
R r = $"{local}";
return r;
}
}

[InterpolatedStringHandlerAttribute]
ref struct R
{
public R(int literalLength, int formattedCount) { }

public void AppendFormatted(in int x) { }
}
""";
CreateCompilation([source, InterpolatedStringHandlerAttribute]).VerifyDiagnostics();
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/75592")]
public void SelfAssignment_ReturnOnly()
{
Expand Down
24 changes: 18 additions & 6 deletions src/Compilers/Core/Portable/Collections/KeyedStack.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.Text;

namespace Microsoft.CodeAnalysis.Collections
{
internal class KeyedStack<T, R>
internal sealed class KeyedStack<T, R>
where T : notnull
{
private readonly Dictionary<T, Stack<R>> _dict = new Dictionary<T, Stack<R>>();

public bool ContainsKey(T key)
{
return _dict.ContainsKey(key);
}

public void Push(T key, R value)
{
Stack<R>? store;
Expand All @@ -29,6 +29,18 @@ public void Push(T key, R value)
store.Push(value);
}

public bool TryPeek(T key, [MaybeNullWhen(returnValue: false)] out R value)
{
if (_dict.TryGetValue(key, out var stack) && stack.Count > 0)
{
value = stack.Peek();
return true;
}

value = default;
return false;
}

public bool TryPop(T key, [MaybeNullWhen(returnValue: false)] out R value)
{
Stack<R>? store;
Expand Down