diff --git a/src/Middleware/HeaderPropagation/ref/Microsoft.AspNetCore.HeaderPropagation.netcoreapp3.0.cs b/src/Middleware/HeaderPropagation/ref/Microsoft.AspNetCore.HeaderPropagation.netcoreapp3.0.cs index 887719c62484..c0bea7edc63e 100644 --- a/src/Middleware/HeaderPropagation/ref/Microsoft.AspNetCore.HeaderPropagation.netcoreapp3.0.cs +++ b/src/Middleware/HeaderPropagation/ref/Microsoft.AspNetCore.HeaderPropagation.netcoreapp3.0.cs @@ -14,10 +14,10 @@ namespace Microsoft.AspNetCore.HeaderPropagation public readonly partial struct HeaderPropagationContext { private readonly object _dummy; - public HeaderPropagationContext(Microsoft.AspNetCore.Http.HttpContext httpContext, string headerName, Microsoft.Extensions.Primitives.StringValues headerValue) { throw null; } + public HeaderPropagationContext(System.Collections.Generic.IDictionary requestHeaders, string headerName, Microsoft.Extensions.Primitives.StringValues headerValue) { throw null; } public string HeaderName { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } public Microsoft.Extensions.Primitives.StringValues HeaderValue { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } - public Microsoft.AspNetCore.Http.HttpContext HttpContext { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } + public System.Collections.Generic.IDictionary RequestHeaders { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } } public partial class HeaderPropagationEntry { @@ -58,7 +58,7 @@ public HeaderPropagationMessageHandlerOptions() { } } public partial class HeaderPropagationMiddleware { - public HeaderPropagationMiddleware(Microsoft.AspNetCore.Http.RequestDelegate next, Microsoft.Extensions.Options.IOptions options, Microsoft.AspNetCore.HeaderPropagation.HeaderPropagationValues values) { } + public HeaderPropagationMiddleware(Microsoft.AspNetCore.Http.RequestDelegate next, Microsoft.AspNetCore.HeaderPropagation.IHeaderPropagationProcessor processor) { } public System.Threading.Tasks.Task Invoke(Microsoft.AspNetCore.Http.HttpContext context) { throw null; } } public partial class HeaderPropagationOptions @@ -66,11 +66,20 @@ public partial class HeaderPropagationOptions public HeaderPropagationOptions() { } public Microsoft.AspNetCore.HeaderPropagation.HeaderPropagationEntryCollection Headers { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } } + public partial class HeaderPropagationProcessor : Microsoft.AspNetCore.HeaderPropagation.IHeaderPropagationProcessor + { + public HeaderPropagationProcessor(Microsoft.Extensions.Options.IOptions options, Microsoft.AspNetCore.HeaderPropagation.HeaderPropagationValues values) { } + public void ProcessRequest(System.Collections.Generic.IDictionary requestHeaders) { } + } public partial class HeaderPropagationValues { public HeaderPropagationValues() { } public System.Collections.Generic.IDictionary Headers { get { throw null; } set { } } } + public partial interface IHeaderPropagationProcessor + { + void ProcessRequest(System.Collections.Generic.IDictionary requestHeaders); + } } namespace Microsoft.Extensions.DependencyInjection { diff --git a/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/SampleHostedService.cs b/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/SampleHostedService.cs new file mode 100644 index 000000000000..e83757a45435 --- /dev/null +++ b/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/SampleHostedService.cs @@ -0,0 +1,47 @@ +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.HeaderPropagation; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Primitives; + +namespace HeaderPropagationSample +{ + public class SampleHostedService : IHostedService + { + private readonly IHttpClientFactory _httpClientFactory; + private readonly HeaderPropagationProcessor _headerPropagationProcessor; + private readonly ILogger _logger; + + public SampleHostedService(IHttpClientFactory httpClientFactory, HeaderPropagationProcessor headerPropagationProcessor, ILogger logger) + { + _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory)); + _headerPropagationProcessor = headerPropagationProcessor ?? throw new ArgumentNullException(nameof(headerPropagationProcessor)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + } + + public Task StartAsync(CancellationToken cancellationToken) + { + return DoWorkAsync(); + } + + private async Task DoWorkAsync() + { + _logger.LogInformation("Background Service is working."); + + _headerPropagationProcessor.ProcessRequest(new Dictionary()); + var client = _httpClientFactory.CreateClient("test"); + var result = await client.GetAsync("http://localhost:62013/forwarded"); + + _logger.LogInformation("Background Service:\n{result}", result); + } + + public Task StopAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + } +} diff --git a/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/Startup.cs b/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/Startup.cs index 00138f6efe31..e405c1a038c1 100644 --- a/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/Startup.cs +++ b/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/Startup.cs @@ -49,6 +49,8 @@ public void ConfigureServices(IServiceCollection services) services .AddHttpClient("another") .AddHeaderPropagation(options => options.Headers.Add("X-BetaFeatures", "X-Experiments")); + + services.AddHostedService(); } public void Configure(IApplicationBuilder app, IWebHostEnvironment env, IHttpClientFactory clientFactory) diff --git a/src/Middleware/HeaderPropagation/src/DependencyInjection/HeaderPropagationServiceCollectionExtensions.cs b/src/Middleware/HeaderPropagation/src/DependencyInjection/HeaderPropagationServiceCollectionExtensions.cs index ab3d60cc0d1e..af7c94451025 100644 --- a/src/Middleware/HeaderPropagation/src/DependencyInjection/HeaderPropagationServiceCollectionExtensions.cs +++ b/src/Middleware/HeaderPropagation/src/DependencyInjection/HeaderPropagationServiceCollectionExtensions.cs @@ -23,6 +23,7 @@ public static IServiceCollection AddHeaderPropagation(this IServiceCollection se } services.TryAddSingleton(); + services.TryAddSingleton(); return services; } diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationContext.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationContext.cs index bacddeae00ab..8ec09abd1639 100644 --- a/src/Middleware/HeaderPropagation/src/HeaderPropagationContext.cs +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationContext.cs @@ -2,7 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using Microsoft.AspNetCore.Http; +using System.Collections.Generic; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.HeaderPropagation @@ -14,32 +14,22 @@ public readonly struct HeaderPropagationContext { /// /// Initializes a new instance of with the provided - /// , and . + /// , and . /// - /// The associated with the current request. + /// The headers associated with the current request. /// The header name. /// The header value present in the current request. - public HeaderPropagationContext(HttpContext httpContext, string headerName, StringValues headerValue) + public HeaderPropagationContext(IDictionary requestHeaders, string headerName, StringValues headerValue) { - if (httpContext == null) - { - throw new ArgumentNullException(nameof(httpContext)); - } - - if (headerName == null) - { - throw new ArgumentNullException(nameof(headerName)); - } - - HttpContext = httpContext; - HeaderName = headerName; + RequestHeaders = requestHeaders ?? throw new ArgumentNullException(nameof(requestHeaders)); + HeaderName = headerName ?? throw new ArgumentNullException(nameof(headerName)); HeaderValue = headerValue; } /// - /// Gets the associated with the current request. + /// Gets the headers associated with the current request. /// - public HttpContext HttpContext { get; } + public IDictionary RequestHeaders { get; } /// /// Gets the header name. diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandler.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandler.cs index 7d07dc3572f6..06693e980e7c 100644 --- a/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandler.cs +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandler.cs @@ -15,8 +15,8 @@ namespace Microsoft.AspNetCore.HeaderPropagation /// public class HeaderPropagationMessageHandler : DelegatingHandler { - private readonly HeaderPropagationValues _values; private readonly HeaderPropagationMessageHandlerOptions _options; + private readonly HeaderPropagationValues _values; /// /// Creates a new instance of the . @@ -47,9 +47,10 @@ protected override Task SendAsync(HttpRequestMessage reques if (captured == null) { var message = - $"The {nameof(HeaderPropagationValues)}.{nameof(HeaderPropagationValues.Headers)} property has not been " + - $"initialized. Register the header propagation middleware by adding 'app.{nameof(HeaderPropagationApplicationBuilderExtensions.UseHeaderPropagation)}() " + - $"in the 'Configure(...)' method."; + $"The {nameof(HeaderPropagationValues)}.{nameof(HeaderPropagationValues.Headers)} property has not been initialized. " + + $"If using this {nameof(HttpClient)} as part of an http request, register the header propagation middleware by adding " + + $"'app.{nameof(HeaderPropagationApplicationBuilderExtensions.UseHeaderPropagation)}() in the 'Configure(...)' method. " + + $"Otherwise, use {nameof(HeaderPropagationProcessor)}.{nameof(HeaderPropagationProcessor.ProcessRequest)}() before using the {nameof(HttpClient)}."; throw new InvalidOperationException(message); } diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs index f62c9e4a72bd..9ba1362d379b 100644 --- a/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs @@ -2,12 +2,9 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; using System.Net.Http; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.Options; -using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.HeaderPropagation { @@ -17,61 +14,19 @@ namespace Microsoft.AspNetCore.HeaderPropagation public class HeaderPropagationMiddleware { private readonly RequestDelegate _next; - private readonly HeaderPropagationOptions _options; - private readonly HeaderPropagationValues _values; + private readonly IHeaderPropagationProcessor _processor; - public HeaderPropagationMiddleware(RequestDelegate next, IOptions options, HeaderPropagationValues values) + public HeaderPropagationMiddleware(RequestDelegate next, IHeaderPropagationProcessor processor) { _next = next ?? throw new ArgumentNullException(nameof(next)); - - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - _options = options.Value; - - _values = values ?? throw new ArgumentNullException(nameof(values)); + _processor = processor ?? throw new ArgumentNullException(nameof(processor)); } public Task Invoke(HttpContext context) { - // We need to intialize the headers because the message handler will use this to detect misconfiguration. - var headers = _values.Headers ??= new Dictionary(StringComparer.OrdinalIgnoreCase); - - // Perf: avoid foreach since we don't define a struct enumerator. - var entries = _options.Headers; - for (var i = 0; i < entries.Count; i++) - { - var entry = entries[i]; - - // We intentionally process entries in order, and allow earlier entries to - // take precedence over later entries when they have the same output name. - if (!headers.ContainsKey(entry.CapturedHeaderName)) - { - var value = GetValue(context, entry); - if (!StringValues.IsNullOrEmpty(value)) - { - headers.Add(entry.CapturedHeaderName, value); - } - } - } + _processor.ProcessRequest(context.Request.Headers); return _next.Invoke(context); } - - private static StringValues GetValue(HttpContext context, HeaderPropagationEntry entry) - { - context.Request.Headers.TryGetValue(entry.InboundHeaderName, out var value); - if (entry.ValueFilter != null) - { - var filtered = entry.ValueFilter(new HeaderPropagationContext(context, entry.InboundHeaderName, value)); - if (!StringValues.IsNullOrEmpty(filtered)) - { - value = filtered; - } - } - - return value; - } } } diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationProcessor.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationProcessor.cs new file mode 100644 index 000000000000..df9be40f45f5 --- /dev/null +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationProcessor.cs @@ -0,0 +1,79 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.AspNetCore.HeaderPropagation +{ + public class HeaderPropagationProcessor : IHeaderPropagationProcessor + { + private readonly HeaderPropagationOptions _options; + private readonly HeaderPropagationValues _values; + + public HeaderPropagationProcessor(IOptions options, HeaderPropagationValues values) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + _options = options.Value; + + _values = values; + } + + public void ProcessRequest(IDictionary requestHeaders) + { + if (requestHeaders == null) + { + throw new ArgumentNullException(nameof(requestHeaders)); + } + + if (_values.Headers != null) + { + var message = + $"The {nameof(HeaderPropagationValues)}.{nameof(HeaderPropagationValues.Headers)} was already initialized. " + + $"Each invocation of {nameof(HeaderPropagationProcessor)}.{nameof(HeaderPropagationProcessor.ProcessRequest)}() must be in a separate async context."; + throw new InvalidOperationException(message); + } + + // We need to intialize the headers because the message handler will use this to detect misconfiguration. + var headers = _values.Headers = new Dictionary(StringComparer.OrdinalIgnoreCase); + + // Perf: avoid foreach since we don't define a struct enumerator. + var entries = _options.Headers; + for (var i = 0; i < entries.Count; i++) + { + var entry = entries[i]; + + // We intentionally process entries in order, and allow earlier entries to + // take precedence over later entries when they have the same output name. + if (!headers.ContainsKey(entry.CapturedHeaderName)) + { + var value = GetValue(requestHeaders, entry); + if (!StringValues.IsNullOrEmpty(value)) + { + headers.Add(entry.CapturedHeaderName, value); + } + } + } + } + + private static StringValues GetValue(IDictionary requestHeaders, HeaderPropagationEntry entry) + { + requestHeaders.TryGetValue(entry.InboundHeaderName, out var value); + if (entry.ValueFilter != null) + { + var filtered = entry.ValueFilter(new HeaderPropagationContext(requestHeaders, entry.InboundHeaderName, value)); + if (!StringValues.IsNullOrEmpty(filtered)) + { + value = filtered; + } + } + + return value; + } + } +} diff --git a/src/Middleware/HeaderPropagation/src/IHeaderPropagationProcessor.cs b/src/Middleware/HeaderPropagation/src/IHeaderPropagationProcessor.cs new file mode 100644 index 000000000000..db4137acf366 --- /dev/null +++ b/src/Middleware/HeaderPropagation/src/IHeaderPropagationProcessor.cs @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.HeaderPropagation +{ + public interface IHeaderPropagationProcessor + { + void ProcessRequest(System.Collections.Generic.IDictionary requestHeaders); + } +} \ No newline at end of file diff --git a/src/Middleware/HeaderPropagation/test/HeaderPropagationIntegrationTest.cs b/src/Middleware/HeaderPropagation/test/HeaderPropagationIntegrationTest.cs index 585bb914729a..bed3b3545e52 100644 --- a/src/Middleware/HeaderPropagation/test/HeaderPropagationIntegrationTest.cs +++ b/src/Middleware/HeaderPropagation/test/HeaderPropagationIntegrationTest.cs @@ -63,8 +63,9 @@ public async Task HeaderPropagation_WithoutMiddleware_Throws() Assert.Equal(HttpStatusCode.OK, response.StatusCode); Assert.IsType(captured); Assert.Equal( - "The HeaderPropagationValues.Headers property has not been initialized. Register the header propagation middleware " + - "by adding 'app.UseHeaderPropagation() in the 'Configure(...)' method.", + "The HeaderPropagationValues.Headers property has not been initialized. If using this HttpClient as part of an http request, " + + "register the header propagation middleware by adding 'app.UseHeaderPropagation() in the 'Configure(...)' method. " + + "Otherwise, use HeaderPropagationProcessor.ProcessRequest() before using the HttpClient.", captured.Message); } diff --git a/src/Middleware/HeaderPropagation/test/HeaderPropagationMiddlewareTest.cs b/src/Middleware/HeaderPropagation/test/HeaderPropagationMiddlewareTest.cs index f6576d2d688d..2a9bfa5f421b 100644 --- a/src/Middleware/HeaderPropagation/test/HeaderPropagationMiddlewareTest.cs +++ b/src/Middleware/HeaderPropagation/test/HeaderPropagationMiddlewareTest.cs @@ -1,9 +1,9 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Collections.Generic; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; using Xunit; @@ -15,167 +15,52 @@ public HeaderPropagationMiddlewareTest() { Context = new DefaultHttpContext(); Next = ctx => Task.CompletedTask; - Configuration = new HeaderPropagationOptions(); - State = new HeaderPropagationValues(); - Middleware = new HeaderPropagationMiddleware(Next, - new OptionsWrapper(Configuration), - State); + Processor = new HeaderPropagationProcessorMock(); + Middleware = new HeaderPropagationMiddleware(Next, Processor); } public DefaultHttpContext Context { get; set; } public RequestDelegate Next { get; set; } - public HeaderPropagationOptions Configuration { get; set; } - public HeaderPropagationValues State { get; set; } + public HeaderPropagationProcessorMock Processor { get; set; } public HeaderPropagationMiddleware Middleware { get; set; } [Fact] - public async Task HeaderInRequest_AddCorrectValue() + public async Task Invoke_InvokesProcessorWithRequestHeaders() { - // Arrange - Configuration.Headers.Add("in"); - Context.Request.Headers.Add("in", "test"); - - // Act - await Middleware.Invoke(Context); - - // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal(new[] { "test" }, State.Headers["in"]); - } - - [Fact] - public async Task NoHeaderInRequest_DoesNotAddIt() - { - // Arrange - Configuration.Headers.Add("in"); - - // Act - await Middleware.Invoke(Context); - - // Assert - Assert.Empty(State.Headers); - } - - [Fact] - public async Task HeaderInRequest_NotInOptions_DoesNotAddIt() - { - // Arrange - Context.Request.Headers.Add("in", "test"); - // Act await Middleware.Invoke(Context); // Assert - Assert.Empty(State.Headers); + Assert.NotNull(Processor.ReceivedRequestHeaders); + Assert.Same(Processor.ReceivedRequestHeaders, Context.Request.Headers); } [Fact] - public async Task MultipleHeadersInRequest_AddAllHeaders() - { - // Arrange - Configuration.Headers.Add("in"); - Configuration.Headers.Add("another"); - Context.Request.Headers.Add("in", "test"); - Context.Request.Headers.Add("another", "test2"); - - // Act - await Middleware.Invoke(Context); - - // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal(new[] { "test" }, State.Headers["in"]); - Assert.Contains("another", State.Headers.Keys); - Assert.Equal(new[] { "test2" }, State.Headers["another"]); - } - - [Theory] - [InlineData(null)] - [InlineData("")] - public async Task HeaderEmptyInRequest_DoesNotAddIt(string headerValue) - { - // Arrange - Configuration.Headers.Add("in"); - Context.Request.Headers.Add("in", headerValue); - - // Act - await Middleware.Invoke(Context); - - // Assert - Assert.DoesNotContain("in", State.Headers.Keys); - } - - [Theory] - [InlineData(new[] { "default" }, new[] { "default" })] - [InlineData(new[] { "default", "other" }, new[] { "default", "other" })] - public async Task UsesValueFilter(string[] filterValues, string[] expectedValues) + public async Task Invoke_InvokesNextMiddleware() { // Arrange - string receivedName = null; - StringValues receivedValue = default; - HttpContext receivedContext = null; - Configuration.Headers.Add("in", context => + var called = false; + Next = ctx => { - receivedValue = context.HeaderValue; - receivedName = context.HeaderName; - receivedContext = context.HttpContext; - return filterValues; - }); - - Context.Request.Headers.Add("in", "value"); - - // Act - await Middleware.Invoke(Context); - - // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal(expectedValues, State.Headers["in"]); - Assert.Equal("in", receivedName); - Assert.Equal(new StringValues("value"), receivedValue); - Assert.Same(Context, receivedContext); - } - - [Fact] - public async Task PreferValueFilter_OverRequestHeader() - { - // Arrange - Configuration.Headers.Add("in", context => "test"); - Context.Request.Headers.Add("in", "no"); + called = true; + return Task.CompletedTask; + }; // Act await Middleware.Invoke(Context); // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal("test", State.Headers["in"]); + Assert.True(called); } - [Fact] - public async Task EmptyValuesFromValueFilter_DoesNotAddIt() + public class HeaderPropagationProcessorMock : IHeaderPropagationProcessor { - // Arrange - Configuration.Headers.Add("in", (context) => StringValues.Empty); - - // Act - await Middleware.Invoke(Context); + public IDictionary ReceivedRequestHeaders { get; private set; } - // Assert - Assert.DoesNotContain("in", State.Headers.Keys); - } - - [Fact] - public async Task MultipleEntries_AddsFirstToProduceValue() - { - // Arrange - Configuration.Headers.Add("in"); - Configuration.Headers.Add("in", (context) => StringValues.Empty); - Configuration.Headers.Add("in", (context) => "Test"); - - // Act - await Middleware.Invoke(Context); - - // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal("Test", State.Headers["in"]); + public void ProcessRequest(IDictionary requestHeaders) + { + ReceivedRequestHeaders = requestHeaders; + } } } } diff --git a/src/Middleware/HeaderPropagation/test/HeaderPropagationProcessorTest.cs b/src/Middleware/HeaderPropagation/test/HeaderPropagationProcessorTest.cs new file mode 100644 index 000000000000..1caa02e862b6 --- /dev/null +++ b/src/Middleware/HeaderPropagation/test/HeaderPropagationProcessorTest.cs @@ -0,0 +1,196 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; +using Xunit; + +namespace Microsoft.AspNetCore.HeaderPropagation.Tests +{ + public class HeaderPropagationProcessorTest + { + public HeaderPropagationProcessorTest() + { + Configuration = new HeaderPropagationOptions(); + State = new HeaderPropagationValues(); + Processor = new HeaderPropagationProcessor(new OptionsWrapper(Configuration), State); + RequestHeaders = new Dictionary(); + } + + public HeaderPropagationOptions Configuration { get; set; } + public HeaderPropagationValues State { get; set; } + public HeaderPropagationProcessor Processor { get; set; } + public IDictionary RequestHeaders { get; set; } + + [Fact] + public void HeaderInRequest_AddCorrectValue() + { + // Arrange + Configuration.Headers.Add("in"); + RequestHeaders.Add("in", "test"); + + // Act + Processor.ProcessRequest(RequestHeaders); + + // Assert + Assert.Contains("in", State.Headers.Keys); + Assert.Equal(new[] { "test" }, State.Headers["in"]); + } + + [Fact] + public void NoHeaderInRequest_DoesNotAddIt() + { + // Arrange + Configuration.Headers.Add("in"); + + // Act + Processor.ProcessRequest(RequestHeaders); + + // Assert + Assert.Empty(State.Headers); + } + + [Fact] + public void HeaderInRequest_NotInOptions_DoesNotAddIt() + { + // Arrange + RequestHeaders.Add("in", "test"); + + // Act + Processor.ProcessRequest(RequestHeaders); + + // Assert + Assert.Empty(State.Headers); + } + + [Fact] + public void MultipleHeadersInRequest_AddAllHeaders() + { + // Arrange + Configuration.Headers.Add("in"); + Configuration.Headers.Add("another"); + RequestHeaders.Add("in", "test"); + RequestHeaders.Add("another", "test2"); + + // Act + Processor.ProcessRequest(RequestHeaders); + + // Assert + Assert.Contains("in", State.Headers.Keys); + Assert.Equal(new[] { "test" }, State.Headers["in"]); + Assert.Contains("another", State.Headers.Keys); + Assert.Equal(new[] { "test2" }, State.Headers["another"]); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + public void HeaderEmptyInRequest_DoesNotAddIt(string headerValue) + { + // Arrange + Configuration.Headers.Add("in"); + RequestHeaders.Add("in", headerValue); + + // Act + Processor.ProcessRequest(RequestHeaders); + + // Assert + Assert.DoesNotContain("in", State.Headers.Keys); + } + + [Theory] + [InlineData(new[] { "default" }, new[] { "default" })] + [InlineData(new[] { "default", "other" }, new[] { "default", "other" })] + public void UsesValueFilter(string[] filterValues, string[] expectedValues) + { + // Arrange + string receivedName = null; + StringValues receivedValue = default; + IDictionary receivedRequestHeaders = null; + Configuration.Headers.Add("in", context => + { + receivedValue = context.HeaderValue; + receivedName = context.HeaderName; + receivedRequestHeaders = context.RequestHeaders; + return filterValues; + }); + + RequestHeaders.Add("in", "value"); + + // Act + Processor.ProcessRequest(RequestHeaders); + + // Assert + Assert.Contains("in", State.Headers.Keys); + Assert.Equal(expectedValues, State.Headers["in"]); + Assert.Equal("in", receivedName); + Assert.Equal(new StringValues("value"), receivedValue); + Assert.Same(RequestHeaders, receivedRequestHeaders); + } + + [Fact] + public void PreferValueFilter_OverRequestHeader() + { + // Arrange + Configuration.Headers.Add("in", context => "test"); + RequestHeaders.Add("in", "no"); + + // Act + Processor.ProcessRequest(RequestHeaders); + + // Assert + Assert.Contains("in", State.Headers.Keys); + Assert.Equal("test", State.Headers["in"]); + } + + [Fact] + public void EmptyValuesFromValueFilter_DoesNotAddIt() + { + // Arrange + Configuration.Headers.Add("in", (context) => StringValues.Empty); + + // Act + Processor.ProcessRequest(RequestHeaders); + + // Assert + Assert.DoesNotContain("in", State.Headers.Keys); + } + + [Fact] + public void MultipleEntries_AddsFirstToProduceValue() + { + // Arrange + Configuration.Headers.Add("in"); + Configuration.Headers.Add("in", (context) => StringValues.Empty); + Configuration.Headers.Add("in", (context) => "Test"); + + // Act + Processor.ProcessRequest(RequestHeaders); + + // Assert + Assert.Contains("in", State.Headers.Keys); + Assert.Equal("Test", State.Headers["in"]); + } + + [Fact] + public void MultipleCalls_ThrowsException() + { + // Arrange + Configuration.Headers.Add("in"); + Configuration.Headers.Add("in", (context) => StringValues.Empty); + Configuration.Headers.Add("in", (context) => "Test"); + + // Act + Processor.ProcessRequest(RequestHeaders); + var exception = Assert.Throws(() => Processor.ProcessRequest(RequestHeaders)); + + // Assert + Assert.Equal( + "The HeaderPropagationValues.Headers was already initialized. " + + "Each invocation of HeaderPropagationProcessor.ProcessRequest() must be in a separate async context.", + exception.Message); + } + } +}