Skip to content

Commit 0a4d35c

Browse files
Implement public client application global cache
1 parent 4107f24 commit 0a4d35c

File tree

1 file changed

+179
-38
lines changed

1 file changed

+179
-38
lines changed

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs

+179-38
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Concurrent;
7+
using System.Collections.Generic;
68
using System.Linq;
79
using System.Security;
810
using System.Threading;
@@ -15,6 +17,9 @@ namespace Microsoft.Data.SqlClient
1517
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/ActiveDirectoryAuthenticationProvider/*'/>
1618
public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider
1719
{
20+
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
21+
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
22+
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
1823
private static readonly string s_defaultScopeSuffix = "/.default";
1924
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
2025
private readonly SqlClientLogger _logger = new SqlClientLogger();
@@ -67,10 +72,10 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
6772
}
6873

6974
#if NETSTANDARD
70-
private Func<object> parentActivityOrWindowFunc = null;
75+
private Func<object> _parentActivityOrWindowFunc = null;
7176

7277
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/SetParentActivityOrWindowFunc/*'/>
73-
public void SetParentActivityOrWindowFunc(Func<object> parentActivityOrWindowFunc) => this.parentActivityOrWindowFunc = parentActivityOrWindowFunc;
78+
public void SetParentActivityOrWindowFunc(Func<object> parentActivityOrWindowFunc) => this._parentActivityOrWindowFunc = parentActivityOrWindowFunc;
7479
#endif
7580

7681
#if NETFRAMEWORK
@@ -108,51 +113,24 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication
108113
*
109114
* https://docs.microsoft.com/en-us/azure/active-directory/develop/scenario-desktop-app-registration#redirect-uris
110115
*/
111-
string redirectURI = "https://login.microsoftonline.com/common/oauth2/nativeclient";
116+
string redirectUri = s_nativeClientRedirectUri;
112117

113118
#if NETCOREAPP
114119
if (parameters.AuthenticationMethod != SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
115120
{
116-
redirectURI = "http://localhost";
117-
}
118-
#endif
119-
IPublicClientApplication app;
120-
121-
#if NETSTANDARD
122-
if (parentActivityOrWindowFunc != null)
123-
{
124-
app = PublicClientApplicationBuilder.Create(_applicationClientId)
125-
.WithAuthority(parameters.Authority)
126-
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
127-
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
128-
.WithRedirectUri(redirectURI)
129-
.WithParentActivityOrWindow(parentActivityOrWindowFunc)
130-
.Build();
121+
redirectUri = "http://localhost";
131122
}
132123
#endif
124+
PublicClientAppKey pcaKey = new PublicClientAppKey(parameters.Authority, redirectUri, _applicationClientId
133125
#if NETFRAMEWORK
134-
if (_iWin32WindowFunc != null)
135-
{
136-
app = PublicClientApplicationBuilder.Create(_applicationClientId)
137-
.WithAuthority(parameters.Authority)
138-
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
139-
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
140-
.WithRedirectUri(redirectURI)
141-
.WithParentActivityOrWindow(_iWin32WindowFunc)
142-
.Build();
143-
}
126+
, _iWin32WindowFunc
144127
#endif
145-
#if !NETCOREAPP
146-
else
128+
#if NETSTANDARD
129+
, _parentActivityOrWindowFunc
147130
#endif
148-
{
149-
app = PublicClientApplicationBuilder.Create(_applicationClientId)
150-
.WithAuthority(parameters.Authority)
151-
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
152-
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
153-
.WithRedirectUri(redirectURI)
154-
.Build();
155-
}
131+
);
132+
133+
IPublicClientApplication app = GetPublicClientAppInstance(pcaKey);
156134

157135
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
158136
{
@@ -320,5 +298,168 @@ private class CustomWebUi : ICustomWebUi
320298
public Task<Uri> AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken)
321299
=> _acquireAuthorizationCodeAsyncCallback.Invoke(authorizationUri, redirectUri, cancellationToken);
322300
}
301+
302+
private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey publicClientAppKey)
303+
{
304+
IPublicClientApplication clientApplicationInstance;
305+
306+
if (s_pcaMap.ContainsKey(publicClientAppKey))
307+
{
308+
s_pcaMap.TryGetValue(publicClientAppKey, out clientApplicationInstance);
309+
}
310+
else
311+
{
312+
clientApplicationInstance = CreateClientAppInstance(publicClientAppKey);
313+
s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance);
314+
}
315+
return clientApplicationInstance;
316+
}
317+
318+
private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
319+
{
320+
IPublicClientApplication publicClientApplication;
321+
322+
#if NETSTANDARD
323+
if (_parentActivityOrWindowFunc != null)
324+
{
325+
publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
326+
.WithAuthority(publicClientAppKey._authority)
327+
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
328+
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
329+
.WithRedirectUri(publicClientAppKey._redirectUri)
330+
.WithParentActivityOrWindow(_parentActivityOrWindowFunc)
331+
.Build();
332+
}
333+
#endif
334+
#if NETFRAMEWORK
335+
if (_iWin32WindowFunc != null)
336+
{
337+
publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
338+
.WithAuthority(publicClientAppKey._authority)
339+
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
340+
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
341+
.WithRedirectUri(publicClientAppKey._redirectUri)
342+
.WithParentActivityOrWindow(_iWin32WindowFunc)
343+
.Build();
344+
}
345+
#endif
346+
#if !NETCOREAPP
347+
else
348+
#endif
349+
{
350+
publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
351+
.WithAuthority(publicClientAppKey._authority)
352+
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
353+
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
354+
.WithRedirectUri(publicClientAppKey._redirectUri)
355+
.Build();
356+
}
357+
358+
return publicClientApplication;
359+
}
360+
361+
internal class PublicClientAppKey
362+
{
363+
public readonly string _authority;
364+
public readonly string _redirectUri;
365+
public readonly string _applicationClientId;
366+
#if NETFRAMEWORK
367+
public readonly Func<System.Windows.Forms.IWin32Window> _iWin32WindowFunc;
368+
#endif
369+
#if NETSTANDARD
370+
public readonly Func<object> _parentActivityOrWindowFunc;
371+
#endif
372+
private int _hashValue;
373+
374+
public PublicClientAppKey(string authority, string redirectUri, string applicationClientId
375+
#if NETFRAMEWORK
376+
, Func<System.Windows.Forms.IWin32Window> iWin32WindowFunc
377+
#endif
378+
#if NETSTANDARD
379+
, Func<object> parentActivityOrWindowFunc
380+
#endif
381+
)
382+
{
383+
_authority = authority;
384+
_redirectUri = redirectUri;
385+
_applicationClientId = applicationClientId;
386+
#if NETFRAMEWORK
387+
_iWin32WindowFunc = iWin32WindowFunc;
388+
#endif
389+
#if NETSTANDARD
390+
_parentActivityOrWindowFunc = parentActivityOrWindowFunc;
391+
#endif
392+
}
393+
394+
public override bool Equals(object obj)
395+
{
396+
if (obj == null)
397+
{
398+
return false;
399+
}
400+
401+
PublicClientAppKey pcaKey = obj as PublicClientAppKey;
402+
return (string.CompareOrdinal(_authority, pcaKey._authority) == 0
403+
&& string.CompareOrdinal(_redirectUri, pcaKey._redirectUri) == 0
404+
&& string.CompareOrdinal(_applicationClientId, pcaKey._applicationClientId) == 0
405+
#if NETFRAMEWORK
406+
&& pcaKey._iWin32WindowFunc == _iWin32WindowFunc
407+
#endif
408+
#if NETSTANDARD
409+
&& pcaKey._parentActivityOrWindowFunc == _parentActivityOrWindowFunc
410+
#endif
411+
);
412+
}
413+
414+
public override int GetHashCode()
415+
{
416+
return _hashValue;
417+
}
418+
419+
private void CalculateHashCode()
420+
{
421+
_hashValue = base.GetHashCode();
422+
423+
if (_authority != null)
424+
{
425+
unchecked
426+
{
427+
_hashValue = _hashValue * 17 + _authority.GetHashCode();
428+
}
429+
}
430+
if (_redirectUri != null)
431+
{
432+
unchecked
433+
{
434+
_hashValue = _hashValue * 17 + _redirectUri.GetHashCode();
435+
}
436+
}
437+
if (_applicationClientId != null)
438+
{
439+
unchecked
440+
{
441+
_hashValue = _hashValue * 17 + _applicationClientId.GetHashCode();
442+
}
443+
}
444+
#if NETFRAMEWORK
445+
if (_iWin32WindowFunc != null)
446+
{
447+
unchecked
448+
{
449+
_hashValue = _hashValue * 17 + _iWin32WindowFunc.GetHashCode();
450+
}
451+
}
452+
#endif
453+
#if NETSTANDARD
454+
if (_parentActivityOrWindowFunc != null)
455+
{
456+
unchecked
457+
{
458+
_hashValue = _hashValue * 17 + _parentActivityOrWindowFunc.GetHashCode();
459+
}
460+
}
461+
#endif
462+
}
463+
}
323464
}
324465
}

0 commit comments

Comments
 (0)