diff --git a/src/Http/Headers/src/HeaderNames.cs b/src/Http/Headers/src/HeaderNames.cs index 8b1b8af991b3..2528c481cd6f 100644 --- a/src/Http/Headers/src/HeaderNames.cs +++ b/src/Http/Headers/src/HeaderNames.cs @@ -225,6 +225,9 @@ public static class HeaderNames /// Gets the Sec-WebSocket-Version HTTP header name. public static readonly string SecWebSocketVersion = "Sec-WebSocket-Version"; + /// Gets the Sec-WebSocket-Extensions HTTP header name. + public static readonly string SecWebSocketExtensions = "Sec-WebSocket-Extensions"; + /// Gets the Server HTTP header name. public static readonly string Server = "Server"; diff --git a/src/Http/Headers/src/PublicAPI.Unshipped.txt b/src/Http/Headers/src/PublicAPI.Unshipped.txt index 9035747ce6c4..b5ecbe22ebbb 100644 --- a/src/Http/Headers/src/PublicAPI.Unshipped.txt +++ b/src/Http/Headers/src/PublicAPI.Unshipped.txt @@ -5,6 +5,7 @@ Microsoft.Net.Http.Headers.RangeConditionHeaderValue.RangeConditionHeaderValue(M static readonly Microsoft.Net.Http.Headers.HeaderNames.Baggage -> string! static readonly Microsoft.Net.Http.Headers.HeaderNames.Link -> string! static readonly Microsoft.Net.Http.Headers.HeaderNames.ProxyConnection -> string! +static readonly Microsoft.Net.Http.Headers.HeaderNames.SecWebSocketExtensions -> string! static readonly Microsoft.Net.Http.Headers.HeaderNames.XContentTypeOptions -> string! static readonly Microsoft.Net.Http.Headers.HeaderNames.XPoweredBy -> string! static readonly Microsoft.Net.Http.Headers.HeaderNames.XUACompatible -> string! diff --git a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt index 687f1f1e034b..0d6ad2ae7320 100644 --- a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt @@ -23,3 +23,4 @@ abstract Microsoft.AspNetCore.Http.HttpRequest.ContentType.get -> string? static Microsoft.AspNetCore.Builder.UseExtensions.Use(this Microsoft.AspNetCore.Builder.IApplicationBuilder! app, System.Func! middleware) -> Microsoft.AspNetCore.Builder.IApplicationBuilder! static Microsoft.AspNetCore.Builder.UseMiddlewareExtensions.UseMiddleware(this Microsoft.AspNetCore.Builder.IApplicationBuilder! app, System.Type! middleware, params object?[]! args) -> Microsoft.AspNetCore.Builder.IApplicationBuilder! static Microsoft.AspNetCore.Builder.UseMiddlewareExtensions.UseMiddleware(this Microsoft.AspNetCore.Builder.IApplicationBuilder! app, params object?[]! args) -> Microsoft.AspNetCore.Builder.IApplicationBuilder! +virtual Microsoft.AspNetCore.Http.WebSocketManager.AcceptWebSocketAsync(Microsoft.AspNetCore.Http.WebSocketAcceptContext! acceptContext) -> System.Threading.Tasks.Task! diff --git a/src/Http/Http.Abstractions/src/WebSocketManager.cs b/src/Http/Http.Abstractions/src/WebSocketManager.cs index 1fe47d59587d..79fd84bc184c 100644 --- a/src/Http/Http.Abstractions/src/WebSocketManager.cs +++ b/src/Http/Http.Abstractions/src/WebSocketManager.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Collections.Generic; using System.Net.WebSockets; using System.Threading.Tasks; @@ -8,7 +9,7 @@ namespace Microsoft.AspNetCore.Http { /// - /// Manages the establishment of WebSocket connections for a specific HTTP request. + /// Manages the establishment of WebSocket connections for a specific HTTP request. /// public abstract class WebSocketManager { @@ -37,5 +38,12 @@ public virtual Task AcceptWebSocketAsync() /// The sub-protocol to use. /// A task representing the completion of the transition. public abstract Task AcceptWebSocketAsync(string? subProtocol); + + /// + /// + /// + /// + /// + public virtual Task AcceptWebSocketAsync(WebSocketAcceptContext acceptContext) => throw new NotImplementedException(); } } diff --git a/src/Http/Http.Features/src/IHeaderDictionary.Keyed.cs b/src/Http/Http.Features/src/IHeaderDictionary.Keyed.cs index 7f7739599ad9..fd98c8ae4d1d 100644 --- a/src/Http/Http.Features/src/IHeaderDictionary.Keyed.cs +++ b/src/Http/Http.Features/src/IHeaderDictionary.Keyed.cs @@ -202,6 +202,9 @@ public partial interface IHeaderDictionary /// Gets or sets the Sec-WebSocket-Version HTTP header. StringValues SecWebSocketVersion { get => this[HeaderNames.SecWebSocketVersion]; set => this[HeaderNames.SecWebSocketVersion] = value; } + /// Gets or sets the Sec-WebSocket-Extensions HTTP header. + StringValues SecWebSocketExtensions { get => this[HeaderNames.SecWebSocketExtensions]; set => this[HeaderNames.SecWebSocketExtensions] = value; } + /// Gets or sets the Server HTTP header. StringValues Server { get => this[HeaderNames.Server]; set => this[HeaderNames.Server] = value; } diff --git a/src/Http/Http.Features/src/PublicAPI.Unshipped.txt b/src/Http/Http.Features/src/PublicAPI.Unshipped.txt index 6911dbe5b627..a391a2f049ad 100644 --- a/src/Http/Http.Features/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Features/src/PublicAPI.Unshipped.txt @@ -161,6 +161,8 @@ Microsoft.AspNetCore.Http.IHeaderDictionary.RetryAfter.get -> Microsoft.Extensio Microsoft.AspNetCore.Http.IHeaderDictionary.RetryAfter.set -> void Microsoft.AspNetCore.Http.IHeaderDictionary.SecWebSocketAccept.get -> Microsoft.Extensions.Primitives.StringValues Microsoft.AspNetCore.Http.IHeaderDictionary.SecWebSocketAccept.set -> void +Microsoft.AspNetCore.Http.IHeaderDictionary.SecWebSocketExtensions.get -> Microsoft.Extensions.Primitives.StringValues +Microsoft.AspNetCore.Http.IHeaderDictionary.SecWebSocketExtensions.set -> void Microsoft.AspNetCore.Http.IHeaderDictionary.SecWebSocketKey.get -> Microsoft.Extensions.Primitives.StringValues Microsoft.AspNetCore.Http.IHeaderDictionary.SecWebSocketKey.set -> void Microsoft.AspNetCore.Http.IHeaderDictionary.SecWebSocketProtocol.get -> Microsoft.Extensions.Primitives.StringValues @@ -232,6 +234,14 @@ Microsoft.AspNetCore.Http.Features.FeatureCollection.IsReadOnly.get -> bool (for Microsoft.AspNetCore.Http.Features.FeatureCollection.Set(TFeature? instance) -> void (forwarded, contained in Microsoft.Extensions.Features) Microsoft.AspNetCore.Http.Features.FeatureCollection.this[System.Type! key].get -> object? (forwarded, contained in Microsoft.Extensions.Features) Microsoft.AspNetCore.Http.Features.FeatureCollection.this[System.Type! key].set -> void (forwarded, contained in Microsoft.Extensions.Features) +Microsoft.AspNetCore.Http.WebSocketAcceptContext.DangerousEnableCompression.get -> bool +Microsoft.AspNetCore.Http.WebSocketAcceptContext.DangerousEnableCompression.set -> void +Microsoft.AspNetCore.Http.WebSocketAcceptContext.DisableServerContextTakeover.get -> bool +Microsoft.AspNetCore.Http.WebSocketAcceptContext.DisableServerContextTakeover.set -> void +Microsoft.AspNetCore.Http.WebSocketAcceptContext.ServerMaxWindowBits.get -> int +Microsoft.AspNetCore.Http.WebSocketAcceptContext.ServerMaxWindowBits.set -> void virtual Microsoft.AspNetCore.Http.Features.FeatureCollection.Revision.get -> int (forwarded, contained in Microsoft.Extensions.Features) +virtual Microsoft.AspNetCore.Http.WebSocketAcceptContext.KeepAliveInterval.get -> System.TimeSpan? +virtual Microsoft.AspNetCore.Http.WebSocketAcceptContext.KeepAliveInterval.set -> void ~Microsoft.AspNetCore.Http.Features.FeatureReference<> (forwarded, contained in Microsoft.Extensions.Features) ~Microsoft.AspNetCore.Http.Features.FeatureReferences<> (forwarded, contained in Microsoft.Extensions.Features) diff --git a/src/Http/Http.Features/src/WebSocketAcceptContext.cs b/src/Http/Http.Features/src/WebSocketAcceptContext.cs index b293ad874be7..5f2c9d7a049a 100644 --- a/src/Http/Http.Features/src/WebSocketAcceptContext.cs +++ b/src/Http/Http.Features/src/WebSocketAcceptContext.cs @@ -1,7 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using Microsoft.AspNetCore.Http.Features; +using System; +using System.Net.WebSockets; namespace Microsoft.AspNetCore.Http { @@ -10,9 +11,63 @@ namespace Microsoft.AspNetCore.Http /// public class WebSocketAcceptContext { + private int _serverMaxWindowBits = 15; + /// /// Gets or sets the subprotocol being negotiated. /// public virtual string? SubProtocol { get; set; } + + /// + /// The interval to send pong frames. This is a heart-beat that keeps the connection alive. + /// + public virtual TimeSpan? KeepAliveInterval { get; set; } + + /// + /// Enables support for the 'permessage-deflate' WebSocket extension. + /// Be aware that enabling compression over encrypted connections makes the application subject to CRIME/BREACH type attacks. + /// It is strongly advised to turn off compression when sending data containing secrets by + /// specifying when sending such messages. + /// + public bool DangerousEnableCompression { get; set; } + + /// + /// Disables server context takeover when using compression. + /// This setting reduces the memory overhead of compression at the cost of a potentially worse compresson ratio. + /// + /// + /// This property does nothing when is false, + /// or when the client does not use compression. + /// + /// + /// false + /// + public bool DisableServerContextTakeover { get; set; } + + /// + /// Sets the maximum base-2 logarithm of the LZ77 sliding window size that can be used for compression. + /// This setting reduces the memory overhead of compression at the cost of a potentially worse compresson ratio. + /// + /// + /// This property does nothing when is false, + /// or when the client does not use compression. + /// Valid values are 9 through 15. + /// + /// + /// 15 + /// + public int ServerMaxWindowBits + { + get => _serverMaxWindowBits; + set + { + if (value < 9 || value > 15) + { + throw new ArgumentOutOfRangeException(nameof(ServerMaxWindowBits), + "The argument must be a value from 9 to 15."); + } + _serverMaxWindowBits = value; + } + } } } diff --git a/src/Http/Http/src/Internal/DefaultWebSocketManager.cs b/src/Http/Http/src/Internal/DefaultWebSocketManager.cs index 43ad4c0907ea..8697904b952c 100644 --- a/src/Http/Http/src/Internal/DefaultWebSocketManager.cs +++ b/src/Http/Http/src/Internal/DefaultWebSocketManager.cs @@ -17,6 +17,7 @@ internal sealed class DefaultWebSocketManager : WebSocketManager private readonly static Func _nullWebSocketFeature = f => null; private FeatureReferences _features; + private readonly static WebSocketAcceptContext _defaultWebSocketAcceptContext = new WebSocketAcceptContext(); public DefaultWebSocketManager(IFeatureCollection features) { @@ -61,12 +62,19 @@ public override IList WebSocketRequestedProtocols } public override Task AcceptWebSocketAsync(string? subProtocol) + { + var acceptContext = subProtocol is null ? _defaultWebSocketAcceptContext : + new WebSocketAcceptContext() { SubProtocol = subProtocol }; + return AcceptWebSocketAsync(acceptContext); + } + + public override Task AcceptWebSocketAsync(WebSocketAcceptContext acceptContext) { if (WebSocketFeature == null) { throw new NotSupportedException("WebSockets are not supported"); } - return WebSocketFeature.AcceptAsync(new WebSocketAcceptContext() { SubProtocol = subProtocol }); + return WebSocketFeature.AcceptAsync(acceptContext); } struct FeatureInterfaces diff --git a/src/Middleware/WebSockets/src/ExtendedWebSocketAcceptContext.cs b/src/Middleware/WebSockets/src/ExtendedWebSocketAcceptContext.cs index 7bfd7de4963a..2e85ef2d41f8 100644 --- a/src/Middleware/WebSockets/src/ExtendedWebSocketAcceptContext.cs +++ b/src/Middleware/WebSockets/src/ExtendedWebSocketAcceptContext.cs @@ -9,6 +9,7 @@ namespace Microsoft.AspNetCore.WebSockets /// /// Extends the class with additional properties. /// + [Obsolete("This type is obsolete and will be removed in a future version. The recommended alternative is Microsoft.AspNetCore.Http.WebSocketAcceptContext.")] public class ExtendedWebSocketAcceptContext : WebSocketAcceptContext { /// @@ -23,6 +24,6 @@ public class ExtendedWebSocketAcceptContext : WebSocketAcceptContext /// /// The interval to send pong frames. This is a heart-beat that keeps the connection alive. /// - public TimeSpan? KeepAliveInterval { get; set; } + public new TimeSpan? KeepAliveInterval { get; set; } } } diff --git a/src/Middleware/WebSockets/src/HandshakeHelpers.cs b/src/Middleware/WebSockets/src/HandshakeHelpers.cs index 05c5ac5363a3..a242a281154a 100644 --- a/src/Middleware/WebSockets/src/HandshakeHelpers.cs +++ b/src/Middleware/WebSockets/src/HandshakeHelpers.cs @@ -2,6 +2,10 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Net.WebSockets; using System.Security.Cryptography; using System.Text; using Microsoft.AspNetCore.Http; @@ -72,5 +76,210 @@ public static string CreateResponseKey(string requestKey) return Convert.ToBase64String(hashedBytes); } + + // https://datatracker.ietf.org/doc/html/rfc7692#section-7.1 + public static bool ParseDeflateOptions(ReadOnlySpan extension, bool serverContextTakeover, + int serverMaxWindowBits, out WebSocketDeflateOptions parsedOptions, [NotNullWhen(true)] out string? response) + { + bool hasServerMaxWindowBits = false; + bool hasClientMaxWindowBits = false; + bool hasClientNoContext = false; + bool hasServerNoContext = false; + response = null; + parsedOptions = new WebSocketDeflateOptions() + { + ServerContextTakeover = serverContextTakeover, + ServerMaxWindowBits = serverMaxWindowBits + }; + + using var builder = new ValueStringBuilder(WebSocketDeflateConstants.MaxExtensionLength); + builder.Append(WebSocketDeflateConstants.Extension); + + while (true) + { + int end = extension.IndexOf(';'); + ReadOnlySpan value = (end >= 0 ? extension[..end] : extension).Trim(); + + if (value.Length == 0) + { + break; + } + + if (value.SequenceEqual(WebSocketDeflateConstants.ClientNoContextTakeover)) + { + // https://datatracker.ietf.org/doc/html/rfc7692#section-7 + // MUST decline if: + // The negotiation offer contains multiple extension parameters with + // the same name. + if (hasClientNoContext) + { + return false; + } + + hasClientNoContext = true; + parsedOptions.ClientContextTakeover = false; + builder.Append(';'); + builder.Append(' '); + builder.Append(WebSocketDeflateConstants.ClientNoContextTakeover); + } + else if (value.SequenceEqual(WebSocketDeflateConstants.ServerNoContextTakeover)) + { + // https://datatracker.ietf.org/doc/html/rfc7692#section-7 + // MUST decline if: + // The negotiation offer contains multiple extension parameters with + // the same name. + if (hasServerNoContext) + { + return false; + } + + hasServerNoContext = true; + parsedOptions.ServerContextTakeover = false; + } + else if (value.StartsWith(WebSocketDeflateConstants.ClientMaxWindowBits)) + { + // https://datatracker.ietf.org/doc/html/rfc7692#section-7 + // MUST decline if: + // The negotiation offer contains multiple extension parameters with + // the same name. + if (hasClientMaxWindowBits) + { + return false; + } + + hasClientMaxWindowBits = true; + if (!ParseWindowBits(value, WebSocketDeflateConstants.ClientMaxWindowBits, out var clientMaxWindowBits)) + { + return false; + } + + // 8 is a valid value according to the spec, but our zlib implementation does not support it + if (clientMaxWindowBits == 8) + { + return false; + } + + // https://tools.ietf.org/html/rfc7692#section-7.1.2.2 + // the server may either ignore this + // value or use this value to avoid allocating an unnecessarily big LZ77 + // sliding window by including the "client_max_window_bits" extension + // parameter in the corresponding extension negotiation response to the + // offer with a value equal to or smaller than the received value. + parsedOptions.ClientMaxWindowBits = clientMaxWindowBits ?? 15; + + // If a received extension negotiation offer doesn't have the + // "client_max_window_bits" extension parameter, the corresponding + // extension negotiation response to the offer MUST NOT include the + // "client_max_window_bits" extension parameter. + builder.Append(';'); + builder.Append(' '); + builder.Append(WebSocketDeflateConstants.ClientMaxWindowBits); + builder.Append('='); + var len = (parsedOptions.ClientMaxWindowBits > 9) ? 2 : 1; + var span = builder.AppendSpan(len); + var ret = parsedOptions.ClientMaxWindowBits.TryFormat(span, out var written); + Debug.Assert(ret); + Debug.Assert(written == len); + } + else if (value.StartsWith(WebSocketDeflateConstants.ServerMaxWindowBits)) + { + // https://datatracker.ietf.org/doc/html/rfc7692#section-7 + // MUST decline if: + // The negotiation offer contains multiple extension parameters with + // the same name. + if (hasServerMaxWindowBits) + { + return false; + } + + hasServerMaxWindowBits = true; + if (!ParseWindowBits(value, WebSocketDeflateConstants.ServerMaxWindowBits, out var parsedServerMaxWindowBits)) + { + return false; + } + + // 8 is a valid value according to the spec, but our zlib implementation does not support it + if (parsedServerMaxWindowBits == 8) + { + return false; + } + + // https://tools.ietf.org/html/rfc7692#section-7.1.2.1 + // A server accepts an extension negotiation offer with this parameter + // by including the "server_max_window_bits" extension parameter in the + // extension negotiation response to send back to the client with the + // same or smaller value as the offer. + parsedOptions.ServerMaxWindowBits = Math.Min(parsedServerMaxWindowBits ?? 15, serverMaxWindowBits); + } + + static bool ParseWindowBits(ReadOnlySpan value, string propertyName, out int? parsedValue) + { + var startIndex = value.IndexOf('='); + + // parameters can be sent without a value by the client, we'll use the values set by the app developer or the default of 15 + if (startIndex < 0) + { + parsedValue = null; + return true; + } + + value = value[(startIndex + 1)..].TrimEnd(); + + if (value.Length == 0) + { + parsedValue = null; + return false; + } + + // https://datatracker.ietf.org/doc/html/rfc7692#section-5.2 + // check for value in quotes and pull the value out without the quotes + if (value[0] == '"' && value.EndsWith("\"".AsSpan()) && value.Length > 1) + { + value = value[1..^1]; + } + + if (!int.TryParse(value, NumberStyles.Integer, CultureInfo.InvariantCulture, out int windowBits) || + windowBits < 8 || + windowBits > 15) + { + parsedValue = null; + return false; + } + + parsedValue = windowBits; + return true; + } + + if (end < 0) + { + break; + } + extension = extension[(end + 1)..]; + } + + if (!parsedOptions.ServerContextTakeover) + { + builder.Append(';'); + builder.Append(' '); + builder.Append(WebSocketDeflateConstants.ServerNoContextTakeover); + } + + if (hasServerMaxWindowBits || parsedOptions.ServerMaxWindowBits != 15) + { + builder.Append(';'); + builder.Append(' '); + builder.Append(WebSocketDeflateConstants.ServerMaxWindowBits); + builder.Append('='); + var len = (parsedOptions.ServerMaxWindowBits > 9) ? 2 : 1; + var span = builder.AppendSpan(len); + var ret = parsedOptions.ServerMaxWindowBits.TryFormat(span, out var written); + Debug.Assert(ret); + Debug.Assert(written == len); + } + + response = builder.ToString(); + + return true; + } } } diff --git a/src/Middleware/WebSockets/src/Microsoft.AspNetCore.WebSockets.csproj b/src/Middleware/WebSockets/src/Microsoft.AspNetCore.WebSockets.csproj index 931295880832..1443b5fae957 100644 --- a/src/Middleware/WebSockets/src/Microsoft.AspNetCore.WebSockets.csproj +++ b/src/Middleware/WebSockets/src/Microsoft.AspNetCore.WebSockets.csproj @@ -17,6 +17,10 @@ + + + + diff --git a/src/Middleware/WebSockets/src/WebSocketDeflateConstants.cs b/src/Middleware/WebSockets/src/WebSocketDeflateConstants.cs new file mode 100644 index 000000000000..aff0a93dccf4 --- /dev/null +++ b/src/Middleware/WebSockets/src/WebSocketDeflateConstants.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.WebSockets +{ + internal static class WebSocketDeflateConstants + { + /// + /// The maximum length that this extension can have, assuming that we're not using extra white space. + /// + /// "permessage-deflate; client_max_window_bits=15; client_no_context_takeover; server_max_window_bits=15; server_no_context_takeover" + /// + public const int MaxExtensionLength = 128; + + public const string Extension = "permessage-deflate"; + + public const string ClientMaxWindowBits = "client_max_window_bits"; + public const string ClientNoContextTakeover = "client_no_context_takeover"; + + public const string ServerMaxWindowBits = "server_max_window_bits"; + public const string ServerNoContextTakeover = "server_no_context_takeover"; + } +} diff --git a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs index 72fbc077b33f..c640ed4b368d 100644 --- a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs +++ b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs @@ -68,7 +68,7 @@ public Task Invoke(HttpContext context) var upgradeFeature = context.Features.Get(); if (upgradeFeature != null && context.Features.Get() == null) { - var webSocketFeature = new UpgradeHandshake(context, upgradeFeature, _options); + var webSocketFeature = new UpgradeHandshake(context, upgradeFeature, _options, _logger); context.Features.Set(webSocketFeature); if (!_anyOriginAllowed) @@ -97,13 +97,15 @@ private class UpgradeHandshake : IHttpWebSocketFeature private readonly HttpContext _context; private readonly IHttpUpgradeFeature _upgradeFeature; private readonly WebSocketOptions _options; + private readonly ILogger _logger; private bool? _isWebSocketRequest; - public UpgradeHandshake(HttpContext context, IHttpUpgradeFeature upgradeFeature, WebSocketOptions options) + public UpgradeHandshake(HttpContext context, IHttpUpgradeFeature upgradeFeature, WebSocketOptions options, ILogger logger) { _context = context; _upgradeFeature = upgradeFeature; _options = options; + _logger = logger; } public bool IsWebSocketRequest @@ -133,13 +135,22 @@ public async Task AcceptAsync(WebSocketAcceptContext acceptContext) } string? subProtocol = null; + bool enableCompression = false; + bool serverContextTakeover = true; + int serverMaxWindowBits = 15; + TimeSpan keepAliveInterval = _options.KeepAliveInterval; if (acceptContext != null) { subProtocol = acceptContext.SubProtocol; + enableCompression = acceptContext.DangerousEnableCompression; + serverContextTakeover = !acceptContext.DisableServerContextTakeover; + serverMaxWindowBits = acceptContext.ServerMaxWindowBits; + keepAliveInterval = acceptContext.KeepAliveInterval ?? keepAliveInterval; } - TimeSpan keepAliveInterval = _options.KeepAliveInterval; +#pragma warning disable CS0618 // Type or member is obsolete if (acceptContext is ExtendedWebSocketAcceptContext advancedAcceptContext) +#pragma warning restore CS0618 // Type or member is obsolete { if (advancedAcceptContext.KeepAliveInterval.HasValue) { @@ -151,9 +162,45 @@ public async Task AcceptAsync(WebSocketAcceptContext acceptContext) HandshakeHelpers.GenerateResponseHeaders(key, subProtocol, _context.Response.Headers); + WebSocketDeflateOptions? deflateOptions = null; + if (enableCompression) + { + var ext = _context.Request.Headers.SecWebSocketExtensions; + if (ext.Count != 0) + { + // loop over each extension offer, extensions can have multiple offers, we can accept any + foreach (var extension in _context.Request.Headers.GetCommaSeparatedValues(HeaderNames.SecWebSocketExtensions)) + { + if (extension.AsSpan().TrimStart().StartsWith("permessage-deflate", StringComparison.Ordinal)) + { + if (HandshakeHelpers.ParseDeflateOptions(extension.AsSpan().TrimStart(), serverContextTakeover, serverMaxWindowBits, out var parsedOptions, out var response)) + { + Log.CompressionAccepted(_logger, response); + deflateOptions = parsedOptions; + // If more extension types are added, this would need to be a header append + // and we wouldn't want to break out of the loop + _context.Response.Headers.SecWebSocketExtensions = response; + break; + } + } + } + + if (deflateOptions is null) + { + Log.CompressionNotAccepted(_logger); + } + } + } + Stream opaqueTransport = await _upgradeFeature.UpgradeAsync(); // Sets status code to 101 - return WebSocket.CreateFromStream(opaqueTransport, isServer: true, subProtocol: subProtocol, keepAliveInterval: keepAliveInterval); + return WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions() + { + IsServer = true, + KeepAliveInterval = keepAliveInterval, + SubProtocol = subProtocol, + DangerousDeflateOptions = deflateOptions + }); } public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders) @@ -227,5 +274,24 @@ public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictiona return HandshakeHelpers.IsRequestKeyValid(requestHeaders.SecWebSocketKey.ToString()); } } + + private static class Log + { + private static readonly Action _compressionAccepted = + LoggerMessage.Define(LogLevel.Debug, new EventId(1, "CompressionAccepted"), "WebSocket compression negotiation accepted with values '{CompressionResponse}'."); + + private static readonly Action _compressionNotAccepted = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, "CompressionNotAccepted"), "Compression negotiation not accepted by server."); + + public static void CompressionAccepted(ILogger logger, string response) + { + _compressionAccepted(logger, response, null); + } + + public static void CompressionNotAccepted(ILogger logger) + { + _compressionNotAccepted(logger, null); + } + } } } diff --git a/src/Middleware/WebSockets/test/ConformanceTests/AutobahnTestApp/Startup.cs b/src/Middleware/WebSockets/test/ConformanceTests/AutobahnTestApp/Startup.cs index f8b75268c28b..c351ae285e6f 100644 --- a/src/Middleware/WebSockets/test/ConformanceTests/AutobahnTestApp/Startup.cs +++ b/src/Middleware/WebSockets/test/ConformanceTests/AutobahnTestApp/Startup.cs @@ -22,7 +22,10 @@ public void Configure(IApplicationBuilder app, ILoggerFactory loggerFactory) if (context.WebSockets.IsWebSocketRequest) { logger.LogInformation("Received WebSocket request"); - using (var webSocket = await context.WebSockets.AcceptWebSocketAsync()) + using (var webSocket = await context.WebSockets.AcceptWebSocketAsync(new WebSocketAcceptContext() + { + DangerousEnableCompression = true + })) { await Echo(webSocket, context.RequestAborted); } diff --git a/src/Middleware/WebSockets/test/UnitTests/HandshakeTests.cs b/src/Middleware/WebSockets/test/UnitTests/HandshakeTests.cs index ec19793ddc2e..163fc71db4c3 100644 --- a/src/Middleware/WebSockets/test/UnitTests/HandshakeTests.cs +++ b/src/Middleware/WebSockets/test/UnitTests/HandshakeTests.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using Xunit; namespace Microsoft.AspNetCore.WebSockets.Tests @@ -37,5 +38,62 @@ public void RejectsInvalidRequestKeys(string key) { Assert.False(HandshakeHelpers.IsRequestKeyValid(key)); } + + [Theory] + [InlineData("permessage-deflate", "permessage-deflate")] + [InlineData("permessage-deflate; server_no_context_takeover", "permessage-deflate; server_no_context_takeover")] + [InlineData("permessage-deflate; client_no_context_takeover", "permessage-deflate; client_no_context_takeover")] + [InlineData("permessage-deflate; client_max_window_bits=9", "permessage-deflate; client_max_window_bits=9")] + [InlineData("permessage-deflate; client_max_window_bits=\"9\"", "permessage-deflate; client_max_window_bits=9")] + [InlineData("permessage-deflate; client_max_window_bits", "permessage-deflate; client_max_window_bits=15")] + [InlineData("permessage-deflate; server_max_window_bits", "permessage-deflate; server_max_window_bits=15")] + [InlineData("permessage-deflate; server_max_window_bits=10", "permessage-deflate; server_max_window_bits=10")] + [InlineData("permessage-deflate; server_max_window_bits=10; server_no_context_takeover", "permessage-deflate; server_no_context_takeover; server_max_window_bits=10")] + [InlineData("permessage-deflate; server_max_window_bits=10; server_no_context_takeover; client_no_context_takeover; client_max_window_bits=12", "permessage-deflate; client_no_context_takeover; client_max_window_bits=12; server_no_context_takeover; server_max_window_bits=10")] + public void CompressionNegotiationProducesCorrectHeaderWithDefaultOptions(string clientHeader, string expectedResponse) + { + Assert.True(HandshakeHelpers.ParseDeflateOptions(clientHeader.AsSpan(), serverContextTakeover: true, serverMaxWindowBits: 15, + out var _, out var response)); + Assert.Equal(expectedResponse, response); + } + + [Theory] + [InlineData("permessage-deflate", "permessage-deflate; server_no_context_takeover; server_max_window_bits=14")] + [InlineData("permessage-deflate; server_no_context_takeover", "permessage-deflate; server_no_context_takeover; server_max_window_bits=14")] + [InlineData("permessage-deflate; client_no_context_takeover", "permessage-deflate; client_no_context_takeover; server_no_context_takeover; server_max_window_bits=14")] + [InlineData("permessage-deflate; client_max_window_bits=9", "permessage-deflate; client_max_window_bits=9; server_no_context_takeover; server_max_window_bits=14")] + [InlineData("permessage-deflate; client_max_window_bits", "permessage-deflate; client_max_window_bits=15; server_no_context_takeover; server_max_window_bits=14")] + [InlineData("permessage-deflate; server_max_window_bits", "permessage-deflate; server_no_context_takeover; server_max_window_bits=14")] + [InlineData("permessage-deflate; server_max_window_bits=14", "permessage-deflate; server_no_context_takeover; server_max_window_bits=14")] + [InlineData("permessage-deflate; server_max_window_bits=10", "permessage-deflate; server_no_context_takeover; server_max_window_bits=10")] + [InlineData("permessage-deflate; server_max_window_bits=10; server_no_context_takeover", "permessage-deflate; server_no_context_takeover; server_max_window_bits=10")] + [InlineData("permessage-deflate; server_max_window_bits=10; client_no_context_takeover; client_max_window_bits=12", "permessage-deflate; client_no_context_takeover; client_max_window_bits=12; server_no_context_takeover; server_max_window_bits=10")] + public void CompressionNegotiationProducesCorrectHeaderWithCustomOptions(string clientHeader, string expectedResponse) + { + Assert.True(HandshakeHelpers.ParseDeflateOptions(clientHeader.AsSpan(), serverContextTakeover: false, serverMaxWindowBits: 14, + out var _, out var response)); + Assert.Equal(expectedResponse, response); + } + + [Theory] + [InlineData("permessage-deflate; server_max_window_bits=8")] + [InlineData("permessage-deflate; client_max_window_bits=8")] + [InlineData("permessage-deflate; server_max_window_bits=16")] + [InlineData("permessage-deflate; client_max_window_bits=16")] + [InlineData("permessage-deflate; client_max_window_bits=\"15")] + [InlineData("permessage-deflate; client_max_window_bits=14\"")] + [InlineData("permessage-deflate; client_max_window_bits=\"")] + [InlineData("permessage-deflate; client_max_window_bits=\"13")] + [InlineData("permessage-deflate; client_max_window_bits=")] + [InlineData("permessage-deflate; client_max_window_bits=\"\"")] + [InlineData("permessage-deflate; client_max_window_bits=14; client_max_window_bits=14")] + [InlineData("permessage-deflate; server_max_window_bits=14; server_max_window_bits=14")] + [InlineData("permessage-deflate; server_no_context_takeover; server_no_context_takeover")] + [InlineData("permessage-deflate; client_no_context_takeover; client_no_context_takeover")] + public void CompressionNegotiateNotAccepted(string clientHeader) + { + Assert.False(HandshakeHelpers.ParseDeflateOptions(clientHeader.AsSpan(), serverContextTakeover: true, serverMaxWindowBits: 15, + out var _, out var response)); + } } } diff --git a/src/Middleware/WebSockets/test/UnitTests/WebSocketCompressionMiddlewareTests.cs b/src/Middleware/WebSockets/test/UnitTests/WebSocketCompressionMiddlewareTests.cs new file mode 100644 index 000000000000..3803e52f9808 --- /dev/null +++ b/src/Middleware/WebSockets/test/UnitTests/WebSocketCompressionMiddlewareTests.cs @@ -0,0 +1,184 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Testing; +using Microsoft.Net.Http.Headers; +using Xunit; + +namespace Microsoft.AspNetCore.WebSockets.Test +{ + public class WebSocketCompressionMiddlewareTests : LoggedTest + { + [Fact] + public async Task CompressionNegotiationServerCanChooseSevrverNoContextTakeover() + { + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + { + Assert.True(context.WebSockets.IsWebSocketRequest); + var webSocket = await context.WebSockets.AcceptWebSocketAsync(new WebSocketAcceptContext() + { + DangerousEnableCompression = true, + DisableServerContextTakeover = true + }); + })) + { + using (var client = new HttpClient()) + { + var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/")); + uri.Scheme = "http"; + + // Craft a valid WebSocket Upgrade request + using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString())) + { + SetGenericWebSocketRequest(request); + request.Headers.Add(HeaderNames.SecWebSocketExtensions, "permessage-deflate"); + + var response = await client.SendAsync(request); + Assert.Equal(HttpStatusCode.SwitchingProtocols, response.StatusCode); + Assert.Equal("permessage-deflate; server_no_context_takeover", response.Headers.GetValues(HeaderNames.SecWebSocketExtensions).Aggregate((l, r) => $"{l}; {r}")); + } + } + } + } + + [Fact] + public async Task CompressionNegotiationIgnoredIfNotEnabledOnServer() + { + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + { + Assert.True(context.WebSockets.IsWebSocketRequest); + var webSocket = await context.WebSockets.AcceptWebSocketAsync(); + })) + { + using (var client = new HttpClient()) + { + var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/")); + uri.Scheme = "http"; + + // Craft a valid WebSocket Upgrade request + using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString())) + { + SetGenericWebSocketRequest(request); + request.Headers.Add(HeaderNames.SecWebSocketExtensions, "permessage-deflate"); + + var response = await client.SendAsync(request); + Assert.Equal(HttpStatusCode.SwitchingProtocols, response.StatusCode); + Assert.False(response.Headers.Contains(HeaderNames.SecWebSocketExtensions)); + } + } + } + } + + [Theory] + [InlineData("permessage-deflate; server_max_window_bits=14, permessage-deflate; server_max_window_bits=13", "permessage-deflate; server_max_window_bits=13")] + [InlineData("permessage-deflate; client_max_window_bits=8, permessage-deflate; client_max_window_bits=13", "permessage-deflate; client_max_window_bits=13; server_max_window_bits=13")] + public async Task CompressionNegotiationCanChooseExtension(string clientHeader, string expectedResponse) + { + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + { + Assert.True(context.WebSockets.IsWebSocketRequest); + var webSocket = await context.WebSockets.AcceptWebSocketAsync(new WebSocketAcceptContext() + { + DangerousEnableCompression = true, + ServerMaxWindowBits = 13 + }); + })) + { + using (var client = new HttpClient()) + { + var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/")); + uri.Scheme = "http"; + + // Craft a valid WebSocket Upgrade request + using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString())) + { + SetGenericWebSocketRequest(request); + request.Headers.Add(HeaderNames.SecWebSocketExtensions, clientHeader); + + var response = await client.SendAsync(request); + Assert.Equal(HttpStatusCode.SwitchingProtocols, response.StatusCode); + Assert.Equal(expectedResponse, response.Headers.GetValues(HeaderNames.SecWebSocketExtensions).Aggregate((l, r) => $"{l}; {r}")); + } + } + } + } + + // Smoke test that compression works, we aren't responsible for the specifics of the compression frames + [Fact] + public async Task CanSendAndReceiveCompressedData() + { + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + { + Assert.True(context.WebSockets.IsWebSocketRequest); + using var webSocket = await context.WebSockets.AcceptWebSocketAsync(new WebSocketAcceptContext() + { + DangerousEnableCompression = true, + ServerMaxWindowBits = 13 + }); + + var serverBuffer = new byte[1024]; + while (true) + { + var result = await webSocket.ReceiveAsync(serverBuffer, CancellationToken.None); + if (result.MessageType == WebSocketMessageType.Close) + { + break; + } + await webSocket.SendAsync(serverBuffer.AsMemory(0, result.Count), result.MessageType, result.EndOfMessage, default); + } + await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, null, default); + })) + { + using (var client = new ClientWebSocket()) + { + client.Options.DangerousDeflateOptions = new WebSocketDeflateOptions() + { + ServerMaxWindowBits = 12, + ClientMaxWindowBits = 11, + }; + await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); + var sendCount = 8193; + var clientBuf = new byte[sendCount]; + var receiveBuf = new byte[sendCount]; + Random.Shared.NextBytes(clientBuf); + await client.SendAsync(clientBuf.AsMemory(0, sendCount), WebSocketMessageType.Binary, true, default); + var totalRecv = 0; + while (totalRecv < sendCount) + { + var result = await client.ReceiveAsync(receiveBuf.AsMemory(totalRecv), default); + totalRecv += result.Count; + if (result.EndOfMessage) + { + Assert.Equal(sendCount, totalRecv); + for (var i = 0; i < sendCount; ++i) + { + Assert.True(clientBuf[i] == receiveBuf[i], $"offset {i} not equal: {clientBuf[i]} == {receiveBuf[i]}"); + } + } + } + + await client.CloseAsync(WebSocketCloseStatus.NormalClosure, null, default); + } + } + } + + private static void SetGenericWebSocketRequest(HttpRequestMessage request) + { + request.Headers.Connection.Clear(); + request.Headers.Connection.Add("Upgrade"); + request.Headers.Connection.Add("keep-alive"); + request.Headers.Upgrade.Add(new System.Net.Http.Headers.ProductHeaderValue("websocket")); + request.Headers.Add(HeaderNames.SecWebSocketVersion, "13"); + // SecWebSocketKey required to be 16 bytes + request.Headers.Add(HeaderNames.SecWebSocketKey, Convert.ToBase64String(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }, Base64FormattingOptions.None)); + } + } +} diff --git a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs index 1fb5686a909b..4fe10f556e4d 100644 --- a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs +++ b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs @@ -36,7 +36,7 @@ public async Task Connect_Success() [Fact] public async Task NegotiateSubProtocol_Success() { - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); Assert.Equal("alpha, bravo, charlie", context.Request.Headers["Sec-WebSocket-Protocol"]); @@ -64,7 +64,7 @@ public async Task NegotiateSubProtocol_Success() [Fact] public async Task SendEmptyData_Success() { - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -89,7 +89,7 @@ public async Task SendEmptyData_Success() public async Task SendShortData_Success() { var orriginalData = Encoding.UTF8.GetBytes("Hello World"); - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -114,7 +114,7 @@ public async Task SendShortData_Success() public async Task SendMediumData_Success() { var orriginalData = Encoding.UTF8.GetBytes(new string('a', 130)); - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -140,7 +140,7 @@ public async Task SendLongData_Success() { var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var orriginalData = Encoding.UTF8.GetBytes(new string('a', 0x1FFFF)); - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -171,7 +171,7 @@ public async Task SendFragmentedData_Success() { var orriginalData = Encoding.UTF8.GetBytes("Hello World"); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -220,7 +220,7 @@ public async Task SendFragmentedData_Success() public async Task ReceiveShortData_Success() { var orriginalData = Encoding.UTF8.GetBytes("Hello World"); - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -245,7 +245,7 @@ public async Task ReceiveShortData_Success() public async Task ReceiveMediumData_Success() { var orriginalData = Encoding.UTF8.GetBytes(new string('a', 130)); - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -270,7 +270,7 @@ public async Task ReceiveMediumData_Success() public async Task ReceiveLongData() { var orriginalData = Encoding.UTF8.GetBytes(new string('a', 0x1FFFF)); - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -303,7 +303,7 @@ public async Task ReceiveLongData() public async Task ReceiveFragmentedData_Success() { var orriginalData = Encoding.UTF8.GetBytes("Hello World"); - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -346,7 +346,7 @@ public async Task ReceiveFragmentedData_Success() public async Task SendClose_Success() { string closeDescription = "Test Closed"; - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -374,7 +374,7 @@ public async Task SendClose_Success() public async Task ReceiveClose_Success() { string closeDescription = "Test Closed"; - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -402,7 +402,7 @@ public async Task ReceiveClose_Success() public async Task CloseFromOpen_Success() { string closeDescription = "Test Closed"; - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -432,7 +432,7 @@ public async Task CloseFromOpen_Success() public async Task CloseFromCloseSent_Success() { string closeDescription = "Test Closed"; - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -464,7 +464,7 @@ public async Task CloseFromCloseSent_Success() public async Task CloseFromCloseReceived_Success() { string closeDescription = "Test Closed"; - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -510,7 +510,7 @@ public async Task CloseFromCloseReceived_Success() [InlineData(HttpStatusCode.OK, "http://ExAmPLE.cOm")] public async Task OriginIsValidatedForWebSocketRequests(HttpStatusCode expectedCode, params string[] origins) { - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, context => { Assert.True(context.WebSockets.IsWebSocketRequest); return Task.CompletedTask; @@ -553,7 +553,7 @@ public async Task OriginIsValidatedForWebSocketRequests(HttpStatusCode expectedC [Fact] public async Task OriginIsNotValidatedForNonWebSocketRequests() { - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, context => { Assert.False(context.WebSockets.IsWebSocketRequest); return Task.CompletedTask; @@ -579,7 +579,7 @@ public async Task OriginIsNotValidatedForNonWebSocketRequests() [Fact] public async Task CommonHeadersAreSetToInternedStrings() { - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); @@ -600,7 +600,7 @@ public async Task CommonHeadersAreSetToInternedStrings() [Fact] public async Task MultipleValueHeadersNotOverridden() { - await using(var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.Generated.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.Generated.cs index 5760be03b633..fa499588ba0e 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.Generated.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.Generated.cs @@ -101,7 +101,7 @@ internal enum KnownHeaderType internal partial class HttpHeaders { - private readonly static HashSet _internedHeaderNames = new HashSet(95, StringComparer.OrdinalIgnoreCase) + private readonly static HashSet _internedHeaderNames = new HashSet(96, StringComparer.OrdinalIgnoreCase) { HeaderNames.Accept, HeaderNames.AcceptCharset, @@ -174,6 +174,7 @@ internal partial class HttpHeaders HeaderNames.SecWebSocketKey, HeaderNames.SecWebSocketProtocol, HeaderNames.SecWebSocketVersion, + HeaderNames.SecWebSocketExtensions, HeaderNames.Server, HeaderNames.SetCookie, HeaderNames.Status, @@ -2127,6 +2128,24 @@ StringValues IHeaderDictionary.SecWebSocketVersion SetValueUnknown(HeaderNames.SecWebSocketVersion, value); } } + StringValues IHeaderDictionary.SecWebSocketExtensions + { + get + { + StringValues value = default; + if (!TryGetUnknown(HeaderNames.SecWebSocketExtensions, ref value)) + { + value = default; + } + return value; + } + set + { + if (_isReadOnly) { ThrowHeadersReadOnlyException(); } + + SetValueUnknown(HeaderNames.SecWebSocketExtensions, value); + } + } StringValues IHeaderDictionary.Server { get @@ -10012,6 +10031,24 @@ StringValues IHeaderDictionary.SecWebSocketVersion SetValueUnknown(HeaderNames.SecWebSocketVersion, value); } } + StringValues IHeaderDictionary.SecWebSocketExtensions + { + get + { + StringValues value = default; + if (!TryGetUnknown(HeaderNames.SecWebSocketExtensions, ref value)) + { + value = default; + } + return value; + } + set + { + if (_isReadOnly) { ThrowHeadersReadOnlyException(); } + + SetValueUnknown(HeaderNames.SecWebSocketExtensions, value); + } + } StringValues IHeaderDictionary.StrictTransportSecurity { get @@ -16387,6 +16424,24 @@ StringValues IHeaderDictionary.SecWebSocketVersion SetValueUnknown(HeaderNames.SecWebSocketVersion, value); } } + StringValues IHeaderDictionary.SecWebSocketExtensions + { + get + { + StringValues value = default; + if (!TryGetUnknown(HeaderNames.SecWebSocketExtensions, ref value)) + { + value = default; + } + return value; + } + set + { + if (_isReadOnly) { ThrowHeadersReadOnlyException(); } + + SetValueUnknown(HeaderNames.SecWebSocketExtensions, value); + } + } StringValues IHeaderDictionary.Server { get diff --git a/src/Shared/ValueStringBuilder/ValueStringBuilder.cs b/src/Shared/ValueStringBuilder/ValueStringBuilder.cs new file mode 100644 index 000000000000..90c65f9cfd72 --- /dev/null +++ b/src/Shared/ValueStringBuilder/ValueStringBuilder.cs @@ -0,0 +1,314 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace System.Text +{ + // Copied from https://github.com/dotnet/runtime/blob/a9c5eadd951dcba73167f72cc624eb790573663a/src/libraries/Common/src/System/Text/ValueStringBuilder.cs + internal ref partial struct ValueStringBuilder + { + private char[]? _arrayToReturnToPool; + private Span _chars; + private int _pos; + + public ValueStringBuilder(Span initialBuffer) + { + _arrayToReturnToPool = null; + _chars = initialBuffer; + _pos = 0; + } + + public ValueStringBuilder(int initialCapacity) + { + _arrayToReturnToPool = ArrayPool.Shared.Rent(initialCapacity); + _chars = _arrayToReturnToPool; + _pos = 0; + } + + public int Length + { + get => _pos; + set + { + Debug.Assert(value >= 0); + Debug.Assert(value <= _chars.Length); + _pos = value; + } + } + + public int Capacity => _chars.Length; + + public void EnsureCapacity(int capacity) + { + // This is not expected to be called this with negative capacity + Debug.Assert(capacity >= 0); + + // If the caller has a bug and calls this with negative capacity, make sure to call Grow to throw an exception. + if ((uint)capacity > (uint)_chars.Length) + Grow(capacity - _pos); + } + + /// + /// Get a pinnable reference to the builder. + /// Does not ensure there is a null char after + /// This overload is pattern matched in the C# 7.3+ compiler so you can omit + /// the explicit method call, and write eg "fixed (char* c = builder)" + /// + public ref char GetPinnableReference() + { + return ref MemoryMarshal.GetReference(_chars); + } + + /// + /// Get a pinnable reference to the builder. + /// + /// Ensures that the builder has a null char after + public ref char GetPinnableReference(bool terminate) + { + if (terminate) + { + EnsureCapacity(Length + 1); + _chars[Length] = '\0'; + } + return ref MemoryMarshal.GetReference(_chars); + } + + public ref char this[int index] + { + get + { + Debug.Assert(index < _pos); + return ref _chars[index]; + } + } + + public override string ToString() + { + string s = _chars.Slice(0, _pos).ToString(); + Dispose(); + return s; + } + + /// Returns the underlying storage of the builder. + public Span RawChars => _chars; + + /// + /// Returns a span around the contents of the builder. + /// + /// Ensures that the builder has a null char after + public ReadOnlySpan AsSpan(bool terminate) + { + if (terminate) + { + EnsureCapacity(Length + 1); + _chars[Length] = '\0'; + } + return _chars.Slice(0, _pos); + } + + public ReadOnlySpan AsSpan() => _chars.Slice(0, _pos); + public ReadOnlySpan AsSpan(int start) => _chars.Slice(start, _pos - start); + public ReadOnlySpan AsSpan(int start, int length) => _chars.Slice(start, length); + + public bool TryCopyTo(Span destination, out int charsWritten) + { + if (_chars.Slice(0, _pos).TryCopyTo(destination)) + { + charsWritten = _pos; + Dispose(); + return true; + } + else + { + charsWritten = 0; + Dispose(); + return false; + } + } + + public void Insert(int index, char value, int count) + { + if (_pos > _chars.Length - count) + { + Grow(count); + } + + int remaining = _pos - index; + _chars.Slice(index, remaining).CopyTo(_chars.Slice(index + count)); + _chars.Slice(index, count).Fill(value); + _pos += count; + } + + public void Insert(int index, string? s) + { + if (s == null) + { + return; + } + + int count = s.Length; + + if (_pos > (_chars.Length - count)) + { + Grow(count); + } + + int remaining = _pos - index; + _chars.Slice(index, remaining).CopyTo(_chars.Slice(index + count)); + s.AsSpan().CopyTo(_chars.Slice(index)); + _pos += count; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Append(char c) + { + int pos = _pos; + if ((uint)pos < (uint)_chars.Length) + { + _chars[pos] = c; + _pos = pos + 1; + } + else + { + GrowAndAppend(c); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Append(string? s) + { + if (s == null) + { + return; + } + + int pos = _pos; + if (s.Length == 1 && (uint)pos < (uint)_chars.Length) // very common case, e.g. appending strings from NumberFormatInfo like separators, percent symbols, etc. + { + _chars[pos] = s[0]; + _pos = pos + 1; + } + else + { + AppendSlow(s); + } + } + + private void AppendSlow(string s) + { + int pos = _pos; + if (pos > _chars.Length - s.Length) + { + Grow(s.Length); + } + + s.AsSpan().CopyTo(_chars.Slice(pos)); + _pos += s.Length; + } + + public void Append(char c, int count) + { + if (_pos > _chars.Length - count) + { + Grow(count); + } + + Span dst = _chars.Slice(_pos, count); + for (int i = 0; i < dst.Length; i++) + { + dst[i] = c; + } + _pos += count; + } + + public unsafe void Append(char* value, int length) + { + int pos = _pos; + if (pos > _chars.Length - length) + { + Grow(length); + } + + Span dst = _chars.Slice(_pos, length); + for (int i = 0; i < dst.Length; i++) + { + dst[i] = *value++; + } + _pos += length; + } + + public void Append(ReadOnlySpan value) + { + int pos = _pos; + if (pos > _chars.Length - value.Length) + { + Grow(value.Length); + } + + value.CopyTo(_chars.Slice(_pos)); + _pos += value.Length; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Span AppendSpan(int length) + { + int origPos = _pos; + if (origPos > _chars.Length - length) + { + Grow(length); + } + + _pos = origPos + length; + return _chars.Slice(origPos, length); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void GrowAndAppend(char c) + { + Grow(1); + Append(c); + } + + /// + /// Resize the internal buffer either by doubling current buffer size or + /// by adding to + /// whichever is greater. + /// + /// + /// Number of chars requested beyond current position. + /// + [MethodImpl(MethodImplOptions.NoInlining)] + private void Grow(int additionalCapacityBeyondPos) + { + Debug.Assert(additionalCapacityBeyondPos > 0); + Debug.Assert(_pos > _chars.Length - additionalCapacityBeyondPos, "Grow called incorrectly, no resize is needed."); + + // Make sure to let Rent throw an exception if the caller has a bug and the desired capacity is negative + char[] poolArray = ArrayPool.Shared.Rent((int)Math.Max((uint)(_pos + additionalCapacityBeyondPos), (uint)_chars.Length * 2)); + + _chars.Slice(0, _pos).CopyTo(poolArray); + + char[]? toReturn = _arrayToReturnToPool; + _chars = _arrayToReturnToPool = poolArray; + if (toReturn != null) + { + ArrayPool.Shared.Return(toReturn); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Dispose() + { + char[]? toReturn = _arrayToReturnToPool; + this = default; // for safety, to avoid using pooled array if this instance is erroneously appended to again + if (toReturn != null) + { + ArrayPool.Shared.Return(toReturn); + } + } + } +}