Skip to content

Add IConnectionCompleteFeature.OnCompleted #9754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ public enum TransferFormat
}
namespace Microsoft.AspNetCore.Connections.Features
{
public partial interface IConnectionCompleteFeature
{
void OnCompleted(System.Func<object, System.Threading.Tasks.Task> callback, object state);
}
public partial interface IConnectionHeartbeatFeature
{
void OnHeartbeat(System.Action<object> action, object state);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Threading.Tasks;

namespace Microsoft.AspNetCore.Connections.Features
{
/// <summary>
/// Represents the completion action for a connection.
/// </summary>
public interface IConnectionCompleteFeature
{
/// <summary>
/// Registers a callback to be invoked after a connection has fully completed processing. This is
/// intended for resource cleanup.
/// </summary>
/// <param name="callback">The callback to invoke after the connection has completed processing.</param>
/// <param name="state">The state to pass into the callback.</param>
void OnCompleted(Func<object, Task> callback, object state);
}
}
2 changes: 2 additions & 0 deletions src/Servers/Kestrel/Core/src/Internal/ConnectionDispatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ private async Task Execute(KestrelConnection connection)
}
finally
{
await connectionContext.CompleteAsync();

Log.ConnectionStop(connectionContext.ConnectionId);
KestrelEventSource.Log.ConnectionStop(connectionContext);

Expand Down
47 changes: 47 additions & 0 deletions src/Servers/Kestrel/Core/test/ConnectionDispatcherTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal;
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.Logging;
using Moq;
using Xunit;

Expand Down Expand Up @@ -69,5 +71,50 @@ public async Task OnConnectionCompletesTransportPipesAfterReturning()
mockPipeWriter.Verify(m => m.Complete(It.IsAny<Exception>()), Times.Once());
mockPipeReader.Verify(m => m.Complete(It.IsAny<Exception>()), Times.Once());
}

[Fact]
public async Task OnConnectionFiresOnCompleted()
{
var serviceContext = new TestServiceContext();
var dispatcher = new ConnectionDispatcher(serviceContext, _ => Task.CompletedTask);

var connection = new Mock<TransportConnection> { CallBase = true }.Object;
connection.ConnectionClosed = new CancellationToken(canceled: true);
var completeFeature = connection.Features.Get<IConnectionCompleteFeature>();

Assert.NotNull(completeFeature);
object stateObject = new object();
object callbackState = null;
completeFeature.OnCompleted(state => { callbackState = state; return Task.CompletedTask; }, stateObject);

await dispatcher.OnConnection(connection);

Assert.Equal(stateObject, callbackState);
}

[Fact]
public async Task OnConnectionOnCompletedExceptionCaught()
{
var serviceContext = new TestServiceContext();
var dispatcher = new ConnectionDispatcher(serviceContext, _ => Task.CompletedTask);

var connection = new Mock<TransportConnection> { CallBase = true }.Object;
connection.ConnectionClosed = new CancellationToken(canceled: true);
var completeFeature = connection.Features.Get<IConnectionCompleteFeature>();
var mockLogger = new Mock<ILogger>();
connection.Logger = mockLogger.Object;

Assert.NotNull(completeFeature);
object stateObject = new object();
object callbackState = null;
completeFeature.OnCompleted(state => { callbackState = state; throw new InvalidTimeZoneException(); }, stateObject);

await dispatcher.OnConnection(connection);

Assert.Equal(stateObject, callbackState);
var log = mockLogger.Invocations.First();
Assert.Equal("An error occured running an IConnectionCompleteFeature.OnCompleted callback.", log.Arguments[2].ToString());
Assert.IsType<InvalidTimeZoneException>(log.Arguments[3]);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
<ItemGroup Condition="'$(TargetFramework)' == 'netcoreapp3.0'">
<Compile Include="Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.netcoreapp3.0.cs" />
<Reference Include="Microsoft.AspNetCore.Connections.Abstractions" />
<Reference Include="Microsoft.Extensions.Logging.Abstractions" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public enum SchedulingMode
ThreadPool = 1,
Inline = 2,
}
public abstract partial class TransportConnection : Microsoft.AspNetCore.Connections.ConnectionContext, Microsoft.AspNetCore.Connections.Features.IConnectionHeartbeatFeature, Microsoft.AspNetCore.Connections.Features.IConnectionIdFeature, Microsoft.AspNetCore.Connections.Features.IConnectionItemsFeature, Microsoft.AspNetCore.Connections.Features.IConnectionLifetimeFeature, Microsoft.AspNetCore.Connections.Features.IConnectionLifetimeNotificationFeature, Microsoft.AspNetCore.Connections.Features.IConnectionTransportFeature, Microsoft.AspNetCore.Connections.Features.IMemoryPoolFeature, Microsoft.AspNetCore.Http.Features.IFeatureCollection, Microsoft.AspNetCore.Http.Features.IHttpConnectionFeature, Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IApplicationTransportFeature, Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.ITransportSchedulerFeature, System.Collections.Generic.IEnumerable<System.Collections.Generic.KeyValuePair<System.Type, object>>, System.Collections.IEnumerable
public abstract partial class TransportConnection : Microsoft.AspNetCore.Connections.ConnectionContext, Microsoft.AspNetCore.Connections.Features.IConnectionCompleteFeature, Microsoft.AspNetCore.Connections.Features.IConnectionHeartbeatFeature, Microsoft.AspNetCore.Connections.Features.IConnectionIdFeature, Microsoft.AspNetCore.Connections.Features.IConnectionItemsFeature, Microsoft.AspNetCore.Connections.Features.IConnectionLifetimeFeature, Microsoft.AspNetCore.Connections.Features.IConnectionLifetimeNotificationFeature, Microsoft.AspNetCore.Connections.Features.IConnectionTransportFeature, Microsoft.AspNetCore.Connections.Features.IMemoryPoolFeature, Microsoft.AspNetCore.Http.Features.IFeatureCollection, Microsoft.AspNetCore.Http.Features.IHttpConnectionFeature, Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IApplicationTransportFeature, Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.ITransportSchedulerFeature, System.Collections.Generic.IEnumerable<System.Collections.Generic.KeyValuePair<System.Type, object>>, System.Collections.IEnumerable
{
protected readonly System.Threading.CancellationTokenSource _connectionClosingCts;
public TransportConnection() { }
Expand All @@ -73,6 +73,7 @@ public TransportConnection() { }
public override System.Collections.Generic.IDictionary<object, object> Items { get { throw null; } set { } }
public System.Net.IPAddress LocalAddress { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public int LocalPort { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
protected internal virtual Microsoft.Extensions.Logging.ILogger Logger { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public virtual System.Buffers.MemoryPool<byte> MemoryPool { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } }
System.Collections.Generic.IDictionary<object, object> Microsoft.AspNetCore.Connections.Features.IConnectionItemsFeature.Items { get { throw null; } set { } }
System.Threading.CancellationToken Microsoft.AspNetCore.Connections.Features.IConnectionLifetimeFeature.ConnectionClosed { get { throw null; } set { } }
Expand All @@ -96,6 +97,8 @@ public TransportConnection() { }
public int RemotePort { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public override System.IO.Pipelines.IDuplexPipe Transport { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public override void Abort(Microsoft.AspNetCore.Connections.ConnectionAbortedException abortReason) { }
public System.Threading.Tasks.Task CompleteAsync() { throw null; }
void Microsoft.AspNetCore.Connections.Features.IConnectionCompleteFeature.OnCompleted(System.Func<object, System.Threading.Tasks.Task> callback, object state) { }
void Microsoft.AspNetCore.Connections.Features.IConnectionHeartbeatFeature.OnHeartbeat(System.Action<object> action, object state) { }
void Microsoft.AspNetCore.Connections.Features.IConnectionLifetimeFeature.Abort() { }
void Microsoft.AspNetCore.Connections.Features.IConnectionLifetimeNotificationFeature.RequestClose() { }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.Extensions.Logging;

namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal
{
Expand All @@ -21,12 +24,16 @@ public partial class TransportConnection : IHttpConnectionFeature,
ITransportSchedulerFeature,
IConnectionLifetimeFeature,
IConnectionHeartbeatFeature,
IConnectionLifetimeNotificationFeature
IConnectionLifetimeNotificationFeature,
IConnectionCompleteFeature
{
// NOTE: When feature interfaces are added to or removed from this TransportConnection class implementation,
// then the list of `features` in the generated code project MUST also be updated.
// See also: tools/CodeGenerator/TransportConnectionFeatureCollection.cs

private Stack<KeyValuePair<Func<object, Task>, object>> _onCompleted;
private bool _completed;

string IHttpConnectionFeature.ConnectionId
{
get => ConnectionId;
Expand Down Expand Up @@ -100,5 +107,82 @@ void IConnectionHeartbeatFeature.OnHeartbeat(System.Action<object> action, objec
{
OnHeartbeat(action, state);
}

void IConnectionCompleteFeature.OnCompleted(Func<object, Task> callback, object state)
{
if (_completed)
{
throw new InvalidOperationException("The connection is already complete.");
}

if (_onCompleted == null)
{
_onCompleted = new Stack<KeyValuePair<Func<object, Task>, object>>();
}
_onCompleted.Push(new KeyValuePair<Func<object, Task>, object>(callback, state));
}

public Task CompleteAsync()
{
if (_completed)
{
throw new InvalidOperationException("The connection is already complete.");
}

_completed = true;
var onCompleted = _onCompleted;

if (onCompleted == null || onCompleted.Count == 0)
{
return Task.CompletedTask;
}

return CompleteAsyncMayAwait(onCompleted);
}

private Task CompleteAsyncMayAwait(Stack<KeyValuePair<Func<object, Task>, object>> onCompleted)
{
while (onCompleted.TryPop(out var entry))
{
try
{
var task = entry.Key.Invoke(entry.Value);
if (!ReferenceEquals(task, Task.CompletedTask))
{
return CompleteAsyncAwaited(task, onCompleted);
}
}
catch (Exception ex)
{
Logger?.LogError(ex, "An error occured running an IConnectionCompleteFeature.OnCompleted callback.");
}
}

return Task.CompletedTask;
}

private async Task CompleteAsyncAwaited(Task currentTask, Stack<KeyValuePair<Func<object, Task>, object>> onCompleted)
{
try
{
await currentTask;
}
catch (Exception ex)
{
Logger?.LogError(ex, "An error occured running an IConnectionCompleteFeature.OnCompleted callback.");
}

while (onCompleted.TryPop(out var entry))
{
try
{
await entry.Key.Invoke(entry.Value);
}
catch (Exception ex)
{
Logger?.LogError(ex, "An error occured running an IConnectionCompleteFeature.OnCompleted callback.");
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public partial class TransportConnection : IFeatureCollection
private static readonly Type IConnectionLifetimeFeatureType = typeof(IConnectionLifetimeFeature);
private static readonly Type IConnectionHeartbeatFeatureType = typeof(IConnectionHeartbeatFeature);
private static readonly Type IConnectionLifetimeNotificationFeatureType = typeof(IConnectionLifetimeNotificationFeature);
private static readonly Type IConnectionCompleteFeatureType = typeof(IConnectionCompleteFeature);

private object _currentIHttpConnectionFeature;
private object _currentIConnectionIdFeature;
Expand All @@ -33,6 +34,7 @@ public partial class TransportConnection : IFeatureCollection
private object _currentIConnectionLifetimeFeature;
private object _currentIConnectionHeartbeatFeature;
private object _currentIConnectionLifetimeNotificationFeature;
private object _currentIConnectionCompleteFeature;

private int _featureRevision;

Expand All @@ -50,6 +52,7 @@ private void FastReset()
_currentIConnectionLifetimeFeature = this;
_currentIConnectionHeartbeatFeature = this;
_currentIConnectionLifetimeNotificationFeature = this;
_currentIConnectionCompleteFeature = this;

}

Expand Down Expand Up @@ -145,6 +148,10 @@ object IFeatureCollection.this[Type key]
{
feature = _currentIConnectionLifetimeNotificationFeature;
}
else if (key == IConnectionCompleteFeatureType)
{
feature = _currentIConnectionCompleteFeature;
}
else if (MaybeExtra != null)
{
feature = ExtraFeatureGet(key);
Expand Down Expand Up @@ -197,6 +204,10 @@ object IFeatureCollection.this[Type key]
{
_currentIConnectionLifetimeNotificationFeature = value;
}
else if (key == IConnectionCompleteFeatureType)
{
_currentIConnectionCompleteFeature = value;
}
else
{
ExtraFeatureSet(key, value);
Expand Down Expand Up @@ -247,6 +258,10 @@ TFeature IFeatureCollection.Get<TFeature>()
{
feature = (TFeature)_currentIConnectionLifetimeNotificationFeature;
}
else if (typeof(TFeature) == typeof(IConnectionCompleteFeature))
{
feature = (TFeature)_currentIConnectionCompleteFeature;
}
else if (MaybeExtra != null)
{
feature = (TFeature)(ExtraFeatureGet(typeof(TFeature)));
Expand Down Expand Up @@ -298,6 +313,10 @@ void IFeatureCollection.Set<TFeature>(TFeature feature)
{
_currentIConnectionLifetimeNotificationFeature = feature;
}
else if (typeof(TFeature) == typeof(IConnectionCompleteFeature))
{
_currentIConnectionCompleteFeature = feature;
}
else
{
ExtraFeatureSet(typeof(TFeature), feature);
Expand Down Expand Up @@ -346,6 +365,10 @@ private IEnumerable<KeyValuePair<Type, object>> FastEnumerable()
{
yield return new KeyValuePair<Type, object>(IConnectionLifetimeNotificationFeatureType, _currentIConnectionLifetimeNotificationFeature);
}
if (_currentIConnectionCompleteFeature != null)
{
yield return new KeyValuePair<Type, object>(IConnectionCompleteFeatureType, _currentIConnectionCompleteFeature);
}

if (MaybeExtra != null)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
Expand All @@ -9,6 +9,7 @@
using System.Threading;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.Extensions.Logging;

namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal
{
Expand All @@ -35,6 +36,8 @@ public TransportConnection()

public override IFeatureCollection Features => this;

protected internal virtual ILogger Logger { get; set; }

public virtual MemoryPool<byte> MemoryPool { get; }
public virtual PipeScheduler InputWriterScheduler { get; }
public virtual PipeScheduler OutputReaderScheduler { get; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

<ItemGroup>
<Reference Include="Microsoft.AspNetCore.Connections.Abstractions" />
<Reference Include="Microsoft.Extensions.Logging.Abstractions" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public LibuvConnection(UvStreamHandle socket, ILibuvTrace log, LibuvThread threa
LocalPort = localEndPoint?.Port ?? 0;

ConnectionClosed = _connectionClosedTokenSource.Token;
Logger = log;
Log = log;
Thread = thread;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ internal SocketConnection(Socket socket, MemoryPool<byte> memoryPool, PipeSchedu
_socket = socket;
MemoryPool = memoryPool;
_scheduler = scheduler;
Logger = trace;
_trace = trace;

var localEndPoint = (IPEndPoint)_socket.LocalEndPoint;
Expand Down
Loading