Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/Microsoft.IdentityModel.Tokens/CryptoProviderFactory.cs
Comment thread
cpp11nullptr marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public class CryptoProviderFactory
{
private static CryptoProviderFactory _default;
private static readonly ConcurrentDictionary<string, string> _typeToAlgorithmMap = new ConcurrentDictionary<string, string>();
private static readonly ConcurrentDictionary<string, string> _customProviderTypeCache = new ConcurrentDictionary<string, string>();
private static int _defaultSignatureProviderObjectPoolCacheSize = Environment.ProcessorCount * 4;
private static string _typeofAsymmetricSignatureProvider = typeof(AsymmetricSignatureProvider).ToString();
private static string _typeofSymmetricSignatureProvider = typeof(SymmetricSignatureProvider).ToString();
Expand Down Expand Up @@ -92,6 +93,7 @@ public CryptoProviderFactory(CryptoProviderFactory other)
CryptoProviderCache = new InMemoryCryptoProviderCache() { CryptoProviderFactory = this };
CustomCryptoProvider = other.CustomCryptoProvider;
CacheSignatureProviders = other.CacheSignatureProviders;
CacheCustomProviders = other.CacheCustomProviders;
SignatureProviderObjectPoolCacheSize = other.SignatureProviderObjectPoolCacheSize;
}

Expand All @@ -115,6 +117,18 @@ public CryptoProviderFactory(CryptoProviderFactory other)
[DefaultValue(true)]
public bool CacheSignatureProviders { get; set; } = DefaultCacheSignatureProviders;

/// <summary>
/// Gets or sets a bool controlling if <see cref="SignatureProvider"/> instances created by <see cref="CustomCryptoProvider"/>
/// should be cached. Default is <see langword="false"/>.
/// </summary>
/// <remarks>
/// When <see langword="true"/>, signature providers returned by <see cref="ICryptoProvider.Create(string, object[])"/>
/// are cached using the same <see cref="CryptoProviderCache"/> used for built-in providers. This avoids
/// repeated provider creation and key materialisation on every signature validation.
/// </remarks>
[DefaultValue(false)]
public bool CacheCustomProviders { get; set; }
Comment thread
cpp11nullptr marked this conversation as resolved.

/// <summary>
/// Gets or sets the maximum size of the object pool used by the SignatureProvider that are used for crypto objects.
/// </summary>
Expand Down Expand Up @@ -598,6 +612,23 @@ private SignatureProvider CreateSignatureProvider(
SignatureProvider signatureProvider;
if (CustomCryptoProvider != null && CustomCryptoProvider.IsSupportedAlgorithm(algorithm, key, willCreateSignatures))
{
if (CacheCustomProviders && CacheSignatureProviders && cacheProvider)
{
// Try cache lookup first using the remembered provider type from a previous Create call.
string cacheTypeKey = CustomCryptoProvider.GetType().ToString() + "-" + algorithm;
Comment thread
iNinja marked this conversation as resolved.
if (_customProviderTypeCache.TryGetValue(cacheTypeKey, out string providerType)
&& CryptoProviderCache.TryGetSignatureProvider(
key,
algorithm,
providerType,
willCreateSignatures,
out SignatureProvider cachedProvider))
{
cachedProvider.AddRef();
return cachedProvider;
}
}

signatureProvider = CustomCryptoProvider.Create(algorithm, key, willCreateSignatures) as SignatureProvider;
if (signatureProvider == null)
throw LogHelper.LogExceptionMessage(
Expand All @@ -608,6 +639,17 @@ private SignatureProvider CreateSignatureProvider(
LogHelper.MarkAsNonPII(key.KeyId),
LogHelper.MarkAsNonPII(typeof(SignatureProvider)))));

if (CacheCustomProviders && CacheSignatureProviders && cacheProvider)
Comment thread
iNinja marked this conversation as resolved.
{
// Remember the provider type for future cache lookups.
string cacheTypeKey = CustomCryptoProvider.GetType().ToString() + "-" + algorithm;
string providerType = signatureProvider.GetType().ToString();
_customProviderTypeCache.TryAdd(cacheTypeKey, providerType);

if (ShouldCacheSignatureProvider(signatureProvider))
signatureProvider.IsCached = CryptoProviderCache.TryAdd(signatureProvider);
}

return signatureProvider;
}

Expand Down
109 changes: 109 additions & 0 deletions src/Microsoft.IdentityModel.Tokens/InternalAPI.Unshipped.txt

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Microsoft.IdentityModel.Tokens.CryptoProviderFactory.CacheCustomProviders.get -> bool
Microsoft.IdentityModel.Tokens.CryptoProviderFactory.CacheCustomProviders.set -> void
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Microsoft.IdentityModel.Tokens.CryptoProviderFactory.CacheCustomProviders.get -> bool
Microsoft.IdentityModel.Tokens.CryptoProviderFactory.CacheCustomProviders.set -> void
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Microsoft.IdentityModel.Tokens.CryptoProviderFactory.CacheCustomProviders.get -> bool
Microsoft.IdentityModel.Tokens.CryptoProviderFactory.CacheCustomProviders.set -> void
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Microsoft.IdentityModel.Tokens.CryptoProviderFactory.CacheCustomProviders.get -> bool
Microsoft.IdentityModel.Tokens.CryptoProviderFactory.CacheCustomProviders.set -> void
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Microsoft.IdentityModel.Tokens.CryptoProviderFactory.CacheCustomProviders.get -> bool
Microsoft.IdentityModel.Tokens.CryptoProviderFactory.CacheCustomProviders.set -> void
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Microsoft.IdentityModel.Tokens.CryptoProviderFactory.CacheCustomProviders.get -> bool
Microsoft.IdentityModel.Tokens.CryptoProviderFactory.CacheCustomProviders.set -> void
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Microsoft.IdentityModel.Tokens.CryptoProviderFactory.CacheCustomProviders.get -> bool
Microsoft.IdentityModel.Tokens.CryptoProviderFactory.CacheCustomProviders.set -> void
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,197 @@ private static IList<SignatureProvider> CreateVerifyingProviders(CryptoProviderF

private static bool GetSignatureProviderIsDisposedByReflect(SignatureProvider signatureProvider) =>
(bool)signatureProvider.GetType().GetField("_disposed", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(signatureProvider);

[Fact]
public void CacheCustomProviders_WhenEnabled_CachesProviderFromCustomCryptoProvider()
{
// Arrange
var signingKey = new SymmetricSecurityKey(KeyingMaterial.DefaultSymmetricKeyBytes_256)
{
KeyId = "test-cache-kid"
};
var algorithm = SecurityAlgorithms.HmacSha256;

int createCount = 0;
var customCrypto = new CountingCryptoProvider(algorithm, () =>
{
createCount++;
return new SymmetricSignatureProvider(signingKey, algorithm, false);
});

var factory = new CryptoProviderFactory(CryptoProviderCacheTests.CreateCacheForTesting())
{
CustomCryptoProvider = customCrypto,
CacheCustomProviders = true,
CacheSignatureProviders = true
};

// Act — first call creates and should cache
var provider1 = factory.CreateForVerifying(signingKey, algorithm);

// Debug: verify the cache state directly
string providerType = provider1.GetType().ToString();
string internalId = signingKey.InternalId;

Assert.True(internalId.Length > 0,
$"Key InternalId should not be empty, got: '{internalId}'");
Assert.True(provider1.IsCached,
$"Provider should be marked as cached. IsCached={provider1.IsCached}");

// Try to retrieve directly from the cache to isolate the issue
// Wait briefly for the async event queue
System.Threading.Thread.Sleep(200);
Comment thread
pmaytak marked this conversation as resolved.
Outdated

Comment thread
pmaytak marked this conversation as resolved.
Outdated
bool tryGetResult = factory.CryptoProviderCache.TryGetSignatureProvider(
signingKey, algorithm, providerType, false, out var fromCache);

// Debug: manually compute the cache keys to compare
string addKey = $"{provider1.Key.GetType()}-{provider1.Key.InternalId}-{provider1.Algorithm}-{provider1.GetType()}";
string getKey = $"{signingKey.GetType()}-{signingKey.InternalId}-{algorithm}-{providerType}";

Assert.True(tryGetResult,
$"TryGetSignatureProvider failed. " +
$"addKey='{addKey}', " +
$"getKey='{getKey}', " +
$"keysMatch={addKey == getKey}, " +
$"createCount={createCount}");

Assert.Same(provider1, fromCache);
}

[Fact]
public void CacheCustomProviders_WhenDisabled_CreatesNewProviderEachTime()
{
// Arrange
var signingKey = KeyingMaterial.DefaultSymmetricSigningCreds_256_Sha2.Key;
var algorithm = SecurityAlgorithms.HmacSha256;

int createCount = 0;
var customCrypto = new CountingCryptoProvider(algorithm, () =>
{
createCount++;
return new SymmetricSignatureProvider(signingKey, algorithm);
});

var factory = new CryptoProviderFactory(CryptoProviderCacheTests.CreateCacheForTesting())
{
CustomCryptoProvider = customCrypto,
CacheCustomProviders = false, // default
CacheSignatureProviders = true
};

// Act
var provider1 = factory.CreateForVerifying(signingKey, algorithm);
var provider2 = factory.CreateForVerifying(signingKey, algorithm);

// Assert — Create called each time, providers are not the same
Assert.Equal(2, createCount);
Assert.NotSame(provider1, provider2);
}

[Fact]
public void CacheCustomProviders_WhenCacheSignatureProvidersDisabled_DoesNotCache()
{
// Arrange — CacheCustomProviders = true but CacheSignatureProviders = false
// The master switch overrides the sub-switch.
var signingKey = KeyingMaterial.DefaultSymmetricSigningCreds_256_Sha2.Key;
var algorithm = SecurityAlgorithms.HmacSha256;

int createCount = 0;
var customCrypto = new CountingCryptoProvider(algorithm, () =>
{
createCount++;
return new SymmetricSignatureProvider(signingKey, algorithm);
});

var factory = new CryptoProviderFactory(CryptoProviderCacheTests.CreateCacheForTesting())
{
CustomCryptoProvider = customCrypto,
CacheCustomProviders = true,
CacheSignatureProviders = false // master switch off
};

// Act
var provider1 = factory.CreateForVerifying(signingKey, algorithm);
var provider2 = factory.CreateForVerifying(signingKey, algorithm);

// Assert — Create called each time despite CacheCustomProviders = true
Assert.Equal(2, createCount);
}

/// <summary>
/// A custom crypto provider that counts how many times Create is called.
/// </summary>
private class CountingCryptoProvider : ICryptoProvider
{
private readonly string _algorithm;
private readonly Func<SignatureProvider> _factory;

public CountingCryptoProvider(string algorithm, Func<SignatureProvider> factory)
{
_algorithm = algorithm;
_factory = factory;
}

public bool IsSupportedAlgorithm(string algorithm, params object[] args) =>
algorithm == _algorithm;

public object Create(string algorithm, params object[] args) =>
_factory();

public void Release(object cryptoInstance) { }
}

[Fact]
public void CacheCustomProviders_ConcurrentAccess_AllThreadsGetSameProvider()
{
// Arrange
var signingKey = new SymmetricSecurityKey(KeyingMaterial.DefaultSymmetricKeyBytes_256)
{
KeyId = "concurrent-test-kid"
};
var algorithm = SecurityAlgorithms.HmacSha256;

int createCount = 0;
var customCrypto = new CountingCryptoProvider(algorithm, () =>
{
Interlocked.Increment(ref createCount);
return new SymmetricSignatureProvider(signingKey, algorithm, false);
});

var factory = new CryptoProviderFactory(CryptoProviderCacheTests.CreateCacheForTesting())
{
CustomCryptoProvider = customCrypto,
CacheCustomProviders = true,
CacheSignatureProviders = true
};

// Warm the cache with an initial call so the provider type is known
// and the provider is in the cache.
var warmup = factory.CreateForVerifying(signingKey, algorithm);
Assert.True(warmup.IsCached, "Warmup provider should be cached.");
int warmupCreateCount = createCount;

// Act — launch many concurrent calls
int threadCount = 20;
var providers = new SignatureProvider[threadCount];
var barrier = new System.Threading.Barrier(threadCount);

Parallel.For(0, threadCount, i =>
{
barrier.SignalAndWait(); // synchronize start
providers[i] = factory.CreateForVerifying(signingKey, algorithm);
});

// Assert — all threads should get the same cached instance
for (int i = 0; i < threadCount; i++)
{
Assert.Same(warmup, providers[i]);
}

// Create should not have been called again after warmup (all cache hits)
Assert.Equal(warmupCreateCount, createCount);
}
}
}
#pragma warning restore CS3016 // Arrays as attribute arguments is not CLS-compliant