Skip to content

Avoid stack overflows in CompositeEndpointDataSource #44392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 18 additions & 28 deletions src/Http/Routing/src/CompositeEndpointDataSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ public sealed class CompositeEndpointDataSource : EndpointDataSource, IDisposabl

internal CompositeEndpointDataSource(ObservableCollection<EndpointDataSource> dataSources)
{
dataSources.CollectionChanged += OnDataSourcesChanged;
_dataSources = dataSources;
dataSources.CollectionChanged += OnDataSourcesChanged;
}

/// <summary>
Expand All @@ -38,15 +38,10 @@ internal CompositeEndpointDataSource(ObservableCollection<EndpointDataSource> da
/// <returns>A <see cref="CompositeEndpointDataSource"/>.</returns>
public CompositeEndpointDataSource(IEnumerable<EndpointDataSource> endpointDataSources)
{
_dataSources = new List<EndpointDataSource>();

foreach (var dataSource in endpointDataSources)
{
_dataSources.Add(dataSource);
}
_dataSources = new List<EndpointDataSource>(endpointDataSources);
}

private void OnDataSourcesChanged(object? sender, NotifyCollectionChangedEventArgs e) => HandleChange(collectionChanged: true);
private void OnDataSourcesChanged(object? sender, NotifyCollectionChangedEventArgs e) => HandleChange();

/// <summary>
/// Returns the collection of <see cref="EndpointDataSource"/> instances associated with the object.
Expand Down Expand Up @@ -183,11 +178,11 @@ private void EnsureChangeTokenInitialized()
}

// This is our first time initializing the change token, so the collection has "changed" from nothing.
CreateChangeTokenUnsynchronized(collectionChanged: true);
CreateChangeTokenUnsynchronized();
}
}

private void HandleChange(bool collectionChanged)
private void HandleChange()
{
CancellationTokenSource? oldTokenSource = null;
List<IDisposable>? oldChangeTokenRegistrations = null;
Expand All @@ -199,13 +194,7 @@ private void HandleChange(bool collectionChanged)
return;
}

// Prevent consumers from re-registering callback to in-flight events as that can
// cause a stack overflow.
// Example:
// 1. B registers A.
// 2. A fires event causing B's callback to get called.
// 3. B executes some code in its callback, but needs to re-register callback
// in the same callback.
// Register for new changes before disposing old registrations to ensure no changes are missed.
oldTokenSource = _cts;
oldChangeTokenRegistrations = _changeTokenRegistrations;

Expand All @@ -214,7 +203,7 @@ private void HandleChange(bool collectionChanged)
{
// We have to hook to any OnChange callbacks before caching endpoints,
// otherwise we might miss changes that occurred to one of the _dataSources after caching.
CreateChangeTokenUnsynchronized(collectionChanged);
CreateChangeTokenUnsynchronized();
}

// Don't update endpoints if no one has read them yet.
Expand All @@ -226,7 +215,7 @@ private void HandleChange(bool collectionChanged)
}

// Disposing registrations can block on user defined code on running on other threads that could try to acquire the _lock.
if (collectionChanged && oldChangeTokenRegistrations is not null)
if (oldChangeTokenRegistrations is not null)
{
foreach (var registration in oldChangeTokenRegistrations)
{
Expand All @@ -240,19 +229,15 @@ private void HandleChange(bool collectionChanged)
}

[MemberNotNull(nameof(_consumerChangeToken))]
private void CreateChangeTokenUnsynchronized(bool collectionChanged)
private void CreateChangeTokenUnsynchronized()
{
var cts = new CancellationTokenSource();

if (collectionChanged)
_changeTokenRegistrations = new();
foreach (var dataSource in _dataSources)
{
_changeTokenRegistrations = new();
foreach (var dataSource in _dataSources)
{
_changeTokenRegistrations.Add(ChangeToken.OnChange(
dataSource.GetChangeToken,
() => HandleChange(collectionChanged: false)));
}
_changeTokenRegistrations.Add(dataSource.GetChangeToken()
.RegisterChangeCallback(DispatchHandleChange, this));
}

_cts = cts;
Expand All @@ -274,6 +259,11 @@ private void CreateEndpointsUnsynchronized()
_endpoints = endpoints;
}

private static void DispatchHandleChange(object? state)
{
ThreadPool.UnsafeQueueUserWorkItem(static innerState => ((CompositeEndpointDataSource)innerState!).HandleChange(), state);
}

// Use private variable '_endpoints' to avoid initialization
private string DebuggerDisplayString => GetDebuggerDisplayStringForEndpoints(_endpoints);
}