Skip to content

Commit d12d883

Browse files
committed
Use the scoped keyword to significantly simplify the caller-allocated-buffer logic for stateful marshallers while also increasing safety/keeping us from shooting ourselves in the foot with lifetimes.
PR feedback Use the official ScopedKeyword API
1 parent 69f58ba commit d12d883

File tree

5 files changed

+53
-44
lines changed

5 files changed

+53
-44
lines changed

src/libraries/System.Formats.Tar/src/System/Formats/Tar/TarHeader.Read.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ internal sealed partial class TarHeader
2828
archiveStream.ReadExactly(buffer);
2929

3030
TarHeader? header = TryReadAttributes(initialFormat, buffer);
31-
header?.ProcessDataBlock(archiveStream, copyData);
31+
if (header != null)
32+
{
33+
header.ProcessDataBlock(archiveStream, copyData);
34+
}
3235

3336
return header;
3437
}

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManagedTypeInfo.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,17 @@ namespace Microsoft.Interop
1515
/// </summary>
1616
public abstract record ManagedTypeInfo(string FullTypeName, string DiagnosticFormattedName)
1717
{
18-
public TypeSyntax Syntax { get; } = SyntaxFactory.ParseTypeName(FullTypeName);
18+
private TypeSyntax? _syntax;
19+
public TypeSyntax Syntax => _syntax ??= SyntaxFactory.ParseTypeName(FullTypeName);
20+
21+
protected ManagedTypeInfo(ManagedTypeInfo original)
22+
{
23+
FullTypeName = original.FullTypeName;
24+
DiagnosticFormattedName = original.DiagnosticFormattedName;
25+
// Explicitly don't initialize _syntax here. We want Syntax to be recalculated
26+
// from the results of a with-expression, which assigns the new property values
27+
// to the result of this constructor.
28+
}
1929

2030
public static ManagedTypeInfo CreateTypeInfoForTypeSymbol(ITypeSymbol type)
2131
{

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo
235235
ICustomTypeMarshallingStrategy marshallingStrategy;
236236
if (marshallerData.HasState)
237237
{
238-
marshallingStrategy = new StatefulValueMarshalling(marshallerData.MarshallerType.Syntax, marshallerData.NativeType.Syntax, marshallerData.Shape);
238+
marshallingStrategy = new StatefulValueMarshalling(marshallerData.MarshallerType, marshallerData.NativeType.Syntax, marshallerData.Shape);
239239
if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer))
240240
marshallingStrategy = new StatefulCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax, marshallerData.BufferElementType.Syntax);
241241
}
@@ -283,16 +283,21 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
283283

284284
// Insert the unmanaged element type into the marshaller type
285285
TypeSyntax unmanagedElementType = elementMarshaller.AsNativeType(elementInfo).GetCompatibleGenericTypeParameterSyntax();
286-
TypeSyntax marshallerTypeSyntax = marshallerData.MarshallerType.Syntax;
287-
marshallerTypeSyntax = ReplacePlaceholderSyntaxWithUnmanagedTypeSyntax(marshallerTypeSyntax, marshalInfo, unmanagedElementType);
286+
ManagedTypeInfo marshallerType = marshallerData.MarshallerType;
287+
TypeSyntax marshallerTypeSyntax = ReplacePlaceholderSyntaxWithUnmanagedTypeSyntax(marshallerType.Syntax, marshalInfo, unmanagedElementType);
288+
marshallerType = marshallerType with
289+
{
290+
FullTypeName = marshallerTypeSyntax.ToString(),
291+
DiagnosticFormattedName = marshallerTypeSyntax.ToString(),
292+
};
288293
TypeSyntax nativeTypeSyntax = ReplacePlaceholderSyntaxWithUnmanagedTypeSyntax(marshallerData.NativeType.Syntax, marshalInfo, unmanagedElementType);
289294

290295
ICustomTypeMarshallingStrategy marshallingStrategy;
291296
bool elementIsBlittable = elementMarshaller is BlittableMarshaller;
292297

293298
if (marshallerData.HasState)
294299
{
295-
marshallingStrategy = new StatefulValueMarshalling(marshallerTypeSyntax, nativeTypeSyntax, marshallerData.Shape);
300+
marshallingStrategy = new StatefulValueMarshalling(marshallerType, nativeTypeSyntax, marshallerData.Shape);
296301
if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer))
297302
{
298303
// Check if the buffer element type is actually the unmanaged element type

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ namespace Microsoft.Interop
1414
internal sealed class StatefulValueMarshalling : ICustomTypeMarshallingStrategy
1515
{
1616
internal const string MarshallerIdentifier = "marshaller";
17-
private readonly TypeSyntax _marshallerTypeSyntax;
17+
private readonly ManagedTypeInfo _marshallerType;
1818
private readonly TypeSyntax _nativeTypeSyntax;
1919
private readonly MarshallerShape _shape;
2020

21-
public StatefulValueMarshalling(TypeSyntax marshallerTypeSyntax, TypeSyntax nativeTypeSyntax, MarshallerShape shape)
21+
public StatefulValueMarshalling(ManagedTypeInfo marshallerType, TypeSyntax nativeTypeSyntax, MarshallerShape shape)
2222
{
23-
_marshallerTypeSyntax = marshallerTypeSyntax;
23+
_marshallerType = marshallerType;
2424
_nativeTypeSyntax = nativeTypeSyntax;
2525
_shape = shape;
2626
}
@@ -140,10 +140,23 @@ public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePosit
140140
public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context)
141141
{
142142
// <marshaller> = new();
143-
yield return MarshallerHelpers.Declare(
144-
_marshallerTypeSyntax,
143+
LocalDeclarationStatementSyntax declaration = MarshallerHelpers.Declare(
144+
_marshallerType.Syntax,
145145
context.GetAdditionalIdentifier(info, MarshallerIdentifier),
146146
ImplicitObjectCreationExpression(ArgumentList(), initializer: null));
147+
148+
// For byref-like marshaller types, we'll mark them as scoped.
149+
// Byref-like types can capture references, so by default the compiler has to worry that
150+
// they could enable those references to escape the current stack frame.
151+
// In particular, this can interact poorly with the caller-allocated-buffer marshalling
152+
// support and make the simple `marshaller.FromManaged(managed, stackalloc X[i])` expression
153+
// illegal. Mark the marshaller type as scoped so the compiler knows that it won't escape.
154+
if (_marshallerType is ValueTypeInfo { IsByRefLike: true })
155+
{
156+
declaration = declaration.AddModifiers(Token(SyntaxKind.ScopedKeyword));
157+
}
158+
159+
yield return declaration;
147160
}
148161

149162
public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info, StubCodeContext context)
@@ -218,28 +231,9 @@ public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo i
218231

219232
IEnumerable<StatementSyntax> GenerateCallerAllocatedBufferMarshalStatements()
220233
{
221-
// TODO: Update once we can consume the scoped keword. We should be able to simplify this once we get that API.
222-
string stackPtrIdentifier = context.GetAdditionalIdentifier(info, "stackptr");
223-
// <bufferElementType>* <managedIdentifier>__stackptr = stackalloc <bufferElementType>[<_bufferSize>];
224-
yield return LocalDeclarationStatement(
225-
VariableDeclaration(
226-
PointerType(_bufferElementType),
227-
SingletonSeparatedList(
228-
VariableDeclarator(stackPtrIdentifier)
229-
.WithInitializer(EqualsValueClause(
230-
StackAllocArrayCreationExpression(
231-
ArrayType(
232-
_bufferElementType,
233-
SingletonList(ArrayRankSpecifier(SingletonSeparatedList<ExpressionSyntax>(
234-
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
235-
_marshallerType,
236-
IdentifierName(ShapeMemberNames.BufferSize))
237-
))))))))));
238-
239-
240234
(string managedIdentifier, _) = context.GetIdentifiers(info);
241235

242-
// <marshaller>.FromManaged(<managedIdentifier>, new Span<bufferElementType>(<stackPtrIdentifier>, <marshallerType>.BufferSize));
236+
// <marshaller>.FromManaged(<managedIdentifier>, stackalloc <bufferElementType>[<marshallerType>.BufferSize]);
243237
yield return ExpressionStatement(
244238
InvocationExpression(
245239
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
@@ -249,19 +243,13 @@ IEnumerable<StatementSyntax> GenerateCallerAllocatedBufferMarshalStatements()
249243
new[]
250244
{
251245
Argument(IdentifierName(managedIdentifier)),
252-
Argument(
253-
ObjectCreationExpression(
254-
GenericName(Identifier(TypeNames.System_Span),
255-
TypeArgumentList(SingletonSeparatedList(
256-
_bufferElementType))))
257-
.WithArgumentList(
258-
ArgumentList(SeparatedList(new ArgumentSyntax[]
259-
{
260-
Argument(IdentifierName(stackPtrIdentifier)),
261-
Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
246+
Argument(StackAllocArrayCreationExpression(
247+
ArrayType(
248+
_bufferElementType,
249+
SingletonList(ArrayRankSpecifier(SingletonSeparatedList<ExpressionSyntax>(
250+
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
262251
_marshallerType,
263-
IdentifierName(ShapeMemberNames.BufferSize)))
264-
}))))
252+
IdentifierName(ShapeMemberNames.BufferSize))))))))
265253
}))));
266254
}
267255
}

src/libraries/System.Text.RegularExpressions/gen/UpgradeToGeneratedRegexCodeFixer.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ private static async Task<Document> ConvertToSourceGenerator(Document document,
100100
// Get the parent type declaration so that we can inspect its methods as well as check if we need to add the partial keyword.
101101
SyntaxNode? typeDeclarationOrCompilationUnit = nodeToFix.Ancestors().OfType<TypeDeclarationSyntax>().FirstOrDefault();
102102

103-
typeDeclarationOrCompilationUnit ??= await nodeToFix.SyntaxTree.GetRootAsync(cancellationToken).ConfigureAwait(false);
103+
if (typeDeclarationOrCompilationUnit is null)
104+
{
105+
typeDeclarationOrCompilationUnit = await nodeToFix.SyntaxTree.GetRootAsync(cancellationToken).ConfigureAwait(false);
106+
}
104107

105108
// Calculate what name should be used for the generated static partial method
106109
string methodName = DefaultRegexMethodName;

0 commit comments

Comments
 (0)