diff --git a/src/Http/Headers/src/HeaderNames.cs b/src/Http/Headers/src/HeaderNames.cs
index 8b1b8af991b3..2528c481cd6f 100644
--- a/src/Http/Headers/src/HeaderNames.cs
+++ b/src/Http/Headers/src/HeaderNames.cs
@@ -225,6 +225,9 @@ public static class HeaderNames
/// Gets the Sec-WebSocket-Version HTTP header name.
public static readonly string SecWebSocketVersion = "Sec-WebSocket-Version";
+ /// Gets the Sec-WebSocket-Extensions HTTP header name.
+ public static readonly string SecWebSocketExtensions = "Sec-WebSocket-Extensions";
+
/// Gets the Server HTTP header name.
public static readonly string Server = "Server";
diff --git a/src/Http/Headers/src/PublicAPI.Unshipped.txt b/src/Http/Headers/src/PublicAPI.Unshipped.txt
index 9035747ce6c4..b5ecbe22ebbb 100644
--- a/src/Http/Headers/src/PublicAPI.Unshipped.txt
+++ b/src/Http/Headers/src/PublicAPI.Unshipped.txt
@@ -5,6 +5,7 @@ Microsoft.Net.Http.Headers.RangeConditionHeaderValue.RangeConditionHeaderValue(M
static readonly Microsoft.Net.Http.Headers.HeaderNames.Baggage -> string!
static readonly Microsoft.Net.Http.Headers.HeaderNames.Link -> string!
static readonly Microsoft.Net.Http.Headers.HeaderNames.ProxyConnection -> string!
+static readonly Microsoft.Net.Http.Headers.HeaderNames.SecWebSocketExtensions -> string!
static readonly Microsoft.Net.Http.Headers.HeaderNames.XContentTypeOptions -> string!
static readonly Microsoft.Net.Http.Headers.HeaderNames.XPoweredBy -> string!
static readonly Microsoft.Net.Http.Headers.HeaderNames.XUACompatible -> string!
diff --git a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt
index 687f1f1e034b..0d6ad2ae7320 100644
--- a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt
+++ b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt
@@ -23,3 +23,4 @@ abstract Microsoft.AspNetCore.Http.HttpRequest.ContentType.get -> string?
static Microsoft.AspNetCore.Builder.UseExtensions.Use(this Microsoft.AspNetCore.Builder.IApplicationBuilder! app, System.Func! middleware) -> Microsoft.AspNetCore.Builder.IApplicationBuilder!
static Microsoft.AspNetCore.Builder.UseMiddlewareExtensions.UseMiddleware(this Microsoft.AspNetCore.Builder.IApplicationBuilder! app, System.Type! middleware, params object?[]! args) -> Microsoft.AspNetCore.Builder.IApplicationBuilder!
static Microsoft.AspNetCore.Builder.UseMiddlewareExtensions.UseMiddleware(this Microsoft.AspNetCore.Builder.IApplicationBuilder! app, params object?[]! args) -> Microsoft.AspNetCore.Builder.IApplicationBuilder!
+virtual Microsoft.AspNetCore.Http.WebSocketManager.AcceptWebSocketAsync(Microsoft.AspNetCore.Http.WebSocketAcceptContext! acceptContext) -> System.Threading.Tasks.Task!
diff --git a/src/Http/Http.Abstractions/src/WebSocketManager.cs b/src/Http/Http.Abstractions/src/WebSocketManager.cs
index 1fe47d59587d..79fd84bc184c 100644
--- a/src/Http/Http.Abstractions/src/WebSocketManager.cs
+++ b/src/Http/Http.Abstractions/src/WebSocketManager.cs
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+using System;
using System.Collections.Generic;
using System.Net.WebSockets;
using System.Threading.Tasks;
@@ -8,7 +9,7 @@
namespace Microsoft.AspNetCore.Http
{
///
- /// Manages the establishment of WebSocket connections for a specific HTTP request.
+ /// Manages the establishment of WebSocket connections for a specific HTTP request.
///
public abstract class WebSocketManager
{
@@ -37,5 +38,12 @@ public virtual Task AcceptWebSocketAsync()
/// The sub-protocol to use.
/// A task representing the completion of the transition.
public abstract Task AcceptWebSocketAsync(string? subProtocol);
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ public virtual Task AcceptWebSocketAsync(WebSocketAcceptContext acceptContext) => throw new NotImplementedException();
}
}
diff --git a/src/Http/Http.Features/src/IHeaderDictionary.Keyed.cs b/src/Http/Http.Features/src/IHeaderDictionary.Keyed.cs
index 7f7739599ad9..fd98c8ae4d1d 100644
--- a/src/Http/Http.Features/src/IHeaderDictionary.Keyed.cs
+++ b/src/Http/Http.Features/src/IHeaderDictionary.Keyed.cs
@@ -202,6 +202,9 @@ public partial interface IHeaderDictionary
/// Gets or sets the Sec-WebSocket-Version HTTP header.
StringValues SecWebSocketVersion { get => this[HeaderNames.SecWebSocketVersion]; set => this[HeaderNames.SecWebSocketVersion] = value; }
+ /// Gets or sets the Sec-WebSocket-Extensions HTTP header.
+ StringValues SecWebSocketExtensions { get => this[HeaderNames.SecWebSocketExtensions]; set => this[HeaderNames.SecWebSocketExtensions] = value; }
+
/// Gets or sets the Server HTTP header.
StringValues Server { get => this[HeaderNames.Server]; set => this[HeaderNames.Server] = value; }
diff --git a/src/Http/Http.Features/src/PublicAPI.Unshipped.txt b/src/Http/Http.Features/src/PublicAPI.Unshipped.txt
index 6911dbe5b627..a391a2f049ad 100644
--- a/src/Http/Http.Features/src/PublicAPI.Unshipped.txt
+++ b/src/Http/Http.Features/src/PublicAPI.Unshipped.txt
@@ -161,6 +161,8 @@ Microsoft.AspNetCore.Http.IHeaderDictionary.RetryAfter.get -> Microsoft.Extensio
Microsoft.AspNetCore.Http.IHeaderDictionary.RetryAfter.set -> void
Microsoft.AspNetCore.Http.IHeaderDictionary.SecWebSocketAccept.get -> Microsoft.Extensions.Primitives.StringValues
Microsoft.AspNetCore.Http.IHeaderDictionary.SecWebSocketAccept.set -> void
+Microsoft.AspNetCore.Http.IHeaderDictionary.SecWebSocketExtensions.get -> Microsoft.Extensions.Primitives.StringValues
+Microsoft.AspNetCore.Http.IHeaderDictionary.SecWebSocketExtensions.set -> void
Microsoft.AspNetCore.Http.IHeaderDictionary.SecWebSocketKey.get -> Microsoft.Extensions.Primitives.StringValues
Microsoft.AspNetCore.Http.IHeaderDictionary.SecWebSocketKey.set -> void
Microsoft.AspNetCore.Http.IHeaderDictionary.SecWebSocketProtocol.get -> Microsoft.Extensions.Primitives.StringValues
@@ -232,6 +234,14 @@ Microsoft.AspNetCore.Http.Features.FeatureCollection.IsReadOnly.get -> bool (for
Microsoft.AspNetCore.Http.Features.FeatureCollection.Set(TFeature? instance) -> void (forwarded, contained in Microsoft.Extensions.Features)
Microsoft.AspNetCore.Http.Features.FeatureCollection.this[System.Type! key].get -> object? (forwarded, contained in Microsoft.Extensions.Features)
Microsoft.AspNetCore.Http.Features.FeatureCollection.this[System.Type! key].set -> void (forwarded, contained in Microsoft.Extensions.Features)
+Microsoft.AspNetCore.Http.WebSocketAcceptContext.DangerousEnableCompression.get -> bool
+Microsoft.AspNetCore.Http.WebSocketAcceptContext.DangerousEnableCompression.set -> void
+Microsoft.AspNetCore.Http.WebSocketAcceptContext.DisableServerContextTakeover.get -> bool
+Microsoft.AspNetCore.Http.WebSocketAcceptContext.DisableServerContextTakeover.set -> void
+Microsoft.AspNetCore.Http.WebSocketAcceptContext.ServerMaxWindowBits.get -> int
+Microsoft.AspNetCore.Http.WebSocketAcceptContext.ServerMaxWindowBits.set -> void
virtual Microsoft.AspNetCore.Http.Features.FeatureCollection.Revision.get -> int (forwarded, contained in Microsoft.Extensions.Features)
+virtual Microsoft.AspNetCore.Http.WebSocketAcceptContext.KeepAliveInterval.get -> System.TimeSpan?
+virtual Microsoft.AspNetCore.Http.WebSocketAcceptContext.KeepAliveInterval.set -> void
~Microsoft.AspNetCore.Http.Features.FeatureReference<> (forwarded, contained in Microsoft.Extensions.Features)
~Microsoft.AspNetCore.Http.Features.FeatureReferences<> (forwarded, contained in Microsoft.Extensions.Features)
diff --git a/src/Http/Http.Features/src/WebSocketAcceptContext.cs b/src/Http/Http.Features/src/WebSocketAcceptContext.cs
index b293ad874be7..5f2c9d7a049a 100644
--- a/src/Http/Http.Features/src/WebSocketAcceptContext.cs
+++ b/src/Http/Http.Features/src/WebSocketAcceptContext.cs
@@ -1,7 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-using Microsoft.AspNetCore.Http.Features;
+using System;
+using System.Net.WebSockets;
namespace Microsoft.AspNetCore.Http
{
@@ -10,9 +11,63 @@ namespace Microsoft.AspNetCore.Http
///
public class WebSocketAcceptContext
{
+ private int _serverMaxWindowBits = 15;
+
///
/// Gets or sets the subprotocol being negotiated.
///
public virtual string? SubProtocol { get; set; }
+
+ ///
+ /// The interval to send pong frames. This is a heart-beat that keeps the connection alive.
+ ///
+ public virtual TimeSpan? KeepAliveInterval { get; set; }
+
+ ///
+ /// Enables support for the 'permessage-deflate' WebSocket extension.
+ /// Be aware that enabling compression over encrypted connections makes the application subject to CRIME/BREACH type attacks.
+ /// It is strongly advised to turn off compression when sending data containing secrets by
+ /// specifying when sending such messages.
+ ///
+ public bool DangerousEnableCompression { get; set; }
+
+ ///
+ /// Disables server context takeover when using compression.
+ /// This setting reduces the memory overhead of compression at the cost of a potentially worse compresson ratio.
+ ///
+ ///
+ /// This property does nothing when is false,
+ /// or when the client does not use compression.
+ ///
+ ///
+ /// false
+ ///
+ public bool DisableServerContextTakeover { get; set; }
+
+ ///
+ /// Sets the maximum base-2 logarithm of the LZ77 sliding window size that can be used for compression.
+ /// This setting reduces the memory overhead of compression at the cost of a potentially worse compresson ratio.
+ ///
+ ///
+ /// This property does nothing when is false,
+ /// or when the client does not use compression.
+ /// Valid values are 9 through 15.
+ ///
+ ///
+ /// 15
+ ///
+ public int ServerMaxWindowBits
+ {
+ get => _serverMaxWindowBits;
+ set
+ {
+ if (value < 9 || value > 15)
+ {
+ throw new ArgumentOutOfRangeException(nameof(ServerMaxWindowBits),
+ "The argument must be a value from 9 to 15.");
+ }
+ _serverMaxWindowBits = value;
+ }
+ }
}
}
diff --git a/src/Http/Http/src/Internal/DefaultWebSocketManager.cs b/src/Http/Http/src/Internal/DefaultWebSocketManager.cs
index 43ad4c0907ea..8697904b952c 100644
--- a/src/Http/Http/src/Internal/DefaultWebSocketManager.cs
+++ b/src/Http/Http/src/Internal/DefaultWebSocketManager.cs
@@ -17,6 +17,7 @@ internal sealed class DefaultWebSocketManager : WebSocketManager
private readonly static Func _nullWebSocketFeature = f => null;
private FeatureReferences _features;
+ private readonly static WebSocketAcceptContext _defaultWebSocketAcceptContext = new WebSocketAcceptContext();
public DefaultWebSocketManager(IFeatureCollection features)
{
@@ -61,12 +62,19 @@ public override IList WebSocketRequestedProtocols
}
public override Task AcceptWebSocketAsync(string? subProtocol)
+ {
+ var acceptContext = subProtocol is null ? _defaultWebSocketAcceptContext :
+ new WebSocketAcceptContext() { SubProtocol = subProtocol };
+ return AcceptWebSocketAsync(acceptContext);
+ }
+
+ public override Task AcceptWebSocketAsync(WebSocketAcceptContext acceptContext)
{
if (WebSocketFeature == null)
{
throw new NotSupportedException("WebSockets are not supported");
}
- return WebSocketFeature.AcceptAsync(new WebSocketAcceptContext() { SubProtocol = subProtocol });
+ return WebSocketFeature.AcceptAsync(acceptContext);
}
struct FeatureInterfaces
diff --git a/src/Middleware/WebSockets/src/ExtendedWebSocketAcceptContext.cs b/src/Middleware/WebSockets/src/ExtendedWebSocketAcceptContext.cs
index 7bfd7de4963a..2e85ef2d41f8 100644
--- a/src/Middleware/WebSockets/src/ExtendedWebSocketAcceptContext.cs
+++ b/src/Middleware/WebSockets/src/ExtendedWebSocketAcceptContext.cs
@@ -9,6 +9,7 @@ namespace Microsoft.AspNetCore.WebSockets
///
/// Extends the class with additional properties.
///
+ [Obsolete("This type is obsolete and will be removed in a future version. The recommended alternative is Microsoft.AspNetCore.Http.WebSocketAcceptContext.")]
public class ExtendedWebSocketAcceptContext : WebSocketAcceptContext
{
///
@@ -23,6 +24,6 @@ public class ExtendedWebSocketAcceptContext : WebSocketAcceptContext
///
/// The interval to send pong frames. This is a heart-beat that keeps the connection alive.
///
- public TimeSpan? KeepAliveInterval { get; set; }
+ public new TimeSpan? KeepAliveInterval { get; set; }
}
}
diff --git a/src/Middleware/WebSockets/src/HandshakeHelpers.cs b/src/Middleware/WebSockets/src/HandshakeHelpers.cs
index 05c5ac5363a3..a242a281154a 100644
--- a/src/Middleware/WebSockets/src/HandshakeHelpers.cs
+++ b/src/Middleware/WebSockets/src/HandshakeHelpers.cs
@@ -2,6 +2,10 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Globalization;
+using System.Net.WebSockets;
using System.Security.Cryptography;
using System.Text;
using Microsoft.AspNetCore.Http;
@@ -72,5 +76,210 @@ public static string CreateResponseKey(string requestKey)
return Convert.ToBase64String(hashedBytes);
}
+
+ // https://datatracker.ietf.org/doc/html/rfc7692#section-7.1
+ public static bool ParseDeflateOptions(ReadOnlySpan extension, bool serverContextTakeover,
+ int serverMaxWindowBits, out WebSocketDeflateOptions parsedOptions, [NotNullWhen(true)] out string? response)
+ {
+ bool hasServerMaxWindowBits = false;
+ bool hasClientMaxWindowBits = false;
+ bool hasClientNoContext = false;
+ bool hasServerNoContext = false;
+ response = null;
+ parsedOptions = new WebSocketDeflateOptions()
+ {
+ ServerContextTakeover = serverContextTakeover,
+ ServerMaxWindowBits = serverMaxWindowBits
+ };
+
+ using var builder = new ValueStringBuilder(WebSocketDeflateConstants.MaxExtensionLength);
+ builder.Append(WebSocketDeflateConstants.Extension);
+
+ while (true)
+ {
+ int end = extension.IndexOf(';');
+ ReadOnlySpan value = (end >= 0 ? extension[..end] : extension).Trim();
+
+ if (value.Length == 0)
+ {
+ break;
+ }
+
+ if (value.SequenceEqual(WebSocketDeflateConstants.ClientNoContextTakeover))
+ {
+ // https://datatracker.ietf.org/doc/html/rfc7692#section-7
+ // MUST decline if:
+ // The negotiation offer contains multiple extension parameters with
+ // the same name.
+ if (hasClientNoContext)
+ {
+ return false;
+ }
+
+ hasClientNoContext = true;
+ parsedOptions.ClientContextTakeover = false;
+ builder.Append(';');
+ builder.Append(' ');
+ builder.Append(WebSocketDeflateConstants.ClientNoContextTakeover);
+ }
+ else if (value.SequenceEqual(WebSocketDeflateConstants.ServerNoContextTakeover))
+ {
+ // https://datatracker.ietf.org/doc/html/rfc7692#section-7
+ // MUST decline if:
+ // The negotiation offer contains multiple extension parameters with
+ // the same name.
+ if (hasServerNoContext)
+ {
+ return false;
+ }
+
+ hasServerNoContext = true;
+ parsedOptions.ServerContextTakeover = false;
+ }
+ else if (value.StartsWith(WebSocketDeflateConstants.ClientMaxWindowBits))
+ {
+ // https://datatracker.ietf.org/doc/html/rfc7692#section-7
+ // MUST decline if:
+ // The negotiation offer contains multiple extension parameters with
+ // the same name.
+ if (hasClientMaxWindowBits)
+ {
+ return false;
+ }
+
+ hasClientMaxWindowBits = true;
+ if (!ParseWindowBits(value, WebSocketDeflateConstants.ClientMaxWindowBits, out var clientMaxWindowBits))
+ {
+ return false;
+ }
+
+ // 8 is a valid value according to the spec, but our zlib implementation does not support it
+ if (clientMaxWindowBits == 8)
+ {
+ return false;
+ }
+
+ // https://tools.ietf.org/html/rfc7692#section-7.1.2.2
+ // the server may either ignore this
+ // value or use this value to avoid allocating an unnecessarily big LZ77
+ // sliding window by including the "client_max_window_bits" extension
+ // parameter in the corresponding extension negotiation response to the
+ // offer with a value equal to or smaller than the received value.
+ parsedOptions.ClientMaxWindowBits = clientMaxWindowBits ?? 15;
+
+ // If a received extension negotiation offer doesn't have the
+ // "client_max_window_bits" extension parameter, the corresponding
+ // extension negotiation response to the offer MUST NOT include the
+ // "client_max_window_bits" extension parameter.
+ builder.Append(';');
+ builder.Append(' ');
+ builder.Append(WebSocketDeflateConstants.ClientMaxWindowBits);
+ builder.Append('=');
+ var len = (parsedOptions.ClientMaxWindowBits > 9) ? 2 : 1;
+ var span = builder.AppendSpan(len);
+ var ret = parsedOptions.ClientMaxWindowBits.TryFormat(span, out var written);
+ Debug.Assert(ret);
+ Debug.Assert(written == len);
+ }
+ else if (value.StartsWith(WebSocketDeflateConstants.ServerMaxWindowBits))
+ {
+ // https://datatracker.ietf.org/doc/html/rfc7692#section-7
+ // MUST decline if:
+ // The negotiation offer contains multiple extension parameters with
+ // the same name.
+ if (hasServerMaxWindowBits)
+ {
+ return false;
+ }
+
+ hasServerMaxWindowBits = true;
+ if (!ParseWindowBits(value, WebSocketDeflateConstants.ServerMaxWindowBits, out var parsedServerMaxWindowBits))
+ {
+ return false;
+ }
+
+ // 8 is a valid value according to the spec, but our zlib implementation does not support it
+ if (parsedServerMaxWindowBits == 8)
+ {
+ return false;
+ }
+
+ // https://tools.ietf.org/html/rfc7692#section-7.1.2.1
+ // A server accepts an extension negotiation offer with this parameter
+ // by including the "server_max_window_bits" extension parameter in the
+ // extension negotiation response to send back to the client with the
+ // same or smaller value as the offer.
+ parsedOptions.ServerMaxWindowBits = Math.Min(parsedServerMaxWindowBits ?? 15, serverMaxWindowBits);
+ }
+
+ static bool ParseWindowBits(ReadOnlySpan value, string propertyName, out int? parsedValue)
+ {
+ var startIndex = value.IndexOf('=');
+
+ // parameters can be sent without a value by the client, we'll use the values set by the app developer or the default of 15
+ if (startIndex < 0)
+ {
+ parsedValue = null;
+ return true;
+ }
+
+ value = value[(startIndex + 1)..].TrimEnd();
+
+ if (value.Length == 0)
+ {
+ parsedValue = null;
+ return false;
+ }
+
+ // https://datatracker.ietf.org/doc/html/rfc7692#section-5.2
+ // check for value in quotes and pull the value out without the quotes
+ if (value[0] == '"' && value.EndsWith("\"".AsSpan()) && value.Length > 1)
+ {
+ value = value[1..^1];
+ }
+
+ if (!int.TryParse(value, NumberStyles.Integer, CultureInfo.InvariantCulture, out int windowBits) ||
+ windowBits < 8 ||
+ windowBits > 15)
+ {
+ parsedValue = null;
+ return false;
+ }
+
+ parsedValue = windowBits;
+ return true;
+ }
+
+ if (end < 0)
+ {
+ break;
+ }
+ extension = extension[(end + 1)..];
+ }
+
+ if (!parsedOptions.ServerContextTakeover)
+ {
+ builder.Append(';');
+ builder.Append(' ');
+ builder.Append(WebSocketDeflateConstants.ServerNoContextTakeover);
+ }
+
+ if (hasServerMaxWindowBits || parsedOptions.ServerMaxWindowBits != 15)
+ {
+ builder.Append(';');
+ builder.Append(' ');
+ builder.Append(WebSocketDeflateConstants.ServerMaxWindowBits);
+ builder.Append('=');
+ var len = (parsedOptions.ServerMaxWindowBits > 9) ? 2 : 1;
+ var span = builder.AppendSpan(len);
+ var ret = parsedOptions.ServerMaxWindowBits.TryFormat(span, out var written);
+ Debug.Assert(ret);
+ Debug.Assert(written == len);
+ }
+
+ response = builder.ToString();
+
+ return true;
+ }
}
}
diff --git a/src/Middleware/WebSockets/src/Microsoft.AspNetCore.WebSockets.csproj b/src/Middleware/WebSockets/src/Microsoft.AspNetCore.WebSockets.csproj
index 931295880832..1443b5fae957 100644
--- a/src/Middleware/WebSockets/src/Microsoft.AspNetCore.WebSockets.csproj
+++ b/src/Middleware/WebSockets/src/Microsoft.AspNetCore.WebSockets.csproj
@@ -17,6 +17,10 @@
+
+
+
+
diff --git a/src/Middleware/WebSockets/src/WebSocketDeflateConstants.cs b/src/Middleware/WebSockets/src/WebSocketDeflateConstants.cs
new file mode 100644
index 000000000000..aff0a93dccf4
--- /dev/null
+++ b/src/Middleware/WebSockets/src/WebSocketDeflateConstants.cs
@@ -0,0 +1,23 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+namespace Microsoft.AspNetCore.WebSockets
+{
+ internal static class WebSocketDeflateConstants
+ {
+ ///
+ /// The maximum length that this extension can have, assuming that we're not using extra white space.
+ ///
+ /// "permessage-deflate; client_max_window_bits=15; client_no_context_takeover; server_max_window_bits=15; server_no_context_takeover"
+ ///
+ public const int MaxExtensionLength = 128;
+
+ public const string Extension = "permessage-deflate";
+
+ public const string ClientMaxWindowBits = "client_max_window_bits";
+ public const string ClientNoContextTakeover = "client_no_context_takeover";
+
+ public const string ServerMaxWindowBits = "server_max_window_bits";
+ public const string ServerNoContextTakeover = "server_no_context_takeover";
+ }
+}
diff --git a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs
index 72fbc077b33f..c640ed4b368d 100644
--- a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs
+++ b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs
@@ -68,7 +68,7 @@ public Task Invoke(HttpContext context)
var upgradeFeature = context.Features.Get();
if (upgradeFeature != null && context.Features.Get() == null)
{
- var webSocketFeature = new UpgradeHandshake(context, upgradeFeature, _options);
+ var webSocketFeature = new UpgradeHandshake(context, upgradeFeature, _options, _logger);
context.Features.Set(webSocketFeature);
if (!_anyOriginAllowed)
@@ -97,13 +97,15 @@ private class UpgradeHandshake : IHttpWebSocketFeature
private readonly HttpContext _context;
private readonly IHttpUpgradeFeature _upgradeFeature;
private readonly WebSocketOptions _options;
+ private readonly ILogger _logger;
private bool? _isWebSocketRequest;
- public UpgradeHandshake(HttpContext context, IHttpUpgradeFeature upgradeFeature, WebSocketOptions options)
+ public UpgradeHandshake(HttpContext context, IHttpUpgradeFeature upgradeFeature, WebSocketOptions options, ILogger logger)
{
_context = context;
_upgradeFeature = upgradeFeature;
_options = options;
+ _logger = logger;
}
public bool IsWebSocketRequest
@@ -133,13 +135,22 @@ public async Task AcceptAsync(WebSocketAcceptContext acceptContext)
}
string? subProtocol = null;
+ bool enableCompression = false;
+ bool serverContextTakeover = true;
+ int serverMaxWindowBits = 15;
+ TimeSpan keepAliveInterval = _options.KeepAliveInterval;
if (acceptContext != null)
{
subProtocol = acceptContext.SubProtocol;
+ enableCompression = acceptContext.DangerousEnableCompression;
+ serverContextTakeover = !acceptContext.DisableServerContextTakeover;
+ serverMaxWindowBits = acceptContext.ServerMaxWindowBits;
+ keepAliveInterval = acceptContext.KeepAliveInterval ?? keepAliveInterval;
}
- TimeSpan keepAliveInterval = _options.KeepAliveInterval;
+#pragma warning disable CS0618 // Type or member is obsolete
if (acceptContext is ExtendedWebSocketAcceptContext advancedAcceptContext)
+#pragma warning restore CS0618 // Type or member is obsolete
{
if (advancedAcceptContext.KeepAliveInterval.HasValue)
{
@@ -151,9 +162,45 @@ public async Task AcceptAsync(WebSocketAcceptContext acceptContext)
HandshakeHelpers.GenerateResponseHeaders(key, subProtocol, _context.Response.Headers);
+ WebSocketDeflateOptions? deflateOptions = null;
+ if (enableCompression)
+ {
+ var ext = _context.Request.Headers.SecWebSocketExtensions;
+ if (ext.Count != 0)
+ {
+ // loop over each extension offer, extensions can have multiple offers, we can accept any
+ foreach (var extension in _context.Request.Headers.GetCommaSeparatedValues(HeaderNames.SecWebSocketExtensions))
+ {
+ if (extension.AsSpan().TrimStart().StartsWith("permessage-deflate", StringComparison.Ordinal))
+ {
+ if (HandshakeHelpers.ParseDeflateOptions(extension.AsSpan().TrimStart(), serverContextTakeover, serverMaxWindowBits, out var parsedOptions, out var response))
+ {
+ Log.CompressionAccepted(_logger, response);
+ deflateOptions = parsedOptions;
+ // If more extension types are added, this would need to be a header append
+ // and we wouldn't want to break out of the loop
+ _context.Response.Headers.SecWebSocketExtensions = response;
+ break;
+ }
+ }
+ }
+
+ if (deflateOptions is null)
+ {
+ Log.CompressionNotAccepted(_logger);
+ }
+ }
+ }
+
Stream opaqueTransport = await _upgradeFeature.UpgradeAsync(); // Sets status code to 101
- return WebSocket.CreateFromStream(opaqueTransport, isServer: true, subProtocol: subProtocol, keepAliveInterval: keepAliveInterval);
+ return WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
+ {
+ IsServer = true,
+ KeepAliveInterval = keepAliveInterval,
+ SubProtocol = subProtocol,
+ DangerousDeflateOptions = deflateOptions
+ });
}
public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders)
@@ -227,5 +274,24 @@ public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictiona
return HandshakeHelpers.IsRequestKeyValid(requestHeaders.SecWebSocketKey.ToString());
}
}
+
+ private static class Log
+ {
+ private static readonly Action _compressionAccepted =
+ LoggerMessage.Define(LogLevel.Debug, new EventId(1, "CompressionAccepted"), "WebSocket compression negotiation accepted with values '{CompressionResponse}'.");
+
+ private static readonly Action _compressionNotAccepted =
+ LoggerMessage.Define(LogLevel.Debug, new EventId(2, "CompressionNotAccepted"), "Compression negotiation not accepted by server.");
+
+ public static void CompressionAccepted(ILogger logger, string response)
+ {
+ _compressionAccepted(logger, response, null);
+ }
+
+ public static void CompressionNotAccepted(ILogger logger)
+ {
+ _compressionNotAccepted(logger, null);
+ }
+ }
}
}
diff --git a/src/Middleware/WebSockets/test/ConformanceTests/AutobahnTestApp/Startup.cs b/src/Middleware/WebSockets/test/ConformanceTests/AutobahnTestApp/Startup.cs
index f8b75268c28b..c351ae285e6f 100644
--- a/src/Middleware/WebSockets/test/ConformanceTests/AutobahnTestApp/Startup.cs
+++ b/src/Middleware/WebSockets/test/ConformanceTests/AutobahnTestApp/Startup.cs
@@ -22,7 +22,10 @@ public void Configure(IApplicationBuilder app, ILoggerFactory loggerFactory)
if (context.WebSockets.IsWebSocketRequest)
{
logger.LogInformation("Received WebSocket request");
- using (var webSocket = await context.WebSockets.AcceptWebSocketAsync())
+ using (var webSocket = await context.WebSockets.AcceptWebSocketAsync(new WebSocketAcceptContext()
+ {
+ DangerousEnableCompression = true
+ }))
{
await Echo(webSocket, context.RequestAborted);
}
diff --git a/src/Middleware/WebSockets/test/UnitTests/HandshakeTests.cs b/src/Middleware/WebSockets/test/UnitTests/HandshakeTests.cs
index ec19793ddc2e..163fc71db4c3 100644
--- a/src/Middleware/WebSockets/test/UnitTests/HandshakeTests.cs
+++ b/src/Middleware/WebSockets/test/UnitTests/HandshakeTests.cs
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+using System;
using Xunit;
namespace Microsoft.AspNetCore.WebSockets.Tests
@@ -37,5 +38,62 @@ public void RejectsInvalidRequestKeys(string key)
{
Assert.False(HandshakeHelpers.IsRequestKeyValid(key));
}
+
+ [Theory]
+ [InlineData("permessage-deflate", "permessage-deflate")]
+ [InlineData("permessage-deflate; server_no_context_takeover", "permessage-deflate; server_no_context_takeover")]
+ [InlineData("permessage-deflate; client_no_context_takeover", "permessage-deflate; client_no_context_takeover")]
+ [InlineData("permessage-deflate; client_max_window_bits=9", "permessage-deflate; client_max_window_bits=9")]
+ [InlineData("permessage-deflate; client_max_window_bits=\"9\"", "permessage-deflate; client_max_window_bits=9")]
+ [InlineData("permessage-deflate; client_max_window_bits", "permessage-deflate; client_max_window_bits=15")]
+ [InlineData("permessage-deflate; server_max_window_bits", "permessage-deflate; server_max_window_bits=15")]
+ [InlineData("permessage-deflate; server_max_window_bits=10", "permessage-deflate; server_max_window_bits=10")]
+ [InlineData("permessage-deflate; server_max_window_bits=10; server_no_context_takeover", "permessage-deflate; server_no_context_takeover; server_max_window_bits=10")]
+ [InlineData("permessage-deflate; server_max_window_bits=10; server_no_context_takeover; client_no_context_takeover; client_max_window_bits=12", "permessage-deflate; client_no_context_takeover; client_max_window_bits=12; server_no_context_takeover; server_max_window_bits=10")]
+ public void CompressionNegotiationProducesCorrectHeaderWithDefaultOptions(string clientHeader, string expectedResponse)
+ {
+ Assert.True(HandshakeHelpers.ParseDeflateOptions(clientHeader.AsSpan(), serverContextTakeover: true, serverMaxWindowBits: 15,
+ out var _, out var response));
+ Assert.Equal(expectedResponse, response);
+ }
+
+ [Theory]
+ [InlineData("permessage-deflate", "permessage-deflate; server_no_context_takeover; server_max_window_bits=14")]
+ [InlineData("permessage-deflate; server_no_context_takeover", "permessage-deflate; server_no_context_takeover; server_max_window_bits=14")]
+ [InlineData("permessage-deflate; client_no_context_takeover", "permessage-deflate; client_no_context_takeover; server_no_context_takeover; server_max_window_bits=14")]
+ [InlineData("permessage-deflate; client_max_window_bits=9", "permessage-deflate; client_max_window_bits=9; server_no_context_takeover; server_max_window_bits=14")]
+ [InlineData("permessage-deflate; client_max_window_bits", "permessage-deflate; client_max_window_bits=15; server_no_context_takeover; server_max_window_bits=14")]
+ [InlineData("permessage-deflate; server_max_window_bits", "permessage-deflate; server_no_context_takeover; server_max_window_bits=14")]
+ [InlineData("permessage-deflate; server_max_window_bits=14", "permessage-deflate; server_no_context_takeover; server_max_window_bits=14")]
+ [InlineData("permessage-deflate; server_max_window_bits=10", "permessage-deflate; server_no_context_takeover; server_max_window_bits=10")]
+ [InlineData("permessage-deflate; server_max_window_bits=10; server_no_context_takeover", "permessage-deflate; server_no_context_takeover; server_max_window_bits=10")]
+ [InlineData("permessage-deflate; server_max_window_bits=10; client_no_context_takeover; client_max_window_bits=12", "permessage-deflate; client_no_context_takeover; client_max_window_bits=12; server_no_context_takeover; server_max_window_bits=10")]
+ public void CompressionNegotiationProducesCorrectHeaderWithCustomOptions(string clientHeader, string expectedResponse)
+ {
+ Assert.True(HandshakeHelpers.ParseDeflateOptions(clientHeader.AsSpan(), serverContextTakeover: false, serverMaxWindowBits: 14,
+ out var _, out var response));
+ Assert.Equal(expectedResponse, response);
+ }
+
+ [Theory]
+ [InlineData("permessage-deflate; server_max_window_bits=8")]
+ [InlineData("permessage-deflate; client_max_window_bits=8")]
+ [InlineData("permessage-deflate; server_max_window_bits=16")]
+ [InlineData("permessage-deflate; client_max_window_bits=16")]
+ [InlineData("permessage-deflate; client_max_window_bits=\"15")]
+ [InlineData("permessage-deflate; client_max_window_bits=14\"")]
+ [InlineData("permessage-deflate; client_max_window_bits=\"")]
+ [InlineData("permessage-deflate; client_max_window_bits=\"13")]
+ [InlineData("permessage-deflate; client_max_window_bits=")]
+ [InlineData("permessage-deflate; client_max_window_bits=\"\"")]
+ [InlineData("permessage-deflate; client_max_window_bits=14; client_max_window_bits=14")]
+ [InlineData("permessage-deflate; server_max_window_bits=14; server_max_window_bits=14")]
+ [InlineData("permessage-deflate; server_no_context_takeover; server_no_context_takeover")]
+ [InlineData("permessage-deflate; client_no_context_takeover; client_no_context_takeover")]
+ public void CompressionNegotiateNotAccepted(string clientHeader)
+ {
+ Assert.False(HandshakeHelpers.ParseDeflateOptions(clientHeader.AsSpan(), serverContextTakeover: true, serverMaxWindowBits: 15,
+ out var _, out var response));
+ }
}
}
diff --git a/src/Middleware/WebSockets/test/UnitTests/WebSocketCompressionMiddlewareTests.cs b/src/Middleware/WebSockets/test/UnitTests/WebSocketCompressionMiddlewareTests.cs
new file mode 100644
index 000000000000..3803e52f9808
--- /dev/null
+++ b/src/Middleware/WebSockets/test/UnitTests/WebSocketCompressionMiddlewareTests.cs
@@ -0,0 +1,184 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Linq;
+using System.Net;
+using System.Net.Http;
+using System.Net.WebSockets;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Testing;
+using Microsoft.Net.Http.Headers;
+using Xunit;
+
+namespace Microsoft.AspNetCore.WebSockets.Test
+{
+ public class WebSocketCompressionMiddlewareTests : LoggedTest
+ {
+ [Fact]
+ public async Task CompressionNegotiationServerCanChooseSevrverNoContextTakeover()
+ {
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ {
+ Assert.True(context.WebSockets.IsWebSocketRequest);
+ var webSocket = await context.WebSockets.AcceptWebSocketAsync(new WebSocketAcceptContext()
+ {
+ DangerousEnableCompression = true,
+ DisableServerContextTakeover = true
+ });
+ }))
+ {
+ using (var client = new HttpClient())
+ {
+ var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/"));
+ uri.Scheme = "http";
+
+ // Craft a valid WebSocket Upgrade request
+ using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString()))
+ {
+ SetGenericWebSocketRequest(request);
+ request.Headers.Add(HeaderNames.SecWebSocketExtensions, "permessage-deflate");
+
+ var response = await client.SendAsync(request);
+ Assert.Equal(HttpStatusCode.SwitchingProtocols, response.StatusCode);
+ Assert.Equal("permessage-deflate; server_no_context_takeover", response.Headers.GetValues(HeaderNames.SecWebSocketExtensions).Aggregate((l, r) => $"{l}; {r}"));
+ }
+ }
+ }
+ }
+
+ [Fact]
+ public async Task CompressionNegotiationIgnoredIfNotEnabledOnServer()
+ {
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ {
+ Assert.True(context.WebSockets.IsWebSocketRequest);
+ var webSocket = await context.WebSockets.AcceptWebSocketAsync();
+ }))
+ {
+ using (var client = new HttpClient())
+ {
+ var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/"));
+ uri.Scheme = "http";
+
+ // Craft a valid WebSocket Upgrade request
+ using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString()))
+ {
+ SetGenericWebSocketRequest(request);
+ request.Headers.Add(HeaderNames.SecWebSocketExtensions, "permessage-deflate");
+
+ var response = await client.SendAsync(request);
+ Assert.Equal(HttpStatusCode.SwitchingProtocols, response.StatusCode);
+ Assert.False(response.Headers.Contains(HeaderNames.SecWebSocketExtensions));
+ }
+ }
+ }
+ }
+
+ [Theory]
+ [InlineData("permessage-deflate; server_max_window_bits=14, permessage-deflate; server_max_window_bits=13", "permessage-deflate; server_max_window_bits=13")]
+ [InlineData("permessage-deflate; client_max_window_bits=8, permessage-deflate; client_max_window_bits=13", "permessage-deflate; client_max_window_bits=13; server_max_window_bits=13")]
+ public async Task CompressionNegotiationCanChooseExtension(string clientHeader, string expectedResponse)
+ {
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ {
+ Assert.True(context.WebSockets.IsWebSocketRequest);
+ var webSocket = await context.WebSockets.AcceptWebSocketAsync(new WebSocketAcceptContext()
+ {
+ DangerousEnableCompression = true,
+ ServerMaxWindowBits = 13
+ });
+ }))
+ {
+ using (var client = new HttpClient())
+ {
+ var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/"));
+ uri.Scheme = "http";
+
+ // Craft a valid WebSocket Upgrade request
+ using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString()))
+ {
+ SetGenericWebSocketRequest(request);
+ request.Headers.Add(HeaderNames.SecWebSocketExtensions, clientHeader);
+
+ var response = await client.SendAsync(request);
+ Assert.Equal(HttpStatusCode.SwitchingProtocols, response.StatusCode);
+ Assert.Equal(expectedResponse, response.Headers.GetValues(HeaderNames.SecWebSocketExtensions).Aggregate((l, r) => $"{l}; {r}"));
+ }
+ }
+ }
+ }
+
+ // Smoke test that compression works, we aren't responsible for the specifics of the compression frames
+ [Fact]
+ public async Task CanSendAndReceiveCompressedData()
+ {
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ {
+ Assert.True(context.WebSockets.IsWebSocketRequest);
+ using var webSocket = await context.WebSockets.AcceptWebSocketAsync(new WebSocketAcceptContext()
+ {
+ DangerousEnableCompression = true,
+ ServerMaxWindowBits = 13
+ });
+
+ var serverBuffer = new byte[1024];
+ while (true)
+ {
+ var result = await webSocket.ReceiveAsync(serverBuffer, CancellationToken.None);
+ if (result.MessageType == WebSocketMessageType.Close)
+ {
+ break;
+ }
+ await webSocket.SendAsync(serverBuffer.AsMemory(0, result.Count), result.MessageType, result.EndOfMessage, default);
+ }
+ await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, null, default);
+ }))
+ {
+ using (var client = new ClientWebSocket())
+ {
+ client.Options.DangerousDeflateOptions = new WebSocketDeflateOptions()
+ {
+ ServerMaxWindowBits = 12,
+ ClientMaxWindowBits = 11,
+ };
+ await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);
+ var sendCount = 8193;
+ var clientBuf = new byte[sendCount];
+ var receiveBuf = new byte[sendCount];
+ Random.Shared.NextBytes(clientBuf);
+ await client.SendAsync(clientBuf.AsMemory(0, sendCount), WebSocketMessageType.Binary, true, default);
+ var totalRecv = 0;
+ while (totalRecv < sendCount)
+ {
+ var result = await client.ReceiveAsync(receiveBuf.AsMemory(totalRecv), default);
+ totalRecv += result.Count;
+ if (result.EndOfMessage)
+ {
+ Assert.Equal(sendCount, totalRecv);
+ for (var i = 0; i < sendCount; ++i)
+ {
+ Assert.True(clientBuf[i] == receiveBuf[i], $"offset {i} not equal: {clientBuf[i]} == {receiveBuf[i]}");
+ }
+ }
+ }
+
+ await client.CloseAsync(WebSocketCloseStatus.NormalClosure, null, default);
+ }
+ }
+ }
+
+ private static void SetGenericWebSocketRequest(HttpRequestMessage request)
+ {
+ request.Headers.Connection.Clear();
+ request.Headers.Connection.Add("Upgrade");
+ request.Headers.Connection.Add("keep-alive");
+ request.Headers.Upgrade.Add(new System.Net.Http.Headers.ProductHeaderValue("websocket"));
+ request.Headers.Add(HeaderNames.SecWebSocketVersion, "13");
+ // SecWebSocketKey required to be 16 bytes
+ request.Headers.Add(HeaderNames.SecWebSocketKey, Convert.ToBase64String(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }, Base64FormattingOptions.None));
+ }
+ }
+}
diff --git a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs
index 1fb5686a909b..4fe10f556e4d 100644
--- a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs
+++ b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs
@@ -36,7 +36,7 @@ public async Task Connect_Success()
[Fact]
public async Task NegotiateSubProtocol_Success()
{
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
Assert.Equal("alpha, bravo, charlie", context.Request.Headers["Sec-WebSocket-Protocol"]);
@@ -64,7 +64,7 @@ public async Task NegotiateSubProtocol_Success()
[Fact]
public async Task SendEmptyData_Success()
{
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -89,7 +89,7 @@ public async Task SendEmptyData_Success()
public async Task SendShortData_Success()
{
var orriginalData = Encoding.UTF8.GetBytes("Hello World");
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -114,7 +114,7 @@ public async Task SendShortData_Success()
public async Task SendMediumData_Success()
{
var orriginalData = Encoding.UTF8.GetBytes(new string('a', 130));
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -140,7 +140,7 @@ public async Task SendLongData_Success()
{
var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
var orriginalData = Encoding.UTF8.GetBytes(new string('a', 0x1FFFF));
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -171,7 +171,7 @@ public async Task SendFragmentedData_Success()
{
var orriginalData = Encoding.UTF8.GetBytes("Hello World");
var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -220,7 +220,7 @@ public async Task SendFragmentedData_Success()
public async Task ReceiveShortData_Success()
{
var orriginalData = Encoding.UTF8.GetBytes("Hello World");
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -245,7 +245,7 @@ public async Task ReceiveShortData_Success()
public async Task ReceiveMediumData_Success()
{
var orriginalData = Encoding.UTF8.GetBytes(new string('a', 130));
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -270,7 +270,7 @@ public async Task ReceiveMediumData_Success()
public async Task ReceiveLongData()
{
var orriginalData = Encoding.UTF8.GetBytes(new string('a', 0x1FFFF));
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -303,7 +303,7 @@ public async Task ReceiveLongData()
public async Task ReceiveFragmentedData_Success()
{
var orriginalData = Encoding.UTF8.GetBytes("Hello World");
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -346,7 +346,7 @@ public async Task ReceiveFragmentedData_Success()
public async Task SendClose_Success()
{
string closeDescription = "Test Closed";
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -374,7 +374,7 @@ public async Task SendClose_Success()
public async Task ReceiveClose_Success()
{
string closeDescription = "Test Closed";
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -402,7 +402,7 @@ public async Task ReceiveClose_Success()
public async Task CloseFromOpen_Success()
{
string closeDescription = "Test Closed";
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -432,7 +432,7 @@ public async Task CloseFromOpen_Success()
public async Task CloseFromCloseSent_Success()
{
string closeDescription = "Test Closed";
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -464,7 +464,7 @@ public async Task CloseFromCloseSent_Success()
public async Task CloseFromCloseReceived_Success()
{
string closeDescription = "Test Closed";
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -510,7 +510,7 @@ public async Task CloseFromCloseReceived_Success()
[InlineData(HttpStatusCode.OK, "http://ExAmPLE.cOm")]
public async Task OriginIsValidatedForWebSocketRequests(HttpStatusCode expectedCode, params string[] origins)
{
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
return Task.CompletedTask;
@@ -553,7 +553,7 @@ public async Task OriginIsValidatedForWebSocketRequests(HttpStatusCode expectedC
[Fact]
public async Task OriginIsNotValidatedForNonWebSocketRequests()
{
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, context =>
{
Assert.False(context.WebSockets.IsWebSocketRequest);
return Task.CompletedTask;
@@ -579,7 +579,7 @@ public async Task OriginIsNotValidatedForNonWebSocketRequests()
[Fact]
public async Task CommonHeadersAreSetToInternedStrings()
{
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
@@ -600,7 +600,7 @@ public async Task CommonHeadersAreSetToInternedStrings()
[Fact]
public async Task MultipleValueHeadersNotOverridden()
{
- await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.Generated.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.Generated.cs
index 5760be03b633..fa499588ba0e 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.Generated.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.Generated.cs
@@ -101,7 +101,7 @@ internal enum KnownHeaderType
internal partial class HttpHeaders
{
- private readonly static HashSet _internedHeaderNames = new HashSet(95, StringComparer.OrdinalIgnoreCase)
+ private readonly static HashSet _internedHeaderNames = new HashSet(96, StringComparer.OrdinalIgnoreCase)
{
HeaderNames.Accept,
HeaderNames.AcceptCharset,
@@ -174,6 +174,7 @@ internal partial class HttpHeaders
HeaderNames.SecWebSocketKey,
HeaderNames.SecWebSocketProtocol,
HeaderNames.SecWebSocketVersion,
+ HeaderNames.SecWebSocketExtensions,
HeaderNames.Server,
HeaderNames.SetCookie,
HeaderNames.Status,
@@ -2127,6 +2128,24 @@ StringValues IHeaderDictionary.SecWebSocketVersion
SetValueUnknown(HeaderNames.SecWebSocketVersion, value);
}
}
+ StringValues IHeaderDictionary.SecWebSocketExtensions
+ {
+ get
+ {
+ StringValues value = default;
+ if (!TryGetUnknown(HeaderNames.SecWebSocketExtensions, ref value))
+ {
+ value = default;
+ }
+ return value;
+ }
+ set
+ {
+ if (_isReadOnly) { ThrowHeadersReadOnlyException(); }
+
+ SetValueUnknown(HeaderNames.SecWebSocketExtensions, value);
+ }
+ }
StringValues IHeaderDictionary.Server
{
get
@@ -10012,6 +10031,24 @@ StringValues IHeaderDictionary.SecWebSocketVersion
SetValueUnknown(HeaderNames.SecWebSocketVersion, value);
}
}
+ StringValues IHeaderDictionary.SecWebSocketExtensions
+ {
+ get
+ {
+ StringValues value = default;
+ if (!TryGetUnknown(HeaderNames.SecWebSocketExtensions, ref value))
+ {
+ value = default;
+ }
+ return value;
+ }
+ set
+ {
+ if (_isReadOnly) { ThrowHeadersReadOnlyException(); }
+
+ SetValueUnknown(HeaderNames.SecWebSocketExtensions, value);
+ }
+ }
StringValues IHeaderDictionary.StrictTransportSecurity
{
get
@@ -16387,6 +16424,24 @@ StringValues IHeaderDictionary.SecWebSocketVersion
SetValueUnknown(HeaderNames.SecWebSocketVersion, value);
}
}
+ StringValues IHeaderDictionary.SecWebSocketExtensions
+ {
+ get
+ {
+ StringValues value = default;
+ if (!TryGetUnknown(HeaderNames.SecWebSocketExtensions, ref value))
+ {
+ value = default;
+ }
+ return value;
+ }
+ set
+ {
+ if (_isReadOnly) { ThrowHeadersReadOnlyException(); }
+
+ SetValueUnknown(HeaderNames.SecWebSocketExtensions, value);
+ }
+ }
StringValues IHeaderDictionary.Server
{
get
diff --git a/src/Shared/ValueStringBuilder/ValueStringBuilder.cs b/src/Shared/ValueStringBuilder/ValueStringBuilder.cs
new file mode 100644
index 000000000000..90c65f9cfd72
--- /dev/null
+++ b/src/Shared/ValueStringBuilder/ValueStringBuilder.cs
@@ -0,0 +1,314 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+using System.Diagnostics;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+
+namespace System.Text
+{
+ // Copied from https://github.com/dotnet/runtime/blob/a9c5eadd951dcba73167f72cc624eb790573663a/src/libraries/Common/src/System/Text/ValueStringBuilder.cs
+ internal ref partial struct ValueStringBuilder
+ {
+ private char[]? _arrayToReturnToPool;
+ private Span _chars;
+ private int _pos;
+
+ public ValueStringBuilder(Span initialBuffer)
+ {
+ _arrayToReturnToPool = null;
+ _chars = initialBuffer;
+ _pos = 0;
+ }
+
+ public ValueStringBuilder(int initialCapacity)
+ {
+ _arrayToReturnToPool = ArrayPool.Shared.Rent(initialCapacity);
+ _chars = _arrayToReturnToPool;
+ _pos = 0;
+ }
+
+ public int Length
+ {
+ get => _pos;
+ set
+ {
+ Debug.Assert(value >= 0);
+ Debug.Assert(value <= _chars.Length);
+ _pos = value;
+ }
+ }
+
+ public int Capacity => _chars.Length;
+
+ public void EnsureCapacity(int capacity)
+ {
+ // This is not expected to be called this with negative capacity
+ Debug.Assert(capacity >= 0);
+
+ // If the caller has a bug and calls this with negative capacity, make sure to call Grow to throw an exception.
+ if ((uint)capacity > (uint)_chars.Length)
+ Grow(capacity - _pos);
+ }
+
+ ///
+ /// Get a pinnable reference to the builder.
+ /// Does not ensure there is a null char after
+ /// This overload is pattern matched in the C# 7.3+ compiler so you can omit
+ /// the explicit method call, and write eg "fixed (char* c = builder)"
+ ///
+ public ref char GetPinnableReference()
+ {
+ return ref MemoryMarshal.GetReference(_chars);
+ }
+
+ ///
+ /// Get a pinnable reference to the builder.
+ ///
+ /// Ensures that the builder has a null char after
+ public ref char GetPinnableReference(bool terminate)
+ {
+ if (terminate)
+ {
+ EnsureCapacity(Length + 1);
+ _chars[Length] = '\0';
+ }
+ return ref MemoryMarshal.GetReference(_chars);
+ }
+
+ public ref char this[int index]
+ {
+ get
+ {
+ Debug.Assert(index < _pos);
+ return ref _chars[index];
+ }
+ }
+
+ public override string ToString()
+ {
+ string s = _chars.Slice(0, _pos).ToString();
+ Dispose();
+ return s;
+ }
+
+ /// Returns the underlying storage of the builder.
+ public Span RawChars => _chars;
+
+ ///
+ /// Returns a span around the contents of the builder.
+ ///
+ /// Ensures that the builder has a null char after
+ public ReadOnlySpan AsSpan(bool terminate)
+ {
+ if (terminate)
+ {
+ EnsureCapacity(Length + 1);
+ _chars[Length] = '\0';
+ }
+ return _chars.Slice(0, _pos);
+ }
+
+ public ReadOnlySpan AsSpan() => _chars.Slice(0, _pos);
+ public ReadOnlySpan AsSpan(int start) => _chars.Slice(start, _pos - start);
+ public ReadOnlySpan AsSpan(int start, int length) => _chars.Slice(start, length);
+
+ public bool TryCopyTo(Span destination, out int charsWritten)
+ {
+ if (_chars.Slice(0, _pos).TryCopyTo(destination))
+ {
+ charsWritten = _pos;
+ Dispose();
+ return true;
+ }
+ else
+ {
+ charsWritten = 0;
+ Dispose();
+ return false;
+ }
+ }
+
+ public void Insert(int index, char value, int count)
+ {
+ if (_pos > _chars.Length - count)
+ {
+ Grow(count);
+ }
+
+ int remaining = _pos - index;
+ _chars.Slice(index, remaining).CopyTo(_chars.Slice(index + count));
+ _chars.Slice(index, count).Fill(value);
+ _pos += count;
+ }
+
+ public void Insert(int index, string? s)
+ {
+ if (s == null)
+ {
+ return;
+ }
+
+ int count = s.Length;
+
+ if (_pos > (_chars.Length - count))
+ {
+ Grow(count);
+ }
+
+ int remaining = _pos - index;
+ _chars.Slice(index, remaining).CopyTo(_chars.Slice(index + count));
+ s.AsSpan().CopyTo(_chars.Slice(index));
+ _pos += count;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void Append(char c)
+ {
+ int pos = _pos;
+ if ((uint)pos < (uint)_chars.Length)
+ {
+ _chars[pos] = c;
+ _pos = pos + 1;
+ }
+ else
+ {
+ GrowAndAppend(c);
+ }
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void Append(string? s)
+ {
+ if (s == null)
+ {
+ return;
+ }
+
+ int pos = _pos;
+ if (s.Length == 1 && (uint)pos < (uint)_chars.Length) // very common case, e.g. appending strings from NumberFormatInfo like separators, percent symbols, etc.
+ {
+ _chars[pos] = s[0];
+ _pos = pos + 1;
+ }
+ else
+ {
+ AppendSlow(s);
+ }
+ }
+
+ private void AppendSlow(string s)
+ {
+ int pos = _pos;
+ if (pos > _chars.Length - s.Length)
+ {
+ Grow(s.Length);
+ }
+
+ s.AsSpan().CopyTo(_chars.Slice(pos));
+ _pos += s.Length;
+ }
+
+ public void Append(char c, int count)
+ {
+ if (_pos > _chars.Length - count)
+ {
+ Grow(count);
+ }
+
+ Span dst = _chars.Slice(_pos, count);
+ for (int i = 0; i < dst.Length; i++)
+ {
+ dst[i] = c;
+ }
+ _pos += count;
+ }
+
+ public unsafe void Append(char* value, int length)
+ {
+ int pos = _pos;
+ if (pos > _chars.Length - length)
+ {
+ Grow(length);
+ }
+
+ Span dst = _chars.Slice(_pos, length);
+ for (int i = 0; i < dst.Length; i++)
+ {
+ dst[i] = *value++;
+ }
+ _pos += length;
+ }
+
+ public void Append(ReadOnlySpan value)
+ {
+ int pos = _pos;
+ if (pos > _chars.Length - value.Length)
+ {
+ Grow(value.Length);
+ }
+
+ value.CopyTo(_chars.Slice(_pos));
+ _pos += value.Length;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public Span AppendSpan(int length)
+ {
+ int origPos = _pos;
+ if (origPos > _chars.Length - length)
+ {
+ Grow(length);
+ }
+
+ _pos = origPos + length;
+ return _chars.Slice(origPos, length);
+ }
+
+ [MethodImpl(MethodImplOptions.NoInlining)]
+ private void GrowAndAppend(char c)
+ {
+ Grow(1);
+ Append(c);
+ }
+
+ ///
+ /// Resize the internal buffer either by doubling current buffer size or
+ /// by adding to
+ /// whichever is greater.
+ ///
+ ///
+ /// Number of chars requested beyond current position.
+ ///
+ [MethodImpl(MethodImplOptions.NoInlining)]
+ private void Grow(int additionalCapacityBeyondPos)
+ {
+ Debug.Assert(additionalCapacityBeyondPos > 0);
+ Debug.Assert(_pos > _chars.Length - additionalCapacityBeyondPos, "Grow called incorrectly, no resize is needed.");
+
+ // Make sure to let Rent throw an exception if the caller has a bug and the desired capacity is negative
+ char[] poolArray = ArrayPool.Shared.Rent((int)Math.Max((uint)(_pos + additionalCapacityBeyondPos), (uint)_chars.Length * 2));
+
+ _chars.Slice(0, _pos).CopyTo(poolArray);
+
+ char[]? toReturn = _arrayToReturnToPool;
+ _chars = _arrayToReturnToPool = poolArray;
+ if (toReturn != null)
+ {
+ ArrayPool.Shared.Return(toReturn);
+ }
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void Dispose()
+ {
+ char[]? toReturn = _arrayToReturnToPool;
+ this = default; // for safety, to avoid using pooled array if this instance is erroneously appended to again
+ if (toReturn != null)
+ {
+ ArrayPool.Shared.Return(toReturn);
+ }
+ }
+ }
+}