@@ -14,13 +14,13 @@ namespace Microsoft.Interop
14
14
internal sealed class StatefulValueMarshalling : ICustomTypeMarshallingStrategy
15
15
{
16
16
internal const string MarshallerIdentifier = "marshaller" ;
17
- private readonly TypeSyntax _marshallerTypeSyntax ;
17
+ private readonly ManagedTypeInfo _marshallerType ;
18
18
private readonly TypeSyntax _nativeTypeSyntax ;
19
19
private readonly MarshallerShape _shape ;
20
20
21
- public StatefulValueMarshalling ( TypeSyntax marshallerTypeSyntax , TypeSyntax nativeTypeSyntax , MarshallerShape shape )
21
+ public StatefulValueMarshalling ( ManagedTypeInfo marshallerType , TypeSyntax nativeTypeSyntax , MarshallerShape shape )
22
22
{
23
- _marshallerTypeSyntax = marshallerTypeSyntax ;
23
+ _marshallerType = marshallerType ;
24
24
_nativeTypeSyntax = nativeTypeSyntax ;
25
25
_shape = shape ;
26
26
}
@@ -140,10 +140,23 @@ public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePosit
140
140
public IEnumerable < StatementSyntax > GenerateSetupStatements ( TypePositionInfo info , StubCodeContext context )
141
141
{
142
142
// <marshaller> = new();
143
- yield return MarshallerHelpers . Declare (
144
- _marshallerTypeSyntax ,
143
+ LocalDeclarationStatementSyntax declaration = MarshallerHelpers . Declare (
144
+ _marshallerType . Syntax ,
145
145
context . GetAdditionalIdentifier ( info , MarshallerIdentifier ) ,
146
146
ImplicitObjectCreationExpression ( ArgumentList ( ) , initializer : null ) ) ;
147
+
148
+ // For byref-like marshaller types, we'll mark them as scoped.
149
+ // Byref-like types can capture references, so by default the compiler has to worry that
150
+ // they could enable those references to escape the current stack frame.
151
+ // In particular, this can interact poorly with the caller-allocated-buffer marshalling
152
+ // support and make the simple `marshaller.FromManaged(managed, stackalloc X[i])` expression
153
+ // illegal. Mark the marshaller type as scoped so the compiler knows that it won't escape.
154
+ if ( _marshallerType is ValueTypeInfo { IsByRefLike : true } )
155
+ {
156
+ declaration = declaration . AddModifiers ( Token ( SyntaxKind . ScopedKeyword ) ) ;
157
+ }
158
+
159
+ yield return declaration ;
147
160
}
148
161
149
162
public IEnumerable < StatementSyntax > GeneratePinStatements ( TypePositionInfo info , StubCodeContext context )
@@ -218,28 +231,9 @@ public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo i
218
231
219
232
IEnumerable < StatementSyntax > GenerateCallerAllocatedBufferMarshalStatements ( )
220
233
{
221
- // TODO: Update once we can consume the scoped keword. We should be able to simplify this once we get that API.
222
- string stackPtrIdentifier = context . GetAdditionalIdentifier ( info , "stackptr" ) ;
223
- // <bufferElementType>* <managedIdentifier>__stackptr = stackalloc <bufferElementType>[<_bufferSize>];
224
- yield return LocalDeclarationStatement (
225
- VariableDeclaration (
226
- PointerType ( _bufferElementType ) ,
227
- SingletonSeparatedList (
228
- VariableDeclarator ( stackPtrIdentifier )
229
- . WithInitializer ( EqualsValueClause (
230
- StackAllocArrayCreationExpression (
231
- ArrayType (
232
- _bufferElementType ,
233
- SingletonList ( ArrayRankSpecifier ( SingletonSeparatedList < ExpressionSyntax > (
234
- MemberAccessExpression ( SyntaxKind . SimpleMemberAccessExpression ,
235
- _marshallerType ,
236
- IdentifierName ( ShapeMemberNames . BufferSize ) )
237
- ) ) ) ) ) ) ) ) ) ) ;
238
-
239
-
240
234
( string managedIdentifier , _ ) = context . GetIdentifiers ( info ) ;
241
235
242
- // <marshaller>.FromManaged(<managedIdentifier>, new Span <bufferElementType>(<stackPtrIdentifier>, < marshallerType>.BufferSize) );
236
+ // <marshaller>.FromManaged(<managedIdentifier>, stackalloc <bufferElementType>[< marshallerType>.BufferSize] );
243
237
yield return ExpressionStatement (
244
238
InvocationExpression (
245
239
MemberAccessExpression ( SyntaxKind . SimpleMemberAccessExpression ,
@@ -249,19 +243,13 @@ IEnumerable<StatementSyntax> GenerateCallerAllocatedBufferMarshalStatements()
249
243
new [ ]
250
244
{
251
245
Argument ( IdentifierName ( managedIdentifier ) ) ,
252
- Argument (
253
- ObjectCreationExpression (
254
- GenericName ( Identifier ( TypeNames . System_Span ) ,
255
- TypeArgumentList ( SingletonSeparatedList (
256
- _bufferElementType ) ) ) )
257
- . WithArgumentList (
258
- ArgumentList ( SeparatedList ( new ArgumentSyntax [ ]
259
- {
260
- Argument ( IdentifierName ( stackPtrIdentifier ) ) ,
261
- Argument ( MemberAccessExpression ( SyntaxKind . SimpleMemberAccessExpression ,
246
+ Argument ( StackAllocArrayCreationExpression (
247
+ ArrayType (
248
+ _bufferElementType ,
249
+ SingletonList ( ArrayRankSpecifier ( SingletonSeparatedList < ExpressionSyntax > (
250
+ MemberAccessExpression ( SyntaxKind . SimpleMemberAccessExpression ,
262
251
_marshallerType ,
263
- IdentifierName ( ShapeMemberNames . BufferSize ) ) )
264
- } ) ) ) )
252
+ IdentifierName ( ShapeMemberNames . BufferSize ) ) ) ) ) ) ) )
265
253
} ) ) ) ) ;
266
254
}
267
255
}
0 commit comments