Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
namespace NServiceBus.Core.Tests.AssemblyScanner;

using System;
using System.IO;
using System.Linq;
using System.Runtime.Loader;
using Hosting.Helpers;
using NUnit.Framework;

[TestFixture]
public class When_more_than_one_AssemblyLoadContext_has_scannable_types
{
[Test]
public void Should_only_load_one_copy_of_the_assembly()
{
var scanPath = Path.Combine(TestContext.CurrentContext.TestDirectory, "TestDlls", "Messages");
var customAssemblyLoadContext = new AssemblyLoadContext("ScannerTestALC", isCollectible: true);

customAssemblyLoadContext.LoadFromAssemblyPath(Path.Combine(scanPath, "Messages.Referencing.Core.dll"));

var scanner = new AssemblyScanner(scanPath);
var result = scanner.GetScannableAssemblies();

var loadedFromScanPath = result.Assemblies
.Where(a =>
!string.IsNullOrWhiteSpace(a.Location) &&
a.Location.StartsWith(scanPath, StringComparison.OrdinalIgnoreCase))
.ToList();

Assert.That(loadedFromScanPath, Is.Not.Empty, "Expected at least one assembly to be loaded from the scan directory.");

var assemblies = loadedFromScanPath.GroupBy(a => a.FullName);

foreach (var assembly in assemblies)
{
var numberOfTimesLoaded = assembly.Count();
Assert.That(numberOfTimesLoaded, Is.EqualTo(1), $"Assembly {assembly.Key} was loaded from more than one AssemblyLoadContext.");
}

var messagesAssembly = loadedFromScanPath.Single(a => a.FullName.StartsWith("Messages.Referencing.Core"));
var loadContext = AssemblyLoadContext.GetLoadContext(messagesAssembly);

Assert.That(loadContext.Name, Is.EqualTo("ScannerTestALC"), "The wrong AssemblyLoadContext was used to load the assembly.");
}
}
7 changes: 7 additions & 0 deletions src/NServiceBus.Core.Tests/StructConventionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ namespace NServiceBus.Core.Tests;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using NUnit.Framework;
Expand Down Expand Up @@ -38,6 +39,12 @@ public void ApproveStructsWhichDontFollowStructGuidelines()
continue;
}

// readonly structs can probably be ignored
if (type.GetCustomAttribute<IsReadOnlyAttribute>() != null)
{
continue;
}

// For some reason this class's size is different across platforms causing the test to fail on Linux. Disabling here since it won't be used as of v8
if (type.Namespace.StartsWith("NServiceBus.Timeout.Core"))
{
Expand Down
29 changes: 19 additions & 10 deletions src/NServiceBus.Core/Hosting/Helpers/AssemblyScanner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ internal IReadOnlyCollection<Type> TypesToSkip
public AssemblyScannerResults GetScannableAssemblies()
{
var results = new AssemblyScannerResults();
var processed = new Dictionary<string, bool>(StringComparer.OrdinalIgnoreCase);
var processed = new Dictionary<AssemblyIdentity, bool>();

if (assemblyToScan != null)
{
Expand Down Expand Up @@ -207,50 +207,54 @@ bool TryLoadScannableAssembly(string assemblyPath, AssemblyScannerResults result
}
}

bool ScanAssembly(Assembly assembly, Dictionary<string, bool> processed)
bool ScanAssembly(Assembly assembly, Dictionary<AssemblyIdentity, bool> processed)
{
if (assembly == null)
{
return false;
}

if (processed.TryGetValue(assembly.FullName, out var value))
var identity = GetAssemblyIdentity(assembly);

if (processed.TryGetValue(identity, out var value))
{
return value;
}

processed[assembly.FullName] = false;
processed[identity] = false;

var assemblyName = assembly.GetName();
if (IsCoreOrMessageInterfaceAssembly(assemblyName))
{
return processed[assembly.FullName] = true;
return processed[identity] = true;
}

if (ShouldScanDependencies(assembly))
{
var context = AssemblyLoadContext.GetLoadContext(assembly);

foreach (var referencedAssemblyName in assembly.GetReferencedAssemblies())
{
var referencedAssembly = GetReferencedAssembly(referencedAssemblyName);
var referencedAssembly = GetReferencedAssembly(context, referencedAssemblyName);
var referencesCore = ScanAssembly(referencedAssembly, processed);
if (referencesCore)
{
processed[assembly.FullName] = true;
processed[identity] = true;
break;
}
}
}

return processed[assembly.FullName];
return processed[identity];
}

static Assembly GetReferencedAssembly(AssemblyName assemblyName)
static Assembly GetReferencedAssembly(AssemblyLoadContext context, AssemblyName assemblyName)
{
Assembly referencedAssembly = null;

try
{
referencedAssembly = Assembly.Load(assemblyName);
referencedAssembly = context?.LoadFromAssemblyName(assemblyName);
}
catch (Exception ex) when (ex is FileNotFoundException or BadImageFormatException or FileLoadException) { }

Expand Down Expand Up @@ -439,4 +443,9 @@ bool ShouldScanDependencies(Assembly assembly)
// And other windows azure stuff
"Microsoft.WindowsAzure"
};

static AssemblyIdentity GetAssemblyIdentity(Assembly assembly) =>
new(assembly.FullName!, AssemblyLoadContext.GetLoadContext(assembly));

readonly record struct AssemblyIdentity(string FullName, AssemblyLoadContext LoadContext);
}
44 changes: 41 additions & 3 deletions src/NServiceBus.Core/Hosting/Helpers/AssemblyScannerResults.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
namespace NServiceBus.Hosting.Helpers;
namespace NServiceBus.Hosting.Helpers;

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.Loader;

/// <summary>
/// Holds <see cref="AssemblyScanner.GetScannableAssemblies" /> results.
Expand Down Expand Up @@ -45,7 +46,44 @@ public AssemblyScannerResults()

internal void RemoveDuplicates()
{
Assemblies = Assemblies.Distinct().ToList();
Types = Types.Distinct().ToList();
if (AssemblyLoadContext.All.Count() == 1)
{
Assemblies = [.. Assemblies.Distinct()];
Types = [.. Types.Distinct()];

return;
}

var preferredAssemblies = new Dictionary<string, Assembly>(StringComparer.OrdinalIgnoreCase);

foreach (var assembly in Assemblies)
{
var fullName = assembly.FullName;

if (fullName is null)
{
continue;
}

if (!preferredAssemblies.TryGetValue(fullName, out var existing))
{
preferredAssemblies[fullName] = assembly;
continue;
}

preferredAssemblies[fullName] = PreferAssembly(existing, assembly);
}

Assemblies = [.. preferredAssemblies.Values];

var assemblySet = Assemblies.ToHashSet();
Types = [.. Types.Where(t => assemblySet.Contains(t.Assembly)).Distinct()];
}

static Assembly PreferAssembly(Assembly left, Assembly right)
{
var leftIsDefault = AssemblyLoadContext.GetLoadContext(left) == AssemblyLoadContext.Default;

return leftIsDefault ? right : left;
}
}