Skip to content

Get DataContractSerializer to behave nicely with unloadable AssemblyLoadContext #88791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@
<Compile Include="$(CommonPath)System\CodeDom\CodeTypeReference.cs" />
<Compile Include="$(CommonPath)System\CodeDom\CodeTypeReferenceCollection.cs" />
<Compile Include="$(CommonPath)System\CodeDom\CodeObject.cs" />
<Compile Include="System\Runtime\Serialization\ContextAware.cs" />
</ItemGroup>

<ItemGroup>
Expand All @@ -163,6 +164,7 @@
<Reference Include="System.Reflection.Primitives" />
<Reference Include="System.Runtime" />
<Reference Include="System.Runtime.Intrinsics" />
<Reference Include="System.Runtime.Loader" />
<Reference Include="System.Runtime.Serialization.Formatters" />
<Reference Include="System.Runtime.Serialization.Primitives" />
<Reference Include="System.Text.Encoding.Extensions" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.Loader;
using System.Runtime.Serialization.DataContracts;

namespace System.Runtime.Serialization
{
internal sealed class ContextAwareDataContractIndex
{
private (DataContract? strong, WeakReference<DataContract>? weak)[] _contracts;
private ConditionalWeakTable<Type, DataContract> _keepAlive;

public int Length => _contracts.Length;

public ContextAwareDataContractIndex(int size)
{
_contracts = new (DataContract?, WeakReference<DataContract>?)[size];
_keepAlive = new ConditionalWeakTable<Type, DataContract>();
}

public DataContract? GetItem(int index) => _contracts[index].strong ?? (_contracts[index].weak?.TryGetTarget(out DataContract? ret) == true ? ret : null);

public void SetItem(int index, DataContract dataContract)
{
// Check for unloadability to decide how to store the value
AssemblyLoadContext? alc = AssemblyLoadContext.GetLoadContext(dataContract.UnderlyingType.Assembly);
if (alc == null || !alc.IsCollectible)
{
_contracts[index].strong = dataContract;
}
else
{
_contracts[index].weak = new WeakReference<DataContract>(dataContract);
_keepAlive.Add(dataContract.UnderlyingType, dataContract);
}
}

public void Resize(int newSize)
{
Array.Resize<(DataContract?, WeakReference<DataContract>?)>(ref _contracts, newSize);
}
}

internal sealed class ContextAwareDictionary<TKey, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] TValue>
where TKey : Type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does TKey have to be generic? You could just use Type.

where TValue : class?
{
private readonly ConcurrentDictionary<TKey, TValue> _fastDictionary = new();
private readonly ConditionalWeakTable<TKey, TValue> _collectibleTable = new();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConditionalWeakTable is pretty fast itself. Have you ran any benchmarks to see if non-unloadable assemblies show an improvement with ConcurrentDictionary?



internal TValue GetOrAdd(TKey t, Func<TKey, TValue> f)
{
TValue? ret;

// The fast and most common default case
if (_fastDictionary.TryGetValue(t, out ret))
return ret;

// Common case for collectible contexts
if (_collectibleTable.TryGetValue(t, out ret))
return ret;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to hold the lock on lookup too


// Not found. Do the slower work of creating the value in the correct collection.
AssemblyLoadContext? alc = AssemblyLoadContext.GetLoadContext(t.Assembly);

// Null and non-collectible load contexts use the default table
if (alc == null || !alc.IsCollectible)
{
return _fastDictionary.GetOrAdd(t, f);
}

// Collectible load contexts should use the ConditionalWeakTable so they can be unloaded
else
{
lock (_collectibleTable)
{
if (!_collectibleTable.TryGetValue(t, out ret))
{
ret = f(t);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running the delegate inside the lock is prone to deadlocks. Can you move it outside?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The delegate isn't necessarily a simple amount of work depending on your OM design. It's something you really want to avoid executing multiple times. For example, if you instantiate a DCS on an incoming request on a REST api for example, having 100 concurrent requests could result in redoing this expensive work 100 times. You could negatively impact your first request time significantly.
Can you provide me information about it being prone to deadlocks? The code run by the delegate isn't going to do anything async so any reentrance will occur on the same thread. I don't believe this specific scenario is able to deadlock, but I'm open to learning about ways it can deadlock that I might not be aware of.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You know better if you think that creating a DCS will not execute arbitrary code. But note that ConcurrentDictionary.GetOrAdd will not execute the delegate inside a lock. In the most common case of non-unloadable assemblies there is still the possibility that the delegate will run many times.

There is also ConditionalWeakTable.CreateValue that you can use like Concurrent.Dictionary.GetOrAdd.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, we should place a lock around ConcurrentDictionary.GetOrAdd. We don't need to check a second time if we are always holding a lock when adding as that would guarantee the prior add has completed before calling GetOrAdd and it will act like a Get.

_collectibleTable.AddOrUpdate(t, ret);
}
}
}

return ret;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Buffers.Binary;
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
Expand Down Expand Up @@ -297,19 +298,16 @@ internal virtual bool IsValidContract()

internal class DataContractCriticalHelper
{
private static readonly Hashtable s_typeToIDCache = new Hashtable(new HashTableEqualityComparer());
private static DataContract[] s_dataContractCache = new DataContract[32];
private static readonly ConcurrentDictionary<nint, int> s_typeToIDCache = new();
private static readonly ContextAwareDataContractIndex s_dataContractCache = new(32);
private static int s_dataContractID;
private static Dictionary<Type, DataContract?>? s_typeToBuiltInContract;
private static readonly ContextAwareDictionary<Type, DataContract?> s_typeToBuiltInContract = new();
private static Dictionary<XmlQualifiedName, DataContract?>? s_nameToBuiltInContract;
private static Dictionary<string, DataContract?>? s_typeNameToBuiltInContract;
private static readonly Hashtable s_namespaces = new Hashtable();
private static Dictionary<string, XmlDictionaryString>? s_clrTypeStrings;
private static XmlDictionary? s_clrTypeStringsDictionary;

[ThreadStatic]
private static TypeHandleRef? s_typeHandleRef;

private static readonly object s_cacheLock = new object();
private static readonly object s_createDataContractLock = new object();
private static readonly object s_initBuiltInContractsLock = new object();
Expand Down Expand Up @@ -337,7 +335,7 @@ internal class DataContractCriticalHelper
[RequiresUnreferencedCode(DataContract.SerializerTrimmerWarning)]
internal static DataContract GetDataContractSkipValidation(int id, RuntimeTypeHandle typeHandle, Type? type)
{
DataContract dataContract = s_dataContractCache[id];
DataContract? dataContract = s_dataContractCache.GetItem(id);
if (dataContract == null)
{
dataContract = CreateDataContract(id, typeHandle, type);
Expand All @@ -353,13 +351,13 @@ internal static DataContract GetDataContractSkipValidation(int id, RuntimeTypeHa
[RequiresUnreferencedCode(DataContract.SerializerTrimmerWarning)]
internal static DataContract GetGetOnlyCollectionDataContractSkipValidation(int id, RuntimeTypeHandle typeHandle, Type? type)
{
DataContract dataContract = s_dataContractCache[id] ?? CreateGetOnlyCollectionDataContract(id, typeHandle, type);
DataContract dataContract = s_dataContractCache.GetItem(id) ?? CreateGetOnlyCollectionDataContract(id, typeHandle, type);
return dataContract;
}

internal static DataContract GetDataContractForInitialization(int id)
{
DataContract dataContract = s_dataContractCache[id];
DataContract? dataContract = s_dataContractCache.GetItem(id);
if (dataContract == null)
{
throw new SerializationException(SR.DataContractCacheOverflow);
Expand All @@ -370,59 +368,52 @@ internal static DataContract GetDataContractForInitialization(int id)
internal static int GetIdForInitialization(ClassDataContract classContract)
{
int id = DataContract.GetId(classContract.TypeForInitialization.TypeHandle);
if (id < s_dataContractCache.Length && ContractMatches(classContract, s_dataContractCache[id]))
if (id < s_dataContractCache.Length && ContractMatches(classContract, s_dataContractCache.GetItem(id)))
{
return id;
}
int currentDataContractId = DataContractCriticalHelper.s_dataContractID;
for (int i = 0; i < currentDataContractId; i++)
{
if (ContractMatches(classContract, s_dataContractCache[i]))
if (ContractMatches(classContract, s_dataContractCache.GetItem(id)))
{
return i;
}
}
throw new SerializationException(SR.DataContractCacheOverflow);
}

private static bool ContractMatches(DataContract contract, DataContract cachedContract)
private static bool ContractMatches(DataContract contract, DataContract? cachedContract)
{
return (cachedContract != null && cachedContract.UnderlyingType == contract.UnderlyingType);
}

internal static int GetId(RuntimeTypeHandle typeHandle)
{
typeHandle = GetDataContractAdapterTypeHandle(typeHandle);
s_typeHandleRef ??= new TypeHandleRef();
s_typeHandleRef.Value = typeHandle;

object? value = s_typeToIDCache[s_typeHandleRef];
if (value != null)
return ((IntRef)value).Value;
if (s_typeToIDCache.TryGetValue(typeHandle.Value, out int id))
return id;

try
{
lock (s_cacheLock)
{
value = s_typeToIDCache[s_typeHandleRef];
if (value != null)
return ((IntRef)value).Value;

int nextId = s_dataContractID++;
if (nextId >= s_dataContractCache.Length)
return s_typeToIDCache.GetOrAdd(typeHandle.Value, static _ =>
{
int newSize = (nextId < int.MaxValue / 2) ? nextId * 2 : int.MaxValue;
if (newSize <= nextId)
int nextId = s_dataContractID++;
if (nextId >= s_dataContractCache.Length)
{
Debug.Fail("DataContract cache overflow");
throw new SerializationException(SR.DataContractCacheOverflow);
int newSize = (nextId < int.MaxValue / 2) ? nextId * 2 : int.MaxValue;
if (newSize <= nextId)
{
Debug.Fail("DataContract cache overflow");
throw new SerializationException(SR.DataContractCacheOverflow);
}
s_dataContractCache.Resize(newSize);
}
Array.Resize<DataContract>(ref s_dataContractCache, newSize);
}
IntRef id = new IntRef(nextId);

s_typeToIDCache.Add(new TypeHandleRef(typeHandle), id);
return id.Value;
return nextId;
});
}
}
catch (Exception ex) when (!ExceptionUtility.IsFatal(ex))
Expand All @@ -436,12 +427,12 @@ internal static int GetId(RuntimeTypeHandle typeHandle)
[RequiresUnreferencedCode(DataContract.SerializerTrimmerWarning)]
private static DataContract CreateDataContract(int id, RuntimeTypeHandle typeHandle, Type? type)
{
DataContract? dataContract = s_dataContractCache[id];
DataContract? dataContract = s_dataContractCache.GetItem(id);
if (dataContract == null)
{
lock (s_createDataContractLock)
{
dataContract = s_dataContractCache[id];
dataContract = s_dataContractCache.GetItem(id);
if (dataContract == null)
{
type ??= Type.GetTypeFromHandle(typeHandle)!;
Expand Down Expand Up @@ -508,7 +499,7 @@ private static void AssignDataContractToId(DataContract dataContract, int id)
{
lock (s_cacheLock)
{
s_dataContractCache[id] = dataContract;
s_dataContractCache.SetItem(id, dataContract);
}
}

Expand All @@ -519,7 +510,7 @@ private static DataContract CreateGetOnlyCollectionDataContract(int id, RuntimeT
DataContract? dataContract = null;
lock (s_createDataContractLock)
{
dataContract = s_dataContractCache[id];
dataContract = s_dataContractCache.GetItem(id);
if (dataContract == null)
{
type ??= Type.GetTypeFromHandle(typeHandle)!;
Expand Down Expand Up @@ -589,17 +580,11 @@ private static RuntimeTypeHandle GetDataContractAdapterTypeHandle(RuntimeTypeHan
if (type.IsInterface && !CollectionDataContract.IsCollectionInterface(type))
type = Globals.TypeOfObject;

lock (s_initBuiltInContractsLock)
return s_typeToBuiltInContract.GetOrAdd(type, static (Type key) =>
{
s_typeToBuiltInContract ??= new Dictionary<Type, DataContract?>();

if (!s_typeToBuiltInContract.TryGetValue(type, out DataContract? dataContract))
{
TryCreateBuiltInDataContract(type, out dataContract);
s_typeToBuiltInContract.Add(type, dataContract);
}
TryCreateBuiltInDataContract(key, out DataContract? dataContract);
return dataContract;
}
});
}

[RequiresDynamicCode(DataContract.SerializerAOTWarning)]
Expand Down Expand Up @@ -945,19 +930,8 @@ internal static void ThrowInvalidDataContractException(string? message, Type? ty
{
if (type != null)
{
lock (s_cacheLock)
{
s_typeHandleRef ??= new TypeHandleRef();
s_typeHandleRef.Value = GetDataContractAdapterTypeHandle(type.TypeHandle);

if (s_typeToIDCache.ContainsKey(s_typeHandleRef))
{
lock (s_cacheLock)
{
s_typeToIDCache.Remove(s_typeHandleRef);
}
}
}
RuntimeTypeHandle runtimeTypeHandle = GetDataContractAdapterTypeHandle(type.TypeHandle);
s_typeToIDCache.TryRemove(runtimeTypeHandle.Value, out _);
}

throw new InvalidDataContractException(message);
Expand Down Expand Up @@ -2515,62 +2489,4 @@ public override int GetHashCode()
return _object1.GetHashCode() ^ _object2.GetHashCode();
}
}

internal sealed class HashTableEqualityComparer : IEqualityComparer
{
bool IEqualityComparer.Equals(object? x, object? y)
{
return ((TypeHandleRef)x!).Value.Equals(((TypeHandleRef)y!).Value);
}

public int GetHashCode(object obj)
{
return ((TypeHandleRef)obj).Value.GetHashCode();
}
}

internal sealed class TypeHandleRefEqualityComparer : IEqualityComparer<TypeHandleRef>
{
public bool Equals(TypeHandleRef? x, TypeHandleRef? y)
{
return x!.Value.Equals(y!.Value);
}

public int GetHashCode(TypeHandleRef obj)
{
return obj.Value.GetHashCode();
}
}

internal sealed class TypeHandleRef
{
private RuntimeTypeHandle _value;

public TypeHandleRef()
{
}

public TypeHandleRef(RuntimeTypeHandle value)
{
_value = value;
}

public RuntimeTypeHandle Value
{
get => _value;
set => _value = value;
}
}

internal sealed class IntRef
{
private readonly int _value;

public IntRef(int value)
{
_value = value;
}

public int Value => _value;
}
}
Loading